5240_project / app.py
miltonc's picture
Update app.py
382ebde verified
import streamlit as st
from transformers import pipeline
from translate import Translator
import re
# Load models
def load_models():
"""Loads the sentiment analysis and summarization models."""
sentiment_analyzer = pipeline("text-classification", model="miltonc/distilbert-base-uncased_ft_5")
summarizer = pipeline("summarization", model="FelixChao/T5-Chinese-Summarization")
return sentiment_analyzer, summarizer
# Max length for models
SENTIMENT_MAX_LENGTH = 512 # Maximum input length for DistilBERT sentiment analysis
SUMMARIZER_MAX_LENGTH = 512 # Maximum input length for T5 summarization (approximate)
def sentiment_analysis(text, sentiment_analyzer):
"""Analyzes the sentiment of the input text."""
try:
result = sentiment_analyzer(text)[0]
return result
except Exception as e:
print(f"sentiment_analysis error for '{text}': {e}. Returning 'sentiment_analysis Failed'")
return {"label": "sentiment_analysis Failed", "score": 0.0}
def summarize_news(text, summarizer):
"""Summarizes the input news text."""
try:
summary = summarizer(text, max_length=30, min_length=10)[0]['summary_text']
return summary
except Exception as e:
print(f"Summarization error for '{text}': {e}. Returning 'Summarization Failed'")
return "Summarization Failed"
def translate_text(text, target_language='en', source_language='zh-TW'):
"""Translates Chinese text to English."""
translator = Translator(to_lang=target_language, from_lang=source_language)
try:
translation = translator.translate(text)
return translation
except Exception as e:
print(f"Translation error for '{text}': {e}. Returning 'Translation Failed'")
return "Translation Failed"
# Check if the input text is Chinese
def is_chinese(text):
"""Checks if the input text contains Chinese characters."""
return bool(re.search(r'[\u4e00-\u9fff]', text))
# Main Streamlit app
def main():
"""Main Streamlit application."""
st.title("AI-Powered Sentiment Analysis and Summarization")
sentiment_analyzer, summarizer = load_models()
text = st.text_area(f"Enter the Chinese text here (Max {max(SENTIMENT_MAX_LENGTH, SUMMARIZER_MAX_LENGTH)} characters).....", height=200)
if st.button("Analyze"):
if text:
if not is_chinese(text):
st.error("Error: Please enter Chinese text.")
return
if len(text) > SENTIMENT_MAX_LENGTH:
st.error(f"Error: Input text for sentiment analysis exceeds the maximum length of {SENTIMENT_MAX_LENGTH} characters. Please shorten your text.")
return
if len(text) > SUMMARIZER_MAX_LENGTH:
st.error(f"Error: Input text for summarization exceeds the maximum length of {SUMMARIZER_MAX_LENGTH} characters. Please shorten your text.")
return
with st.spinner("Analyzing sentiment..."):
text_en = translate_text(text, target_language='en', source_language='zh-TW')
sentiment_result = sentiment_analysis(text_en, sentiment_analyzer)
label_map = {"LABEL_0": 'negative', "LABEL_1": 'neutral', "LABEL_2": 'positive'}
try:
sentiment_label = label_map.get(sentiment_result['label'], 'Unknown')
sentiment_score = sentiment_result['score']
st.write(f"### Sentiment: {sentiment_label}")
st.write(f"Score: {sentiment_score:.4f}")
except (KeyError, ValueError):
st.write("Error processing sentiment result.")
st.write(sentiment_result)
with st.spinner("Summarizing News..."):
story = summarize_news(text, summarizer)
st.write("### Summarized News:")
st.write(story)
if __name__ == "__main__":
main()