import streamlit as st
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import requests
import pandas as pd
import altair as alt
from collections import OrderedDict
from nltk.tokenize import sent_tokenize
import trafilatura
import validators
# Load the punkt tokenizer from nltk
import nltk'punkt')
# Load model and tokenizer
model_name = 'dejanseo/sentiment' #Load model adapted from
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Sentiment labels as textual descriptions
sentiment_labels = {
0: "very positive",
1: "positive",
2: "somewhat positive",
3: "neutral",
4: "somewhat negative",
5: "negative",
6: "very negative"
# Background colors for sentiments
background_colors = {
"very positive": "rgba(0, 255, 0, 0.5)",
"positive": "rgba(0, 255, 0, 0.3)",
"somewhat positive": "rgba(0, 255, 0, 0.1)",
"neutral": "rgba(128, 128, 128, 0.1)",
"somewhat negative": "rgba(255, 0, 0, 0.1)",
"negative": "rgba(255, 0, 0, 0.3)",
"very negative": "rgba(255, 0, 0, 0.5)"
# Function to classify text and return sentiment scores
def classify_text(text, max_length):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=max_length)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=-1).squeeze().tolist()
return probabilities
# Function to get text content from a URL, restricted to Medium stories/articles
def get_text_from_url(url):
if not validators.url(url):
return None, "Invalid URL"
if "" not in url: # Check if it's a Medium URL
return None, "URL is not a Medium story/article."
downloaded = trafilatura.fetch_url(url)
if downloaded:
return trafilatura.extract(downloaded), None
return None, "Could not download content from URL."
except Exception as e:
return None, f"Error extracting text: {e}"
# Function to handle long texts
def classify_long_text(text):
max_length = tokenizer.model_max_length
# Split the text into chunks
chunks = [text[i:i + max_length] for i in range(0, len(text), max_length)]
aggregate_scores = [0] * len(sentiment_labels)
chunk_scores_list = []
for chunk in chunks:
chunk_scores = classify_text(chunk, max_length)
aggregate_scores = [x + y for x, y in zip(aggregate_scores, chunk_scores)]
# Average the scores
aggregate_scores = [x / len(chunks) for x in aggregate_scores]
return aggregate_scores, chunk_scores_list, chunks
# Function to classify each sentence in the text
def classify_sentences(text):
sentences = sent_tokenize(text)
sentence_scores = []
for sentence in sentences:
scores = classify_text(sentence, tokenizer.model_max_length)
sentiment_idx = scores.index(max(scores))
sentiment = sentiment_labels[sentiment_idx]
sentence_scores.append((sentence, sentiment))
return sentence_scores
# Streamlit UI
st.title("Sentiment Classification Model (Medium Only)")
url = st.text_input("Enter Medium URL:")
if url:
text, error_message = get_text_from_url(url)
if error_message:
st.error(error_message) # Display error message
elif text:
# ... (rest of the analysis and display code remains the same)
scores, chunk_scores_list, chunks = classify_long_text(text)
scores_dict = {sentiment_labels[i]: scores[i] for i in range(len(sentiment_labels))}
# Ensure the exact order of labels in the graph
sentiment_order = [
"very positive", "positive", "somewhat positive",
"somewhat negative", "negative", "very negative"
ordered_scores_dict = OrderedDict((label, scores_dict[label]) for label in sentiment_order)
# Prepare the DataFrame and reindex
df = pd.DataFrame.from_dict(ordered_scores_dict, orient='index', columns=['Likelihood']).reindex(sentiment_order)
# Use Altair to plot the bar chart
chart = alt.Chart(df.reset_index()).mark_bar().encode(
x=alt.X('index', sort=sentiment_order, title='Sentiment'),
st.altair_chart(chart, use_container_width=True)
# Display each chunk and its own chart
for i, (chunk_scores, chunk) in enumerate(zip(chunk_scores_list, chunks)):
chunk_scores_dict = {sentiment_labels[j]: chunk_scores[j] for j in range(len(sentiment_labels))}
ordered_chunk_scores_dict = OrderedDict((label, chunk_scores_dict[label]) for label in sentiment_order)
df_chunk = pd.DataFrame.from_dict(ordered_chunk_scores_dict, orient='index', columns=['Likelihood']).reindex(sentiment_order)
chunk_chart = alt.Chart(df_chunk.reset_index()).mark_bar().encode(
x=alt.X('index', sort=sentiment_order, title='Sentiment'),
st.write(f"Chunk {i + 1}:")
# Sentence-level classification with background colors
st.write("Extracted Text with Sentiment Highlights:")
sentence_scores = classify_sentences(text)
for sentence, sentiment in sentence_scores:
bg_color = background_colors[sentiment]
st.markdown(f'<span style="background-color: {bg_color}">{sentence}</span>', unsafe_allow_html=True)