Michaeldavidstein commited on
Commit
7225d45
·
verified ·
1 Parent(s): 2f3bcdf

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -0
app.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ from sentence_transformers import SentenceTransformer
4
+ import chromadb
5
+ import pandas as pd
6
+ import os
7
+
8
+ # Load the sentence transformer model
9
+ model = SentenceTransformer('all-MiniLM-L6-v2')
10
+
11
+ # Initialize the ChromaDB client
12
+ client = chromadb.Client()
13
+
14
+ # Function to build the database from CSV
15
+ def build_database():
16
+ # Read the CSV file
17
+ df = pd.read_csv('collection_data.csv')
18
+
19
+ # Create a collection
20
+ collection_name = 'Dataset-10k-companies'
21
+
22
+ # Delete the existing collection if it exists
23
+ if collection_name in client.list_collections():
24
+ client.delete_collection(name=collection_name)
25
+
26
+ # Create a new collection
27
+ collection = client.create_collection(name=collection_name)
28
+
29
+ # Add the data from the DataFrame to the collection
30
+ collection.add(
31
+ documents=df['documents'].tolist(),
32
+ ids=df['ids'].tolist(),
33
+ metadatas=df['metadatas'].apply(eval).tolist(),
34
+ embeddings=df['embeddings'].apply(lambda x: eval(x.replace(',,', ','))).tolist()
35
+ )
36
+
37
+ return collection
38
+
39
+ # Build the database when the app starts
40
+ collection = build_database()
41
+
42
+ # Function to get relevant chunks
43
+ def get_relevant_chunks(query, collection, top_n=3):
44
+ query_embedding = model.encode(query).tolist()
45
+ results = collection.query(query_embeddings=[query_embedding], n_results=top_n)
46
+
47
+ relevant_chunks = []
48
+ for i in range(len(results['documents'][0])):
49
+ chunk = results['documents'][0][i]
50
+ source = results['metadatas'][0][i]['source']
51
+ page = results['metadatas'][0][i]['page']
52
+ relevant_chunks.append((chunk, source, page))
53
+
54
+ return relevant_chunks
55
+
56
+ # Function to get LLM response
57
+ def get_llm_response(prompt, max_attempts=3):
58
+ full_response = ""
59
+ for attempt in range(max_attempts):
60
+ try:
61
+ response = client.complete(prompt, max_tokens=1000) # Increase max_tokens if possible
62
+ chunk = response.text.strip()
63
+ full_response += chunk
64
+ if chunk.endswith((".", "!", "?")): # Check if response seems complete
65
+ break
66
+ else:
67
+ prompt = "Please continue from where you left off:\n" + chunk[-100:] # Use the last 100 chars as context
68
+ except Exception as e:
69
+ print(f"Attempt {attempt + 1} failed with error: {e}")
70
+ return full_response
71
+
72
+ # Prediction function
73
+ def predict(company, user_query):
74
+ # Modify the query to include the company name
75
+ modified_query = f"{user_query} for {company}"
76
+
77
+ # Get relevant chunks
78
+ relevant_chunks = get_relevant_chunks(modified_query, collection)
79
+
80
+ # Prepare the context string
81
+ context = ""
82
+ for chunk, source, page in relevant_chunks:
83
+ context += chunk + "\n"
84
+ context += f"[Source: {source}, Page: {page}]\n\n"
85
+
86
+ # Generate answer
87
+ prompt = f"Based on the following context, answer the question: {modified_query}\n\nContext:\n{context}"
88
+ answer = get_llm_response(prompt)
89
+ # While the prediction is made, log both the inputs and outputs to a local log file
90
+ # While writing to the log file, ensure that the commit scheduler is locked to avoid parallel
91
+ # access
92
+
93
+ with scheduler.lock:
94
+ with log_file.open("a") as f:
95
+ f.write(json.dumps(
96
+ {
97
+ 'user_input': user_input,
98
+ 'retrieved_context': context_for_query,
99
+ 'model_response': prediction
100
+ }
101
+ ))
102
+ f.write("\n")
103
+
104
+ return answer
105
+
106
+ # Create Gradio interface
107
+ company_list = ["MSFT", "AWS", "Meta", "Google", "IBM"]
108
+ iface = gr.Interface(
109
+ fn=predict,
110
+ inputs=[
111
+ gr.Radio(company_list, label="Select Company"),
112
+ gr.Textbox(lines=2, placeholder="Enter your query here...", label="User Query")
113
+ ],
114
+ outputs=gr.Textbox(label="Generated Answer"),
115
+ title="Company Reports Q&A",
116
+ description="Query the vector database and get an LLM response based on the documents in the collection."
117
+ )
118
+
119
+ # Launch the interface
120
+ iface.launch()