File size: 4,268 Bytes
434ef73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4e00f7
 
 
 
 
 
 
 
 
434ef73
 
 
 
 
a4e00f7
434ef73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132

# Import the necessary Libraries
import gradio as gr
from sentence_transformers import SentenceTransformer
import chromadb
import pandas as pd
import os

# Load the sentence transformer model
model = SentenceTransformer('all-MiniLM-L6-v2')

# Initialize the ChromaDB client
client = chromadb.Client()

def build_database():
    # Read the CSV file
    df = pd.read_csv('collection_data.csv')
    
    # Create a collection
    collection_name = 'Dataset-10k-companies'
    
    # Delete the existing collection if it exists
    if collection_name in client.list_collections():
        client.delete_collection(name=collection_name)
    
    # Create a new collection
    collection = client.create_collection(name=collection_name)
    
    # Function to safely process embeddings
    def process_embedding(x):
        if isinstance(x, str):
            return eval(x.replace(',,', ','))
        elif isinstance(x, float):
            return []  # or some default value
        else:
            return x

    # Add the data from the DataFrame to the collection
    collection.add(
        documents=df['documents'].tolist(),
        ids=df['ids'].tolist(),
        metadatas=df['metadatas'].apply(eval).tolist(),
        embeddings=df['embeddings'].apply(process_embedding).tolist()
    )
    
    return collection

# Build the database when the app starts
collection = build_database()

# Function to perform similarity search and return relevant chunks
def get_relevant_chunks(query, collection, top_n=3):
    query_embedding = model.encode(query).tolist()
    results = collection.query(query_embeddings=[query_embedding], n_results=top_n)
    
    relevant_chunks = []
    for i in range(len(results['documents'][0])):
        chunk = results['documents'][0][i]
        source = results['metadatas'][0][i]['source']
        page = results['metadatas'][0][i]['page']
        relevant_chunks.append((chunk, source, page))
    
    return relevant_chunks

# Function to get LLM response
def get_llm_response(prompt, max_attempts=3):
    full_response = ""
    for attempt in range(max_attempts):
        try:
            response = client.complete(prompt, max_tokens=1000)
            chunk = response.text.strip()
            full_response += chunk
            if chunk.endswith((".", "!", "?")):
                break
            else:
                prompt = "Please continue from where you left off:\n" + chunk[-100:]
        except Exception as e:
            print(f"Attempt {attempt + 1} failed with error: {e}")
    return full_response

# Prediction function
def predict(company, user_query):
    # Modify the query to include the company name
    modified_query = f"{user_query} for {company}"
    
    # Get relevant chunks
    relevant_chunks = get_relevant_chunks(modified_query, collection)
    
    # Prepare the context string
    context = ""
    for chunk, source, page in relevant_chunks:
        context += chunk + "\n"
        context += f"[Source: {source}, Page: {page}]\n\n"
    
    # Generate answer
    prompt = f"Based on the following context, answer the question: {modified_query}\n\nContext:\n{context}"
    prediction = get_llm_response(prompt)
    
    # While the prediction is made, log both the inputs and outputs to a local log file
    # While writing to the log file, ensure that the commit scheduler is locked to avoid parallel
    # access

    with scheduler.lock:
        with log_file.open("a") as f:
            f.write(json.dumps(
                {
                    'user_input': user_input,
                    'retrieved_context': context_for_query,
                    'model_response': prediction
                }
            ))
            f.write("\n")

    return prediction

# Create Gradio interface
company_list = ["MSFT", "AWS", "Meta", "Google", "IBM"]
iface = gr.Interface(
    fn=predict,
    inputs=[
        gr.Radio(company_list, label="Select Company"),
        gr.Textbox(lines=2, placeholder="Enter your query here...", label="User Query")
    ],
    outputs=gr.Textbox(label="Generated Answer"),
    title="Company Reports Q&A",
    description="Query the vector database and get an LLM response based on the documents in the collection."
)

# Launch the interface
demo.queue()
iface.launch()