import os import json import chromadb import numpy as np from dotenv import load_dotenv from fastapi import FastAPI, HTTPException from pydantic import BaseModel import torch from transformers import AutoTokenizer, AutoModel from groq import Groq import gradio as gr import httpx # Used to make async HTTP requests to FastAPI # Load environment variables load_dotenv() # List of API keys for Groq api_keys = [ os.getenv("GROQ_API_KEY"), os.getenv("GROQ_API_KEY_2"), os.getenv("GROQ_API_KEY_3"), os.getenv("GROQ_API_KEY_4"), ] if not any(api_keys): raise ValueError("At least one GROQ_API_KEY environment variable must be set.") # Initialize Groq client with the first API key current_key_index = 0 client = Groq(api_key=api_keys[current_key_index]) # FastAPI app app = FastAPI() # Define Groq-based model with fallback class GroqChatbot: def __init__(self, api_keys): self.api_keys = api_keys self.current_key_index = 0 self.client = Groq(api_key=self.api_keys[self.current_key_index]) def switch_key(self): """Switch to the next API key in the list.""" self.current_key_index = (self.current_key_index + 1) % len(self.api_keys) self.client = Groq(api_key=self.api_keys[self.current_key_index]) print(f"Switched to API key index {self.current_key_index}") def get_response(self, prompt): """Get a response from the API, switching keys on failure.""" while True: try: response = self.client.chat.completions.create( messages=[ {"role": "system", "content": "You are a helpful AI assistant."}, {"role": "user", "content": prompt} ], model="llama3-70b-8192", ) return response.choices[0].message.content except Exception as e: print(f"Error: {e}") self.switch_key() if self.current_key_index == 0: return "All API keys have been exhausted. Please try again later." def text_to_embedding(self, text): """Convert text to embedding using the current model.""" try: # Load the model and tokenizer tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-3.2-1B") model = AutoModel.from_pretrained("NousResearch/Llama-3.2-1B") # Move model to GPU if available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) model.eval() # Ensure tokenizer has a padding token if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Tokenize the text encoded_input = tokenizer( text, padding=True, truncation=True, max_length=512, return_tensors='pt' ).to(device) # Generate embeddings with torch.no_grad(): model_output = model(**encoded_input) sentence_embeddings = model_output.last_hidden_state # Mean pooling attention_mask = encoded_input['attention_mask'] mask = attention_mask.unsqueeze(-1).expand(sentence_embeddings.size()).float() masked_embeddings = sentence_embeddings * mask summed = torch.sum(masked_embeddings, dim=1) summed_mask = torch.clamp(torch.sum(attention_mask, dim=1).unsqueeze(-1), min=1e-9) mean_pooled = (summed / summed_mask).squeeze() # Move to CPU and convert to numpy embedding = mean_pooled.cpu().numpy() # Normalize the embedding vector embedding = embedding / np.linalg.norm(embedding) print(f"Generated embedding for text: {text}") return embedding except Exception as e: print(f"Error generating embedding: {e}") return None # Modify LocalEmbeddingStore to use ChromaDB class LocalEmbeddingStore: def __init__(self, storage_dir="./chromadb_storage"): self.client = chromadb.PersistentClient(path=storage_dir) # Use ChromaDB client with persistent storage self.collection_name = "chatbot_docs" # Collection for storing embeddings self.collection = self.client.get_or_create_collection(name=self.collection_name) def add_embedding(self, doc_id, embedding, metadata): """Add a document and its embedding to ChromaDB.""" self.collection.add( documents=[doc_id], # Document ID for identification embeddings=[embedding], # Embedding for the document metadatas=[metadata], # Optional metadata ids=[doc_id] # Same ID as document ID ) print(f"Added embedding for document ID: {doc_id}") def search_embedding(self, query_embedding, num_results=3): """Search for the most relevant document based on embedding similarity.""" results = self.collection.query( query_embeddings=[query_embedding], n_results=num_results ) print(f"Search results: {results}") return results['documents'], results['distances'] # Returning both document IDs and distances # Modify RAGSystem to integrate ChromaDB search class RAGSystem: def __init__(self, groq_client, embedding_store): self.groq_client = groq_client self.embedding_store = embedding_store def get_most_relevant_document(self, query_embedding): """Retrieve the most relevant document based on cosine similarity.""" docs, distances = self.embedding_store.search_embedding(query_embedding) if docs: return docs[0], distances[0][0] # Return the most relevant document and the first distance value return None, None def chat_with_rag(self, user_input): """Handle the RAG process.""" query_embedding = self.groq_client.text_to_embedding(user_input) if query_embedding is None or query_embedding.size == 0: return "Failed to generate embeddings." context_document_id, similarity_score = self.get_most_relevant_document(query_embedding) if not context_document_id: return "No relevant documents found." # Assuming metadata retrieval works context_metadata = f"Metadata for {context_document_id}" # Placeholder, implement as needed prompt = f"""Context (similarity score {similarity_score:.2f}): {context_metadata} User: {user_input} AI:""" return self.groq_client.get_response(prompt) # Initialize components embedding_store = LocalEmbeddingStore(storage_dir="./chromadb_storage") chatbot = GroqChatbot(api_keys=api_keys) rag_system = RAGSystem(groq_client=chatbot, embedding_store=embedding_store) # Pydantic models for API request and response class UserInput(BaseModel): input_text: str class ChatResponse(BaseModel): response: str @app.get("/") async def read_root(): return {"message": "Welcome to the Groq and ChromaDB integration API!"} @app.post("/chat", response_model=ChatResponse) async def chat(user_input: UserInput): """Handle chat interactions with Groq and ChromaDB.""" ai_response = rag_system.chat_with_rag(user_input.input_text) return ChatResponse(response=ai_response) @app.post("/embed", response_model=ChatResponse) async def embed_text(user_input: UserInput): """Handle text embedding.""" embedding = chatbot.text_to_embedding(user_input.input_text) if embedding is not None: return ChatResponse(response="Text embedded successfully.") else: raise HTTPException(status_code=400, detail="Embedding generation failed.") @app.post("/add_document", response_model=ChatResponse) async def add_document(user_input: UserInput): """Add a document embedding to ChromaDB.""" embedding = chatbot.text_to_embedding(user_input.input_text) if embedding is not None: doc_id = "sample_document" # You can generate or pass a doc ID embedding_store.add_embedding(doc_id, embedding, metadata={"source": "user_input"}) return ChatResponse(response="Document added to the database.") else: raise HTTPException(status_code=400, detail="Embedding generation failed.") # Gradio Interface for querying the FastAPI /chat endpoint async def gradio_chatbot(input_text: str): async with httpx.AsyncClient() as client: response = await client.post( "http://127.0.0.1:7860/chat", # FastAPI endpoint json={"input_text": input_text} ) response_data = response.json() return response_data["response"] # Gradio Interface iface = gr.Interface(fn=gradio_chatbot, inputs="text", outputs="text") if __name__ == "__main__": # Launch the Gradio interface iface.launch()