Michaeldavidstein commited on
Commit
2f3bcdf
·
verified ·
1 Parent(s): 5ce514e

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -130
app.py DELETED
@@ -1,130 +0,0 @@
1
-
2
- # Import the necessary Libraries
3
- import gradio as gr
4
- from sentence_transformers import SentenceTransformer
5
- import chromadb
6
- import pandas as pd
7
- import os
8
-
9
- # Load the sentence transformer model
10
- model = SentenceTransformer('all-MiniLM-L6-v2')
11
-
12
- # Initialize the ChromaDB client
13
- client = chromadb.Client()
14
-
15
- def build_database():
16
- # Read the CSV file
17
- df = pd.read_csv('collection_data.csv')
18
-
19
- # Create a collection
20
- collection_name = 'Dataset-10k-companies'
21
-
22
- # Delete the existing collection if it exists
23
- if collection_name in client.list_collections():
24
- client.delete_collection(name=collection_name)
25
-
26
- # Create a new collection
27
- collection = client.create_collection(name=collection_name)
28
-
29
- # Function to safely process embeddings
30
- def process_embedding(x):
31
- if isinstance(x, str):
32
- return eval(x.replace(',,', ','))
33
- elif isinstance(x, float):
34
- return [] # or some default value
35
- else:
36
- return x
37
-
38
- # Add the data from the DataFrame to the collection
39
- collection.add(
40
- documents=df['documents'].tolist(),
41
- ids=df['ids'].tolist(),
42
- metadatas=df['metadatas'].apply(eval).tolist(),
43
- embeddings=df['embeddings'].apply(process_embedding).tolist()
44
- )
45
-
46
- return collection
47
-
48
- # Build the database when the app starts
49
- collection = build_database()
50
-
51
- # Function to perform similarity search and return relevant chunks
52
- def get_relevant_chunks(query, collection, top_n=3):
53
- query_embedding = model.encode(query).tolist()
54
- results = collection.query(query_embeddings=[query_embedding], n_results=top_n)
55
-
56
- relevant_chunks = []
57
- for i in range(len(results['documents'][0])):
58
- chunk = results['documents'][0][i]
59
- source = results['metadatas'][0][i]['source']
60
- page = results['metadatas'][0][i]['page']
61
- relevant_chunks.append((chunk, source, page))
62
-
63
- return relevant_chunks
64
-
65
- # Function to get LLM response
66
- def get_llm_response(prompt, max_attempts=3):
67
- full_response = ""
68
- for attempt in range(max_attempts):
69
- try:
70
- response = client.complete(prompt, max_tokens=1000)
71
- chunk = response.text.strip()
72
- full_response += chunk
73
- if chunk.endswith((".", "!", "?")):
74
- break
75
- else:
76
- prompt = "Please continue from where you left off:\n" + chunk[-100:]
77
- except Exception as e:
78
- print(f"Attempt {attempt + 1} failed with error: {e}")
79
- return full_response
80
-
81
- # Prediction function
82
- def predict(company, user_query):
83
- # Modify the query to include the company name
84
- modified_query = f"{user_query} for {company}"
85
-
86
- # Get relevant chunks
87
- relevant_chunks = get_relevant_chunks(modified_query, collection)
88
-
89
- # Prepare the context string
90
- context = ""
91
- for chunk, source, page in relevant_chunks:
92
- context += chunk + "\n"
93
- context += f"[Source: {source}, Page: {page}]\n\n"
94
-
95
- # Generate answer
96
- prompt = f"Based on the following context, answer the question: {modified_query}\n\nContext:\n{context}"
97
- prediction = get_llm_response(prompt)
98
-
99
- # While the prediction is made, log both the inputs and outputs to a local log file
100
- # While writing to the log file, ensure that the commit scheduler is locked to avoid parallel
101
- # access
102
-
103
- with scheduler.lock:
104
- with log_file.open("a") as f:
105
- f.write(json.dumps(
106
- {
107
- 'user_input': user_input,
108
- 'retrieved_context': context_for_query,
109
- 'model_response': prediction
110
- }
111
- ))
112
- f.write("\n")
113
-
114
- return prediction
115
-
116
- # Create Gradio interface
117
- company_list = ["MSFT", "AWS", "Meta", "Google", "IBM"]
118
- iface = gr.Interface(
119
- fn=predict,
120
- inputs=[
121
- gr.Radio(company_list, label="Select Company"),
122
- gr.Textbox(lines=2, placeholder="Enter your query here...", label="User Query")
123
- ],
124
- outputs=gr.Textbox(label="Generated Answer"),
125
- title="Company Reports Q&A",
126
- description="Query the vector database and get an LLM response based on the documents in the collection."
127
- )
128
-
129
- # Launch the interface
130
- iface.launch()