File size: 9,038 Bytes
c901ab2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
import os
import json
import chromadb
import numpy as np
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
from transformers import AutoTokenizer, AutoModel
from groq import Groq
import gradio as gr
import httpx  # Used to make async HTTP requests to FastAPI

# Load environment variables
load_dotenv()

# List of API keys for Groq
api_keys = [
    os.getenv("GROQ_API_KEY"),
    os.getenv("GROQ_API_KEY_2"),
    os.getenv("GROQ_API_KEY_3"),
    os.getenv("GROQ_API_KEY_4"),
]

if not any(api_keys):
    raise ValueError("At least one GROQ_API_KEY environment variable must be set.")

# Initialize Groq client with the first API key
current_key_index = 0
client = Groq(api_key=api_keys[current_key_index])

# FastAPI app
app = FastAPI()

# Define Groq-based model with fallback
class GroqChatbot:
    def __init__(self, api_keys):
        self.api_keys = api_keys
        self.current_key_index = 0
        self.client = Groq(api_key=self.api_keys[self.current_key_index])

    def switch_key(self):
        """Switch to the next API key in the list."""
        self.current_key_index = (self.current_key_index + 1) % len(self.api_keys)
        self.client = Groq(api_key=self.api_keys[self.current_key_index])
        print(f"Switched to API key index {self.current_key_index}")

    def get_response(self, prompt):
        """Get a response from the API, switching keys on failure."""
        while True:
            try:
                response = self.client.chat.completions.create(
                    messages=[
                        {"role": "system", "content": "You are a helpful AI assistant."},
                        {"role": "user", "content": prompt}
                    ],
                    model="llama3-70b-8192",
                )
                return response.choices[0].message.content
            except Exception as e:
                print(f"Error: {e}")
                self.switch_key()
                if self.current_key_index == 0:
                    return "All API keys have been exhausted. Please try again later."

    def text_to_embedding(self, text):
        """Convert text to embedding using the current model."""
        try:
            # Load the model and tokenizer
            tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-3.2-1B")
            model = AutoModel.from_pretrained("NousResearch/Llama-3.2-1B")

            # Move model to GPU if available
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            model = model.to(device)
            model.eval()

            # Ensure tokenizer has a padding token
            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token

            # Tokenize the text
            encoded_input = tokenizer(
                text,
                padding=True,
                truncation=True,
                max_length=512,
                return_tensors='pt'
            ).to(device)

            # Generate embeddings
            with torch.no_grad():
                model_output = model(**encoded_input)
                sentence_embeddings = model_output.last_hidden_state

                # Mean pooling
                attention_mask = encoded_input['attention_mask']
                mask = attention_mask.unsqueeze(-1).expand(sentence_embeddings.size()).float()
                masked_embeddings = sentence_embeddings * mask
                summed = torch.sum(masked_embeddings, dim=1)
                summed_mask = torch.clamp(torch.sum(attention_mask, dim=1).unsqueeze(-1), min=1e-9)
                mean_pooled = (summed / summed_mask).squeeze()

                # Move to CPU and convert to numpy
                embedding = mean_pooled.cpu().numpy()

                # Normalize the embedding vector
                embedding = embedding / np.linalg.norm(embedding)

                print(f"Generated embedding for text: {text}")
                return embedding
        except Exception as e:
            print(f"Error generating embedding: {e}")
            return None

# Modify LocalEmbeddingStore to use ChromaDB
class LocalEmbeddingStore:
    def __init__(self, storage_dir="./chromadb_storage"):
        self.client = chromadb.PersistentClient(path=storage_dir)  # Use ChromaDB client with persistent storage
        self.collection_name = "chatbot_docs"  # Collection for storing embeddings
        self.collection = self.client.get_or_create_collection(name=self.collection_name)

    def add_embedding(self, doc_id, embedding, metadata):
        """Add a document and its embedding to ChromaDB."""
        self.collection.add(
            documents=[doc_id],  # Document ID for identification
            embeddings=[embedding],  # Embedding for the document
            metadatas=[metadata],  # Optional metadata
            ids=[doc_id]  # Same ID as document ID
        )
        print(f"Added embedding for document ID: {doc_id}")

    def search_embedding(self, query_embedding, num_results=3):
        """Search for the most relevant document based on embedding similarity."""
        results = self.collection.query(
            query_embeddings=[query_embedding],
            n_results=num_results
        )
        print(f"Search results: {results}")
        return results['documents'], results['distances']  # Returning both document IDs and distances

# Modify RAGSystem to integrate ChromaDB search
class RAGSystem:
    def __init__(self, groq_client, embedding_store):
        self.groq_client = groq_client
        self.embedding_store = embedding_store

    def get_most_relevant_document(self, query_embedding):
        """Retrieve the most relevant document based on cosine similarity."""
        docs, distances = self.embedding_store.search_embedding(query_embedding)
        if docs:
            return docs[0], distances[0][0]  # Return the most relevant document and the first distance value
        return None, None

    def chat_with_rag(self, user_input):
        """Handle the RAG process."""
        query_embedding = self.groq_client.text_to_embedding(user_input)
        if query_embedding is None or query_embedding.size == 0:
            return "Failed to generate embeddings."

        context_document_id, similarity_score = self.get_most_relevant_document(query_embedding)
        if not context_document_id:
            return "No relevant documents found."

        # Assuming metadata retrieval works
        context_metadata = f"Metadata for {context_document_id}"  # Placeholder, implement as needed

        prompt = f"""Context (similarity score {similarity_score:.2f}):
{context_metadata}

User: {user_input}
AI:"""
        return self.groq_client.get_response(prompt)

# Initialize components
embedding_store = LocalEmbeddingStore(storage_dir="./chromadb_storage")
chatbot = GroqChatbot(api_keys=api_keys)
rag_system = RAGSystem(groq_client=chatbot, embedding_store=embedding_store)

# Pydantic models for API request and response
class UserInput(BaseModel):
    input_text: str

class ChatResponse(BaseModel):
    response: str

@app.get("/")
async def read_root():
    return {"message": "Welcome to the Groq and ChromaDB integration API!"}

@app.post("/chat", response_model=ChatResponse)
async def chat(user_input: UserInput):
    """Handle chat interactions with Groq and ChromaDB."""
    ai_response = rag_system.chat_with_rag(user_input.input_text)
    return ChatResponse(response=ai_response)

@app.post("/embed", response_model=ChatResponse)
async def embed_text(user_input: UserInput):
    """Handle text embedding."""
    embedding = chatbot.text_to_embedding(user_input.input_text)
    if embedding is not None:
        return ChatResponse(response="Text embedded successfully.")
    else:
        raise HTTPException(status_code=400, detail="Embedding generation failed.")

@app.post("/add_document", response_model=ChatResponse)
async def add_document(user_input: UserInput):
    """Add a document embedding to ChromaDB."""
    embedding = chatbot.text_to_embedding(user_input.input_text)
    if embedding is not None:
        doc_id = "sample_document"  # You can generate or pass a doc ID
        embedding_store.add_embedding(doc_id, embedding, metadata={"source": "user_input"})
        return ChatResponse(response="Document added to the database.")
    else:
        raise HTTPException(status_code=400, detail="Embedding generation failed.")

# Gradio Interface for querying the FastAPI /chat endpoint
async def gradio_chatbot(input_text: str):
    async with httpx.AsyncClient() as client:
        response = await client.post(
            "http://127.0.0.1:7860/chat",  # FastAPI endpoint
            json={"input_text": input_text}
        )
        response_data = response.json()
        return response_data["response"]

# Gradio Interface
iface = gr.Interface(fn=gradio_chatbot, inputs="text", outputs="text")

if __name__ == "__main__":
    # Launch the Gradio interface
    iface.launch()