manasagangotri commited on
Commit
af77c21
ยท
verified ยท
1 Parent(s): a6a2ff2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -51
app.py CHANGED
@@ -2,13 +2,14 @@
2
  import gradio as gr
3
  import torch
4
  import requests
5
- from transformers import pipeline
6
- from sentence_transformers import SentenceTransformer
7
- from qdrant_client import QdrantClient
8
  from datetime import datetime
9
- import dspy
10
  import json
11
 
 
 
 
 
12
  # === Load Models ===
13
  print("Loading zero-shot classifier...")
14
  classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
@@ -16,17 +17,36 @@ classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnl
16
  print("Loading embedding model...")
17
  embedding_model = SentenceTransformer("intfloat/e5-large")
18
 
 
 
 
 
 
 
19
  # === Qdrant Setup ===
20
  print("Connecting to Qdrant...")
21
  qdrant_client = QdrantClient(path="qdrant_data")
22
  collection_name = "math_problems"
23
 
24
- # === Guard Function ===
25
  def is_valid_math_question(text):
26
  candidate_labels = ["math", "not math"]
27
  result = classifier(text, candidate_labels)
28
  return result['labels'][0] == "math" and result['scores'][0] > 0.7
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  # === Retrieval ===
31
  def retrieve_from_qdrant(query):
32
  query_vector = embedding_model.encode(query).tolist()
@@ -42,35 +62,41 @@ def web_search_tavily(query):
42
  )
43
  return response.json().get("answer", "No answer found from Tavily.")
44
 
45
- # === DSPy Signature ===
46
- class MathAnswer(dspy.Signature):
47
- question = dspy.InputField()
48
- retrieved_context = dspy.InputField()
49
- answer = dspy.OutputField()
50
-
51
- # === DSPy Programs ===
52
- class MathRetrievalQA(dspy.Program):
53
- def forward(self, question):
54
- context_items = retrieve_from_qdrant(question)
55
- context = "\n".join([item["solution"] for item in context_items if "solution" in item])
56
- if not context:
57
- return {"answer": "", "retrieved_context": ""}
58
- prompt = f"Question: {question}\nContext: {context}\nAnswer:"
59
- return {"answer": prompt, "retrieved_context": context}
60
-
61
- class WebFallbackQA(dspy.Program):
62
- def forward(self, question):
63
- answer = web_search_tavily(question)
64
- return {"answer": answer, "retrieved_context": "Tavily"}
65
-
66
- class MathRouter(dspy.Program):
67
- def forward(self, question):
68
- if not is_valid_math_question(question):
69
- return {"answer": "โŒ Only math questions are accepted. Please rephrase.", "retrieved_context": ""}
70
- result = MathRetrievalQA().forward(question)
71
- return result if result["answer"] else WebFallbackQA().forward(question)
72
-
73
- router = MathRouter()
 
 
 
 
 
 
74
 
75
  # === Feedback Storage ===
76
  def store_feedback(question, answer, feedback, correct_answer):
@@ -84,39 +110,35 @@ def store_feedback(question, answer, feedback, correct_answer):
84
  with open("feedback.json", "a") as f:
85
  f.write(json.dumps(entry) + "\n")
86
 
87
- # === Gradio Functions ===
88
  def ask_question(question):
89
- result = router.forward(question)
90
- return result["answer"], question, result["answer"]
91
 
92
- def submit_feedback(question, model_answer, feedback, correct_answer):
93
- store_feedback(question, model_answer, feedback, correct_answer)
94
  return "โœ… Feedback received. Thank you!"
95
 
96
- # === Gradio UI ===
97
  with gr.Blocks() as demo:
98
- gr.Markdown("## ๐Ÿงฎ Math Question Answering with DSPy + Feedback")
99
 
100
  with gr.Row():
101
  question_input = gr.Textbox(label="Enter your math question", lines=2)
 
102
 
103
  answer_output = gr.Markdown()
104
  hidden_q = gr.Textbox(visible=False)
105
  hidden_a = gr.Textbox(visible=False)
106
-
107
- submit_btn = gr.Button("Get Answer")
108
  submit_btn.click(fn=ask_question, inputs=[question_input], outputs=[answer_output, hidden_q, hidden_a])
109
 
110
- # Feedback section on same page
111
- gr.Markdown("### ๐Ÿ’ฌ Give Feedback")
112
- fb_correct = gr.Textbox(label="Correct Answer (optional)")
113
- fb_like = gr.Radio(["๐Ÿ‘", "๐Ÿ‘Ž"], label="Was the answer helpful?")
114
  fb_submit_btn = gr.Button("Submit Feedback")
115
- fb_status = gr.Markdown()
116
 
117
  fb_submit_btn.click(fn=submit_feedback,
118
- inputs=[hidden_q, hidden_a, fb_like, fb_correct],
119
  outputs=[fb_status])
120
 
121
- demo.launch(share=True, debug=True)
122
-
 
2
  import gradio as gr
3
  import torch
4
  import requests
5
+ import re
 
 
6
  from datetime import datetime
 
7
  import json
8
 
9
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
10
+ from sentence_transformers import SentenceTransformer
11
+ from qdrant_client import QdrantClient
12
+
13
  # === Load Models ===
14
  print("Loading zero-shot classifier...")
15
  classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
 
17
  print("Loading embedding model...")
18
  embedding_model = SentenceTransformer("intfloat/e5-large")
19
 
20
+ print("Loading WizardMath model...")
21
+ tokenizer = AutoTokenizer.from_pretrained("WizardLM/WizardMath-7B-V1.1")
22
+ model = AutoModelForCausalLM.from_pretrained(
23
+ "WizardLM/WizardMath-7B-V1.1", torch_dtype=torch.float16, device_map="auto"
24
+ )
25
+
26
  # === Qdrant Setup ===
27
  print("Connecting to Qdrant...")
28
  qdrant_client = QdrantClient(path="qdrant_data")
29
  collection_name = "math_problems"
30
 
31
+ # === Guard Functions ===
32
  def is_valid_math_question(text):
33
  candidate_labels = ["math", "not math"]
34
  result = classifier(text, candidate_labels)
35
  return result['labels'][0] == "math" and result['scores'][0] > 0.7
36
 
37
+ def output_guardrails(answer):
38
+ if not answer or len(answer.strip()) < 10:
39
+ return False
40
+ math_keywords = ["solve", "equation", "integral", "derivative", "value", "expression", "steps", "solution"]
41
+ if not any(word in answer.lower() for word in math_keywords):
42
+ return False
43
+ banned_keywords = ["kill", "bomb", "hate", "politics", "violence"]
44
+ if any(word in answer.lower() for word in banned_keywords):
45
+ return False
46
+ if re.match(r"^\s*I'm just a model|Sorry, I can't|As an AI", answer, re.IGNORECASE):
47
+ return False
48
+ return True
49
+
50
  # === Retrieval ===
51
  def retrieve_from_qdrant(query):
52
  query_vector = embedding_model.encode(query).tolist()
 
62
  )
63
  return response.json().get("answer", "No answer found from Tavily.")
64
 
65
+ # === Answer Generation ===
66
+ def generate_step_by_step_answer(question, context=""):
67
+ prompt = f"### Question:\n{question}\n"
68
+ if context:
69
+ prompt += f"### Context:\n{context}\n"
70
+ prompt += "### Let's solve it step by step:\n"
71
+
72
+ inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
73
+ outputs = model.generate(
74
+ **inputs,
75
+ max_new_tokens=256,
76
+ temperature=0.7,
77
+ top_p=0.95,
78
+ do_sample=True,
79
+ pad_token_id=tokenizer.eos_token_id
80
+ )
81
+ decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
82
+ answer = decoded.split("### Let's solve it step by step:")[-1].strip()
83
+ return answer
84
+
85
+ # === Router ===
86
+ def router(question):
87
+ if not is_valid_math_question(question):
88
+ return "โŒ Only math questions are accepted. Please rephrase."
89
+
90
+ context_items = retrieve_from_qdrant(question)
91
+ context = "\n".join([item.get("solution", "") for item in context_items])
92
+
93
+ if context:
94
+ answer = generate_step_by_step_answer(question, context)
95
+ if output_guardrails(answer):
96
+ return answer
97
+
98
+ answer = web_search_tavily(question)
99
+ return answer if output_guardrails(answer) else "โš ๏ธ No valid math answer found."
100
 
101
  # === Feedback Storage ===
102
  def store_feedback(question, answer, feedback, correct_answer):
 
110
  with open("feedback.json", "a") as f:
111
  f.write(json.dumps(entry) + "\n")
112
 
113
+ # === Gradio UI ===
114
  def ask_question(question):
115
+ answer = router(question)
116
+ return answer, question, answer
117
 
118
+ def submit_feedback(question, model_answer, feedback):
119
+ store_feedback(question, model_answer, feedback, "")
120
  return "โœ… Feedback received. Thank you!"
121
 
 
122
  with gr.Blocks() as demo:
123
+ gr.Markdown("## ๐Ÿงฎ Math Tutor with AI Guardrails + Feedback")
124
 
125
  with gr.Row():
126
  question_input = gr.Textbox(label="Enter your math question", lines=2)
127
+ submit_btn = gr.Button("Get Answer")
128
 
129
  answer_output = gr.Markdown()
130
  hidden_q = gr.Textbox(visible=False)
131
  hidden_a = gr.Textbox(visible=False)
132
+
 
133
  submit_btn.click(fn=ask_question, inputs=[question_input], outputs=[answer_output, hidden_q, hidden_a])
134
 
135
+ gr.Markdown("### ๐Ÿ“ Feedback")
136
+ fb_like = gr.Radio(["๐Ÿ‘", "๐Ÿ‘Ž"], label="Was this answer helpful?")
 
 
137
  fb_submit_btn = gr.Button("Submit Feedback")
138
+ fb_status = gr.Textbox(label="Status", interactive=False)
139
 
140
  fb_submit_btn.click(fn=submit_feedback,
141
+ inputs=[hidden_q, hidden_a, fb_like],
142
  outputs=[fb_status])
143
 
144
+ demo.launch(share=True, debug=True)