Michaeldavidstein's picture
Update app.py
a4e00f7 verified
raw
history blame
4.27 kB
# Import the necessary Libraries
import gradio as gr
from sentence_transformers import SentenceTransformer
import chromadb
import pandas as pd
import os
# Load the sentence transformer model
model = SentenceTransformer('all-MiniLM-L6-v2')
# Initialize the ChromaDB client
client = chromadb.Client()
def build_database():
# Read the CSV file
df = pd.read_csv('collection_data.csv')
# Create a collection
collection_name = 'Dataset-10k-companies'
# Delete the existing collection if it exists
if collection_name in client.list_collections():
client.delete_collection(name=collection_name)
# Create a new collection
collection = client.create_collection(name=collection_name)
# Function to safely process embeddings
def process_embedding(x):
if isinstance(x, str):
return eval(x.replace(',,', ','))
elif isinstance(x, float):
return [] # or some default value
else:
return x
# Add the data from the DataFrame to the collection
collection.add(
documents=df['documents'].tolist(),
ids=df['ids'].tolist(),
metadatas=df['metadatas'].apply(eval).tolist(),
embeddings=df['embeddings'].apply(process_embedding).tolist()
)
return collection
# Build the database when the app starts
collection = build_database()
# Function to perform similarity search and return relevant chunks
def get_relevant_chunks(query, collection, top_n=3):
query_embedding = model.encode(query).tolist()
results = collection.query(query_embeddings=[query_embedding], n_results=top_n)
relevant_chunks = []
for i in range(len(results['documents'][0])):
chunk = results['documents'][0][i]
source = results['metadatas'][0][i]['source']
page = results['metadatas'][0][i]['page']
relevant_chunks.append((chunk, source, page))
return relevant_chunks
# Function to get LLM response
def get_llm_response(prompt, max_attempts=3):
full_response = ""
for attempt in range(max_attempts):
try:
response = client.complete(prompt, max_tokens=1000)
chunk = response.text.strip()
full_response += chunk
if chunk.endswith((".", "!", "?")):
break
else:
prompt = "Please continue from where you left off:\n" + chunk[-100:]
except Exception as e:
print(f"Attempt {attempt + 1} failed with error: {e}")
return full_response
# Prediction function
def predict(company, user_query):
# Modify the query to include the company name
modified_query = f"{user_query} for {company}"
# Get relevant chunks
relevant_chunks = get_relevant_chunks(modified_query, collection)
# Prepare the context string
context = ""
for chunk, source, page in relevant_chunks:
context += chunk + "\n"
context += f"[Source: {source}, Page: {page}]\n\n"
# Generate answer
prompt = f"Based on the following context, answer the question: {modified_query}\n\nContext:\n{context}"
prediction = get_llm_response(prompt)
# While the prediction is made, log both the inputs and outputs to a local log file
# While writing to the log file, ensure that the commit scheduler is locked to avoid parallel
# access
with scheduler.lock:
with log_file.open("a") as f:
f.write(json.dumps(
{
'user_input': user_input,
'retrieved_context': context_for_query,
'model_response': prediction
}
))
f.write("\n")
return prediction
# Create Gradio interface
company_list = ["MSFT", "AWS", "Meta", "Google", "IBM"]
iface = gr.Interface(
fn=predict,
inputs=[
gr.Radio(company_list, label="Select Company"),
gr.Textbox(lines=2, placeholder="Enter your query here...", label="User Query")
],
outputs=gr.Textbox(label="Generated Answer"),
title="Company Reports Q&A",
description="Query the vector database and get an LLM response based on the documents in the collection."
)
# Launch the interface
demo.queue()
iface.launch()