Michaeldavidstein commited on
Commit
434ef73
·
verified ·
1 Parent(s): b634bff

Upload 3 files

Browse files
Files changed (3) hide show
  1. app (2).py +129 -0
  2. collection_data.csv +0 -0
  3. requirements (2).txt +4 -0
app (2).py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Import the necessary Libraries
3
+ import gradio as gr
4
+ from sentence_transformers import SentenceTransformer
5
+ import chromadb
6
+ import pandas as pd
7
+ import os
8
+
9
+ # Load the sentence transformer model
10
+ model = SentenceTransformer('all-MiniLM-L6-v2')
11
+
12
+ # Initialize the ChromaDB client
13
+ client = chromadb.Client()
14
+
15
+ # Function to build the database from CSV
16
+ def build_database():
17
+ # Read the CSV file
18
+ df = pd.read_csv('collection_data.csv')
19
+
20
+ # Drop unnecessary columns if they exist
21
+ if 'uris' in df.columns:
22
+ df.drop('uris', axis=1, inplace=True)
23
+ if 'data' in df.columns:
24
+ df.drop('data', axis=1, inplace=True)
25
+
26
+ # Create a collection
27
+ collection_name = 'Dataset-10k-companies'
28
+
29
+ # Delete the existing collection if it exists
30
+ if collection_name in client.list_collections():
31
+ client.delete_collection(name=collection_name)
32
+
33
+ # Create a new collection
34
+ collection = client.create_collection(name=collection_name)
35
+
36
+ # Add the data from the DataFrame to the collection
37
+ collection.add(
38
+ documents=df['documents'].tolist(),
39
+ ids=df['ids'].tolist(),
40
+ metadatas=df['metadatas'].apply(eval).tolist(),
41
+ embeddings=df['embeddings'].apply(lambda x: eval(x.replace(',,', ','))).tolist()
42
+ )
43
+
44
+ return collection
45
+
46
+ # Build the database when the app starts
47
+ collection = build_database()
48
+
49
+ # Function to perform similarity search and return relevant chunks
50
+ def get_relevant_chunks(query, collection, top_n=3):
51
+ query_embedding = model.encode(query).tolist()
52
+ results = collection.query(query_embeddings=[query_embedding], n_results=top_n)
53
+
54
+ relevant_chunks = []
55
+ for i in range(len(results['documents'][0])):
56
+ chunk = results['documents'][0][i]
57
+ source = results['metadatas'][0][i]['source']
58
+ page = results['metadatas'][0][i]['page']
59
+ relevant_chunks.append((chunk, source, page))
60
+
61
+ return relevant_chunks
62
+
63
+ # Function to get LLM response
64
+ def get_llm_response(prompt, max_attempts=3):
65
+ full_response = ""
66
+ for attempt in range(max_attempts):
67
+ try:
68
+ response = client.complete(prompt, max_tokens=1000)
69
+ chunk = response.text.strip()
70
+ full_response += chunk
71
+ if chunk.endswith((".", "!", "?")):
72
+ break
73
+ else:
74
+ prompt = "Please continue from where you left off:\n" + chunk[-100:]
75
+ except Exception as e:
76
+ print(f"Attempt {attempt + 1} failed with error: {e}")
77
+ return full_response
78
+
79
+ # Prediction function
80
+ def predict(company, user_query):
81
+ # Modify the query to include the company name
82
+ modified_query = f"{user_query} for {company}"
83
+
84
+ # Get relevant chunks
85
+ relevant_chunks = get_relevant_chunks(modified_query, collection)
86
+
87
+ # Prepare the context string
88
+ context = ""
89
+ for chunk, source, page in relevant_chunks:
90
+ context += chunk + "\n"
91
+ context += f"[Source: {source}, Page: {page}]\n\n"
92
+
93
+ # Generate answer
94
+ prompt = f"Based on the following context, answer the question: {modified_query}\n\nContext:\n{context}"
95
+ prediction = get_llm_response(prompt)
96
+
97
+ # While the prediction is made, log both the inputs and outputs to a local log file
98
+ # While writing to the log file, ensure that the commit scheduler is locked to avoid parallel
99
+ # access
100
+
101
+ with scheduler.lock:
102
+ with log_file.open("a") as f:
103
+ f.write(json.dumps(
104
+ {
105
+ 'user_input': user_input,
106
+ 'retrieved_context': context_for_query,
107
+ 'model_response': prediction
108
+ }
109
+ ))
110
+ f.write("\n")
111
+
112
+ return prediction
113
+
114
+ # Create Gradio interface
115
+ company_list = ["MSFT", "AWS", "Meta", "Google", "IBM"]
116
+ iface = gr.Interface(
117
+ fn=predict,
118
+ inputs=[
119
+ gr.Radio(company_list, label="Select Company"),
120
+ gr.Textbox(lines=2, placeholder="Enter your query here...", label="User Query")
121
+ ],
122
+ outputs=gr.Textbox(label="Generated Answer"),
123
+ title="Company Reports Q&A",
124
+ description="Query the vector database and get an LLM response based on the documents in the collection."
125
+ )
126
+
127
+ # Launch the interface
128
+ demo.queue()
129
+ iface.launch()
collection_data.csv ADDED
The diff for this file is too large to render. See raw diff
 
requirements (2).txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio
2
+ sentence-transformers
3
+ chromadb
4
+ pandas