Skip to content

Commit

Permalink
Fix XPIAOrchestrator Blob Not Found Exception (#694)
Browse files Browse the repository at this point in the history
Co-authored-by: Raja Sekhar Rao Dheekonda <[email protected]>
Co-authored-by: Roman Lutz <[email protected]>
  • Loading branch information
3 people authored Feb 12, 2025
1 parent 6d07f5b commit 3dbd738
Show file tree
Hide file tree
Showing 7 changed files with 484 additions and 738 deletions.
1,105 changes: 398 additions & 707 deletions doc/code/orchestrators/3_xpia_orchestrator.ipynb

Large diffs are not rendered by default.

5 changes: 2 additions & 3 deletions doc/code/orchestrators/3_xpia_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
# This is to simulate a processing target with a plugin similar to what one might expect in an XPIA-oriented AI red teaming operation.

# %%

from xpia_helpers import AzureStoragePlugin, SemanticKernelPluginAzureOpenAIPromptTarget

from pyrit.common import IN_MEMORY, initialize_pyrit
Expand Down Expand Up @@ -82,7 +81,6 @@
#
# Finally, we can put all the pieces together:
# %%

from pyrit.orchestrator import XPIATestOrchestrator
from pyrit.prompt_target import AzureBlobStorageTarget
from pyrit.score import SubStringScorer
Expand All @@ -107,7 +105,6 @@
# Clean up storage container

# %%

import os

from xpia_helpers import AzureStoragePlugin
Expand All @@ -120,3 +117,5 @@

memory = CentralMemory.get_memory_instance()
memory.dispose_engine()

# %%
60 changes: 47 additions & 13 deletions doc/code/orchestrators/xpia_helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from typing import Any, Optional
from urllib.parse import urlparse

from azure.storage.blob.aio import ContainerClient as AsyncContainerClient
from openai import AsyncAzureOpenAI
Expand Down Expand Up @@ -84,19 +85,19 @@ def __init__(

self._kernel = Kernel()

service_id = "chat"
self._service_id = "chat"

self._kernel.add_service(
AzureChatCompletion(
service_id=service_id, deployment_name=self._deployment_name, async_client=self._async_client
service_id=self._service_id, deployment_name=self._deployment_name, async_client=self._async_client
),
)

self._plugin_name = plugin_name
self._kernel.import_plugin_from_object(plugin, plugin_name)
self._kernel.add_plugin(plugin, plugin_name)

self._execution_settings = AzureChatPromptExecutionSettings(
service_id=service_id,
service_id=self._service_id,
ai_model_id=self._deployment_name,
max_tokens=max_tokens,
temperature=temperature,
Expand Down Expand Up @@ -136,16 +137,36 @@ async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> P
template=request.converted_value,
name=self._plugin_name,
template_format="semantic-kernel",
execution_settings=self._execution_settings,
execution_settings={self._service_id: self._execution_settings},
)
processing_function = self._kernel.create_function_from_prompt(
processing_function = self._kernel.add_function(
function_name="processingFunc", plugin_name=self._plugin_name, prompt_template_config=prompt_template_config
)
processing_output = await self._kernel.invoke(processing_function)
processing_output = str(processing_output)
processing_output = await self._kernel.invoke(processing_function) # type: ignore
if processing_output is None:
raise ValueError("Processing function returned None unexpectedly.")
try:
inner_content = processing_output.get_inner_content()

if (
not hasattr(inner_content, "choices")
or not isinstance(inner_content.choices, list)
or not inner_content.choices
):
raise ValueError("Invalid response: 'choices' is missing or empty.")

first_choice = inner_content.choices[0]

if not hasattr(first_choice, "message") or not hasattr(first_choice.message, "content"):
raise ValueError("Invalid response: 'message' or 'content' is missing in choices[0].")

processing_output = first_choice.message.content

except AttributeError as e:
raise ValueError(f"Unexpected structure in processing_output: {e}")
logger.info(f'Received the following response from the prompt target "{processing_output}"')

response = construct_response_from_request(request=request, response_text_pieces=[processing_output])
response = construct_response_from_request(request=request, response_text_pieces=[str(processing_output)])
return response

def _validate_request(self, *, prompt_request: PromptRequestResponse) -> None:
Expand Down Expand Up @@ -182,16 +203,17 @@ async def _create_container_client_async(self) -> None:
"""Creates an asynchronous ContainerClient for Azure Storage. If a SAS token is provided via the
AZURE_STORAGE_ACCOUNT_SAS_TOKEN environment variable or the init sas_token parameter, it will be used
for authentication. Otherwise, a delegation SAS token will be created using Entra ID authentication."""
container_url, _ = self._parse_url()
try:
sas_token: str = default_values.get_required_value(
env_var_name=self.SAS_TOKEN_ENVIRONMENT_VARIABLE, passed_value=self._sas_token
)
logger.info("Using SAS token from environment variable or passed parameter.")
except ValueError:
logger.info("SAS token not provided. Creating a delegation SAS token using Entra ID authentication.")
sas_token = await AzureStorageAuth.get_sas_token(self._container_url)
sas_token = await AzureStorageAuth.get_sas_token(container_url)
self._storage_client = AsyncContainerClient.from_container_url(
container_url=self._container_url,
container_url=container_url,
credential=sas_token,
)

Expand All @@ -204,8 +226,10 @@ async def download_async(self) -> str:
await self._create_container_client_async()

all_blobs = ""
# Parse the Azure Storage Blob URL to extract components
_, blob_prefix = self._parse_url()
async with self._storage_client as client:
async for blob in client.list_blobs():
async for blob in client.list_blobs(name_starts_with=blob_prefix):
logger.info(f"Downloading Azure storage blob {blob.name}")
blob_client = client.get_blob_client(blob=blob.name)
blob_data = await blob_client.download_blob()
Expand All @@ -223,11 +247,21 @@ async def delete_blobs_async(self):
await self._create_container_client_async()
logger.info("Deleting all blobs in the container.")
try:
_, blob_prefix = self._parse_url()
async with self._storage_client as client:
async for blob in client.list_blobs():
async for blob in client.list_blobs(name_starts_with=blob_prefix):
print("blob name is given as", blob.name)
await client.get_blob_client(blob=blob.name).delete_blob()
logger.info(f"Deleted blob: {blob.name}")
except Exception as ex:
logger.exception(msg=f"An error occurred while deleting blobs: {ex}")
raise

def _parse_url(self):
"""Parses the Azure Storage Blob URL to extract components."""
parsed_url = urlparse(self._container_url)
path_parts = parsed_url.path.split("/")
container_name = path_parts[1]
blob_prefix = "/".join(path_parts[2:])
container_url = f"https://{parsed_url.netloc}/{container_name}"
return container_url, blob_prefix
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ dev = [
"pytest-asyncio>=0.23.5",
"pytest-cov>=4.0.0",
"respx>=0.22.0",
"semantic-kernel==0.9.4b1",
"semantic-kernel>=1.20.0",
"types-PyYAML>=6.0.12.9",
]
torch = [
Expand Down Expand Up @@ -131,7 +131,7 @@ all = [
"pytest-asyncio>=0.23.5",
"pytest-cov>=4.0.0",
"respx>=0.20.2",
"semantic-kernel==0.9.4b1",
"semantic-kernel>=1.20.0",
"sentencepiece==0.2.0",
"torch>=2.3.0",
"playwright==1.49.0",
Expand Down
5 changes: 5 additions & 0 deletions pyrit/models/data_type_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ def _get_storage_io(self):
ValueError: If the Azure Storage URL is detected but the datasets storage handle is not set.
"""
if self._is_azure_storage_url(self.value):
# Scenarios where a user utilizes an in-memory DuckDB but also needs to interact
# with an Azure Storage Account, ex., XPIAOrchestrator.
from pyrit.common import AZURE_SQL, initialize_pyrit

initialize_pyrit(memory_db_type=AZURE_SQL)
return self._memory.results_storage_io
return DiskStorageIO()

Expand Down
28 changes: 18 additions & 10 deletions pyrit/prompt_target/azure_blob_storage_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
from enum import Enum
from typing import Optional
from urllib.parse import urlparse

from azure.core.exceptions import ClientAuthenticationError
from azure.storage.blob import ContentSettings
Expand Down Expand Up @@ -69,17 +70,17 @@ async def _create_container_client_async(self) -> None:
"""Creates an asynchronous ContainerClient for Azure Storage. If a SAS token is provided via the
AZURE_STORAGE_ACCOUNT_SAS_TOKEN environment variable or the init sas_token parameter, it will be used
for authentication. Otherwise, a delegation SAS token will be created using Entra ID authentication."""
container_url, _ = self._parse_url()
try:
sas_token: str = default_values.get_required_value(
env_var_name=self.SAS_TOKEN_ENVIRONMENT_VARIABLE, passed_value=self._sas_token
)
logger.info("Using SAS token from environment variable or passed parameter.")
except ValueError:
logger.info("SAS token not provided. Creating a delegation SAS token using Entra ID authentication.")
sas_token = await AzureStorageAuth.get_sas_token(self._container_url)

sas_token = await AzureStorageAuth.get_sas_token(container_url)
self._client_async = AsyncContainerClient.from_container_url(
container_url=self._container_url,
container_url=container_url,
credential=sas_token,
)

Expand All @@ -98,14 +99,12 @@ async def _upload_blob_async(self, file_name: str, data: bytes, content_type: st

if not self._client_async:
await self._create_container_client_async()

# Parse the Azure Storage Blob URL to extract components
_, blob_prefix = self._parse_url()
blob_path = f"{blob_prefix}/{file_name}"
try:
await self._client_async.upload_blob(
name=file_name,
data=data,
content_settings=content_settings,
overwrite=True,
)
blob_client = self._client_async.get_blob_client(blob=blob_path)
await blob_client.upload_blob(data=data, content_settings=content_settings)
except Exception as exc:
if isinstance(exc, ClientAuthenticationError):
logger.exception(
Expand All @@ -119,6 +118,15 @@ async def _upload_blob_async(self, file_name: str, data: bytes, content_type: st
logger.exception(msg=f"An unexpected error occurred: {exc}")
raise

def _parse_url(self):
"""Parses the Azure Storage Blob URL to extract components."""
parsed_url = urlparse(self._container_url)
path_parts = parsed_url.path.split("/")
container_name = path_parts[1]
blob_prefix = "/".join(path_parts[2:])
container_url = f"https://{parsed_url.netloc}/{container_name}"
return container_url, blob_prefix

@limit_requests_per_minute
async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> PromptRequestResponse:
"""
Expand Down
15 changes: 12 additions & 3 deletions tests/unit/target/test_prompt_target_azure_blob_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
# Licensed under the MIT license.

import os
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from azure.storage.blob.aio import BlobClient as AsyncBlobClient
from azure.storage.blob.aio import ContainerClient as AsyncContainerClient
from unit.mocks import get_sample_conversations

Expand Down Expand Up @@ -94,15 +95,22 @@ async def test_azure_blob_storage_validate_prev_convs(


@pytest.mark.asyncio
@patch.object(AsyncContainerClient, "upload_blob", new_callable=AsyncMock)
@patch.object(AzureBlobStorageTarget, "_create_container_client_async", new_callable=AsyncMock)
@patch.object(AsyncBlobClient, "upload_blob", new_callable=AsyncMock)
@patch.object(AsyncContainerClient, "get_blob_client", new_callable=MagicMock)
async def test_send_prompt_async(
mock_create_client,
mock_get_blob_client,
mock_upload_blob,
mock_create_client,
azure_blob_storage_target: AzureBlobStorageTarget,
sample_entries: list[PromptRequestPiece],
):
mock_blob_client = AsyncMock()
mock_get_blob_client.return_value = mock_blob_client

mock_blob_client.upload_blob = mock_upload_blob
mock_upload_blob.return_value = None

azure_blob_storage_target._client_async = AsyncContainerClient.from_container_url(
container_url=azure_blob_storage_target._container_url, credential="mocked_sas_token"
)
Expand All @@ -112,6 +120,7 @@ async def test_send_prompt_async(
request = PromptRequestResponse([request_piece])

response = await azure_blob_storage_target.send_prompt_async(prompt_request=request)

assert response
blob_url = response.request_pieces[0].converted_value
assert azure_blob_storage_target._container_url in blob_url
Expand Down

0 comments on commit 3dbd738

Please sign in to comment.