manasagangotri commited on
Commit
5eea801
ยท
verified ยท
1 Parent(s): af77c21

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -60
app.py CHANGED
@@ -2,13 +2,11 @@
2
  import gradio as gr
3
  import torch
4
  import requests
5
- import re
6
- from datetime import datetime
7
  import json
8
-
9
  from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
10
  from sentence_transformers import SentenceTransformer
11
  from qdrant_client import QdrantClient
 
12
 
13
  # === Load Models ===
14
  print("Loading zero-shot classifier...")
@@ -17,43 +15,32 @@ classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnl
17
  print("Loading embedding model...")
18
  embedding_model = SentenceTransformer("intfloat/e5-large")
19
 
20
- print("Loading WizardMath model...")
21
- tokenizer = AutoTokenizer.from_pretrained("WizardLM/WizardMath-7B-V1.1")
22
- model = AutoModelForCausalLM.from_pretrained(
23
- "WizardLM/WizardMath-7B-V1.1", torch_dtype=torch.float16, device_map="auto"
24
- )
25
 
26
  # === Qdrant Setup ===
27
  print("Connecting to Qdrant...")
28
  qdrant_client = QdrantClient(path="qdrant_data")
29
  collection_name = "math_problems"
30
 
31
- # === Guard Functions ===
32
  def is_valid_math_question(text):
33
  candidate_labels = ["math", "not math"]
34
  result = classifier(text, candidate_labels)
35
  return result['labels'][0] == "math" and result['scores'][0] > 0.7
36
 
37
- def output_guardrails(answer):
38
- if not answer or len(answer.strip()) < 10:
39
- return False
40
- math_keywords = ["solve", "equation", "integral", "derivative", "value", "expression", "steps", "solution"]
41
- if not any(word in answer.lower() for word in math_keywords):
42
- return False
43
- banned_keywords = ["kill", "bomb", "hate", "politics", "violence"]
44
- if any(word in answer.lower() for word in banned_keywords):
45
- return False
46
- if re.match(r"^\s*I'm just a model|Sorry, I can't|As an AI", answer, re.IGNORECASE):
47
- return False
48
- return True
49
-
50
- # === Retrieval ===
51
  def retrieve_from_qdrant(query):
52
  query_vector = embedding_model.encode(query).tolist()
53
- hits = qdrant_client.search(collection_name=collection_name, query_vector=query_vector, limit=3)
 
 
 
 
54
  return [hit.payload for hit in hits] if hits else []
55
 
56
- # === Web Search ===
57
  def web_search_tavily(query):
58
  TAVILY_API_KEY = "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"
59
  response = requests.post(
@@ -62,14 +49,10 @@ def web_search_tavily(query):
62
  )
63
  return response.json().get("answer", "No answer found from Tavily.")
64
 
65
- # === Answer Generation ===
66
- def generate_step_by_step_answer(question, context=""):
67
- prompt = f"### Question:\n{question}\n"
68
- if context:
69
- prompt += f"### Context:\n{context}\n"
70
- prompt += "### Let's solve it step by step:\n"
71
-
72
- inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
73
  outputs = model.generate(
74
  **inputs,
75
  max_new_tokens=256,
@@ -78,67 +61,65 @@ def generate_step_by_step_answer(question, context=""):
78
  do_sample=True,
79
  pad_token_id=tokenizer.eos_token_id
80
  )
81
- decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
82
- answer = decoded.split("### Let's solve it step by step:")[-1].strip()
83
- return answer
84
 
85
  # === Router ===
86
  def router(question):
87
  if not is_valid_math_question(question):
88
- return "โŒ Only math questions are accepted. Please rephrase."
89
-
90
- context_items = retrieve_from_qdrant(question)
91
- context = "\n".join([item.get("solution", "") for item in context_items])
92
 
 
 
93
  if context:
94
  answer = generate_step_by_step_answer(question, context)
95
- if output_guardrails(answer):
96
- return answer
97
-
98
- answer = web_search_tavily(question)
99
- return answer if output_guardrails(answer) else "โš ๏ธ No valid math answer found."
100
 
101
  # === Feedback Storage ===
102
- def store_feedback(question, answer, feedback, correct_answer):
103
  entry = {
104
  "question": question,
105
  "model_answer": answer,
106
- "feedback": feedback,
107
  "correct_answer": correct_answer,
108
  "timestamp": str(datetime.now())
109
  }
110
  with open("feedback.json", "a") as f:
111
  f.write(json.dumps(entry) + "\n")
112
 
113
- # === Gradio UI ===
114
  def ask_question(question):
115
- answer = router(question)
116
  return answer, question, answer
117
 
118
- def submit_feedback(question, model_answer, feedback):
119
- store_feedback(question, model_answer, feedback, "")
120
  return "โœ… Feedback received. Thank you!"
121
 
 
122
  with gr.Blocks() as demo:
123
- gr.Markdown("## ๐Ÿงฎ Math Tutor with AI Guardrails + Feedback")
124
 
125
  with gr.Row():
126
  question_input = gr.Textbox(label="Enter your math question", lines=2)
127
  submit_btn = gr.Button("Get Answer")
128
 
129
- answer_output = gr.Markdown()
130
  hidden_q = gr.Textbox(visible=False)
131
  hidden_a = gr.Textbox(visible=False)
132
 
133
  submit_btn.click(fn=ask_question, inputs=[question_input], outputs=[answer_output, hidden_q, hidden_a])
134
 
135
- gr.Markdown("### ๐Ÿ“ Feedback")
136
- fb_like = gr.Radio(["๐Ÿ‘", "๐Ÿ‘Ž"], label="Was this answer helpful?")
137
- fb_submit_btn = gr.Button("Submit Feedback")
138
  fb_status = gr.Textbox(label="Status", interactive=False)
139
 
140
- fb_submit_btn.click(fn=submit_feedback,
141
- inputs=[hidden_q, hidden_a, fb_like],
142
- outputs=[fb_status])
 
 
143
 
144
- demo.launch(share=True, debug=True)
 
2
  import gradio as gr
3
  import torch
4
  import requests
 
 
5
  import json
 
6
  from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
7
  from sentence_transformers import SentenceTransformer
8
  from qdrant_client import QdrantClient
9
+ from datetime import datetime
10
 
11
  # === Load Models ===
12
  print("Loading zero-shot classifier...")
 
15
  print("Loading embedding model...")
16
  embedding_model = SentenceTransformer("intfloat/e5-large")
17
 
18
+ print("Loading step-by-step generator...")
19
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")
20
+ model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2")
 
 
21
 
22
  # === Qdrant Setup ===
23
  print("Connecting to Qdrant...")
24
  qdrant_client = QdrantClient(path="qdrant_data")
25
  collection_name = "math_problems"
26
 
27
+ # === Guard Function ===
28
  def is_valid_math_question(text):
29
  candidate_labels = ["math", "not math"]
30
  result = classifier(text, candidate_labels)
31
  return result['labels'][0] == "math" and result['scores'][0] > 0.7
32
 
33
+ # === Retrieval from Qdrant ===
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  def retrieve_from_qdrant(query):
35
  query_vector = embedding_model.encode(query).tolist()
36
+ hits = qdrant_client.query_points(
37
+ collection_name=collection_name,
38
+ query_vector=query_vector,
39
+ limit=3
40
+ )
41
  return [hit.payload for hit in hits] if hits else []
42
 
43
+ # === Web Search Fallback ===
44
  def web_search_tavily(query):
45
  TAVILY_API_KEY = "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"
46
  response = requests.post(
 
49
  )
50
  return response.json().get("answer", "No answer found from Tavily.")
51
 
52
+ # === Generator ===
53
+ def generate_step_by_step_answer(question, context):
54
+ prompt = f"Answer the following math question step-by-step:\nQuestion: {question}\nContext: {context}\nAnswer:"
55
+ inputs = tokenizer(prompt, return_tensors="pt")
 
 
 
 
56
  outputs = model.generate(
57
  **inputs,
58
  max_new_tokens=256,
 
61
  do_sample=True,
62
  pad_token_id=tokenizer.eos_token_id
63
  )
64
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
65
 
66
  # === Router ===
67
  def router(question):
68
  if not is_valid_math_question(question):
69
+ return "โŒ Only math questions are accepted. Please rephrase.", ""
 
 
 
70
 
71
+ retrieved = retrieve_from_qdrant(question)
72
+ context = "\n".join([item["solution"] for item in retrieved if "solution" in item])
73
  if context:
74
  answer = generate_step_by_step_answer(question, context)
75
+ return answer, context
76
+ else:
77
+ fallback = web_search_tavily(question)
78
+ return fallback, "Tavily Search"
 
79
 
80
  # === Feedback Storage ===
81
+ def store_feedback(question, answer, correct_answer):
82
  entry = {
83
  "question": question,
84
  "model_answer": answer,
 
85
  "correct_answer": correct_answer,
86
  "timestamp": str(datetime.now())
87
  }
88
  with open("feedback.json", "a") as f:
89
  f.write(json.dumps(entry) + "\n")
90
 
91
+ # === Gradio Functions ===
92
  def ask_question(question):
93
+ answer, context = router(question)
94
  return answer, question, answer
95
 
96
+ def submit_feedback(question, model_answer, correct_answer):
97
+ store_feedback(question, model_answer, correct_answer)
98
  return "โœ… Feedback received. Thank you!"
99
 
100
+ # === Gradio UI ===
101
  with gr.Blocks() as demo:
102
+ gr.Markdown("## ๐Ÿงฎ Math Question Answering with Retrieval + Feedback")
103
 
104
  with gr.Row():
105
  question_input = gr.Textbox(label="Enter your math question", lines=2)
106
  submit_btn = gr.Button("Get Answer")
107
 
108
+ answer_output = gr.Markdown(label="Answer")
109
  hidden_q = gr.Textbox(visible=False)
110
  hidden_a = gr.Textbox(visible=False)
111
 
112
  submit_btn.click(fn=ask_question, inputs=[question_input], outputs=[answer_output, hidden_q, hidden_a])
113
 
114
+ gr.Markdown("### ๐Ÿ“ Submit Feedback")
115
+ fb_correct = gr.Textbox(label="Correct Answer (optional)")
116
+ fb_submit = gr.Button("Submit Feedback")
117
  fb_status = gr.Textbox(label="Status", interactive=False)
118
 
119
+ fb_submit.click(
120
+ fn=submit_feedback,
121
+ inputs=[hidden_q, hidden_a, fb_correct],
122
+ outputs=[fb_status]
123
+ )
124
 
125
+ demo.launch(share=True)