walaa2022 commited on
Commit
91033f9
·
verified ·
1 Parent(s): eac8dde

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +227 -46
app.py CHANGED
@@ -3,8 +3,11 @@ import gradio as gr
3
  import pandas as pd
4
  import torch
5
  import logging
6
- from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
7
  import gc
 
 
 
 
8
 
9
  # Setup logging
10
  logging.basicConfig(
@@ -17,12 +20,34 @@ logger = logging.getLogger(__name__)
17
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
  logger.info(f"Using device: {DEVICE}")
19
 
 
 
 
 
 
 
 
 
 
 
 
20
  def clear_gpu_memory():
21
  """Utility function to clear GPU memory"""
22
  if DEVICE == "cuda":
23
  torch.cuda.empty_cache()
24
  gc.collect()
25
 
 
 
 
 
 
 
 
 
 
 
 
26
  class ModelManager:
27
  """Handles model loading and inference"""
28
 
@@ -30,29 +55,52 @@ class ModelManager:
30
  self.device = DEVICE
31
  self.models = {}
32
  self.tokenizers = {}
 
 
33
 
34
- def load_model(self, model_name, model_type="sentiment"):
35
- """Load model and tokenizer"""
36
  try:
37
- if model_name not in self.models:
38
- if model_type == "sentiment":
39
- self.tokenizers[model_name] = AutoTokenizer.from_pretrained(model_name)
40
- self.models[model_name] = AutoModelForSequenceClassification.from_pretrained(
41
- model_name,
42
- torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
43
- ).to(self.device)
44
- else:
45
- self.models[model_name] = pipeline(
46
- "text-generation",
47
- model=model_name,
48
- device_map="auto" if self.device == "cuda" else None,
49
- torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
50
- )
51
- logger.info(f"Loaded model: {model_name}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  except Exception as e:
53
  logger.error(f"Error loading model {model_name}: {str(e)}")
54
  raise
55
 
 
 
 
 
 
 
 
56
  def unload_model(self, model_name):
57
  """Unload model and tokenizer"""
58
  try:
@@ -92,29 +140,53 @@ class FinancialAnalyzer:
92
  raise
93
 
94
  def read_csv(self, file_obj):
95
- """Read and validate CSV file"""
96
  try:
97
  if file_obj is None:
98
  raise ValueError("No file provided")
99
 
100
- df = pd.read_csv(file_obj)
 
101
 
102
  if df.empty:
103
  raise ValueError("Empty CSV file")
104
-
105
- return df.describe()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  except Exception as e:
107
  logger.error(f"Error reading CSV: {str(e)}")
108
  raise
109
 
110
-
111
  def analyze_sentiment(self, text):
112
- """Analyze sentiment using FinBERT"""
113
  try:
114
  model_name = self.models["sentiment"]
115
  model = self.model_manager.get_model(model_name)
116
  tokenizer = self.model_manager.get_tokenizer(model_name)
117
 
 
 
 
 
 
 
 
 
 
 
118
  inputs = tokenizer(
119
  text,
120
  return_tensors="pt",
@@ -123,10 +195,12 @@ class FinancialAnalyzer:
123
  padding=True
124
  ).to(DEVICE)
125
 
 
126
  with torch.no_grad():
127
  outputs = model(**inputs)
128
  probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
129
 
 
130
  labels = ['negative', 'neutral', 'positive']
131
  scores = probabilities[0].cpu().tolist()
132
 
@@ -135,79 +209,181 @@ class FinancialAnalyzer:
135
  for label, score in zip(labels, scores)
136
  ]
137
 
 
138
  return [results]
 
139
  except Exception as e:
140
  logger.error(f"Sentiment analysis error: {str(e)}")
141
  return [{"label": "error", "score": 1.0}]
142
 
143
  def generate_analysis(self, financial_data):
144
- """Generate strategic analysis"""
145
  try:
146
  model_name = self.models["analysis"]
147
  self.model_manager.load_model(model_name, "generation")
148
 
149
- prompt = f"""[INST] Analyze these financial statements:
 
 
150
  {financial_data}
151
- Provide:
 
 
152
  1. Business Health Assessment
 
 
 
 
153
  2. Key Strategic Insights
 
 
 
 
154
  3. Market Position
 
 
 
 
155
  4. Growth Opportunities
156
- 5. Risk Factors [/INST]"""
 
 
 
 
 
 
 
 
 
157
 
158
  response = self.model_manager.get_model(model_name)(
159
  prompt,
160
- max_length=1000,
 
161
  temperature=0.7,
162
  do_sample=True,
163
  num_return_sequences=1,
164
- truncation=True
 
 
165
  )
166
 
167
- return response[0]['generated_text']
 
 
168
  except Exception as e:
169
  logger.error(f"Analysis generation error: {str(e)}")
170
  return "Error in analysis generation"
171
  finally:
172
  self.model_manager.unload_model(model_name)
173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  def generate_recommendations(self, analysis):
175
- """Generate recommendations"""
176
  try:
177
  model_name = self.models["recommendation"]
178
  self.model_manager.load_model(model_name, "generation")
179
 
180
- prompt = f"""Based on this analysis:
 
 
181
  {analysis}
182
-
183
- Provide actionable recommendations for:
 
184
  1. Strategic Initiatives
 
 
 
 
185
  2. Operational Improvements
 
 
 
 
186
  3. Financial Management
 
 
 
 
187
  4. Risk Mitigation
188
- 5. Growth Strategy"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
  response = self.model_manager.get_model(model_name)(
191
  prompt,
192
- max_length=1000,
193
- temperature=0.6,
 
194
  do_sample=True,
195
  num_return_sequences=1,
196
- truncation=True
 
 
197
  )
198
 
199
- return response[0]['generated_text']
 
 
200
  except Exception as e:
201
  logger.error(f"Recommendations generation error: {str(e)}")
202
  return "Error generating recommendations"
203
  finally:
204
  self.model_manager.unload_model(model_name)
205
 
206
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
  def analyze_financial_statements(income_statement, balance_sheet):
209
- """Main analysis function"""
210
  try:
 
211
  analyzer = FinancialAnalyzer()
212
 
213
  # Validate inputs
@@ -240,7 +416,9 @@ def analyze_financial_statements(income_statement, balance_sheet):
240
  recommendations = analyzer.generate_recommendations(analysis)
241
 
242
  # Format results
243
- return format_results(analysis, sentiment, recommendations)
 
 
244
 
245
  except Exception as e:
246
  logger.error(f"Analysis error: {str(e)}")
@@ -254,7 +432,7 @@ def analyze_financial_statements(income_statement, balance_sheet):
254
  3. File size is within limits"""
255
 
256
  def format_results(analysis, sentiment, recommendations):
257
- """Format analysis results"""
258
  try:
259
  if not isinstance(analysis, str) or not isinstance(recommendations, str):
260
  raise ValueError("Invalid input types")
@@ -280,7 +458,7 @@ def format_results(analysis, sentiment, recommendations):
280
  logger.error(f"Formatting error: {str(e)}")
281
  return "Error formatting results"
282
 
283
- # Create Gradio interface
284
  iface = gr.Interface(
285
  fn=analyze_financial_statements,
286
  inputs=[
@@ -294,7 +472,10 @@ iface = gr.Interface(
294
  - Sentiment Analysis (FinBERT)
295
  - Strategic Recommendations (Falcon)
296
 
297
- Note: Please ensure files are in CSV format.""",
 
 
 
298
  flagging_mode="never"
299
  )
300
 
 
3
  import pandas as pd
4
  import torch
5
  import logging
 
6
  import gc
7
+ import signal
8
+ from contextlib import contextmanager
9
+ import psutil
10
+ from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
11
 
12
  # Setup logging
13
  logging.basicConfig(
 
20
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
21
  logger.info(f"Using device: {DEVICE}")
22
 
23
+ def monitor_memory():
24
+ """Monitor system memory usage"""
25
+ try:
26
+ process = psutil.Process(os.getpid())
27
+ memory_info = process.memory_info()
28
+ logger.info(f"Memory usage: {memory_info.rss / 1024 / 1024:.2f} MB")
29
+ if DEVICE == "cuda":
30
+ logger.info(f"GPU Memory: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f} MB")
31
+ except Exception as e:
32
+ logger.error(f"Error monitoring memory: {str(e)}")
33
+
34
  def clear_gpu_memory():
35
  """Utility function to clear GPU memory"""
36
  if DEVICE == "cuda":
37
  torch.cuda.empty_cache()
38
  gc.collect()
39
 
40
+ @contextmanager
41
+ def timeout_context(seconds):
42
+ def signal_handler(signum, frame):
43
+ raise TimeoutError(f"Operation timed out after {seconds} seconds")
44
+ signal.signal(signal.SIGALRM, signal_handler)
45
+ signal.alarm(seconds)
46
+ try:
47
+ yield
48
+ finally:
49
+ signal.alarm(0)
50
+
51
  class ModelManager:
52
  """Handles model loading and inference"""
53
 
 
55
  self.device = DEVICE
56
  self.models = {}
57
  self.tokenizers = {}
58
+ self.model_cache = {}
59
+ self.max_cache_size = 2
60
 
61
+ def load_model(self, model_name, model_type="sentiment", timeout=300):
62
+ """Load model and tokenizer with timeout"""
63
  try:
64
+ if model_name in self.model_cache:
65
+ self.models[model_name] = self.model_cache[model_name]
66
+ logger.info(f"Loaded {model_name} from cache")
67
+ return
68
+
69
+ with timeout_context(timeout):
70
+ if model_name not in self.models:
71
+ if model_type == "sentiment":
72
+ self.tokenizers[model_name] = AutoTokenizer.from_pretrained(
73
+ model_name,
74
+ use_fast=True
75
+ )
76
+ self.models[model_name] = AutoModelForSequenceClassification.from_pretrained(
77
+ model_name,
78
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
79
+ ).to(self.device)
80
+ else:
81
+ self.models[model_name] = pipeline(
82
+ "text-generation",
83
+ model=model_name,
84
+ device_map="auto" if self.device == "cuda" else None,
85
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
86
+ )
87
+
88
+ # Cache the model
89
+ self.cache_model(model_name, self.models[model_name])
90
+ logger.info(f"Successfully loaded model: {model_name}")
91
+ monitor_memory()
92
+
93
  except Exception as e:
94
  logger.error(f"Error loading model {model_name}: {str(e)}")
95
  raise
96
 
97
+ def cache_model(self, model_name, model):
98
+ """Cache model for faster reloading"""
99
+ if len(self.model_cache) >= self.max_cache_size:
100
+ oldest_model = next(iter(self.model_cache))
101
+ del self.model_cache[oldest_model]
102
+ self.model_cache[model_name] = model
103
+
104
  def unload_model(self, model_name):
105
  """Unload model and tokenizer"""
106
  try:
 
140
  raise
141
 
142
  def read_csv(self, file_obj):
143
+ """Read and validate CSV file with better error handling"""
144
  try:
145
  if file_obj is None:
146
  raise ValueError("No file provided")
147
 
148
+ # Read CSV with explicit encoding and error handling
149
+ df = pd.read_csv(file_obj, encoding='utf-8', on_bad_lines='skip')
150
 
151
  if df.empty:
152
  raise ValueError("Empty CSV file")
153
+
154
+ # Log CSV information
155
+ logger.info(f"CSV Preview:\n{df.head()}")
156
+ logger.info(f"CSV Columns: {df.columns.tolist()}")
157
+
158
+ # Validate numeric columns
159
+ numeric_cols = df.select_dtypes(include=['float64', 'int64']).columns
160
+ if len(numeric_cols) == 0:
161
+ raise ValueError("No numeric columns found in CSV")
162
+
163
+ # Generate statistical summary
164
+ summary = df[numeric_cols].describe()
165
+ logger.info(f"Statistical Summary:\n{summary}")
166
+
167
+ return summary
168
+
169
  except Exception as e:
170
  logger.error(f"Error reading CSV: {str(e)}")
171
  raise
172
 
 
173
  def analyze_sentiment(self, text):
174
+ """Analyze sentiment using FinBERT with improved error handling"""
175
  try:
176
  model_name = self.models["sentiment"]
177
  model = self.model_manager.get_model(model_name)
178
  tokenizer = self.model_manager.get_tokenizer(model_name)
179
 
180
+ # Validate input
181
+ if not text or not isinstance(text, str):
182
+ raise ValueError("Invalid input text")
183
+
184
+ # Preprocess text
185
+ text = text.strip()
186
+ if len(text) == 0:
187
+ raise ValueError("Empty text input")
188
+
189
+ # Tokenize with proper padding and truncation
190
  inputs = tokenizer(
191
  text,
192
  return_tensors="pt",
 
195
  padding=True
196
  ).to(DEVICE)
197
 
198
+ # Get prediction
199
  with torch.no_grad():
200
  outputs = model(**inputs)
201
  probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
202
 
203
+ # Process results
204
  labels = ['negative', 'neutral', 'positive']
205
  scores = probabilities[0].cpu().tolist()
206
 
 
209
  for label, score in zip(labels, scores)
210
  ]
211
 
212
+ logger.info(f"Sentiment analysis results: {results}")
213
  return [results]
214
+
215
  except Exception as e:
216
  logger.error(f"Sentiment analysis error: {str(e)}")
217
  return [{"label": "error", "score": 1.0}]
218
 
219
  def generate_analysis(self, financial_data):
220
+ """Generate strategic analysis with improved prompting"""
221
  try:
222
  model_name = self.models["analysis"]
223
  self.model_manager.load_model(model_name, "generation")
224
 
225
+ prompt = f"""[INST] As a senior financial analyst, provide a detailed analysis of these financial statements:
226
+
227
+ Financial Data:
228
  {financial_data}
229
+
230
+ Please provide a comprehensive analysis covering:
231
+
232
  1. Business Health Assessment
233
+ - Current financial position
234
+ - Key performance indicators
235
+ - Trend analysis
236
+
237
  2. Key Strategic Insights
238
+ - Major financial trends
239
+ - Performance drivers
240
+ - Areas of concern
241
+
242
  3. Market Position
243
+ - Competitive advantages
244
+ - Market share indicators
245
+ - Industry comparison
246
+
247
  4. Growth Opportunities
248
+ - Expansion potential
249
+ - Investment opportunities
250
+ - Revenue growth areas
251
+
252
+ 5. Risk Factors
253
+ - Financial risks
254
+ - Operational risks
255
+ - Market risks
256
+
257
+ Provide specific metrics and detailed explanations for each section. [/INST]"""
258
 
259
  response = self.model_manager.get_model(model_name)(
260
  prompt,
261
+ max_length=2000,
262
+ min_length=800,
263
  temperature=0.7,
264
  do_sample=True,
265
  num_return_sequences=1,
266
+ truncation=True,
267
+ repetition_penalty=1.2,
268
+ no_repeat_ngram_size=3
269
  )
270
 
271
+ analysis_text = response[0]['generated_text']
272
+ return self.format_analysis_text(analysis_text)
273
+
274
  except Exception as e:
275
  logger.error(f"Analysis generation error: {str(e)}")
276
  return "Error in analysis generation"
277
  finally:
278
  self.model_manager.unload_model(model_name)
279
 
280
+ def format_analysis_text(self, text):
281
+ """Format the analysis text for better readability"""
282
+ try:
283
+ sections = text.split('\n\n')
284
+ formatted_sections = []
285
+
286
+ for section in sections:
287
+ if section.strip():
288
+ if any(section.startswith(str(i)) for i in range(1, 6)):
289
+ formatted_sections.append(f"### {section}")
290
+ else:
291
+ formatted_sections.append(section)
292
+
293
+ return '\n\n'.join(formatted_sections)
294
+ except Exception as e:
295
+ logger.error(f"Error formatting analysis text: {str(e)}")
296
+ return text
297
+
298
  def generate_recommendations(self, analysis):
299
+ """Generate recommendations with comprehensive prompting"""
300
  try:
301
  model_name = self.models["recommendation"]
302
  self.model_manager.load_model(model_name, "generation")
303
 
304
+ prompt = f"""Based on this financial analysis, provide detailed strategic recommendations:
305
+
306
+ Analysis Context:
307
  {analysis}
308
+
309
+ Please provide specific, actionable recommendations for each area:
310
+
311
  1. Strategic Initiatives
312
+ - Detail specific actions for business growth
313
+ - Identify market expansion opportunities
314
+ - Outline product/service development strategies
315
+
316
  2. Operational Improvements
317
+ - Specify efficiency enhancement measures
318
+ - Recommend process optimization steps
319
+ - Suggest cost reduction strategies
320
+
321
  3. Financial Management
322
+ - Provide cash flow optimization tactics
323
+ - Prioritize investment opportunities
324
+ - Detail risk management approaches
325
+
326
  4. Risk Mitigation
327
+ - Address identified risks
328
+ - Outline specific mitigation strategies
329
+ - Suggest monitoring mechanisms
330
+
331
+ 5. Growth Strategy
332
+ - Identify market opportunities
333
+ - Detail expansion plans
334
+ - Specify resource requirements
335
+
336
+ For each recommendation:
337
+ - Include implementation timeline
338
+ - Specify resource requirements
339
+ - Define success metrics
340
+ - List potential challenges
341
+
342
+ Format each section with clear, actionable bullet points."""
343
 
344
  response = self.model_manager.get_model(model_name)(
345
  prompt,
346
+ max_length=2000,
347
+ min_length=800,
348
+ temperature=0.7,
349
  do_sample=True,
350
  num_return_sequences=1,
351
+ truncation=True,
352
+ repetition_penalty=1.2,
353
+ no_repeat_ngram_size=3
354
  )
355
 
356
+ recommendations_text = response[0]['generated_text']
357
+ return self.format_recommendation_text(recommendations_text)
358
+
359
  except Exception as e:
360
  logger.error(f"Recommendations generation error: {str(e)}")
361
  return "Error generating recommendations"
362
  finally:
363
  self.model_manager.unload_model(model_name)
364
 
365
+ def format_recommendation_text(self, text):
366
+ """Format the recommendation text for better readability"""
367
+ try:
368
+ sections = text.split('\n\n')
369
+ formatted_sections = []
370
+
371
+ for section in sections:
372
+ if section.strip():
373
+ if any(section.startswith(str(i)) for i in range(1, 6)):
374
+ formatted_sections.append(f"### {section}")
375
+ else:
376
+ formatted_sections.append(section)
377
+
378
+ return '\n\n'.join(formatted_sections)
379
+ except Exception as e:
380
+ logger.error(f"Error formatting recommendation text: {str(e)}")
381
+ return text
382
 
383
  def analyze_financial_statements(income_statement, balance_sheet):
384
+ """Main analysis function with improved error handling and logging"""
385
  try:
386
+ monitor_memory()
387
  analyzer = FinancialAnalyzer()
388
 
389
  # Validate inputs
 
416
  recommendations = analyzer.generate_recommendations(analysis)
417
 
418
  # Format results
419
+ result = format_results(analysis, sentiment, recommendations)
420
+ monitor_memory()
421
+ return result
422
 
423
  except Exception as e:
424
  logger.error(f"Analysis error: {str(e)}")
 
432
  3. File size is within limits"""
433
 
434
  def format_results(analysis, sentiment, recommendations):
435
+ """Format analysis results with improved validation and formatting"""
436
  try:
437
  if not isinstance(analysis, str) or not isinstance(recommendations, str):
438
  raise ValueError("Invalid input types")
 
458
  logger.error(f"Formatting error: {str(e)}")
459
  return "Error formatting results"
460
 
461
+ # Create Gradio interface with improved error handling
462
  iface = gr.Interface(
463
  fn=analyze_financial_statements,
464
  inputs=[
 
472
  - Sentiment Analysis (FinBERT)
473
  - Strategic Recommendations (Falcon)
474
 
475
+ Note:
476
+ - Files must be in CSV format
477
+ - Each file should contain financial data in columns
478
+ - Maximum file size: 10MB""",
479
  flagging_mode="never"
480
  )
481