marianeft commited on
Commit
f90a027
·
verified ·
1 Parent(s): d9b5c9b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -5
app.py CHANGED
@@ -14,7 +14,7 @@ import nltk
14
  nltk.download('punkt')
15
 
16
  # Load model and tokenizer
17
- model_name = 'dejanseo/sentiment'
18
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
19
  tokenizer = AutoTokenizer.from_pretrained(model_name)
20
 
@@ -40,6 +40,15 @@ background_colors = {
40
  "very negative": "rgba(255, 0, 0, 0.5)"
41
  }
42
 
 
 
 
 
 
 
 
 
 
43
  # Function to get text content from a URL, restricted to Medium stories/articles
44
  def get_text_from_url(url):
45
  if not validators.url(url):
@@ -57,7 +66,31 @@ def get_text_from_url(url):
57
  except Exception as e:
58
  return None, f"Error extracting text: {e}"
59
 
60
- # ... (rest of the functions: classify_text, classify_long_text, classify_sentences remain the same)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  # Streamlit UI
63
  st.title("Sentiment Classification Model (Medium Only)")
@@ -111,9 +144,7 @@ if url:
111
  )
112
 
113
  st.write(f"Chunk {i + 1}:")
114
- st.write(chunk)
115
- st.altair_chart(chunk_chart, use_container_width=True)
116
-
117
  # Sentence-level classification with background colors
118
  st.write("Extracted Text with Sentiment Highlights:")
119
  sentence_scores = classify_sentences(text)
 
14
  nltk.download('punkt')
15
 
16
  # Load model and tokenizer
17
+ model_name = 'dejanseo/sentiment' #Load model adapted from
18
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
19
  tokenizer = AutoTokenizer.from_pretrained(model_name)
20
 
 
40
  "very negative": "rgba(255, 0, 0, 0.5)"
41
  }
42
 
43
+ # Function to classify text and return sentiment scores
44
+ def classify_text(text, max_length):
45
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=max_length)
46
+ with torch.no_grad():
47
+ outputs = model(**inputs)
48
+ logits = outputs.logits
49
+ probabilities = torch.softmax(logits, dim=-1).squeeze().tolist()
50
+ return probabilities
51
+
52
  # Function to get text content from a URL, restricted to Medium stories/articles
53
  def get_text_from_url(url):
54
  if not validators.url(url):
 
66
  except Exception as e:
67
  return None, f"Error extracting text: {e}"
68
 
69
+ # Function to handle long texts
70
+ def classify_long_text(text):
71
+ max_length = tokenizer.model_max_length
72
+ # Split the text into chunks
73
+ chunks = [text[i:i + max_length] for i in range(0, len(text), max_length)]
74
+ aggregate_scores = [0] * len(sentiment_labels)
75
+ chunk_scores_list = []
76
+ for chunk in chunks:
77
+ chunk_scores = classify_text(chunk, max_length)
78
+ chunk_scores_list.append(chunk_scores)
79
+ aggregate_scores = [x + y for x, y in zip(aggregate_scores, chunk_scores)]
80
+ # Average the scores
81
+ aggregate_scores = [x / len(chunks) for x in aggregate_scores]
82
+ return aggregate_scores, chunk_scores_list, chunks
83
+
84
+ # Function to classify each sentence in the text
85
+ def classify_sentences(text):
86
+ sentences = sent_tokenize(text)
87
+ sentence_scores = []
88
+ for sentence in sentences:
89
+ scores = classify_text(sentence, tokenizer.model_max_length)
90
+ sentiment_idx = scores.index(max(scores))
91
+ sentiment = sentiment_labels[sentiment_idx]
92
+ sentence_scores.append((sentence, sentiment))
93
+ return sentence_scores
94
 
95
  # Streamlit UI
96
  st.title("Sentiment Classification Model (Medium Only)")
 
144
  )
145
 
146
  st.write(f"Chunk {i + 1}:")
147
+
 
 
148
  # Sentence-level classification with background colors
149
  st.write("Extracted Text with Sentiment Highlights:")
150
  sentence_scores = classify_sentences(text)