Spaces:
Configuration error
Configuration error
import json | |
import os | |
import sys | |
import uuid | |
from unittest.mock import MagicMock, patch | |
import pytest | |
sys.path.insert( | |
0, os.path.abspath("../../../../..") | |
) # Adds the parent directory to the system path | |
""" | |
Unit tests for OllamaModelInfo.get_models functionality. | |
""" | |
# Ensure a dummy httpx module is available for import in tests | |
import sys | |
import types | |
# Provide a dummy httpx module for import in get_models | |
if "httpx" not in sys.modules: | |
# Create a minimal module with HTTPStatusError | |
httpx_mod = types.ModuleType("httpx") | |
httpx_mod.HTTPStatusError = Exception | |
sys.modules["httpx"] = httpx_mod | |
import httpx | |
from litellm.llms.ollama.common_utils import OllamaModelInfo | |
class DummyResponse: | |
""" | |
A dummy response object to simulate httpx responses. | |
""" | |
def __init__(self, json_data, status_code=200): | |
self._json = json_data | |
self.status_code = status_code | |
def raise_for_status(self): | |
if self.status_code >= 400: | |
# Simulate an HTTP status error | |
raise httpx.HTTPStatusError( | |
"Error status code", request=None, response=None | |
) | |
def json(self): | |
return self._json | |
class TestOllamaModelInfo: | |
def test_get_models_from_dict_response(self, monkeypatch): | |
""" | |
When the /api/tags endpoint returns a dict with a 'models' list, | |
get_models should extract and return sorted unique model names. | |
""" | |
calls = [] | |
sample = { | |
"models": [ | |
{"name": "zeta"}, | |
{"model": "alpha"}, | |
{"name": 123}, # non-str should be ignored | |
"invalid", # non-dict should be ignored | |
] | |
} | |
def mock_get(url): | |
calls.append(url) | |
return DummyResponse(sample, status_code=200) | |
monkeypatch.setattr(httpx, "get", mock_get) | |
info = OllamaModelInfo() | |
models = info.get_models() | |
# Only 'alpha' and 'zeta' should be returned, sorted alphabetically | |
assert models == ["alpha", "zeta"] | |
# Ensure correct endpoint was called | |
assert calls and calls[0].endswith("/api/tags") | |
def test_get_models_from_list_response(self, monkeypatch): | |
""" | |
When the /api/tags endpoint returns a list of dicts, | |
get_models should extract and return sorted unique model names. | |
""" | |
sample = [ | |
{"name": "m1"}, | |
{"model": "m2"}, | |
{}, # no name/model key should be ignored | |
] | |
def mock_get(url): | |
return DummyResponse(sample, status_code=200) | |
monkeypatch.setattr(httpx, "get", mock_get) | |
info = OllamaModelInfo() | |
models = info.get_models() | |
assert models == ["m1", "m2"] | |
def test_get_models_fallback_on_error(self, monkeypatch): | |
""" | |
If the httpx.get call raises an exception, get_models should | |
fall back to the static models_by_provider list prefixed by 'ollama/'. | |
""" | |
def mock_get(url): | |
raise Exception("connection failure") | |
monkeypatch.setattr(httpx, "get", mock_get) | |
info = OllamaModelInfo() | |
models = info.get_models() | |
# Default static ollama_models is ['llama2'], so expect ['ollama/llama2'] | |
assert models == ["ollama/llama2"] | |