File size: 14,656 Bytes
447ebeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
# litellm/proxy/guardrails/guardrail_hooks/pangea.py
import os
from typing import Any, Optional, Protocol

from fastapi import HTTPException

from litellm._logging import verbose_proxy_logger
from litellm.caching.dual_cache import DualCache
from litellm.integrations.custom_guardrail import (
    CustomGuardrail,
    log_guardrail_information,
)

from litellm.llms.custom_httpx.http_handler import (
    get_async_httpx_client,
    httpxSpecialProvider,
)
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.common_utils.callback_utils import (
    add_guardrail_to_applied_guardrails_header,
)
from litellm.types.guardrails import GuardrailEventHooks
from litellm.types.utils import LLMResponseTypes, ModelResponse, TextCompletionResponse


class PangeaGuardrailMissingSecrets(Exception):
    """Custom exception for missing Pangea secrets."""

    pass


class _Transformer(Protocol):
    def get_messages(self) -> list[dict]:
        ...

    def update_original_body(self, prompt_messages: list[dict]) -> Any:
        ...


class _TextCompletionRequest:
    def __init__(self, body):
        self.body = body

    def get_messages(self) -> list[dict]:
        return [{"role": "user", "content": self.body["prompt"]}]

    # This mutates the original dict, but we'll still return it anyways
    def update_original_body(self, prompt_messages: list[dict]) -> Any:
        assert len(prompt_messages) == 1
        self.body["prompt"] = prompt_messages[0]["content"]
        return self.body


class _TextCompletionResponse:
    def __init__(self, body):
        self.body = body

    def get_messages(self) -> list[dict]:
        messages = []
        for choice in self.body["choices"]:
            messages.append({"role": "assistant", "content": choice["text"]})

        return messages

    def update_original_body(self, prompt_messages: list[dict]) -> Any:
        assert len(prompt_messages) == len(self.body["choices"])

        for choice, prompt_message in zip(self.body["choices"], prompt_messages):
            choice["text"] = prompt_message["content"]

        return self.body


class _ChatCompletionRequest:
    def __init__(self, body):
        self.body = body

    def get_messages(self) -> list[dict]:
        messages = []

        for message in self.body["messages"]:
            role = message["role"]
            content = message["content"]
            if isinstance(content, str):
                messages.append({"role": role, "content": content})
            if isinstance(content, list):
                for content_part in content:
                    if content_part["type"] == "text":
                        messages.append({"role": role, "content": content_part["text"]})

        return messages

    def update_original_body(self, prompt_messages: list[dict]) -> Any:
        count = 0

        for message in self.body["messages"]:
            content = message["content"]
            if isinstance(content, str):
                message["content"] = prompt_messages[count]["content"]
                count += 1
            if isinstance(content, list):
                for content_part in content:
                    if content_part["type"] == "text":
                        content_part["text"] = prompt_messages[count]["content"]
                        count += 1

        assert len(prompt_messages) == count
        return self.body


class _ChatCompletionResponse:
    def __init__(self, body):
        self.body = body

    def get_messages(self) -> list[dict]:
        messages = []

        for choice in self.body["choices"]:
            messages.append(
                {
                    "role": choice["message"]["role"],
                    "content": choice["message"]["content"],
                }
            )

        return messages

    def update_original_body(self, prompt_messages: list[dict]) -> Any:
        assert len(prompt_messages) == len(self.body["choices"])

        for choice, prompt_message in zip(self.body["choices"], prompt_messages):
            choice["message"]["content"] = prompt_message["content"]

        return self.body


def _get_transformer_for_request(body, call_type) -> Optional[_Transformer]:
    match call_type:
        case "text_completion" | "atext_completion":
            return _TextCompletionRequest(body)
        case "completion" | "acompletion":
            return _ChatCompletionRequest(body)

    return None


def _get_transformer_for_response(body) -> Optional[_Transformer]:
    match body:
        case TextCompletionResponse():
            return _TextCompletionResponse(body)
        case ModelResponse():
            return _ChatCompletionResponse(body)

    return None


class PangeaHandler(CustomGuardrail):
    """
    Pangea AI Guardrail handler to interact with the Pangea AI Guard service.

    This class implements the necessary hooks to call the Pangea AI Guard API
    for input and output scanning based on the configured recipe.
    """

    def __init__(
        self,
        guardrail_name: str,
        pangea_input_recipe: Optional[str] = None,
        pangea_output_recipe: Optional[str] = None,
        api_key: Optional[str] = None,
        api_base: Optional[str] = None,
        **kwargs,
    ):
        """
        Initializes the PangeaHandler.

        Args:
            guardrail_name (str): The name of the guardrail instance.
            pangea_recipe (str): The Pangea recipe key to use for scanning.
            api_key (Optional[str]): The Pangea API key. Reads from PANGEA_API_KEY env var if None.
            api_base (Optional[str]): The Pangea API base URL. Reads from PANGEA_API_BASE env var or uses default if None.
            **kwargs: Additional arguments passed to the CustomGuardrail base class.
        """
        self.async_handler = get_async_httpx_client(
            llm_provider=httpxSpecialProvider.GuardrailCallback
        )
        self.api_key = api_key or os.environ.get("PANGEA_API_KEY")
        if not self.api_key:
            raise PangeaGuardrailMissingSecrets(
                "Pangea API Key not found. Set PANGEA_API_KEY environment variable or pass it in litellm_params."
            )

        # Default Pangea base URL if not provided
        self.api_base = (
            api_base
            or os.environ.get("PANGEA_API_BASE")
            or "https://ai-guard.aws.us.pangea.cloud"
        )
        self.pangea_input_recipe = pangea_input_recipe
        self.pangea_output_recipe = pangea_output_recipe
        self.guardrail_endpoint = f"{self.api_base}/v1/text/guard"

        # Pass relevant kwargs to the parent class
        super().__init__(guardrail_name=guardrail_name, **kwargs)
        verbose_proxy_logger.info(
            f"Initialized Pangea Guardrail: name={guardrail_name}, recipe={pangea_input_recipe}, api_base={self.api_base}"
        )

    async def _call_pangea_guard(self, payload: dict, hook_name: str) -> dict:
        """
        Makes the API call to the Pangea AI Guard endpoint.
        The function itself will raise an error in the case that a response
        should be blocked, but will return a list of redacted messages that the caller
        should act on.

        Args:
            payload (dict): The request payload.
            request_data (dict): Original request data (used for logging/headers).
            hook_name (str): Name of the hook calling this function (for logging).

        Raises:
            HTTPException: If the Pangea API returns a 'blocked: true' response.
            Exception: For other API call failures.

        Returns:
            list[dict]: The original response body
        """
        headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json",
        }
        try:
            verbose_proxy_logger.debug(
                f"Pangea Guardrail ({hook_name}): Calling endpoint {self.guardrail_endpoint} with payload: {payload}"
            )
            response = await self.async_handler.post(
                url=self.guardrail_endpoint, json=payload, headers=headers
            )
            response.raise_for_status()  # Raise HTTPError for bad responses (4xx or 5xx)

            result = response.json()
            verbose_proxy_logger.debug(
                f"Pangea Guardrail ({hook_name}): Received response: {result}"
            )

            # Check if the request was blocked
            if result.get("result", {}).get("blocked") is True:
                verbose_proxy_logger.warning(
                    f"Pangea Guardrail ({hook_name}): Request blocked. Response: {result}"
                )
                raise HTTPException(
                    status_code=400,  # Bad Request, indicating violation
                    detail={
                        "error": "Violated Pangea guardrail policy",
                        "guardrail_name": self.guardrail_name,
                        "pangea_response": result.get("result"),
                    },
                )
            else:
                verbose_proxy_logger.info(
                    f"Pangea Guardrail ({hook_name}): Request passed. Response: {result.get('result', {}).get('detectors')}"
                )

            return result

        except HTTPException as e:
            # Re-raise HTTPException if it's the one we raised for blocking
            raise e
        except Exception as e:
            verbose_proxy_logger.error(
                f"Pangea Guardrail ({hook_name}): Error calling API: {e}. Response text: {getattr(e, 'response', None) and getattr(e.response, 'text', None)}"  # type: ignore
            )
            # Decide if you want to block by default on error, or allow through
            # Raising an exception here will block the request.
            # To allow through on error, you might just log and return.
            raise HTTPException(
                status_code=500,
                detail={
                    "error": "Error communicating with Pangea Guardrail",
                    "guardrail_name": self.guardrail_name,
                    "exception": str(e),
                },
            ) from e

    @log_guardrail_information
    async def async_pre_call_hook(
        self,
        user_api_key_dict: UserAPIKeyAuth,
        cache: DualCache,
        data: dict,
        call_type: str,
    ):
        event_type = GuardrailEventHooks.pre_call
        if self.should_run_guardrail(data=data, event_type=event_type) is not True:
            verbose_proxy_logger.debug(
                f"Pangea Guardrail (async_pre_call_hook): Guardrail is disabled {self.guardrail_name}."
            )
            return data

        transformer = _get_transformer_for_request(data, call_type)
        if not transformer:
            verbose_proxy_logger.warning(
                f"Pangea Guardrail (async_pre_call_hook): Skipping guardrail {self.guardrail_name}"
                f" because we cannot determine type of request: call_type '{call_type}'"
            )
            return

        messages = transformer.get_messages()
        if not messages:
            verbose_proxy_logger.warning(
                f"Pangea Guardrail (async_pre_call_hook): Skipping guardrail {self.guardrail_name}"
                " because messages is empty."
            )
            return

        ai_guard_payload = {
            "debug": False,  # Or make this configurable if needed
            "messages": messages,
        }
        if self.pangea_input_recipe:
            ai_guard_payload["recipe"] = self.pangea_input_recipe

        ai_guard_response = await self._call_pangea_guard(
            ai_guard_payload, "async_pre_call_hook"
        )
        # Add guardrail name to header if passed
        add_guardrail_to_applied_guardrails_header(
            request_data=data, guardrail_name=self.guardrail_name
        )
        prompt_messages = ai_guard_response.get("result", {}).get("prompt_messages", [])

        try:
            return transformer.update_original_body(prompt_messages)
        except Exception as e:
            raise HTTPException(
                status_code=500,
                detail={
                    "error": "Failed to update original request body",
                    "guardrail_name": self.guardrail_name,
                    "exceptions": str(e),
                },
            ) from e

    @log_guardrail_information
    async def async_post_call_success_hook(
        self,
        data: dict,
        user_api_key_dict: UserAPIKeyAuth,
        # This union isn't actually correct -- it can get other response types depending on the API called
        response: LLMResponseTypes,
    ):
        """
        Guardrail hook run after a successful LLM call (scans output).

        Args:
            data (dict): The original request data.
            user_api_key_dict (UserAPIKeyAuth): User API key details.
            response (LLMResponseTypes): The response object from the LLM call.
        """
        event_type = GuardrailEventHooks.post_call
        if self.should_run_guardrail(data=data, event_type=event_type) is not True:
            verbose_proxy_logger.debug(
                f"Pangea Guardrail (async_pre_call_hook): Guardrail is disabled {self.guardrail_name}."
            )
            return data

        transformer = _get_transformer_for_response(response)
        if not transformer:
            verbose_proxy_logger.warning(
                f"Pangea Guardrail (async_post_call_success_hook): Skipping guardrail {self.guardrail_name}"
                " because we cannot determine type of request"
            )
            return

        messages = transformer.get_messages()
        verbose_proxy_logger.warning(f"GOT MESSAGES: {messages}")
        ai_guard_payload = {
            "debug": False,  # Or make this configurable if needed
            "messages": messages,
        }
        if self.pangea_output_recipe:
            ai_guard_payload["recipe"] = self.pangea_output_recipe

        ai_guard_response = await self._call_pangea_guard(
            ai_guard_payload, "post_call_success_hook"
        )
        prompt_messages = ai_guard_response.get("result", {}).get("prompt_messages", [])

        try:
            return transformer.update_original_body(prompt_messages)
        except Exception as e:
            raise HTTPException(
                status_code=500,
                detail={
                    "error": "Failed to update original response body",
                    "guardrail_name": self.guardrail_name,
                    "exceptions": str(e),
                },
            ) from e