ylingag commited on
Commit
c726c9b
·
verified ·
1 Parent(s): 5c7aa28

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +315 -0
  2. requirements.txt +9 -3
app.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ import os
4
+ import torch
5
+ import nltk
6
+ import spacy
7
+ import re
8
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline, AutoModelForSeq2SeqLM
9
+
10
+ # Download necessary NLTK data for sentence tokenization
11
+ try:
12
+ nltk.data.find('tokenizers/punkt')
13
+ except LookupError:
14
+ nltk.download('punkt')
15
+
16
+ SUMMARY_FILE = "training_summary.json"
17
+ # Assume label meanings are consistent with previous files
18
+ LABEL_MAP = {0: "Negative", 1: "Neutral", 2: "Positive"}
19
+ # Color coding for sentiment
20
+ COLOR_MAP = {
21
+ "Negative": "red",
22
+ "Neutral": "blue",
23
+ "Positive": "green"
24
+ }
25
+
26
+ # Global loading of models and NLP components
27
+ loaded_model = None
28
+ loaded_tokenizer = None
29
+ best_model_summary = None
30
+ summarizer = None
31
+ nlp = None # For NER
32
+
33
+ def load_models_and_components():
34
+ global loaded_model, loaded_tokenizer, best_model_summary, summarizer, nlp
35
+
36
+ # Load sentiment analysis model from training
37
+ if not os.path.exists(SUMMARY_FILE):
38
+ raise FileNotFoundError(f"Error: Could not find training summary file {SUMMARY_FILE}. Please run the fine-tuning and testing scripts first.")
39
+
40
+ with open(SUMMARY_FILE, 'r') as f:
41
+ summary_data = json.load(f)
42
+
43
+ if "best_model_details" not in summary_data or not summary_data["best_model_details"]:
44
+ raise ValueError(f"Error: Best model information not found or incomplete in {SUMMARY_FILE}.")
45
+
46
+ best_model_summary = summary_data["best_model_details"]
47
+ best_model_path = best_model_summary.get("best_model_path")
48
+
49
+ if not best_model_path:
50
+ best_model_path = summary_data.get("best_model_path") # Compatible with older format
51
+
52
+ if not best_model_path or not os.path.exists(best_model_path):
53
+ raise FileNotFoundError(f"Error: Best model path {best_model_path} not found or invalid.")
54
+
55
+ print(f"Loading sentiment model {best_model_summary['model_name']} from {best_model_path}...")
56
+ try:
57
+ loaded_tokenizer = AutoTokenizer.from_pretrained(best_model_path)
58
+ loaded_model = AutoModelForSequenceClassification.from_pretrained(best_model_path)
59
+ loaded_model.eval() # Set to evaluation mode
60
+ print("Sentiment model loaded successfully.")
61
+ except Exception as e:
62
+ raise RuntimeError(f"Failed to load sentiment model: {e}")
63
+
64
+ # Load summarization model
65
+ print("Loading summarization model...")
66
+ try:
67
+ summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
68
+ print("Summarization model loaded successfully.")
69
+ except Exception as e:
70
+ print(f"Warning: Failed to load summarization model: {e}")
71
+ print("Will continue without summarization capability.")
72
+ summarizer = None
73
+
74
+ # Load spaCy model for NER (Named Entity Recognition)
75
+ print("Loading NER model...")
76
+ try:
77
+ # Download the model if it's not already downloaded
78
+ if not spacy.util.is_package("en_core_web_sm"):
79
+ spacy.cli.download("en_core_web_sm")
80
+ nlp = spacy.load("en_core_web_sm")
81
+ print("NER model loaded successfully.")
82
+ except Exception as e:
83
+ print(f"Warning: Failed to load NER model: {e}")
84
+ print("Will continue without NER capability.")
85
+ nlp = None
86
+
87
+ def predict_sentiment(text):
88
+ """Predict sentiment for a single piece of text"""
89
+ global loaded_model, loaded_tokenizer
90
+ if not loaded_model or not loaded_tokenizer:
91
+ return "Error: Model not loaded.", None
92
+
93
+ if not text or not text.strip():
94
+ return "Please enter text for analysis.", None
95
+
96
+ try:
97
+ inputs = loaded_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
98
+
99
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
100
+ loaded_model.to(device)
101
+ inputs = {k: v.to(device) for k, v in inputs.items()}
102
+
103
+ with torch.no_grad():
104
+ outputs = loaded_model(**inputs)
105
+
106
+ prediction_idx = torch.argmax(outputs.logits, dim=-1).item()
107
+ sentiment = LABEL_MAP.get(prediction_idx, f"Unknown ({prediction_idx})")
108
+ return sentiment, prediction_idx
109
+ except Exception as e:
110
+ print(f"Error during sentiment prediction: {e}")
111
+ return f"Error: {str(e)}", None
112
+
113
+ def generate_summary(text):
114
+ """Generate a summary for longer text"""
115
+ global summarizer
116
+ if not summarizer:
117
+ return "Summarization model not available."
118
+
119
+ if not text or len(text.strip()) < 50:
120
+ return "Text too short for summarization."
121
+
122
+ try:
123
+ # BART has a max length, so we'll truncate if needed
124
+ max_length = min(1024, len(text.split()))
125
+ summary = summarizer(text, max_length=max_length//4, min_length=30, do_sample=False)
126
+ return summary[0]['summary_text']
127
+ except Exception as e:
128
+ print(f"Error during summarization: {e}")
129
+ return f"Summarization error: {str(e)}"
130
+
131
+ def identify_entities(text):
132
+ """Identify locations and organizations in the text"""
133
+ global nlp
134
+ if not nlp:
135
+ return "NER model not available."
136
+
137
+ if not text or not text.strip():
138
+ return "Please enter text for entity analysis."
139
+
140
+ try:
141
+ doc = nlp(text)
142
+ locations = []
143
+ organizations = []
144
+
145
+ for ent in doc.ents:
146
+ if ent.label_ == "GPE" or ent.label_ == "LOC": # Geopolitical entity or Location
147
+ locations.append(ent.text)
148
+ elif ent.label_ == "ORG": # Organization
149
+ organizations.append(ent.text)
150
+
151
+ # Remove duplicates and sort
152
+ locations = sorted(list(set(locations)))
153
+ organizations = sorted(list(set(organizations)))
154
+
155
+ return {
156
+ "locations": locations,
157
+ "organizations": organizations
158
+ }
159
+ except Exception as e:
160
+ print(f"Error during entity identification: {e}")
161
+ return f"Entity identification error: {str(e)}"
162
+
163
+ def format_entities(entities):
164
+ """Format identified entities for display"""
165
+ if isinstance(entities, str): # Error message
166
+ return entities
167
+
168
+ formatted = "<h3>Interested Parties</h3>"
169
+
170
+ # Add locations in red
171
+ if entities["locations"]:
172
+ formatted += "<p><b>Locations:</b> "
173
+ formatted += ", ".join([f"<span style='color: red'>{loc}</span>" for loc in entities["locations"]])
174
+ formatted += "</p>"
175
+ else:
176
+ formatted += "<p><b>Locations:</b> None identified</p>"
177
+
178
+ # Add organizations in green
179
+ if entities["organizations"]:
180
+ formatted += "<p><b>Organizations:</b> "
181
+ formatted += ", ".join([f"<span style='color: green'>{org}</span>" for org in entities["organizations"]])
182
+ formatted += "</p>"
183
+ else:
184
+ formatted += "<p><b>Organizations:</b> None identified</p>"
185
+
186
+ return formatted
187
+
188
+ def analyze_text_sentiment_by_sentence(text):
189
+ """Analyze sentiment of each sentence in the text and format with colors"""
190
+ if not text or not text.strip():
191
+ return "Please enter text for analysis."
192
+
193
+ try:
194
+ # Split text into sentences
195
+ sentences = nltk.sent_tokenize(text)
196
+ formatted_result = ""
197
+
198
+ for sentence in sentences:
199
+ if len(sentence.strip()) < 3: # Skip very short sentences
200
+ continue
201
+
202
+ sentiment, _ = predict_sentiment(sentence)
203
+ color = COLOR_MAP.get(sentiment, "black")
204
+
205
+ formatted_result += f"<span style='color: {color}'>{sentence}</span> "
206
+
207
+ return formatted_result if formatted_result else "No valid sentences found for analysis."
208
+ except Exception as e:
209
+ print(f"Error during sentence-level sentiment analysis: {e}")
210
+ return f"Error: {str(e)}"
211
+
212
+ def analyze_financial_text(text):
213
+ """Master function that performs all analysis tasks"""
214
+ if not text or not text.strip():
215
+ return "Please enter text for analysis.", "No summary available.", "No entities identified."
216
+
217
+ # Generate summary
218
+ summary = generate_summary(text)
219
+
220
+ # Perform sentence-level sentiment analysis
221
+ sentiment_analysis = analyze_text_sentiment_by_sentence(text)
222
+
223
+ # Identify entities
224
+ entities = identify_entities(text)
225
+ formatted_entities = format_entities(entities)
226
+
227
+ return sentiment_analysis, summary, formatted_entities
228
+
229
+ # Try to load models at app startup
230
+ try:
231
+ load_models_and_components()
232
+ except Exception as e:
233
+ print(f"Initial model loading failed: {e}")
234
+ # Gradio interface will still start, but functionality will be limited
235
+
236
+ # Build Gradio interface
237
+ model_info = "### Model Information\n"
238
+ if best_model_summary:
239
+ model_name = best_model_summary.get("model_name", "N/A")
240
+ accuracy = best_model_summary.get("accuracy_percent", "N/A")
241
+ run_time = best_model_summary.get("run_time_sec", "N/A")
242
+ hyperparams = best_model_summary.get("hyperparameters", {})
243
+
244
+ model_info += f"- **Model Name**: {model_name}\n"
245
+ model_info += f"- **Model Accuracy**: {accuracy}%\n"
246
+ model_info += f"- **Description**: The model is trained and fine-tuned using the financial news dataset to improve its sensitivity in recognizing financial sentiment.\n"
247
+
248
+ # Add hyperparameters
249
+ model_info += "\n### Hyperparameters\n"
250
+ model_info += f"- **Learning Rate**: {hyperparams.get('learning_rate', 'N/A')}\n"
251
+ model_info += f"- **Batch Size**: {hyperparams.get('batch_size', 'N/A')}\n"
252
+ model_info += f"- **Number of Epochs**: {hyperparams.get('num_epochs', 'N/A')}\n"
253
+ else:
254
+ model_info += "Model information loading failed. Please check the `training_summary.json` file and backend logs."
255
+
256
+ # Gradio interface definition
257
+ app_title = "ISOM5240_financial_tone"
258
+ app_description = (
259
+ "Analyze financial news text to extract summary, sentiment, and identify interested parties. "
260
+ "The sentiment analysis model is fine-tuned on financial news data."
261
+ )
262
+
263
+ with gr.Blocks(title=app_title) as iface:
264
+ gr.Markdown(f"# {app_title}")
265
+ gr.Markdown(app_description)
266
+
267
+ with gr.Row():
268
+ with gr.Column(scale=2):
269
+ input_text = gr.Textbox(
270
+ lines=10,
271
+ label="Financial News Text",
272
+ placeholder="Enter a longer financial news text here for analysis..."
273
+ )
274
+ analyze_btn = gr.Button("Start Analysis", variant="primary")
275
+
276
+ with gr.Column(scale=1):
277
+ gr.Markdown(model_info)
278
+
279
+ with gr.Row():
280
+ with gr.Column():
281
+ gr.Markdown("### Text Summary")
282
+ summary_output = gr.Textbox(label="Summary", lines=3)
283
+
284
+ with gr.Row():
285
+ with gr.Column():
286
+ gr.Markdown("### Sentiment Analysis (Sentence-level)")
287
+ gr.Markdown("- <span style='color: green'>Green</span>: Positive")
288
+ gr.Markdown("- <span style='color: blue'>Blue</span>: Neutral")
289
+ gr.Markdown("- <span style='color: red'>Red</span>: Negative")
290
+ sentiment_output = gr.HTML(label="Sentiment")
291
+
292
+ with gr.Row():
293
+ with gr.Column():
294
+ entities_output = gr.HTML(label="Interested Parties")
295
+
296
+ # Set up the click event for the analyze button
297
+ analyze_btn.click(
298
+ fn=analyze_financial_text,
299
+ inputs=[input_text],
300
+ outputs=[sentiment_output, summary_output, entities_output]
301
+ )
302
+
303
+ # Add examples
304
+ gr.Examples(
305
+ [
306
+ ["The Federal Reserve announced today that interest rates will remain unchanged. Markets responded positively, with the S&P 500 gaining 1.2%. However, smaller tech companies in Silicon Valley expressed concerns about potential future rate hikes affecting their access to capital."],
307
+ ["Apple Inc. reported record quarterly revenue of $91.8 billion, an increase of 9% from the year-ago quarter. The company's CEO Tim Cook attributed this success to strong international sales, particularly in European markets and China. However, supply chain disruptions in Taiwan may impact future quarters."]
308
+ ],
309
+ inputs=input_text
310
+ )
311
+
312
+ if __name__ == "__main__":
313
+ print("Starting Gradio application...")
314
+ # share=True will generate a public link
315
+ iface.launch(share=True)
requirements.txt CHANGED
@@ -1,3 +1,9 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
 
 
 
1
+ transformers
2
+ datasets
3
+ torch
4
+ evaluate
5
+ scikit-learn
6
+ gradio
7
+ accelerate
8
+ nltk
9
+ spacy