diff --git a/pyrit/backend/services/attack_service.py b/pyrit/backend/services/attack_service.py index 836a783fba..23d58bf623 100644 --- a/pyrit/backend/services/attack_service.py +++ b/pyrit/backend/services/attack_service.py @@ -50,7 +50,7 @@ from pyrit.backend.models.common import PaginationInfo from pyrit.backend.services.converter_service import get_converter_service from pyrit.backend.services.target_service import get_target_service -from pyrit.memory import CentralMemory +from pyrit.memory import CentralMemory, data_serializer_factory from pyrit.models import ( AttackOutcome, AttackResult, @@ -60,7 +60,6 @@ MessagePiece, PromptDataType, build_atomic_attack_identifier, - data_serializer_factory, ) from pyrit.prompt_normalizer import PromptConverterConfiguration, PromptNormalizer diff --git a/pyrit/backend/services/converter_service.py b/pyrit/backend/services/converter_service.py index 4f09248aaa..8bd6199592 100644 --- a/pyrit/backend/services/converter_service.py +++ b/pyrit/backend/services/converter_service.py @@ -36,8 +36,8 @@ CreateConverterResponse, PreviewStep, ) +from pyrit.memory import data_serializer_factory from pyrit.models import PromptDataType -from pyrit.models.data_type_serializer import data_serializer_factory from pyrit.prompt_converter import PromptConverter from pyrit.prompt_target import PromptTarget from pyrit.registry.object_registries import ConverterRegistry diff --git a/pyrit/common/data_url_converter.py b/pyrit/common/data_url_converter.py index 20ff008332..6fef6337d6 100644 --- a/pyrit/common/data_url_converter.py +++ b/pyrit/common/data_url_converter.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. from pyrit.common.deprecation import print_deprecation_message -from pyrit.models import DataTypeSerializer, data_serializer_factory +from pyrit.memory import DataTypeSerializer, data_serializer_factory # Supported image formats for Azure OpenAI GPT-4o, # https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/use-your-image-data diff --git a/pyrit/common/display_response.py b/pyrit/common/display_response.py index 53e63b6502..aa00ff279f 100644 --- a/pyrit/common/display_response.py +++ b/pyrit/common/display_response.py @@ -8,8 +8,8 @@ from pyrit.common.deprecation import print_deprecation_message from pyrit.common.notebook_utils import is_in_ipython_session -from pyrit.memory import CentralMemory -from pyrit.models import AzureBlobStorageIO, DiskStorageIO, MessagePiece +from pyrit.memory import AzureBlobStorageIO, CentralMemory, DiskStorageIO +from pyrit.models import MessagePiece logger = logging.getLogger(__name__) diff --git a/pyrit/datasets/seed_datasets/remote/_image_cache.py b/pyrit/datasets/seed_datasets/remote/_image_cache.py index bdb502ad23..649641158c 100644 --- a/pyrit/datasets/seed_datasets/remote/_image_cache.py +++ b/pyrit/datasets/seed_datasets/remote/_image_cache.py @@ -17,7 +17,7 @@ from typing import Any from pyrit.common.net_utility import make_request_and_raise_if_error_async -from pyrit.models import data_serializer_factory +from pyrit.memory import data_serializer_factory logger = logging.getLogger(__name__) diff --git a/pyrit/datasets/seed_datasets/remote/msts_dataset.py b/pyrit/datasets/seed_datasets/remote/msts_dataset.py index 317bcfbc73..c24e28a2b3 100644 --- a/pyrit/datasets/seed_datasets/remote/msts_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/msts_dataset.py @@ -14,7 +14,8 @@ from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import ( _RemoteDatasetLoader, ) -from pyrit.models import SeedDataset, SeedPrompt, data_serializer_factory +from pyrit.memory import data_serializer_factory +from pyrit.models import SeedDataset, SeedPrompt if TYPE_CHECKING: from PIL.Image import Image as PILImage diff --git a/pyrit/memory/__init__.py b/pyrit/memory/__init__.py index 2579dc1334..62d5887770 100644 --- a/pyrit/memory/__init__.py +++ b/pyrit/memory/__init__.py @@ -15,18 +15,52 @@ from pyrit.memory.memory_interface import MemoryInterface from pyrit.memory.memory_models import AttackResultEntry, EmbeddingDataEntry, PromptMemoryEntry, SeedEntry from pyrit.memory.sqlite_memory import SQLiteMemory +from pyrit.memory.storage import ( + AllowedCategories, + AudioPathDataTypeSerializer, + AzureBlobStorageIO, + BinaryPathDataTypeSerializer, + DataTypeSerializer, + DiskStorageIO, + ErrorDataTypeSerializer, + ImagePathDataTypeSerializer, + StorageIO, + SupportedContentType, + TextDataTypeSerializer, + URLDataTypeSerializer, + VideoPathDataTypeSerializer, + data_serializer_factory, + set_message_piece_sha256_async, + set_seed_sha256_async, +) __all__ = [ + "AllowedCategories", "AttackResultEntry", + "AudioPathDataTypeSerializer", + "AzureBlobStorageIO", "AzureSQLMemory", + "BinaryPathDataTypeSerializer", "CentralMemory", - "SQLiteMemory", + "DataTypeSerializer", + "data_serializer_factory", + "DiskStorageIO", "EmbeddingDataEntry", + "ErrorDataTypeSerializer", + "ImagePathDataTypeSerializer", "MemoryInterface", "MemoryEmbedding", "MemoryExporter", "PromptMemoryEntry", "SeedEntry", + "set_message_piece_sha256_async", + "set_seed_sha256_async", + "SQLiteMemory", + "StorageIO", + "SupportedContentType", + "TextDataTypeSerializer", + "URLDataTypeSerializer", + "VideoPathDataTypeSerializer", ] diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 81b97d8716..7d62ab5bd9 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -26,11 +26,8 @@ EmbeddingDataEntry, PromptMemoryEntry, ) -from pyrit.models import ( - AzureBlobStorageIO, - ConversationStats, - MessagePiece, -) +from pyrit.memory.storage import AzureBlobStorageIO +from pyrit.models import ConversationStats, MessagePiece if TYPE_CHECKING: from azure.core.credentials import AccessToken @@ -871,7 +868,7 @@ def _update_entries(self, *, entries: MutableSequence[Base], update_fields: dict # attributes from the (potentially stale) detached object # and silently overwrite concurrent updates to columns # that are NOT in update_fields. - entry_in_session = session.get(type(entry), entry.id) + entry_in_session = session.get(type(entry), entry.id) # type: ignore[ty:unresolved-attribute] if entry_in_session is None: entry_in_session = session.merge(entry) for field, value in update_fields.items(): diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index f0e5fdeb9b..ff4a6af97f 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -33,10 +33,15 @@ ScoreEntry, SeedEntry, ) +from pyrit.memory.storage import ( + DataTypeSerializer, + StorageIO, + data_serializer_factory, + set_seed_sha256_async, +) from pyrit.models import ( AttackResult, ConversationStats, - DataTypeSerializer, IdentifierFilter, IdentifierType, Message, @@ -47,8 +52,6 @@ SeedDataset, SeedGroup, SeedType, - StorageIO, - data_serializer_factory, group_conversation_message_pieces_by_sequence, sort_message_pieces, ) @@ -1395,7 +1398,7 @@ async def add_seeds_to_memory_async(self, *, seeds: Sequence[Seed], added_by: st serialized_prompt_value = await self._serialize_seed_value_async(prompt=prompt) prompt.value = serialized_prompt_value - await prompt.set_sha256_value_async() + await set_seed_sha256_async(prompt) if prompt.value_sha256 and not self.get_seeds( value_sha256=[prompt.value_sha256], dataset_name=prompt.dataset_name diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index 5f628ab075..61f556aee2 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -29,7 +29,8 @@ PromptMemoryEntry, ScenarioResultEntry, ) -from pyrit.models import ConversationStats, DiskStorageIO, MessagePiece +from pyrit.memory.storage import DiskStorageIO +from pyrit.models import ConversationStats, MessagePiece logger = logging.getLogger(__name__) @@ -440,7 +441,7 @@ def _update_entries(self, *, entries: MutableSequence[Base], update_fields: dict # attributes from the (potentially stale) detached object # and silently overwrite concurrent updates to columns # that are NOT in update_fields. - entry_in_session = session.get(type(entry), entry.id) + entry_in_session = session.get(type(entry), entry.id) # type: ignore[ty:unresolved-attribute] if entry_in_session is None: entry_in_session = session.merge(entry) for field, value in update_fields.items(): @@ -614,7 +615,7 @@ def export_all_tables(self, *, export_type: str = "json") -> None: file_extension = f".{export_type}" file_path = DB_DATA_PATH / f"{table_name}{file_extension}" # Convert to list for exporter compatibility - self.exporter.export_data(list(data), file_path=file_path, export_type=export_type) + self.exporter.export_data(list(data), file_path=file_path, export_type=export_type) # type: ignore[ty:invalid-argument-type] def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories: Sequence[str]) -> Any: """ diff --git a/pyrit/memory/storage/__init__.py b/pyrit/memory/storage/__init__.py new file mode 100644 index 0000000000..b10fcb1d35 --- /dev/null +++ b/pyrit/memory/storage/__init__.py @@ -0,0 +1,56 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Storage layer for PyRIT: storage backends and multi-modal data serializers. + +Provides the disk and blob storage adapters (``StorageIO`` and its +implementations) and the data-type serializers (``data_serializer_factory`` and +the per-type ``*DataTypeSerializer`` classes) used to read and write prompt +payloads such as text, images, audio, and video. + +These serializers write payload files into the location configured on the active +memory instance (``results_path`` / ``results_storage_io``), which is why they +live alongside ``pyrit.memory``: the database holds the records and this package +holds the blob payloads those records point to. +""" + +from pyrit.memory.storage.serializers import ( + AllowedCategories, + AudioPathDataTypeSerializer, + BinaryPathDataTypeSerializer, + DataTypeSerializer, + ErrorDataTypeSerializer, + ImagePathDataTypeSerializer, + TextDataTypeSerializer, + URLDataTypeSerializer, + VideoPathDataTypeSerializer, + data_serializer_factory, + set_message_piece_sha256_async, + set_seed_sha256_async, +) +from pyrit.memory.storage.storage import ( + AzureBlobStorageIO, + DiskStorageIO, + StorageIO, + SupportedContentType, +) + +__all__ = [ + "AllowedCategories", + "AudioPathDataTypeSerializer", + "AzureBlobStorageIO", + "BinaryPathDataTypeSerializer", + "DataTypeSerializer", + "data_serializer_factory", + "DiskStorageIO", + "ErrorDataTypeSerializer", + "ImagePathDataTypeSerializer", + "set_message_piece_sha256_async", + "set_seed_sha256_async", + "StorageIO", + "SupportedContentType", + "TextDataTypeSerializer", + "URLDataTypeSerializer", + "VideoPathDataTypeSerializer", +] diff --git a/pyrit/memory/storage/serializers.py b/pyrit/memory/storage/serializers.py new file mode 100644 index 0000000000..7a4e84ff14 --- /dev/null +++ b/pyrit/memory/storage/serializers.py @@ -0,0 +1,795 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + +import abc +import asyncio +import base64 +import hashlib +import tempfile +import time +import wave +from mimetypes import guess_type +from pathlib import Path +from typing import TYPE_CHECKING, Literal, get_args +from urllib.parse import urlparse + +import aiofiles + +from pyrit.common.deprecation import print_deprecation_message +from pyrit.common.path import DB_DATA_PATH +from pyrit.memory.storage.storage import DiskStorageIO, StorageIO + +if TYPE_CHECKING: + from pyrit.memory import MemoryInterface + from pyrit.models.literals import PromptDataType + from pyrit.models.messages.message_piece import MessagePiece + from pyrit.models.seeds.seed import Seed + +# Define allowed categories for validation +AllowedCategories = Literal["seed-prompt-entries", "prompt-memory-entries"] + + +def _write_wav_sync( + path: str, + *, + num_channels: int, + sample_width: int, + sample_rate: int, + data: bytes, +) -> None: + """Write PCM audio bytes to a WAV file synchronously.""" + with wave.open(path, "wb") as wav_file: + wav_file.setnchannels(num_channels) + wav_file.setsampwidth(sample_width) + wav_file.setframerate(sample_rate) + wav_file.writeframes(data) + + +def data_serializer_factory( + *, + data_type: PromptDataType, + value: str | None = None, + extension: str | None = None, + category: AllowedCategories, +) -> DataTypeSerializer: + """ + Create a DataTypeSerializer instance. + + Args: + data_type (str): The type of the data (e.g., 'text', 'image_path', 'audio_path'). + value (str): The data value to be serialized. + extension (Optional[str]): The file extension, if applicable. + category (AllowedCategories): The category or context for the data (e.g., 'seed-prompt-entries'). + + Returns: + DataTypeSerializer: An instance of the appropriate serializer. + + Raises: + ValueError: If the category is not provided or invalid. + + """ + if not category: + raise ValueError( + f"The 'category' argument is mandatory and must be one of the following: {get_args(AllowedCategories)}." + ) + if value is not None: + if data_type in ["text", "reasoning", "function_call", "tool_call", "function_call_output"]: + return TextDataTypeSerializer(prompt_text=value, data_type=data_type) + if data_type == "image_path": + return ImagePathDataTypeSerializer(category=category, prompt_text=value, extension=extension) + if data_type == "audio_path": + return AudioPathDataTypeSerializer(category=category, prompt_text=value, extension=extension) + if data_type == "video_path": + return VideoPathDataTypeSerializer(category=category, prompt_text=value, extension=extension) + if data_type == "binary_path": + return BinaryPathDataTypeSerializer(category=category, prompt_text=value, extension=extension) + if data_type == "error": + return ErrorDataTypeSerializer(prompt_text=value) + if data_type == "url": + return URLDataTypeSerializer(category=category, prompt_text=value, extension=extension) + raise ValueError(f"Data type {data_type} not supported") + if data_type == "image_path": + return ImagePathDataTypeSerializer(category=category, extension=extension) + if data_type == "audio_path": + return AudioPathDataTypeSerializer(category=category, extension=extension) + if data_type == "video_path": + return VideoPathDataTypeSerializer(category=category, extension=extension) + if data_type == "binary_path": + return BinaryPathDataTypeSerializer(category=category, extension=extension) + if data_type == "error": + return ErrorDataTypeSerializer(prompt_text="") + raise ValueError(f"Data type {data_type} without prompt text not supported") + + +class DataTypeSerializer(abc.ABC): + """ + Abstract base class for data type normalizers. + + Responsible for reading and saving multi-modal data types to local disk or Azure Storage Account. + """ + + data_type: PromptDataType + value: str + category: str + data_sub_directory: str + file_extension: str + + _file_path: Path | str | None = None + + @property + def _memory(self) -> MemoryInterface: + from pyrit.memory import CentralMemory + + return CentralMemory.get_memory_instance() + + def _get_storage_io(self) -> StorageIO: + """ + Retrieve the input datasets storage handle. + + Returns: + StorageIO: An instance of DiskStorageIO or AzureBlobStorageIO based on the storage configuration. + + Raises: + ValueError: If the Azure Storage URL is detected but the datasets storage handle is not set. + RuntimeError: If results_storage_io is not configured but Azure storage URL was detected. + + """ + if self._is_azure_storage_url(self.value): + # Scenarios where a user utilizes an in-memory DuckDB but also needs to interact + # with an Azure Storage Account, ex., XPIAWorkflow. + if self._memory.results_storage_io is None: + raise RuntimeError("results_storage_io is not configured but Azure storage URL was detected") + return self._memory.results_storage_io + return DiskStorageIO() + + @abc.abstractmethod + def data_on_disk(self) -> bool: + """ + Indicate whether the data is stored on disk. + + Returns: + bool: True when data is persisted on disk. + + """ + + async def save_data_async(self, data: bytes, output_filename: str | None = None) -> None: + """ + Save data to storage. + + Arguments: + data: bytes: The data to be saved. + output_filename (optional, str): filename to store data as. Defaults to UUID if not provided + + Raises: + RuntimeError: If storage IO is not initialized. + """ + file_path = await self.get_data_filename_async(file_name=output_filename) + if self._memory.results_storage_io is None: + raise RuntimeError("Storage IO not initialized") + await self._memory.results_storage_io.write_file_async(file_path, data) + self.value = str(file_path) + + async def save_b64_image_async(self, data: str | bytes, output_filename: str | None = None) -> None: + """ + Save a base64-encoded image to storage. + + Arguments: + data: string or bytes with base64 data + output_filename (optional, str): filename to store image as. Defaults to UUID if not provided + + Raises: + RuntimeError: If storage IO is not initialized. + """ + file_path = await self.get_data_filename_async(file_name=output_filename) + image_bytes = base64.b64decode(data) + if self._memory.results_storage_io is None: + raise RuntimeError("Storage IO not initialized") + await self._memory.results_storage_io.write_file_async(file_path, image_bytes) + self.value = str(file_path) + + async def save_formatted_audio_async( + self, + data: bytes, + num_channels: int = 1, + sample_width: int = 2, + sample_rate: int = 16000, + output_filename: str | None = None, + ) -> None: + """ + Save PCM16 or similarly formatted audio data to storage. + + Arguments: + data: bytes with audio data + output_filename (optional, str): filename to store audio as. Defaults to UUID if not provided + num_channels (optional, int): number of channels in audio data. Defaults to 1 + sample_width (optional, int): sample width in bytes. Defaults to 2 + sample_rate (optional, int): sample rate in Hz. Defaults to 16000 + + Raises: + RuntimeError: If storage IO is not initialized. + """ + file_path = await self.get_data_filename_async(file_name=output_filename) + + # save audio file locally first if in AzureStorageBlob so we can use wave.open to set audio parameters + if self._is_azure_storage_url(str(file_path)): + with tempfile.NamedTemporaryFile(suffix=".wav", dir=DB_DATA_PATH, delete=False) as tmp: + local_temp_path = Path(tmp.name) + try: + await asyncio.to_thread( + _write_wav_sync, + str(local_temp_path), + num_channels=num_channels, + sample_width=sample_width, + sample_rate=sample_rate, + data=data, + ) + async with aiofiles.open(local_temp_path, "rb") as f: + audio_data = await f.read() + if self._memory.results_storage_io is None: + raise RuntimeError("self._memory.results_storage_io is not initialized") + await self._memory.results_storage_io.write_file_async(file_path, audio_data) + finally: + local_temp_path.unlink(missing_ok=True) + + # If local, we can just save straight to disk and do not need to delete temp file after + else: + await asyncio.to_thread( + _write_wav_sync, + str(file_path), + num_channels=num_channels, + sample_width=sample_width, + sample_rate=sample_rate, + data=data, + ) + + self.value = str(file_path) + + async def read_data_async(self) -> bytes: + """ + Read data from storage. + + Returns: + bytes: The data read from storage. + + Raises: + TypeError: If the serializer does not represent on-disk data. + RuntimeError: If no value is set. + FileNotFoundError: If the referenced file does not exist. + + """ + if not self.data_on_disk(): + raise TypeError(f"Data for data Type {self.data_type} is not stored on disk") + + if not self.value: + raise RuntimeError("Prompt text not set") + + storage_io = self._get_storage_io() + # Check if path exists + file_exists = await storage_io.path_exists_async(path=self.value) + if not file_exists: + raise FileNotFoundError(f"File not found: {self.value}") + # Read the contents from the path + return await storage_io.read_file_async(self.value) + + async def read_data_base64_async(self) -> str: + """ + Read data from storage and return it as a base64 string. + + Returns: + str: Base64-encoded data. + + """ + byte_array = await self.read_data_async() + return base64.b64encode(byte_array).decode("utf-8") + + async def get_sha256_async(self) -> str: + """ + Compute SHA256 hash for this serializer's current value. + + Returns: + str: Hex digest of the computed SHA256 hash. + + Raises: + FileNotFoundError: If on-disk data path does not exist. + ValueError: If in-memory data cannot be converted to bytes. + + """ + input_bytes: bytes | None = None + + if self.data_on_disk(): + storage_io = self._get_storage_io() + file_exists = await storage_io.path_exists_async(self.value) + if not file_exists: + raise FileNotFoundError(f"File not found: {self.value}") + + # Read the data from storage + input_bytes = await storage_io.read_file_async(self.value) + else: + if isinstance(self.value, str): + input_bytes = self.value.encode("utf-8") + else: + raise ValueError(f"Invalid data type {self.value}, expected str data type.") + + hash_object = hashlib.sha256(input_bytes) + return hash_object.hexdigest() + + async def get_data_filename_async(self, file_name: str | None = None) -> Path | str: + """ + Generate or retrieve a unique filename for the data file. + + Args: + file_name (Optional[str]): Optional file name override. + + Returns: + Union[Path, str]: Full storage path for the generated data file. + + Raises: + TypeError: If the serializer is not configured for on-disk data. + RuntimeError: If required data subdirectory information is missing. + + """ + if self._file_path: + return self._file_path + + if not self.data_on_disk(): + raise TypeError("Data is not stored on disk") + + if not self.data_sub_directory: + raise RuntimeError("Data sub directory not set") + + ticks = int(time.time() * 1_000_000) + if self._memory.results_path: + results_path = str(self._memory.results_path) + else: + from pyrit.common.path import DB_DATA_PATH + + results_path = str(DB_DATA_PATH) + file_name = file_name if file_name else str(ticks) + + if self._is_azure_storage_url(results_path): + full_data_directory_path = results_path + self.data_sub_directory + self._file_path = full_data_directory_path + f"/{file_name}.{self.file_extension}" + else: + full_data_directory_path = results_path + self.data_sub_directory + if self._memory.results_storage_io is None: + raise RuntimeError("self._memory.results_storage_io is not initialized") + await self._memory.results_storage_io.create_directory_if_not_exists_async(Path(full_data_directory_path)) + self._file_path = Path(full_data_directory_path, f"{file_name}.{self.file_extension}") + + return self._file_path + + async def save_data( # pyrit-async-suffix-exempt + self, data: bytes, output_filename: str | None = None + ) -> None: + """ + Save data to storage (deprecated alias of ``save_data_async``). + + Args: + data: The data to be saved. + output_filename: Optional filename to store data as. + """ + print_deprecation_message( + old_item="pyrit.memory.storage.serializers.DataTypeSerializer.save_data", + new_item="pyrit.memory.storage.serializers.DataTypeSerializer.save_data_async", + removed_in="0.16.0", + ) + await self.save_data_async(data, output_filename) + + async def save_b64_image( # pyrit-async-suffix-exempt + self, data: str | bytes, output_filename: str | None = None + ) -> None: + """ + Save a base64-encoded image to storage (deprecated alias of ``save_b64_image_async``). + + Args: + data: String or bytes with base64 data. + output_filename: Optional filename to store image as. + """ + print_deprecation_message( + old_item="pyrit.memory.storage.serializers.DataTypeSerializer.save_b64_image", + new_item="pyrit.memory.storage.serializers.DataTypeSerializer.save_b64_image_async", + removed_in="0.16.0", + ) + await self.save_b64_image_async(data, output_filename) + + async def save_formatted_audio( # pyrit-async-suffix-exempt + self, + data: bytes, + num_channels: int = 1, + sample_width: int = 2, + sample_rate: int = 16000, + output_filename: str | None = None, + ) -> None: + """ + Save formatted audio data to storage (deprecated alias of ``save_formatted_audio_async``). + + Args: + data: Audio data bytes. + num_channels: Number of channels in audio data. + sample_width: Sample width in bytes. + sample_rate: Sample rate in Hz. + output_filename: Optional filename to store audio as. + """ + print_deprecation_message( + old_item="pyrit.memory.storage.serializers.DataTypeSerializer.save_formatted_audio", + new_item="pyrit.memory.storage.serializers.DataTypeSerializer.save_formatted_audio_async", + removed_in="0.16.0", + ) + await self.save_formatted_audio_async(data, num_channels, sample_width, sample_rate, output_filename) + + async def read_data(self) -> bytes: # pyrit-async-suffix-exempt + """ + Read data from storage (deprecated alias of ``read_data_async``). + + Returns: + bytes: The data read from storage. + """ + print_deprecation_message( + old_item="pyrit.memory.storage.serializers.DataTypeSerializer.read_data", + new_item="pyrit.memory.storage.serializers.DataTypeSerializer.read_data_async", + removed_in="0.16.0", + ) + return await self.read_data_async() + + async def read_data_base64(self) -> str: # pyrit-async-suffix-exempt + """ + Read data and return it as a base64 string (deprecated alias of ``read_data_base64_async``). + + Returns: + str: Base64-encoded data. + """ + print_deprecation_message( + old_item="pyrit.memory.storage.serializers.DataTypeSerializer.read_data_base64", + new_item="pyrit.memory.storage.serializers.DataTypeSerializer.read_data_base64_async", + removed_in="0.16.0", + ) + return await self.read_data_base64_async() + + async def get_sha256(self) -> str: # pyrit-async-suffix-exempt + """ + Compute SHA256 hash for this serializer's current value (deprecated alias of ``get_sha256_async``). + + Returns: + str: Hex digest of the computed SHA256 hash. + """ + print_deprecation_message( + old_item="pyrit.memory.storage.serializers.DataTypeSerializer.get_sha256", + new_item="pyrit.memory.storage.serializers.DataTypeSerializer.get_sha256_async", + removed_in="0.16.0", + ) + return await self.get_sha256_async() + + async def get_data_filename( # pyrit-async-suffix-exempt + self, file_name: str | None = None + ) -> Path | str: + """ + Generate or retrieve a unique filename for the data file (deprecated alias of ``get_data_filename_async``). + + Args: + file_name: Optional file name override. + + Returns: + Union[Path, str]: Full storage path for the generated data file. + """ + print_deprecation_message( + old_item="pyrit.memory.storage.serializers.DataTypeSerializer.get_data_filename", + new_item="pyrit.memory.storage.serializers.DataTypeSerializer.get_data_filename_async", + removed_in="0.16.0", + ) + return await self.get_data_filename_async(file_name) + + @staticmethod + def get_extension(file_path: str) -> str | None: + """ + Get the file extension from the file path. + + Args: + file_path (str): Input file path. + + Returns: + str | None: File extension (including dot) or None if unavailable. + + """ + ext = Path(file_path).suffix + return ext or None + + @staticmethod + def get_mime_type(file_path: str) -> str | None: + """ + Get the MIME type of the file path. + + Args: + file_path (str): Input file path. + + Returns: + str | None: MIME type if detectable; otherwise None. + + """ + mime_type, _ = guess_type(file_path) + return mime_type + + def _is_azure_storage_url(self, path: str) -> bool: + """ + Validate whether the given path is an Azure Storage URL. + + Args: + path (str): Path or URL to check. + + Returns: + bool: True if the path is an Azure Blob Storage URL. + + """ + parsed = urlparse(path) + return parsed.scheme in ("http", "https") and "blob.core.windows.net" in parsed.netloc + + +class TextDataTypeSerializer(DataTypeSerializer): + """Serializer for text and text-like prompt values that stay in-memory.""" + + def __init__(self, *, prompt_text: str, data_type: PromptDataType = "text") -> None: + """ + Initialize a text serializer. + + Args: + prompt_text (str): Prompt value. + data_type (PromptDataType): Text-like prompt data type. + + """ + self.data_type = data_type + self.value = prompt_text + + def data_on_disk(self) -> bool: + """ + Indicate whether this serializer persists data on disk. + + Returns: + bool: Always False for text serializers. + + """ + return False + + +class ErrorDataTypeSerializer(DataTypeSerializer): + """Serializer for error payloads stored as in-memory text.""" + + def __init__(self, *, prompt_text: str) -> None: + """ + Initialize an error serializer. + + Args: + prompt_text (str): Error payload text. + + """ + self.data_type = "error" + self.value = prompt_text + + def data_on_disk(self) -> bool: + """ + Indicate whether this serializer persists data on disk. + + Returns: + bool: Always False for error serializers. + + """ + return False + + +class URLDataTypeSerializer(DataTypeSerializer): + """Serializer for URL values and URL-backed local file references.""" + + def __init__(self, *, category: str, prompt_text: str, extension: str | None = None) -> None: + """ + Initialize a URL serializer. + + Args: + category (str): Data category folder name. + prompt_text (str): URL or path value. + extension (Optional[str]): Optional extension for persisted content. + + """ + self.data_type = "url" + self.value = prompt_text + self.data_sub_directory = f"/{category}/urls" + self.file_extension = extension if extension else "txt" + self.on_disk = not (prompt_text.startswith(("http://", "https://"))) + + def data_on_disk(self) -> bool: + """ + Indicate whether this serializer persists data on disk. + + Returns: + bool: True for non-http values, False for URL values. + + """ + return self.on_disk + + +class ImagePathDataTypeSerializer(DataTypeSerializer): + """Serializer for image path values stored on disk.""" + + def __init__(self, *, category: str, prompt_text: str | None = None, extension: str | None = None) -> None: + """ + Initialize an image-path serializer. + + Args: + category (str): Data category folder name. + prompt_text (Optional[str]): Optional existing image path. + extension (Optional[str]): Optional image extension. + + """ + self.data_type = "image_path" + self.data_sub_directory = f"/{category}/images" + self.file_extension = extension if extension else "png" + + if prompt_text: + self.value = prompt_text + + def data_on_disk(self) -> bool: + """ + Indicate whether this serializer persists data on disk. + + Returns: + bool: Always True for image path serializers. + + """ + return True + + +class AudioPathDataTypeSerializer(DataTypeSerializer): + """Serializer for audio path values stored on disk.""" + + def __init__( + self, + *, + category: str, + prompt_text: str | None = None, + extension: str | None = None, + ) -> None: + """ + Initialize an audio-path serializer. + + Args: + category (str): Data category folder name. + prompt_text (Optional[str]): Optional existing audio path. + extension (Optional[str]): Optional audio extension. + + """ + self.data_type = "audio_path" + self.data_sub_directory = f"/{category}/audio" + self.file_extension = extension if extension else "mp3" + + if prompt_text: + self.value = prompt_text + + def data_on_disk(self) -> bool: + """ + Indicate whether this serializer persists data on disk. + + Returns: + bool: Always True for audio path serializers. + + """ + return True + + +class VideoPathDataTypeSerializer(DataTypeSerializer): + """Serializer for video path values stored on disk.""" + + def __init__( + self, + *, + category: str, + prompt_text: str | None = None, + extension: str | None = None, + ) -> None: + """ + Initialize a video-path serializer. + + Args: + category (str): The category or context for the data. + prompt_text (Optional[str]): The video path or identifier. + extension (Optional[str]): The file extension, defaults to 'mp4'. + + """ + self.data_type = "video_path" + self.data_sub_directory = f"/{category}/videos" + self.file_extension = extension if extension else "mp4" + + if prompt_text: + self.value = prompt_text + + def data_on_disk(self) -> bool: + """ + Indicate whether this serializer persists data on disk. + + Returns: + bool: Always True for video path serializers. + + """ + return True + + +class BinaryPathDataTypeSerializer(DataTypeSerializer): + """Serializer for generic binary path values stored on disk.""" + + def __init__( + self, + *, + category: str, + prompt_text: str | None = None, + extension: str | None = None, + ) -> None: + """ + Initialize a generic binary-path serializer. + + This serializer handles generic binary data that doesn't fit into specific + categories like images, audio, or video. Useful for XPIA attacks and + storing files like PDFs, documents, or other binary formats. + + Args: + category (str): The category or context for the data. + prompt_text (Optional[str]): The binary file path or identifier. + extension (Optional[str]): The file extension, defaults to 'bin'. + + """ + self.data_type = "binary_path" + self.data_sub_directory = f"/{category}/binaries" + self.file_extension = extension if extension else "bin" + + if prompt_text: + self.value = prompt_text + + def data_on_disk(self) -> bool: + """ + Indicate whether this serializer persists data on disk. + + Returns: + bool: Always True for binary path serializers. + + """ + return True + + +async def set_message_piece_sha256_async(message_piece: MessagePiece) -> None: + """ + Compute and assign SHA256 hash values for a message piece's original and converted payloads. + + Async because blob payloads may need to be fetched. Must be called explicitly after + the message piece is constructed and its values are finalized. + + Args: + message_piece (MessagePiece): The message piece to populate with SHA256 values. + """ + original_serializer = data_serializer_factory( + category="prompt-memory-entries", + data_type=message_piece.original_value_data_type, + value=message_piece.original_value, + ) + message_piece.original_value_sha256 = await original_serializer.get_sha256_async() + + converted_serializer = data_serializer_factory( + category="prompt-memory-entries", + data_type=message_piece.converted_value_data_type, + value=message_piece.converted_value, + ) + message_piece.converted_value_sha256 = await converted_serializer.get_sha256_async() + + +async def set_seed_sha256_async(seed: Seed) -> None: + """ + Compute and assign the SHA256 hash value for a seed's value. + + Should be called after the seed ``value`` is serialized to text, as file paths used in + the ``value`` may have changed from local to memory storage paths. Async due to blob retrieval. + + Args: + seed (Seed): The seed to populate with its SHA256 value. + """ + serializer = data_serializer_factory( + category="seed-prompt-entries", + data_type=seed.data_type, + value=seed.value, + ) + seed.value_sha256 = await serializer.get_sha256_async() diff --git a/pyrit/memory/storage/storage.py b/pyrit/memory/storage/storage.py new file mode 100644 index 0000000000..aeb084c9c6 --- /dev/null +++ b/pyrit/memory/storage/storage.py @@ -0,0 +1,507 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + +import logging +from abc import ABC, abstractmethod +from enum import Enum +from pathlib import Path +from typing import TYPE_CHECKING +from urllib.parse import urlparse + +import aiofiles + +from pyrit.common.deprecation import print_deprecation_message + +if TYPE_CHECKING: + from azure.storage.blob.aio import ContainerClient as AsyncContainerClient + +logger = logging.getLogger(__name__) + + +class SupportedContentType(Enum): + """ + All supported content types for uploading blobs to provided storage account container. + See all options here: https://www.iana.org/assignments/media-types/media-types.xhtml. + """ + + # TODO, add other media supported types + PLAIN_TEXT = "text/plain" + + +class StorageIO(ABC): + """ + Abstract interface for storage systems (local disk, Azure Storage Account, etc.). + """ + + @abstractmethod + async def read_file_async(self, path: Path | str) -> bytes: + """ + Asynchronously reads the file (or blob) from the given path. + """ + + @abstractmethod + async def write_file_async(self, path: Path | str, data: bytes) -> None: + """ + Asynchronously writes data to the given path. + """ + + @abstractmethod + async def path_exists_async(self, path: Path | str) -> bool: + """ + Asynchronously checks if a file or blob exists at the given path. + """ + + @abstractmethod + async def is_file_async(self, path: Path | str) -> bool: + """ + Asynchronously checks if the path refers to a file (not a directory or container). + """ + + @abstractmethod + async def create_directory_if_not_exists_async(self, path: Path | str) -> None: + """ + Asynchronously creates a directory or equivalent in the storage system if it doesn't exist. + """ + + async def read_file(self, path: Path | str) -> bytes: # pyrit-async-suffix-exempt + """ + Read a file from storage (deprecated alias of ``read_file_async``). + + Args: + path (Union[Path, str]): The path to the file. + + Returns: + bytes: The content of the file. + """ + print_deprecation_message( + old_item="pyrit.memory.storage.storage.StorageIO.read_file", + new_item="pyrit.memory.storage.storage.StorageIO.read_file_async", + removed_in="0.16.0", + ) + return await self.read_file_async(path) + + async def write_file(self, path: Path | str, data: bytes) -> None: # pyrit-async-suffix-exempt + """ + Write data to storage (deprecated alias of ``write_file_async``). + + Args: + path (Union[Path, str]): The path to the file. + data (bytes): The content to write to the file. + """ + print_deprecation_message( + old_item="pyrit.memory.storage.storage.StorageIO.write_file", + new_item="pyrit.memory.storage.storage.StorageIO.write_file_async", + removed_in="0.16.0", + ) + await self.write_file_async(path, data) + + async def path_exists(self, path: Path | str) -> bool: # pyrit-async-suffix-exempt + """ + Check whether a path exists (deprecated alias of ``path_exists_async``). + + Args: + path (Union[Path, str]): The path to check. + + Returns: + bool: True if the path exists, False otherwise. + """ + print_deprecation_message( + old_item="pyrit.memory.storage.storage.StorageIO.path_exists", + new_item="pyrit.memory.storage.storage.StorageIO.path_exists_async", + removed_in="0.16.0", + ) + return await self.path_exists_async(path) + + async def is_file(self, path: Path | str) -> bool: # pyrit-async-suffix-exempt + """ + Check whether the given path is a file (deprecated alias of ``is_file_async``). + + Args: + path (Union[Path, str]): The path to check. + + Returns: + bool: True if the path is a file, False otherwise. + """ + print_deprecation_message( + old_item="pyrit.memory.storage.storage.StorageIO.is_file", + new_item="pyrit.memory.storage.storage.StorageIO.is_file_async", + removed_in="0.16.0", + ) + return await self.is_file_async(path) + + async def create_directory_if_not_exists(self, path: Path | str) -> None: # pyrit-async-suffix-exempt + """ + Create a directory if it does not exist (deprecated alias of ``create_directory_if_not_exists_async``). + + Args: + path (Union[Path, str]): The directory path to create. + """ + print_deprecation_message( + old_item="pyrit.memory.storage.storage.StorageIO.create_directory_if_not_exists", + new_item="pyrit.memory.storage.storage.StorageIO.create_directory_if_not_exists_async", + removed_in="0.16.0", + ) + await self.create_directory_if_not_exists_async(path) + + +class DiskStorageIO(StorageIO): + """ + Implementation of StorageIO for local disk storage. + """ + + async def read_file_async(self, path: Path | str) -> bytes: + """ + Asynchronously reads a file from the local disk. + + Args: + path (Union[Path, str]): The path to the file. + + Returns: + bytes: The content of the file. + + """ + path = self._convert_to_path(path) + async with aiofiles.open(path, "rb") as file: + return await file.read() + + async def write_file_async(self, path: Path | str, data: bytes) -> None: + """ + Asynchronously writes data to a file on the local disk. + + Args: + path (Path): The path to the file. + data (bytes): The content to write to the file. + + """ + path = self._convert_to_path(path) + async with aiofiles.open(path, "wb") as file: + await file.write(data) + + async def path_exists_async(self, path: Path | str) -> bool: + """ + Check whether a path exists on the local disk. + + Args: + path (Path): The path to check. + + Returns: + bool: True if the path exists, False otherwise. + + """ + path = self._convert_to_path(path) + return path.exists() + + async def is_file_async(self, path: Path | str) -> bool: + """ + Check whether the given path is a file (not a directory). + + Args: + path (Path): The path to check. + + Returns: + bool: True if the path is a file, False otherwise. + + """ + path = self._convert_to_path(path) + return path.is_file() + + async def create_directory_if_not_exists_async(self, path: Path | str) -> None: + """ + Asynchronously creates a directory if it doesn't exist on the local disk. + + Args: + path (Path): The directory path to create. + + """ + directory_path = self._convert_to_path(path) + if not directory_path.exists(): + directory_path.mkdir(parents=True, exist_ok=True) + + def _convert_to_path(self, path: Path | str) -> Path: + """ + Convert an input path to a Path object. + + Args: + path (Union[Path, str]): Input path value. + + Returns: + Path: Normalized Path instance. + + """ + return Path(path) if isinstance(path, str) else path + + +class AzureBlobStorageIO(StorageIO): + """ + Implementation of StorageIO for Azure Blob Storage. + """ + + def __init__( + self, + *, + container_url: str | None = None, + sas_token: str | None = None, + blob_content_type: SupportedContentType = SupportedContentType.PLAIN_TEXT, + ) -> None: + """ + Initialize an Azure Blob Storage I/O adapter. + + Args: + container_url (Optional[str]): Azure Blob container URL. + sas_token (Optional[str]): Optional SAS token. + blob_content_type (SupportedContentType): Blob content type for uploads. + + Raises: + ValueError: If container_url is missing. + + """ + self._blob_content_type: str = blob_content_type.value + if not container_url: + raise ValueError("Invalid Azure Storage Account Container URL.") + + self._container_url: str = container_url + self._sas_token = sas_token + self._client_async: AsyncContainerClient | None = None + + async def _create_container_client_async(self) -> AsyncContainerClient: + """ + Create an asynchronous ContainerClient for Azure Storage. + + If a SAS token is provided via the + AZURE_STORAGE_ACCOUNT_SAS_TOKEN environment variable or the init sas_token parameter, it will be used + for authentication. Otherwise, a delegation SAS token will be created using Entra ID authentication. + + Returns: + AsyncContainerClient: The initialized container client. + """ + from azure.storage.blob.aio import ContainerClient as AsyncContainerClient + + from pyrit.auth import AzureStorageAuth + + sas_token = self._sas_token + if not self._sas_token: + logger.info("SAS token not provided. Creating a delegation SAS token using Entra ID authentication.") + sas_token = await AzureStorageAuth.get_sas_token_async(self._container_url) + + self._client_async = AsyncContainerClient.from_container_url( + container_url=self._container_url, + credential=sas_token, + ) + return self._client_async + + async def _upload_blob_async(self, file_name: str, data: bytes, content_type: str) -> None: + """ + (Async) Handles uploading blob to given storage container. + + Args: + file_name (str): File name to assign to uploaded blob. + data (bytes): Byte representation of content to upload to container. + content_type (str): Content type to upload. + + Raises: + RuntimeError: If the Azure container client is not initialized. + """ + from azure.core.exceptions import ClientAuthenticationError + from azure.storage.blob import ContentSettings + + content_settings = ContentSettings(content_type=f"{content_type}") + logger.info(msg="\nUploading to Azure Storage as blob:\n\t" + file_name) + + try: + if self._client_async is None: + raise RuntimeError("Azure container client not initialized") + await self._client_async.upload_blob( + name=file_name, + data=data, + content_settings=content_settings, + overwrite=True, + ) + except Exception as exc: + if isinstance(exc, ClientAuthenticationError): + logger.exception( + msg="Authentication failed. Please check that the container existence in the " + "Azure Storage Account and ensure the validity of the provided SAS token. If you " + "haven't set the SAS token as an environment variable use `az login` to " + "enable delegation-based SAS authentication to connect to the storage account" + ) + raise + logger.exception(msg=f"An unexpected error occurred: {exc}") + raise + + def parse_blob_url(self, file_path: str) -> tuple[str, str]: + """ + Parse a blob URL to extract the container and blob name. + + Args: + file_path (str): Full blob URL. + + Returns: + tuple[str, str]: Container name and blob name. + + Raises: + ValueError: If file_path is not a valid blob URL. + + """ + parsed_url = urlparse(file_path) + if parsed_url.scheme and parsed_url.netloc: + container_name = parsed_url.path.split("/")[1] + blob_name = "/".join(parsed_url.path.split("/")[2:]) + return container_name, blob_name + raise ValueError("Invalid blob URL") + + def _resolve_blob_name(self, path: Path | str) -> str: + """ + Resolve a blob name from either a full blob URL or a relative blob path. + + When a full URL is provided the blob name is extracted from it. The container + name embedded in the URL is intentionally discarded — operations always run + against the container configured in the constructor. + + Backslashes are normalized to forward slashes so that ``Path`` objects + created on Windows still produce valid blob names. + + Args: + path (Union[Path, str]): Blob URL or relative blob path. + + Returns: + str: The resolved blob name. + + """ + path_str = str(path).replace("\\", "/") + try: + # parse_blob_url validates scheme + netloc internally + _, blob_name = self.parse_blob_url(path_str) + return blob_name + except ValueError: + return path_str + + async def read_file_async(self, path: Path | str) -> bytes: + """ + Asynchronously reads the content of a file (blob) from Azure Blob Storage. + + If the provided ``path`` is a full URL + (e.g., ``https://account.blob.core.windows.net/container/dir1/dir2/sample.png``), + it extracts the relative blob path (e.g., ``dir1/dir2/sample.png``) to correctly access the blob. + If a relative path is provided, it will use it as-is. + + Args: + path (str): The path to the file (blob) in Azure Blob Storage. + This can be either a full URL or a relative path. + + Returns: + bytes: The content of the file (blob) as bytes. + + Example: + ``file_content = await read_file_async("https://account.blob.core.windows.net/container/dir2/1726627689003831.png")`` + + Or using a relative path: + + ``file_content = await read_file_async("dir1/dir2/1726627689003831.png")`` + + """ + if not self._client_async: + self._client_async = await self._create_container_client_async() + + blob_name = self._resolve_blob_name(path) + + try: + blob_client = self._client_async.get_blob_client(blob=blob_name) + + # Download the blob + blob_stream = await blob_client.download_blob() + return bytes(await blob_stream.readall()) # type: ignore[ty:invalid-argument-type] + + except Exception as exc: + logger.exception(f"Failed to read file at {blob_name}: {exc}") + raise + finally: + await self._client_async.close() + self._client_async = None + + async def write_file_async(self, path: Path | str, data: bytes) -> None: + """ + Write data to Azure Blob Storage at the specified path. + + If the provided ``path`` is a full URL, the blob name is extracted from it. + If a relative path is provided, it is used as the blob name directly. + + Args: + path (Union[Path, str]): Full blob URL or relative blob path. + data (bytes): The data to write. + """ + if not self._client_async: + self._client_async = await self._create_container_client_async() + blob_name = self._resolve_blob_name(path) + try: + await self._upload_blob_async(file_name=blob_name, data=data, content_type=self._blob_content_type) + except Exception as exc: + logger.exception(f"Failed to write file at {blob_name}: {exc}") + raise + finally: + await self._client_async.close() + self._client_async = None + + async def path_exists_async(self, path: Path | str) -> bool: + """ + Check whether a given path exists in the Azure Blob Storage container. + + Args: + path (Union[Path, str]): Blob URL or path to test. + + Returns: + bool: True when the path exists. + """ + from azure.core.exceptions import ResourceNotFoundError + + if not self._client_async: + self._client_async = await self._create_container_client_async() + try: + blob_name = self._resolve_blob_name(path) + blob_client = self._client_async.get_blob_client(blob=blob_name) + await blob_client.get_blob_properties() + return True + except ResourceNotFoundError: + return False + finally: + await self._client_async.close() + self._client_async = None + + async def is_file_async(self, path: Path | str) -> bool: + """ + Check whether the path refers to a file (blob) in Azure Blob Storage. + + Args: + path (Union[Path, str]): Blob URL or path to test. + + Returns: + bool: True when the blob exists and has non-zero content size. + """ + from azure.core.exceptions import ResourceNotFoundError + + if not self._client_async: + self._client_async = await self._create_container_client_async() + try: + blob_name = self._resolve_blob_name(path) + blob_client = self._client_async.get_blob_client(blob=blob_name) + blob_properties = await blob_client.get_blob_properties() + return bool(blob_properties.size > 0) + except ResourceNotFoundError: + return False + finally: + await self._client_async.close() + self._client_async = None + + async def create_directory_if_not_exists_async(self, directory_path: Path | str) -> None: # type: ignore[ty:invalid-method-override] + """ + Log a no-op directory creation for Azure Blob Storage. + + Args: + directory_path (Union[Path, str]): Requested directory path. + + """ + logger.info( + f"Directory creation is handled automatically during upload operations in Azure Blob Storage. " + f"Directory path: {directory_path}" + ) diff --git a/pyrit/message_normalizer/chat_message_normalizer.py b/pyrit/message_normalizer/chat_message_normalizer.py index d3fcaade31..30ffcafb6b 100644 --- a/pyrit/message_normalizer/chat_message_normalizer.py +++ b/pyrit/message_normalizer/chat_message_normalizer.py @@ -9,14 +9,14 @@ import aiofiles from pyrit.common.data_url_converter import convert_local_image_to_data_url_async +from pyrit.memory import DataTypeSerializer from pyrit.message_normalizer.message_normalizer import ( MessageListNormalizer, MessageStringNormalizer, SystemMessageBehavior, apply_system_message_behavior_async, ) -from pyrit.models import ChatMessage, DataTypeSerializer, Message -from pyrit.models.messages.message_piece import MessagePiece +from pyrit.models import ChatMessage, Message, MessagePiece if TYPE_CHECKING: from pyrit.models.literals import ChatMessageRole diff --git a/pyrit/models/__init__.py b/pyrit/models/__init__.py index 2a9fb9aec1..6ad07cfcc9 100644 --- a/pyrit/models/__init__.py +++ b/pyrit/models/__init__.py @@ -17,6 +17,7 @@ a deprecation shim through ``0.16.0``. """ +import importlib from typing import TYPE_CHECKING, Any from pyrit.common.deprecation import print_deprecation_message @@ -27,17 +28,6 @@ ) from pyrit.models.conversation_reference import ConversationReference, ConversationType from pyrit.models.conversation_stats import ConversationStats -from pyrit.models.data_type_serializer import ( - AllowedCategories, - AudioPathDataTypeSerializer, - BinaryPathDataTypeSerializer, - DataTypeSerializer, - ErrorDataTypeSerializer, - ImagePathDataTypeSerializer, - TextDataTypeSerializer, - VideoPathDataTypeSerializer, - data_serializer_factory, -) from pyrit.models.embeddings import EmbeddingData, EmbeddingResponse, EmbeddingSupport, EmbeddingUsageInformation from pyrit.models.harm_definition import HarmDefinition, ScaleDescription, get_all_harm_definitions from pyrit.models.identifiers import ( @@ -102,10 +92,6 @@ SimulatedTargetSystemPromptPaths, ) -# Keep old module-level imports working (deprecated, will be removed) -# These are re-exported from the seeds submodule -from pyrit.models.storage_io import AzureBlobStorageIO, DiskStorageIO, StorageIO - __all__ = [ "ALLOWED_CHAT_MESSAGE_ROLES", "AllowedCategories", @@ -204,6 +190,24 @@ "ScorerIdentifier": ComponentIdentifier, } +# Names that moved to ``pyrit.memory.storage``. Served lazily via importlib so that +# importing ``pyrit.models`` stays import-boundary clean and fires no warning until a +# moved name is actually accessed. Will be removed in 0.17.0. +_MOVED_TO_MEMORY_STORAGE: dict[str, str] = { + "AllowedCategories": "pyrit.memory.storage.serializers", + "AudioPathDataTypeSerializer": "pyrit.memory.storage.serializers", + "BinaryPathDataTypeSerializer": "pyrit.memory.storage.serializers", + "DataTypeSerializer": "pyrit.memory.storage.serializers", + "ErrorDataTypeSerializer": "pyrit.memory.storage.serializers", + "ImagePathDataTypeSerializer": "pyrit.memory.storage.serializers", + "TextDataTypeSerializer": "pyrit.memory.storage.serializers", + "VideoPathDataTypeSerializer": "pyrit.memory.storage.serializers", + "data_serializer_factory": "pyrit.memory.storage.serializers", + "AzureBlobStorageIO": "pyrit.memory.storage.storage", + "DiskStorageIO": "pyrit.memory.storage.storage", + "StorageIO": "pyrit.memory.storage.storage", +} + _warned: set[str] = set() @@ -218,4 +222,14 @@ def __getattr__(name: str) -> Any: ) _warned.add(name) return target + if name in _MOVED_TO_MEMORY_STORAGE: + target_module = _MOVED_TO_MEMORY_STORAGE[name] + if name not in _warned: + print_deprecation_message( + old_item=f"{__name__}.{name}", + new_item=f"{target_module}.{name}", + removed_in="0.17.0", + ) + _warned.add(name) + return getattr(importlib.import_module(target_module), name) raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/pyrit/models/data_type_serializer.py b/pyrit/models/data_type_serializer.py index 2cf9e6593d..a2659204a3 100644 --- a/pyrit/models/data_type_serializer.py +++ b/pyrit/models/data_type_serializer.py @@ -1,749 +1,39 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from __future__ import annotations - -import abc -import asyncio -import base64 -import hashlib -import tempfile -import time -import wave -from mimetypes import guess_type -from pathlib import Path -from typing import TYPE_CHECKING, Literal, get_args -from urllib.parse import urlparse - -import aiofiles - -from pyrit.common.deprecation import print_deprecation_message -from pyrit.common.path import DB_DATA_PATH -from pyrit.models.storage_io import DiskStorageIO, StorageIO - -if TYPE_CHECKING: - from pyrit.memory import MemoryInterface - from pyrit.models.literals import PromptDataType - -# Define allowed categories for validation -AllowedCategories = Literal["seed-prompt-entries", "prompt-memory-entries"] - - -def _write_wav_sync( - path: str, - *, - num_channels: int, - sample_width: int, - sample_rate: int, - data: bytes, -) -> None: - """Write PCM audio bytes to a WAV file synchronously.""" - with wave.open(path, "wb") as wav_file: - wav_file.setnchannels(num_channels) - wav_file.setsampwidth(sample_width) - wav_file.setframerate(sample_rate) - wav_file.writeframes(data) - - -def data_serializer_factory( - *, - data_type: PromptDataType, - value: str | None = None, - extension: str | None = None, - category: AllowedCategories, -) -> DataTypeSerializer: - """ - Create a DataTypeSerializer instance. - - Args: - data_type (str): The type of the data (e.g., 'text', 'image_path', 'audio_path'). - value (str): The data value to be serialized. - extension (str | None): The file extension, if applicable. - category (AllowedCategories): The category or context for the data (e.g., 'seed-prompt-entries'). - - Returns: - DataTypeSerializer: An instance of the appropriate serializer. - - Raises: - ValueError: If the category is not provided or invalid. - - """ - if not category: - raise ValueError( - f"The 'category' argument is mandatory and must be one of the following: {get_args(AllowedCategories)}." - ) - if value is not None: - if data_type in ["text", "reasoning", "function_call", "tool_call", "function_call_output"]: - return TextDataTypeSerializer(prompt_text=value, data_type=data_type) - if data_type == "image_path": - return ImagePathDataTypeSerializer(category=category, prompt_text=value, extension=extension) - if data_type == "audio_path": - return AudioPathDataTypeSerializer(category=category, prompt_text=value, extension=extension) - if data_type == "video_path": - return VideoPathDataTypeSerializer(category=category, prompt_text=value, extension=extension) - if data_type == "binary_path": - return BinaryPathDataTypeSerializer(category=category, prompt_text=value, extension=extension) - if data_type == "error": - return ErrorDataTypeSerializer(prompt_text=value) - if data_type == "url": - return URLDataTypeSerializer(category=category, prompt_text=value, extension=extension) - raise ValueError(f"Data type {data_type} not supported") - if data_type == "image_path": - return ImagePathDataTypeSerializer(category=category, extension=extension) - if data_type == "audio_path": - return AudioPathDataTypeSerializer(category=category, extension=extension) - if data_type == "video_path": - return VideoPathDataTypeSerializer(category=category, extension=extension) - if data_type == "binary_path": - return BinaryPathDataTypeSerializer(category=category, extension=extension) - if data_type == "error": - return ErrorDataTypeSerializer(prompt_text="") - raise ValueError(f"Data type {data_type} without prompt text not supported") - - -class DataTypeSerializer(abc.ABC): - """ - Abstract base class for data type normalizers. - - Responsible for reading and saving multi-modal data types to local disk or Azure Storage Account. - """ - - data_type: PromptDataType - value: str - category: str - data_sub_directory: str - file_extension: str - - _file_path: Path | str | None = None - - @property - def _memory(self) -> MemoryInterface: - from pyrit.memory import CentralMemory - - return CentralMemory.get_memory_instance() - - def _get_storage_io(self) -> StorageIO: - """ - Retrieve the input datasets storage handle. - - Returns: - StorageIO: An instance of DiskStorageIO or AzureBlobStorageIO based on the storage configuration. - - Raises: - ValueError: If the Azure Storage URL is detected but the datasets storage handle is not set. - RuntimeError: If results_storage_io is not configured but Azure storage URL was detected. - - """ - if self._is_azure_storage_url(self.value): - # Scenarios where a user utilizes an in-memory DuckDB but also needs to interact - # with an Azure Storage Account, ex., XPIAWorkflow. - if self._memory.results_storage_io is None: - raise RuntimeError("results_storage_io is not configured but Azure storage URL was detected") - return self._memory.results_storage_io - return DiskStorageIO() - - @abc.abstractmethod - def data_on_disk(self) -> bool: - """ - Indicate whether the data is stored on disk. - - Returns: - bool: True when data is persisted on disk. - - """ - - async def save_data_async(self, data: bytes, output_filename: str | None = None) -> None: - """ - Save data to storage. - - Arguments: - data: bytes: The data to be saved. - output_filename (optional, str): filename to store data as. Defaults to UUID if not provided - - Raises: - RuntimeError: If storage IO is not initialized. - """ - file_path = await self.get_data_filename_async(file_name=output_filename) - if self._memory.results_storage_io is None: - raise RuntimeError("Storage IO not initialized") - await self._memory.results_storage_io.write_file_async(file_path, data) - self.value = str(file_path) - - async def save_b64_image_async(self, data: str | bytes, output_filename: str | None = None) -> None: - """ - Save a base64-encoded image to storage. - - Arguments: - data: string or bytes with base64 data - output_filename (optional, str): filename to store image as. Defaults to UUID if not provided - - Raises: - RuntimeError: If storage IO is not initialized. - """ - file_path = await self.get_data_filename_async(file_name=output_filename) - image_bytes = base64.b64decode(data) - if self._memory.results_storage_io is None: - raise RuntimeError("Storage IO not initialized") - await self._memory.results_storage_io.write_file_async(file_path, image_bytes) - self.value = str(file_path) - - async def save_formatted_audio_async( - self, - data: bytes, - num_channels: int = 1, - sample_width: int = 2, - sample_rate: int = 16000, - output_filename: str | None = None, - ) -> None: - """ - Save PCM16 or similarly formatted audio data to storage. - - Arguments: - data: bytes with audio data - output_filename (optional, str): filename to store audio as. Defaults to UUID if not provided - num_channels (optional, int): number of channels in audio data. Defaults to 1 - sample_width (optional, int): sample width in bytes. Defaults to 2 - sample_rate (optional, int): sample rate in Hz. Defaults to 16000 - - Raises: - RuntimeError: If storage IO is not initialized. - """ - file_path = await self.get_data_filename_async(file_name=output_filename) - - # save audio file locally first if in AzureStorageBlob so we can use wave.open to set audio parameters - if self._is_azure_storage_url(str(file_path)): - with tempfile.NamedTemporaryFile(suffix=".wav", dir=DB_DATA_PATH, delete=False) as tmp: - local_temp_path = Path(tmp.name) - try: - await asyncio.to_thread( - _write_wav_sync, - str(local_temp_path), - num_channels=num_channels, - sample_width=sample_width, - sample_rate=sample_rate, - data=data, - ) - async with aiofiles.open(local_temp_path, "rb") as f: - audio_data = await f.read() - if self._memory.results_storage_io is None: - raise RuntimeError("self._memory.results_storage_io is not initialized") - await self._memory.results_storage_io.write_file_async(file_path, audio_data) - finally: - local_temp_path.unlink(missing_ok=True) - # If local, we can just save straight to disk and do not need to delete temp file after - else: - await asyncio.to_thread( - _write_wav_sync, - str(file_path), - num_channels=num_channels, - sample_width=sample_width, - sample_rate=sample_rate, - data=data, - ) - - self.value = str(file_path) - - async def read_data_async(self) -> bytes: - """ - Read data from storage. - - Returns: - bytes: The data read from storage. - - Raises: - TypeError: If the serializer does not represent on-disk data. - RuntimeError: If no value is set. - FileNotFoundError: If the referenced file does not exist. - - """ - if not self.data_on_disk(): - raise TypeError(f"Data for data Type {self.data_type} is not stored on disk") - - if not self.value: - raise RuntimeError("Prompt text not set") - - storage_io = self._get_storage_io() - # Check if path exists - file_exists = await storage_io.path_exists_async(path=self.value) - if not file_exists: - raise FileNotFoundError(f"File not found: {self.value}") - # Read the contents from the path - return await storage_io.read_file_async(self.value) - - async def read_data_base64_async(self) -> str: - """ - Read data from storage and return it as a base64 string. - - Returns: - str: Base64-encoded data. - - """ - byte_array = await self.read_data_async() - return base64.b64encode(byte_array).decode("utf-8") - - async def get_sha256_async(self) -> str: - """ - Compute SHA256 hash for this serializer's current value. - - Returns: - str: Hex digest of the computed SHA256 hash. - - Raises: - FileNotFoundError: If on-disk data path does not exist. - ValueError: If in-memory data cannot be converted to bytes. - - """ - input_bytes: bytes | None = None - - if self.data_on_disk(): - storage_io = self._get_storage_io() - file_exists = await storage_io.path_exists_async(self.value) - if not file_exists: - raise FileNotFoundError(f"File not found: {self.value}") - - # Read the data from storage - input_bytes = await storage_io.read_file_async(self.value) - else: - if isinstance(self.value, str): - input_bytes = self.value.encode("utf-8") - else: - raise ValueError(f"Invalid data type {self.value}, expected str data type.") - - hash_object = hashlib.sha256(input_bytes) - return hash_object.hexdigest() - - async def get_data_filename_async(self, file_name: str | None = None) -> Path | str: - """ - Generate or retrieve a unique filename for the data file. - - Args: - file_name (str | None): Optional file name override. - - Returns: - Path | str: Full storage path for the generated data file. - - Raises: - TypeError: If the serializer is not configured for on-disk data. - RuntimeError: If required data subdirectory information is missing. - - """ - if self._file_path: - return self._file_path - - if not self.data_on_disk(): - raise TypeError("Data is not stored on disk") - - if not self.data_sub_directory: - raise RuntimeError("Data sub directory not set") - - ticks = int(time.time() * 1_000_000) - if self._memory.results_path: - results_path = str(self._memory.results_path) - else: - from pyrit.common.path import DB_DATA_PATH - - results_path = str(DB_DATA_PATH) - file_name = file_name if file_name else str(ticks) - - if self._is_azure_storage_url(results_path): - full_data_directory_path = results_path + self.data_sub_directory - self._file_path = full_data_directory_path + f"/{file_name}.{self.file_extension}" - else: - full_data_directory_path = results_path + self.data_sub_directory - if self._memory.results_storage_io is None: - raise RuntimeError("self._memory.results_storage_io is not initialized") - await self._memory.results_storage_io.create_directory_if_not_exists_async(Path(full_data_directory_path)) - self._file_path = Path(full_data_directory_path, f"{file_name}.{self.file_extension}") - - return self._file_path - - async def save_data( # pyrit-async-suffix-exempt - self, data: bytes, output_filename: str | None = None - ) -> None: - """ - Save data to storage (deprecated alias of ``save_data_async``). - - Args: - data: The data to be saved. - output_filename: Optional filename to store data as. - """ - print_deprecation_message( - old_item="pyrit.models.data_type_serializer.DataTypeSerializer.save_data", - new_item="pyrit.models.data_type_serializer.DataTypeSerializer.save_data_async", - removed_in="0.16.0", - ) - await self.save_data_async(data, output_filename) - - async def save_b64_image( # pyrit-async-suffix-exempt - self, data: str | bytes, output_filename: str | None = None - ) -> None: - """ - Save a base64-encoded image to storage (deprecated alias of ``save_b64_image_async``). - - Args: - data: String or bytes with base64 data. - output_filename: Optional filename to store image as. - """ - print_deprecation_message( - old_item="pyrit.models.data_type_serializer.DataTypeSerializer.save_b64_image", - new_item="pyrit.models.data_type_serializer.DataTypeSerializer.save_b64_image_async", - removed_in="0.16.0", - ) - await self.save_b64_image_async(data, output_filename) - - async def save_formatted_audio( # pyrit-async-suffix-exempt - self, - data: bytes, - num_channels: int = 1, - sample_width: int = 2, - sample_rate: int = 16000, - output_filename: str | None = None, - ) -> None: - """ - Save formatted audio data to storage (deprecated alias of ``save_formatted_audio_async``). - - Args: - data: Audio data bytes. - num_channels: Number of channels in audio data. - sample_width: Sample width in bytes. - sample_rate: Sample rate in Hz. - output_filename: Optional filename to store audio as. - """ - print_deprecation_message( - old_item="pyrit.models.data_type_serializer.DataTypeSerializer.save_formatted_audio", - new_item="pyrit.models.data_type_serializer.DataTypeSerializer.save_formatted_audio_async", - removed_in="0.16.0", - ) - await self.save_formatted_audio_async(data, num_channels, sample_width, sample_rate, output_filename) - - async def read_data(self) -> bytes: # pyrit-async-suffix-exempt - """ - Read data from storage (deprecated alias of ``read_data_async``). - - Returns: - bytes: The data read from storage. - """ - print_deprecation_message( - old_item="pyrit.models.data_type_serializer.DataTypeSerializer.read_data", - new_item="pyrit.models.data_type_serializer.DataTypeSerializer.read_data_async", - removed_in="0.16.0", - ) - return await self.read_data_async() - - async def read_data_base64(self) -> str: # pyrit-async-suffix-exempt - """ - Read data and return it as a base64 string (deprecated alias of ``read_data_base64_async``). - - Returns: - str: Base64-encoded data. - """ - print_deprecation_message( - old_item="pyrit.models.data_type_serializer.DataTypeSerializer.read_data_base64", - new_item="pyrit.models.data_type_serializer.DataTypeSerializer.read_data_base64_async", - removed_in="0.16.0", - ) - return await self.read_data_base64_async() +""" +Deprecation shim — the data-type serializers now live in +``pyrit.memory.storage``. - async def get_sha256(self) -> str: # pyrit-async-suffix-exempt - """ - Compute SHA256 hash for this serializer's current value (deprecated alias of ``get_sha256_async``). +Importing names from ``pyrit.models.data_type_serializer`` still works for one +release but emits a one-time ``DeprecationWarning`` per name. Import from +``pyrit.memory.storage`` instead. This shim will be removed in 0.17.0. +""" - Returns: - str: Hex digest of the computed SHA256 hash. - """ - print_deprecation_message( - old_item="pyrit.models.data_type_serializer.DataTypeSerializer.get_sha256", - new_item="pyrit.models.data_type_serializer.DataTypeSerializer.get_sha256_async", - removed_in="0.16.0", - ) - return await self.get_sha256_async() - - async def get_data_filename( # pyrit-async-suffix-exempt - self, file_name: str | None = None - ) -> Path | str: - """ - Generate or retrieve a unique filename for the data file (deprecated alias of ``get_data_filename_async``). - - Args: - file_name: Optional file name override. - - Returns: - Path | str: Full storage path for the generated data file. - """ - print_deprecation_message( - old_item="pyrit.models.data_type_serializer.DataTypeSerializer.get_data_filename", - new_item="pyrit.models.data_type_serializer.DataTypeSerializer.get_data_filename_async", - removed_in="0.16.0", - ) - return await self.get_data_filename_async(file_name) - - @staticmethod - def get_extension(file_path: str) -> str | None: - """ - Get the file extension from the file path. - - Args: - file_path (str): Input file path. - - Returns: - str | None: File extension (including dot) or None if unavailable. - - """ - ext = Path(file_path).suffix - return ext or None - - @staticmethod - def get_mime_type(file_path: str) -> str | None: - """ - Get the MIME type of the file path. - - Args: - file_path (str): Input file path. - - Returns: - str | None: MIME type if detectable; otherwise None. - - """ - mime_type, _ = guess_type(file_path) - return mime_type - - def _is_azure_storage_url(self, path: str) -> bool: - """ - Validate whether the given path is an Azure Storage URL. - - Args: - path (str): Path or URL to check. - - Returns: - bool: True if the path is an Azure Blob Storage URL. - - """ - parsed = urlparse(path) - return parsed.scheme in ("http", "https") and "blob.core.windows.net" in parsed.netloc - - -class TextDataTypeSerializer(DataTypeSerializer): - """Serializer for text and text-like prompt values that stay in-memory.""" - - def __init__(self, *, prompt_text: str, data_type: PromptDataType = "text") -> None: - """ - Initialize a text serializer. - - Args: - prompt_text (str): Prompt value. - data_type (PromptDataType): Text-like prompt data type. - - """ - self.data_type = data_type - self.value = prompt_text - - def data_on_disk(self) -> bool: - """ - Indicate whether this serializer persists data on disk. - - Returns: - bool: Always False for text serializers. - - """ - return False - - -class ErrorDataTypeSerializer(DataTypeSerializer): - """Serializer for error payloads stored as in-memory text.""" - - def __init__(self, *, prompt_text: str) -> None: - """ - Initialize an error serializer. - - Args: - prompt_text (str): Error payload text. - - """ - self.data_type = "error" - self.value = prompt_text - - def data_on_disk(self) -> bool: - """ - Indicate whether this serializer persists data on disk. - - Returns: - bool: Always False for error serializers. - - """ - return False - - -class URLDataTypeSerializer(DataTypeSerializer): - """Serializer for URL values and URL-backed local file references.""" - - def __init__(self, *, category: str, prompt_text: str, extension: str | None = None) -> None: - """ - Initialize a URL serializer. - - Args: - category (str): Data category folder name. - prompt_text (str): URL or path value. - extension (str | None): Optional extension for persisted content. - - """ - self.data_type = "url" - self.value = prompt_text - self.data_sub_directory = f"/{category}/urls" - self.file_extension = extension if extension else "txt" - self.on_disk = not (prompt_text.startswith(("http://", "https://"))) - - def data_on_disk(self) -> bool: - """ - Indicate whether this serializer persists data on disk. - - Returns: - bool: True for non-http values, False for URL values. - - """ - return self.on_disk - - -class ImagePathDataTypeSerializer(DataTypeSerializer): - """Serializer for image path values stored on disk.""" - - def __init__(self, *, category: str, prompt_text: str | None = None, extension: str | None = None) -> None: - """ - Initialize an image-path serializer. - - Args: - category (str): Data category folder name. - prompt_text (str | None): Optional existing image path. - extension (str | None): Optional image extension. - - """ - self.data_type = "image_path" - self.data_sub_directory = f"/{category}/images" - self.file_extension = extension if extension else "png" - - if prompt_text: - self.value = prompt_text - - def data_on_disk(self) -> bool: - """ - Indicate whether this serializer persists data on disk. - - Returns: - bool: Always True for image path serializers. - - """ - return True - - -class AudioPathDataTypeSerializer(DataTypeSerializer): - """Serializer for audio path values stored on disk.""" - - def __init__( - self, - *, - category: str, - prompt_text: str | None = None, - extension: str | None = None, - ) -> None: - """ - Initialize an audio-path serializer. - - Args: - category (str): Data category folder name. - prompt_text (str | None): Optional existing audio path. - extension (str | None): Optional audio extension. - - """ - self.data_type = "audio_path" - self.data_sub_directory = f"/{category}/audio" - self.file_extension = extension if extension else "mp3" - - if prompt_text: - self.value = prompt_text - - def data_on_disk(self) -> bool: - """ - Indicate whether this serializer persists data on disk. - - Returns: - bool: Always True for audio path serializers. - - """ - return True - - -class VideoPathDataTypeSerializer(DataTypeSerializer): - """Serializer for video path values stored on disk.""" - - def __init__( - self, - *, - category: str, - prompt_text: str | None = None, - extension: str | None = None, - ) -> None: - """ - Initialize a video-path serializer. - - Args: - category (str): The category or context for the data. - prompt_text (str | None): The video path or identifier. - extension (str | None): The file extension, defaults to 'mp4'. - - """ - self.data_type = "video_path" - self.data_sub_directory = f"/{category}/videos" - self.file_extension = extension if extension else "mp4" - - if prompt_text: - self.value = prompt_text - - def data_on_disk(self) -> bool: - """ - Indicate whether this serializer persists data on disk. - - Returns: - bool: Always True for video path serializers. - - """ - return True - - -class BinaryPathDataTypeSerializer(DataTypeSerializer): - """Serializer for generic binary path values stored on disk.""" - - def __init__( - self, - *, - category: str, - prompt_text: str | None = None, - extension: str | None = None, - ) -> None: - """ - Initialize a generic binary-path serializer. - - This serializer handles generic binary data that doesn't fit into specific - categories like images, audio, or video. Useful for XPIA attacks and - storing files like PDFs, documents, or other binary formats. - - Args: - category (str): The category or context for the data. - prompt_text (str | None): The binary file path or identifier. - extension (str | None): The file extension, defaults to 'bin'. - - """ - self.data_type = "binary_path" - self.data_sub_directory = f"/{category}/binaries" - self.file_extension = extension if extension else "bin" - - if prompt_text: - self.value = prompt_text - - def data_on_disk(self) -> bool: - """ - Indicate whether this serializer persists data on disk. - - Returns: - bool: Always True for binary path serializers. +from __future__ import annotations - """ - return True +from pyrit.common.deprecation import module_deprecation_getattr + +__all__ = [ + "AllowedCategories", + "AudioPathDataTypeSerializer", + "BinaryPathDataTypeSerializer", + "DataTypeSerializer", + "data_serializer_factory", + "ErrorDataTypeSerializer", + "ImagePathDataTypeSerializer", + "TextDataTypeSerializer", + "URLDataTypeSerializer", + "VideoPathDataTypeSerializer", +] + +__getattr__ = module_deprecation_getattr( + old_module="pyrit.models.data_type_serializer", + target_module="pyrit.memory.storage.serializers", + names=__all__, + removed_in="0.17.0", +) + + +def __dir__() -> list[str]: + return sorted(__all__) diff --git a/pyrit/models/messages/message_piece.py b/pyrit/models/messages/message_piece.py index 728f736f20..1ad0533231 100644 --- a/pyrit/models/messages/message_piece.py +++ b/pyrit/models/messages/message_piece.py @@ -17,7 +17,6 @@ ) from pyrit.common.deprecation import print_deprecation_message -from pyrit.models.data_type_serializer import data_serializer_factory from pyrit.models.literals import ( # noqa: TC001 (runtime-required by Pydantic field annotations) ChatMessageRole, PromptDataType, @@ -306,22 +305,19 @@ async def set_sha256_values_async(self) -> None: """ Compute SHA256 hash values for original and converted payloads. - Async because blob payloads may need to be fetched. Must be called - explicitly after construction. + .. deprecated:: 0.15.0 + Use ``pyrit.memory.storage.serializers.set_message_piece_sha256_async`` instead. + This method will be removed in 0.17.0. """ - original_serializer = data_serializer_factory( - category="prompt-memory-entries", - data_type=self.original_value_data_type, - value=self.original_value, - ) - self.original_value_sha256 = await original_serializer.get_sha256_async() + import importlib - converted_serializer = data_serializer_factory( - category="prompt-memory-entries", - data_type=self.converted_value_data_type, - value=self.converted_value, + print_deprecation_message( + old_item="pyrit.models.messages.message_piece.MessagePiece.set_sha256_values_async", + new_item="pyrit.memory.storage.serializers.set_message_piece_sha256_async", + removed_in="0.17.0", ) - self.converted_value_sha256 = await converted_serializer.get_sha256_async() + serializers = importlib.import_module("pyrit.memory.storage.serializers") + await serializers.set_message_piece_sha256_async(self) def sort_message_pieces(message_pieces: list[MessagePiece]) -> list[MessagePiece]: diff --git a/pyrit/models/seeds/seed.py b/pyrit/models/seeds/seed.py index 353d5313c7..c8ee9588bf 100644 --- a/pyrit/models/seeds/seed.py +++ b/pyrit/models/seeds/seed.py @@ -232,19 +232,22 @@ def render_template_value_silent(self, **kwargs: Any) -> str: async def set_sha256_value_async(self) -> None: """ Compute the SHA256 hash value asynchronously. - It should be called after prompt `value` is serialized to text, - as file paths used in the `value` may have changed from local to memory storage paths. - Note, this method is async due to the blob retrieval. And because of that, we opted - to take it out of main and setter functions. The disadvantage is that it must be explicitly called. + .. deprecated:: 0.15.0 + Use ``pyrit.memory.storage.serializers.set_seed_sha256_async`` instead. + This method will be removed in 0.17.0. """ - from pyrit.models.data_type_serializer import data_serializer_factory + import importlib - original_serializer = data_serializer_factory( - category="seed-prompt-entries", data_type=self.data_type, value=self.value - ) + from pyrit.common.deprecation import print_deprecation_message - self.value_sha256 = await original_serializer.get_sha256_async() + print_deprecation_message( + old_item="pyrit.models.seeds.seed.Seed.set_sha256_value_async", + new_item="pyrit.memory.storage.serializers.set_seed_sha256_async", + removed_in="0.17.0", + ) + serializers = importlib.import_module("pyrit.memory.storage.serializers") + await serializers.set_seed_sha256_async(self) @staticmethod def escape_for_jinja(value: str) -> str: diff --git a/pyrit/models/seeds/seed_prompt.py b/pyrit/models/seeds/seed_prompt.py index 9656211be8..ecb3e5b4f6 100644 --- a/pyrit/models/seeds/seed_prompt.py +++ b/pyrit/models/seeds/seed_prompt.py @@ -15,7 +15,6 @@ from tinytag import TinyTag from pyrit.common.path import PATHS_DICT -from pyrit.models.data_type_serializer import DataTypeSerializer from pyrit.models.literals import ( # noqa: TC001 (runtime-required by Pydantic field annotations) ChatMessageRole, PromptDataType, @@ -106,7 +105,7 @@ def set_encoding_metadata(self) -> None: return if self.metadata is None: self.metadata = {} - extension = DataTypeSerializer.get_extension(self.value) + extension = Path(self.value).suffix or None if extension: extension = extension.lstrip(".") self.metadata.update({"format": extension}) diff --git a/pyrit/models/storage_io.py b/pyrit/models/storage_io.py index 15a7068c33..ba4b284e44 100644 --- a/pyrit/models/storage_io.py +++ b/pyrit/models/storage_io.py @@ -1,507 +1,33 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from __future__ import annotations - -import logging -from abc import ABC, abstractmethod -from enum import Enum -from pathlib import Path -from typing import TYPE_CHECKING -from urllib.parse import urlparse - -import aiofiles - -from pyrit.common.deprecation import print_deprecation_message - -if TYPE_CHECKING: - from azure.storage.blob.aio import ContainerClient as AsyncContainerClient - -logger = logging.getLogger(__name__) - - -class SupportedContentType(Enum): - """ - All supported content types for uploading blobs to provided storage account container. - See all options here: https://www.iana.org/assignments/media-types/media-types.xhtml. - """ - - # TODO, add other media supported types - PLAIN_TEXT = "text/plain" - - -class StorageIO(ABC): - """ - Abstract interface for storage systems (local disk, Azure Storage Account, etc.). - """ - - @abstractmethod - async def read_file_async(self, path: Path | str) -> bytes: - """ - Asynchronously reads the file (or blob) from the given path. - """ - - @abstractmethod - async def write_file_async(self, path: Path | str, data: bytes) -> None: - """ - Asynchronously writes data to the given path. - """ - - @abstractmethod - async def path_exists_async(self, path: Path | str) -> bool: - """ - Asynchronously checks if a file or blob exists at the given path. - """ - - @abstractmethod - async def is_file_async(self, path: Path | str) -> bool: - """ - Asynchronously checks if the path refers to a file (not a directory or container). - """ - - @abstractmethod - async def create_directory_if_not_exists_async(self, path: Path | str) -> None: - """ - Asynchronously creates a directory or equivalent in the storage system if it doesn't exist. - """ - - async def read_file(self, path: Path | str) -> bytes: # pyrit-async-suffix-exempt - """ - Read a file from storage (deprecated alias of ``read_file_async``). - - Args: - path (Path | str): The path to the file. - - Returns: - bytes: The content of the file. - """ - print_deprecation_message( - old_item="pyrit.models.storage_io.StorageIO.read_file", - new_item="pyrit.models.storage_io.StorageIO.read_file_async", - removed_in="0.16.0", - ) - return await self.read_file_async(path) - - async def write_file(self, path: Path | str, data: bytes) -> None: # pyrit-async-suffix-exempt - """ - Write data to storage (deprecated alias of ``write_file_async``). - - Args: - path (Path | str): The path to the file. - data (bytes): The content to write to the file. - """ - print_deprecation_message( - old_item="pyrit.models.storage_io.StorageIO.write_file", - new_item="pyrit.models.storage_io.StorageIO.write_file_async", - removed_in="0.16.0", - ) - await self.write_file_async(path, data) - - async def path_exists(self, path: Path | str) -> bool: # pyrit-async-suffix-exempt - """ - Check whether a path exists (deprecated alias of ``path_exists_async``). - - Args: - path (Path | str): The path to check. - - Returns: - bool: True if the path exists, False otherwise. - """ - print_deprecation_message( - old_item="pyrit.models.storage_io.StorageIO.path_exists", - new_item="pyrit.models.storage_io.StorageIO.path_exists_async", - removed_in="0.16.0", - ) - return await self.path_exists_async(path) - - async def is_file(self, path: Path | str) -> bool: # pyrit-async-suffix-exempt - """ - Check whether the given path is a file (deprecated alias of ``is_file_async``). - - Args: - path (Path | str): The path to check. - - Returns: - bool: True if the path is a file, False otherwise. - """ - print_deprecation_message( - old_item="pyrit.models.storage_io.StorageIO.is_file", - new_item="pyrit.models.storage_io.StorageIO.is_file_async", - removed_in="0.16.0", - ) - return await self.is_file_async(path) - - async def create_directory_if_not_exists(self, path: Path | str) -> None: # pyrit-async-suffix-exempt - """ - Create a directory if it does not exist (deprecated alias of ``create_directory_if_not_exists_async``). - - Args: - path (Path | str): The directory path to create. - """ - print_deprecation_message( - old_item="pyrit.models.storage_io.StorageIO.create_directory_if_not_exists", - new_item="pyrit.models.storage_io.StorageIO.create_directory_if_not_exists_async", - removed_in="0.16.0", - ) - await self.create_directory_if_not_exists_async(path) - - -class DiskStorageIO(StorageIO): - """ - Implementation of StorageIO for local disk storage. - """ - - async def read_file_async(self, path: Path | str) -> bytes: - """ - Asynchronously reads a file from the local disk. - - Args: - path (Path | str): The path to the file. - - Returns: - bytes: The content of the file. - - """ - path = self._convert_to_path(path) - async with aiofiles.open(path, "rb") as file: - return await file.read() - - async def write_file_async(self, path: Path | str, data: bytes) -> None: - """ - Asynchronously writes data to a file on the local disk. - - Args: - path (Path): The path to the file. - data (bytes): The content to write to the file. - - """ - path = self._convert_to_path(path) - async with aiofiles.open(path, "wb") as file: - await file.write(data) - - async def path_exists_async(self, path: Path | str) -> bool: - """ - Check whether a path exists on the local disk. - - Args: - path (Path): The path to check. - - Returns: - bool: True if the path exists, False otherwise. - - """ - path = self._convert_to_path(path) - return path.exists() - - async def is_file_async(self, path: Path | str) -> bool: - """ - Check whether the given path is a file (not a directory). - - Args: - path (Path): The path to check. - - Returns: - bool: True if the path is a file, False otherwise. - - """ - path = self._convert_to_path(path) - return path.is_file() - - async def create_directory_if_not_exists_async(self, path: Path | str) -> None: - """ - Asynchronously creates a directory if it doesn't exist on the local disk. - - Args: - path (Path): The directory path to create. - - """ - directory_path = self._convert_to_path(path) - if not directory_path.exists(): - directory_path.mkdir(parents=True, exist_ok=True) - - def _convert_to_path(self, path: Path | str) -> Path: - """ - Convert an input path to a Path object. - - Args: - path (Path | str): Input path value. +""" +Deprecation shim — the storage I/O classes now live in +``pyrit.memory.storage``. - Returns: - Path: Normalized Path instance. +Importing names from ``pyrit.models.storage_io`` still works for one release but +emits a one-time ``DeprecationWarning`` per name. Import from +``pyrit.memory.storage`` instead. This shim will be removed in 0.17.0. +""" - """ - return Path(path) if isinstance(path, str) else path - - -class AzureBlobStorageIO(StorageIO): - """ - Implementation of StorageIO for Azure Blob Storage. - """ - - def __init__( - self, - *, - container_url: str | None = None, - sas_token: str | None = None, - blob_content_type: SupportedContentType = SupportedContentType.PLAIN_TEXT, - ) -> None: - """ - Initialize an Azure Blob Storage I/O adapter. - - Args: - container_url (str | None): Azure Blob container URL. - sas_token (str | None): Optional SAS token. - blob_content_type (SupportedContentType): Blob content type for uploads. - - Raises: - ValueError: If container_url is missing. - - """ - self._blob_content_type: str = blob_content_type.value - if not container_url: - raise ValueError("Invalid Azure Storage Account Container URL.") - - self._container_url: str = container_url - self._sas_token = sas_token - self._client_async: AsyncContainerClient | None = None - - async def _create_container_client_async(self) -> AsyncContainerClient: - """ - Create an asynchronous ContainerClient for Azure Storage. - - If a SAS token is provided via the - AZURE_STORAGE_ACCOUNT_SAS_TOKEN environment variable or the init sas_token parameter, it will be used - for authentication. Otherwise, a delegation SAS token will be created using Entra ID authentication. - - Returns: - AsyncContainerClient: The initialized container client. - """ - from azure.storage.blob.aio import ContainerClient as AsyncContainerClient - - from pyrit.auth import AzureStorageAuth - - sas_token = self._sas_token - if not self._sas_token: - logger.info("SAS token not provided. Creating a delegation SAS token using Entra ID authentication.") - sas_token = await AzureStorageAuth.get_sas_token_async(self._container_url) - - self._client_async = AsyncContainerClient.from_container_url( - container_url=self._container_url, - credential=sas_token, - ) - return self._client_async - - async def _upload_blob_async(self, file_name: str, data: bytes, content_type: str) -> None: - """ - (Async) Handles uploading blob to given storage container. - - Args: - file_name (str): File name to assign to uploaded blob. - data (bytes): Byte representation of content to upload to container. - content_type (str): Content type to upload. - - Raises: - RuntimeError: If the Azure container client is not initialized. - """ - from azure.core.exceptions import ClientAuthenticationError - from azure.storage.blob import ContentSettings - - content_settings = ContentSettings(content_type=f"{content_type}") - logger.info(msg="\nUploading to Azure Storage as blob:\n\t" + file_name) - - try: - if self._client_async is None: - raise RuntimeError("Azure container client not initialized") - await self._client_async.upload_blob( - name=file_name, - data=data, - content_settings=content_settings, - overwrite=True, - ) - except Exception as exc: - if isinstance(exc, ClientAuthenticationError): - logger.exception( - msg="Authentication failed. Please check that the container existence in the " - "Azure Storage Account and ensure the validity of the provided SAS token. If you " - "haven't set the SAS token as an environment variable use `az login` to " - "enable delegation-based SAS authentication to connect to the storage account" - ) - raise - logger.exception(msg=f"An unexpected error occurred: {exc}") - raise - - def parse_blob_url(self, file_path: str) -> tuple[str, str]: - """ - Parse a blob URL to extract the container and blob name. - - Args: - file_path (str): Full blob URL. - - Returns: - tuple[str, str]: Container name and blob name. - - Raises: - ValueError: If file_path is not a valid blob URL. - - """ - parsed_url = urlparse(file_path) - if parsed_url.scheme and parsed_url.netloc: - container_name = parsed_url.path.split("/")[1] - blob_name = "/".join(parsed_url.path.split("/")[2:]) - return container_name, blob_name - raise ValueError("Invalid blob URL") - - def _resolve_blob_name(self, path: Path | str) -> str: - """ - Resolve a blob name from either a full blob URL or a relative blob path. - - When a full URL is provided the blob name is extracted from it. The container - name embedded in the URL is intentionally discarded — operations always run - against the container configured in the constructor. - - Backslashes are normalized to forward slashes so that ``Path`` objects - created on Windows still produce valid blob names. - - Args: - path (Path | str): Blob URL or relative blob path. - - Returns: - str: The resolved blob name. - - """ - path_str = str(path).replace("\\", "/") - try: - # parse_blob_url validates scheme + netloc internally - _, blob_name = self.parse_blob_url(path_str) - return blob_name - except ValueError: - return path_str - - async def read_file_async(self, path: Path | str) -> bytes: - """ - Asynchronously reads the content of a file (blob) from Azure Blob Storage. - - If the provided ``path`` is a full URL - (e.g., ``https://account.blob.core.windows.net/container/dir1/dir2/sample.png``), - it extracts the relative blob path (e.g., ``dir1/dir2/sample.png``) to correctly access the blob. - If a relative path is provided, it will use it as-is. - - Args: - path (str): The path to the file (blob) in Azure Blob Storage. - This can be either a full URL or a relative path. - - Returns: - bytes: The content of the file (blob) as bytes. - - Example: - ``file_content = await read_file_async("https://account.blob.core.windows.net/container/dir2/1726627689003831.png")`` - - Or using a relative path: - - ``file_content = await read_file_async("dir1/dir2/1726627689003831.png")`` - - """ - if not self._client_async: - self._client_async = await self._create_container_client_async() - - blob_name = self._resolve_blob_name(path) - - try: - blob_client = self._client_async.get_blob_client(blob=blob_name) - - # Download the blob - blob_stream = await blob_client.download_blob() - return bytes(await blob_stream.readall()) - - except Exception as exc: - logger.exception(f"Failed to read file at {blob_name}: {exc}") - raise - finally: - await self._client_async.close() - self._client_async = None - - async def write_file_async(self, path: Path | str, data: bytes) -> None: - """ - Write data to Azure Blob Storage at the specified path. - - If the provided ``path`` is a full URL, the blob name is extracted from it. - If a relative path is provided, it is used as the blob name directly. - - Args: - path (Path | str): Full blob URL or relative blob path. - data (bytes): The data to write. - """ - if not self._client_async: - self._client_async = await self._create_container_client_async() - blob_name = self._resolve_blob_name(path) - try: - await self._upload_blob_async(file_name=blob_name, data=data, content_type=self._blob_content_type) - except Exception as exc: - logger.exception(f"Failed to write file at {blob_name}: {exc}") - raise - finally: - await self._client_async.close() - self._client_async = None - - async def path_exists_async(self, path: Path | str) -> bool: - """ - Check whether a given path exists in the Azure Blob Storage container. - - Args: - path (Path | str): Blob URL or path to test. - - Returns: - bool: True when the path exists. - """ - from azure.core.exceptions import ResourceNotFoundError - - if not self._client_async: - self._client_async = await self._create_container_client_async() - try: - blob_name = self._resolve_blob_name(path) - blob_client = self._client_async.get_blob_client(blob=blob_name) - await blob_client.get_blob_properties() - return True - except ResourceNotFoundError: - return False - finally: - await self._client_async.close() - self._client_async = None - - async def is_file_async(self, path: Path | str) -> bool: - """ - Check whether the path refers to a file (blob) in Azure Blob Storage. - - Args: - path (Path | str): Blob URL or path to test. +from __future__ import annotations - Returns: - bool: True when the blob exists and has non-zero content size. - """ - from azure.core.exceptions import ResourceNotFoundError +from pyrit.common.deprecation import module_deprecation_getattr - if not self._client_async: - self._client_async = await self._create_container_client_async() - try: - blob_name = self._resolve_blob_name(path) - blob_client = self._client_async.get_blob_client(blob=blob_name) - blob_properties = await blob_client.get_blob_properties() - return bool(blob_properties.size > 0) - except ResourceNotFoundError: - return False - finally: - await self._client_async.close() - self._client_async = None +__all__ = [ + "AzureBlobStorageIO", + "DiskStorageIO", + "StorageIO", + "SupportedContentType", +] - async def create_directory_if_not_exists_async(self, directory_path: Path | str) -> None: # type: ignore[ty:invalid-method-override] - """ - Log a no-op directory creation for Azure Blob Storage. +__getattr__ = module_deprecation_getattr( + old_module="pyrit.models.storage_io", + target_module="pyrit.memory.storage.storage", + names=__all__, + removed_in="0.17.0", +) - Args: - directory_path (Path | str): Requested directory path. - """ - logger.info( - f"Directory creation is handled automatically during upload operations in Azure Blob Storage. " - f"Directory path: {directory_path}" - ) +def __dir__() -> list[str]: + return sorted(__all__) diff --git a/pyrit/output/conversation/pretty.py b/pyrit/output/conversation/pretty.py index 7af6250c1f..10aff03afb 100644 --- a/pyrit/output/conversation/pretty.py +++ b/pyrit/output/conversation/pretty.py @@ -338,7 +338,7 @@ async def _display_image_async(self, piece: MessagePiece) -> None: if not is_in_ipython_session(): return - from pyrit.models.data_type_serializer import ImagePathDataTypeSerializer + from pyrit.memory import ImagePathDataTypeSerializer try: serializer = ImagePathDataTypeSerializer(category="", prompt_text=piece.converted_value) diff --git a/pyrit/prompt_converter/add_image_text_converter.py b/pyrit/prompt_converter/add_image_text_converter.py index d4ee0809c8..2f5678b021 100644 --- a/pyrit/prompt_converter/add_image_text_converter.py +++ b/pyrit/prompt_converter/add_image_text_converter.py @@ -10,7 +10,8 @@ from PIL.ImageFont import FreeTypeFont from pyrit.common.deprecation import print_deprecation_message -from pyrit.models import ComponentIdentifier, PromptDataType, data_serializer_factory +from pyrit.memory import data_serializer_factory +from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.base_image_text_converter import _BaseImageTextConverter from pyrit.prompt_converter.prompt_converter import ConverterResult @@ -192,7 +193,7 @@ def _load_font_at_size(self, size: int) -> FreeTypeFont: if self._font_load_failed: return cast("FreeTypeFont", ImageFont.load_default(size=size)) try: - return ImageFont.truetype(self._font_name, size) + return ImageFont.truetype(self._font_name, size) # type: ignore[ty:invalid-argument-type] except OSError: logger.warning(f"Cannot open font resource: {self._font_name}. Using Pillow built-in default font.") self._font_load_failed = True diff --git a/pyrit/prompt_converter/add_image_to_video_converter.py b/pyrit/prompt_converter/add_image_to_video_converter.py index 73ccc9220d..74af2499f9 100644 --- a/pyrit/prompt_converter/add_image_to_video_converter.py +++ b/pyrit/prompt_converter/add_image_to_video_converter.py @@ -9,7 +9,8 @@ import numpy as np from pyrit.common.path import DB_DATA_PATH -from pyrit.models import ComponentIdentifier, PromptDataType, data_serializer_factory +from pyrit.memory import data_serializer_factory +from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter logger = logging.getLogger(__name__) diff --git a/pyrit/prompt_converter/add_text_image_converter.py b/pyrit/prompt_converter/add_text_image_converter.py index 57964c8e75..a4576f8e41 100644 --- a/pyrit/prompt_converter/add_text_image_converter.py +++ b/pyrit/prompt_converter/add_text_image_converter.py @@ -11,7 +11,8 @@ from PIL.ImageFont import FreeTypeFont from pyrit.common.deprecation import print_deprecation_message -from pyrit.models import ComponentIdentifier, PromptDataType, data_serializer_factory +from pyrit.memory import data_serializer_factory +from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.base_image_text_converter import _BaseImageTextConverter from pyrit.prompt_converter.prompt_converter import ConverterResult diff --git a/pyrit/prompt_converter/audio_echo_converter.py b/pyrit/prompt_converter/audio_echo_converter.py index 73a40385d4..176b8fa219 100644 --- a/pyrit/prompt_converter/audio_echo_converter.py +++ b/pyrit/prompt_converter/audio_echo_converter.py @@ -8,7 +8,8 @@ import numpy as np from scipy.io import wavfile -from pyrit.models import PromptDataType, data_serializer_factory +from pyrit.memory import data_serializer_factory +from pyrit.models import PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter logger = logging.getLogger(__name__) diff --git a/pyrit/prompt_converter/audio_frequency_converter.py b/pyrit/prompt_converter/audio_frequency_converter.py index 65050ea1f1..cada5a407e 100644 --- a/pyrit/prompt_converter/audio_frequency_converter.py +++ b/pyrit/prompt_converter/audio_frequency_converter.py @@ -8,7 +8,8 @@ import numpy as np from scipy.io import wavfile -from pyrit.models import ComponentIdentifier, PromptDataType, data_serializer_factory +from pyrit.memory import data_serializer_factory +from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter logger = logging.getLogger(__name__) diff --git a/pyrit/prompt_converter/audio_speed_converter.py b/pyrit/prompt_converter/audio_speed_converter.py index 9a7a8053e6..1eb81538cd 100644 --- a/pyrit/prompt_converter/audio_speed_converter.py +++ b/pyrit/prompt_converter/audio_speed_converter.py @@ -9,7 +9,8 @@ from scipy.interpolate import interp1d from scipy.io import wavfile -from pyrit.models import PromptDataType, data_serializer_factory +from pyrit.memory import data_serializer_factory +from pyrit.models import PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter logger = logging.getLogger(__name__) diff --git a/pyrit/prompt_converter/audio_volume_converter.py b/pyrit/prompt_converter/audio_volume_converter.py index 40e8e2a340..5bc088af33 100644 --- a/pyrit/prompt_converter/audio_volume_converter.py +++ b/pyrit/prompt_converter/audio_volume_converter.py @@ -8,7 +8,8 @@ import numpy as np from scipy.io import wavfile -from pyrit.models import PromptDataType, data_serializer_factory +from pyrit.memory import data_serializer_factory +from pyrit.models import PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter logger = logging.getLogger(__name__) diff --git a/pyrit/prompt_converter/audio_white_noise_converter.py b/pyrit/prompt_converter/audio_white_noise_converter.py index 63726ce356..187a0de019 100644 --- a/pyrit/prompt_converter/audio_white_noise_converter.py +++ b/pyrit/prompt_converter/audio_white_noise_converter.py @@ -8,7 +8,8 @@ import numpy as np from scipy.io import wavfile -from pyrit.models import PromptDataType, data_serializer_factory +from pyrit.memory import data_serializer_factory +from pyrit.models import PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter logger = logging.getLogger(__name__) diff --git a/pyrit/prompt_converter/azure_speech_audio_to_text_converter.py b/pyrit/prompt_converter/azure_speech_audio_to_text_converter.py index d4c6698aae..3d4fd20260 100644 --- a/pyrit/prompt_converter/azure_speech_audio_to_text_converter.py +++ b/pyrit/prompt_converter/azure_speech_audio_to_text_converter.py @@ -13,7 +13,8 @@ from pyrit.auth.azure_auth import get_speech_config, get_speech_config_async from pyrit.common import default_values from pyrit.common.deprecation import print_deprecation_message -from pyrit.models import ComponentIdentifier, PromptDataType, data_serializer_factory +from pyrit.memory import data_serializer_factory +from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter logger = logging.getLogger(__name__) diff --git a/pyrit/prompt_converter/azure_speech_text_to_audio_converter.py b/pyrit/prompt_converter/azure_speech_text_to_audio_converter.py index 98b107f7b3..66e42407a5 100644 --- a/pyrit/prompt_converter/azure_speech_text_to_audio_converter.py +++ b/pyrit/prompt_converter/azure_speech_text_to_audio_converter.py @@ -11,7 +11,8 @@ from pyrit.auth.azure_auth import get_speech_config_async from pyrit.common import default_values from pyrit.common.deprecation import print_deprecation_message -from pyrit.models import ComponentIdentifier, PromptDataType, data_serializer_factory +from pyrit.memory import data_serializer_factory +from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter logger = logging.getLogger(__name__) diff --git a/pyrit/prompt_converter/base_image_to_image_converter.py b/pyrit/prompt_converter/base_image_to_image_converter.py index 0351e9a6ad..87434ec0e4 100644 --- a/pyrit/prompt_converter/base_image_to_image_converter.py +++ b/pyrit/prompt_converter/base_image_to_image_converter.py @@ -11,7 +11,8 @@ import aiohttp from PIL import Image -from pyrit.models import PromptDataType, data_serializer_factory +from pyrit.memory import data_serializer_factory +from pyrit.models import PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter logger = logging.getLogger(__name__) diff --git a/pyrit/prompt_converter/image_compression_converter.py b/pyrit/prompt_converter/image_compression_converter.py index da8c01e91f..d250d002a1 100644 --- a/pyrit/prompt_converter/image_compression_converter.py +++ b/pyrit/prompt_converter/image_compression_converter.py @@ -10,7 +10,8 @@ import aiohttp from PIL import Image -from pyrit.models import ComponentIdentifier, PromptDataType, data_serializer_factory +from pyrit.memory import data_serializer_factory +from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter logger = logging.getLogger(__name__) diff --git a/pyrit/prompt_converter/image_overlay_converter.py b/pyrit/prompt_converter/image_overlay_converter.py index cd0d1d0c3a..ef416cfb02 100644 --- a/pyrit/prompt_converter/image_overlay_converter.py +++ b/pyrit/prompt_converter/image_overlay_converter.py @@ -7,8 +7,8 @@ from PIL import Image -from pyrit.models import ComponentIdentifier, PromptDataType, data_serializer_factory -from pyrit.models.data_type_serializer import DataTypeSerializer +from pyrit.memory import DataTypeSerializer, data_serializer_factory +from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter logger = logging.getLogger(__name__) diff --git a/pyrit/prompt_converter/pdf_converter.py b/pyrit/prompt_converter/pdf_converter.py index fd75b05242..3b9dcfc12a 100644 --- a/pyrit/prompt_converter/pdf_converter.py +++ b/pyrit/prompt_converter/pdf_converter.py @@ -13,8 +13,8 @@ from reportlab.pdfgen import canvas from pyrit.common.logger import logger -from pyrit.models import ComponentIdentifier, PromptDataType, SeedPrompt, data_serializer_factory -from pyrit.models.data_type_serializer import DataTypeSerializer +from pyrit.memory import DataTypeSerializer, data_serializer_factory +from pyrit.models import ComponentIdentifier, PromptDataType, SeedPrompt from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter diff --git a/pyrit/prompt_converter/qr_code_converter.py b/pyrit/prompt_converter/qr_code_converter.py index b0515fd92d..bfba824ca8 100644 --- a/pyrit/prompt_converter/qr_code_converter.py +++ b/pyrit/prompt_converter/qr_code_converter.py @@ -4,7 +4,8 @@ import segno -from pyrit.models import ComponentIdentifier, PromptDataType, data_serializer_factory +from pyrit.memory import data_serializer_factory +from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter diff --git a/pyrit/prompt_converter/transparency_attack_converter.py b/pyrit/prompt_converter/transparency_attack_converter.py index c02f1be008..7301504450 100644 --- a/pyrit/prompt_converter/transparency_attack_converter.py +++ b/pyrit/prompt_converter/transparency_attack_converter.py @@ -10,7 +10,8 @@ import numpy as np from PIL import Image -from pyrit.models import ComponentIdentifier, PromptDataType, data_serializer_factory +from pyrit.memory import data_serializer_factory +from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter logger = logging.getLogger(__name__) diff --git a/pyrit/prompt_converter/word_doc_converter.py b/pyrit/prompt_converter/word_doc_converter.py index b811ab1b3f..4766254c12 100644 --- a/pyrit/prompt_converter/word_doc_converter.py +++ b/pyrit/prompt_converter/word_doc_converter.py @@ -12,14 +12,14 @@ from docx import Document from pyrit.common.logger import logger -from pyrit.models import PromptDataType, SeedPrompt, data_serializer_factory +from pyrit.memory import data_serializer_factory from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter if TYPE_CHECKING: from pathlib import Path - from pyrit.models import ComponentIdentifier - from pyrit.models.data_type_serializer import DataTypeSerializer + from pyrit.memory import DataTypeSerializer + from pyrit.models import ComponentIdentifier, PromptDataType, SeedPrompt @dataclass diff --git a/pyrit/prompt_normalizer/prompt_normalizer.py b/pyrit/prompt_normalizer/prompt_normalizer.py index c089930778..360dcde35c 100644 --- a/pyrit/prompt_normalizer/prompt_normalizer.py +++ b/pyrit/prompt_normalizer/prompt_normalizer.py @@ -19,7 +19,7 @@ execution_context, get_execution_context, ) -from pyrit.memory import CentralMemory, MemoryInterface +from pyrit.memory import CentralMemory, MemoryInterface, set_message_piece_sha256_async from pyrit.models import ( ComponentIdentifier, Message, @@ -377,7 +377,7 @@ async def convert_audio_async( async def _calc_hash_async(self, request: Message) -> None: """Add a request to the memory.""" - tasks = [asyncio.create_task(piece.set_sha256_values_async()) for piece in request.message_pieces] + tasks = [asyncio.create_task(set_message_piece_sha256_async(piece)) for piece in request.message_pieces] await asyncio.gather(*tasks) async def hash_and_persist_message_async(self, *, message: Message) -> None: diff --git a/pyrit/prompt_target/openai/openai_chat_target.py b/pyrit/prompt_target/openai/openai_chat_target.py index 1fd2400ca9..9ff713ab86 100644 --- a/pyrit/prompt_target/openai/openai_chat_target.py +++ b/pyrit/prompt_target/openai/openai_chat_target.py @@ -13,15 +13,8 @@ PyritException, pyrit_target_retry, ) -from pyrit.models import ( - ChatMessage, - ComponentIdentifier, - DataTypeSerializer, - Message, - MessagePiece, - construct_response_from_request, - data_serializer_factory, -) +from pyrit.memory import DataTypeSerializer, data_serializer_factory +from pyrit.models import ChatMessage, ComponentIdentifier, Message, MessagePiece, construct_response_from_request from pyrit.models.json_response_config import _JsonResponseConfig from pyrit.prompt_target.common.target_capabilities import TargetCapabilities from pyrit.prompt_target.common.target_configuration import TargetConfiguration diff --git a/pyrit/prompt_target/openai/openai_image_target.py b/pyrit/prompt_target/openai/openai_image_target.py index 5d03f985c3..749e47b672 100644 --- a/pyrit/prompt_target/openai/openai_image_target.py +++ b/pyrit/prompt_target/openai/openai_image_target.py @@ -11,12 +11,8 @@ EmptyResponseException, pyrit_target_retry, ) -from pyrit.models import ( - ComponentIdentifier, - Message, - construct_response_from_request, - data_serializer_factory, -) +from pyrit.memory import data_serializer_factory +from pyrit.models import ComponentIdentifier, Message, construct_response_from_request from pyrit.prompt_target.common.target_capabilities import TargetCapabilities from pyrit.prompt_target.common.target_configuration import TargetConfiguration from pyrit.prompt_target.common.utils import limit_requests_per_minute diff --git a/pyrit/prompt_target/openai/openai_realtime_target.py b/pyrit/prompt_target/openai/openai_realtime_target.py index e5ed421385..bb8c305f47 100644 --- a/pyrit/prompt_target/openai/openai_realtime_target.py +++ b/pyrit/prompt_target/openai/openai_realtime_target.py @@ -15,12 +15,8 @@ pyrit_target_retry, ) from pyrit.exceptions.exception_classes import ServerErrorException -from pyrit.models import ( - ComponentIdentifier, - Message, - construct_response_from_request, - data_serializer_factory, -) +from pyrit.memory import data_serializer_factory +from pyrit.models import ComponentIdentifier, Message, construct_response_from_request from pyrit.prompt_target.common.realtime_audio import ( RealtimeTargetResult, ServerVadConfig, diff --git a/pyrit/prompt_target/openai/openai_tts_target.py b/pyrit/prompt_target/openai/openai_tts_target.py index eaa6101c75..03602c31a4 100644 --- a/pyrit/prompt_target/openai/openai_tts_target.py +++ b/pyrit/prompt_target/openai/openai_tts_target.py @@ -7,12 +7,8 @@ from pyrit.exceptions import ( pyrit_target_retry, ) -from pyrit.models import ( - ComponentIdentifier, - Message, - construct_response_from_request, - data_serializer_factory, -) +from pyrit.memory import data_serializer_factory +from pyrit.models import ComponentIdentifier, Message, construct_response_from_request from pyrit.prompt_target.common.target_capabilities import TargetCapabilities from pyrit.prompt_target.common.target_configuration import TargetConfiguration from pyrit.prompt_target.common.utils import limit_requests_per_minute @@ -145,8 +141,8 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me model=str(body_parameters["model"]), voice=str(body_parameters["voice"]), input=str(body_parameters["input"]), - response_format=body_parameters.get("response_format"), - speed=body_parameters.get("speed"), + response_format=body_parameters.get("response_format"), # type: ignore[ty:invalid-argument-type] + speed=body_parameters.get("speed"), # type: ignore[ty:invalid-argument-type] ), request=message, ) diff --git a/pyrit/prompt_target/openai/openai_video_target.py b/pyrit/prompt_target/openai/openai_video_target.py index db8deadf35..0d83d7ed74 100644 --- a/pyrit/prompt_target/openai/openai_video_target.py +++ b/pyrit/prompt_target/openai/openai_video_target.py @@ -11,14 +11,8 @@ from pyrit.exceptions import ( pyrit_target_retry, ) -from pyrit.models import ( - ComponentIdentifier, - DataTypeSerializer, - Message, - MessagePiece, - construct_response_from_request, - data_serializer_factory, -) +from pyrit.memory import DataTypeSerializer, data_serializer_factory +from pyrit.models import ComponentIdentifier, Message, MessagePiece, construct_response_from_request from pyrit.prompt_target.common.target_capabilities import TargetCapabilities from pyrit.prompt_target.common.target_configuration import TargetConfiguration from pyrit.prompt_target.common.utils import limit_requests_per_minute diff --git a/pyrit/prompt_target/playwright_copilot_target.py b/pyrit/prompt_target/playwright_copilot_target.py index fa87ed9e4a..99a625cae6 100644 --- a/pyrit/prompt_target/playwright_copilot_target.py +++ b/pyrit/prompt_target/playwright_copilot_target.py @@ -9,13 +9,8 @@ from enum import Enum from typing import TYPE_CHECKING, Any -from pyrit.models import ( - ComponentIdentifier, - Message, - MessagePiece, - construct_response_from_request, - data_serializer_factory, -) +from pyrit.memory import data_serializer_factory +from pyrit.models import ComponentIdentifier, Message, MessagePiece, construct_response_from_request from pyrit.models.literals import PromptDataType from pyrit.prompt_target.common.prompt_target import PromptTarget from pyrit.prompt_target.common.target_capabilities import TargetCapabilities diff --git a/pyrit/prompt_target/websocket_copilot_target.py b/pyrit/prompt_target/websocket_copilot_target.py index 41ffaa99a5..98b380e35b 100644 --- a/pyrit/prompt_target/websocket_copilot_target.py +++ b/pyrit/prompt_target/websocket_copilot_target.py @@ -19,7 +19,8 @@ EmptyResponseException, pyrit_target_retry, ) -from pyrit.models import ComponentIdentifier, DataTypeSerializer, Message, MessagePiece, construct_response_from_request +from pyrit.memory import DataTypeSerializer +from pyrit.models import ComponentIdentifier, Message, MessagePiece, construct_response_from_request from pyrit.prompt_target import PromptTarget, limit_requests_per_minute from pyrit.prompt_target.common.target_capabilities import TargetCapabilities from pyrit.prompt_target.common.target_configuration import TargetConfiguration diff --git a/pyrit/score/float_scale/azure_content_filter_scorer.py b/pyrit/score/float_scale/azure_content_filter_scorer.py index ff8747e298..99d24ef06a 100644 --- a/pyrit/score/float_scale/azure_content_filter_scorer.py +++ b/pyrit/score/float_scale/azure_content_filter_scorer.py @@ -19,14 +19,8 @@ from pyrit.auth import AsyncTokenProviderCredential, ensure_async_token_provider, get_azure_async_token_provider from pyrit.common import default_values -from pyrit.models import ( - ComponentIdentifier, - DataTypeSerializer, - Message, - MessagePiece, - Score, - data_serializer_factory, -) +from pyrit.memory import DataTypeSerializer, data_serializer_factory +from pyrit.models import ComponentIdentifier, Message, MessagePiece, Score from pyrit.score.float_scale.float_scale_score_aggregator import ( FloatScaleScorerByCategory, ) @@ -150,7 +144,7 @@ def __init__( if callable(self._api_key): # Token provider - create an AsyncTokenCredential wrapper credential = AsyncTokenProviderCredential(self._api_key) # type: ignore[ty:invalid-argument-type] - self._azure_cf_client = ContentSafetyClient(self._endpoint, credential=credential) + self._azure_cf_client = ContentSafetyClient(self._endpoint, credential=credential) # type: ignore[ty:invalid-argument-type] else: # String API key if not isinstance(self._api_key, str): diff --git a/tests/unit/common/test_convert_local_image_to_data_url.py b/tests/unit/common/test_convert_local_image_to_data_url.py index bebbd6c67e..502cc1df95 100644 --- a/tests/unit/common/test_convert_local_image_to_data_url.py +++ b/tests/unit/common/test_convert_local_image_to_data_url.py @@ -50,7 +50,7 @@ async def test_convert_local_image_to_data_url_missing_file(): @patch("os.path.exists", return_value=True) @patch("mimetypes.guess_type", return_value=("image/jpg", None)) -@patch("pyrit.models.data_type_serializer.ImagePathDataTypeSerializer") +@patch("pyrit.memory.storage.serializers.ImagePathDataTypeSerializer") @patch("pyrit.memory.CentralMemory.get_memory_instance", return_value=SQLiteMemory(db_path=":memory:")) async def test_convert_image_to_data_url_success( mock_get_memory_instance, mock_serializer_class, mock_guess_type, mock_exists diff --git a/tests/unit/common/test_display_response.py b/tests/unit/common/test_display_response.py index ff2ac368af..faac0e90ce 100644 --- a/tests/unit/common/test_display_response.py +++ b/tests/unit/common/test_display_response.py @@ -98,7 +98,7 @@ async def test_display_image_logs_error_when_storage_io_is_none(mock_ipython, ca @patch("pyrit.common.display_response.display", create=True) async def test_display_image_azure_fallback_to_disk(mock_display, mock_image, mock_disk_io_cls, mock_ipython): """Test that when AzureBlobStorageIO read fails, it falls back to DiskStorageIO.""" - from pyrit.models import AzureBlobStorageIO + from pyrit.memory import AzureBlobStorageIO mock_memory = MagicMock() mock_azure_io = MagicMock(spec=AzureBlobStorageIO) @@ -126,7 +126,7 @@ async def test_display_image_azure_fallback_to_disk(mock_display, mock_image, mo @patch("pyrit.common.display_response.DiskStorageIO") async def test_display_image_azure_and_disk_both_fail(mock_disk_io_cls, mock_ipython, caplog): """Test that when both AzureBlobStorageIO and DiskStorageIO fail, error is logged and returns.""" - from pyrit.models import AzureBlobStorageIO + from pyrit.memory import AzureBlobStorageIO mock_memory = MagicMock() mock_azure_io = MagicMock(spec=AzureBlobStorageIO) diff --git a/tests/unit/memory/storage/__init__.py b/tests/unit/memory/storage/__init__.py new file mode 100644 index 0000000000..9a0454564d --- /dev/null +++ b/tests/unit/memory/storage/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. diff --git a/tests/unit/memory/storage/test_deprecation_shims.py b/tests/unit/memory/storage/test_deprecation_shims.py new file mode 100644 index 0000000000..3d5a748508 --- /dev/null +++ b/tests/unit/memory/storage/test_deprecation_shims.py @@ -0,0 +1,181 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Tests for the Phase 9 deprecation shims. + +``pyrit.models.storage_io`` and ``pyrit.models.data_type_serializer`` moved to +``pyrit.memory.storage.storage`` / ``pyrit.memory.storage.serializers``. The old module paths, the +``pyrit.models`` package-root re-exports, and the +``MessagePiece.set_sha256_values_async`` / ``Seed.set_sha256_value_async`` +method shims all still work but emit a ``DeprecationWarning`` pointing at the +new ``pyrit.memory.storage`` location. These tests pin that contract. The shims will be +removed in 0.17.0. +""" + +from __future__ import annotations + +import importlib +import subprocess +import sys +import warnings +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +import pyrit.memory.storage.serializers as new_serializers +import pyrit.memory.storage.storage as new_storage +import pyrit.models as models_pkg +import pyrit.models.data_type_serializer as serializer_shim +import pyrit.models.storage_io as storage_shim +from pyrit.models.messages.message_piece import MessagePiece +from pyrit.models.seeds.seed import Seed + +MODULE_SHIM_PAIRS = [ + (storage_shim, new_storage, "pyrit.models.storage_io", "pyrit.memory.storage.storage"), + (serializer_shim, new_serializers, "pyrit.models.data_type_serializer", "pyrit.memory.storage.serializers"), +] + + +@pytest.fixture(autouse=True) +def _reset_models_warned(): + """Reset the ``pyrit.models`` package-root warn-once cache so each test starts clean.""" + saved = set(models_pkg._warned) + models_pkg._warned.clear() + try: + yield + finally: + models_pkg._warned.clear() + models_pkg._warned.update(saved) + + +@pytest.mark.parametrize("shim_mod, new_mod, old_path, new_path", MODULE_SHIM_PAIRS) +def test_module_shim_forwards_every_name(shim_mod, new_mod, old_path, new_path): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + for name in shim_mod.__all__: + assert getattr(shim_mod, name) is getattr(new_mod, name), f"{old_path}.{name} did not forward" + + +@pytest.mark.parametrize("shim_mod, new_mod, old_path, new_path", MODULE_SHIM_PAIRS) +def test_module_shim_warns_once_per_name(shim_mod, new_mod, old_path, new_path): + # Reload the shim to reset its internal warn-once closure for a clean count. + shim_mod = importlib.reload(shim_mod) + for name in shim_mod.__all__: + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always", DeprecationWarning) + getattr(shim_mod, name) + getattr(shim_mod, name) + + dep = [w for w in caught if issubclass(w.category, DeprecationWarning)] + assert len(dep) == 1, f"Expected 1 DeprecationWarning for {old_path}.{name}, got {len(dep)}" + message = str(dep[0].message) + assert f"{old_path}.{name}" in message + assert f"{new_path}.{name}" in message + assert "0.17.0" in message + + +@pytest.mark.parametrize("shim_mod, new_mod, old_path, new_path", MODULE_SHIM_PAIRS) +def test_module_shim_attribute_error_for_unknown_name(shim_mod, new_mod, old_path, new_path): + with pytest.raises(AttributeError, match=f"module {old_path!r} has no attribute"): + _ = shim_mod.definitely_not_a_real_name + + +@pytest.mark.parametrize("shim_mod, new_mod, old_path, new_path", MODULE_SHIM_PAIRS) +def test_module_shim_dir_returns_sorted_all(shim_mod, new_mod, old_path, new_path): + assert dir(shim_mod) == sorted(shim_mod.__all__) + + +def test_moved_to_memory_storage_contains_expected_root_exports(): + # Guards against accidentally dropping a previously root-importable name from the + # forwarding table. These are exactly the names that used to be importable from + # ``pyrit.models`` and now live in ``pyrit.memory.storage``. URLDataTypeSerializer and + # SupportedContentType were never root-exported, so they are intentionally absent. + expected = { + "AllowedCategories", + "AudioPathDataTypeSerializer", + "BinaryPathDataTypeSerializer", + "DataTypeSerializer", + "ErrorDataTypeSerializer", + "ImagePathDataTypeSerializer", + "TextDataTypeSerializer", + "VideoPathDataTypeSerializer", + "data_serializer_factory", + "AzureBlobStorageIO", + "DiskStorageIO", + "StorageIO", + } + assert set(models_pkg._MOVED_TO_MEMORY_STORAGE) == expected + + +@pytest.mark.parametrize("name", sorted(models_pkg._MOVED_TO_MEMORY_STORAGE)) +def test_models_package_root_forwards_and_warns_once(name): + target_module = models_pkg._MOVED_TO_MEMORY_STORAGE[name] + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always", DeprecationWarning) + first = getattr(models_pkg, name) + second = getattr(models_pkg, name) + + assert first is second + assert first is getattr(importlib.import_module(target_module), name) + + dep = [w for w in caught if issubclass(w.category, DeprecationWarning)] + assert len(dep) == 1, f"Expected 1 DeprecationWarning for pyrit.models.{name}, got {len(dep)}" + message = str(dep[0].message) + assert f"pyrit.models.{name}" in message + assert f"{target_module}.{name}" in message + assert "0.17.0" in message + + +def test_importing_pyrit_models_does_not_warn(): + # Use a subprocess so the import is genuinely fresh and reloading the core + # package can't contaminate other tests in this worker. Filter to warnings + # that reference the moved paths so unrelated third-party DeprecationWarnings + # emitted at import time don't make this flaky. + script = ( + "import warnings\n" + "with warnings.catch_warnings(record=True) as caught:\n" + " warnings.simplefilter('always')\n" + " import pyrit.models\n" + "offenders = [str(w.message) for w in caught\n" + " if issubclass(w.category, DeprecationWarning)\n" + " and ('pyrit.memory.storage' in str(w.message) or 'pyrit.models.storage_io' in str(w.message)\n" + " or 'pyrit.models.data_type_serializer' in str(w.message))]\n" + "assert not offenders, offenders\n" + ) + result = subprocess.run([sys.executable, "-c", script], capture_output=True, text=True) + assert result.returncode == 0, f"Importing pyrit.models warned about moved names:\n{result.stderr}" + + +async def test_message_piece_method_shim_warns_and_delegates(): + fake_self = MagicMock(spec=MessagePiece) + delegate = AsyncMock() + with patch.object(new_serializers, "set_message_piece_sha256_async", delegate): + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always", DeprecationWarning) + await MessagePiece.set_sha256_values_async(fake_self) + + delegate.assert_awaited_once_with(fake_self) + dep = [w for w in caught if issubclass(w.category, DeprecationWarning)] + assert len(dep) == 1 + message = str(dep[0].message) + assert "MessagePiece.set_sha256_values_async" in message + assert "pyrit.memory.storage.serializers.set_message_piece_sha256_async" in message + assert "0.17.0" in message + + +async def test_seed_method_shim_warns_and_delegates(): + fake_self = MagicMock(spec=Seed) + delegate = AsyncMock() + with patch.object(new_serializers, "set_seed_sha256_async", delegate): + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always", DeprecationWarning) + await Seed.set_sha256_value_async(fake_self) + + delegate.assert_awaited_once_with(fake_self) + dep = [w for w in caught if issubclass(w.category, DeprecationWarning)] + assert len(dep) == 1 + message = str(dep[0].message) + assert "Seed.set_sha256_value_async" in message + assert "pyrit.memory.storage.serializers.set_seed_sha256_async" in message + assert "0.17.0" in message diff --git a/tests/unit/models/test_data_type_serializer.py b/tests/unit/memory/storage/test_serializers.py similarity index 94% rename from tests/unit/models/test_data_type_serializer.py rename to tests/unit/memory/storage/test_serializers.py index 085575c5f7..4cd7daf704 100644 --- a/tests/unit/models/test_data_type_serializer.py +++ b/tests/unit/memory/storage/test_serializers.py @@ -11,7 +11,7 @@ import pytest from PIL import Image -from pyrit.models import ( +from pyrit.memory.storage import ( AllowedCategories, BinaryPathDataTypeSerializer, DataTypeSerializer, @@ -19,7 +19,10 @@ ImagePathDataTypeSerializer, TextDataTypeSerializer, data_serializer_factory, + set_message_piece_sha256_async, + set_seed_sha256_async, ) +from pyrit.models import MessagePiece, SeedPrompt def test_allowed_categories(): @@ -252,7 +255,7 @@ async def test_read_data_local_file_with_dummy_image(sqlite_instance): with open(image_path, "rb") as f: mock_storage_io.read_file_async.return_value = f.read() - with patch("pyrit.models.data_type_serializer.DiskStorageIO", return_value=mock_storage_io): + with patch("pyrit.memory.storage.serializers.DiskStorageIO", return_value=mock_storage_io): serializer = data_serializer_factory( category="prompt-memory-entries", data_type="image_path", value=image_path ) @@ -385,7 +388,7 @@ async def test_save_b64_image_raises_when_results_storage_io_none(): async def test_save_formatted_audio_raises_when_results_storage_io_none(): - from pyrit.models import data_serializer_factory as factory + from pyrit.memory.storage import data_serializer_factory as factory serializer = factory(category="prompt-memory-entries", data_type="audio_path") mock_memory = MagicMock() @@ -408,7 +411,7 @@ async def test_save_formatted_audio_writes_local_wav_via_to_thread(sqlite_instan """save_formatted_audio (local-disk path) should produce a readable WAV via _write_wav_sync.""" import wave - from pyrit.models import data_serializer_factory as factory + from pyrit.memory.storage import data_serializer_factory as factory serializer = factory(category="prompt-memory-entries", data_type="audio_path") output_path = tmp_path / "out.wav" @@ -434,7 +437,7 @@ def test_write_wav_sync_produces_readable_wav(tmp_path): """_write_wav_sync should produce a WAV file readable by wave.open with the same metadata and frames.""" import wave - from pyrit.models.data_type_serializer import _write_wav_sync + from pyrit.memory.storage.serializers import _write_wav_sync out_path = tmp_path / "direct.wav" pcm = b"\x10\x00\x20\x00\x30\x00\x40\x00" @@ -459,7 +462,7 @@ async def test_save_formatted_audio_writes_azure_wav_via_storage_io(sqlite_insta import wave from pyrit.common import path as common_path - from pyrit.models import data_serializer_factory as factory + from pyrit.memory.storage import data_serializer_factory as factory captured: dict[str, bytes] = {} @@ -480,7 +483,7 @@ async def _capture_write(file_path, data): with patch.object(serializer, "get_data_filename_async", new_callable=AsyncMock, return_value=azure_url): # Redirect so the temp_audio.wav write lands in tmp_path with patch.object(common_path, "DB_DATA_PATH", str(tmp_path)): - from pyrit.models import data_type_serializer as dts_module + from pyrit.memory.storage import serializers as dts_module with patch.object(dts_module, "DB_DATA_PATH", str(tmp_path)): await serializer.save_formatted_audio_async( @@ -600,7 +603,7 @@ async def test_get_data_filename_emits_deprecation_warning_and_delegates(sqlite_ async def test_save_formatted_audio_azure_storage_unlinks_local_temp(tmp_path): """save_formatted_audio_async cleans up the local temp WAV after writing to Azure storage.""" - from pyrit.models import data_serializer_factory as factory + from pyrit.memory.storage import data_serializer_factory as factory serializer = factory(category="prompt-memory-entries", data_type="audio_path") mock_memory = MagicMock() @@ -611,7 +614,7 @@ async def test_save_formatted_audio_azure_storage_unlinks_local_temp(tmp_path): with ( patch.object(type(serializer), "_memory", new_callable=PropertyMock, return_value=mock_memory), patch.object(serializer, "get_data_filename_async", new_callable=AsyncMock, return_value=azure_url), - patch("pyrit.models.data_type_serializer.DB_DATA_PATH", tmp_path), + patch("pyrit.memory.storage.serializers.DB_DATA_PATH", tmp_path), ): await serializer.save_formatted_audio_async(data=b"\x00\x01\x02\x03") @@ -622,7 +625,26 @@ async def test_save_formatted_audio_azure_storage_unlinks_local_temp(tmp_path): assert serializer.value == azure_url -@pytest.mark.asyncio +async def test_set_message_piece_sha256_async_sets_text_hashes(sqlite_instance): + piece = MessagePiece(role="user", original_value="Hello") + piece.original_value = "newvalue" + piece.converted_value = "newvalue" + + await set_message_piece_sha256_async(piece) + + expected = "70e01503173b8e904d53b40b3ebb3bded5e5d3add087d3463a4b1abe92f1a8ca" + assert piece.original_value_sha256 == expected + assert piece.converted_value_sha256 == expected + + +async def test_set_seed_sha256_async_sets_text_hash(sqlite_instance): + seed = SeedPrompt(value="Hello1", data_type="text") + + await set_seed_sha256_async(seed) + + assert seed.value_sha256 == "948edbe7ede5aa7423476ae29dcd7d61e7711a071aea0d83698377effa896525" + + async def test_save_formatted_audio_async_cleans_up_temp_file_on_azure_upload_failure(tmp_path): """Regression test: temp file must be deleted even when Azure upload fails.""" serializer = data_serializer_factory(category="prompt-memory-entries", data_type="audio_path") @@ -639,7 +661,7 @@ async def test_save_formatted_audio_async_cleans_up_temp_file_on_azure_upload_fa with patch.object(type(serializer), "_memory", new_callable=PropertyMock, return_value=mock_memory): with patch.object(serializer, "get_data_filename_async", new_callable=AsyncMock, return_value=azure_url): - with patch("pyrit.models.data_type_serializer.DB_DATA_PATH", tmp_path): + with patch("pyrit.memory.storage.serializers.DB_DATA_PATH", tmp_path): with pytest.raises(RuntimeError, match="Azure upload failed"): await serializer.save_formatted_audio_async(data=b"\x00\x01\x02") diff --git a/tests/unit/models/test_storage_io.py b/tests/unit/memory/storage/test_storage.py similarity index 99% rename from tests/unit/models/test_storage_io.py rename to tests/unit/memory/storage/test_storage.py index 0adde24a75..6a29821e5f 100644 --- a/tests/unit/models/test_storage_io.py +++ b/tests/unit/memory/storage/test_storage.py @@ -6,7 +6,7 @@ import pytest -from pyrit.models.storage_io import ( +from pyrit.memory.storage.storage import ( AzureBlobStorageIO, DiskStorageIO, SupportedContentType, diff --git a/tests/unit/models/test_import_boundary.py b/tests/unit/models/test_import_boundary.py index 79495a5f99..bcb52c1cd6 100644 --- a/tests/unit/models/test_import_boundary.py +++ b/tests/unit/models/test_import_boundary.py @@ -56,22 +56,16 @@ # ratchet, tracked separately so the phase that removes the lazy workaround is # explicit. KNOWN_LAZY_VIOLATIONS: dict[str, dict[str, str]] = { - "pyrit.models.data_type_serializer": { - "pyrit.memory": "phase-9", - }, "pyrit.models.identifiers.evaluation_identifier": { "pyrit.executor.attack.core.attack_strategy": "phase-7", }, - "pyrit.models.storage_io": { - "pyrit.auth": "phase-9", - }, } # Reverse-guard violations: pyrit.common modules that still reach up into higher # layers. These are slated to relocate; the ratchet forces them to shrink. KNOWN_COMMON_VIOLATIONS: dict[str, dict[str, str]] = { "pyrit.common.data_url_converter": { - "pyrit.models": "relocate", + "pyrit.memory": "relocate", }, "pyrit.common.display_response": { "pyrit.memory": "relocate", diff --git a/tests/unit/output/test_blur_images.py b/tests/unit/output/test_blur_images.py index 538e3fb8a0..1e60046c8e 100644 --- a/tests/unit/output/test_blur_images.py +++ b/tests/unit/output/test_blur_images.py @@ -53,7 +53,7 @@ async def test_pretty_blurs_image_bytes_before_display(tmp_path, patch_central_d with ( patch("pyrit.common.notebook_utils.is_in_ipython_session", return_value=True), patch( - "pyrit.models.data_type_serializer.ImagePathDataTypeSerializer", + "pyrit.memory.storage.serializers.ImagePathDataTypeSerializer", return_value=fake_serializer, ), patch( @@ -93,7 +93,7 @@ async def test_pretty_does_not_blur_by_default(tmp_path, patch_central_database) with ( patch("pyrit.common.notebook_utils.is_in_ipython_session", return_value=True), patch( - "pyrit.models.data_type_serializer.ImagePathDataTypeSerializer", + "pyrit.memory.storage.serializers.ImagePathDataTypeSerializer", return_value=fake_serializer, ), patch( diff --git a/tests/unit/prompt_converter/test_pdf_converter.py b/tests/unit/prompt_converter/test_pdf_converter.py index b40daf88c5..9ba93066d4 100644 --- a/tests/unit/prompt_converter/test_pdf_converter.py +++ b/tests/unit/prompt_converter/test_pdf_converter.py @@ -11,7 +11,8 @@ from reportlab.lib.pagesizes import A4 from reportlab.pdfgen import canvas -from pyrit.models import DataTypeSerializer, SeedPrompt +from pyrit.memory import DataTypeSerializer +from pyrit.models import SeedPrompt from pyrit.prompt_converter import ConverterResult, PDFConverter diff --git a/uv.lock b/uv.lock index 31454866e2..e4ecfb9887 100644 --- a/uv.lock +++ b/uv.lock @@ -5150,7 +5150,7 @@ wheels = [ [[package]] name = "pyrit" -version = "0.14.0.dev0" +version = "0.15.0.dev0" source = { editable = "." } dependencies = [ { name = "aiofiles" },