File size: 3,357 Bytes
447ebeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import json
import os
import sys
import uuid
from unittest.mock import MagicMock, patch

import pytest

sys.path.insert(
    0, os.path.abspath("../../../../..")
)  # Adds the parent directory to the system path

"""
Unit tests for OllamaModelInfo.get_models functionality.
"""
# Ensure a dummy httpx module is available for import in tests
import sys
import types

# Provide a dummy httpx module for import in get_models
if "httpx" not in sys.modules:
    # Create a minimal module with HTTPStatusError
    httpx_mod = types.ModuleType("httpx")
    httpx_mod.HTTPStatusError = Exception
    sys.modules["httpx"] = httpx_mod

import httpx

from litellm.llms.ollama.common_utils import OllamaModelInfo


class DummyResponse:
    """
    A dummy response object to simulate httpx responses.
    """

    def __init__(self, json_data, status_code=200):
        self._json = json_data
        self.status_code = status_code

    def raise_for_status(self):
        if self.status_code >= 400:
            # Simulate an HTTP status error
            raise httpx.HTTPStatusError(
                "Error status code", request=None, response=None
            )

    def json(self):
        return self._json


class TestOllamaModelInfo:
    def test_get_models_from_dict_response(self, monkeypatch):
        """
        When the /api/tags endpoint returns a dict with a 'models' list,
        get_models should extract and return sorted unique model names.
        """
        calls = []
        sample = {
            "models": [
                {"name": "zeta"},
                {"model": "alpha"},
                {"name": 123},  # non-str should be ignored
                "invalid",  # non-dict should be ignored
            ]
        }

        def mock_get(url):
            calls.append(url)
            return DummyResponse(sample, status_code=200)

        monkeypatch.setattr(httpx, "get", mock_get)
        info = OllamaModelInfo()
        models = info.get_models()
        # Only 'alpha' and 'zeta' should be returned, sorted alphabetically
        assert models == ["alpha", "zeta"]
        # Ensure correct endpoint was called
        assert calls and calls[0].endswith("/api/tags")

    def test_get_models_from_list_response(self, monkeypatch):
        """
        When the /api/tags endpoint returns a list of dicts,
        get_models should extract and return sorted unique model names.
        """
        sample = [
            {"name": "m1"},
            {"model": "m2"},
            {},  # no name/model key should be ignored
        ]

        def mock_get(url):
            return DummyResponse(sample, status_code=200)

        monkeypatch.setattr(httpx, "get", mock_get)
        info = OllamaModelInfo()
        models = info.get_models()
        assert models == ["m1", "m2"]

    def test_get_models_fallback_on_error(self, monkeypatch):
        """
        If the httpx.get call raises an exception, get_models should
        fall back to the static models_by_provider list prefixed by 'ollama/'.
        """

        def mock_get(url):
            raise Exception("connection failure")

        monkeypatch.setattr(httpx, "get", mock_get)
        info = OllamaModelInfo()
        models = info.get_models()
        # Default static ollama_models is ['llama2'], so expect ['ollama/llama2']
        assert models == ["ollama/llama2"]