raghavNCI
commited on
Commit
·
aefa1e1
1
Parent(s):
326a8da
added keyword_extractor
Browse files- nuse_modules/keyword_extracter.py +65 -0
- question.py +2 -27
nuse_modules/keyword_extracter.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# nuse_modules/keyword_extractor.py
|
2 |
+
|
3 |
+
import os
|
4 |
+
import requests
|
5 |
+
import json
|
6 |
+
|
7 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
8 |
+
|
9 |
+
HF_API_URL = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.3"
|
10 |
+
HEADERS = {
|
11 |
+
"Authorization": f"Bearer {HF_TOKEN}",
|
12 |
+
"Content-Type": "application/json"
|
13 |
+
}
|
14 |
+
|
15 |
+
|
16 |
+
def mistral_generate(prompt: str, max_new_tokens=128) -> str:
|
17 |
+
payload = {
|
18 |
+
"inputs": prompt,
|
19 |
+
"parameters": {
|
20 |
+
"max_new_tokens": max_new_tokens,
|
21 |
+
"temperature": 0.7
|
22 |
+
}
|
23 |
+
}
|
24 |
+
try:
|
25 |
+
response = requests.post(HF_API_URL, headers=HEADERS, data=json.dumps(payload), timeout=30)
|
26 |
+
response.raise_for_status()
|
27 |
+
result = response.json()
|
28 |
+
if isinstance(result, list) and len(result) > 0:
|
29 |
+
return result[0].get("generated_text", "").strip()
|
30 |
+
except Exception as e:
|
31 |
+
print("[mistral_generate error]", str(e))
|
32 |
+
|
33 |
+
return ""
|
34 |
+
|
35 |
+
|
36 |
+
def extract_last_keywords(raw: str, max_keywords: int = 8) -> list[str]:
|
37 |
+
segments = raw.strip().split("\n")
|
38 |
+
|
39 |
+
for line in reversed(segments):
|
40 |
+
line = line.strip()
|
41 |
+
if line.lower().startswith("extract") or not line or len(line) < 10:
|
42 |
+
continue
|
43 |
+
|
44 |
+
if line.count(",") >= 2:
|
45 |
+
parts = [kw.strip().strip('"') for kw in line.split(",") if kw.strip()]
|
46 |
+
if all(len(p.split()) <= 3 for p in parts) and 1 <= len(parts) <= max_keywords:
|
47 |
+
return parts
|
48 |
+
|
49 |
+
return []
|
50 |
+
|
51 |
+
|
52 |
+
def keywords_extractor(question: str) -> list[str]:
|
53 |
+
prompt = (
|
54 |
+
f"Extract the 3–6 most important keywords from the following question. "
|
55 |
+
f"Return only the keywords, comma-separated (no explanations):\n\n"
|
56 |
+
f"{question}"
|
57 |
+
)
|
58 |
+
|
59 |
+
raw_output = mistral_generate(prompt, max_new_tokens=32)
|
60 |
+
keywords = extract_last_keywords(raw_output)
|
61 |
+
|
62 |
+
print("Raw extracted keywords:", raw_output)
|
63 |
+
print("Parsed keywords:", keywords)
|
64 |
+
|
65 |
+
return keywords
|
question.py
CHANGED
@@ -9,6 +9,7 @@ from dotenv import load_dotenv
|
|
9 |
from urllib.parse import quote
|
10 |
import json
|
11 |
from nuse_modules.classifier import classify_question, REVERSE_MAP
|
|
|
12 |
|
13 |
load_dotenv()
|
14 |
|
@@ -26,25 +27,6 @@ HEADERS = {
|
|
26 |
"Content-Type": "application/json"
|
27 |
}
|
28 |
|
29 |
-
def extract_last_keywords(raw: str, max_keywords=8):
|
30 |
-
segments = raw.strip().split("\n")
|
31 |
-
|
32 |
-
# Ignore quoted or prompt lines
|
33 |
-
for line in reversed(segments):
|
34 |
-
line = line.strip()
|
35 |
-
if line.lower().startswith("extract") or not line or len(line) < 10:
|
36 |
-
continue
|
37 |
-
|
38 |
-
# Look for lines with multiple comma-separated items
|
39 |
-
if line.count(",") >= 2:
|
40 |
-
parts = [kw.strip().strip('"') for kw in line.split(",") if kw.strip()]
|
41 |
-
# Ensure they're not just long phrases or sentence fragments
|
42 |
-
if all(len(p.split()) <= 3 for p in parts) and 1 <= len(parts) <= max_keywords:
|
43 |
-
return parts
|
44 |
-
|
45 |
-
return []
|
46 |
-
|
47 |
-
|
48 |
def is_relevant(article, keywords):
|
49 |
text = f"{article.get('title', '')} {article.get('content', '')}".lower()
|
50 |
return any(kw.lower() in text for kw in keywords)
|
@@ -97,14 +79,7 @@ async def ask_question(input: QuestionInput):
|
|
97 |
print("Intent ID:", qid)
|
98 |
print("Category:", REVERSE_MAP.get(qid, "unknown"))
|
99 |
|
100 |
-
|
101 |
-
keyword_prompt = (
|
102 |
-
f"Extract the 3–6 most important keywords from the following question. "
|
103 |
-
f"Return only the keywords, comma-separated (no explanations):\n\n"
|
104 |
-
f"{question}"
|
105 |
-
)
|
106 |
-
raw_keywords = mistral_generate(keyword_prompt, max_new_tokens=32)
|
107 |
-
keywords = extract_last_keywords(raw_keywords)
|
108 |
|
109 |
print("Raw extracted keywords:", keywords)
|
110 |
|
|
|
9 |
from urllib.parse import quote
|
10 |
import json
|
11 |
from nuse_modules.classifier import classify_question, REVERSE_MAP
|
12 |
+
from nuse_modules.keyword_extracter import keywords_extractor
|
13 |
|
14 |
load_dotenv()
|
15 |
|
|
|
27 |
"Content-Type": "application/json"
|
28 |
}
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
def is_relevant(article, keywords):
|
31 |
text = f"{article.get('title', '')} {article.get('content', '')}".lower()
|
32 |
return any(kw.lower() in text for kw in keywords)
|
|
|
79 |
print("Intent ID:", qid)
|
80 |
print("Category:", REVERSE_MAP.get(qid, "unknown"))
|
81 |
|
82 |
+
keywords = keywords_extractor(question)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
print("Raw extracted keywords:", keywords)
|
85 |
|