|
|
|
|
|
import gradio as gr |
|
from sentence_transformers import SentenceTransformer |
|
import chromadb |
|
import pandas as pd |
|
import os |
|
|
|
|
|
model = SentenceTransformer('all-MiniLM-L6-v2') |
|
|
|
|
|
client = chromadb.Client() |
|
|
|
def build_database(): |
|
|
|
df = pd.read_csv('collection_data.csv') |
|
|
|
|
|
collection_name = 'Dataset-10k-companies' |
|
|
|
|
|
if collection_name in client.list_collections(): |
|
client.delete_collection(name=collection_name) |
|
|
|
|
|
collection = client.create_collection(name=collection_name) |
|
|
|
|
|
def process_embedding(x): |
|
if isinstance(x, str): |
|
return eval(x.replace(',,', ',')) |
|
elif isinstance(x, float): |
|
return [] |
|
else: |
|
return x |
|
|
|
|
|
collection.add( |
|
documents=df['documents'].tolist(), |
|
ids=df['ids'].tolist(), |
|
metadatas=df['metadatas'].apply(eval).tolist(), |
|
embeddings=df['embeddings'].apply(process_embedding).tolist() |
|
) |
|
|
|
return collection |
|
|
|
|
|
collection = build_database() |
|
|
|
|
|
def get_relevant_chunks(query, collection, top_n=3): |
|
query_embedding = model.encode(query).tolist() |
|
results = collection.query(query_embeddings=[query_embedding], n_results=top_n) |
|
|
|
relevant_chunks = [] |
|
for i in range(len(results['documents'][0])): |
|
chunk = results['documents'][0][i] |
|
source = results['metadatas'][0][i]['source'] |
|
page = results['metadatas'][0][i]['page'] |
|
relevant_chunks.append((chunk, source, page)) |
|
|
|
return relevant_chunks |
|
|
|
|
|
def get_llm_response(prompt, max_attempts=3): |
|
full_response = "" |
|
for attempt in range(max_attempts): |
|
try: |
|
response = client.complete(prompt, max_tokens=1000) |
|
chunk = response.text.strip() |
|
full_response += chunk |
|
if chunk.endswith((".", "!", "?")): |
|
break |
|
else: |
|
prompt = "Please continue from where you left off:\n" + chunk[-100:] |
|
except Exception as e: |
|
print(f"Attempt {attempt + 1} failed with error: {e}") |
|
return full_response |
|
|
|
|
|
def predict(company, user_query): |
|
|
|
modified_query = f"{user_query} for {company}" |
|
|
|
|
|
relevant_chunks = get_relevant_chunks(modified_query, collection) |
|
|
|
|
|
context = "" |
|
for chunk, source, page in relevant_chunks: |
|
context += chunk + "\n" |
|
context += f"[Source: {source}, Page: {page}]\n\n" |
|
|
|
|
|
prompt = f"Based on the following context, answer the question: {modified_query}\n\nContext:\n{context}" |
|
prediction = get_llm_response(prompt) |
|
|
|
|
|
|
|
|
|
|
|
with scheduler.lock: |
|
with log_file.open("a") as f: |
|
f.write(json.dumps( |
|
{ |
|
'user_input': user_input, |
|
'retrieved_context': context_for_query, |
|
'model_response': prediction |
|
} |
|
)) |
|
f.write("\n") |
|
|
|
return prediction |
|
|
|
|
|
company_list = ["MSFT", "AWS", "Meta", "Google", "IBM"] |
|
iface = gr.Interface( |
|
fn=predict, |
|
inputs=[ |
|
gr.Radio(company_list, label="Select Company"), |
|
gr.Textbox(lines=2, placeholder="Enter your query here...", label="User Query") |
|
], |
|
outputs=gr.Textbox(label="Generated Answer"), |
|
title="Company Reports Q&A", |
|
description="Query the vector database and get an LLM response based on the documents in the collection." |
|
) |
|
|
|
|
|
demo.queue() |
|
iface.launch() |
|
|