patrickbdevaney commited on
Commit
c901ab2
·
verified ·
1 Parent(s): a866fda

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +232 -0
app.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import chromadb
4
+ import numpy as np
5
+ from dotenv import load_dotenv
6
+ from fastapi import FastAPI, HTTPException
7
+ from pydantic import BaseModel
8
+ import torch
9
+ from transformers import AutoTokenizer, AutoModel
10
+ from groq import Groq
11
+ import gradio as gr
12
+ import httpx # Used to make async HTTP requests to FastAPI
13
+
14
+ # Load environment variables
15
+ load_dotenv()
16
+
17
+ # List of API keys for Groq
18
+ api_keys = [
19
+ os.getenv("GROQ_API_KEY"),
20
+ os.getenv("GROQ_API_KEY_2"),
21
+ os.getenv("GROQ_API_KEY_3"),
22
+ os.getenv("GROQ_API_KEY_4"),
23
+ ]
24
+
25
+ if not any(api_keys):
26
+ raise ValueError("At least one GROQ_API_KEY environment variable must be set.")
27
+
28
+ # Initialize Groq client with the first API key
29
+ current_key_index = 0
30
+ client = Groq(api_key=api_keys[current_key_index])
31
+
32
+ # FastAPI app
33
+ app = FastAPI()
34
+
35
+ # Define Groq-based model with fallback
36
+ class GroqChatbot:
37
+ def __init__(self, api_keys):
38
+ self.api_keys = api_keys
39
+ self.current_key_index = 0
40
+ self.client = Groq(api_key=self.api_keys[self.current_key_index])
41
+
42
+ def switch_key(self):
43
+ """Switch to the next API key in the list."""
44
+ self.current_key_index = (self.current_key_index + 1) % len(self.api_keys)
45
+ self.client = Groq(api_key=self.api_keys[self.current_key_index])
46
+ print(f"Switched to API key index {self.current_key_index}")
47
+
48
+ def get_response(self, prompt):
49
+ """Get a response from the API, switching keys on failure."""
50
+ while True:
51
+ try:
52
+ response = self.client.chat.completions.create(
53
+ messages=[
54
+ {"role": "system", "content": "You are a helpful AI assistant."},
55
+ {"role": "user", "content": prompt}
56
+ ],
57
+ model="llama3-70b-8192",
58
+ )
59
+ return response.choices[0].message.content
60
+ except Exception as e:
61
+ print(f"Error: {e}")
62
+ self.switch_key()
63
+ if self.current_key_index == 0:
64
+ return "All API keys have been exhausted. Please try again later."
65
+
66
+ def text_to_embedding(self, text):
67
+ """Convert text to embedding using the current model."""
68
+ try:
69
+ # Load the model and tokenizer
70
+ tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-3.2-1B")
71
+ model = AutoModel.from_pretrained("NousResearch/Llama-3.2-1B")
72
+
73
+ # Move model to GPU if available
74
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
75
+ model = model.to(device)
76
+ model.eval()
77
+
78
+ # Ensure tokenizer has a padding token
79
+ if tokenizer.pad_token is None:
80
+ tokenizer.pad_token = tokenizer.eos_token
81
+
82
+ # Tokenize the text
83
+ encoded_input = tokenizer(
84
+ text,
85
+ padding=True,
86
+ truncation=True,
87
+ max_length=512,
88
+ return_tensors='pt'
89
+ ).to(device)
90
+
91
+ # Generate embeddings
92
+ with torch.no_grad():
93
+ model_output = model(**encoded_input)
94
+ sentence_embeddings = model_output.last_hidden_state
95
+
96
+ # Mean pooling
97
+ attention_mask = encoded_input['attention_mask']
98
+ mask = attention_mask.unsqueeze(-1).expand(sentence_embeddings.size()).float()
99
+ masked_embeddings = sentence_embeddings * mask
100
+ summed = torch.sum(masked_embeddings, dim=1)
101
+ summed_mask = torch.clamp(torch.sum(attention_mask, dim=1).unsqueeze(-1), min=1e-9)
102
+ mean_pooled = (summed / summed_mask).squeeze()
103
+
104
+ # Move to CPU and convert to numpy
105
+ embedding = mean_pooled.cpu().numpy()
106
+
107
+ # Normalize the embedding vector
108
+ embedding = embedding / np.linalg.norm(embedding)
109
+
110
+ print(f"Generated embedding for text: {text}")
111
+ return embedding
112
+ except Exception as e:
113
+ print(f"Error generating embedding: {e}")
114
+ return None
115
+
116
+ # Modify LocalEmbeddingStore to use ChromaDB
117
+ class LocalEmbeddingStore:
118
+ def __init__(self, storage_dir="./chromadb_storage"):
119
+ self.client = chromadb.PersistentClient(path=storage_dir) # Use ChromaDB client with persistent storage
120
+ self.collection_name = "chatbot_docs" # Collection for storing embeddings
121
+ self.collection = self.client.get_or_create_collection(name=self.collection_name)
122
+
123
+ def add_embedding(self, doc_id, embedding, metadata):
124
+ """Add a document and its embedding to ChromaDB."""
125
+ self.collection.add(
126
+ documents=[doc_id], # Document ID for identification
127
+ embeddings=[embedding], # Embedding for the document
128
+ metadatas=[metadata], # Optional metadata
129
+ ids=[doc_id] # Same ID as document ID
130
+ )
131
+ print(f"Added embedding for document ID: {doc_id}")
132
+
133
+ def search_embedding(self, query_embedding, num_results=3):
134
+ """Search for the most relevant document based on embedding similarity."""
135
+ results = self.collection.query(
136
+ query_embeddings=[query_embedding],
137
+ n_results=num_results
138
+ )
139
+ print(f"Search results: {results}")
140
+ return results['documents'], results['distances'] # Returning both document IDs and distances
141
+
142
+ # Modify RAGSystem to integrate ChromaDB search
143
+ class RAGSystem:
144
+ def __init__(self, groq_client, embedding_store):
145
+ self.groq_client = groq_client
146
+ self.embedding_store = embedding_store
147
+
148
+ def get_most_relevant_document(self, query_embedding):
149
+ """Retrieve the most relevant document based on cosine similarity."""
150
+ docs, distances = self.embedding_store.search_embedding(query_embedding)
151
+ if docs:
152
+ return docs[0], distances[0][0] # Return the most relevant document and the first distance value
153
+ return None, None
154
+
155
+ def chat_with_rag(self, user_input):
156
+ """Handle the RAG process."""
157
+ query_embedding = self.groq_client.text_to_embedding(user_input)
158
+ if query_embedding is None or query_embedding.size == 0:
159
+ return "Failed to generate embeddings."
160
+
161
+ context_document_id, similarity_score = self.get_most_relevant_document(query_embedding)
162
+ if not context_document_id:
163
+ return "No relevant documents found."
164
+
165
+ # Assuming metadata retrieval works
166
+ context_metadata = f"Metadata for {context_document_id}" # Placeholder, implement as needed
167
+
168
+ prompt = f"""Context (similarity score {similarity_score:.2f}):
169
+ {context_metadata}
170
+
171
+ User: {user_input}
172
+ AI:"""
173
+ return self.groq_client.get_response(prompt)
174
+
175
+ # Initialize components
176
+ embedding_store = LocalEmbeddingStore(storage_dir="./chromadb_storage")
177
+ chatbot = GroqChatbot(api_keys=api_keys)
178
+ rag_system = RAGSystem(groq_client=chatbot, embedding_store=embedding_store)
179
+
180
+ # Pydantic models for API request and response
181
+ class UserInput(BaseModel):
182
+ input_text: str
183
+
184
+ class ChatResponse(BaseModel):
185
+ response: str
186
+
187
+ @app.get("/")
188
+ async def read_root():
189
+ return {"message": "Welcome to the Groq and ChromaDB integration API!"}
190
+
191
+ @app.post("/chat", response_model=ChatResponse)
192
+ async def chat(user_input: UserInput):
193
+ """Handle chat interactions with Groq and ChromaDB."""
194
+ ai_response = rag_system.chat_with_rag(user_input.input_text)
195
+ return ChatResponse(response=ai_response)
196
+
197
+ @app.post("/embed", response_model=ChatResponse)
198
+ async def embed_text(user_input: UserInput):
199
+ """Handle text embedding."""
200
+ embedding = chatbot.text_to_embedding(user_input.input_text)
201
+ if embedding is not None:
202
+ return ChatResponse(response="Text embedded successfully.")
203
+ else:
204
+ raise HTTPException(status_code=400, detail="Embedding generation failed.")
205
+
206
+ @app.post("/add_document", response_model=ChatResponse)
207
+ async def add_document(user_input: UserInput):
208
+ """Add a document embedding to ChromaDB."""
209
+ embedding = chatbot.text_to_embedding(user_input.input_text)
210
+ if embedding is not None:
211
+ doc_id = "sample_document" # You can generate or pass a doc ID
212
+ embedding_store.add_embedding(doc_id, embedding, metadata={"source": "user_input"})
213
+ return ChatResponse(response="Document added to the database.")
214
+ else:
215
+ raise HTTPException(status_code=400, detail="Embedding generation failed.")
216
+
217
+ # Gradio Interface for querying the FastAPI /chat endpoint
218
+ async def gradio_chatbot(input_text: str):
219
+ async with httpx.AsyncClient() as client:
220
+ response = await client.post(
221
+ "http://127.0.0.1:7860/chat", # FastAPI endpoint
222
+ json={"input_text": input_text}
223
+ )
224
+ response_data = response.json()
225
+ return response_data["response"]
226
+
227
+ # Gradio Interface
228
+ iface = gr.Interface(fn=gradio_chatbot, inputs="text", outputs="text")
229
+
230
+ if __name__ == "__main__":
231
+ # Launch the Gradio interface
232
+ iface.launch()