Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
import requests | |
import pandas as pd | |
import altair as alt | |
from collections import OrderedDict | |
from nltk.tokenize import sent_tokenize | |
import trafilatura | |
import validators | |
# Load the punkt tokenizer from nltk | |
import nltk | |
nltk.download('punkt') | |
# Load model and tokenizer | |
model_name = 'dejanseo/sentiment' #Load model adapted from | |
model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Sentiment labels as textual descriptions | |
sentiment_labels = { | |
0: "very positive", | |
1: "positive", | |
2: "somewhat positive", | |
3: "neutral", | |
4: "somewhat negative", | |
5: "negative", | |
6: "very negative" | |
} | |
# Background colors for sentiments | |
background_colors = { | |
"very positive": "rgba(0, 255, 0, 0.5)", | |
"positive": "rgba(0, 255, 0, 0.3)", | |
"somewhat positive": "rgba(0, 255, 0, 0.1)", | |
"neutral": "rgba(128, 128, 128, 0.1)", | |
"somewhat negative": "rgba(255, 0, 0, 0.1)", | |
"negative": "rgba(255, 0, 0, 0.3)", | |
"very negative": "rgba(255, 0, 0, 0.5)" | |
} | |
# Function to classify text and return sentiment scores | |
def classify_text(text, max_length): | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=max_length) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
probabilities = torch.softmax(logits, dim=-1).squeeze().tolist() | |
return probabilities | |
# Function to get text content from a URL, restricted to Medium stories/articles | |
def get_text_from_url(url): | |
if not validators.url(url): | |
return None, "Invalid URL" | |
if "medium.com/" not in url: # Check if it's a Medium URL | |
return None, "URL is not a Medium story/article." | |
try: | |
downloaded = trafilatura.fetch_url(url) | |
if downloaded: | |
return trafilatura.extract(downloaded), None | |
else: | |
return None, "Could not download content from URL." | |
except Exception as e: | |
return None, f"Error extracting text: {e}" | |
# Function to handle long texts | |
def classify_long_text(text): | |
max_length = tokenizer.model_max_length | |
# Split the text into chunks | |
chunks = [text[i:i + max_length] for i in range(0, len(text), max_length)] | |
aggregate_scores = [0] * len(sentiment_labels) | |
chunk_scores_list = [] | |
for chunk in chunks: | |
chunk_scores = classify_text(chunk, max_length) | |
chunk_scores_list.append(chunk_scores) | |
aggregate_scores = [x + y for x, y in zip(aggregate_scores, chunk_scores)] | |
# Average the scores | |
aggregate_scores = [x / len(chunks) for x in aggregate_scores] | |
return aggregate_scores, chunk_scores_list, chunks | |
# Function to classify each sentence in the text | |
def classify_sentences(text): | |
sentences = sent_tokenize(text) | |
sentence_scores = [] | |
for sentence in sentences: | |
scores = classify_text(sentence, tokenizer.model_max_length) | |
sentiment_idx = scores.index(max(scores)) | |
sentiment = sentiment_labels[sentiment_idx] | |
sentence_scores.append((sentence, sentiment)) | |
return sentence_scores | |
# Streamlit UI | |
st.title("Sentiment Classification Model (Medium Only)") | |
url = st.text_input("Enter Medium URL:") | |
if url: | |
text, error_message = get_text_from_url(url) | |
if error_message: | |
st.error(error_message) # Display error message | |
elif text: | |
# ... (rest of the analysis and display code remains the same) | |
scores, chunk_scores_list, chunks = classify_long_text(text) | |
scores_dict = {sentiment_labels[i]: scores[i] for i in range(len(sentiment_labels))} | |
# Ensure the exact order of labels in the graph | |
sentiment_order = [ | |
"very positive", "positive", "somewhat positive", | |
"neutral", | |
"somewhat negative", "negative", "very negative" | |
] | |
ordered_scores_dict = OrderedDict((label, scores_dict[label]) for label in sentiment_order) | |
# Prepare the DataFrame and reindex | |
df = pd.DataFrame.from_dict(ordered_scores_dict, orient='index', columns=['Likelihood']).reindex(sentiment_order) | |
# Use Altair to plot the bar chart | |
chart = alt.Chart(df.reset_index()).mark_bar().encode( | |
x=alt.X('index', sort=sentiment_order, title='Sentiment'), | |
y='Likelihood' | |
).properties( | |
width=600, | |
height=400 | |
) | |
st.altair_chart(chart, use_container_width=True) | |
# Display each chunk and its own chart | |
for i, (chunk_scores, chunk) in enumerate(zip(chunk_scores_list, chunks)): | |
chunk_scores_dict = {sentiment_labels[j]: chunk_scores[j] for j in range(len(sentiment_labels))} | |
ordered_chunk_scores_dict = OrderedDict((label, chunk_scores_dict[label]) for label in sentiment_order) | |
df_chunk = pd.DataFrame.from_dict(ordered_chunk_scores_dict, orient='index', columns=['Likelihood']).reindex(sentiment_order) | |
chunk_chart = alt.Chart(df_chunk.reset_index()).mark_bar().encode( | |
x=alt.X('index', sort=sentiment_order, title='Sentiment'), | |
y='Likelihood' | |
).properties( | |
width=600, | |
height=400 | |
) | |
st.write(f"Chunk {i + 1}:") | |
# Sentence-level classification with background colors | |
st.write("Extracted Text with Sentiment Highlights:") | |
sentence_scores = classify_sentences(text) | |
for sentence, sentiment in sentence_scores: | |
bg_color = background_colors[sentiment] | |
st.markdown(f'<span style="background-color: {bg_color}">{sentence}</span>', unsafe_allow_html=True) |