Medresearch / components /model_utils.py
mgbam's picture
Add untracked files and synchronize with remote
9c7387c
raw
history blame contribute delete
658 Bytes
from transformers import pipeline
import os
def load_summarization_model():
"""Loads the summarization model. Check for HUGGINGFACE_API_TOKEN first."""
api_token = os.environ.get("HUGGINGFACE_API_TOKEN")
model_name = "facebook/bart-large-cnn" # Or whatever
if not api_token:
print("HUGGINGFACE_API_TOKEN not found. Summarization will not work.")
return None
try:
summarizer = pipeline("summarization", model=model_name, token=api_token)
print(f"Summarization Model {model_name} Loaded...")
return summarizer
except Exception as e:
print(f"Model load error: {e}")
return None