raghavNCI
commited on
Commit
·
eafca75
1
Parent(s):
f5e2bc7
added a classifier v1
Browse files- nuse_modules/__init__.py +0 -0
- nuse_modules/classifier.py +97 -0
- question.py +5 -4
nuse_modules/__init__.py
ADDED
File without changes
|
nuse_modules/classifier.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import requests
|
4 |
+
from dotenv import load_dotenv
|
5 |
+
|
6 |
+
load_dotenv()
|
7 |
+
|
8 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
9 |
+
HF_ZERO_SHOT_MODEL = "facebook/bart-large-mnli"
|
10 |
+
|
11 |
+
# Map readable categories to numeric IDs
|
12 |
+
QUESTION_TYPES = {
|
13 |
+
"recent_update": 1,
|
14 |
+
"explainer": 2,
|
15 |
+
"timeline": 3,
|
16 |
+
"person_in_news": 4,
|
17 |
+
"policy_or_law": 5,
|
18 |
+
"election_poll": 6,
|
19 |
+
"business_market": 7,
|
20 |
+
"sports_event": 8,
|
21 |
+
"pop_culture": 9,
|
22 |
+
"science_health": 10,
|
23 |
+
"fact_check": 11,
|
24 |
+
"compare_entities": 12
|
25 |
+
}
|
26 |
+
|
27 |
+
REVERSE_MAP = {v: k for k, v in QUESTION_TYPES.items()}
|
28 |
+
|
29 |
+
|
30 |
+
# ---------- Step 1: Fast Rule-Based Classification ----------
|
31 |
+
def rule_based_classify(prompt: str) -> int:
|
32 |
+
p = prompt.lower()
|
33 |
+
|
34 |
+
if any(x in p for x in ["latest", "recent", "update", "breaking", "happened today"]):
|
35 |
+
return QUESTION_TYPES["recent_update"]
|
36 |
+
if any(x in p for x in ["explain", "why", "what is", "background", "summary"]):
|
37 |
+
return QUESTION_TYPES["explainer"]
|
38 |
+
if "timeline" in p or re.search(r"how .* changed", p):
|
39 |
+
return QUESTION_TYPES["timeline"]
|
40 |
+
if re.search(r"why .* in the news", p) or "trending" in p:
|
41 |
+
return QUESTION_TYPES["person_in_news"]
|
42 |
+
if any(x in p for x in ["bill", "policy", "law", "executive order", "passed", "signed"]):
|
43 |
+
return QUESTION_TYPES["policy_or_law"]
|
44 |
+
if any(x in p for x in ["election", "poll", "vote", "candidate", "ballot"]):
|
45 |
+
return QUESTION_TYPES["election_poll"]
|
46 |
+
if any(x in p for x in ["stock", "inflation", "economy", "market", "job report"]):
|
47 |
+
return QUESTION_TYPES["business_market"]
|
48 |
+
if any(x in p for x in ["score", "match", "tournament", "game", "league"]):
|
49 |
+
return QUESTION_TYPES["sports_event"]
|
50 |
+
if any(x in p for x in ["celebrity", "actor", "album", "movie", "music", "show"]):
|
51 |
+
return QUESTION_TYPES["pop_culture"]
|
52 |
+
if any(x in p for x in ["health", "covid", "science", "study", "research", "doctor"]):
|
53 |
+
return QUESTION_TYPES["science_health"]
|
54 |
+
if any(x in p for x in ["true", "false", "hoax", "real", "claim", "fact check"]):
|
55 |
+
return QUESTION_TYPES["fact_check"]
|
56 |
+
if any(x in p for x in ["compare", "vs", "difference between"]):
|
57 |
+
return QUESTION_TYPES["compare_entities"]
|
58 |
+
|
59 |
+
return -1
|
60 |
+
|
61 |
+
|
62 |
+
# ---------- Step 2: HF Zero-Shot Fallback ----------
|
63 |
+
def zero_shot_classify(prompt: str) -> int:
|
64 |
+
candidate_labels = list(QUESTION_TYPES.keys())
|
65 |
+
payload = {
|
66 |
+
"inputs": prompt,
|
67 |
+
"parameters": {
|
68 |
+
"candidate_labels": candidate_labels
|
69 |
+
}
|
70 |
+
}
|
71 |
+
|
72 |
+
headers = {
|
73 |
+
"Authorization": f"Bearer {HF_TOKEN}",
|
74 |
+
"Content-Type": "application/json"
|
75 |
+
}
|
76 |
+
|
77 |
+
url = f"https://api-inference.huggingface.co/models/{HF_ZERO_SHOT_MODEL}"
|
78 |
+
|
79 |
+
try:
|
80 |
+
res = requests.post(url, headers=headers, json=payload, timeout=20)
|
81 |
+
res.raise_for_status()
|
82 |
+
data = res.json()
|
83 |
+
if isinstance(data, dict) and "labels" in data:
|
84 |
+
top_label = data["labels"][0]
|
85 |
+
return QUESTION_TYPES.get(top_label, -1)
|
86 |
+
except Exception as e:
|
87 |
+
print("[HF Classifier Error]", str(e))
|
88 |
+
|
89 |
+
return -1
|
90 |
+
|
91 |
+
|
92 |
+
# ---------- Public Hybrid Classifier ----------
|
93 |
+
def classify_question(prompt: str) -> int:
|
94 |
+
rule_result = rule_based_classify(prompt)
|
95 |
+
if rule_result != -1:
|
96 |
+
return rule_result
|
97 |
+
return zero_shot_classify(prompt)
|
question.py
CHANGED
@@ -8,6 +8,7 @@ from redis_client import redis_client as r
|
|
8 |
from dotenv import load_dotenv
|
9 |
from urllib.parse import quote
|
10 |
import json
|
|
|
11 |
|
12 |
load_dotenv()
|
13 |
|
@@ -92,6 +93,10 @@ def fetch_gnews_articles(query: str) -> List[dict]:
|
|
92 |
async def ask_question(input: QuestionInput):
|
93 |
question = input.question
|
94 |
|
|
|
|
|
|
|
|
|
95 |
# Step 1: Ask Mistral to extract keywords
|
96 |
keyword_prompt = (
|
97 |
f"Extract the 3–6 most important keywords from the following question. "
|
@@ -114,8 +119,6 @@ async def ask_question(input: QuestionInput):
|
|
114 |
query_or = " OR ".join(f'"{kw}"' for kw in keywords)
|
115 |
articles = fetch_gnews_articles(query_or)
|
116 |
|
117 |
-
print("Fetched articles:", articles)
|
118 |
-
|
119 |
relevant_articles = [a for a in articles if is_relevant(a, keywords)]
|
120 |
|
121 |
context = "\n\n".join([
|
@@ -144,8 +147,6 @@ async def ask_question(input: QuestionInput):
|
|
144 |
|
145 |
final_answer = extract_answer_after_label(answer)
|
146 |
|
147 |
-
print("Mistral Answer:", answer)
|
148 |
-
|
149 |
return {
|
150 |
"question": question,
|
151 |
"answer": final_answer.strip(),
|
|
|
8 |
from dotenv import load_dotenv
|
9 |
from urllib.parse import quote
|
10 |
import json
|
11 |
+
from nuse_modules.classifier import classify_question, REVERSE_MAP
|
12 |
|
13 |
load_dotenv()
|
14 |
|
|
|
93 |
async def ask_question(input: QuestionInput):
|
94 |
question = input.question
|
95 |
|
96 |
+
qid = classify_question(question)
|
97 |
+
print("Intent ID:", qid)
|
98 |
+
print("Category:", REVERSE_MAP.get(qid, "unknown"))
|
99 |
+
|
100 |
# Step 1: Ask Mistral to extract keywords
|
101 |
keyword_prompt = (
|
102 |
f"Extract the 3–6 most important keywords from the following question. "
|
|
|
119 |
query_or = " OR ".join(f'"{kw}"' for kw in keywords)
|
120 |
articles = fetch_gnews_articles(query_or)
|
121 |
|
|
|
|
|
122 |
relevant_articles = [a for a in articles if is_relevant(a, keywords)]
|
123 |
|
124 |
context = "\n\n".join([
|
|
|
147 |
|
148 |
final_answer = extract_answer_after_label(answer)
|
149 |
|
|
|
|
|
150 |
return {
|
151 |
"question": question,
|
152 |
"answer": final_answer.strip(),
|