Michaeldavidstein commited on
Commit
d28a1c7
·
verified ·
1 Parent(s): 7225d45

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -120
app.py DELETED
@@ -1,120 +0,0 @@
1
-
2
- import gradio as gr
3
- from sentence_transformers import SentenceTransformer
4
- import chromadb
5
- import pandas as pd
6
- import os
7
-
8
- # Load the sentence transformer model
9
- model = SentenceTransformer('all-MiniLM-L6-v2')
10
-
11
- # Initialize the ChromaDB client
12
- client = chromadb.Client()
13
-
14
- # Function to build the database from CSV
15
- def build_database():
16
- # Read the CSV file
17
- df = pd.read_csv('collection_data.csv')
18
-
19
- # Create a collection
20
- collection_name = 'Dataset-10k-companies'
21
-
22
- # Delete the existing collection if it exists
23
- if collection_name in client.list_collections():
24
- client.delete_collection(name=collection_name)
25
-
26
- # Create a new collection
27
- collection = client.create_collection(name=collection_name)
28
-
29
- # Add the data from the DataFrame to the collection
30
- collection.add(
31
- documents=df['documents'].tolist(),
32
- ids=df['ids'].tolist(),
33
- metadatas=df['metadatas'].apply(eval).tolist(),
34
- embeddings=df['embeddings'].apply(lambda x: eval(x.replace(',,', ','))).tolist()
35
- )
36
-
37
- return collection
38
-
39
- # Build the database when the app starts
40
- collection = build_database()
41
-
42
- # Function to get relevant chunks
43
- def get_relevant_chunks(query, collection, top_n=3):
44
- query_embedding = model.encode(query).tolist()
45
- results = collection.query(query_embeddings=[query_embedding], n_results=top_n)
46
-
47
- relevant_chunks = []
48
- for i in range(len(results['documents'][0])):
49
- chunk = results['documents'][0][i]
50
- source = results['metadatas'][0][i]['source']
51
- page = results['metadatas'][0][i]['page']
52
- relevant_chunks.append((chunk, source, page))
53
-
54
- return relevant_chunks
55
-
56
- # Function to get LLM response
57
- def get_llm_response(prompt, max_attempts=3):
58
- full_response = ""
59
- for attempt in range(max_attempts):
60
- try:
61
- response = client.complete(prompt, max_tokens=1000) # Increase max_tokens if possible
62
- chunk = response.text.strip()
63
- full_response += chunk
64
- if chunk.endswith((".", "!", "?")): # Check if response seems complete
65
- break
66
- else:
67
- prompt = "Please continue from where you left off:\n" + chunk[-100:] # Use the last 100 chars as context
68
- except Exception as e:
69
- print(f"Attempt {attempt + 1} failed with error: {e}")
70
- return full_response
71
-
72
- # Prediction function
73
- def predict(company, user_query):
74
- # Modify the query to include the company name
75
- modified_query = f"{user_query} for {company}"
76
-
77
- # Get relevant chunks
78
- relevant_chunks = get_relevant_chunks(modified_query, collection)
79
-
80
- # Prepare the context string
81
- context = ""
82
- for chunk, source, page in relevant_chunks:
83
- context += chunk + "\n"
84
- context += f"[Source: {source}, Page: {page}]\n\n"
85
-
86
- # Generate answer
87
- prompt = f"Based on the following context, answer the question: {modified_query}\n\nContext:\n{context}"
88
- answer = get_llm_response(prompt)
89
- # While the prediction is made, log both the inputs and outputs to a local log file
90
- # While writing to the log file, ensure that the commit scheduler is locked to avoid parallel
91
- # access
92
-
93
- with scheduler.lock:
94
- with log_file.open("a") as f:
95
- f.write(json.dumps(
96
- {
97
- 'user_input': user_input,
98
- 'retrieved_context': context_for_query,
99
- 'model_response': prediction
100
- }
101
- ))
102
- f.write("\n")
103
-
104
- return answer
105
-
106
- # Create Gradio interface
107
- company_list = ["MSFT", "AWS", "Meta", "Google", "IBM"]
108
- iface = gr.Interface(
109
- fn=predict,
110
- inputs=[
111
- gr.Radio(company_list, label="Select Company"),
112
- gr.Textbox(lines=2, placeholder="Enter your query here...", label="User Query")
113
- ],
114
- outputs=gr.Textbox(label="Generated Answer"),
115
- title="Company Reports Q&A",
116
- description="Query the vector database and get an LLM response based on the documents in the collection."
117
- )
118
-
119
- # Launch the interface
120
- iface.launch()