f1-ai / llm_manager.py
AdityaAdaki
initial deployment
180a8b0
raw
history blame
8.82 kB
import os
from typing import List, Dict, Any
from huggingface_hub import InferenceClient
from langchain_ollama import OllamaEmbeddings, OllamaLLM
from dotenv import load_dotenv
import numpy as np
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load environment variables
load_dotenv()
class LLMManager:
"""
Manager class for handling different LLM and embedding models.
Uses HuggingFace's InferenceClient directly for HuggingFace models.
"""
def __init__(self, provider: str = "huggingface"):
"""
Initialize the LLM Manager.
Args:
provider (str): The provider for LLM and embeddings.
Options: "ollama", "huggingface", "huggingface-openai"
"""
self.provider = provider
self.llm_client = None
self.embedding_client = None
# Initialize models based on the provider
if provider == "ollama":
self._init_ollama()
elif provider == "huggingface" or provider == "huggingface-openai":
self._init_huggingface()
else:
raise ValueError(f"Unsupported provider: {provider}. Choose 'ollama', 'huggingface', or 'huggingface-openai'")
def _init_ollama(self):
"""Initialize Ollama models."""
self.llm = OllamaLLM(model="phi4-mini:3.8b")
self.embeddings = OllamaEmbeddings(model="mxbai-embed-large:latest")
def _init_huggingface(self):
"""Initialize HuggingFace models using InferenceClient directly."""
# Get API key from environment
api_key = os.getenv("HUGGINGFACE_API_KEY")
if not api_key:
raise ValueError("HuggingFace API key not found. Set HUGGINGFACE_API_KEY in environment variables.")
llm_endpoint = "mistralai/Mixtral-8x7B-Instruct-v0.1"
embedding_endpoint = "sentence-transformers/all-MiniLM-L6-v2"
# Initialize InferenceClient for LLM
self.llm_client = InferenceClient(
model=llm_endpoint,
token=api_key
)
# Initialize InferenceClient for embeddings
self.embedding_client = InferenceClient(
model=embedding_endpoint,
token=api_key
)
# Store generation parameters
self.generation_kwargs = {
"temperature": 0.7,
"max_new_tokens": 512, # Reduced to avoid potential token limit issues
"repetition_penalty": 1.1,
"do_sample": True,
"top_k": 50,
"top_p": 0.9,
"return_full_text": False # Only return the generated text, not the prompt
}
# LLM methods for compatibility with LangChain
def get_llm(self):
"""
Return a callable object that mimics LangChain LLM interface.
For huggingface providers, this returns a function that calls the InferenceClient.
"""
if self.provider == "ollama":
return self.llm
else:
# Return a function that wraps the InferenceClient for LLM
def llm_function(prompt, **kwargs):
params = {**self.generation_kwargs, **kwargs}
try:
logger.info(f"Sending prompt to HuggingFace (length: {len(prompt)})")
response = self.llm_client.text_generation(
prompt,
details=True, # Get detailed response
**params
)
# Extract generated text from response
if isinstance(response, dict) and 'generated_text' in response:
response = response['generated_text']
logger.info(f"Received response from HuggingFace (length: {len(response) if response else 0})")
# Ensure we get a valid string response
if not response or not isinstance(response, str) or response.strip() == "":
logger.warning("Empty or invalid response from HuggingFace, using fallback")
return "I couldn't generate a proper response based on the available information."
return response
except Exception as e:
logger.error(f"Error during LLM inference: {str(e)}")
return f"Error generating response: {str(e)}"
# Add async capability
async def allm_function(prompt, **kwargs):
params = {**self.generation_kwargs, **kwargs}
try:
response = await self.llm_client.text_generation(
prompt,
**params,
stream=False
)
# Ensure we get a valid string response
if not response or not isinstance(response, str) or response.strip() == "":
logger.warning("Empty or invalid response from HuggingFace async, using fallback")
return "I couldn't generate a proper response based on the available information."
return response
except Exception as e:
logger.error(f"Error during async LLM inference: {str(e)}")
return f"Error generating response: {str(e)}"
llm_function.ainvoke = allm_function
return llm_function
# Embeddings methods for compatibility with LangChain
def get_embeddings(self):
"""
Return a callable object that mimics LangChain Embeddings interface.
For huggingface providers, this returns an object with embed_documents and embed_query methods.
"""
if self.provider == "ollama":
return self.embeddings
else:
# Create a wrapper object that has the expected methods
class EmbeddingsWrapper:
def __init__(self, client):
self.client = client
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed multiple documents."""
embeddings = []
# Process in batches to avoid overwhelming the API
batch_size = 8
for i in range(0, len(texts), batch_size):
batch = texts[i:i+batch_size]
try:
batch_embeddings = self.client.feature_extraction(batch)
# Convert to standard Python list format
batch_results = [list(map(float, embedding)) for embedding in batch_embeddings]
embeddings.extend(batch_results)
except Exception as e:
logger.error(f"Error embedding batch {i}: {str(e)}")
# Return zero vectors as fallback
for _ in range(len(batch)):
embeddings.append([0.0] * 384) # Use correct dimension
return embeddings
def embed_query(self, text: str) -> List[float]:
"""Embed a single query."""
try:
embedding = self.client.feature_extraction(text)
if isinstance(embedding, list) and len(embedding) > 0:
# If it returns a batch (list of embeddings) for a single input
return list(map(float, embedding[0]))
# If it returns a single embedding
return list(map(float, embedding))
except Exception as e:
logger.error(f"Error embedding query: {str(e)}")
# Return zero vector as fallback
return [0.0] * 384 # Use correct dimension
# Make the class callable to fix the TypeError
def __call__(self, texts):
"""Make the object callable for compatibility with LangChain."""
if isinstance(texts, str):
return self.embed_query(texts)
elif isinstance(texts, list):
return self.embed_documents(texts)
else:
raise ValueError(f"Unsupported input type: {type(texts)}")
return EmbeddingsWrapper(self.embedding_client)