File size: 4,268 Bytes
434ef73 a4e00f7 434ef73 a4e00f7 434ef73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
# Import the necessary Libraries
import gradio as gr
from sentence_transformers import SentenceTransformer
import chromadb
import pandas as pd
import os
# Load the sentence transformer model
model = SentenceTransformer('all-MiniLM-L6-v2')
# Initialize the ChromaDB client
client = chromadb.Client()
def build_database():
# Read the CSV file
df = pd.read_csv('collection_data.csv')
# Create a collection
collection_name = 'Dataset-10k-companies'
# Delete the existing collection if it exists
if collection_name in client.list_collections():
client.delete_collection(name=collection_name)
# Create a new collection
collection = client.create_collection(name=collection_name)
# Function to safely process embeddings
def process_embedding(x):
if isinstance(x, str):
return eval(x.replace(',,', ','))
elif isinstance(x, float):
return [] # or some default value
else:
return x
# Add the data from the DataFrame to the collection
collection.add(
documents=df['documents'].tolist(),
ids=df['ids'].tolist(),
metadatas=df['metadatas'].apply(eval).tolist(),
embeddings=df['embeddings'].apply(process_embedding).tolist()
)
return collection
# Build the database when the app starts
collection = build_database()
# Function to perform similarity search and return relevant chunks
def get_relevant_chunks(query, collection, top_n=3):
query_embedding = model.encode(query).tolist()
results = collection.query(query_embeddings=[query_embedding], n_results=top_n)
relevant_chunks = []
for i in range(len(results['documents'][0])):
chunk = results['documents'][0][i]
source = results['metadatas'][0][i]['source']
page = results['metadatas'][0][i]['page']
relevant_chunks.append((chunk, source, page))
return relevant_chunks
# Function to get LLM response
def get_llm_response(prompt, max_attempts=3):
full_response = ""
for attempt in range(max_attempts):
try:
response = client.complete(prompt, max_tokens=1000)
chunk = response.text.strip()
full_response += chunk
if chunk.endswith((".", "!", "?")):
break
else:
prompt = "Please continue from where you left off:\n" + chunk[-100:]
except Exception as e:
print(f"Attempt {attempt + 1} failed with error: {e}")
return full_response
# Prediction function
def predict(company, user_query):
# Modify the query to include the company name
modified_query = f"{user_query} for {company}"
# Get relevant chunks
relevant_chunks = get_relevant_chunks(modified_query, collection)
# Prepare the context string
context = ""
for chunk, source, page in relevant_chunks:
context += chunk + "\n"
context += f"[Source: {source}, Page: {page}]\n\n"
# Generate answer
prompt = f"Based on the following context, answer the question: {modified_query}\n\nContext:\n{context}"
prediction = get_llm_response(prompt)
# While the prediction is made, log both the inputs and outputs to a local log file
# While writing to the log file, ensure that the commit scheduler is locked to avoid parallel
# access
with scheduler.lock:
with log_file.open("a") as f:
f.write(json.dumps(
{
'user_input': user_input,
'retrieved_context': context_for_query,
'model_response': prediction
}
))
f.write("\n")
return prediction
# Create Gradio interface
company_list = ["MSFT", "AWS", "Meta", "Google", "IBM"]
iface = gr.Interface(
fn=predict,
inputs=[
gr.Radio(company_list, label="Select Company"),
gr.Textbox(lines=2, placeholder="Enter your query here...", label="User Query")
],
outputs=gr.Textbox(label="Generated Answer"),
title="Company Reports Q&A",
description="Query the vector database and get an LLM response based on the documents in the collection."
)
# Launch the interface
demo.queue()
iface.launch()
|