manasagangotri commited on
Commit
d16f9ab
ยท
verified ยท
1 Parent(s): 5eea801

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -61
app.py CHANGED
@@ -2,11 +2,12 @@
2
  import gradio as gr
3
  import torch
4
  import requests
5
- import json
6
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
7
  from sentence_transformers import SentenceTransformer
8
  from qdrant_client import QdrantClient
9
  from datetime import datetime
 
 
10
 
11
  # === Load Models ===
12
  print("Loading zero-shot classifier...")
@@ -15,9 +16,9 @@ classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnl
15
  print("Loading embedding model...")
16
  embedding_model = SentenceTransformer("intfloat/e5-large")
17
 
18
- print("Loading step-by-step generator...")
19
- tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")
20
- model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2")
21
 
22
  # === Qdrant Setup ===
23
  print("Connecting to Qdrant...")
@@ -28,20 +29,20 @@ collection_name = "math_problems"
28
  def is_valid_math_question(text):
29
  candidate_labels = ["math", "not math"]
30
  result = classifier(text, candidate_labels)
 
31
  return result['labels'][0] == "math" and result['scores'][0] > 0.7
32
 
33
- # === Retrieval from Qdrant ===
34
  def retrieve_from_qdrant(query):
 
35
  query_vector = embedding_model.encode(query).tolist()
36
- hits = qdrant_client.query_points(
37
- collection_name=collection_name,
38
- query_vector=query_vector,
39
- limit=3
40
- )
41
  return [hit.payload for hit in hits] if hits else []
42
 
43
- # === Web Search Fallback ===
44
  def web_search_tavily(query):
 
45
  TAVILY_API_KEY = "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"
46
  response = requests.post(
47
  "https://api.tavily.com/search",
@@ -49,77 +50,147 @@ def web_search_tavily(query):
49
  )
50
  return response.json().get("answer", "No answer found from Tavily.")
51
 
52
- # === Generator ===
53
- def generate_step_by_step_answer(question, context):
54
- prompt = f"Answer the following math question step-by-step:\nQuestion: {question}\nContext: {context}\nAnswer:"
55
- inputs = tokenizer(prompt, return_tensors="pt")
56
- outputs = model.generate(
57
- **inputs,
58
- max_new_tokens=256,
59
- temperature=0.7,
60
- top_p=0.95,
61
- do_sample=True,
62
- pad_token_id=tokenizer.eos_token_id
63
- )
64
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
65
-
66
- # === Router ===
67
- def router(question):
68
- if not is_valid_math_question(question):
69
- return "โŒ Only math questions are accepted. Please rephrase.", ""
70
-
71
- retrieved = retrieve_from_qdrant(question)
72
- context = "\n".join([item["solution"] for item in retrieved if "solution" in item])
73
- if context:
74
- answer = generate_step_by_step_answer(question, context)
75
- return answer, context
76
- else:
77
- fallback = web_search_tavily(question)
78
- return fallback, "Tavily Search"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  # === Feedback Storage ===
81
- def store_feedback(question, answer, correct_answer):
82
  entry = {
83
  "question": question,
84
  "model_answer": answer,
 
85
  "correct_answer": correct_answer,
86
  "timestamp": str(datetime.now())
87
  }
 
88
  with open("feedback.json", "a") as f:
89
  f.write(json.dumps(entry) + "\n")
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  # === Gradio Functions ===
92
  def ask_question(question):
93
- answer, context = router(question)
94
- return answer, question, answer
 
 
95
 
96
- def submit_feedback(question, model_answer, correct_answer):
97
- store_feedback(question, model_answer, correct_answer)
 
 
98
  return "โœ… Feedback received. Thank you!"
99
 
100
  # === Gradio UI ===
101
  with gr.Blocks() as demo:
102
- gr.Markdown("## ๐Ÿงฎ Math Question Answering with Retrieval + Feedback")
103
-
104
- with gr.Row():
105
- question_input = gr.Textbox(label="Enter your math question", lines=2)
 
 
 
 
 
 
 
106
  submit_btn = gr.Button("Get Answer")
 
107
 
108
- answer_output = gr.Markdown(label="Answer")
109
- hidden_q = gr.Textbox(visible=False)
110
- hidden_a = gr.Textbox(visible=False)
111
-
112
- submit_btn.click(fn=ask_question, inputs=[question_input], outputs=[answer_output, hidden_q, hidden_a])
113
-
114
- gr.Markdown("### ๐Ÿ“ Submit Feedback")
115
  fb_correct = gr.Textbox(label="Correct Answer (optional)")
116
- fb_submit = gr.Button("Submit Feedback")
117
  fb_status = gr.Textbox(label="Status", interactive=False)
118
 
119
- fb_submit.click(
120
- fn=submit_feedback,
121
- inputs=[hidden_q, hidden_a, fb_correct],
122
- outputs=[fb_status]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  )
124
 
125
- demo.launch(share=True)
 
 
2
  import gradio as gr
3
  import torch
4
  import requests
5
+ from transformers import pipeline
 
6
  from sentence_transformers import SentenceTransformer
7
  from qdrant_client import QdrantClient
8
  from datetime import datetime
9
+ import dspy
10
+ import json
11
 
12
  # === Load Models ===
13
  print("Loading zero-shot classifier...")
 
16
  print("Loading embedding model...")
17
  embedding_model = SentenceTransformer("intfloat/e5-large")
18
 
19
+ print("Loading text generation model...")
20
+ # Use a lighter model for testing
21
+ #qa_pipeline = pipeline("text-generation", model="gpt2")
22
 
23
  # === Qdrant Setup ===
24
  print("Connecting to Qdrant...")
 
29
  def is_valid_math_question(text):
30
  candidate_labels = ["math", "not math"]
31
  result = classifier(text, candidate_labels)
32
+ print("Classifier result:", result)
33
  return result['labels'][0] == "math" and result['scores'][0] > 0.7
34
 
35
+ # === Retrieval ===
36
  def retrieve_from_qdrant(query):
37
+ print("Retrieving context from Qdrant...")
38
  query_vector = embedding_model.encode(query).tolist()
39
+ hits = qdrant_client.search(collection_name=collection_name, query_vector=query_vector, limit=3)
40
+ print("Retrieved hits:", hits)
 
 
 
41
  return [hit.payload for hit in hits] if hits else []
42
 
43
+ # === Web Search ===
44
  def web_search_tavily(query):
45
+ print("Calling Tavily...")
46
  TAVILY_API_KEY = "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"
47
  response = requests.post(
48
  "https://api.tavily.com/search",
 
50
  )
51
  return response.json().get("answer", "No answer found from Tavily.")
52
 
53
+ # === DSPy Signature ===
54
+ class MathAnswer(dspy.Signature):
55
+ question = dspy.InputField()
56
+ retrieved_context = dspy.InputField()
57
+ answer = dspy.OutputField()
58
+
59
+ # === DSPy Programs ===
60
+ # === DSPy Programs with Output Guard ===
61
+ class MathRetrievalQA(dspy.Program):
62
+ def forward(self, question):
63
+ print("Inside MathRetrievalQA...")
64
+ context_items = retrieve_from_qdrant(question)
65
+ context = "\n".join([item["solution"] for item in context_items if "solution" in item])
66
+ print("Context for generation:", context)
67
+
68
+ if not context:
69
+ return {"answer": "", "retrieved_context": ""}
70
+
71
+ # === Replace below with real model call when ready ===
72
+ prompt = f"Question: {question}\nContext: {context}\nAnswer:"
73
+ print("Prompt for generation:", prompt)
74
+
75
+ # TEMP answer (replace with real generated output)
76
+ generated_answer = "This is a placeholder answer based on the context." # Simulated generation
77
+ print("Generated answer:", generated_answer)
78
+
79
+ # === Output Guard ===
80
+ if not generated_answer or len(generated_answer.strip()) < 10 or "I don't know" in generated_answer:
81
+ return {"answer": "", "retrieved_context": context}
82
+
83
+ return {"answer": generated_answer.strip(), "retrieved_context": context}
84
+
85
+
86
+ class WebFallbackQA(dspy.Program):
87
+ def forward(self, question):
88
+ print("Fallback to Tavily...")
89
+ answer = web_search_tavily(question)
90
+ if not answer or len(answer.strip()) < 10 or "No answer found" in answer:
91
+ answer = "โŒ Sorry, I couldn't find a reliable answer."
92
+ return {"answer": answer.strip(), "retrieved_context": "Tavily"}
93
+
94
+
95
+ class MathRouter(dspy.Program):
96
+ def forward(self, question):
97
+ print("Routing question:", question)
98
+ if not is_valid_math_question(question):
99
+ return {"answer": "โŒ Only math questions are accepted. Please rephrase.", "retrieved_context": ""}
100
+
101
+ result = MathRetrievalQA().forward(question)
102
+
103
+ if result["answer"]:
104
+ return result
105
+ else:
106
+ return WebFallbackQA().forward(question)
107
+
108
+
109
 
110
  # === Feedback Storage ===
111
+ def store_feedback(question, answer, feedback, correct_answer):
112
  entry = {
113
  "question": question,
114
  "model_answer": answer,
115
+ "feedback": feedback,
116
  "correct_answer": correct_answer,
117
  "timestamp": str(datetime.now())
118
  }
119
+ print("Storing feedback:", entry)
120
  with open("feedback.json", "a") as f:
121
  f.write(json.dumps(entry) + "\n")
122
 
123
+ def load_feedback_entries():
124
+ entries = []
125
+ try:
126
+ with open("feedback.json", "r") as f:
127
+ for line in f:
128
+ entry = json.loads(line)
129
+ entries.append(entry)
130
+ except FileNotFoundError:
131
+ pass
132
+ return entries
133
+
134
+
135
+ # === Gradio Functions ===
136
  # === Gradio Functions ===
137
  def ask_question(question):
138
+ print("ask_question() called with:", question)
139
+ result = router.forward(question)
140
+ print("Result:", result)
141
+ return result["answer"], question, result["answer"]
142
 
143
+
144
+
145
+ def submit_feedback(question, model_answer, feedback, correct_answer):
146
+ store_feedback(question, model_answer, feedback, correct_answer)
147
  return "โœ… Feedback received. Thank you!"
148
 
149
  # === Gradio UI ===
150
  with gr.Blocks() as demo:
151
+ gr.Markdown("## ๐Ÿงฎ Math Question Answering with DSPy + Feedback")
152
+
153
+ with gr.Tab("Ask a Math Question"):
154
+ with gr.Row():
155
+ question_input = gr.Textbox(label="Enter your math question", lines=2)
156
+ gr.Markdown("### ๐Ÿง  Answer:")
157
+ answer_output = gr.Markdown()
158
+
159
+ #answer_output = gr.Markdown(label="Answer")
160
+ hidden_q = gr.Textbox(visible=False)
161
+ hidden_a = gr.Textbox(visible=False)
162
  submit_btn = gr.Button("Get Answer")
163
+ submit_btn.click(fn=ask_question, inputs=[question_input], outputs=[answer_output, hidden_q, hidden_a])
164
 
165
+ with gr.Tab("Submit Feedback"):
166
+ gr.Markdown("### Was the answer helpful?")
167
+ fb_question = gr.Textbox(label="Original Question")
168
+ fb_answer = gr.Textbox(label="Model's Answer")
169
+ fb_like = gr.Radio(["๐Ÿ‘", "๐Ÿ‘Ž"], label="Your Feedback")
 
 
170
  fb_correct = gr.Textbox(label="Correct Answer (optional)")
171
+ fb_submit_btn = gr.Button("Submit Feedback")
172
  fb_status = gr.Textbox(label="Status", interactive=False)
173
 
174
+ feedback_display = gr.Dataframe(headers=["Question", "Answer", "Feedback", "Correct Answer", "Timestamp"],
175
+ row_count=10, max_rows=50, wrap=True)
176
+
177
+ def feedback_submission_and_display(question, answer, feedback, correct_answer):
178
+ store_feedback(question, answer, feedback, correct_answer)
179
+ entries = load_feedback_entries()
180
+ display_rows = [[
181
+ e["question"],
182
+ e["model_answer"],
183
+ e["feedback"],
184
+ e["correct_answer"],
185
+ e["timestamp"]
186
+ ] for e in entries]
187
+ return "โœ… Feedback received. Thank you!", display_rows
188
+
189
+ fb_submit_btn.click(
190
+ fn=feedback_submission_and_display,
191
+ inputs=[fb_question, fb_answer, fb_like, fb_correct],
192
+ outputs=[fb_status, feedback_display]
193
  )
194
 
195
+
196
+ demo.launch(share=True, debug=True)