Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
-
import gradio as gr
|
2 |
|
|
|
3 |
import torch
|
4 |
import requests
|
5 |
from transformers import pipeline
|
@@ -10,11 +10,18 @@ import dspy
|
|
10 |
import json
|
11 |
|
12 |
# === Load Models ===
|
|
|
13 |
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
|
|
|
|
|
14 |
embedding_model = SentenceTransformer("intfloat/e5-large")
|
15 |
-
|
|
|
|
|
|
|
16 |
|
17 |
# === Qdrant Setup ===
|
|
|
18 |
qdrant_client = QdrantClient(path="qdrant_data")
|
19 |
collection_name = "math_problems"
|
20 |
|
@@ -22,16 +29,20 @@ collection_name = "math_problems"
|
|
22 |
def is_valid_math_question(text):
|
23 |
candidate_labels = ["math", "not math"]
|
24 |
result = classifier(text, candidate_labels)
|
|
|
25 |
return result['labels'][0] == "math" and result['scores'][0] > 0.7
|
26 |
|
27 |
# === Retrieval ===
|
28 |
def retrieve_from_qdrant(query):
|
|
|
29 |
query_vector = embedding_model.encode(query).tolist()
|
30 |
hits = qdrant_client.search(collection_name=collection_name, query_vector=query_vector, limit=3)
|
|
|
31 |
return [hit.payload for hit in hits] if hits else []
|
32 |
|
33 |
# === Web Search ===
|
34 |
def web_search_tavily(query):
|
|
|
35 |
TAVILY_API_KEY = "your_tavily_api_key"
|
36 |
response = requests.post(
|
37 |
"https://api.tavily.com/search",
|
@@ -48,21 +59,27 @@ class MathAnswer(dspy.Signature):
|
|
48 |
# === DSPy Programs ===
|
49 |
class MathRetrievalQA(dspy.Program):
|
50 |
def forward(self, question):
|
|
|
51 |
context_items = retrieve_from_qdrant(question)
|
52 |
context = "\n".join([item["solution"] for item in context_items if "solution" in item])
|
|
|
53 |
if not context:
|
54 |
return dspy.Output(answer="", retrieved_context="")
|
55 |
prompt = f"Question: {question}\nContext: {context}\nAnswer:"
|
56 |
-
|
|
|
|
|
57 |
return dspy.Output(answer=answer, retrieved_context=context)
|
58 |
|
59 |
class WebFallbackQA(dspy.Program):
|
60 |
def forward(self, question):
|
|
|
61 |
answer = web_search_tavily(question)
|
62 |
return dspy.Output(answer=answer, retrieved_context="Tavily")
|
63 |
|
64 |
class MathRouter(dspy.Program):
|
65 |
def forward(self, question):
|
|
|
66 |
if not is_valid_math_question(question):
|
67 |
return dspy.Output(answer="❌ Only math questions are accepted. Please rephrase.", retrieved_context="")
|
68 |
result = MathRetrievalQA().forward(question)
|
@@ -79,12 +96,15 @@ def store_feedback(question, answer, feedback, correct_answer):
|
|
79 |
"correct_answer": correct_answer,
|
80 |
"timestamp": str(datetime.now())
|
81 |
}
|
|
|
82 |
with open("feedback.json", "a") as f:
|
83 |
f.write(json.dumps(entry) + "\n")
|
84 |
|
85 |
# === Gradio Functions ===
|
86 |
def ask_question(question):
|
|
|
87 |
result = router.forward(question)
|
|
|
88 |
return result.answer, question, result.answer
|
89 |
|
90 |
def submit_feedback(question, model_answer, feedback, correct_answer):
|
@@ -116,4 +136,4 @@ with gr.Blocks() as demo:
|
|
116 |
inputs=[fb_question, fb_answer, fb_like, fb_correct],
|
117 |
outputs=[fb_status])
|
118 |
|
119 |
-
demo.launch()
|
|
|
|
|
1 |
|
2 |
+
import gradio as gr
|
3 |
import torch
|
4 |
import requests
|
5 |
from transformers import pipeline
|
|
|
10 |
import json
|
11 |
|
12 |
# === Load Models ===
|
13 |
+
print("Loading zero-shot classifier...")
|
14 |
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
|
15 |
+
|
16 |
+
print("Loading embedding model...")
|
17 |
embedding_model = SentenceTransformer("intfloat/e5-large")
|
18 |
+
|
19 |
+
print("Loading text generation model...")
|
20 |
+
# Use a lighter model for testing
|
21 |
+
qa_pipeline = pipeline("text-generation", model="gpt2")
|
22 |
|
23 |
# === Qdrant Setup ===
|
24 |
+
print("Connecting to Qdrant...")
|
25 |
qdrant_client = QdrantClient(path="qdrant_data")
|
26 |
collection_name = "math_problems"
|
27 |
|
|
|
29 |
def is_valid_math_question(text):
|
30 |
candidate_labels = ["math", "not math"]
|
31 |
result = classifier(text, candidate_labels)
|
32 |
+
print("Classifier result:", result)
|
33 |
return result['labels'][0] == "math" and result['scores'][0] > 0.7
|
34 |
|
35 |
# === Retrieval ===
|
36 |
def retrieve_from_qdrant(query):
|
37 |
+
print("Retrieving context from Qdrant...")
|
38 |
query_vector = embedding_model.encode(query).tolist()
|
39 |
hits = qdrant_client.search(collection_name=collection_name, query_vector=query_vector, limit=3)
|
40 |
+
print("Retrieved hits:", hits)
|
41 |
return [hit.payload for hit in hits] if hits else []
|
42 |
|
43 |
# === Web Search ===
|
44 |
def web_search_tavily(query):
|
45 |
+
print("Calling Tavily...")
|
46 |
TAVILY_API_KEY = "your_tavily_api_key"
|
47 |
response = requests.post(
|
48 |
"https://api.tavily.com/search",
|
|
|
59 |
# === DSPy Programs ===
|
60 |
class MathRetrievalQA(dspy.Program):
|
61 |
def forward(self, question):
|
62 |
+
print("Inside MathRetrievalQA...")
|
63 |
context_items = retrieve_from_qdrant(question)
|
64 |
context = "\n".join([item["solution"] for item in context_items if "solution" in item])
|
65 |
+
print("Context for generation:", context)
|
66 |
if not context:
|
67 |
return dspy.Output(answer="", retrieved_context="")
|
68 |
prompt = f"Question: {question}\nContext: {context}\nAnswer:"
|
69 |
+
print("Generating answer...")
|
70 |
+
answer = qa_pipeline(prompt, max_new_tokens=100)[0]["generated_text"]
|
71 |
+
print("Generated answer:", answer)
|
72 |
return dspy.Output(answer=answer, retrieved_context=context)
|
73 |
|
74 |
class WebFallbackQA(dspy.Program):
|
75 |
def forward(self, question):
|
76 |
+
print("Fallback to Tavily...")
|
77 |
answer = web_search_tavily(question)
|
78 |
return dspy.Output(answer=answer, retrieved_context="Tavily")
|
79 |
|
80 |
class MathRouter(dspy.Program):
|
81 |
def forward(self, question):
|
82 |
+
print("Routing question:", question)
|
83 |
if not is_valid_math_question(question):
|
84 |
return dspy.Output(answer="❌ Only math questions are accepted. Please rephrase.", retrieved_context="")
|
85 |
result = MathRetrievalQA().forward(question)
|
|
|
96 |
"correct_answer": correct_answer,
|
97 |
"timestamp": str(datetime.now())
|
98 |
}
|
99 |
+
print("Storing feedback:", entry)
|
100 |
with open("feedback.json", "a") as f:
|
101 |
f.write(json.dumps(entry) + "\n")
|
102 |
|
103 |
# === Gradio Functions ===
|
104 |
def ask_question(question):
|
105 |
+
print("ask_question() called with:", question)
|
106 |
result = router.forward(question)
|
107 |
+
print("Result:", result)
|
108 |
return result.answer, question, result.answer
|
109 |
|
110 |
def submit_feedback(question, model_answer, feedback, correct_answer):
|
|
|
136 |
inputs=[fb_question, fb_answer, fb_like, fb_correct],
|
137 |
outputs=[fb_status])
|
138 |
|
139 |
+
demo.launch(share=True, debug=True)
|