manasagangotri commited on
Commit
564d0c6
·
verified ·
1 Parent(s): 9805a18

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -3
app.py CHANGED
@@ -18,7 +18,14 @@ embedding_model = SentenceTransformer("intfloat/e5-large")
18
 
19
  print("Loading text generation model...")
20
  # Use a lighter model for testing
21
- qa_pipeline = pipeline("text-generation", model="gpt2")
 
 
 
 
 
 
 
22
 
23
  # === Qdrant Setup ===
24
  print("Connecting to Qdrant...")
@@ -65,9 +72,11 @@ class MathRetrievalQA(dspy.Program):
65
  print("Context for generation:", context)
66
  if not context:
67
  return dspy.Output(answer="", retrieved_context="")
68
- prompt = f"Question: {question}\nContext: {context}\nAnswer:"
 
69
  print("Generating answer...")
70
- answer = qa_pipeline(prompt, max_new_tokens=100)[0]["generated_text"]
 
71
  print("Generated answer:", answer)
72
  return {"answer": answer, "retrieved_context": context}
73
 
 
18
 
19
  print("Loading text generation model...")
20
  # Use a lighter model for testing
21
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
22
+
23
+ print("Loading text generation model...")
24
+ model_name = "google/flan-t5-large"
25
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
26
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
27
+ qa_pipeline = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
28
+
29
 
30
  # === Qdrant Setup ===
31
  print("Connecting to Qdrant...")
 
72
  print("Context for generation:", context)
73
  if not context:
74
  return dspy.Output(answer="", retrieved_context="")
75
+
76
+ prompt = f"Question: {question}\nStep-by-step solution:\n{context}\nAnswer:"
77
  print("Generating answer...")
78
+ answer = qa_pipeline(prompt, max_new_tokens=150)[0]["generated_text"]
79
+
80
  print("Generated answer:", answer)
81
  return {"answer": answer, "retrieved_context": context}
82