manasagangotri commited on
Commit
13c672e
·
verified ·
1 Parent(s): 564d0c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -13
app.py CHANGED
@@ -8,6 +8,13 @@ from qdrant_client import QdrantClient
8
  from datetime import datetime
9
  import dspy
10
  import json
 
 
 
 
 
 
 
11
 
12
  # === Load Models ===
13
  print("Loading zero-shot classifier...")
@@ -20,12 +27,6 @@ print("Loading text generation model...")
20
  # Use a lighter model for testing
21
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
22
 
23
- print("Loading text generation model...")
24
- model_name = "google/flan-t5-large"
25
- tokenizer = AutoTokenizer.from_pretrained(model_name)
26
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
27
- qa_pipeline = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
28
-
29
 
30
  # === Qdrant Setup ===
31
  print("Connecting to Qdrant...")
@@ -64,6 +65,8 @@ class MathAnswer(dspy.Signature):
64
  answer = dspy.OutputField()
65
 
66
  # === DSPy Programs ===
 
 
67
  class MathRetrievalQA(dspy.Program):
68
  def forward(self, question):
69
  print("Inside MathRetrievalQA...")
@@ -71,16 +74,27 @@ class MathRetrievalQA(dspy.Program):
71
  context = "\n".join([item["solution"] for item in context_items if "solution" in item])
72
  print("Context for generation:", context)
73
  if not context:
74
- return dspy.Output(answer="", retrieved_context="")
75
 
76
- prompt = f"Question: {question}\nStep-by-step solution:\n{context}\nAnswer:"
77
- print("Generating answer...")
78
- answer = qa_pipeline(prompt, max_new_tokens=150)[0]["generated_text"]
79
 
80
- print("Generated answer:", answer)
81
- return {"answer": answer, "retrieved_context": context}
82
 
83
- # return dspy.Output(answer=answer, retrieved_context=context)
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  class WebFallbackQA(dspy.Program):
86
  def forward(self, question):
 
8
  from datetime import datetime
9
  import dspy
10
  import json
11
+ import google.generativeai as genai
12
+
13
+ # Configure Gemini API
14
+ genai.configure(api_key="AIzaSyBO3-HG-WcITn58PdpK7mMyvFQitoH00qA") # Replace with your actual Gemini API key
15
+
16
+ # Load Gemini model
17
+ gemini_model = genai.GenerativeModel('gemini-pro')
18
 
19
  # === Load Models ===
20
  print("Loading zero-shot classifier...")
 
27
  # Use a lighter model for testing
28
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
29
 
 
 
 
 
 
 
30
 
31
  # === Qdrant Setup ===
32
  print("Connecting to Qdrant...")
 
65
  answer = dspy.OutputField()
66
 
67
  # === DSPy Programs ===
68
+
69
+ # return dspy.Output(answer=answer, retrieved_context=context)
70
  class MathRetrievalQA(dspy.Program):
71
  def forward(self, question):
72
  print("Inside MathRetrievalQA...")
 
74
  context = "\n".join([item["solution"] for item in context_items if "solution" in item])
75
  print("Context for generation:", context)
76
  if not context:
77
+ return {"answer": "", "retrieved_context": ""}
78
 
79
+ # Step 1: Generate raw answer (e.g., using GPT2 or any pipeline)
80
+ prompt = f"Question: {question}\nContext: {context}\nAnswer:"
81
+ raw_answer = qa_pipeline(prompt, max_new_tokens=100)[0]["generated_text"]
82
 
83
+ # Step 2: Send raw answer to Gemini for formatting
84
+ format_prompt = f"""You are a helpful math assistant. Please format the following answer into a clear, step-by-step solution for better readability.
85
 
86
+ Question: {question}
87
+
88
+ Raw Answer:
89
+ {raw_answer}
90
+
91
+ Formatted Step-by-Step Answer:"""
92
+
93
+ response = gemini_model.generate_content(format_prompt)
94
+ formatted_answer = response.text
95
+
96
+ print("Formatted answer:", formatted_answer)
97
+ return {"answer": formatted_answer, "retrieved_context": context}
98
 
99
  class WebFallbackQA(dspy.Program):
100
  def forward(self, question):