raghavNCI commited on
Commit
eafca75
·
1 Parent(s): f5e2bc7

added a classifier v1

Browse files
nuse_modules/__init__.py ADDED
File without changes
nuse_modules/classifier.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import requests
4
+ from dotenv import load_dotenv
5
+
6
+ load_dotenv()
7
+
8
+ HF_TOKEN = os.getenv("HF_TOKEN")
9
+ HF_ZERO_SHOT_MODEL = "facebook/bart-large-mnli"
10
+
11
+ # Map readable categories to numeric IDs
12
+ QUESTION_TYPES = {
13
+ "recent_update": 1,
14
+ "explainer": 2,
15
+ "timeline": 3,
16
+ "person_in_news": 4,
17
+ "policy_or_law": 5,
18
+ "election_poll": 6,
19
+ "business_market": 7,
20
+ "sports_event": 8,
21
+ "pop_culture": 9,
22
+ "science_health": 10,
23
+ "fact_check": 11,
24
+ "compare_entities": 12
25
+ }
26
+
27
+ REVERSE_MAP = {v: k for k, v in QUESTION_TYPES.items()}
28
+
29
+
30
+ # ---------- Step 1: Fast Rule-Based Classification ----------
31
+ def rule_based_classify(prompt: str) -> int:
32
+ p = prompt.lower()
33
+
34
+ if any(x in p for x in ["latest", "recent", "update", "breaking", "happened today"]):
35
+ return QUESTION_TYPES["recent_update"]
36
+ if any(x in p for x in ["explain", "why", "what is", "background", "summary"]):
37
+ return QUESTION_TYPES["explainer"]
38
+ if "timeline" in p or re.search(r"how .* changed", p):
39
+ return QUESTION_TYPES["timeline"]
40
+ if re.search(r"why .* in the news", p) or "trending" in p:
41
+ return QUESTION_TYPES["person_in_news"]
42
+ if any(x in p for x in ["bill", "policy", "law", "executive order", "passed", "signed"]):
43
+ return QUESTION_TYPES["policy_or_law"]
44
+ if any(x in p for x in ["election", "poll", "vote", "candidate", "ballot"]):
45
+ return QUESTION_TYPES["election_poll"]
46
+ if any(x in p for x in ["stock", "inflation", "economy", "market", "job report"]):
47
+ return QUESTION_TYPES["business_market"]
48
+ if any(x in p for x in ["score", "match", "tournament", "game", "league"]):
49
+ return QUESTION_TYPES["sports_event"]
50
+ if any(x in p for x in ["celebrity", "actor", "album", "movie", "music", "show"]):
51
+ return QUESTION_TYPES["pop_culture"]
52
+ if any(x in p for x in ["health", "covid", "science", "study", "research", "doctor"]):
53
+ return QUESTION_TYPES["science_health"]
54
+ if any(x in p for x in ["true", "false", "hoax", "real", "claim", "fact check"]):
55
+ return QUESTION_TYPES["fact_check"]
56
+ if any(x in p for x in ["compare", "vs", "difference between"]):
57
+ return QUESTION_TYPES["compare_entities"]
58
+
59
+ return -1
60
+
61
+
62
+ # ---------- Step 2: HF Zero-Shot Fallback ----------
63
+ def zero_shot_classify(prompt: str) -> int:
64
+ candidate_labels = list(QUESTION_TYPES.keys())
65
+ payload = {
66
+ "inputs": prompt,
67
+ "parameters": {
68
+ "candidate_labels": candidate_labels
69
+ }
70
+ }
71
+
72
+ headers = {
73
+ "Authorization": f"Bearer {HF_TOKEN}",
74
+ "Content-Type": "application/json"
75
+ }
76
+
77
+ url = f"https://api-inference.huggingface.co/models/{HF_ZERO_SHOT_MODEL}"
78
+
79
+ try:
80
+ res = requests.post(url, headers=headers, json=payload, timeout=20)
81
+ res.raise_for_status()
82
+ data = res.json()
83
+ if isinstance(data, dict) and "labels" in data:
84
+ top_label = data["labels"][0]
85
+ return QUESTION_TYPES.get(top_label, -1)
86
+ except Exception as e:
87
+ print("[HF Classifier Error]", str(e))
88
+
89
+ return -1
90
+
91
+
92
+ # ---------- Public Hybrid Classifier ----------
93
+ def classify_question(prompt: str) -> int:
94
+ rule_result = rule_based_classify(prompt)
95
+ if rule_result != -1:
96
+ return rule_result
97
+ return zero_shot_classify(prompt)
question.py CHANGED
@@ -8,6 +8,7 @@ from redis_client import redis_client as r
8
  from dotenv import load_dotenv
9
  from urllib.parse import quote
10
  import json
 
11
 
12
  load_dotenv()
13
 
@@ -92,6 +93,10 @@ def fetch_gnews_articles(query: str) -> List[dict]:
92
  async def ask_question(input: QuestionInput):
93
  question = input.question
94
 
 
 
 
 
95
  # Step 1: Ask Mistral to extract keywords
96
  keyword_prompt = (
97
  f"Extract the 3–6 most important keywords from the following question. "
@@ -114,8 +119,6 @@ async def ask_question(input: QuestionInput):
114
  query_or = " OR ".join(f'"{kw}"' for kw in keywords)
115
  articles = fetch_gnews_articles(query_or)
116
 
117
- print("Fetched articles:", articles)
118
-
119
  relevant_articles = [a for a in articles if is_relevant(a, keywords)]
120
 
121
  context = "\n\n".join([
@@ -144,8 +147,6 @@ async def ask_question(input: QuestionInput):
144
 
145
  final_answer = extract_answer_after_label(answer)
146
 
147
- print("Mistral Answer:", answer)
148
-
149
  return {
150
  "question": question,
151
  "answer": final_answer.strip(),
 
8
  from dotenv import load_dotenv
9
  from urllib.parse import quote
10
  import json
11
+ from nuse_modules.classifier import classify_question, REVERSE_MAP
12
 
13
  load_dotenv()
14
 
 
93
  async def ask_question(input: QuestionInput):
94
  question = input.question
95
 
96
+ qid = classify_question(question)
97
+ print("Intent ID:", qid)
98
+ print("Category:", REVERSE_MAP.get(qid, "unknown"))
99
+
100
  # Step 1: Ask Mistral to extract keywords
101
  keyword_prompt = (
102
  f"Extract the 3–6 most important keywords from the following question. "
 
119
  query_or = " OR ".join(f'"{kw}"' for kw in keywords)
120
  articles = fetch_gnews_articles(query_or)
121
 
 
 
122
  relevant_articles = [a for a in articles if is_relevant(a, keywords)]
123
 
124
  context = "\n\n".join([
 
147
 
148
  final_answer = extract_answer_after_label(answer)
149
 
 
 
150
  return {
151
  "question": question,
152
  "answer": final_answer.strip(),