ragV98 commited on
Commit
8e4c661
·
1 Parent(s): c8b3b66

switching to bart

Browse files
components/LLMs/Bart.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # components/LLMs/bart.py
2
+
3
+ import os
4
+ import requests
5
+ from typing import Optional
6
+
7
+ HF_TOKEN = os.environ.get("HF_TOKEN")
8
+ BART_URL = "https://wq3d0mr9dcpcldou.us-east-1.aws.endpoints.huggingface.cloud"
9
+
10
+ HEADERS = {
11
+ "Authorization": f"Bearer {HF_TOKEN}",
12
+ "Content-Type": "application/json"
13
+ }
14
+
15
+
16
+ def call_bart_summarizer(base_prompt: str, tail_prompt: str, max_length: int = 130) -> Optional[str]:
17
+ """
18
+ Calls facebook/bart-large-cnn using HF Inference API with composed prompt.
19
+
20
+ Args:
21
+ base_prompt (str): Instruction or high-level instruction.
22
+ tail_prompt (str): News content or body.
23
+ max_length (int): Output summary length limit.
24
+
25
+ Returns:
26
+ str: Cleaned summary string, or None if error.
27
+ """
28
+ full_input = f"{base_prompt.strip()}\n\n{tail_prompt.strip()}"
29
+ payload = {
30
+ "inputs": full_input,
31
+ "parameters": {
32
+ "max_length": max_length,
33
+ "do_sample": False
34
+ }
35
+ }
36
+
37
+ try:
38
+ response = requests.post(BART_URL, headers=HEADERS, json=payload, timeout=30)
39
+ response.raise_for_status()
40
+ result = response.json()
41
+
42
+ if isinstance(result, list) and result and "summary_text" in result[0]:
43
+ return result[0]["summary_text"].strip()
44
+ else:
45
+ print("⚠️ Unexpected BART response format:", result)
46
+ return None
47
+
48
+ except Exception as e:
49
+ print(f"⚠️ BART API call failed: {e}")
50
+ return None
components/generators/daily_feed.py CHANGED
@@ -10,6 +10,7 @@ from llama_index.core.schema import Document
10
  from llama_index.core.settings import Settings
11
  from components.LLMs.Mistral import call_mistral
12
  from components.LLMs.TinyLLama import call_tinyllama
 
13
 
14
  # ✅ Disable implicit LLM usage
15
  Settings.llm = None
@@ -49,7 +50,7 @@ def summarize_topic(docs: List[str], topic: str) -> List[Dict]:
49
  for doc in docs[:5]:
50
  tail_prompt = f"Topic: {topic}\n\n{doc.strip()}"
51
  print(f"\n📤 Prompt tail for Mistral:\n{tail_prompt[:300]}...\n")
52
- summary_block = call_tinyllama(base_prompt=BASE_PROMPT, tail_prompt=tail_prompt)
53
 
54
  if summary_block:
55
  for line in summary_block.splitlines():
 
10
  from llama_index.core.settings import Settings
11
  from components.LLMs.Mistral import call_mistral
12
  from components.LLMs.TinyLLama import call_tinyllama
13
+ from components.LLMs.Bart import call_bart_summarizer
14
 
15
  # ✅ Disable implicit LLM usage
16
  Settings.llm = None
 
50
  for doc in docs[:5]:
51
  tail_prompt = f"Topic: {topic}\n\n{doc.strip()}"
52
  print(f"\n📤 Prompt tail for Mistral:\n{tail_prompt[:300]}...\n")
53
+ summary_block = call_bart_summarizer(base_prompt=BASE_PROMPT, tail_prompt=tail_prompt)
54
 
55
  if summary_block:
56
  for line in summary_block.splitlines():