File size: 16,880 Bytes
469eae6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
"""
Redis Semantic Cache implementation for LiteLLM

The RedisSemanticCache provides semantic caching functionality using Redis as a backend.
This cache stores responses based on the semantic similarity of prompts rather than
exact matching, allowing for more flexible caching of LLM responses.

This implementation uses RedisVL's SemanticCache to find semantically similar prompts
and their cached responses.
"""

import ast
import asyncio
import json
import os
from typing import Any, Dict, List, Optional, Tuple, cast

import litellm
from litellm._logging import print_verbose
from litellm.litellm_core_utils.prompt_templates.common_utils import (
    get_str_from_messages,
)
from litellm.types.utils import EmbeddingResponse

from .base_cache import BaseCache


class RedisSemanticCache(BaseCache):
    """
    Redis-backed semantic cache for LLM responses.

    This cache uses vector similarity to find semantically similar prompts that have been
    previously sent to the LLM, allowing for cache hits even when prompts are not identical
    but carry similar meaning.
    """

    DEFAULT_REDIS_INDEX_NAME: str = "litellm_semantic_cache_index"

    def __init__(
        self,
        host: Optional[str] = None,
        port: Optional[str] = None,
        password: Optional[str] = None,
        redis_url: Optional[str] = None,
        similarity_threshold: Optional[float] = None,
        embedding_model: str = "text-embedding-ada-002",
        index_name: Optional[str] = None,
        **kwargs,
    ):
        """
        Initialize the Redis Semantic Cache.

        Args:
            host: Redis host address
            port: Redis port
            password: Redis password
            redis_url: Full Redis URL (alternative to separate host/port/password)
            similarity_threshold: Threshold for semantic similarity (0.0 to 1.0)
                where 1.0 requires exact matches and 0.0 accepts any match
            embedding_model: Model to use for generating embeddings
            index_name: Name for the Redis index
            ttl: Default time-to-live for cache entries in seconds
            **kwargs: Additional arguments passed to the Redis client

        Raises:
            Exception: If similarity_threshold is not provided or required Redis
                connection information is missing
        """
        from redisvl.extensions.llmcache import SemanticCache
        from redisvl.utils.vectorize import CustomTextVectorizer

        if index_name is None:
            index_name = self.DEFAULT_REDIS_INDEX_NAME

        print_verbose(f"Redis semantic-cache initializing index - {index_name}")

        # Validate similarity threshold
        if similarity_threshold is None:
            raise ValueError("similarity_threshold must be provided, passed None")

        # Store configuration
        self.similarity_threshold = similarity_threshold

        # Convert similarity threshold [0,1] to distance threshold [0,2]
        # For cosine distance: 0 = most similar, 2 = least similar
        # While similarity: 1 = most similar, 0 = least similar
        self.distance_threshold = 1 - similarity_threshold
        self.embedding_model = embedding_model

        # Set up Redis connection
        if redis_url is None:
            try:
                # Attempt to use provided parameters or fallback to environment variables
                host = host or os.environ["REDIS_HOST"]
                port = port or os.environ["REDIS_PORT"]
                password = password or os.environ["REDIS_PASSWORD"]
            except KeyError as e:
                # Raise a more informative exception if any of the required keys are missing
                missing_var = e.args[0]
                raise ValueError(
                    f"Missing required Redis configuration: {missing_var}. "
                    f"Provide {missing_var} or redis_url."
                ) from e

            redis_url = f"redis://:{password}@{host}:{port}"

        print_verbose(f"Redis semantic-cache redis_url: {redis_url}")

        # Initialize the Redis vectorizer and cache
        cache_vectorizer = CustomTextVectorizer(self._get_embedding)

        self.llmcache = SemanticCache(
            name=index_name,
            redis_url=redis_url,
            vectorizer=cache_vectorizer,
            distance_threshold=self.distance_threshold,
            overwrite=False,
        )

    def _get_ttl(self, **kwargs) -> Optional[int]:
        """
        Get the TTL (time-to-live) value for cache entries.

        Args:
            **kwargs: Keyword arguments that may contain a custom TTL

        Returns:
            Optional[int]: The TTL value in seconds, or None if no TTL should be applied
        """
        ttl = kwargs.get("ttl")
        if ttl is not None:
            ttl = int(ttl)
        return ttl

    def _get_embedding(self, prompt: str) -> List[float]:
        """
        Generate an embedding vector for the given prompt using the configured embedding model.

        Args:
            prompt: The text to generate an embedding for

        Returns:
            List[float]: The embedding vector
        """
        # Create an embedding from prompt
        embedding_response = cast(
            EmbeddingResponse,
            litellm.embedding(
                model=self.embedding_model,
                input=prompt,
                cache={"no-store": True, "no-cache": True},
            ),
        )
        embedding = embedding_response["data"][0]["embedding"]
        return embedding

    def _get_cache_logic(self, cached_response: Any) -> Any:
        """
        Process the cached response to prepare it for use.

        Args:
            cached_response: The raw cached response

        Returns:
            The processed cache response, or None if input was None
        """
        if cached_response is None:
            return cached_response

        # Convert bytes to string if needed
        if isinstance(cached_response, bytes):
            cached_response = cached_response.decode("utf-8")

        # Convert string representation to Python object
        try:
            cached_response = json.loads(cached_response)
        except json.JSONDecodeError:
            try:
                cached_response = ast.literal_eval(cached_response)
            except (ValueError, SyntaxError) as e:
                print_verbose(f"Error parsing cached response: {str(e)}")
                return None

        return cached_response

    def set_cache(self, key: str, value: Any, **kwargs) -> None:
        """
        Store a value in the semantic cache.

        Args:
            key: The cache key (not directly used in semantic caching)
            value: The response value to cache
            **kwargs: Additional arguments including 'messages' for the prompt
                and optional 'ttl' for time-to-live
        """
        print_verbose(f"Redis semantic-cache set_cache, kwargs: {kwargs}")

        value_str: Optional[str] = None
        try:
            # Extract the prompt from messages
            messages = kwargs.get("messages", [])
            if not messages:
                print_verbose("No messages provided for semantic caching")
                return

            prompt = get_str_from_messages(messages)
            value_str = str(value)

            # Get TTL and store in Redis semantic cache
            ttl = self._get_ttl(**kwargs)
            if ttl is not None:
                self.llmcache.store(prompt, value_str, ttl=int(ttl))
            else:
                self.llmcache.store(prompt, value_str)
        except Exception as e:
            print_verbose(
                f"Error setting {value_str or value} in the Redis semantic cache: {str(e)}"
            )

    def get_cache(self, key: str, **kwargs) -> Any:
        """
        Retrieve a semantically similar cached response.

        Args:
            key: The cache key (not directly used in semantic caching)
            **kwargs: Additional arguments including 'messages' for the prompt

        Returns:
            The cached response if a semantically similar prompt is found, else None
        """
        print_verbose(f"Redis semantic-cache get_cache, kwargs: {kwargs}")

        try:
            # Extract the prompt from messages
            messages = kwargs.get("messages", [])
            if not messages:
                print_verbose("No messages provided for semantic cache lookup")
                return None

            prompt = get_str_from_messages(messages)
            # Check the cache for semantically similar prompts
            results = self.llmcache.check(prompt=prompt)

            # Return None if no similar prompts found
            if not results:
                return None

            # Process the best matching result
            cache_hit = results[0]
            vector_distance = float(cache_hit["vector_distance"])

            # Convert vector distance back to similarity score
            # For cosine distance: 0 = most similar, 2 = least similar
            # While similarity: 1 = most similar, 0 = least similar
            similarity = 1 - vector_distance

            cached_prompt = cache_hit["prompt"]
            cached_response = cache_hit["response"]

            print_verbose(
                f"Cache hit: similarity threshold: {self.similarity_threshold}, "
                f"actual similarity: {similarity}, "
                f"current prompt: {prompt}, "
                f"cached prompt: {cached_prompt}"
            )

            return self._get_cache_logic(cached_response=cached_response)
        except Exception as e:
            print_verbose(f"Error retrieving from Redis semantic cache: {str(e)}")

    async def _get_async_embedding(self, prompt: str, **kwargs) -> List[float]:
        """
        Asynchronously generate an embedding for the given prompt.

        Args:
            prompt: The text to generate an embedding for
            **kwargs: Additional arguments that may contain metadata

        Returns:
            List[float]: The embedding vector
        """
        from litellm.proxy.proxy_server import llm_model_list, llm_router

        # Route the embedding request through the proxy if appropriate
        router_model_names = (
            [m["model_name"] for m in llm_model_list]
            if llm_model_list is not None
            else []
        )

        try:
            if llm_router is not None and self.embedding_model in router_model_names:
                # Use the router for embedding generation
                user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
                embedding_response = await llm_router.aembedding(
                    model=self.embedding_model,
                    input=prompt,
                    cache={"no-store": True, "no-cache": True},
                    metadata={
                        "user_api_key": user_api_key,
                        "semantic-cache-embedding": True,
                        "trace_id": kwargs.get("metadata", {}).get("trace_id", None),
                    },
                )
            else:
                # Generate embedding directly
                embedding_response = await litellm.aembedding(
                    model=self.embedding_model,
                    input=prompt,
                    cache={"no-store": True, "no-cache": True},
                )

            # Extract and return the embedding vector
            return embedding_response["data"][0]["embedding"]
        except Exception as e:
            print_verbose(f"Error generating async embedding: {str(e)}")
            raise ValueError(f"Failed to generate embedding: {str(e)}") from e

    async def async_set_cache(self, key: str, value: Any, **kwargs) -> None:
        """
        Asynchronously store a value in the semantic cache.

        Args:
            key: The cache key (not directly used in semantic caching)
            value: The response value to cache
            **kwargs: Additional arguments including 'messages' for the prompt
                and optional 'ttl' for time-to-live
        """
        print_verbose(f"Async Redis semantic-cache set_cache, kwargs: {kwargs}")

        try:
            # Extract the prompt from messages
            messages = kwargs.get("messages", [])
            if not messages:
                print_verbose("No messages provided for semantic caching")
                return

            prompt = get_str_from_messages(messages)
            value_str = str(value)

            # Generate embedding for the value (response) to cache
            prompt_embedding = await self._get_async_embedding(prompt, **kwargs)

            # Get TTL and store in Redis semantic cache
            ttl = self._get_ttl(**kwargs)
            if ttl is not None:
                await self.llmcache.astore(
                    prompt,
                    value_str,
                    vector=prompt_embedding,  # Pass through custom embedding
                    ttl=ttl,
                )
            else:
                await self.llmcache.astore(
                    prompt,
                    value_str,
                    vector=prompt_embedding,  # Pass through custom embedding
                )
        except Exception as e:
            print_verbose(f"Error in async_set_cache: {str(e)}")

    async def async_get_cache(self, key: str, **kwargs) -> Any:
        """
        Asynchronously retrieve a semantically similar cached response.

        Args:
            key: The cache key (not directly used in semantic caching)
            **kwargs: Additional arguments including 'messages' for the prompt

        Returns:
            The cached response if a semantically similar prompt is found, else None
        """
        print_verbose(f"Async Redis semantic-cache get_cache, kwargs: {kwargs}")

        try:
            # Extract the prompt from messages
            messages = kwargs.get("messages", [])
            if not messages:
                print_verbose("No messages provided for semantic cache lookup")
                kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
                return None

            prompt = get_str_from_messages(messages)

            # Generate embedding for the prompt
            prompt_embedding = await self._get_async_embedding(prompt, **kwargs)

            # Check the cache for semantically similar prompts
            results = await self.llmcache.acheck(prompt=prompt, vector=prompt_embedding)

            # handle results / cache hit
            if not results:
                kwargs.setdefault("metadata", {})[
                    "semantic-similarity"
                ] = 0.0  # TODO why here but not above??
                return None

            cache_hit = results[0]
            vector_distance = float(cache_hit["vector_distance"])

            # Convert vector distance back to similarity
            # For cosine distance: 0 = most similar, 2 = least similar
            # While similarity: 1 = most similar, 0 = least similar
            similarity = 1 - vector_distance

            cached_prompt = cache_hit["prompt"]
            cached_response = cache_hit["response"]

            # update kwargs["metadata"] with similarity, don't rewrite the original metadata
            kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity

            print_verbose(
                f"Cache hit: similarity threshold: {self.similarity_threshold}, "
                f"actual similarity: {similarity}, "
                f"current prompt: {prompt}, "
                f"cached prompt: {cached_prompt}"
            )

            return self._get_cache_logic(cached_response=cached_response)
        except Exception as e:
            print_verbose(f"Error in async_get_cache: {str(e)}")
            kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0

    async def _index_info(self) -> Dict[str, Any]:
        """
        Get information about the Redis index.

        Returns:
            Dict[str, Any]: Information about the Redis index
        """
        aindex = await self.llmcache._get_async_index()
        return await aindex.info()

    async def async_set_cache_pipeline(
        self, cache_list: List[Tuple[str, Any]], **kwargs
    ) -> None:
        """
        Asynchronously store multiple values in the semantic cache.

        Args:
            cache_list: List of (key, value) tuples to cache
            **kwargs: Additional arguments
        """
        try:
            tasks = []
            for val in cache_list:
                tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
            await asyncio.gather(*tasks)
        except Exception as e:
            print_verbose(f"Error in async_set_cache_pipeline: {str(e)}")