AdityaAdaki commited on
Commit
180a8b0
·
1 Parent(s): 5f1dc39

initial deployment

Browse files
Files changed (7) hide show
  1. .gitignore +5 -0
  2. app.py +121 -0
  3. f1_ai.py +285 -0
  4. llm_manager.py +195 -0
  5. packages.txt +3 -0
  6. requirements.txt +15 -0
  7. setup.sh +4 -0
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ .env
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ .streamlit/secrets.toml
app.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import asyncio
3
+ import os
4
+ from f1_ai import F1AI
5
+ from dotenv import load_dotenv
6
+
7
+ # Load environment variables from .streamlit/secrets.toml into os.environ
8
+ for key, value in st.secrets.items():
9
+ os.environ[key] = value
10
+
11
+ # Initialize session state
12
+ if 'f1_ai' not in st.session_state:
13
+ # Use HuggingFace by default for Spaces deployment
14
+ st.session_state.f1_ai = F1AI(llm_provider="huggingface")
15
+ if 'chat_history' not in st.session_state:
16
+ st.session_state.chat_history = []
17
+
18
+ # Set page config
19
+ st.set_page_config(page_title="F1-AI: Formula 1 RAG Application", layout="wide")
20
+
21
+ # Title and description
22
+ st.title("F1-AI: Formula 1 RAG Application")
23
+ st.markdown("""
24
+ This application uses Retrieval-Augmented Generation (RAG) to answer questions about Formula 1.
25
+ """)
26
+
27
+ # Add tabs
28
+ tab1, tab2 = st.tabs(["Chat", "Add Content"])
29
+
30
+ with tab1:
31
+ # Custom CSS for better styling
32
+ st.markdown("""
33
+ <style>
34
+ .stChatMessage {
35
+ padding: 1rem;
36
+ border-radius: 0.5rem;
37
+ margin-bottom: 1rem;
38
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
39
+ }
40
+ .stChatMessage.user {
41
+ background-color: #f0f2f6;
42
+ }
43
+ .stChatMessage.assistant {
44
+ background-color: #ffffff;
45
+ }
46
+ .source-link {
47
+ font-size: 0.8rem;
48
+ color: #666;
49
+ text-decoration: none;
50
+ }
51
+ </style>
52
+ """, unsafe_allow_html=True)
53
+
54
+ # Display chat history with enhanced formatting
55
+ for message in st.session_state.chat_history:
56
+ with st.chat_message(message["role"]):
57
+ if message["role"] == "assistant" and isinstance(message["content"], dict):
58
+ st.markdown(message["content"]["answer"])
59
+ if message["content"]["sources"]:
60
+ st.markdown("---")
61
+ st.markdown("**Sources:**")
62
+ for source in message["content"]["sources"]:
63
+ st.markdown(f"- [{source['url']}]({source['url']})")
64
+ else:
65
+ st.markdown(message["content"])
66
+
67
+ # Question input
68
+ if question := st.chat_input("Ask a question about Formula 1"):
69
+ # Add user question to chat history
70
+ st.session_state.chat_history.append({"role": "user", "content": question})
71
+
72
+ # Display user question
73
+ with st.chat_message("user"):
74
+ st.write(question)
75
+
76
+ # Generate and display response with enhanced formatting
77
+ with st.chat_message("assistant"):
78
+ with st.spinner("🤔 Analyzing Formula 1 knowledge..."):
79
+ response = asyncio.run(st.session_state.f1_ai.ask_question(question))
80
+ st.markdown(response["answer"])
81
+
82
+ # Display sources if available
83
+ if response["sources"]:
84
+ st.markdown("---")
85
+ st.markdown("**Sources:**")
86
+ for source in response["sources"]:
87
+ st.markdown(f"- [{source['url']}]({source['url']})")
88
+
89
+ # Add assistant response to chat history
90
+ st.session_state.chat_history.append({"role": "assistant", "content": response})
91
+
92
+ with tab2:
93
+ st.header("Add Content to Knowledge Base")
94
+
95
+ urls_input = st.text_area("Enter URLs (one per line)",
96
+ placeholder="https://en.wikipedia.org/wiki/Formula_One\nhttps://www.formula1.com/en/latest/article....")
97
+
98
+ max_chunks = st.slider("Maximum chunks per URL", min_value=10, max_value=500, value=100, step=10)
99
+
100
+ if st.button("Ingest Data"):
101
+ if urls_input:
102
+ urls = [url.strip() for url in urls_input.split("\n") if url.strip()]
103
+ if urls:
104
+ with st.spinner(f"Ingesting data from {len(urls)} URLs... This may take several minutes."):
105
+ progress_bar = st.progress(0)
106
+
107
+ # Process URLs one by one for better UI feedback
108
+ for i, url in enumerate(urls):
109
+ st.write(f"Processing: {url}")
110
+ asyncio.run(st.session_state.f1_ai.ingest([url], max_chunks_per_url=max_chunks))
111
+ progress_bar.progress((i + 1) / len(urls))
112
+
113
+ st.success("✅ Data ingestion complete!")
114
+ else:
115
+ st.error("Please enter at least one valid URL.")
116
+ else:
117
+ st.error("Please enter at least one URL to ingest.")
118
+
119
+ # Add a footer with credits
120
+ st.markdown("---")
121
+ st.markdown("F1-AI: A Formula 1 RAG Application • Powered by Hugging Face, Pinecone, and LangChain")
f1_ai.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import logging
4
+ from datetime import datetime
5
+ from dotenv import load_dotenv
6
+ from typing import List, Dict, Any, Optional, Tuple
7
+ from rich.console import Console
8
+ from rich.markdown import Markdown
9
+ from pinecone import Pinecone
10
+ from langchain_pinecone import Pinecone as LangchainPinecone
11
+
12
+ # Import our custom LLM Manager
13
+ from llm_manager import LLMManager
14
+
15
+ # Configure logging
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+ console = Console()
19
+
20
+ # Load environment variables
21
+ load_dotenv()
22
+
23
+ class F1AI:
24
+ def __init__(self, index_name: str = "f12", llm_provider: str = "huggingface"):
25
+ """
26
+ Initialize the F1-AI RAG application.
27
+
28
+ Args:
29
+ index_name (str): Name of the Pinecone index to use
30
+ llm_provider (str): Provider for LLM and embeddings.
31
+ Options: "ollama", "huggingface", "huggingface-openai"
32
+ """
33
+ self.index_name = index_name
34
+
35
+ # Initialize LLM and embeddings via manager
36
+ self.llm_manager = LLMManager(provider=llm_provider)
37
+ self.llm = self.llm_manager.get_llm()
38
+ self.embeddings = self.llm_manager.get_embeddings()
39
+
40
+ # Load Pinecone API Key
41
+ pinecone_api_key = os.getenv("PINECONE_API_KEY")
42
+ if not pinecone_api_key:
43
+ raise ValueError("❌ Pinecone API key missing! Set PINECONE_API_KEY in environment variables.")
44
+
45
+ # Modify this part in f1_ai.py
46
+
47
+ # Initialize Pinecone with v2 client
48
+ try:
49
+ self.pc = Pinecone(api_key=pinecone_api_key)
50
+
51
+ # Check existing indexes
52
+ existing_indexes = [idx['name'] for idx in self.pc.list_indexes()]
53
+
54
+ if index_name not in existing_indexes:
55
+ console.log(f"🚀 Creating Pinecone index: {index_name}")
56
+ # Update the dimension to match your embedding model
57
+ self.pc.create_index(
58
+ name=index_name,
59
+ dimension=384, # Match embedding dimensions of the model
60
+ metric="cosine"
61
+ )
62
+
63
+ # Connect to Pinecone index
64
+ index = self.pc.Index(index_name)
65
+ self.vectordb = LangchainPinecone.from_existing_index(
66
+ index_name=index_name,
67
+ text_key="text",
68
+ embedding=self.embeddings
69
+ )
70
+
71
+ print(f"✅ Successfully connected to Pinecone index: {index_name}")
72
+ except Exception as e:
73
+ import traceback
74
+ print(f"⚠️ Error connecting to Pinecone: {str(e)}")
75
+ print(traceback.format_exc())
76
+ # Set vectordb to None, the application will handle this gracefully
77
+ self.vectordb = None
78
+
79
+
80
+ async def scrape(self, url: str, max_chunks: int = 100) -> List[Dict[str, Any]]:
81
+ """Scrape content from a URL and split into chunks with improved error handling."""
82
+ from playwright.async_api import async_playwright, TimeoutError
83
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
84
+ from bs4 import BeautifulSoup
85
+
86
+ try:
87
+ async with async_playwright() as p:
88
+ browser = await p.chromium.launch()
89
+ page = await browser.new_page()
90
+ console.log(f"[blue]Loading {url}...[/blue]")
91
+
92
+ try:
93
+ await page.goto(url, timeout=30000)
94
+ # Get HTML content
95
+ html_content = await page.content()
96
+ soup = BeautifulSoup(html_content, 'html.parser')
97
+
98
+ # Remove unwanted elements
99
+ for element in soup.find_all(['script', 'style', 'nav', 'footer']):
100
+ element.decompose()
101
+
102
+ text = soup.get_text(separator=' ', strip=True)
103
+ except TimeoutError:
104
+ logger.error(f"Timeout while loading {url}")
105
+ return []
106
+ finally:
107
+ await browser.close()
108
+
109
+ console.log(f"[green]Processing text ({len(text)} characters)...[/green]")
110
+
111
+ # Enhanced text cleaning
112
+ text = ' '.join(text.split()) # Normalize whitespace
113
+
114
+ # Improved text splitting with semantic boundaries
115
+ splitter = RecursiveCharacterTextSplitter(
116
+ chunk_size=512,
117
+ chunk_overlap=50,
118
+ separators=["\n\n", "\n", ".", "!", "?", ",", " "],
119
+ length_function=len
120
+ )
121
+
122
+ docs = splitter.create_documents([text])
123
+
124
+ # Limit the number of chunks
125
+ limited_docs = docs[:max_chunks]
126
+ console.log(f"[yellow]Using {len(limited_docs)} chunks out of {len(docs)} total chunks[/yellow]")
127
+
128
+ # Enhanced metadata
129
+ timestamp = datetime.now().isoformat()
130
+ return [{
131
+ "page_content": doc.page_content,
132
+ "metadata": {
133
+ "source": url,
134
+ "chunk_index": i,
135
+ "total_chunks": len(limited_docs),
136
+ "timestamp": timestamp
137
+ }
138
+ } for i, doc in enumerate(limited_docs)]
139
+
140
+ except Exception as e:
141
+ logger.error(f"Error scraping {url}: {str(e)}")
142
+ return []
143
+
144
+ async def ingest(self, urls: List[str], max_chunks_per_url: int = 100) -> None:
145
+ """Ingest data from URLs into the vector database."""
146
+ from langchain_community.vectorstores import Pinecone as LangchainPinecone
147
+ from tqdm import tqdm
148
+
149
+ # Create empty list to store documents
150
+ all_docs = []
151
+
152
+ # Scrape and process each URL with progress bar
153
+ for url in tqdm(urls, desc="Scraping URLs"):
154
+ chunks = await self.scrape(url, max_chunks=max_chunks_per_url)
155
+ all_docs.extend(chunks)
156
+
157
+ # Create or update vector database
158
+ total_docs = len(all_docs)
159
+ print(f"\nCreating vector database with {total_docs} documents...")
160
+ texts = [doc["page_content"] for doc in all_docs]
161
+ metadatas = [doc["metadata"] for doc in all_docs]
162
+
163
+ print("Starting embedding generation and uploading to Pinecone (this might take several minutes)...")
164
+ self.vectordb = LangchainPinecone.from_texts(
165
+ texts=texts,
166
+ embedding=self.embeddings,
167
+ index_name=self.index_name,
168
+ metadatas=metadatas,
169
+ text_key="text"
170
+ )
171
+
172
+ print("✅ Documents successfully uploaded to Pinecone!")
173
+
174
+ async def ask_question(self, question: str) -> Dict[str, Any]:
175
+ """Ask a question and get a response using RAG."""
176
+ if not self.vectordb:
177
+ return {"answer": "Error: Vector database not initialized. Please ingest data first.", "sources": []}
178
+
179
+ try:
180
+ # Retrieve relevant documents with similarity search
181
+ retriever = self.vectordb.as_retriever(
182
+ search_type="similarity",
183
+ search_kwargs={"k": 5}
184
+ )
185
+
186
+ # Get relevant documents
187
+ docs = retriever.get_relevant_documents(question)
188
+
189
+ if not docs:
190
+ return {
191
+ "answer": "I couldn't find any relevant information in my knowledge base. Please try a different question or ingest more relevant data.",
192
+ "sources": []
193
+ }
194
+
195
+ # Format context from documents
196
+ context = "\n\n".join([f"Document {i+1}: {doc.page_content}" for i, doc in enumerate(docs)])
197
+
198
+ # Create prompt for the LLM
199
+ prompt = f"""
200
+ Answer the question based on the provided context. Include relevant citations using [1], [2], etc.
201
+ If you're unsure or if the context doesn't contain the information, acknowledge the uncertainty.
202
+
203
+ Context:
204
+ {context}
205
+
206
+ Question: {question}
207
+
208
+ Answer with citations:
209
+ """
210
+
211
+ # Get response from LLM
212
+ response_text = ""
213
+ if hasattr(self.llm, "__call__"): # Direct inference client wrapped function
214
+ response_text = self.llm(prompt)
215
+ # Debug response
216
+ logger.info(f"Raw LLM response type: {type(response_text)}")
217
+ if not response_text or response_text.strip() == "":
218
+ logger.error("Empty response from LLM")
219
+ response_text = "I apologize, but I couldn't generate a response. This might be due to an issue with the language model."
220
+ else: # LangChain LLM
221
+ response_text = self.llm.invoke(prompt)
222
+
223
+ # Format sources
224
+ sources = [{
225
+ "url": doc.metadata["source"],
226
+ "chunk_index": doc.metadata.get("chunk_index", 0),
227
+ "timestamp": doc.metadata.get("timestamp", "")
228
+ } for doc in docs]
229
+
230
+ # Format response
231
+ formatted_response = {
232
+ "answer": response_text,
233
+ "sources": sources
234
+ }
235
+
236
+ return formatted_response
237
+
238
+ except Exception as e:
239
+ logger.error(f"Error processing question: {str(e)}")
240
+ return {
241
+ "answer": f"I apologize, but I encountered an error while processing your question: {str(e)}",
242
+ "sources": []
243
+ }
244
+
245
+ async def main():
246
+ """Main function to run the application."""
247
+ import asyncio
248
+
249
+ parser = argparse.ArgumentParser(description="F1-AI: RAG Application for Formula 1 information")
250
+ subparsers = parser.add_subparsers(dest="command", help="Command to run")
251
+
252
+ # Ingest command
253
+ ingest_parser = subparsers.add_parser("ingest", help="Ingest data from URLs")
254
+ ingest_parser.add_argument("--urls", nargs="+", required=True, help="URLs to scrape")
255
+ ingest_parser.add_argument("--max-chunks", type=int, default=100, help="Maximum chunks per URL")
256
+
257
+ # Ask command
258
+ ask_parser = subparsers.add_parser("ask", help="Ask a question")
259
+ ask_parser.add_argument("question", help="Question to ask")
260
+
261
+ # Added provider argument with the new option
262
+ parser.add_argument("--provider", choices=["ollama", "huggingface", "huggingface-openai"], default="huggingface",
263
+ help="Provider for LLM and embeddings (default: huggingface)")
264
+
265
+ args = parser.parse_args()
266
+
267
+ f1_ai = F1AI(llm_provider=args.provider)
268
+
269
+ if args.command == "ingest":
270
+ await f1_ai.ingest(args.urls, max_chunks_per_url=args.max_chunks)
271
+ elif args.command == "ask":
272
+ response = await f1_ai.ask_question(args.question)
273
+ console.print("\n[bold green]Answer:[/bold green]")
274
+ # Format as markdown to make it prettier
275
+ console.print(Markdown(response['answer']))
276
+
277
+ console.print("\n[bold yellow]Sources:[/bold yellow]")
278
+ for i, source in enumerate(response['sources']):
279
+ console.print(f"[{i+1}] {source['url']}")
280
+ else:
281
+ parser.print_help()
282
+
283
+ if __name__ == "__main__":
284
+ import asyncio
285
+ asyncio.run(main())
llm_manager.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Dict, Any
3
+ from huggingface_hub import InferenceClient
4
+ from langchain_ollama import OllamaEmbeddings, OllamaLLM
5
+ from dotenv import load_dotenv
6
+ import numpy as np
7
+ import logging
8
+
9
+ # Configure logging
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
+
13
+ # Load environment variables
14
+ load_dotenv()
15
+
16
+ class LLMManager:
17
+ """
18
+ Manager class for handling different LLM and embedding models.
19
+ Uses HuggingFace's InferenceClient directly for HuggingFace models.
20
+ """
21
+ def __init__(self, provider: str = "huggingface"):
22
+ """
23
+ Initialize the LLM Manager.
24
+
25
+ Args:
26
+ provider (str): The provider for LLM and embeddings.
27
+ Options: "ollama", "huggingface", "huggingface-openai"
28
+ """
29
+ self.provider = provider
30
+ self.llm_client = None
31
+ self.embedding_client = None
32
+
33
+ # Initialize models based on the provider
34
+ if provider == "ollama":
35
+ self._init_ollama()
36
+ elif provider == "huggingface" or provider == "huggingface-openai":
37
+ self._init_huggingface()
38
+ else:
39
+ raise ValueError(f"Unsupported provider: {provider}. Choose 'ollama', 'huggingface', or 'huggingface-openai'")
40
+
41
+ def _init_ollama(self):
42
+ """Initialize Ollama models."""
43
+ self.llm = OllamaLLM(model="phi4-mini:3.8b")
44
+ self.embeddings = OllamaEmbeddings(model="mxbai-embed-large:latest")
45
+
46
+ def _init_huggingface(self):
47
+ """Initialize HuggingFace models using InferenceClient directly."""
48
+ # Get API key from environment
49
+ api_key = os.getenv("HUGGINGFACE_API_KEY")
50
+ if not api_key:
51
+ raise ValueError("HuggingFace API key not found. Set HUGGINGFACE_API_KEY in environment variables.")
52
+
53
+ llm_endpoint = "mistralai/Mixtral-8x7B-Instruct-v0.1"
54
+ embedding_endpoint = "sentence-transformers/all-MiniLM-L6-v2"
55
+
56
+ # Initialize InferenceClient for LLM
57
+ self.llm_client = InferenceClient(
58
+ model=llm_endpoint,
59
+ token=api_key
60
+ )
61
+
62
+ # Initialize InferenceClient for embeddings
63
+ self.embedding_client = InferenceClient(
64
+ model=embedding_endpoint,
65
+ token=api_key
66
+ )
67
+
68
+ # Store generation parameters
69
+ self.generation_kwargs = {
70
+ "temperature": 0.7,
71
+ "max_new_tokens": 512, # Reduced to avoid potential token limit issues
72
+ "repetition_penalty": 1.1,
73
+ "do_sample": True,
74
+ "top_k": 50,
75
+ "top_p": 0.9,
76
+ "return_full_text": False # Only return the generated text, not the prompt
77
+ }
78
+
79
+ # LLM methods for compatibility with LangChain
80
+ def get_llm(self):
81
+ """
82
+ Return a callable object that mimics LangChain LLM interface.
83
+ For huggingface providers, this returns a function that calls the InferenceClient.
84
+ """
85
+ if self.provider == "ollama":
86
+ return self.llm
87
+ else:
88
+ # Return a function that wraps the InferenceClient for LLM
89
+ def llm_function(prompt, **kwargs):
90
+ params = {**self.generation_kwargs, **kwargs}
91
+ try:
92
+ logger.info(f"Sending prompt to HuggingFace (length: {len(prompt)})")
93
+ response = self.llm_client.text_generation(
94
+ prompt,
95
+ details=True, # Get detailed response
96
+ **params
97
+ )
98
+ # Extract generated text from response
99
+ if isinstance(response, dict) and 'generated_text' in response:
100
+ response = response['generated_text']
101
+ logger.info(f"Received response from HuggingFace (length: {len(response) if response else 0})")
102
+
103
+ # Ensure we get a valid string response
104
+ if not response or not isinstance(response, str) or response.strip() == "":
105
+ logger.warning("Empty or invalid response from HuggingFace, using fallback")
106
+ return "I couldn't generate a proper response based on the available information."
107
+
108
+ return response
109
+ except Exception as e:
110
+ logger.error(f"Error during LLM inference: {str(e)}")
111
+ return f"Error generating response: {str(e)}"
112
+
113
+ # Add async capability
114
+ async def allm_function(prompt, **kwargs):
115
+ params = {**self.generation_kwargs, **kwargs}
116
+ try:
117
+ response = await self.llm_client.text_generation(
118
+ prompt,
119
+ **params,
120
+ stream=False
121
+ )
122
+
123
+ # Ensure we get a valid string response
124
+ if not response or not isinstance(response, str) or response.strip() == "":
125
+ logger.warning("Empty or invalid response from HuggingFace async, using fallback")
126
+ return "I couldn't generate a proper response based on the available information."
127
+
128
+ return response
129
+ except Exception as e:
130
+ logger.error(f"Error during async LLM inference: {str(e)}")
131
+ return f"Error generating response: {str(e)}"
132
+
133
+ llm_function.ainvoke = allm_function
134
+ return llm_function
135
+
136
+ # Embeddings methods for compatibility with LangChain
137
+ def get_embeddings(self):
138
+ """
139
+ Return a callable object that mimics LangChain Embeddings interface.
140
+ For huggingface providers, this returns an object with embed_documents and embed_query methods.
141
+ """
142
+ if self.provider == "ollama":
143
+ return self.embeddings
144
+ else:
145
+ # Create a wrapper object that has the expected methods
146
+ class EmbeddingsWrapper:
147
+ def __init__(self, client):
148
+ self.client = client
149
+
150
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
151
+ """Embed multiple documents."""
152
+ embeddings = []
153
+ # Process in batches to avoid overwhelming the API
154
+ batch_size = 8
155
+
156
+ for i in range(0, len(texts), batch_size):
157
+ batch = texts[i:i+batch_size]
158
+ try:
159
+ batch_embeddings = self.client.feature_extraction(batch)
160
+ # Convert to standard Python list format
161
+ batch_results = [list(map(float, embedding)) for embedding in batch_embeddings]
162
+ embeddings.extend(batch_results)
163
+ except Exception as e:
164
+ logger.error(f"Error embedding batch {i}: {str(e)}")
165
+ # Return zero vectors as fallback
166
+ for _ in range(len(batch)):
167
+ embeddings.append([0.0] * 384) # Use correct dimension
168
+
169
+ return embeddings
170
+
171
+ def embed_query(self, text: str) -> List[float]:
172
+ """Embed a single query."""
173
+ try:
174
+ embedding = self.client.feature_extraction(text)
175
+ if isinstance(embedding, list) and len(embedding) > 0:
176
+ # If it returns a batch (list of embeddings) for a single input
177
+ return list(map(float, embedding[0]))
178
+ # If it returns a single embedding
179
+ return list(map(float, embedding))
180
+ except Exception as e:
181
+ logger.error(f"Error embedding query: {str(e)}")
182
+ # Return zero vector as fallback
183
+ return [0.0] * 384 # Use correct dimension
184
+
185
+ # Make the class callable to fix the TypeError
186
+ def __call__(self, texts):
187
+ """Make the object callable for compatibility with LangChain."""
188
+ if isinstance(texts, str):
189
+ return self.embed_query(texts)
190
+ elif isinstance(texts, list):
191
+ return self.embed_documents(texts)
192
+ else:
193
+ raise ValueError(f"Unsupported input type: {type(texts)}")
194
+
195
+ return EmbeddingsWrapper(self.embedding_client)
packages.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ chromium
2
+ wget
3
+ ca-certificates
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ langchain==0.1.0
2
+ langchain-community==0.0.10
3
+ langchain-core==0.1.0
4
+ langchain-pinecone==0.0.1
5
+ langchain-ollama==0.0.1
6
+ pinecone-client==3.0.1
7
+ huggingface-hub==0.20.1
8
+ streamlit==1.29.0
9
+ playwright==1.40.0
10
+ beautifulsoup4==4.12.2
11
+ tqdm==4.66.1
12
+ python-dotenv==1.0.0
13
+ typing-extensions==4.8.0
14
+ rich==13.7.0
15
+ # Remove asyncio package as it's part of Python standard library
setup.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Install playwright browsers
3
+ pip install playwright
4
+ playwright install chromium