Spaces:
Sleeping
Sleeping
AdityaAdaki
commited on
Commit
·
180a8b0
1
Parent(s):
5f1dc39
initial deployment
Browse files- .gitignore +5 -0
- app.py +121 -0
- f1_ai.py +285 -0
- llm_manager.py +195 -0
- packages.txt +3 -0
- requirements.txt +15 -0
- setup.sh +4 -0
.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.env
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
.streamlit/secrets.toml
|
app.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import asyncio
|
3 |
+
import os
|
4 |
+
from f1_ai import F1AI
|
5 |
+
from dotenv import load_dotenv
|
6 |
+
|
7 |
+
# Load environment variables from .streamlit/secrets.toml into os.environ
|
8 |
+
for key, value in st.secrets.items():
|
9 |
+
os.environ[key] = value
|
10 |
+
|
11 |
+
# Initialize session state
|
12 |
+
if 'f1_ai' not in st.session_state:
|
13 |
+
# Use HuggingFace by default for Spaces deployment
|
14 |
+
st.session_state.f1_ai = F1AI(llm_provider="huggingface")
|
15 |
+
if 'chat_history' not in st.session_state:
|
16 |
+
st.session_state.chat_history = []
|
17 |
+
|
18 |
+
# Set page config
|
19 |
+
st.set_page_config(page_title="F1-AI: Formula 1 RAG Application", layout="wide")
|
20 |
+
|
21 |
+
# Title and description
|
22 |
+
st.title("F1-AI: Formula 1 RAG Application")
|
23 |
+
st.markdown("""
|
24 |
+
This application uses Retrieval-Augmented Generation (RAG) to answer questions about Formula 1.
|
25 |
+
""")
|
26 |
+
|
27 |
+
# Add tabs
|
28 |
+
tab1, tab2 = st.tabs(["Chat", "Add Content"])
|
29 |
+
|
30 |
+
with tab1:
|
31 |
+
# Custom CSS for better styling
|
32 |
+
st.markdown("""
|
33 |
+
<style>
|
34 |
+
.stChatMessage {
|
35 |
+
padding: 1rem;
|
36 |
+
border-radius: 0.5rem;
|
37 |
+
margin-bottom: 1rem;
|
38 |
+
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
39 |
+
}
|
40 |
+
.stChatMessage.user {
|
41 |
+
background-color: #f0f2f6;
|
42 |
+
}
|
43 |
+
.stChatMessage.assistant {
|
44 |
+
background-color: #ffffff;
|
45 |
+
}
|
46 |
+
.source-link {
|
47 |
+
font-size: 0.8rem;
|
48 |
+
color: #666;
|
49 |
+
text-decoration: none;
|
50 |
+
}
|
51 |
+
</style>
|
52 |
+
""", unsafe_allow_html=True)
|
53 |
+
|
54 |
+
# Display chat history with enhanced formatting
|
55 |
+
for message in st.session_state.chat_history:
|
56 |
+
with st.chat_message(message["role"]):
|
57 |
+
if message["role"] == "assistant" and isinstance(message["content"], dict):
|
58 |
+
st.markdown(message["content"]["answer"])
|
59 |
+
if message["content"]["sources"]:
|
60 |
+
st.markdown("---")
|
61 |
+
st.markdown("**Sources:**")
|
62 |
+
for source in message["content"]["sources"]:
|
63 |
+
st.markdown(f"- [{source['url']}]({source['url']})")
|
64 |
+
else:
|
65 |
+
st.markdown(message["content"])
|
66 |
+
|
67 |
+
# Question input
|
68 |
+
if question := st.chat_input("Ask a question about Formula 1"):
|
69 |
+
# Add user question to chat history
|
70 |
+
st.session_state.chat_history.append({"role": "user", "content": question})
|
71 |
+
|
72 |
+
# Display user question
|
73 |
+
with st.chat_message("user"):
|
74 |
+
st.write(question)
|
75 |
+
|
76 |
+
# Generate and display response with enhanced formatting
|
77 |
+
with st.chat_message("assistant"):
|
78 |
+
with st.spinner("🤔 Analyzing Formula 1 knowledge..."):
|
79 |
+
response = asyncio.run(st.session_state.f1_ai.ask_question(question))
|
80 |
+
st.markdown(response["answer"])
|
81 |
+
|
82 |
+
# Display sources if available
|
83 |
+
if response["sources"]:
|
84 |
+
st.markdown("---")
|
85 |
+
st.markdown("**Sources:**")
|
86 |
+
for source in response["sources"]:
|
87 |
+
st.markdown(f"- [{source['url']}]({source['url']})")
|
88 |
+
|
89 |
+
# Add assistant response to chat history
|
90 |
+
st.session_state.chat_history.append({"role": "assistant", "content": response})
|
91 |
+
|
92 |
+
with tab2:
|
93 |
+
st.header("Add Content to Knowledge Base")
|
94 |
+
|
95 |
+
urls_input = st.text_area("Enter URLs (one per line)",
|
96 |
+
placeholder="https://en.wikipedia.org/wiki/Formula_One\nhttps://www.formula1.com/en/latest/article....")
|
97 |
+
|
98 |
+
max_chunks = st.slider("Maximum chunks per URL", min_value=10, max_value=500, value=100, step=10)
|
99 |
+
|
100 |
+
if st.button("Ingest Data"):
|
101 |
+
if urls_input:
|
102 |
+
urls = [url.strip() for url in urls_input.split("\n") if url.strip()]
|
103 |
+
if urls:
|
104 |
+
with st.spinner(f"Ingesting data from {len(urls)} URLs... This may take several minutes."):
|
105 |
+
progress_bar = st.progress(0)
|
106 |
+
|
107 |
+
# Process URLs one by one for better UI feedback
|
108 |
+
for i, url in enumerate(urls):
|
109 |
+
st.write(f"Processing: {url}")
|
110 |
+
asyncio.run(st.session_state.f1_ai.ingest([url], max_chunks_per_url=max_chunks))
|
111 |
+
progress_bar.progress((i + 1) / len(urls))
|
112 |
+
|
113 |
+
st.success("✅ Data ingestion complete!")
|
114 |
+
else:
|
115 |
+
st.error("Please enter at least one valid URL.")
|
116 |
+
else:
|
117 |
+
st.error("Please enter at least one URL to ingest.")
|
118 |
+
|
119 |
+
# Add a footer with credits
|
120 |
+
st.markdown("---")
|
121 |
+
st.markdown("F1-AI: A Formula 1 RAG Application • Powered by Hugging Face, Pinecone, and LangChain")
|
f1_ai.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import logging
|
4 |
+
from datetime import datetime
|
5 |
+
from dotenv import load_dotenv
|
6 |
+
from typing import List, Dict, Any, Optional, Tuple
|
7 |
+
from rich.console import Console
|
8 |
+
from rich.markdown import Markdown
|
9 |
+
from pinecone import Pinecone
|
10 |
+
from langchain_pinecone import Pinecone as LangchainPinecone
|
11 |
+
|
12 |
+
# Import our custom LLM Manager
|
13 |
+
from llm_manager import LLMManager
|
14 |
+
|
15 |
+
# Configure logging
|
16 |
+
logging.basicConfig(level=logging.INFO)
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
console = Console()
|
19 |
+
|
20 |
+
# Load environment variables
|
21 |
+
load_dotenv()
|
22 |
+
|
23 |
+
class F1AI:
|
24 |
+
def __init__(self, index_name: str = "f12", llm_provider: str = "huggingface"):
|
25 |
+
"""
|
26 |
+
Initialize the F1-AI RAG application.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
index_name (str): Name of the Pinecone index to use
|
30 |
+
llm_provider (str): Provider for LLM and embeddings.
|
31 |
+
Options: "ollama", "huggingface", "huggingface-openai"
|
32 |
+
"""
|
33 |
+
self.index_name = index_name
|
34 |
+
|
35 |
+
# Initialize LLM and embeddings via manager
|
36 |
+
self.llm_manager = LLMManager(provider=llm_provider)
|
37 |
+
self.llm = self.llm_manager.get_llm()
|
38 |
+
self.embeddings = self.llm_manager.get_embeddings()
|
39 |
+
|
40 |
+
# Load Pinecone API Key
|
41 |
+
pinecone_api_key = os.getenv("PINECONE_API_KEY")
|
42 |
+
if not pinecone_api_key:
|
43 |
+
raise ValueError("❌ Pinecone API key missing! Set PINECONE_API_KEY in environment variables.")
|
44 |
+
|
45 |
+
# Modify this part in f1_ai.py
|
46 |
+
|
47 |
+
# Initialize Pinecone with v2 client
|
48 |
+
try:
|
49 |
+
self.pc = Pinecone(api_key=pinecone_api_key)
|
50 |
+
|
51 |
+
# Check existing indexes
|
52 |
+
existing_indexes = [idx['name'] for idx in self.pc.list_indexes()]
|
53 |
+
|
54 |
+
if index_name not in existing_indexes:
|
55 |
+
console.log(f"🚀 Creating Pinecone index: {index_name}")
|
56 |
+
# Update the dimension to match your embedding model
|
57 |
+
self.pc.create_index(
|
58 |
+
name=index_name,
|
59 |
+
dimension=384, # Match embedding dimensions of the model
|
60 |
+
metric="cosine"
|
61 |
+
)
|
62 |
+
|
63 |
+
# Connect to Pinecone index
|
64 |
+
index = self.pc.Index(index_name)
|
65 |
+
self.vectordb = LangchainPinecone.from_existing_index(
|
66 |
+
index_name=index_name,
|
67 |
+
text_key="text",
|
68 |
+
embedding=self.embeddings
|
69 |
+
)
|
70 |
+
|
71 |
+
print(f"✅ Successfully connected to Pinecone index: {index_name}")
|
72 |
+
except Exception as e:
|
73 |
+
import traceback
|
74 |
+
print(f"⚠️ Error connecting to Pinecone: {str(e)}")
|
75 |
+
print(traceback.format_exc())
|
76 |
+
# Set vectordb to None, the application will handle this gracefully
|
77 |
+
self.vectordb = None
|
78 |
+
|
79 |
+
|
80 |
+
async def scrape(self, url: str, max_chunks: int = 100) -> List[Dict[str, Any]]:
|
81 |
+
"""Scrape content from a URL and split into chunks with improved error handling."""
|
82 |
+
from playwright.async_api import async_playwright, TimeoutError
|
83 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
84 |
+
from bs4 import BeautifulSoup
|
85 |
+
|
86 |
+
try:
|
87 |
+
async with async_playwright() as p:
|
88 |
+
browser = await p.chromium.launch()
|
89 |
+
page = await browser.new_page()
|
90 |
+
console.log(f"[blue]Loading {url}...[/blue]")
|
91 |
+
|
92 |
+
try:
|
93 |
+
await page.goto(url, timeout=30000)
|
94 |
+
# Get HTML content
|
95 |
+
html_content = await page.content()
|
96 |
+
soup = BeautifulSoup(html_content, 'html.parser')
|
97 |
+
|
98 |
+
# Remove unwanted elements
|
99 |
+
for element in soup.find_all(['script', 'style', 'nav', 'footer']):
|
100 |
+
element.decompose()
|
101 |
+
|
102 |
+
text = soup.get_text(separator=' ', strip=True)
|
103 |
+
except TimeoutError:
|
104 |
+
logger.error(f"Timeout while loading {url}")
|
105 |
+
return []
|
106 |
+
finally:
|
107 |
+
await browser.close()
|
108 |
+
|
109 |
+
console.log(f"[green]Processing text ({len(text)} characters)...[/green]")
|
110 |
+
|
111 |
+
# Enhanced text cleaning
|
112 |
+
text = ' '.join(text.split()) # Normalize whitespace
|
113 |
+
|
114 |
+
# Improved text splitting with semantic boundaries
|
115 |
+
splitter = RecursiveCharacterTextSplitter(
|
116 |
+
chunk_size=512,
|
117 |
+
chunk_overlap=50,
|
118 |
+
separators=["\n\n", "\n", ".", "!", "?", ",", " "],
|
119 |
+
length_function=len
|
120 |
+
)
|
121 |
+
|
122 |
+
docs = splitter.create_documents([text])
|
123 |
+
|
124 |
+
# Limit the number of chunks
|
125 |
+
limited_docs = docs[:max_chunks]
|
126 |
+
console.log(f"[yellow]Using {len(limited_docs)} chunks out of {len(docs)} total chunks[/yellow]")
|
127 |
+
|
128 |
+
# Enhanced metadata
|
129 |
+
timestamp = datetime.now().isoformat()
|
130 |
+
return [{
|
131 |
+
"page_content": doc.page_content,
|
132 |
+
"metadata": {
|
133 |
+
"source": url,
|
134 |
+
"chunk_index": i,
|
135 |
+
"total_chunks": len(limited_docs),
|
136 |
+
"timestamp": timestamp
|
137 |
+
}
|
138 |
+
} for i, doc in enumerate(limited_docs)]
|
139 |
+
|
140 |
+
except Exception as e:
|
141 |
+
logger.error(f"Error scraping {url}: {str(e)}")
|
142 |
+
return []
|
143 |
+
|
144 |
+
async def ingest(self, urls: List[str], max_chunks_per_url: int = 100) -> None:
|
145 |
+
"""Ingest data from URLs into the vector database."""
|
146 |
+
from langchain_community.vectorstores import Pinecone as LangchainPinecone
|
147 |
+
from tqdm import tqdm
|
148 |
+
|
149 |
+
# Create empty list to store documents
|
150 |
+
all_docs = []
|
151 |
+
|
152 |
+
# Scrape and process each URL with progress bar
|
153 |
+
for url in tqdm(urls, desc="Scraping URLs"):
|
154 |
+
chunks = await self.scrape(url, max_chunks=max_chunks_per_url)
|
155 |
+
all_docs.extend(chunks)
|
156 |
+
|
157 |
+
# Create or update vector database
|
158 |
+
total_docs = len(all_docs)
|
159 |
+
print(f"\nCreating vector database with {total_docs} documents...")
|
160 |
+
texts = [doc["page_content"] for doc in all_docs]
|
161 |
+
metadatas = [doc["metadata"] for doc in all_docs]
|
162 |
+
|
163 |
+
print("Starting embedding generation and uploading to Pinecone (this might take several minutes)...")
|
164 |
+
self.vectordb = LangchainPinecone.from_texts(
|
165 |
+
texts=texts,
|
166 |
+
embedding=self.embeddings,
|
167 |
+
index_name=self.index_name,
|
168 |
+
metadatas=metadatas,
|
169 |
+
text_key="text"
|
170 |
+
)
|
171 |
+
|
172 |
+
print("✅ Documents successfully uploaded to Pinecone!")
|
173 |
+
|
174 |
+
async def ask_question(self, question: str) -> Dict[str, Any]:
|
175 |
+
"""Ask a question and get a response using RAG."""
|
176 |
+
if not self.vectordb:
|
177 |
+
return {"answer": "Error: Vector database not initialized. Please ingest data first.", "sources": []}
|
178 |
+
|
179 |
+
try:
|
180 |
+
# Retrieve relevant documents with similarity search
|
181 |
+
retriever = self.vectordb.as_retriever(
|
182 |
+
search_type="similarity",
|
183 |
+
search_kwargs={"k": 5}
|
184 |
+
)
|
185 |
+
|
186 |
+
# Get relevant documents
|
187 |
+
docs = retriever.get_relevant_documents(question)
|
188 |
+
|
189 |
+
if not docs:
|
190 |
+
return {
|
191 |
+
"answer": "I couldn't find any relevant information in my knowledge base. Please try a different question or ingest more relevant data.",
|
192 |
+
"sources": []
|
193 |
+
}
|
194 |
+
|
195 |
+
# Format context from documents
|
196 |
+
context = "\n\n".join([f"Document {i+1}: {doc.page_content}" for i, doc in enumerate(docs)])
|
197 |
+
|
198 |
+
# Create prompt for the LLM
|
199 |
+
prompt = f"""
|
200 |
+
Answer the question based on the provided context. Include relevant citations using [1], [2], etc.
|
201 |
+
If you're unsure or if the context doesn't contain the information, acknowledge the uncertainty.
|
202 |
+
|
203 |
+
Context:
|
204 |
+
{context}
|
205 |
+
|
206 |
+
Question: {question}
|
207 |
+
|
208 |
+
Answer with citations:
|
209 |
+
"""
|
210 |
+
|
211 |
+
# Get response from LLM
|
212 |
+
response_text = ""
|
213 |
+
if hasattr(self.llm, "__call__"): # Direct inference client wrapped function
|
214 |
+
response_text = self.llm(prompt)
|
215 |
+
# Debug response
|
216 |
+
logger.info(f"Raw LLM response type: {type(response_text)}")
|
217 |
+
if not response_text or response_text.strip() == "":
|
218 |
+
logger.error("Empty response from LLM")
|
219 |
+
response_text = "I apologize, but I couldn't generate a response. This might be due to an issue with the language model."
|
220 |
+
else: # LangChain LLM
|
221 |
+
response_text = self.llm.invoke(prompt)
|
222 |
+
|
223 |
+
# Format sources
|
224 |
+
sources = [{
|
225 |
+
"url": doc.metadata["source"],
|
226 |
+
"chunk_index": doc.metadata.get("chunk_index", 0),
|
227 |
+
"timestamp": doc.metadata.get("timestamp", "")
|
228 |
+
} for doc in docs]
|
229 |
+
|
230 |
+
# Format response
|
231 |
+
formatted_response = {
|
232 |
+
"answer": response_text,
|
233 |
+
"sources": sources
|
234 |
+
}
|
235 |
+
|
236 |
+
return formatted_response
|
237 |
+
|
238 |
+
except Exception as e:
|
239 |
+
logger.error(f"Error processing question: {str(e)}")
|
240 |
+
return {
|
241 |
+
"answer": f"I apologize, but I encountered an error while processing your question: {str(e)}",
|
242 |
+
"sources": []
|
243 |
+
}
|
244 |
+
|
245 |
+
async def main():
|
246 |
+
"""Main function to run the application."""
|
247 |
+
import asyncio
|
248 |
+
|
249 |
+
parser = argparse.ArgumentParser(description="F1-AI: RAG Application for Formula 1 information")
|
250 |
+
subparsers = parser.add_subparsers(dest="command", help="Command to run")
|
251 |
+
|
252 |
+
# Ingest command
|
253 |
+
ingest_parser = subparsers.add_parser("ingest", help="Ingest data from URLs")
|
254 |
+
ingest_parser.add_argument("--urls", nargs="+", required=True, help="URLs to scrape")
|
255 |
+
ingest_parser.add_argument("--max-chunks", type=int, default=100, help="Maximum chunks per URL")
|
256 |
+
|
257 |
+
# Ask command
|
258 |
+
ask_parser = subparsers.add_parser("ask", help="Ask a question")
|
259 |
+
ask_parser.add_argument("question", help="Question to ask")
|
260 |
+
|
261 |
+
# Added provider argument with the new option
|
262 |
+
parser.add_argument("--provider", choices=["ollama", "huggingface", "huggingface-openai"], default="huggingface",
|
263 |
+
help="Provider for LLM and embeddings (default: huggingface)")
|
264 |
+
|
265 |
+
args = parser.parse_args()
|
266 |
+
|
267 |
+
f1_ai = F1AI(llm_provider=args.provider)
|
268 |
+
|
269 |
+
if args.command == "ingest":
|
270 |
+
await f1_ai.ingest(args.urls, max_chunks_per_url=args.max_chunks)
|
271 |
+
elif args.command == "ask":
|
272 |
+
response = await f1_ai.ask_question(args.question)
|
273 |
+
console.print("\n[bold green]Answer:[/bold green]")
|
274 |
+
# Format as markdown to make it prettier
|
275 |
+
console.print(Markdown(response['answer']))
|
276 |
+
|
277 |
+
console.print("\n[bold yellow]Sources:[/bold yellow]")
|
278 |
+
for i, source in enumerate(response['sources']):
|
279 |
+
console.print(f"[{i+1}] {source['url']}")
|
280 |
+
else:
|
281 |
+
parser.print_help()
|
282 |
+
|
283 |
+
if __name__ == "__main__":
|
284 |
+
import asyncio
|
285 |
+
asyncio.run(main())
|
llm_manager.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List, Dict, Any
|
3 |
+
from huggingface_hub import InferenceClient
|
4 |
+
from langchain_ollama import OllamaEmbeddings, OllamaLLM
|
5 |
+
from dotenv import load_dotenv
|
6 |
+
import numpy as np
|
7 |
+
import logging
|
8 |
+
|
9 |
+
# Configure logging
|
10 |
+
logging.basicConfig(level=logging.INFO)
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
# Load environment variables
|
14 |
+
load_dotenv()
|
15 |
+
|
16 |
+
class LLMManager:
|
17 |
+
"""
|
18 |
+
Manager class for handling different LLM and embedding models.
|
19 |
+
Uses HuggingFace's InferenceClient directly for HuggingFace models.
|
20 |
+
"""
|
21 |
+
def __init__(self, provider: str = "huggingface"):
|
22 |
+
"""
|
23 |
+
Initialize the LLM Manager.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
provider (str): The provider for LLM and embeddings.
|
27 |
+
Options: "ollama", "huggingface", "huggingface-openai"
|
28 |
+
"""
|
29 |
+
self.provider = provider
|
30 |
+
self.llm_client = None
|
31 |
+
self.embedding_client = None
|
32 |
+
|
33 |
+
# Initialize models based on the provider
|
34 |
+
if provider == "ollama":
|
35 |
+
self._init_ollama()
|
36 |
+
elif provider == "huggingface" or provider == "huggingface-openai":
|
37 |
+
self._init_huggingface()
|
38 |
+
else:
|
39 |
+
raise ValueError(f"Unsupported provider: {provider}. Choose 'ollama', 'huggingface', or 'huggingface-openai'")
|
40 |
+
|
41 |
+
def _init_ollama(self):
|
42 |
+
"""Initialize Ollama models."""
|
43 |
+
self.llm = OllamaLLM(model="phi4-mini:3.8b")
|
44 |
+
self.embeddings = OllamaEmbeddings(model="mxbai-embed-large:latest")
|
45 |
+
|
46 |
+
def _init_huggingface(self):
|
47 |
+
"""Initialize HuggingFace models using InferenceClient directly."""
|
48 |
+
# Get API key from environment
|
49 |
+
api_key = os.getenv("HUGGINGFACE_API_KEY")
|
50 |
+
if not api_key:
|
51 |
+
raise ValueError("HuggingFace API key not found. Set HUGGINGFACE_API_KEY in environment variables.")
|
52 |
+
|
53 |
+
llm_endpoint = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
54 |
+
embedding_endpoint = "sentence-transformers/all-MiniLM-L6-v2"
|
55 |
+
|
56 |
+
# Initialize InferenceClient for LLM
|
57 |
+
self.llm_client = InferenceClient(
|
58 |
+
model=llm_endpoint,
|
59 |
+
token=api_key
|
60 |
+
)
|
61 |
+
|
62 |
+
# Initialize InferenceClient for embeddings
|
63 |
+
self.embedding_client = InferenceClient(
|
64 |
+
model=embedding_endpoint,
|
65 |
+
token=api_key
|
66 |
+
)
|
67 |
+
|
68 |
+
# Store generation parameters
|
69 |
+
self.generation_kwargs = {
|
70 |
+
"temperature": 0.7,
|
71 |
+
"max_new_tokens": 512, # Reduced to avoid potential token limit issues
|
72 |
+
"repetition_penalty": 1.1,
|
73 |
+
"do_sample": True,
|
74 |
+
"top_k": 50,
|
75 |
+
"top_p": 0.9,
|
76 |
+
"return_full_text": False # Only return the generated text, not the prompt
|
77 |
+
}
|
78 |
+
|
79 |
+
# LLM methods for compatibility with LangChain
|
80 |
+
def get_llm(self):
|
81 |
+
"""
|
82 |
+
Return a callable object that mimics LangChain LLM interface.
|
83 |
+
For huggingface providers, this returns a function that calls the InferenceClient.
|
84 |
+
"""
|
85 |
+
if self.provider == "ollama":
|
86 |
+
return self.llm
|
87 |
+
else:
|
88 |
+
# Return a function that wraps the InferenceClient for LLM
|
89 |
+
def llm_function(prompt, **kwargs):
|
90 |
+
params = {**self.generation_kwargs, **kwargs}
|
91 |
+
try:
|
92 |
+
logger.info(f"Sending prompt to HuggingFace (length: {len(prompt)})")
|
93 |
+
response = self.llm_client.text_generation(
|
94 |
+
prompt,
|
95 |
+
details=True, # Get detailed response
|
96 |
+
**params
|
97 |
+
)
|
98 |
+
# Extract generated text from response
|
99 |
+
if isinstance(response, dict) and 'generated_text' in response:
|
100 |
+
response = response['generated_text']
|
101 |
+
logger.info(f"Received response from HuggingFace (length: {len(response) if response else 0})")
|
102 |
+
|
103 |
+
# Ensure we get a valid string response
|
104 |
+
if not response or not isinstance(response, str) or response.strip() == "":
|
105 |
+
logger.warning("Empty or invalid response from HuggingFace, using fallback")
|
106 |
+
return "I couldn't generate a proper response based on the available information."
|
107 |
+
|
108 |
+
return response
|
109 |
+
except Exception as e:
|
110 |
+
logger.error(f"Error during LLM inference: {str(e)}")
|
111 |
+
return f"Error generating response: {str(e)}"
|
112 |
+
|
113 |
+
# Add async capability
|
114 |
+
async def allm_function(prompt, **kwargs):
|
115 |
+
params = {**self.generation_kwargs, **kwargs}
|
116 |
+
try:
|
117 |
+
response = await self.llm_client.text_generation(
|
118 |
+
prompt,
|
119 |
+
**params,
|
120 |
+
stream=False
|
121 |
+
)
|
122 |
+
|
123 |
+
# Ensure we get a valid string response
|
124 |
+
if not response or not isinstance(response, str) or response.strip() == "":
|
125 |
+
logger.warning("Empty or invalid response from HuggingFace async, using fallback")
|
126 |
+
return "I couldn't generate a proper response based on the available information."
|
127 |
+
|
128 |
+
return response
|
129 |
+
except Exception as e:
|
130 |
+
logger.error(f"Error during async LLM inference: {str(e)}")
|
131 |
+
return f"Error generating response: {str(e)}"
|
132 |
+
|
133 |
+
llm_function.ainvoke = allm_function
|
134 |
+
return llm_function
|
135 |
+
|
136 |
+
# Embeddings methods for compatibility with LangChain
|
137 |
+
def get_embeddings(self):
|
138 |
+
"""
|
139 |
+
Return a callable object that mimics LangChain Embeddings interface.
|
140 |
+
For huggingface providers, this returns an object with embed_documents and embed_query methods.
|
141 |
+
"""
|
142 |
+
if self.provider == "ollama":
|
143 |
+
return self.embeddings
|
144 |
+
else:
|
145 |
+
# Create a wrapper object that has the expected methods
|
146 |
+
class EmbeddingsWrapper:
|
147 |
+
def __init__(self, client):
|
148 |
+
self.client = client
|
149 |
+
|
150 |
+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
151 |
+
"""Embed multiple documents."""
|
152 |
+
embeddings = []
|
153 |
+
# Process in batches to avoid overwhelming the API
|
154 |
+
batch_size = 8
|
155 |
+
|
156 |
+
for i in range(0, len(texts), batch_size):
|
157 |
+
batch = texts[i:i+batch_size]
|
158 |
+
try:
|
159 |
+
batch_embeddings = self.client.feature_extraction(batch)
|
160 |
+
# Convert to standard Python list format
|
161 |
+
batch_results = [list(map(float, embedding)) for embedding in batch_embeddings]
|
162 |
+
embeddings.extend(batch_results)
|
163 |
+
except Exception as e:
|
164 |
+
logger.error(f"Error embedding batch {i}: {str(e)}")
|
165 |
+
# Return zero vectors as fallback
|
166 |
+
for _ in range(len(batch)):
|
167 |
+
embeddings.append([0.0] * 384) # Use correct dimension
|
168 |
+
|
169 |
+
return embeddings
|
170 |
+
|
171 |
+
def embed_query(self, text: str) -> List[float]:
|
172 |
+
"""Embed a single query."""
|
173 |
+
try:
|
174 |
+
embedding = self.client.feature_extraction(text)
|
175 |
+
if isinstance(embedding, list) and len(embedding) > 0:
|
176 |
+
# If it returns a batch (list of embeddings) for a single input
|
177 |
+
return list(map(float, embedding[0]))
|
178 |
+
# If it returns a single embedding
|
179 |
+
return list(map(float, embedding))
|
180 |
+
except Exception as e:
|
181 |
+
logger.error(f"Error embedding query: {str(e)}")
|
182 |
+
# Return zero vector as fallback
|
183 |
+
return [0.0] * 384 # Use correct dimension
|
184 |
+
|
185 |
+
# Make the class callable to fix the TypeError
|
186 |
+
def __call__(self, texts):
|
187 |
+
"""Make the object callable for compatibility with LangChain."""
|
188 |
+
if isinstance(texts, str):
|
189 |
+
return self.embed_query(texts)
|
190 |
+
elif isinstance(texts, list):
|
191 |
+
return self.embed_documents(texts)
|
192 |
+
else:
|
193 |
+
raise ValueError(f"Unsupported input type: {type(texts)}")
|
194 |
+
|
195 |
+
return EmbeddingsWrapper(self.embedding_client)
|
packages.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
chromium
|
2 |
+
wget
|
3 |
+
ca-certificates
|
requirements.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
langchain==0.1.0
|
2 |
+
langchain-community==0.0.10
|
3 |
+
langchain-core==0.1.0
|
4 |
+
langchain-pinecone==0.0.1
|
5 |
+
langchain-ollama==0.0.1
|
6 |
+
pinecone-client==3.0.1
|
7 |
+
huggingface-hub==0.20.1
|
8 |
+
streamlit==1.29.0
|
9 |
+
playwright==1.40.0
|
10 |
+
beautifulsoup4==4.12.2
|
11 |
+
tqdm==4.66.1
|
12 |
+
python-dotenv==1.0.0
|
13 |
+
typing-extensions==4.8.0
|
14 |
+
rich==13.7.0
|
15 |
+
# Remove asyncio package as it's part of Python standard library
|
setup.sh
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
# Install playwright browsers
|
3 |
+
pip install playwright
|
4 |
+
playwright install chromium
|