f1-ai / llm_manager.py
AdityaAdaki
embedding bug fix
9d37152
raw
history blame
6.48 kB
import os
import json
import requests
from typing import List, Dict, Any
from langchain_community.embeddings import HuggingFaceEmbeddings # Changed import
from dotenv import load_dotenv
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load environment variables
load_dotenv()
class LLMManager:
"""
Manager class for handling Ollama embeddings and OpenRouter LLM.
"""
def __init__(self, provider: str = "openrouter"):
"""
Initialize the LLM Manager.
Args:
provider (str): Provider for LLM (openrouter is default and recommended)
"""
self.provider = provider
# Initialize HuggingFace embeddings instead of Ollama
self.embeddings = HuggingFaceEmbeddings(
model_name="all-MiniLM-L6-v2",
model_kwargs={'device': 'cpu'}
)
# Initialize OpenRouter client
self.openrouter_api_key = os.getenv("OPENROUTER_API_KEY")
if not self.openrouter_api_key:
raise ValueError("OpenRouter API key not found. Set OPENROUTER_API_KEY in environment variables.")
# Set up OpenRouter API details
self.openrouter_url = "https://openrouter.ai/api/v1/chat/completions"
self.openrouter_model = "mistralai/mistral-7b-instruct:free"
self.openrouter_headers = {
"Authorization": f"Bearer {self.openrouter_api_key}",
"Content-Type": "application/json",
"HTTP-Referer": "https://f1-ai.app", # Replace with your app's URL
"X-Title": "F1-AI Application" # Replace with your app's name
}
# LLM methods for compatibility with LangChain
def get_llm(self):
"""
Return a callable function that serves as the LLM interface.
"""
def llm_function(prompt, **kwargs):
try:
logger.info(f"Sending prompt to OpenRouter (length: {len(prompt)})")
# Format the messages for OpenRouter API
messages = [{"role": "user", "content": prompt}]
# Set up request payload
payload = {
"model": self.openrouter_model,
"messages": messages,
"temperature": kwargs.get("temperature", 0.7),
"max_tokens": kwargs.get("max_tokens", 1024),
"top_p": kwargs.get("top_p", 0.9),
"stream": False
}
# Send request to OpenRouter
response = requests.post(
self.openrouter_url,
headers=self.openrouter_headers,
json=payload,
timeout=60
)
# Process the response
if response.status_code == 200:
response_json = response.json()
if "choices" in response_json and len(response_json["choices"]) > 0:
generated_text = response_json["choices"][0]["message"]["content"]
logger.info(f"Received response from OpenRouter (length: {len(generated_text)})")
return generated_text
else:
logger.warning("Unexpected response format from OpenRouter")
return "I couldn't generate a proper response based on the available information."
else:
logger.error(f"Error from OpenRouter API: {response.status_code} - {response.text}")
return f"Error from LLM API: {response.status_code}"
except Exception as e:
logger.error(f"Error during LLM inference: {str(e)}")
return f"Error generating response: {str(e)}"
# Add async capability
async def allm_function(prompt, **kwargs):
import aiohttp
try:
# Format the messages for OpenRouter API
messages = [{"role": "user", "content": prompt}]
# Set up request payload
payload = {
"model": self.openrouter_model,
"messages": messages,
"temperature": kwargs.get("temperature", 0.7),
"max_tokens": kwargs.get("max_tokens", 1024),
"top_p": kwargs.get("top_p", 0.9),
"stream": False
}
async with aiohttp.ClientSession() as session:
async with session.post(
self.openrouter_url,
headers=self.openrouter_headers,
json=payload,
timeout=aiohttp.ClientTimeout(total=60)
) as response:
if response.status == 200:
response_json = await response.json()
if "choices" in response_json and len(response_json["choices"]) > 0:
generated_text = response_json["choices"][0]["message"]["content"]
return generated_text
else:
logger.warning("Unexpected response format from OpenRouter")
return "I couldn't generate a proper response based on the available information."
else:
error_text = await response.text()
logger.error(f"Error from OpenRouter API: {response.status} - {error_text}")
return f"Error from LLM API: {response.status}"
except Exception as e:
logger.error(f"Error during async LLM inference: {str(e)}")
return f"Error generating response: {str(e)}"
# Add async method to the function
llm_function.ainvoke = allm_function
# Add invoke method for compatibility
llm_function.invoke = llm_function
return llm_function
# Embeddings methods for compatibility with LangChain
def get_embeddings(self):
"""Return the embeddings instance."""
return self.embeddings