File size: 4,080 Bytes
7225d45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121

import gradio as gr
from sentence_transformers import SentenceTransformer
import chromadb
import pandas as pd
import os

# Load the sentence transformer model
model = SentenceTransformer('all-MiniLM-L6-v2')

# Initialize the ChromaDB client
client = chromadb.Client()

# Function to build the database from CSV
def build_database():
    # Read the CSV file
    df = pd.read_csv('collection_data.csv')
    
    # Create a collection
    collection_name = 'Dataset-10k-companies'
    
    # Delete the existing collection if it exists
    if collection_name in client.list_collections():
        client.delete_collection(name=collection_name)
    
    # Create a new collection
    collection = client.create_collection(name=collection_name)
    
    # Add the data from the DataFrame to the collection
    collection.add(
        documents=df['documents'].tolist(),
        ids=df['ids'].tolist(),
        metadatas=df['metadatas'].apply(eval).tolist(),
        embeddings=df['embeddings'].apply(lambda x: eval(x.replace(',,', ','))).tolist()
    )
    
    return collection

# Build the database when the app starts
collection = build_database()

# Function to get relevant chunks
def get_relevant_chunks(query, collection, top_n=3):
    query_embedding = model.encode(query).tolist()
    results = collection.query(query_embeddings=[query_embedding], n_results=top_n)
    
    relevant_chunks = []
    for i in range(len(results['documents'][0])):
        chunk = results['documents'][0][i]
        source = results['metadatas'][0][i]['source']
        page = results['metadatas'][0][i]['page']
        relevant_chunks.append((chunk, source, page))
    
    return relevant_chunks

# Function to get LLM response
def get_llm_response(prompt, max_attempts=3):
    full_response = ""
    for attempt in range(max_attempts):
        try:
            response = client.complete(prompt, max_tokens=1000)  # Increase max_tokens if possible
            chunk = response.text.strip()
            full_response += chunk
            if chunk.endswith((".", "!", "?")):  # Check if response seems complete
                break
            else:
                prompt = "Please continue from where you left off:\n" + chunk[-100:]  # Use the last 100 chars as context
        except Exception as e:
            print(f"Attempt {attempt + 1} failed with error: {e}")
    return full_response

# Prediction function
def predict(company, user_query):
    # Modify the query to include the company name
    modified_query = f"{user_query} for {company}"
    
    # Get relevant chunks
    relevant_chunks = get_relevant_chunks(modified_query, collection)
    
    # Prepare the context string
    context = ""
    for chunk, source, page in relevant_chunks:
        context += chunk + "\n"
        context += f"[Source: {source}, Page: {page}]\n\n"
    
    # Generate answer
    prompt = f"Based on the following context, answer the question: {modified_query}\n\nContext:\n{context}"
    answer = get_llm_response(prompt)
        # While the prediction is made, log both the inputs and outputs to a local log file
    # While writing to the log file, ensure that the commit scheduler is locked to avoid parallel
    # access

    with scheduler.lock:
        with log_file.open("a") as f:
            f.write(json.dumps(
                {
                    'user_input': user_input,
                    'retrieved_context': context_for_query,
                    'model_response': prediction
                }
            ))
            f.write("\n")

    return answer

# Create Gradio interface
company_list = ["MSFT", "AWS", "Meta", "Google", "IBM"]
iface = gr.Interface(
    fn=predict,
    inputs=[
        gr.Radio(company_list, label="Select Company"),
        gr.Textbox(lines=2, placeholder="Enter your query here...", label="User Query")
    ],
    outputs=gr.Textbox(label="Generated Answer"),
    title="Company Reports Q&A",
    description="Query the vector database and get an LLM response based on the documents in the collection."
)

# Launch the interface
iface.launch()