Spaces:
Running
Running
"""Add support for llama-cpp-python models to LiteLLM.""" | |
import asyncio | |
import logging | |
import warnings | |
from collections.abc import AsyncIterator, Callable, Iterator | |
from functools import cache | |
from typing import Any, ClassVar, cast | |
import httpx | |
import litellm | |
from litellm import ( # type: ignore[attr-defined] | |
CustomLLM, | |
GenericStreamingChunk, | |
ModelResponse, | |
convert_to_model_response_object, | |
) | |
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler | |
from llama_cpp import ( # type: ignore[attr-defined] | |
ChatCompletionRequestMessage, | |
CreateChatCompletionResponse, | |
CreateChatCompletionStreamResponse, | |
Llama, | |
LlamaRAMCache, | |
) | |
# Reduce the logging level for LiteLLM and flashrank. | |
logging.getLogger("litellm").setLevel(logging.WARNING) | |
logging.getLogger("flashrank").setLevel(logging.WARNING) | |
class LlamaCppPythonLLM(CustomLLM): | |
"""A llama-cpp-python provider for LiteLLM. | |
This provider enables using llama-cpp-python models with LiteLLM. The LiteLLM model | |
specification is "llama-cpp-python/<hugging_face_repo_id>/<filename>@<n_ctx>", where n_ctx is | |
an optional parameter that specifies the context size of the model. If n_ctx is not provided or | |
if it's set to 0, the model's default context size is used. | |
Example usage: | |
```python | |
from litellm import completion | |
response = completion( | |
model="llama-cpp-python/bartowski/Meta-Llama-3.1-8B-Instruct-GGUF/*Q4_K_M.gguf@4092", | |
messages=[{"role": "user", "content": "Hello world!"}], | |
# stream=True | |
) | |
``` | |
""" | |
# Create a lock to prevent concurrent access to llama-cpp-python models. | |
streaming_lock: ClassVar[asyncio.Lock] = asyncio.Lock() | |
# The set of supported OpenAI parameters is the intersection of [1] and [2]. Not included: | |
# max_completion_tokens, stream_options, n, user, logprobs, top_logprobs, extra_headers. | |
# [1] https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_chat_completion | |
# [2] https://docs.litellm.ai/docs/completion/input | |
supported_openai_params: ClassVar[list[str]] = [ | |
"functions", # Deprecated | |
"function_call", # Deprecated | |
"tools", | |
"tool_choice", | |
"temperature", | |
"top_p", | |
"top_k", | |
"min_p", | |
"typical_p", | |
"stop", | |
"seed", | |
"response_format", | |
"max_tokens", | |
"presence_penalty", | |
"frequency_penalty", | |
"repeat_penalty", | |
"tfs_z", | |
"mirostat_mode", | |
"mirostat_tau", | |
"mirostat_eta", | |
"logit_bias", | |
] | |
def llm(model: str, **kwargs: Any) -> Llama: | |
# Drop the llama-cpp-python prefix from the model. | |
repo_id_filename = model.replace("llama-cpp-python/", "") | |
# Convert the LiteLLM model string to repo_id, filename, and n_ctx. | |
repo_id, filename = repo_id_filename.rsplit("/", maxsplit=1) | |
n_ctx = 0 | |
if len(filename_n_ctx := filename.rsplit("@", maxsplit=1)) == 2: # noqa: PLR2004 | |
filename, n_ctx_str = filename_n_ctx | |
n_ctx = int(n_ctx_str) | |
# Load the LLM. | |
with warnings.catch_warnings(): # Filter huggingface_hub warning about HF_TOKEN. | |
warnings.filterwarnings("ignore", category=UserWarning) | |
llm = Llama.from_pretrained( | |
repo_id=repo_id, | |
filename=filename, | |
n_ctx=n_ctx, | |
n_gpu_layers=-1, | |
verbose=False, | |
**kwargs, | |
) | |
# Enable caching. | |
llm.set_cache(LlamaRAMCache()) | |
# Register the model info with LiteLLM. | |
litellm.register_model( # type: ignore[attr-defined] | |
{ | |
model: { | |
"max_tokens": llm.n_ctx(), | |
"max_input_tokens": llm.n_ctx(), | |
"max_output_tokens": None, | |
"input_cost_per_token": 0.0, | |
"output_cost_per_token": 0.0, | |
"output_vector_size": llm.n_embd() if kwargs.get("embedding") else None, | |
"litellm_provider": "llama-cpp-python", | |
"mode": "embedding" if kwargs.get("embedding") else "completion", | |
"supported_openai_params": LlamaCppPythonLLM.supported_openai_params, | |
"supports_function_calling": True, | |
"supports_parallel_function_calling": True, | |
"supports_vision": False, | |
} | |
} | |
) | |
return llm | |
def completion( # noqa: PLR0913 | |
self, | |
model: str, | |
messages: list[ChatCompletionRequestMessage], | |
api_base: str, | |
custom_prompt_dict: dict[str, Any], | |
model_response: ModelResponse, | |
print_verbose: Callable, # type: ignore[type-arg] | |
encoding: str, | |
api_key: str, | |
logging_obj: Any, | |
optional_params: dict[str, Any], | |
acompletion: Callable | None = None, # type: ignore[type-arg] | |
litellm_params: dict[str, Any] | None = None, | |
logger_fn: Callable | None = None, # type: ignore[type-arg] | |
headers: dict[str, Any] | None = None, | |
timeout: float | httpx.Timeout | None = None, | |
client: HTTPHandler | None = None, | |
) -> ModelResponse: | |
llm = self.llm(model) | |
llama_cpp_python_params = { | |
k: v for k, v in optional_params.items() if k in self.supported_openai_params | |
} | |
response = cast( | |
CreateChatCompletionResponse, | |
llm.create_chat_completion(messages=messages, **llama_cpp_python_params), | |
) | |
litellm_model_response: ModelResponse = convert_to_model_response_object( | |
response_object=response, | |
model_response_object=model_response, | |
response_type="completion", | |
stream=False, | |
) | |
return litellm_model_response | |
def streaming( # noqa: PLR0913 | |
self, | |
model: str, | |
messages: list[ChatCompletionRequestMessage], | |
api_base: str, | |
custom_prompt_dict: dict[str, Any], | |
model_response: ModelResponse, | |
print_verbose: Callable, # type: ignore[type-arg] | |
encoding: str, | |
api_key: str, | |
logging_obj: Any, | |
optional_params: dict[str, Any], | |
acompletion: Callable | None = None, # type: ignore[type-arg] | |
litellm_params: dict[str, Any] | None = None, | |
logger_fn: Callable | None = None, # type: ignore[type-arg] | |
headers: dict[str, Any] | None = None, | |
timeout: float | httpx.Timeout | None = None, | |
client: HTTPHandler | None = None, | |
) -> Iterator[GenericStreamingChunk]: | |
llm = self.llm(model) | |
llama_cpp_python_params = { | |
k: v for k, v in optional_params.items() if k in self.supported_openai_params | |
} | |
stream = cast( | |
Iterator[CreateChatCompletionStreamResponse], | |
llm.create_chat_completion(messages=messages, **llama_cpp_python_params, stream=True), | |
) | |
for chunk in stream: | |
choices = chunk.get("choices", []) | |
for choice in choices: | |
text = choice.get("delta", {}).get("content", None) | |
finish_reason = choice.get("finish_reason") | |
litellm_generic_streaming_chunk = GenericStreamingChunk( | |
text=text, # type: ignore[typeddict-item] | |
is_finished=bool(finish_reason), | |
finish_reason=finish_reason, # type: ignore[typeddict-item] | |
usage=None, | |
index=choice.get("index"), # type: ignore[typeddict-item] | |
provider_specific_fields={ | |
"id": chunk.get("id"), | |
"model": chunk.get("model"), | |
"created": chunk.get("created"), | |
"object": chunk.get("object"), | |
}, | |
) | |
yield litellm_generic_streaming_chunk | |
async def astreaming( # type: ignore[misc,override] # noqa: PLR0913 | |
self, | |
model: str, | |
messages: list[ChatCompletionRequestMessage], | |
api_base: str, | |
custom_prompt_dict: dict[str, Any], | |
model_response: ModelResponse, | |
print_verbose: Callable, # type: ignore[type-arg] | |
encoding: str, | |
api_key: str, | |
logging_obj: Any, | |
optional_params: dict[str, Any], | |
acompletion: Callable | None = None, # type: ignore[type-arg] | |
litellm_params: dict[str, Any] | None = None, | |
logger_fn: Callable | None = None, # type: ignore[type-arg] | |
headers: dict[str, Any] | None = None, | |
timeout: float | httpx.Timeout | None = None, # noqa: ASYNC109 | |
client: AsyncHTTPHandler | None = None, | |
) -> AsyncIterator[GenericStreamingChunk]: | |
# Start a synchronous stream. | |
stream = self.streaming( | |
model, | |
messages, | |
api_base, | |
custom_prompt_dict, | |
model_response, | |
print_verbose, | |
encoding, | |
api_key, | |
logging_obj, | |
optional_params, | |
acompletion, | |
litellm_params, | |
logger_fn, | |
headers, | |
timeout, | |
) | |
await asyncio.sleep(0) # Yield control to the event loop after initialising the context. | |
# Wrap the synchronous stream in an asynchronous stream. | |
async with LlamaCppPythonLLM.streaming_lock: | |
for litellm_generic_streaming_chunk in stream: | |
yield litellm_generic_streaming_chunk | |
await asyncio.sleep(0) # Yield control to the event loop after each token. | |
# Register the LlamaCppPythonLLM provider. | |
if not any(provider["provider"] == "llama-cpp-python" for provider in litellm.custom_provider_map): | |
litellm.custom_provider_map.append( | |
{"provider": "llama-cpp-python", "custom_handler": LlamaCppPythonLLM()} | |
) | |
litellm.suppress_debug_info = True | |