Upload 2 files
Browse files- app.py +315 -0
- requirements.txt +9 -3
app.py
ADDED
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
import nltk
|
6 |
+
import spacy
|
7 |
+
import re
|
8 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline, AutoModelForSeq2SeqLM
|
9 |
+
|
10 |
+
# Download necessary NLTK data for sentence tokenization
|
11 |
+
try:
|
12 |
+
nltk.data.find('tokenizers/punkt')
|
13 |
+
except LookupError:
|
14 |
+
nltk.download('punkt')
|
15 |
+
|
16 |
+
SUMMARY_FILE = "training_summary.json"
|
17 |
+
# Assume label meanings are consistent with previous files
|
18 |
+
LABEL_MAP = {0: "Negative", 1: "Neutral", 2: "Positive"}
|
19 |
+
# Color coding for sentiment
|
20 |
+
COLOR_MAP = {
|
21 |
+
"Negative": "red",
|
22 |
+
"Neutral": "blue",
|
23 |
+
"Positive": "green"
|
24 |
+
}
|
25 |
+
|
26 |
+
# Global loading of models and NLP components
|
27 |
+
loaded_model = None
|
28 |
+
loaded_tokenizer = None
|
29 |
+
best_model_summary = None
|
30 |
+
summarizer = None
|
31 |
+
nlp = None # For NER
|
32 |
+
|
33 |
+
def load_models_and_components():
|
34 |
+
global loaded_model, loaded_tokenizer, best_model_summary, summarizer, nlp
|
35 |
+
|
36 |
+
# Load sentiment analysis model from training
|
37 |
+
if not os.path.exists(SUMMARY_FILE):
|
38 |
+
raise FileNotFoundError(f"Error: Could not find training summary file {SUMMARY_FILE}. Please run the fine-tuning and testing scripts first.")
|
39 |
+
|
40 |
+
with open(SUMMARY_FILE, 'r') as f:
|
41 |
+
summary_data = json.load(f)
|
42 |
+
|
43 |
+
if "best_model_details" not in summary_data or not summary_data["best_model_details"]:
|
44 |
+
raise ValueError(f"Error: Best model information not found or incomplete in {SUMMARY_FILE}.")
|
45 |
+
|
46 |
+
best_model_summary = summary_data["best_model_details"]
|
47 |
+
best_model_path = best_model_summary.get("best_model_path")
|
48 |
+
|
49 |
+
if not best_model_path:
|
50 |
+
best_model_path = summary_data.get("best_model_path") # Compatible with older format
|
51 |
+
|
52 |
+
if not best_model_path or not os.path.exists(best_model_path):
|
53 |
+
raise FileNotFoundError(f"Error: Best model path {best_model_path} not found or invalid.")
|
54 |
+
|
55 |
+
print(f"Loading sentiment model {best_model_summary['model_name']} from {best_model_path}...")
|
56 |
+
try:
|
57 |
+
loaded_tokenizer = AutoTokenizer.from_pretrained(best_model_path)
|
58 |
+
loaded_model = AutoModelForSequenceClassification.from_pretrained(best_model_path)
|
59 |
+
loaded_model.eval() # Set to evaluation mode
|
60 |
+
print("Sentiment model loaded successfully.")
|
61 |
+
except Exception as e:
|
62 |
+
raise RuntimeError(f"Failed to load sentiment model: {e}")
|
63 |
+
|
64 |
+
# Load summarization model
|
65 |
+
print("Loading summarization model...")
|
66 |
+
try:
|
67 |
+
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
|
68 |
+
print("Summarization model loaded successfully.")
|
69 |
+
except Exception as e:
|
70 |
+
print(f"Warning: Failed to load summarization model: {e}")
|
71 |
+
print("Will continue without summarization capability.")
|
72 |
+
summarizer = None
|
73 |
+
|
74 |
+
# Load spaCy model for NER (Named Entity Recognition)
|
75 |
+
print("Loading NER model...")
|
76 |
+
try:
|
77 |
+
# Download the model if it's not already downloaded
|
78 |
+
if not spacy.util.is_package("en_core_web_sm"):
|
79 |
+
spacy.cli.download("en_core_web_sm")
|
80 |
+
nlp = spacy.load("en_core_web_sm")
|
81 |
+
print("NER model loaded successfully.")
|
82 |
+
except Exception as e:
|
83 |
+
print(f"Warning: Failed to load NER model: {e}")
|
84 |
+
print("Will continue without NER capability.")
|
85 |
+
nlp = None
|
86 |
+
|
87 |
+
def predict_sentiment(text):
|
88 |
+
"""Predict sentiment for a single piece of text"""
|
89 |
+
global loaded_model, loaded_tokenizer
|
90 |
+
if not loaded_model or not loaded_tokenizer:
|
91 |
+
return "Error: Model not loaded.", None
|
92 |
+
|
93 |
+
if not text or not text.strip():
|
94 |
+
return "Please enter text for analysis.", None
|
95 |
+
|
96 |
+
try:
|
97 |
+
inputs = loaded_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
|
98 |
+
|
99 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
100 |
+
loaded_model.to(device)
|
101 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
102 |
+
|
103 |
+
with torch.no_grad():
|
104 |
+
outputs = loaded_model(**inputs)
|
105 |
+
|
106 |
+
prediction_idx = torch.argmax(outputs.logits, dim=-1).item()
|
107 |
+
sentiment = LABEL_MAP.get(prediction_idx, f"Unknown ({prediction_idx})")
|
108 |
+
return sentiment, prediction_idx
|
109 |
+
except Exception as e:
|
110 |
+
print(f"Error during sentiment prediction: {e}")
|
111 |
+
return f"Error: {str(e)}", None
|
112 |
+
|
113 |
+
def generate_summary(text):
|
114 |
+
"""Generate a summary for longer text"""
|
115 |
+
global summarizer
|
116 |
+
if not summarizer:
|
117 |
+
return "Summarization model not available."
|
118 |
+
|
119 |
+
if not text or len(text.strip()) < 50:
|
120 |
+
return "Text too short for summarization."
|
121 |
+
|
122 |
+
try:
|
123 |
+
# BART has a max length, so we'll truncate if needed
|
124 |
+
max_length = min(1024, len(text.split()))
|
125 |
+
summary = summarizer(text, max_length=max_length//4, min_length=30, do_sample=False)
|
126 |
+
return summary[0]['summary_text']
|
127 |
+
except Exception as e:
|
128 |
+
print(f"Error during summarization: {e}")
|
129 |
+
return f"Summarization error: {str(e)}"
|
130 |
+
|
131 |
+
def identify_entities(text):
|
132 |
+
"""Identify locations and organizations in the text"""
|
133 |
+
global nlp
|
134 |
+
if not nlp:
|
135 |
+
return "NER model not available."
|
136 |
+
|
137 |
+
if not text or not text.strip():
|
138 |
+
return "Please enter text for entity analysis."
|
139 |
+
|
140 |
+
try:
|
141 |
+
doc = nlp(text)
|
142 |
+
locations = []
|
143 |
+
organizations = []
|
144 |
+
|
145 |
+
for ent in doc.ents:
|
146 |
+
if ent.label_ == "GPE" or ent.label_ == "LOC": # Geopolitical entity or Location
|
147 |
+
locations.append(ent.text)
|
148 |
+
elif ent.label_ == "ORG": # Organization
|
149 |
+
organizations.append(ent.text)
|
150 |
+
|
151 |
+
# Remove duplicates and sort
|
152 |
+
locations = sorted(list(set(locations)))
|
153 |
+
organizations = sorted(list(set(organizations)))
|
154 |
+
|
155 |
+
return {
|
156 |
+
"locations": locations,
|
157 |
+
"organizations": organizations
|
158 |
+
}
|
159 |
+
except Exception as e:
|
160 |
+
print(f"Error during entity identification: {e}")
|
161 |
+
return f"Entity identification error: {str(e)}"
|
162 |
+
|
163 |
+
def format_entities(entities):
|
164 |
+
"""Format identified entities for display"""
|
165 |
+
if isinstance(entities, str): # Error message
|
166 |
+
return entities
|
167 |
+
|
168 |
+
formatted = "<h3>Interested Parties</h3>"
|
169 |
+
|
170 |
+
# Add locations in red
|
171 |
+
if entities["locations"]:
|
172 |
+
formatted += "<p><b>Locations:</b> "
|
173 |
+
formatted += ", ".join([f"<span style='color: red'>{loc}</span>" for loc in entities["locations"]])
|
174 |
+
formatted += "</p>"
|
175 |
+
else:
|
176 |
+
formatted += "<p><b>Locations:</b> None identified</p>"
|
177 |
+
|
178 |
+
# Add organizations in green
|
179 |
+
if entities["organizations"]:
|
180 |
+
formatted += "<p><b>Organizations:</b> "
|
181 |
+
formatted += ", ".join([f"<span style='color: green'>{org}</span>" for org in entities["organizations"]])
|
182 |
+
formatted += "</p>"
|
183 |
+
else:
|
184 |
+
formatted += "<p><b>Organizations:</b> None identified</p>"
|
185 |
+
|
186 |
+
return formatted
|
187 |
+
|
188 |
+
def analyze_text_sentiment_by_sentence(text):
|
189 |
+
"""Analyze sentiment of each sentence in the text and format with colors"""
|
190 |
+
if not text or not text.strip():
|
191 |
+
return "Please enter text for analysis."
|
192 |
+
|
193 |
+
try:
|
194 |
+
# Split text into sentences
|
195 |
+
sentences = nltk.sent_tokenize(text)
|
196 |
+
formatted_result = ""
|
197 |
+
|
198 |
+
for sentence in sentences:
|
199 |
+
if len(sentence.strip()) < 3: # Skip very short sentences
|
200 |
+
continue
|
201 |
+
|
202 |
+
sentiment, _ = predict_sentiment(sentence)
|
203 |
+
color = COLOR_MAP.get(sentiment, "black")
|
204 |
+
|
205 |
+
formatted_result += f"<span style='color: {color}'>{sentence}</span> "
|
206 |
+
|
207 |
+
return formatted_result if formatted_result else "No valid sentences found for analysis."
|
208 |
+
except Exception as e:
|
209 |
+
print(f"Error during sentence-level sentiment analysis: {e}")
|
210 |
+
return f"Error: {str(e)}"
|
211 |
+
|
212 |
+
def analyze_financial_text(text):
|
213 |
+
"""Master function that performs all analysis tasks"""
|
214 |
+
if not text or not text.strip():
|
215 |
+
return "Please enter text for analysis.", "No summary available.", "No entities identified."
|
216 |
+
|
217 |
+
# Generate summary
|
218 |
+
summary = generate_summary(text)
|
219 |
+
|
220 |
+
# Perform sentence-level sentiment analysis
|
221 |
+
sentiment_analysis = analyze_text_sentiment_by_sentence(text)
|
222 |
+
|
223 |
+
# Identify entities
|
224 |
+
entities = identify_entities(text)
|
225 |
+
formatted_entities = format_entities(entities)
|
226 |
+
|
227 |
+
return sentiment_analysis, summary, formatted_entities
|
228 |
+
|
229 |
+
# Try to load models at app startup
|
230 |
+
try:
|
231 |
+
load_models_and_components()
|
232 |
+
except Exception as e:
|
233 |
+
print(f"Initial model loading failed: {e}")
|
234 |
+
# Gradio interface will still start, but functionality will be limited
|
235 |
+
|
236 |
+
# Build Gradio interface
|
237 |
+
model_info = "### Model Information\n"
|
238 |
+
if best_model_summary:
|
239 |
+
model_name = best_model_summary.get("model_name", "N/A")
|
240 |
+
accuracy = best_model_summary.get("accuracy_percent", "N/A")
|
241 |
+
run_time = best_model_summary.get("run_time_sec", "N/A")
|
242 |
+
hyperparams = best_model_summary.get("hyperparameters", {})
|
243 |
+
|
244 |
+
model_info += f"- **Model Name**: {model_name}\n"
|
245 |
+
model_info += f"- **Model Accuracy**: {accuracy}%\n"
|
246 |
+
model_info += f"- **Description**: The model is trained and fine-tuned using the financial news dataset to improve its sensitivity in recognizing financial sentiment.\n"
|
247 |
+
|
248 |
+
# Add hyperparameters
|
249 |
+
model_info += "\n### Hyperparameters\n"
|
250 |
+
model_info += f"- **Learning Rate**: {hyperparams.get('learning_rate', 'N/A')}\n"
|
251 |
+
model_info += f"- **Batch Size**: {hyperparams.get('batch_size', 'N/A')}\n"
|
252 |
+
model_info += f"- **Number of Epochs**: {hyperparams.get('num_epochs', 'N/A')}\n"
|
253 |
+
else:
|
254 |
+
model_info += "Model information loading failed. Please check the `training_summary.json` file and backend logs."
|
255 |
+
|
256 |
+
# Gradio interface definition
|
257 |
+
app_title = "ISOM5240_financial_tone"
|
258 |
+
app_description = (
|
259 |
+
"Analyze financial news text to extract summary, sentiment, and identify interested parties. "
|
260 |
+
"The sentiment analysis model is fine-tuned on financial news data."
|
261 |
+
)
|
262 |
+
|
263 |
+
with gr.Blocks(title=app_title) as iface:
|
264 |
+
gr.Markdown(f"# {app_title}")
|
265 |
+
gr.Markdown(app_description)
|
266 |
+
|
267 |
+
with gr.Row():
|
268 |
+
with gr.Column(scale=2):
|
269 |
+
input_text = gr.Textbox(
|
270 |
+
lines=10,
|
271 |
+
label="Financial News Text",
|
272 |
+
placeholder="Enter a longer financial news text here for analysis..."
|
273 |
+
)
|
274 |
+
analyze_btn = gr.Button("Start Analysis", variant="primary")
|
275 |
+
|
276 |
+
with gr.Column(scale=1):
|
277 |
+
gr.Markdown(model_info)
|
278 |
+
|
279 |
+
with gr.Row():
|
280 |
+
with gr.Column():
|
281 |
+
gr.Markdown("### Text Summary")
|
282 |
+
summary_output = gr.Textbox(label="Summary", lines=3)
|
283 |
+
|
284 |
+
with gr.Row():
|
285 |
+
with gr.Column():
|
286 |
+
gr.Markdown("### Sentiment Analysis (Sentence-level)")
|
287 |
+
gr.Markdown("- <span style='color: green'>Green</span>: Positive")
|
288 |
+
gr.Markdown("- <span style='color: blue'>Blue</span>: Neutral")
|
289 |
+
gr.Markdown("- <span style='color: red'>Red</span>: Negative")
|
290 |
+
sentiment_output = gr.HTML(label="Sentiment")
|
291 |
+
|
292 |
+
with gr.Row():
|
293 |
+
with gr.Column():
|
294 |
+
entities_output = gr.HTML(label="Interested Parties")
|
295 |
+
|
296 |
+
# Set up the click event for the analyze button
|
297 |
+
analyze_btn.click(
|
298 |
+
fn=analyze_financial_text,
|
299 |
+
inputs=[input_text],
|
300 |
+
outputs=[sentiment_output, summary_output, entities_output]
|
301 |
+
)
|
302 |
+
|
303 |
+
# Add examples
|
304 |
+
gr.Examples(
|
305 |
+
[
|
306 |
+
["The Federal Reserve announced today that interest rates will remain unchanged. Markets responded positively, with the S&P 500 gaining 1.2%. However, smaller tech companies in Silicon Valley expressed concerns about potential future rate hikes affecting their access to capital."],
|
307 |
+
["Apple Inc. reported record quarterly revenue of $91.8 billion, an increase of 9% from the year-ago quarter. The company's CEO Tim Cook attributed this success to strong international sales, particularly in European markets and China. However, supply chain disruptions in Taiwan may impact future quarters."]
|
308 |
+
],
|
309 |
+
inputs=input_text
|
310 |
+
)
|
311 |
+
|
312 |
+
if __name__ == "__main__":
|
313 |
+
print("Starting Gradio application...")
|
314 |
+
# share=True will generate a public link
|
315 |
+
iface.launch(share=True)
|
requirements.txt
CHANGED
@@ -1,3 +1,9 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers
|
2 |
+
datasets
|
3 |
+
torch
|
4 |
+
evaluate
|
5 |
+
scikit-learn
|
6 |
+
gradio
|
7 |
+
accelerate
|
8 |
+
nltk
|
9 |
+
spacy
|