AdityaAdaki commited on
Commit
79f7264
·
1 Parent(s): 34a1969

refactor(llm_manager): switch from HuggingFace to OpenRouter for LLM

Browse files

The LLM provider has been changed from HuggingFace to OpenRouter to improve reliability and reduce API complexity. This change also removes redundant code related to HuggingFace's InferenceClient and simplifies the LLM interface. The default provider is now set to OpenRouter, and the embeddings are handled separately using Ollama. This refactor improves maintainability and aligns with the new architecture.

Files changed (2) hide show
  1. f1_ai.py +25 -42
  2. llm_manager.py +113 -160
f1_ai.py CHANGED
@@ -21,60 +21,46 @@ console = Console()
21
  load_dotenv()
22
 
23
  class F1AI:
24
- def __init__(self, index_name: str = "f12", llm_provider: str = "huggingface"):
25
  """
26
  Initialize the F1-AI RAG application.
27
 
28
  Args:
29
  index_name (str): Name of the Pinecone index to use
30
- llm_provider (str): Provider for LLM and embeddings.
31
- Options: "ollama", "huggingface", "huggingface-openai"
32
  """
33
  self.index_name = index_name
34
 
35
- # Initialize LLM and embeddings via manager
36
  self.llm_manager = LLMManager(provider=llm_provider)
37
  self.llm = self.llm_manager.get_llm()
38
- self.embeddings = self.llm_manager.get_embeddings()
39
 
40
  # Load Pinecone API Key
41
  pinecone_api_key = os.getenv("PINECONE_API_KEY")
42
  if not pinecone_api_key:
43
  raise ValueError("❌ Pinecone API key missing! Set PINECONE_API_KEY in environment variables.")
44
 
45
- # Modify this part in f1_ai.py
46
-
47
  # Initialize Pinecone with v2 client
48
- try:
49
- self.pc = Pinecone(api_key=pinecone_api_key)
50
 
51
- # Check existing indexes
52
- existing_indexes = [idx['name'] for idx in self.pc.list_indexes()]
53
 
54
- if index_name not in existing_indexes:
55
- console.log(f"🚀 Creating Pinecone index: {index_name}")
56
- # Update the dimension to match your embedding model
57
- self.pc.create_index(
58
- name=index_name,
59
- dimension=384, # Match embedding dimensions of the model
60
- metric="cosine"
61
- )
62
 
63
- # Connect to Pinecone index
64
- index = self.pc.Index(index_name)
65
- self.vectordb = LangchainPinecone.from_existing_index(
66
- index_name=index_name,
67
- text_key="text",
68
- embedding=self.embeddings
69
- )
 
 
 
70
 
71
- print(f"✅ Successfully connected to Pinecone index: {index_name}")
72
- except Exception as e:
73
- import traceback
74
- print(f"⚠️ Error connecting to Pinecone: {str(e)}")
75
- print(traceback.format_exc())
76
- # Set vectordb to None, the application will handle this gracefully
77
- self.vectordb = None
78
 
79
 
80
  async def scrape(self, url: str, max_chunks: int = 100) -> List[Dict[str, Any]]:
@@ -143,7 +129,6 @@ class F1AI:
143
 
144
  async def ingest(self, urls: List[str], max_chunks_per_url: int = 100) -> None:
145
  """Ingest data from URLs into the vector database."""
146
- from langchain_community.vectorstores import Pinecone as LangchainPinecone
147
  from tqdm import tqdm
148
 
149
  # Create empty list to store documents
@@ -161,12 +146,10 @@ class F1AI:
161
  metadatas = [doc["metadata"] for doc in all_docs]
162
 
163
  print("Starting embedding generation and uploading to Pinecone (this might take several minutes)...")
164
- self.vectordb = LangchainPinecone.from_texts(
 
165
  texts=texts,
166
- embedding=self.embeddings,
167
- index_name=self.index_name,
168
- metadatas=metadatas,
169
- text_key="text"
170
  )
171
 
172
  print("✅ Documents successfully uploaded to Pinecone!")
@@ -258,9 +241,9 @@ async def main():
258
  ask_parser = subparsers.add_parser("ask", help="Ask a question")
259
  ask_parser.add_argument("question", help="Question to ask")
260
 
261
- # Added provider argument with the new option
262
- parser.add_argument("--provider", choices=["ollama", "huggingface", "huggingface-openai"], default="huggingface",
263
- help="Provider for LLM and embeddings (default: huggingface)")
264
 
265
  args = parser.parse_args()
266
 
 
21
  load_dotenv()
22
 
23
  class F1AI:
24
+ def __init__(self, index_name: str = "f12", llm_provider: str = "openrouter"):
25
  """
26
  Initialize the F1-AI RAG application.
27
 
28
  Args:
29
  index_name (str): Name of the Pinecone index to use
30
+ llm_provider (str): Provider for LLM. "openrouter" is used by default.
 
31
  """
32
  self.index_name = index_name
33
 
34
+ # Initialize LLM via manager
35
  self.llm_manager = LLMManager(provider=llm_provider)
36
  self.llm = self.llm_manager.get_llm()
 
37
 
38
  # Load Pinecone API Key
39
  pinecone_api_key = os.getenv("PINECONE_API_KEY")
40
  if not pinecone_api_key:
41
  raise ValueError("❌ Pinecone API key missing! Set PINECONE_API_KEY in environment variables.")
42
 
 
 
43
  # Initialize Pinecone with v2 client
44
+ self.pc = Pinecone(api_key=pinecone_api_key)
 
45
 
46
+ # Check existing indexes
47
+ existing_indexes = [idx['name'] for idx in self.pc.list_indexes()]
48
 
49
+ if index_name not in existing_indexes:
50
+ raise ValueError(f" Pinecone index '{index_name}' does not exist! Please create it first.")
 
 
 
 
 
 
51
 
52
+ # Connect to Pinecone index
53
+ index = self.pc.Index(index_name)
54
+
55
+ # Use the existing pre-configured Pinecone index
56
+ # Note: We're using the embeddings that Pinecone already has configured
57
+ self.vectordb = LangchainPinecone(
58
+ index=index,
59
+ text_key="text",
60
+ embedding=self.llm_manager.get_embeddings() # This will only be used for new queries
61
+ )
62
 
63
+ print(f"✅ Successfully connected to Pinecone index: {index_name}")
 
 
 
 
 
 
64
 
65
 
66
  async def scrape(self, url: str, max_chunks: int = 100) -> List[Dict[str, Any]]:
 
129
 
130
  async def ingest(self, urls: List[str], max_chunks_per_url: int = 100) -> None:
131
  """Ingest data from URLs into the vector database."""
 
132
  from tqdm import tqdm
133
 
134
  # Create empty list to store documents
 
146
  metadatas = [doc["metadata"] for doc in all_docs]
147
 
148
  print("Starting embedding generation and uploading to Pinecone (this might take several minutes)...")
149
+ # Use the existing vectordb to add documents
150
+ self.vectordb.add_texts(
151
  texts=texts,
152
+ metadatas=metadatas
 
 
 
153
  )
154
 
155
  print("✅ Documents successfully uploaded to Pinecone!")
 
241
  ask_parser = subparsers.add_parser("ask", help="Ask a question")
242
  ask_parser.add_argument("question", help="Question to ask")
243
 
244
+ # Provider argument
245
+ parser.add_argument("--provider", choices=["ollama", "openrouter"], default="openrouter",
246
+ help="Provider for LLM (default: openrouter)")
247
 
248
  args = parser.parse_args()
249
 
llm_manager.py CHANGED
@@ -1,9 +1,9 @@
1
  import os
 
 
2
  from typing import List, Dict, Any
3
- from huggingface_hub import InferenceClient
4
- from langchain_ollama import OllamaEmbeddings, OllamaLLM
5
  from dotenv import load_dotenv
6
- import numpy as np
7
  import logging
8
 
9
  # Configure logging
@@ -15,181 +15,134 @@ load_dotenv()
15
 
16
  class LLMManager:
17
  """
18
- Manager class for handling different LLM and embedding models.
19
- Uses HuggingFace's InferenceClient directly for HuggingFace models.
20
  """
21
- def __init__(self, provider: str = "huggingface"):
22
  """
23
  Initialize the LLM Manager.
24
 
25
  Args:
26
- provider (str): The provider for LLM and embeddings.
27
- Options: "ollama", "huggingface", "huggingface-openai"
28
  """
29
  self.provider = provider
30
- self.llm_client = None
31
- self.embedding_client = None
32
 
33
- # Initialize models based on the provider
34
- if provider == "ollama":
35
- self._init_ollama()
36
- elif provider == "huggingface" or provider == "huggingface-openai":
37
- self._init_huggingface()
38
- else:
39
- raise ValueError(f"Unsupported provider: {provider}. Choose 'ollama', 'huggingface', or 'huggingface-openai'")
40
-
41
- def _init_ollama(self):
42
- """Initialize Ollama models."""
43
- self.llm = OllamaLLM(model="phi4-mini:3.8b")
44
- self.embeddings = OllamaEmbeddings(model="mxbai-embed-large:latest")
45
-
46
- def _init_huggingface(self):
47
- """Initialize HuggingFace models using InferenceClient directly."""
48
- # Get API key from environment
49
- api_key = os.getenv("HUGGINGFACE_API_KEY")
50
- if not api_key:
51
- raise ValueError("HuggingFace API key not found. Set HUGGINGFACE_API_KEY in environment variables.")
52
-
53
- llm_endpoint = "mistralai/Mixtral-8x7B-Instruct-v0.1"
54
- embedding_endpoint = "sentence-transformers/all-MiniLM-L6-v2"
55
-
56
- # Initialize InferenceClient for LLM
57
- self.llm_client = InferenceClient(
58
- model=llm_endpoint,
59
- token=api_key
60
- )
61
 
62
- # Initialize InferenceClient for embeddings
63
- self.embedding_client = InferenceClient(
64
- model=embedding_endpoint,
65
- token=api_key
66
- )
67
 
68
- # Store generation parameters
69
- self.generation_kwargs = {
70
- "temperature": 0.7,
71
- "max_new_tokens": 512, # Reduced to avoid potential token limit issues
72
- "repetition_penalty": 1.1,
73
- "do_sample": True,
74
- "top_k": 50,
75
- "top_p": 0.9,
76
- "return_full_text": False # Only return the generated text, not the prompt
77
  }
78
 
79
  # LLM methods for compatibility with LangChain
80
  def get_llm(self):
81
  """
82
- Return a callable object that mimics LangChain LLM interface.
83
- For huggingface providers, this returns a function that calls the InferenceClient.
84
- """
85
- if self.provider == "ollama":
86
- return self.llm
87
- else:
88
- # Return a function that wraps the InferenceClient for LLM
89
- def llm_function(prompt, **kwargs):
90
- params = {**self.generation_kwargs, **kwargs}
91
- try:
92
- logger.info(f"Sending prompt to HuggingFace (length: {len(prompt)})")
93
- response = self.llm_client.text_generation(
94
- prompt,
95
- details=True, # Get detailed response
96
- **params
97
- )
98
- # Extract generated text from response
99
- if isinstance(response, dict) and 'generated_text' in response:
100
- response = response['generated_text']
101
- logger.info(f"Received response from HuggingFace (length: {len(response) if response else 0})")
102
-
103
- # Ensure we get a valid string response
104
- if not response or not isinstance(response, str) or response.strip() == "":
105
- logger.warning("Empty or invalid response from HuggingFace, using fallback")
106
- return "I couldn't generate a proper response based on the available information."
107
-
108
- return response
109
- except Exception as e:
110
- logger.error(f"Error during LLM inference: {str(e)}")
111
- return f"Error generating response: {str(e)}"
112
-
113
- # Add async capability
114
- async def allm_function(prompt, **kwargs):
115
- params = {**self.generation_kwargs, **kwargs}
116
- try:
117
- response = await self.llm_client.text_generation(
118
- prompt,
119
- **params,
120
- stream=False
121
- )
122
-
123
- # Ensure we get a valid string response
124
- if not response or not isinstance(response, str) or response.strip() == "":
125
- logger.warning("Empty or invalid response from HuggingFace async, using fallback")
126
- return "I couldn't generate a proper response based on the available information."
127
-
128
- return response
129
- except Exception as e:
130
- logger.error(f"Error during async LLM inference: {str(e)}")
131
- return f"Error generating response: {str(e)}"
132
-
133
- llm_function.ainvoke = allm_function
134
- return llm_function
135
-
136
- # Embeddings methods for compatibility with LangChain
137
- def get_embeddings(self):
138
- """
139
- Return a callable object that mimics LangChain Embeddings interface.
140
- For huggingface providers, this returns an object with embed_documents and embed_query methods.
141
  """
142
- if self.provider == "ollama":
143
- return self.embeddings
144
- else:
145
- # Create a wrapper object that has the expected methods
146
- class EmbeddingsWrapper:
147
- def __init__(self, client):
148
- self.client = client
149
 
150
- def embed_documents(self, texts: List[str]) -> List[List[float]]:
151
- """Embed multiple documents."""
152
- embeddings = []
153
- # Process in batches to avoid overwhelming the API
154
- batch_size = 8
155
-
156
- for i in range(0, len(texts), batch_size):
157
- batch = texts[i:i+batch_size]
158
- try:
159
- batch_embeddings = self.client.feature_extraction(batch)
160
- # Convert to standard Python list format
161
- batch_results = [list(map(float, embedding)) for embedding in batch_embeddings]
162
- embeddings.extend(batch_results)
163
- except Exception as e:
164
- logger.error(f"Error embedding batch {i}: {str(e)}")
165
- # Return zero vectors as fallback
166
- for _ in range(len(batch)):
167
- embeddings.append([0.0] * 384) # Use correct dimension
168
-
169
- return embeddings
170
 
171
- def embed_query(self, text: str) -> List[float]:
172
- """Embed a single query."""
173
- try:
174
- embedding = self.client.feature_extraction(text)
175
- if isinstance(embedding, list) and len(embedding) > 0:
176
- # If it returns a batch (list of embeddings) for a single input
177
- return list(map(float, embedding[0]))
178
- # If it returns a single embedding
179
- return list(map(float, embedding))
180
- except Exception as e:
181
- logger.error(f"Error embedding query: {str(e)}")
182
- # Return zero vector as fallback
183
- return [0.0] * 384 # Use correct dimension
184
 
185
- # Make the class callable to fix the TypeError
186
- def __call__(self, texts):
187
- """Make the object callable for compatibility with LangChain."""
188
- if isinstance(texts, str):
189
- return self.embed_query(texts)
190
- elif isinstance(texts, list):
191
- return self.embed_documents(texts)
 
 
 
 
 
 
 
 
192
  else:
193
- raise ValueError(f"Unsupported input type: {type(texts)}")
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
- return EmbeddingsWrapper(self.embedding_client)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import json
3
+ import requests
4
  from typing import List, Dict, Any
5
+ from langchain_ollama import OllamaEmbeddings
 
6
  from dotenv import load_dotenv
 
7
  import logging
8
 
9
  # Configure logging
 
15
 
16
  class LLMManager:
17
  """
18
+ Manager class for handling Ollama embeddings and OpenRouter LLM.
 
19
  """
20
+ def __init__(self, provider: str = "openrouter"):
21
  """
22
  Initialize the LLM Manager.
23
 
24
  Args:
25
+ provider (str): "ollama" for embeddings, OpenRouter is used for LLM regardless
 
26
  """
27
  self.provider = provider
 
 
28
 
29
+ # Initialize Ollama embeddings
30
+ self.embeddings = OllamaEmbeddings(model="tazarov/all-minilm-l6-v2-f32:latest")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
+ # Initialize OpenRouter client
33
+ self.openrouter_api_key = os.getenv("OPENROUTER_API_KEY")
34
+ if not self.openrouter_api_key:
35
+ raise ValueError("OpenRouter API key not found. Set OPENROUTER_API_KEY in environment variables.")
 
36
 
37
+ # Set up OpenRouter API details
38
+ self.openrouter_url = "https://openrouter.ai/api/v1/chat/completions"
39
+ self.openrouter_model = "mistralai/mistral-7b-instruct:free"
40
+ self.openrouter_headers = {
41
+ "Authorization": f"Bearer {self.openrouter_api_key}",
42
+ "Content-Type": "application/json",
43
+ "HTTP-Referer": "https://f1-ai.app", # Replace with your app's URL
44
+ "X-Title": "F1-AI Application" # Replace with your app's name
 
45
  }
46
 
47
  # LLM methods for compatibility with LangChain
48
  def get_llm(self):
49
  """
50
+ Return a callable function that serves as the LLM interface.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  """
52
+ def llm_function(prompt, **kwargs):
53
+ try:
54
+ logger.info(f"Sending prompt to OpenRouter (length: {len(prompt)})")
 
 
 
 
55
 
56
+ # Format the messages for OpenRouter API
57
+ messages = [{"role": "user", "content": prompt}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
+ # Set up request payload
60
+ payload = {
61
+ "model": self.openrouter_model,
62
+ "messages": messages,
63
+ "temperature": kwargs.get("temperature", 0.7),
64
+ "max_tokens": kwargs.get("max_tokens", 1024),
65
+ "top_p": kwargs.get("top_p", 0.9),
66
+ "stream": False
67
+ }
 
 
 
 
68
 
69
+ # Send request to OpenRouter
70
+ response = requests.post(
71
+ self.openrouter_url,
72
+ headers=self.openrouter_headers,
73
+ json=payload,
74
+ timeout=60
75
+ )
76
+
77
+ # Process the response
78
+ if response.status_code == 200:
79
+ response_json = response.json()
80
+ if "choices" in response_json and len(response_json["choices"]) > 0:
81
+ generated_text = response_json["choices"][0]["message"]["content"]
82
+ logger.info(f"Received response from OpenRouter (length: {len(generated_text)})")
83
+ return generated_text
84
  else:
85
+ logger.warning("Unexpected response format from OpenRouter")
86
+ return "I couldn't generate a proper response based on the available information."
87
+ else:
88
+ logger.error(f"Error from OpenRouter API: {response.status_code} - {response.text}")
89
+ return f"Error from LLM API: {response.status_code}"
90
+
91
+ except Exception as e:
92
+ logger.error(f"Error during LLM inference: {str(e)}")
93
+ return f"Error generating response: {str(e)}"
94
+
95
+ # Add async capability
96
+ async def allm_function(prompt, **kwargs):
97
+ import aiohttp
98
 
99
+ try:
100
+ # Format the messages for OpenRouter API
101
+ messages = [{"role": "user", "content": prompt}]
102
+
103
+ # Set up request payload
104
+ payload = {
105
+ "model": self.openrouter_model,
106
+ "messages": messages,
107
+ "temperature": kwargs.get("temperature", 0.7),
108
+ "max_tokens": kwargs.get("max_tokens", 1024),
109
+ "top_p": kwargs.get("top_p", 0.9),
110
+ "stream": False
111
+ }
112
+
113
+ async with aiohttp.ClientSession() as session:
114
+ async with session.post(
115
+ self.openrouter_url,
116
+ headers=self.openrouter_headers,
117
+ json=payload,
118
+ timeout=aiohttp.ClientTimeout(total=60)
119
+ ) as response:
120
+ if response.status == 200:
121
+ response_json = await response.json()
122
+ if "choices" in response_json and len(response_json["choices"]) > 0:
123
+ generated_text = response_json["choices"][0]["message"]["content"]
124
+ return generated_text
125
+ else:
126
+ logger.warning("Unexpected response format from OpenRouter")
127
+ return "I couldn't generate a proper response based on the available information."
128
+ else:
129
+ error_text = await response.text()
130
+ logger.error(f"Error from OpenRouter API: {response.status} - {error_text}")
131
+ return f"Error from LLM API: {response.status}"
132
+
133
+ except Exception as e:
134
+ logger.error(f"Error during async LLM inference: {str(e)}")
135
+ return f"Error generating response: {str(e)}"
136
+
137
+ # Add async method to the function
138
+ llm_function.ainvoke = allm_function
139
+
140
+ # Add invoke method for compatibility
141
+ llm_function.invoke = llm_function
142
+
143
+ return llm_function
144
+
145
+ # Embeddings methods for compatibility with LangChain
146
+ def get_embeddings(self):
147
+ """Return the embeddings instance."""
148
+ return self.embeddings