raghavNCI
commited on
Commit
Β·
d606723
1
Parent(s):
ccf9b0b
redis correction in question.py
Browse files- routes/question.py +53 -54
routes/question.py
CHANGED
@@ -1,12 +1,14 @@
|
|
1 |
import os
|
2 |
-
import requests
|
3 |
-
import datetime
|
4 |
import json
|
|
|
|
|
|
|
|
|
5 |
from fastapi import APIRouter
|
6 |
from pydantic import BaseModel
|
7 |
-
from clients.redis_client import redis_client as r
|
8 |
from dotenv import load_dotenv
|
9 |
|
|
|
10 |
from models_initialization.mistral_registry import mistral_generate
|
11 |
from nuse_modules.classifier import classify_question, REVERSE_MAP
|
12 |
from nuse_modules.keyword_extracter import keywords_extractor
|
@@ -16,63 +18,75 @@ load_dotenv()
|
|
16 |
|
17 |
askMe = APIRouter()
|
18 |
|
|
|
|
|
|
|
19 |
class QuestionInput(BaseModel):
|
20 |
question: str
|
21 |
|
22 |
|
|
|
|
|
|
|
|
|
23 |
def should_extract_keywords(type_id: int) -> bool:
|
|
|
24 |
return type_id in {1, 2, 3, 4, 5, 6, 7, 10, 11, 12}
|
25 |
|
26 |
|
27 |
def extract_answer_after_label(text: str) -> str:
|
28 |
-
"""
|
29 |
-
Extracts everything after the first 'Answer:' label.
|
30 |
-
Assumes 'Answer:' appears once and is followed by the relevant content.
|
31 |
-
"""
|
32 |
if "Answer:" in text:
|
33 |
return text.split("Answer:", 1)[1].strip()
|
34 |
return text.strip()
|
35 |
|
36 |
|
|
|
|
|
|
|
|
|
37 |
@askMe.post("/ask")
|
38 |
async def ask_question(input: QuestionInput):
|
39 |
-
question = input.question
|
40 |
|
41 |
-
#
|
42 |
qid = classify_question(question)
|
43 |
print("Intent ID:", qid)
|
44 |
print("Category:", REVERSE_MAP.get(qid, "unknown"))
|
45 |
|
46 |
-
|
47 |
-
sources = []
|
48 |
-
|
49 |
if qid == 13:
|
50 |
date_str = datetime.datetime.utcnow().strftime("%Y-%m-%d")
|
51 |
categories = ["world", "india", "finance", "sports", "entertainment"]
|
52 |
-
all_headlines = []
|
53 |
|
54 |
for cat in categories:
|
55 |
-
|
56 |
-
cached = _r.get(
|
57 |
if cached:
|
58 |
-
|
|
|
|
|
|
|
59 |
for art in articles:
|
60 |
all_headlines.append({
|
61 |
-
"title":
|
62 |
-
"summary": art
|
63 |
-
"url":
|
64 |
-
"image":
|
65 |
"category": cat,
|
66 |
})
|
67 |
|
68 |
return {
|
69 |
"question": question,
|
70 |
"answer": "Here are todayβs top headlines:",
|
71 |
-
"headlines": all_headlines
|
72 |
}
|
73 |
-
|
74 |
|
75 |
-
#
|
|
|
|
|
|
|
76 |
if should_extract_keywords(qid):
|
77 |
keywords = keywords_extractor(question)
|
78 |
print("Raw extracted keywords:", keywords)
|
@@ -80,60 +94,45 @@ async def ask_question(input: QuestionInput):
|
|
80 |
if not keywords:
|
81 |
return {"error": "Keyword extraction failed."}
|
82 |
|
83 |
-
#
|
84 |
results = search_google_news(keywords)
|
85 |
print("Found articles:", results)
|
86 |
|
87 |
-
# for r in results:
|
88 |
-
# print(r["title"], r["link"])
|
89 |
-
|
90 |
-
# Build context from snippet/description
|
91 |
context = "\n\n".join([
|
92 |
-
r.get("snippet") or r.get("description", "")
|
93 |
-
for r in results
|
94 |
])[:15000]
|
95 |
|
96 |
-
sources = [
|
97 |
-
{"title": r["title"], "url": r["link"]}
|
98 |
-
for r in results
|
99 |
-
]
|
100 |
|
101 |
if not context.strip():
|
102 |
return {
|
103 |
"question": question,
|
104 |
"answer": "Cannot answer β no relevant context found.",
|
105 |
-
"sources": sources
|
106 |
}
|
107 |
-
|
108 |
-
# Step 3: Ask Mistral to answer
|
109 |
answer_prompt = (
|
110 |
-
|
111 |
-
|
112 |
-
f"If the context is not helpful, you may rely on your own knowledge, but do not mention the context or question again.\n\n"
|
113 |
f"Context:\n{context}\n\n"
|
114 |
-
f"Question: {question}\n\
|
115 |
-
f"Answer:"
|
116 |
)
|
117 |
answer_raw = mistral_generate(answer_prompt, max_new_tokens=256)
|
118 |
-
|
119 |
-
else:
|
120 |
|
|
|
121 |
answer_prompt = (
|
122 |
-
|
123 |
-
f"
|
124 |
-
f"Question: {question}\n\n"
|
125 |
-
f"Answer:"
|
126 |
)
|
127 |
-
|
128 |
answer_raw = mistral_generate(answer_prompt, max_new_tokens=256)
|
129 |
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
|
135 |
return {
|
136 |
"question": question,
|
137 |
"answer": final_answer.strip(),
|
138 |
-
"sources": sources
|
139 |
}
|
|
|
1 |
import os
|
|
|
|
|
2 |
import json
|
3 |
+
import datetime
|
4 |
+
from typing import List, Dict
|
5 |
+
|
6 |
+
import requests
|
7 |
from fastapi import APIRouter
|
8 |
from pydantic import BaseModel
|
|
|
9 |
from dotenv import load_dotenv
|
10 |
|
11 |
+
from clients.redis_client import redis_client as _r
|
12 |
from models_initialization.mistral_registry import mistral_generate
|
13 |
from nuse_modules.classifier import classify_question, REVERSE_MAP
|
14 |
from nuse_modules.keyword_extracter import keywords_extractor
|
|
|
18 |
|
19 |
askMe = APIRouter()
|
20 |
|
21 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
22 |
+
# Pydantic schema
|
23 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
24 |
class QuestionInput(BaseModel):
|
25 |
question: str
|
26 |
|
27 |
|
28 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
29 |
+
# Helper functions
|
30 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
31 |
+
|
32 |
def should_extract_keywords(type_id: int) -> bool:
|
33 |
+
"""Map the intent id to whether we need keyword extraction."""
|
34 |
return type_id in {1, 2, 3, 4, 5, 6, 7, 10, 11, 12}
|
35 |
|
36 |
|
37 |
def extract_answer_after_label(text: str) -> str:
|
38 |
+
"""Extracts everything after the first 'Answer:' label."""
|
|
|
|
|
|
|
39 |
if "Answer:" in text:
|
40 |
return text.split("Answer:", 1)[1].strip()
|
41 |
return text.strip()
|
42 |
|
43 |
|
44 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
45 |
+
# FastAPI route
|
46 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
47 |
+
|
48 |
@askMe.post("/ask")
|
49 |
async def ask_question(input: QuestionInput):
|
50 |
+
question = input.question.strip()
|
51 |
|
52 |
+
# 1οΈβ£ Classify intent
|
53 |
qid = classify_question(question)
|
54 |
print("Intent ID:", qid)
|
55 |
print("Category:", REVERSE_MAP.get(qid, "unknown"))
|
56 |
|
57 |
+
# Special case: ID 13 β return cached headlines
|
|
|
|
|
58 |
if qid == 13:
|
59 |
date_str = datetime.datetime.utcnow().strftime("%Y-%m-%d")
|
60 |
categories = ["world", "india", "finance", "sports", "entertainment"]
|
61 |
+
all_headlines: List[Dict] = []
|
62 |
|
63 |
for cat in categories:
|
64 |
+
redis_key = f"headlines:{date_str}:{cat}"
|
65 |
+
cached = _r.get(redis_key)
|
66 |
if cached:
|
67 |
+
try:
|
68 |
+
articles = json.loads(cached)
|
69 |
+
except json.JSONDecodeError:
|
70 |
+
continue
|
71 |
for art in articles:
|
72 |
all_headlines.append({
|
73 |
+
"title": art.get("title"),
|
74 |
+
"summary": art.get("summary"),
|
75 |
+
"url": art.get("url"),
|
76 |
+
"image": art.get("image"),
|
77 |
"category": cat,
|
78 |
})
|
79 |
|
80 |
return {
|
81 |
"question": question,
|
82 |
"answer": "Here are todayβs top headlines:",
|
83 |
+
"headlines": all_headlines,
|
84 |
}
|
|
|
85 |
|
86 |
+
# 2οΈβ£ Keywordβbased flow for other intents
|
87 |
+
context = ""
|
88 |
+
sources: List[Dict] = []
|
89 |
+
|
90 |
if should_extract_keywords(qid):
|
91 |
keywords = keywords_extractor(question)
|
92 |
print("Raw extracted keywords:", keywords)
|
|
|
94 |
if not keywords:
|
95 |
return {"error": "Keyword extraction failed."}
|
96 |
|
97 |
+
# Google News search
|
98 |
results = search_google_news(keywords)
|
99 |
print("Found articles:", results)
|
100 |
|
|
|
|
|
|
|
|
|
101 |
context = "\n\n".join([
|
102 |
+
r.get("snippet") or r.get("description", "") for r in results
|
|
|
103 |
])[:15000]
|
104 |
|
105 |
+
sources = [{"title": r["title"], "url": r["link"]} for r in results]
|
|
|
|
|
|
|
106 |
|
107 |
if not context.strip():
|
108 |
return {
|
109 |
"question": question,
|
110 |
"answer": "Cannot answer β no relevant context found.",
|
111 |
+
"sources": sources,
|
112 |
}
|
113 |
+
|
|
|
114 |
answer_prompt = (
|
115 |
+
"You are a concise news assistant. Answer the user's question clearly using the provided context if relevant. "
|
116 |
+
"If the context is not helpful, rely on your own knowledge but do not mention the context.\n\n"
|
|
|
117 |
f"Context:\n{context}\n\n"
|
118 |
+
f"Question: {question}\n\nAnswer:"
|
|
|
119 |
)
|
120 |
answer_raw = mistral_generate(answer_prompt, max_new_tokens=256)
|
|
|
|
|
121 |
|
122 |
+
else:
|
123 |
answer_prompt = (
|
124 |
+
"You are a concise news assistant. Answer the user's question clearly and accurately.\n\n"
|
125 |
+
f"Question: {question}\n\nAnswer:"
|
|
|
|
|
126 |
)
|
|
|
127 |
answer_raw = mistral_generate(answer_prompt, max_new_tokens=256)
|
128 |
|
129 |
+
# 3οΈβ£ Postβprocess model output
|
130 |
+
final_answer = extract_answer_after_label(answer_raw or "") or (
|
131 |
+
"Cannot answer β model did not return a valid response."
|
132 |
+
)
|
133 |
|
134 |
return {
|
135 |
"question": question,
|
136 |
"answer": final_answer.strip(),
|
137 |
+
"sources": sources,
|
138 |
}
|