Spaces:
Configuration error
Configuration error
# +------------------------+ | |
# | |
# LLM Guard | |
# https://llm-guard.com/ | |
# | |
# +------------------------+ | |
# Thank you users! We ❤️ you! - Krrish & Ishaan | |
## This provides an LLM Guard Integration for content moderation on the proxy | |
from typing import Literal, Optional | |
import aiohttp | |
from fastapi import HTTPException | |
import litellm | |
from litellm._logging import verbose_proxy_logger | |
from litellm.integrations.custom_logger import CustomLogger | |
from litellm.proxy._types import UserAPIKeyAuth | |
from litellm.secret_managers.main import get_secret_str | |
from litellm.utils import get_formatted_prompt | |
litellm.set_verbose = True | |
class _ENTERPRISE_LLMGuard(CustomLogger): | |
# Class variables or attributes | |
def __init__( | |
self, | |
mock_testing: bool = False, | |
mock_redacted_text: Optional[dict] = None, | |
): | |
self.mock_redacted_text = mock_redacted_text | |
self.llm_guard_mode = litellm.llm_guard_mode | |
if mock_testing is True: # for testing purposes only | |
return | |
self.llm_guard_api_base = get_secret_str("LLM_GUARD_API_BASE", None) | |
if self.llm_guard_api_base is None: | |
raise Exception("Missing `LLM_GUARD_API_BASE` from environment") | |
elif not self.llm_guard_api_base.endswith("/"): | |
self.llm_guard_api_base += "/" | |
def print_verbose(self, print_statement): | |
try: | |
verbose_proxy_logger.debug(print_statement) | |
if litellm.set_verbose: | |
print(print_statement) # noqa | |
except Exception: | |
pass | |
async def moderation_check(self, text: str): | |
""" | |
[TODO] make this more performant for high-throughput scenario | |
""" | |
try: | |
async with aiohttp.ClientSession() as session: | |
if self.mock_redacted_text is not None: | |
redacted_text = self.mock_redacted_text | |
else: | |
# Make the first request to /analyze | |
analyze_url = f"{self.llm_guard_api_base}analyze/prompt" | |
verbose_proxy_logger.debug("Making request to: %s", analyze_url) | |
analyze_payload = {"prompt": text} | |
redacted_text = None | |
async with session.post( | |
analyze_url, json=analyze_payload | |
) as response: | |
redacted_text = await response.json() | |
verbose_proxy_logger.info( | |
f"LLM Guard: Received response - {redacted_text}" | |
) | |
if redacted_text is not None: | |
if ( | |
redacted_text.get("is_valid", None) is not None | |
and redacted_text["is_valid"] is False | |
): | |
raise HTTPException( | |
status_code=400, | |
detail={"error": "Violated content safety policy"}, | |
) | |
else: | |
pass | |
else: | |
raise HTTPException( | |
status_code=500, | |
detail={ | |
"error": f"Invalid content moderation response: {redacted_text}" | |
}, | |
) | |
except Exception as e: | |
verbose_proxy_logger.exception( | |
"litellm.enterprise.enterprise_hooks.llm_guard::moderation_check - Exception occurred - {}".format( | |
str(e) | |
) | |
) | |
raise e | |
def should_proceed(self, user_api_key_dict: UserAPIKeyAuth, data: dict) -> bool: | |
if self.llm_guard_mode == "key-specific": | |
# check if llm guard enabled for specific keys only | |
self.print_verbose( | |
f"user_api_key_dict.permissions: {user_api_key_dict.permissions}" | |
) | |
if ( | |
user_api_key_dict.permissions.get("enable_llm_guard_check", False) | |
is True | |
): | |
return True | |
elif self.llm_guard_mode == "all": | |
return True | |
elif self.llm_guard_mode == "request-specific": | |
self.print_verbose(f"received metadata: {data.get('metadata', {})}") | |
metadata = data.get("metadata", {}) | |
permissions = metadata.get("permissions", {}) | |
if ( | |
"enable_llm_guard_check" in permissions | |
and permissions["enable_llm_guard_check"] is True | |
): | |
return True | |
return False | |
async def async_moderation_hook( | |
self, | |
data: dict, | |
user_api_key_dict: UserAPIKeyAuth, | |
call_type: Literal[ | |
"completion", | |
"embeddings", | |
"image_generation", | |
"moderation", | |
"audio_transcription", | |
"responses", | |
], | |
): | |
""" | |
- Calls the LLM Guard Endpoint | |
- Rejects request if it fails safety check | |
- Use the sanitized prompt returned | |
- LLM Guard can handle things like PII Masking, etc. | |
""" | |
self.print_verbose( | |
f"Inside LLM Guard Pre-Call Hook - llm_guard_mode={self.llm_guard_mode}" | |
) | |
_proceed = self.should_proceed(user_api_key_dict=user_api_key_dict, data=data) | |
if _proceed is False: | |
return | |
self.print_verbose("Makes LLM Guard Check") | |
try: | |
assert call_type in [ | |
"completion", | |
"embeddings", | |
"image_generation", | |
"moderation", | |
"audio_transcription", | |
] | |
except Exception: | |
self.print_verbose( | |
f"Call Type - {call_type}, not in accepted list - ['completion','embeddings','image_generation','moderation','audio_transcription']" | |
) | |
return data | |
formatted_prompt = get_formatted_prompt(data=data, call_type=call_type) # type: ignore | |
self.print_verbose(f"LLM Guard, formatted_prompt: {formatted_prompt}") | |
return await self.moderation_check(text=formatted_prompt) | |
async def async_post_call_streaming_hook( | |
self, user_api_key_dict: UserAPIKeyAuth, response: str | |
): | |
if response is not None: | |
await self.moderation_check(text=response) | |
return response | |
# llm_guard = _ENTERPRISE_LLMGuard() | |
# asyncio.run( | |
# llm_guard.async_moderation_hook( | |
# data={"messages": [{"role": "user", "content": "Hey how's it going?"}]} | |
# ) | |
# ) | |