import gradio as gr from sentence_transformers import SentenceTransformer import chromadb import pandas as pd import os # Load the sentence transformer model model = SentenceTransformer('all-MiniLM-L6-v2') # Initialize the ChromaDB client client = chromadb.Client() # Function to build the database from CSV def build_database(): # Read the CSV file df = pd.read_csv('collection_data.csv') # Create a collection collection_name = 'Dataset-10k-companies' # Delete the existing collection if it exists if collection_name in client.list_collections(): client.delete_collection(name=collection_name) # Create a new collection collection = client.create_collection(name=collection_name) # Add the data from the DataFrame to the collection collection.add( documents=df['documents'].tolist(), ids=df['ids'].tolist(), metadatas=df['metadatas'].apply(eval).tolist(), embeddings=df['embeddings'].apply(lambda x: eval(x.replace(',,', ','))).tolist() ) return collection # Build the database when the app starts collection = build_database() # Function to get relevant chunks def get_relevant_chunks(query, collection, top_n=3): query_embedding = model.encode(query).tolist() results = collection.query(query_embeddings=[query_embedding], n_results=top_n) relevant_chunks = [] for i in range(len(results['documents'][0])): chunk = results['documents'][0][i] source = results['metadatas'][0][i]['source'] page = results['metadatas'][0][i]['page'] relevant_chunks.append((chunk, source, page)) return relevant_chunks # Function to get LLM response def get_llm_response(prompt, max_attempts=3): full_response = "" for attempt in range(max_attempts): try: response = client.complete(prompt, max_tokens=1000) # Increase max_tokens if possible chunk = response.text.strip() full_response += chunk if chunk.endswith((".", "!", "?")): # Check if response seems complete break else: prompt = "Please continue from where you left off:\n" + chunk[-100:] # Use the last 100 chars as context except Exception as e: print(f"Attempt {attempt + 1} failed with error: {e}") return full_response # Prediction function def predict(company, user_query): # Modify the query to include the company name modified_query = f"{user_query} for {company}" # Get relevant chunks relevant_chunks = get_relevant_chunks(modified_query, collection) # Prepare the context string context = "" for chunk, source, page in relevant_chunks: context += chunk + "\n" context += f"[Source: {source}, Page: {page}]\n\n" # Generate answer prompt = f"Based on the following context, answer the question: {modified_query}\n\nContext:\n{context}" answer = get_llm_response(prompt) # While the prediction is made, log both the inputs and outputs to a local log file # While writing to the log file, ensure that the commit scheduler is locked to avoid parallel # access with scheduler.lock: with log_file.open("a") as f: f.write(json.dumps( { 'user_input': user_input, 'retrieved_context': context_for_query, 'model_response': prediction } )) f.write("\n") return answer # Create Gradio interface company_list = ["MSFT", "AWS", "Meta", "Google", "IBM"] iface = gr.Interface( fn=predict, inputs=[ gr.Radio(company_list, label="Select Company"), gr.Textbox(lines=2, placeholder="Enter your query here...", label="User Query") ], outputs=gr.Textbox(label="Generated Answer"), title="Company Reports Q&A", description="Query the vector database and get an LLM response based on the documents in the collection." ) # Launch the interface iface.launch()