File size: 4,080 Bytes
7225d45 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import gradio as gr
from sentence_transformers import SentenceTransformer
import chromadb
import pandas as pd
import os
# Load the sentence transformer model
model = SentenceTransformer('all-MiniLM-L6-v2')
# Initialize the ChromaDB client
client = chromadb.Client()
# Function to build the database from CSV
def build_database():
# Read the CSV file
df = pd.read_csv('collection_data.csv')
# Create a collection
collection_name = 'Dataset-10k-companies'
# Delete the existing collection if it exists
if collection_name in client.list_collections():
client.delete_collection(name=collection_name)
# Create a new collection
collection = client.create_collection(name=collection_name)
# Add the data from the DataFrame to the collection
collection.add(
documents=df['documents'].tolist(),
ids=df['ids'].tolist(),
metadatas=df['metadatas'].apply(eval).tolist(),
embeddings=df['embeddings'].apply(lambda x: eval(x.replace(',,', ','))).tolist()
)
return collection
# Build the database when the app starts
collection = build_database()
# Function to get relevant chunks
def get_relevant_chunks(query, collection, top_n=3):
query_embedding = model.encode(query).tolist()
results = collection.query(query_embeddings=[query_embedding], n_results=top_n)
relevant_chunks = []
for i in range(len(results['documents'][0])):
chunk = results['documents'][0][i]
source = results['metadatas'][0][i]['source']
page = results['metadatas'][0][i]['page']
relevant_chunks.append((chunk, source, page))
return relevant_chunks
# Function to get LLM response
def get_llm_response(prompt, max_attempts=3):
full_response = ""
for attempt in range(max_attempts):
try:
response = client.complete(prompt, max_tokens=1000) # Increase max_tokens if possible
chunk = response.text.strip()
full_response += chunk
if chunk.endswith((".", "!", "?")): # Check if response seems complete
break
else:
prompt = "Please continue from where you left off:\n" + chunk[-100:] # Use the last 100 chars as context
except Exception as e:
print(f"Attempt {attempt + 1} failed with error: {e}")
return full_response
# Prediction function
def predict(company, user_query):
# Modify the query to include the company name
modified_query = f"{user_query} for {company}"
# Get relevant chunks
relevant_chunks = get_relevant_chunks(modified_query, collection)
# Prepare the context string
context = ""
for chunk, source, page in relevant_chunks:
context += chunk + "\n"
context += f"[Source: {source}, Page: {page}]\n\n"
# Generate answer
prompt = f"Based on the following context, answer the question: {modified_query}\n\nContext:\n{context}"
answer = get_llm_response(prompt)
# While the prediction is made, log both the inputs and outputs to a local log file
# While writing to the log file, ensure that the commit scheduler is locked to avoid parallel
# access
with scheduler.lock:
with log_file.open("a") as f:
f.write(json.dumps(
{
'user_input': user_input,
'retrieved_context': context_for_query,
'model_response': prediction
}
))
f.write("\n")
return answer
# Create Gradio interface
company_list = ["MSFT", "AWS", "Meta", "Google", "IBM"]
iface = gr.Interface(
fn=predict,
inputs=[
gr.Radio(company_list, label="Select Company"),
gr.Textbox(lines=2, placeholder="Enter your query here...", label="User Query")
],
outputs=gr.Textbox(label="Generated Answer"),
title="Company Reports Q&A",
description="Query the vector database and get an LLM response based on the documents in the collection."
)
# Launch the interface
iface.launch()
|