import traceback from typing import Optional from fastapi import HTTPException import litellm from litellm._logging import verbose_proxy_logger from litellm.caching.caching import DualCache from litellm.integrations.custom_logger import CustomLogger from litellm.proxy._types import UserAPIKeyAuth class _PROXY_AzureContentSafety( CustomLogger ): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class # Class variables or attributes def __init__(self, endpoint, api_key, thresholds=None): try: from azure.ai.contentsafety.aio import ContentSafetyClient from azure.ai.contentsafety.models import ( AnalyzeTextOptions, AnalyzeTextOutputType, TextCategory, ) from azure.core.credentials import AzureKeyCredential from azure.core.exceptions import HttpResponseError except Exception as e: raise Exception( f"\033[91mAzure Content-Safety not installed, try running 'pip install azure-ai-contentsafety' to fix this error: {e}\n{traceback.format_exc()}\033[0m" ) self.endpoint = endpoint self.api_key = api_key self.text_category = TextCategory self.analyze_text_options = AnalyzeTextOptions self.analyze_text_output_type = AnalyzeTextOutputType self.azure_http_error = HttpResponseError self.thresholds = self._configure_thresholds(thresholds) self.client = ContentSafetyClient( self.endpoint, AzureKeyCredential(self.api_key) ) def _configure_thresholds(self, thresholds=None): default_thresholds = { self.text_category.HATE: 4, self.text_category.SELF_HARM: 4, self.text_category.SEXUAL: 4, self.text_category.VIOLENCE: 4, } if thresholds is None: return default_thresholds for key, default in default_thresholds.items(): if key not in thresholds: thresholds[key] = default return thresholds def _compute_result(self, response): result = {} category_severity = { item.category: item.severity for item in response.categories_analysis } for category in self.text_category: severity = category_severity.get(category) if severity is not None: result[category] = { "filtered": severity >= self.thresholds[category], "severity": severity, } return result async def test_violation(self, content: str, source: Optional[str] = None): verbose_proxy_logger.debug("Testing Azure Content-Safety for: %s", content) # Construct a request request = self.analyze_text_options( text=content, output_type=self.analyze_text_output_type.EIGHT_SEVERITY_LEVELS, ) # Analyze text try: response = await self.client.analyze_text(request) except self.azure_http_error: verbose_proxy_logger.debug( "Error in Azure Content-Safety: %s", traceback.format_exc() ) verbose_proxy_logger.debug(traceback.format_exc()) raise result = self._compute_result(response) verbose_proxy_logger.debug("Azure Content-Safety Result: %s", result) for key, value in result.items(): if value["filtered"]: raise HTTPException( status_code=400, detail={ "error": "Violated content safety policy", "source": source, "category": key, "severity": value["severity"], }, ) async def async_pre_call_hook( self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: str, # "completion", "embeddings", "image_generation", "moderation" ): verbose_proxy_logger.debug("Inside Azure Content-Safety Pre-Call Hook") try: if call_type == "completion" and "messages" in data: for m in data["messages"]: if "content" in m and isinstance(m["content"], str): await self.test_violation(content=m["content"], source="input") except HTTPException as e: raise e except Exception as e: verbose_proxy_logger.error( "litellm.proxy.hooks.azure_content_safety.py::async_pre_call_hook(): Exception occured - {}".format( str(e) ) ) verbose_proxy_logger.debug(traceback.format_exc()) async def async_post_call_success_hook( self, data: dict, user_api_key_dict: UserAPIKeyAuth, response, ): verbose_proxy_logger.debug("Inside Azure Content-Safety Post-Call Hook") if isinstance(response, litellm.ModelResponse) and isinstance( response.choices[0], litellm.utils.Choices ): await self.test_violation( content=response.choices[0].message.content or "", source="output" ) # async def async_post_call_streaming_hook( # self, # user_api_key_dict: UserAPIKeyAuth, # response: str, # ): # verbose_proxy_logger.debug("Inside Azure Content-Safety Call-Stream Hook") # await self.test_violation(content=response, source="output")