File size: 5,389 Bytes
447ebeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import os
import sys
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

sys.path.insert(
    0, os.path.abspath("../../..")
)  # Adds the parent directory to the system path


# Tests for RedisSemanticCache
def test_redis_semantic_cache_initialization(monkeypatch):
    # Mock the redisvl import
    semantic_cache_mock = MagicMock()
    with patch.dict(
        "sys.modules",
        {
            "redisvl.extensions.llmcache": MagicMock(SemanticCache=semantic_cache_mock),
            "redisvl.utils.vectorize": MagicMock(CustomTextVectorizer=MagicMock()),
        },
    ):
        from litellm.caching.redis_semantic_cache import RedisSemanticCache

        # Set environment variables
        monkeypatch.setenv("REDIS_HOST", "localhost")
        monkeypatch.setenv("REDIS_PORT", "6379")
        monkeypatch.setenv("REDIS_PASSWORD", "test_password")

        # Initialize the cache with a similarity threshold
        redis_semantic_cache = RedisSemanticCache(similarity_threshold=0.8)

        # Verify the semantic cache was initialized with correct parameters
        assert redis_semantic_cache.similarity_threshold == 0.8

        # Use pytest.approx for floating point comparison to handle precision issues
        assert redis_semantic_cache.distance_threshold == pytest.approx(0.2, abs=1e-10)
        assert redis_semantic_cache.embedding_model == "text-embedding-ada-002"

        # Test initialization with missing similarity_threshold
        with pytest.raises(ValueError, match="similarity_threshold must be provided"):
            RedisSemanticCache()


def test_redis_semantic_cache_get_cache(monkeypatch):
    # Mock the redisvl import and embedding function
    semantic_cache_mock = MagicMock()
    custom_vectorizer_mock = MagicMock()

    with patch.dict(
        "sys.modules",
        {
            "redisvl.extensions.llmcache": MagicMock(SemanticCache=semantic_cache_mock),
            "redisvl.utils.vectorize": MagicMock(
                CustomTextVectorizer=custom_vectorizer_mock
            ),
        },
    ):
        from litellm.caching.redis_semantic_cache import RedisSemanticCache

        # Set environment variables
        monkeypatch.setenv("REDIS_HOST", "localhost")
        monkeypatch.setenv("REDIS_PORT", "6379")
        monkeypatch.setenv("REDIS_PASSWORD", "test_password")

        # Initialize cache
        redis_semantic_cache = RedisSemanticCache(similarity_threshold=0.8)

        # Mock the llmcache.check method to return a result
        mock_result = [
            {
                "prompt": "What is the capital of France?",
                "response": '{"content": "Paris is the capital of France."}',
                "vector_distance": 0.1,  # Distance of 0.1 means similarity of 0.9
            }
        ]
        redis_semantic_cache.llmcache.check = MagicMock(return_value=mock_result)

        # Mock the embedding function
        with patch(
            "litellm.embedding", return_value={"data": [{"embedding": [0.1, 0.2, 0.3]}]}
        ):
            # Test get_cache with a message
            result = redis_semantic_cache.get_cache(
                key="test_key", messages=[{"content": "What is the capital of France?"}]
            )

            # Verify result is properly parsed
            assert result == {"content": "Paris is the capital of France."}

            # Verify llmcache.check was called
            redis_semantic_cache.llmcache.check.assert_called_once()


@pytest.mark.asyncio
async def test_redis_semantic_cache_async_get_cache(monkeypatch):
    # Mock the redisvl import
    semantic_cache_mock = MagicMock()
    custom_vectorizer_mock = MagicMock()

    with patch.dict(
        "sys.modules",
        {
            "redisvl.extensions.llmcache": MagicMock(SemanticCache=semantic_cache_mock),
            "redisvl.utils.vectorize": MagicMock(
                CustomTextVectorizer=custom_vectorizer_mock
            ),
        },
    ):
        from litellm.caching.redis_semantic_cache import RedisSemanticCache

        # Set environment variables
        monkeypatch.setenv("REDIS_HOST", "localhost")
        monkeypatch.setenv("REDIS_PORT", "6379")
        monkeypatch.setenv("REDIS_PASSWORD", "test_password")

        # Initialize cache
        redis_semantic_cache = RedisSemanticCache(similarity_threshold=0.8)

        # Mock the async methods
        mock_result = [
            {
                "prompt": "What is the capital of France?",
                "response": '{"content": "Paris is the capital of France."}',
                "vector_distance": 0.1,  # Distance of 0.1 means similarity of 0.9
            }
        ]

        redis_semantic_cache.llmcache.acheck = AsyncMock(return_value=mock_result)
        redis_semantic_cache._get_async_embedding = AsyncMock(
            return_value=[0.1, 0.2, 0.3]
        )

        # Test async_get_cache with a message
        result = await redis_semantic_cache.async_get_cache(
            key="test_key",
            messages=[{"content": "What is the capital of France?"}],
            metadata={},
        )

        # Verify result is properly parsed
        assert result == {"content": "Paris is the capital of France."}

        # Verify methods were called
        redis_semantic_cache._get_async_embedding.assert_called_once()
        redis_semantic_cache.llmcache.acheck.assert_called_once()