manasagangotri commited on
Commit
9ec24d8
·
verified ·
1 Parent(s): a1b29fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -4
app.py CHANGED
@@ -1,5 +1,5 @@
1
- import gradio as gr
2
 
 
3
  import torch
4
  import requests
5
  from transformers import pipeline
@@ -10,11 +10,18 @@ import dspy
10
  import json
11
 
12
  # === Load Models ===
 
13
  classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
 
 
14
  embedding_model = SentenceTransformer("intfloat/e5-large")
15
- qa_pipeline = pipeline("text-generation", model="WizardLM/WizardMath-7B-V1.0", device_map="auto", torch_dtype=torch.float16)
 
 
 
16
 
17
  # === Qdrant Setup ===
 
18
  qdrant_client = QdrantClient(path="qdrant_data")
19
  collection_name = "math_problems"
20
 
@@ -22,16 +29,20 @@ collection_name = "math_problems"
22
  def is_valid_math_question(text):
23
  candidate_labels = ["math", "not math"]
24
  result = classifier(text, candidate_labels)
 
25
  return result['labels'][0] == "math" and result['scores'][0] > 0.7
26
 
27
  # === Retrieval ===
28
  def retrieve_from_qdrant(query):
 
29
  query_vector = embedding_model.encode(query).tolist()
30
  hits = qdrant_client.search(collection_name=collection_name, query_vector=query_vector, limit=3)
 
31
  return [hit.payload for hit in hits] if hits else []
32
 
33
  # === Web Search ===
34
  def web_search_tavily(query):
 
35
  TAVILY_API_KEY = "your_tavily_api_key"
36
  response = requests.post(
37
  "https://api.tavily.com/search",
@@ -48,21 +59,27 @@ class MathAnswer(dspy.Signature):
48
  # === DSPy Programs ===
49
  class MathRetrievalQA(dspy.Program):
50
  def forward(self, question):
 
51
  context_items = retrieve_from_qdrant(question)
52
  context = "\n".join([item["solution"] for item in context_items if "solution" in item])
 
53
  if not context:
54
  return dspy.Output(answer="", retrieved_context="")
55
  prompt = f"Question: {question}\nContext: {context}\nAnswer:"
56
- answer = qa_pipeline(prompt, max_new_tokens=512)[0]["generated_text"]
 
 
57
  return dspy.Output(answer=answer, retrieved_context=context)
58
 
59
  class WebFallbackQA(dspy.Program):
60
  def forward(self, question):
 
61
  answer = web_search_tavily(question)
62
  return dspy.Output(answer=answer, retrieved_context="Tavily")
63
 
64
  class MathRouter(dspy.Program):
65
  def forward(self, question):
 
66
  if not is_valid_math_question(question):
67
  return dspy.Output(answer="❌ Only math questions are accepted. Please rephrase.", retrieved_context="")
68
  result = MathRetrievalQA().forward(question)
@@ -79,12 +96,15 @@ def store_feedback(question, answer, feedback, correct_answer):
79
  "correct_answer": correct_answer,
80
  "timestamp": str(datetime.now())
81
  }
 
82
  with open("feedback.json", "a") as f:
83
  f.write(json.dumps(entry) + "\n")
84
 
85
  # === Gradio Functions ===
86
  def ask_question(question):
 
87
  result = router.forward(question)
 
88
  return result.answer, question, result.answer
89
 
90
  def submit_feedback(question, model_answer, feedback, correct_answer):
@@ -116,4 +136,4 @@ with gr.Blocks() as demo:
116
  inputs=[fb_question, fb_answer, fb_like, fb_correct],
117
  outputs=[fb_status])
118
 
119
- demo.launch()
 
 
1
 
2
+ import gradio as gr
3
  import torch
4
  import requests
5
  from transformers import pipeline
 
10
  import json
11
 
12
  # === Load Models ===
13
+ print("Loading zero-shot classifier...")
14
  classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
15
+
16
+ print("Loading embedding model...")
17
  embedding_model = SentenceTransformer("intfloat/e5-large")
18
+
19
+ print("Loading text generation model...")
20
+ # Use a lighter model for testing
21
+ qa_pipeline = pipeline("text-generation", model="gpt2")
22
 
23
  # === Qdrant Setup ===
24
+ print("Connecting to Qdrant...")
25
  qdrant_client = QdrantClient(path="qdrant_data")
26
  collection_name = "math_problems"
27
 
 
29
  def is_valid_math_question(text):
30
  candidate_labels = ["math", "not math"]
31
  result = classifier(text, candidate_labels)
32
+ print("Classifier result:", result)
33
  return result['labels'][0] == "math" and result['scores'][0] > 0.7
34
 
35
  # === Retrieval ===
36
  def retrieve_from_qdrant(query):
37
+ print("Retrieving context from Qdrant...")
38
  query_vector = embedding_model.encode(query).tolist()
39
  hits = qdrant_client.search(collection_name=collection_name, query_vector=query_vector, limit=3)
40
+ print("Retrieved hits:", hits)
41
  return [hit.payload for hit in hits] if hits else []
42
 
43
  # === Web Search ===
44
  def web_search_tavily(query):
45
+ print("Calling Tavily...")
46
  TAVILY_API_KEY = "your_tavily_api_key"
47
  response = requests.post(
48
  "https://api.tavily.com/search",
 
59
  # === DSPy Programs ===
60
  class MathRetrievalQA(dspy.Program):
61
  def forward(self, question):
62
+ print("Inside MathRetrievalQA...")
63
  context_items = retrieve_from_qdrant(question)
64
  context = "\n".join([item["solution"] for item in context_items if "solution" in item])
65
+ print("Context for generation:", context)
66
  if not context:
67
  return dspy.Output(answer="", retrieved_context="")
68
  prompt = f"Question: {question}\nContext: {context}\nAnswer:"
69
+ print("Generating answer...")
70
+ answer = qa_pipeline(prompt, max_new_tokens=100)[0]["generated_text"]
71
+ print("Generated answer:", answer)
72
  return dspy.Output(answer=answer, retrieved_context=context)
73
 
74
  class WebFallbackQA(dspy.Program):
75
  def forward(self, question):
76
+ print("Fallback to Tavily...")
77
  answer = web_search_tavily(question)
78
  return dspy.Output(answer=answer, retrieved_context="Tavily")
79
 
80
  class MathRouter(dspy.Program):
81
  def forward(self, question):
82
+ print("Routing question:", question)
83
  if not is_valid_math_question(question):
84
  return dspy.Output(answer="❌ Only math questions are accepted. Please rephrase.", retrieved_context="")
85
  result = MathRetrievalQA().forward(question)
 
96
  "correct_answer": correct_answer,
97
  "timestamp": str(datetime.now())
98
  }
99
+ print("Storing feedback:", entry)
100
  with open("feedback.json", "a") as f:
101
  f.write(json.dumps(entry) + "\n")
102
 
103
  # === Gradio Functions ===
104
  def ask_question(question):
105
+ print("ask_question() called with:", question)
106
  result = router.forward(question)
107
+ print("Result:", result)
108
  return result.answer, question, result.answer
109
 
110
  def submit_feedback(question, model_answer, feedback, correct_answer):
 
136
  inputs=[fb_question, fb_answer, fb_like, fb_correct],
137
  outputs=[fb_status])
138
 
139
+ demo.launch(share=True, debug=True)