Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -2,11 +2,12 @@
|
|
2 |
import gradio as gr
|
3 |
import torch
|
4 |
import requests
|
5 |
-
import
|
6 |
-
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
|
7 |
from sentence_transformers import SentenceTransformer
|
8 |
from qdrant_client import QdrantClient
|
9 |
from datetime import datetime
|
|
|
|
|
10 |
|
11 |
# === Load Models ===
|
12 |
print("Loading zero-shot classifier...")
|
@@ -15,9 +16,9 @@ classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnl
|
|
15 |
print("Loading embedding model...")
|
16 |
embedding_model = SentenceTransformer("intfloat/e5-large")
|
17 |
|
18 |
-
print("Loading
|
19 |
-
|
20 |
-
|
21 |
|
22 |
# === Qdrant Setup ===
|
23 |
print("Connecting to Qdrant...")
|
@@ -28,20 +29,20 @@ collection_name = "math_problems"
|
|
28 |
def is_valid_math_question(text):
|
29 |
candidate_labels = ["math", "not math"]
|
30 |
result = classifier(text, candidate_labels)
|
|
|
31 |
return result['labels'][0] == "math" and result['scores'][0] > 0.7
|
32 |
|
33 |
-
# === Retrieval
|
34 |
def retrieve_from_qdrant(query):
|
|
|
35 |
query_vector = embedding_model.encode(query).tolist()
|
36 |
-
hits = qdrant_client.
|
37 |
-
|
38 |
-
query_vector=query_vector,
|
39 |
-
limit=3
|
40 |
-
)
|
41 |
return [hit.payload for hit in hits] if hits else []
|
42 |
|
43 |
-
# === Web Search
|
44 |
def web_search_tavily(query):
|
|
|
45 |
TAVILY_API_KEY = "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"
|
46 |
response = requests.post(
|
47 |
"https://api.tavily.com/search",
|
@@ -49,77 +50,147 @@ def web_search_tavily(query):
|
|
49 |
)
|
50 |
return response.json().get("answer", "No answer found from Tavily.")
|
51 |
|
52 |
-
# ===
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
answer
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
# === Feedback Storage ===
|
81 |
-
def store_feedback(question, answer, correct_answer):
|
82 |
entry = {
|
83 |
"question": question,
|
84 |
"model_answer": answer,
|
|
|
85 |
"correct_answer": correct_answer,
|
86 |
"timestamp": str(datetime.now())
|
87 |
}
|
|
|
88 |
with open("feedback.json", "a") as f:
|
89 |
f.write(json.dumps(entry) + "\n")
|
90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
# === Gradio Functions ===
|
92 |
def ask_question(question):
|
93 |
-
|
94 |
-
|
|
|
|
|
95 |
|
96 |
-
|
97 |
-
|
|
|
|
|
98 |
return "โ
Feedback received. Thank you!"
|
99 |
|
100 |
# === Gradio UI ===
|
101 |
with gr.Blocks() as demo:
|
102 |
-
gr.Markdown("## ๐งฎ Math Question Answering with
|
103 |
-
|
104 |
-
with gr.
|
105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
submit_btn = gr.Button("Get Answer")
|
|
|
107 |
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
gr.Markdown("### ๐ Submit Feedback")
|
115 |
fb_correct = gr.Textbox(label="Correct Answer (optional)")
|
116 |
-
|
117 |
fb_status = gr.Textbox(label="Status", interactive=False)
|
118 |
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
)
|
124 |
|
125 |
-
|
|
|
|
2 |
import gradio as gr
|
3 |
import torch
|
4 |
import requests
|
5 |
+
from transformers import pipeline
|
|
|
6 |
from sentence_transformers import SentenceTransformer
|
7 |
from qdrant_client import QdrantClient
|
8 |
from datetime import datetime
|
9 |
+
import dspy
|
10 |
+
import json
|
11 |
|
12 |
# === Load Models ===
|
13 |
print("Loading zero-shot classifier...")
|
|
|
16 |
print("Loading embedding model...")
|
17 |
embedding_model = SentenceTransformer("intfloat/e5-large")
|
18 |
|
19 |
+
print("Loading text generation model...")
|
20 |
+
# Use a lighter model for testing
|
21 |
+
#qa_pipeline = pipeline("text-generation", model="gpt2")
|
22 |
|
23 |
# === Qdrant Setup ===
|
24 |
print("Connecting to Qdrant...")
|
|
|
29 |
def is_valid_math_question(text):
|
30 |
candidate_labels = ["math", "not math"]
|
31 |
result = classifier(text, candidate_labels)
|
32 |
+
print("Classifier result:", result)
|
33 |
return result['labels'][0] == "math" and result['scores'][0] > 0.7
|
34 |
|
35 |
+
# === Retrieval ===
|
36 |
def retrieve_from_qdrant(query):
|
37 |
+
print("Retrieving context from Qdrant...")
|
38 |
query_vector = embedding_model.encode(query).tolist()
|
39 |
+
hits = qdrant_client.search(collection_name=collection_name, query_vector=query_vector, limit=3)
|
40 |
+
print("Retrieved hits:", hits)
|
|
|
|
|
|
|
41 |
return [hit.payload for hit in hits] if hits else []
|
42 |
|
43 |
+
# === Web Search ===
|
44 |
def web_search_tavily(query):
|
45 |
+
print("Calling Tavily...")
|
46 |
TAVILY_API_KEY = "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"
|
47 |
response = requests.post(
|
48 |
"https://api.tavily.com/search",
|
|
|
50 |
)
|
51 |
return response.json().get("answer", "No answer found from Tavily.")
|
52 |
|
53 |
+
# === DSPy Signature ===
|
54 |
+
class MathAnswer(dspy.Signature):
|
55 |
+
question = dspy.InputField()
|
56 |
+
retrieved_context = dspy.InputField()
|
57 |
+
answer = dspy.OutputField()
|
58 |
+
|
59 |
+
# === DSPy Programs ===
|
60 |
+
# === DSPy Programs with Output Guard ===
|
61 |
+
class MathRetrievalQA(dspy.Program):
|
62 |
+
def forward(self, question):
|
63 |
+
print("Inside MathRetrievalQA...")
|
64 |
+
context_items = retrieve_from_qdrant(question)
|
65 |
+
context = "\n".join([item["solution"] for item in context_items if "solution" in item])
|
66 |
+
print("Context for generation:", context)
|
67 |
+
|
68 |
+
if not context:
|
69 |
+
return {"answer": "", "retrieved_context": ""}
|
70 |
+
|
71 |
+
# === Replace below with real model call when ready ===
|
72 |
+
prompt = f"Question: {question}\nContext: {context}\nAnswer:"
|
73 |
+
print("Prompt for generation:", prompt)
|
74 |
+
|
75 |
+
# TEMP answer (replace with real generated output)
|
76 |
+
generated_answer = "This is a placeholder answer based on the context." # Simulated generation
|
77 |
+
print("Generated answer:", generated_answer)
|
78 |
+
|
79 |
+
# === Output Guard ===
|
80 |
+
if not generated_answer or len(generated_answer.strip()) < 10 or "I don't know" in generated_answer:
|
81 |
+
return {"answer": "", "retrieved_context": context}
|
82 |
+
|
83 |
+
return {"answer": generated_answer.strip(), "retrieved_context": context}
|
84 |
+
|
85 |
+
|
86 |
+
class WebFallbackQA(dspy.Program):
|
87 |
+
def forward(self, question):
|
88 |
+
print("Fallback to Tavily...")
|
89 |
+
answer = web_search_tavily(question)
|
90 |
+
if not answer or len(answer.strip()) < 10 or "No answer found" in answer:
|
91 |
+
answer = "โ Sorry, I couldn't find a reliable answer."
|
92 |
+
return {"answer": answer.strip(), "retrieved_context": "Tavily"}
|
93 |
+
|
94 |
+
|
95 |
+
class MathRouter(dspy.Program):
|
96 |
+
def forward(self, question):
|
97 |
+
print("Routing question:", question)
|
98 |
+
if not is_valid_math_question(question):
|
99 |
+
return {"answer": "โ Only math questions are accepted. Please rephrase.", "retrieved_context": ""}
|
100 |
+
|
101 |
+
result = MathRetrievalQA().forward(question)
|
102 |
+
|
103 |
+
if result["answer"]:
|
104 |
+
return result
|
105 |
+
else:
|
106 |
+
return WebFallbackQA().forward(question)
|
107 |
+
|
108 |
+
|
109 |
|
110 |
# === Feedback Storage ===
|
111 |
+
def store_feedback(question, answer, feedback, correct_answer):
|
112 |
entry = {
|
113 |
"question": question,
|
114 |
"model_answer": answer,
|
115 |
+
"feedback": feedback,
|
116 |
"correct_answer": correct_answer,
|
117 |
"timestamp": str(datetime.now())
|
118 |
}
|
119 |
+
print("Storing feedback:", entry)
|
120 |
with open("feedback.json", "a") as f:
|
121 |
f.write(json.dumps(entry) + "\n")
|
122 |
|
123 |
+
def load_feedback_entries():
|
124 |
+
entries = []
|
125 |
+
try:
|
126 |
+
with open("feedback.json", "r") as f:
|
127 |
+
for line in f:
|
128 |
+
entry = json.loads(line)
|
129 |
+
entries.append(entry)
|
130 |
+
except FileNotFoundError:
|
131 |
+
pass
|
132 |
+
return entries
|
133 |
+
|
134 |
+
|
135 |
+
# === Gradio Functions ===
|
136 |
# === Gradio Functions ===
|
137 |
def ask_question(question):
|
138 |
+
print("ask_question() called with:", question)
|
139 |
+
result = router.forward(question)
|
140 |
+
print("Result:", result)
|
141 |
+
return result["answer"], question, result["answer"]
|
142 |
|
143 |
+
|
144 |
+
|
145 |
+
def submit_feedback(question, model_answer, feedback, correct_answer):
|
146 |
+
store_feedback(question, model_answer, feedback, correct_answer)
|
147 |
return "โ
Feedback received. Thank you!"
|
148 |
|
149 |
# === Gradio UI ===
|
150 |
with gr.Blocks() as demo:
|
151 |
+
gr.Markdown("## ๐งฎ Math Question Answering with DSPy + Feedback")
|
152 |
+
|
153 |
+
with gr.Tab("Ask a Math Question"):
|
154 |
+
with gr.Row():
|
155 |
+
question_input = gr.Textbox(label="Enter your math question", lines=2)
|
156 |
+
gr.Markdown("### ๐ง Answer:")
|
157 |
+
answer_output = gr.Markdown()
|
158 |
+
|
159 |
+
#answer_output = gr.Markdown(label="Answer")
|
160 |
+
hidden_q = gr.Textbox(visible=False)
|
161 |
+
hidden_a = gr.Textbox(visible=False)
|
162 |
submit_btn = gr.Button("Get Answer")
|
163 |
+
submit_btn.click(fn=ask_question, inputs=[question_input], outputs=[answer_output, hidden_q, hidden_a])
|
164 |
|
165 |
+
with gr.Tab("Submit Feedback"):
|
166 |
+
gr.Markdown("### Was the answer helpful?")
|
167 |
+
fb_question = gr.Textbox(label="Original Question")
|
168 |
+
fb_answer = gr.Textbox(label="Model's Answer")
|
169 |
+
fb_like = gr.Radio(["๐", "๐"], label="Your Feedback")
|
|
|
|
|
170 |
fb_correct = gr.Textbox(label="Correct Answer (optional)")
|
171 |
+
fb_submit_btn = gr.Button("Submit Feedback")
|
172 |
fb_status = gr.Textbox(label="Status", interactive=False)
|
173 |
|
174 |
+
feedback_display = gr.Dataframe(headers=["Question", "Answer", "Feedback", "Correct Answer", "Timestamp"],
|
175 |
+
row_count=10, max_rows=50, wrap=True)
|
176 |
+
|
177 |
+
def feedback_submission_and_display(question, answer, feedback, correct_answer):
|
178 |
+
store_feedback(question, answer, feedback, correct_answer)
|
179 |
+
entries = load_feedback_entries()
|
180 |
+
display_rows = [[
|
181 |
+
e["question"],
|
182 |
+
e["model_answer"],
|
183 |
+
e["feedback"],
|
184 |
+
e["correct_answer"],
|
185 |
+
e["timestamp"]
|
186 |
+
] for e in entries]
|
187 |
+
return "โ
Feedback received. Thank you!", display_rows
|
188 |
+
|
189 |
+
fb_submit_btn.click(
|
190 |
+
fn=feedback_submission_and_display,
|
191 |
+
inputs=[fb_question, fb_answer, fb_like, fb_correct],
|
192 |
+
outputs=[fb_status, feedback_display]
|
193 |
)
|
194 |
|
195 |
+
|
196 |
+
demo.launch(share=True, debug=True)
|