Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions pyrit/backend/services/attack_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from pyrit.backend.models.common import PaginationInfo
from pyrit.backend.services.converter_service import get_converter_service
from pyrit.backend.services.target_service import get_target_service
from pyrit.memory import CentralMemory
from pyrit.memory import CentralMemory, data_serializer_factory
from pyrit.models import (
AttackOutcome,
AttackResult,
Expand All @@ -60,7 +60,6 @@
MessagePiece,
PromptDataType,
build_atomic_attack_identifier,
data_serializer_factory,
)
from pyrit.prompt_normalizer import PromptConverterConfiguration, PromptNormalizer

Expand Down
2 changes: 1 addition & 1 deletion pyrit/backend/services/converter_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
CreateConverterResponse,
PreviewStep,
)
from pyrit.memory import data_serializer_factory
from pyrit.models import PromptDataType
from pyrit.models.data_type_serializer import data_serializer_factory
from pyrit.prompt_converter import PromptConverter
from pyrit.prompt_target import PromptTarget
from pyrit.registry.object_registries import ConverterRegistry
Expand Down
2 changes: 1 addition & 1 deletion pyrit/common/data_url_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the MIT license.

from pyrit.common.deprecation import print_deprecation_message
from pyrit.models import DataTypeSerializer, data_serializer_factory
from pyrit.memory import DataTypeSerializer, data_serializer_factory

# Supported image formats for Azure OpenAI GPT-4o,
# https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/use-your-image-data
Expand Down
4 changes: 2 additions & 2 deletions pyrit/common/display_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

from pyrit.common.deprecation import print_deprecation_message
from pyrit.common.notebook_utils import is_in_ipython_session
from pyrit.memory import CentralMemory
from pyrit.models import AzureBlobStorageIO, DiskStorageIO, MessagePiece
from pyrit.memory import AzureBlobStorageIO, CentralMemory, DiskStorageIO
from pyrit.models import MessagePiece

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion pyrit/datasets/seed_datasets/remote/_image_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Any

from pyrit.common.net_utility import make_request_and_raise_if_error_async
from pyrit.models import data_serializer_factory
from pyrit.memory import data_serializer_factory

logger = logging.getLogger(__name__)

Expand Down
3 changes: 2 additions & 1 deletion pyrit/datasets/seed_datasets/remote/msts_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import (
_RemoteDatasetLoader,
)
from pyrit.models import SeedDataset, SeedPrompt, data_serializer_factory
from pyrit.memory import data_serializer_factory
Comment thread
rlundeen2 marked this conversation as resolved.
from pyrit.models import SeedDataset, SeedPrompt

if TYPE_CHECKING:
from PIL.Image import Image as PILImage
Expand Down
36 changes: 35 additions & 1 deletion pyrit/memory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,52 @@
from pyrit.memory.memory_interface import MemoryInterface
from pyrit.memory.memory_models import AttackResultEntry, EmbeddingDataEntry, PromptMemoryEntry, SeedEntry
from pyrit.memory.sqlite_memory import SQLiteMemory
from pyrit.memory.storage import (
AllowedCategories,
AudioPathDataTypeSerializer,
AzureBlobStorageIO,
BinaryPathDataTypeSerializer,
DataTypeSerializer,
DiskStorageIO,
ErrorDataTypeSerializer,
ImagePathDataTypeSerializer,
StorageIO,
SupportedContentType,
TextDataTypeSerializer,
URLDataTypeSerializer,
VideoPathDataTypeSerializer,
data_serializer_factory,
set_message_piece_sha256_async,
set_seed_sha256_async,
)

__all__ = [
"AllowedCategories",
"AttackResultEntry",
"AudioPathDataTypeSerializer",
"AzureBlobStorageIO",
"AzureSQLMemory",
"BinaryPathDataTypeSerializer",
"CentralMemory",
"SQLiteMemory",
"DataTypeSerializer",
"data_serializer_factory",
"DiskStorageIO",
"EmbeddingDataEntry",
"ErrorDataTypeSerializer",
"ImagePathDataTypeSerializer",
"MemoryInterface",
"MemoryEmbedding",
"MemoryExporter",
"PromptMemoryEntry",
"SeedEntry",
"set_message_piece_sha256_async",
"set_seed_sha256_async",
"SQLiteMemory",
"StorageIO",
"SupportedContentType",
"TextDataTypeSerializer",
"URLDataTypeSerializer",
"VideoPathDataTypeSerializer",
]


Expand Down
9 changes: 3 additions & 6 deletions pyrit/memory/azure_sql_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,8 @@
EmbeddingDataEntry,
PromptMemoryEntry,
)
from pyrit.models import (
AzureBlobStorageIO,
ConversationStats,
MessagePiece,
)
from pyrit.memory.storage import AzureBlobStorageIO
from pyrit.models import ConversationStats, MessagePiece

if TYPE_CHECKING:
from azure.core.credentials import AccessToken
Expand Down Expand Up @@ -871,7 +868,7 @@ def _update_entries(self, *, entries: MutableSequence[Base], update_fields: dict
# attributes from the (potentially stale) detached object
# and silently overwrite concurrent updates to columns
# that are NOT in update_fields.
entry_in_session = session.get(type(entry), entry.id)
entry_in_session = session.get(type(entry), entry.id) # type: ignore[ty:unresolved-attribute]
if entry_in_session is None:
entry_in_session = session.merge(entry)
for field, value in update_fields.items():
Expand Down
11 changes: 7 additions & 4 deletions pyrit/memory/memory_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,15 @@
ScoreEntry,
SeedEntry,
)
from pyrit.memory.storage import (
DataTypeSerializer,
StorageIO,
data_serializer_factory,
set_seed_sha256_async,
)
from pyrit.models import (
AttackResult,
ConversationStats,
DataTypeSerializer,
IdentifierFilter,
IdentifierType,
Message,
Expand All @@ -47,8 +52,6 @@
SeedDataset,
SeedGroup,
SeedType,
StorageIO,
data_serializer_factory,
group_conversation_message_pieces_by_sequence,
sort_message_pieces,
)
Expand Down Expand Up @@ -1395,7 +1398,7 @@ async def add_seeds_to_memory_async(self, *, seeds: Sequence[Seed], added_by: st
serialized_prompt_value = await self._serialize_seed_value_async(prompt=prompt)
prompt.value = serialized_prompt_value

await prompt.set_sha256_value_async()
await set_seed_sha256_async(prompt)

if prompt.value_sha256 and not self.get_seeds(
value_sha256=[prompt.value_sha256], dataset_name=prompt.dataset_name
Expand Down
7 changes: 4 additions & 3 deletions pyrit/memory/sqlite_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
PromptMemoryEntry,
ScenarioResultEntry,
)
from pyrit.models import ConversationStats, DiskStorageIO, MessagePiece
from pyrit.memory.storage import DiskStorageIO
from pyrit.models import ConversationStats, MessagePiece

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -440,7 +441,7 @@ def _update_entries(self, *, entries: MutableSequence[Base], update_fields: dict
# attributes from the (potentially stale) detached object
# and silently overwrite concurrent updates to columns
# that are NOT in update_fields.
entry_in_session = session.get(type(entry), entry.id)
entry_in_session = session.get(type(entry), entry.id) # type: ignore[ty:unresolved-attribute]
if entry_in_session is None:
entry_in_session = session.merge(entry)
for field, value in update_fields.items():
Expand Down Expand Up @@ -614,7 +615,7 @@ def export_all_tables(self, *, export_type: str = "json") -> None:
file_extension = f".{export_type}"
file_path = DB_DATA_PATH / f"{table_name}{file_extension}"
# Convert to list for exporter compatibility
self.exporter.export_data(list(data), file_path=file_path, export_type=export_type)
self.exporter.export_data(list(data), file_path=file_path, export_type=export_type) # type: ignore[ty:invalid-argument-type]

def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories: Sequence[str]) -> Any:
"""
Expand Down
56 changes: 56 additions & 0 deletions pyrit/memory/storage/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""
Storage layer for PyRIT: storage backends and multi-modal data serializers.

Provides the disk and blob storage adapters (``StorageIO`` and its
implementations) and the data-type serializers (``data_serializer_factory`` and
the per-type ``*DataTypeSerializer`` classes) used to read and write prompt
payloads such as text, images, audio, and video.

These serializers write payload files into the location configured on the active
memory instance (``results_path`` / ``results_storage_io``), which is why they
live alongside ``pyrit.memory``: the database holds the records and this package
holds the blob payloads those records point to.
"""

from pyrit.memory.storage.serializers import (
AllowedCategories,
AudioPathDataTypeSerializer,
BinaryPathDataTypeSerializer,
DataTypeSerializer,
ErrorDataTypeSerializer,
ImagePathDataTypeSerializer,
TextDataTypeSerializer,
URLDataTypeSerializer,
VideoPathDataTypeSerializer,
data_serializer_factory,
set_message_piece_sha256_async,
set_seed_sha256_async,
)
from pyrit.memory.storage.storage import (
AzureBlobStorageIO,
DiskStorageIO,
StorageIO,
SupportedContentType,
)

__all__ = [
"AllowedCategories",
"AudioPathDataTypeSerializer",
"AzureBlobStorageIO",
"BinaryPathDataTypeSerializer",
"DataTypeSerializer",
"data_serializer_factory",
"DiskStorageIO",
"ErrorDataTypeSerializer",
"ImagePathDataTypeSerializer",
"set_message_piece_sha256_async",
"set_seed_sha256_async",
"StorageIO",
"SupportedContentType",
"TextDataTypeSerializer",
"URLDataTypeSerializer",
"VideoPathDataTypeSerializer",
]
Loading
Loading