Spaces:
Running
Running
File size: 8,818 Bytes
180a8b0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import os
from typing import List, Dict, Any
from huggingface_hub import InferenceClient
from langchain_ollama import OllamaEmbeddings, OllamaLLM
from dotenv import load_dotenv
import numpy as np
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load environment variables
load_dotenv()
class LLMManager:
"""
Manager class for handling different LLM and embedding models.
Uses HuggingFace's InferenceClient directly for HuggingFace models.
"""
def __init__(self, provider: str = "huggingface"):
"""
Initialize the LLM Manager.
Args:
provider (str): The provider for LLM and embeddings.
Options: "ollama", "huggingface", "huggingface-openai"
"""
self.provider = provider
self.llm_client = None
self.embedding_client = None
# Initialize models based on the provider
if provider == "ollama":
self._init_ollama()
elif provider == "huggingface" or provider == "huggingface-openai":
self._init_huggingface()
else:
raise ValueError(f"Unsupported provider: {provider}. Choose 'ollama', 'huggingface', or 'huggingface-openai'")
def _init_ollama(self):
"""Initialize Ollama models."""
self.llm = OllamaLLM(model="phi4-mini:3.8b")
self.embeddings = OllamaEmbeddings(model="mxbai-embed-large:latest")
def _init_huggingface(self):
"""Initialize HuggingFace models using InferenceClient directly."""
# Get API key from environment
api_key = os.getenv("HUGGINGFACE_API_KEY")
if not api_key:
raise ValueError("HuggingFace API key not found. Set HUGGINGFACE_API_KEY in environment variables.")
llm_endpoint = "mistralai/Mixtral-8x7B-Instruct-v0.1"
embedding_endpoint = "sentence-transformers/all-MiniLM-L6-v2"
# Initialize InferenceClient for LLM
self.llm_client = InferenceClient(
model=llm_endpoint,
token=api_key
)
# Initialize InferenceClient for embeddings
self.embedding_client = InferenceClient(
model=embedding_endpoint,
token=api_key
)
# Store generation parameters
self.generation_kwargs = {
"temperature": 0.7,
"max_new_tokens": 512, # Reduced to avoid potential token limit issues
"repetition_penalty": 1.1,
"do_sample": True,
"top_k": 50,
"top_p": 0.9,
"return_full_text": False # Only return the generated text, not the prompt
}
# LLM methods for compatibility with LangChain
def get_llm(self):
"""
Return a callable object that mimics LangChain LLM interface.
For huggingface providers, this returns a function that calls the InferenceClient.
"""
if self.provider == "ollama":
return self.llm
else:
# Return a function that wraps the InferenceClient for LLM
def llm_function(prompt, **kwargs):
params = {**self.generation_kwargs, **kwargs}
try:
logger.info(f"Sending prompt to HuggingFace (length: {len(prompt)})")
response = self.llm_client.text_generation(
prompt,
details=True, # Get detailed response
**params
)
# Extract generated text from response
if isinstance(response, dict) and 'generated_text' in response:
response = response['generated_text']
logger.info(f"Received response from HuggingFace (length: {len(response) if response else 0})")
# Ensure we get a valid string response
if not response or not isinstance(response, str) or response.strip() == "":
logger.warning("Empty or invalid response from HuggingFace, using fallback")
return "I couldn't generate a proper response based on the available information."
return response
except Exception as e:
logger.error(f"Error during LLM inference: {str(e)}")
return f"Error generating response: {str(e)}"
# Add async capability
async def allm_function(prompt, **kwargs):
params = {**self.generation_kwargs, **kwargs}
try:
response = await self.llm_client.text_generation(
prompt,
**params,
stream=False
)
# Ensure we get a valid string response
if not response or not isinstance(response, str) or response.strip() == "":
logger.warning("Empty or invalid response from HuggingFace async, using fallback")
return "I couldn't generate a proper response based on the available information."
return response
except Exception as e:
logger.error(f"Error during async LLM inference: {str(e)}")
return f"Error generating response: {str(e)}"
llm_function.ainvoke = allm_function
return llm_function
# Embeddings methods for compatibility with LangChain
def get_embeddings(self):
"""
Return a callable object that mimics LangChain Embeddings interface.
For huggingface providers, this returns an object with embed_documents and embed_query methods.
"""
if self.provider == "ollama":
return self.embeddings
else:
# Create a wrapper object that has the expected methods
class EmbeddingsWrapper:
def __init__(self, client):
self.client = client
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed multiple documents."""
embeddings = []
# Process in batches to avoid overwhelming the API
batch_size = 8
for i in range(0, len(texts), batch_size):
batch = texts[i:i+batch_size]
try:
batch_embeddings = self.client.feature_extraction(batch)
# Convert to standard Python list format
batch_results = [list(map(float, embedding)) for embedding in batch_embeddings]
embeddings.extend(batch_results)
except Exception as e:
logger.error(f"Error embedding batch {i}: {str(e)}")
# Return zero vectors as fallback
for _ in range(len(batch)):
embeddings.append([0.0] * 384) # Use correct dimension
return embeddings
def embed_query(self, text: str) -> List[float]:
"""Embed a single query."""
try:
embedding = self.client.feature_extraction(text)
if isinstance(embedding, list) and len(embedding) > 0:
# If it returns a batch (list of embeddings) for a single input
return list(map(float, embedding[0]))
# If it returns a single embedding
return list(map(float, embedding))
except Exception as e:
logger.error(f"Error embedding query: {str(e)}")
# Return zero vector as fallback
return [0.0] * 384 # Use correct dimension
# Make the class callable to fix the TypeError
def __call__(self, texts):
"""Make the object callable for compatibility with LangChain."""
if isinstance(texts, str):
return self.embed_query(texts)
elif isinstance(texts, list):
return self.embed_documents(texts)
else:
raise ValueError(f"Unsupported input type: {type(texts)}")
return EmbeddingsWrapper(self.embedding_client) |