Spaces:
Sleeping
Sleeping
File size: 9,038 Bytes
c901ab2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 |
import os
import json
import chromadb
import numpy as np
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
from transformers import AutoTokenizer, AutoModel
from groq import Groq
import gradio as gr
import httpx # Used to make async HTTP requests to FastAPI
# Load environment variables
load_dotenv()
# List of API keys for Groq
api_keys = [
os.getenv("GROQ_API_KEY"),
os.getenv("GROQ_API_KEY_2"),
os.getenv("GROQ_API_KEY_3"),
os.getenv("GROQ_API_KEY_4"),
]
if not any(api_keys):
raise ValueError("At least one GROQ_API_KEY environment variable must be set.")
# Initialize Groq client with the first API key
current_key_index = 0
client = Groq(api_key=api_keys[current_key_index])
# FastAPI app
app = FastAPI()
# Define Groq-based model with fallback
class GroqChatbot:
def __init__(self, api_keys):
self.api_keys = api_keys
self.current_key_index = 0
self.client = Groq(api_key=self.api_keys[self.current_key_index])
def switch_key(self):
"""Switch to the next API key in the list."""
self.current_key_index = (self.current_key_index + 1) % len(self.api_keys)
self.client = Groq(api_key=self.api_keys[self.current_key_index])
print(f"Switched to API key index {self.current_key_index}")
def get_response(self, prompt):
"""Get a response from the API, switching keys on failure."""
while True:
try:
response = self.client.chat.completions.create(
messages=[
{"role": "system", "content": "You are a helpful AI assistant."},
{"role": "user", "content": prompt}
],
model="llama3-70b-8192",
)
return response.choices[0].message.content
except Exception as e:
print(f"Error: {e}")
self.switch_key()
if self.current_key_index == 0:
return "All API keys have been exhausted. Please try again later."
def text_to_embedding(self, text):
"""Convert text to embedding using the current model."""
try:
# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-3.2-1B")
model = AutoModel.from_pretrained("NousResearch/Llama-3.2-1B")
# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()
# Ensure tokenizer has a padding token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Tokenize the text
encoded_input = tokenizer(
text,
padding=True,
truncation=True,
max_length=512,
return_tensors='pt'
).to(device)
# Generate embeddings
with torch.no_grad():
model_output = model(**encoded_input)
sentence_embeddings = model_output.last_hidden_state
# Mean pooling
attention_mask = encoded_input['attention_mask']
mask = attention_mask.unsqueeze(-1).expand(sentence_embeddings.size()).float()
masked_embeddings = sentence_embeddings * mask
summed = torch.sum(masked_embeddings, dim=1)
summed_mask = torch.clamp(torch.sum(attention_mask, dim=1).unsqueeze(-1), min=1e-9)
mean_pooled = (summed / summed_mask).squeeze()
# Move to CPU and convert to numpy
embedding = mean_pooled.cpu().numpy()
# Normalize the embedding vector
embedding = embedding / np.linalg.norm(embedding)
print(f"Generated embedding for text: {text}")
return embedding
except Exception as e:
print(f"Error generating embedding: {e}")
return None
# Modify LocalEmbeddingStore to use ChromaDB
class LocalEmbeddingStore:
def __init__(self, storage_dir="./chromadb_storage"):
self.client = chromadb.PersistentClient(path=storage_dir) # Use ChromaDB client with persistent storage
self.collection_name = "chatbot_docs" # Collection for storing embeddings
self.collection = self.client.get_or_create_collection(name=self.collection_name)
def add_embedding(self, doc_id, embedding, metadata):
"""Add a document and its embedding to ChromaDB."""
self.collection.add(
documents=[doc_id], # Document ID for identification
embeddings=[embedding], # Embedding for the document
metadatas=[metadata], # Optional metadata
ids=[doc_id] # Same ID as document ID
)
print(f"Added embedding for document ID: {doc_id}")
def search_embedding(self, query_embedding, num_results=3):
"""Search for the most relevant document based on embedding similarity."""
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=num_results
)
print(f"Search results: {results}")
return results['documents'], results['distances'] # Returning both document IDs and distances
# Modify RAGSystem to integrate ChromaDB search
class RAGSystem:
def __init__(self, groq_client, embedding_store):
self.groq_client = groq_client
self.embedding_store = embedding_store
def get_most_relevant_document(self, query_embedding):
"""Retrieve the most relevant document based on cosine similarity."""
docs, distances = self.embedding_store.search_embedding(query_embedding)
if docs:
return docs[0], distances[0][0] # Return the most relevant document and the first distance value
return None, None
def chat_with_rag(self, user_input):
"""Handle the RAG process."""
query_embedding = self.groq_client.text_to_embedding(user_input)
if query_embedding is None or query_embedding.size == 0:
return "Failed to generate embeddings."
context_document_id, similarity_score = self.get_most_relevant_document(query_embedding)
if not context_document_id:
return "No relevant documents found."
# Assuming metadata retrieval works
context_metadata = f"Metadata for {context_document_id}" # Placeholder, implement as needed
prompt = f"""Context (similarity score {similarity_score:.2f}):
{context_metadata}
User: {user_input}
AI:"""
return self.groq_client.get_response(prompt)
# Initialize components
embedding_store = LocalEmbeddingStore(storage_dir="./chromadb_storage")
chatbot = GroqChatbot(api_keys=api_keys)
rag_system = RAGSystem(groq_client=chatbot, embedding_store=embedding_store)
# Pydantic models for API request and response
class UserInput(BaseModel):
input_text: str
class ChatResponse(BaseModel):
response: str
@app.get("/")
async def read_root():
return {"message": "Welcome to the Groq and ChromaDB integration API!"}
@app.post("/chat", response_model=ChatResponse)
async def chat(user_input: UserInput):
"""Handle chat interactions with Groq and ChromaDB."""
ai_response = rag_system.chat_with_rag(user_input.input_text)
return ChatResponse(response=ai_response)
@app.post("/embed", response_model=ChatResponse)
async def embed_text(user_input: UserInput):
"""Handle text embedding."""
embedding = chatbot.text_to_embedding(user_input.input_text)
if embedding is not None:
return ChatResponse(response="Text embedded successfully.")
else:
raise HTTPException(status_code=400, detail="Embedding generation failed.")
@app.post("/add_document", response_model=ChatResponse)
async def add_document(user_input: UserInput):
"""Add a document embedding to ChromaDB."""
embedding = chatbot.text_to_embedding(user_input.input_text)
if embedding is not None:
doc_id = "sample_document" # You can generate or pass a doc ID
embedding_store.add_embedding(doc_id, embedding, metadata={"source": "user_input"})
return ChatResponse(response="Document added to the database.")
else:
raise HTTPException(status_code=400, detail="Embedding generation failed.")
# Gradio Interface for querying the FastAPI /chat endpoint
async def gradio_chatbot(input_text: str):
async with httpx.AsyncClient() as client:
response = await client.post(
"http://127.0.0.1:7860/chat", # FastAPI endpoint
json={"input_text": input_text}
)
response_data = response.json()
return response_data["response"]
# Gradio Interface
iface = gr.Interface(fn=gradio_chatbot, inputs="text", outputs="text")
if __name__ == "__main__":
# Launch the Gradio interface
iface.launch()
|