walaa2022 commited on
Commit
324809c
·
verified ·
1 Parent(s): 91033f9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +173 -87
app.py CHANGED
@@ -4,9 +4,8 @@ import pandas as pd
4
  import torch
5
  import logging
6
  import gc
7
- import signal
8
- from contextlib import contextmanager
9
- import psutil
10
  from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
11
 
12
  # Setup logging
@@ -20,33 +19,15 @@ logger = logging.getLogger(__name__)
20
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
21
  logger.info(f"Using device: {DEVICE}")
22
 
23
- def monitor_memory():
24
- """Monitor system memory usage"""
25
- try:
26
- process = psutil.Process(os.getpid())
27
- memory_info = process.memory_info()
28
- logger.info(f"Memory usage: {memory_info.rss / 1024 / 1024:.2f} MB")
29
- if DEVICE == "cuda":
30
- logger.info(f"GPU Memory: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f} MB")
31
- except Exception as e:
32
- logger.error(f"Error monitoring memory: {str(e)}")
33
-
34
  def clear_gpu_memory():
35
  """Utility function to clear GPU memory"""
36
  if DEVICE == "cuda":
37
  torch.cuda.empty_cache()
38
  gc.collect()
39
 
40
- @contextmanager
41
- def timeout_context(seconds):
42
- def signal_handler(signum, frame):
43
- raise TimeoutError(f"Operation timed out after {seconds} seconds")
44
- signal.signal(signal.SIGALRM, signal_handler)
45
- signal.alarm(seconds)
46
- try:
47
- yield
48
- finally:
49
- signal.alarm(0)
50
 
51
  class ModelManager:
52
  """Handles model loading and inference"""
@@ -59,40 +40,46 @@ class ModelManager:
59
  self.max_cache_size = 2
60
 
61
  def load_model(self, model_name, model_type="sentiment", timeout=300):
62
- """Load model and tokenizer with timeout"""
63
  try:
64
  if model_name in self.model_cache:
65
  self.models[model_name] = self.model_cache[model_name]
66
  logger.info(f"Loaded {model_name} from cache")
67
  return
68
 
69
- with timeout_context(timeout):
70
- if model_name not in self.models:
71
- if model_type == "sentiment":
72
- self.tokenizers[model_name] = AutoTokenizer.from_pretrained(
73
- model_name,
74
- use_fast=True
75
- )
76
- self.models[model_name] = AutoModelForSequenceClassification.from_pretrained(
77
- model_name,
78
- torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
79
- ).to(self.device)
80
- else:
81
- self.models[model_name] = pipeline(
82
- "text-generation",
83
- model=model_name,
84
- device_map="auto" if self.device == "cuda" else None,
85
- torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
86
- )
87
-
88
- # Cache the model
89
- self.cache_model(model_name, self.models[model_name])
90
- logger.info(f"Successfully loaded model: {model_name}")
91
- monitor_memory()
 
 
 
 
 
 
92
 
93
  except Exception as e:
94
  logger.error(f"Error loading model {model_name}: {str(e)}")
95
- raise
96
 
97
  def cache_model(self, model_name, model):
98
  """Cache model for faster reloading"""
@@ -132,9 +119,10 @@ class FinancialAnalyzer:
132
  "recommendation": "tiiuae/falcon-rw-1b"
133
  }
134
 
135
- # Load sentiment model at initialization
136
  try:
137
- self.model_manager.load_model(self.models["sentiment"], "sentiment")
 
138
  except Exception as e:
139
  logger.error(f"Failed to initialize sentiment model: {str(e)}")
140
  raise
@@ -186,12 +174,17 @@ class FinancialAnalyzer:
186
  if len(text) == 0:
187
  raise ValueError("Empty text input")
188
 
 
 
 
 
 
189
  # Tokenize with proper padding and truncation
190
  inputs = tokenizer(
191
  text,
192
  return_tensors="pt",
193
  truncation=True,
194
- max_length=512,
195
  padding=True
196
  ).to(DEVICE)
197
 
@@ -217,10 +210,16 @@ class FinancialAnalyzer:
217
  return [{"label": "error", "score": 1.0}]
218
 
219
  def generate_analysis(self, financial_data):
220
- """Generate strategic analysis with improved prompting"""
221
  try:
222
  model_name = self.models["analysis"]
223
- self.model_manager.load_model(model_name, "generation")
 
 
 
 
 
 
224
 
225
  prompt = f"""[INST] As a senior financial analyst, provide a detailed analysis of these financial statements:
226
 
@@ -256,6 +255,7 @@ class FinancialAnalyzer:
256
 
257
  Provide specific metrics and detailed explanations for each section. [/INST]"""
258
 
 
259
  response = self.model_manager.get_model(model_name)(
260
  prompt,
261
  max_length=2000,
@@ -283,12 +283,27 @@ class FinancialAnalyzer:
283
  sections = text.split('\n\n')
284
  formatted_sections = []
285
 
 
286
  for section in sections:
287
- if section.strip():
288
- if any(section.startswith(str(i)) for i in range(1, 6)):
289
- formatted_sections.append(f"### {section}")
290
- else:
291
- formatted_sections.append(section)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
 
293
  return '\n\n'.join(formatted_sections)
294
  except Exception as e:
@@ -296,10 +311,16 @@ class FinancialAnalyzer:
296
  return text
297
 
298
  def generate_recommendations(self, analysis):
299
- """Generate recommendations with comprehensive prompting"""
300
  try:
301
  model_name = self.models["recommendation"]
302
- self.model_manager.load_model(model_name, "generation")
 
 
 
 
 
 
303
 
304
  prompt = f"""Based on this financial analysis, provide detailed strategic recommendations:
305
 
@@ -341,6 +362,7 @@ class FinancialAnalyzer:
341
 
342
  Format each section with clear, actionable bullet points."""
343
 
 
344
  response = self.model_manager.get_model(model_name)(
345
  prompt,
346
  max_length=2000,
@@ -368,12 +390,27 @@ class FinancialAnalyzer:
368
  sections = text.split('\n\n')
369
  formatted_sections = []
370
 
 
371
  for section in sections:
372
- if section.strip():
373
- if any(section.startswith(str(i)) for i in range(1, 6)):
374
- formatted_sections.append(f"### {section}")
375
- else:
376
- formatted_sections.append(section)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
 
378
  return '\n\n'.join(formatted_sections)
379
  except Exception as e:
@@ -383,7 +420,8 @@ class FinancialAnalyzer:
383
  def analyze_financial_statements(income_statement, balance_sheet):
384
  """Main analysis function with improved error handling and logging"""
385
  try:
386
- monitor_memory()
 
387
  analyzer = FinancialAnalyzer()
388
 
389
  # Validate inputs
@@ -391,8 +429,9 @@ def analyze_financial_statements(income_statement, balance_sheet):
391
  return "Error: Please provide both income statement and balance sheet files"
392
 
393
  # Process financial statements
394
- logger.info("Processing financial statements...")
395
  income_summary = analyzer.read_csv(income_statement)
 
396
  balance_summary = analyzer.read_csv(balance_sheet)
397
 
398
  financial_data = f"""
@@ -404,20 +443,32 @@ def analyze_financial_statements(income_statement, balance_sheet):
404
  """
405
 
406
  # Generate analysis
407
- logger.info("Generating analysis...")
408
  analysis = analyzer.generate_analysis(financial_data)
 
 
 
409
 
410
  # Analyze sentiment
411
- logger.info("Analyzing sentiment...")
412
  sentiment = analyzer.analyze_sentiment(analysis)
 
 
 
413
 
414
  # Generate recommendations
415
- logger.info("Generating recommendations...")
416
  recommendations = analyzer.generate_recommendations(analysis)
 
 
 
417
 
418
  # Format results
 
419
  result = format_results(analysis, sentiment, recommendations)
420
- monitor_memory()
 
 
421
  return result
422
 
423
  except Exception as e:
@@ -429,7 +480,9 @@ def analyze_financial_statements(income_statement, balance_sheet):
429
  Please verify:
430
  1. Files are valid CSV format
431
  2. Files contain required financial data
432
- 3. File size is within limits"""
 
 
433
 
434
  def format_results(analysis, sentiment, recommendations):
435
  """Format analysis results with improved validation and formatting"""
@@ -458,34 +511,67 @@ def format_results(analysis, sentiment, recommendations):
458
  logger.error(f"Formatting error: {str(e)}")
459
  return "Error formatting results"
460
 
461
- # Create Gradio interface with improved error handling
462
  iface = gr.Interface(
463
  fn=analyze_financial_statements,
464
  inputs=[
465
- gr.File(label="Income Statement (CSV)"),
466
- gr.File(label="Balance Sheet (CSV)")
 
 
 
 
 
 
467
  ],
468
  outputs=gr.Markdown(),
469
- title="Financial Statement Analyzer",
470
- description="""Upload financial statements for AI-powered analysis:
471
- - Strategic Analysis (TinyLlama)
472
- - Sentiment Analysis (FinBERT)
473
- - Strategic Recommendations (Falcon)
474
-
475
- Note:
476
- - Files must be in CSV format
477
- - Each file should contain financial data in columns
478
- - Maximum file size: 10MB""",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
479
  flagging_mode="never"
480
  )
481
 
 
482
  if __name__ == "__main__":
483
  try:
 
484
  iface.queue()
 
 
485
  iface.launch(
486
  share=False,
487
  server_name="0.0.0.0",
488
- server_port=7860
 
 
489
  )
490
  except Exception as e:
491
  logger.error(f"Launch error: {str(e)}")
 
4
  import torch
5
  import logging
6
  import gc
7
+ import threading
8
+ import concurrent.futures
 
9
  from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
10
 
11
  # Setup logging
 
19
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
20
  logger.info(f"Using device: {DEVICE}")
21
 
 
 
 
 
 
 
 
 
 
 
 
22
  def clear_gpu_memory():
23
  """Utility function to clear GPU memory"""
24
  if DEVICE == "cuda":
25
  torch.cuda.empty_cache()
26
  gc.collect()
27
 
28
+ class ModelLoadingError(Exception):
29
+ """Custom exception for model loading errors"""
30
+ pass
 
 
 
 
 
 
 
31
 
32
  class ModelManager:
33
  """Handles model loading and inference"""
 
40
  self.max_cache_size = 2
41
 
42
  def load_model(self, model_name, model_type="sentiment", timeout=300):
43
+ """Load model and tokenizer with thread-safe timeout"""
44
  try:
45
  if model_name in self.model_cache:
46
  self.models[model_name] = self.model_cache[model_name]
47
  logger.info(f"Loaded {model_name} from cache")
48
  return
49
 
50
+ def load_model_task():
51
+ if model_type == "sentiment":
52
+ self.tokenizers[model_name] = AutoTokenizer.from_pretrained(
53
+ model_name,
54
+ use_fast=True
55
+ )
56
+ self.models[model_name] = AutoModelForSequenceClassification.from_pretrained(
57
+ model_name,
58
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
59
+ ).to(self.device)
60
+ else:
61
+ self.models[model_name] = pipeline(
62
+ "text-generation",
63
+ model=model_name,
64
+ device_map="auto" if self.device == "cuda" else None,
65
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
66
+ )
67
+
68
+ # Use ThreadPoolExecutor for timeout
69
+ with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
70
+ future = executor.submit(load_model_task)
71
+ try:
72
+ future.result(timeout=timeout)
73
+ except concurrent.futures.TimeoutError:
74
+ raise ModelLoadingError(f"Model loading timed out after {timeout} seconds")
75
+
76
+ # Cache the model
77
+ self.cache_model(model_name, self.models[model_name])
78
+ logger.info(f"Successfully loaded model: {model_name}")
79
 
80
  except Exception as e:
81
  logger.error(f"Error loading model {model_name}: {str(e)}")
82
+ raise ModelLoadingError(f"Failed to load model {model_name}: {str(e)}")
83
 
84
  def cache_model(self, model_name, model):
85
  """Cache model for faster reloading"""
 
119
  "recommendation": "tiiuae/falcon-rw-1b"
120
  }
121
 
122
+ # Load sentiment model at initialization with longer timeout
123
  try:
124
+ self.model_manager.load_model(self.models["sentiment"], "sentiment", timeout=600)
125
+ logger.info("Sentiment model initialized successfully")
126
  except Exception as e:
127
  logger.error(f"Failed to initialize sentiment model: {str(e)}")
128
  raise
 
174
  if len(text) == 0:
175
  raise ValueError("Empty text input")
176
 
177
+ # Truncate text if too long
178
+ max_length = 512
179
+ if len(text.split()) > max_length:
180
+ logger.warning(f"Text length exceeds {max_length} tokens. Truncating...")
181
+
182
  # Tokenize with proper padding and truncation
183
  inputs = tokenizer(
184
  text,
185
  return_tensors="pt",
186
  truncation=True,
187
+ max_length=max_length,
188
  padding=True
189
  ).to(DEVICE)
190
 
 
210
  return [{"label": "error", "score": 1.0}]
211
 
212
  def generate_analysis(self, financial_data):
213
+ """Generate strategic analysis with improved prompting and error handling"""
214
  try:
215
  model_name = self.models["analysis"]
216
+ self.model_manager.load_model(model_name, "generation", timeout=600)
217
+
218
+ # Truncate financial data if too long
219
+ max_data_length = 1000
220
+ if len(financial_data.split()) > max_data_length:
221
+ logger.warning(f"Financial data too long. Truncating to {max_data_length} tokens...")
222
+ financial_data = ' '.join(financial_data.split()[:max_data_length])
223
 
224
  prompt = f"""[INST] As a senior financial analyst, provide a detailed analysis of these financial statements:
225
 
 
255
 
256
  Provide specific metrics and detailed explanations for each section. [/INST]"""
257
 
258
+ logger.info("Generating analysis...")
259
  response = self.model_manager.get_model(model_name)(
260
  prompt,
261
  max_length=2000,
 
283
  sections = text.split('\n\n')
284
  formatted_sections = []
285
 
286
+ current_section = None
287
  for section in sections:
288
+ section = section.strip()
289
+ if not section:
290
+ continue
291
+
292
+ # Check if this is a new section
293
+ if any(section.startswith(str(i)) for i in range(1, 6)):
294
+ current_section = f"### {section}"
295
+ formatted_sections.append(current_section)
296
+ elif current_section:
297
+ # Add bullet points to content under sections
298
+ lines = section.split('\n')
299
+ formatted_lines = []
300
+ for line in lines:
301
+ line = line.strip()
302
+ if line:
303
+ if not line.startswith('- '):
304
+ line = f"- {line}"
305
+ formatted_lines.append(line)
306
+ formatted_sections.append('\n'.join(formatted_lines))
307
 
308
  return '\n\n'.join(formatted_sections)
309
  except Exception as e:
 
311
  return text
312
 
313
  def generate_recommendations(self, analysis):
314
+ """Generate recommendations with improved prompting and error handling"""
315
  try:
316
  model_name = self.models["recommendation"]
317
+ self.model_manager.load_model(model_name, "generation", timeout=600)
318
+
319
+ # Truncate analysis if too long
320
+ max_analysis_length = 1000
321
+ if len(analysis.split()) > max_analysis_length:
322
+ logger.warning(f"Analysis too long. Truncating to {max_analysis_length} tokens...")
323
+ analysis = ' '.join(analysis.split()[:max_analysis_length])
324
 
325
  prompt = f"""Based on this financial analysis, provide detailed strategic recommendations:
326
 
 
362
 
363
  Format each section with clear, actionable bullet points."""
364
 
365
+ logger.info("Generating recommendations...")
366
  response = self.model_manager.get_model(model_name)(
367
  prompt,
368
  max_length=2000,
 
390
  sections = text.split('\n\n')
391
  formatted_sections = []
392
 
393
+ current_section = None
394
  for section in sections:
395
+ section = section.strip()
396
+ if not section:
397
+ continue
398
+
399
+ # Check if this is a new section
400
+ if any(section.startswith(str(i)) for i in range(1, 6)):
401
+ current_section = f"### {section}"
402
+ formatted_sections.append(current_section)
403
+ elif current_section:
404
+ # Add bullet points to content under sections
405
+ lines = section.split('\n')
406
+ formatted_lines = []
407
+ for line in lines:
408
+ line = line.strip()
409
+ if line:
410
+ if not line.startswith('- '):
411
+ line = f"- {line}"
412
+ formatted_lines.append(line)
413
+ formatted_sections.append('\n'.join(formatted_lines))
414
 
415
  return '\n\n'.join(formatted_sections)
416
  except Exception as e:
 
420
  def analyze_financial_statements(income_statement, balance_sheet):
421
  """Main analysis function with improved error handling and logging"""
422
  try:
423
+ clear_gpu_memory()
424
+ logger.info("Starting financial analysis...")
425
  analyzer = FinancialAnalyzer()
426
 
427
  # Validate inputs
 
429
  return "Error: Please provide both income statement and balance sheet files"
430
 
431
  # Process financial statements
432
+ logger.info("Processing income statement...")
433
  income_summary = analyzer.read_csv(income_statement)
434
+ logger.info("Processing balance sheet...")
435
  balance_summary = analyzer.read_csv(balance_sheet)
436
 
437
  financial_data = f"""
 
443
  """
444
 
445
  # Generate analysis
446
+ logger.info("Starting strategic analysis generation...")
447
  analysis = analyzer.generate_analysis(financial_data)
448
+ if "Error" in analysis:
449
+ logger.error("Strategic analysis generation failed")
450
+ return "Error: Failed to generate strategic analysis. Please try again."
451
 
452
  # Analyze sentiment
453
+ logger.info("Starting sentiment analysis...")
454
  sentiment = analyzer.analyze_sentiment(analysis)
455
+ if sentiment[0][0]['label'] == "error":
456
+ logger.error("Sentiment analysis failed")
457
+ return "Error: Failed to analyze sentiment. Please try again."
458
 
459
  # Generate recommendations
460
+ logger.info("Starting recommendations generation...")
461
  recommendations = analyzer.generate_recommendations(analysis)
462
+ if "Error" in recommendations:
463
+ logger.error("Recommendations generation failed")
464
+ return "Error: Failed to generate recommendations. Please try again."
465
 
466
  # Format results
467
+ logger.info("Formatting final results...")
468
  result = format_results(analysis, sentiment, recommendations)
469
+ clear_gpu_memory()
470
+
471
+ logger.info("Analysis completed successfully")
472
  return result
473
 
474
  except Exception as e:
 
480
  Please verify:
481
  1. Files are valid CSV format
482
  2. Files contain required financial data
483
+ 3. File size is within limits (max 10MB)
484
+ 4. Data contains numeric columns
485
+ 5. Files are not corrupted"""
486
 
487
  def format_results(analysis, sentiment, recommendations):
488
  """Format analysis results with improved validation and formatting"""
 
511
  logger.error(f"Formatting error: {str(e)}")
512
  return "Error formatting results"
513
 
514
+ # Create Gradio interface with improved error handling and guidance
515
  iface = gr.Interface(
516
  fn=analyze_financial_statements,
517
  inputs=[
518
+ gr.File(
519
+ label="Income Statement (CSV)",
520
+ info="Upload income statement in CSV format with numeric data columns"
521
+ ),
522
+ gr.File(
523
+ label="Balance Sheet (CSV)",
524
+ info="Upload balance sheet in CSV format with numeric data columns"
525
+ )
526
  ],
527
  outputs=gr.Markdown(),
528
+ title="AI-Powered Financial Statement Analyzer",
529
+ description="""## Financial Statement Analysis Tool
530
+
531
+ This tool provides comprehensive financial analysis using advanced AI models:
532
+ - Strategic Analysis: In-depth analysis of financial position and trends
533
+ - Sentiment Analysis: Assessment of financial health sentiment
534
+ - Strategic Recommendations: Actionable insights and recommendations
535
+
536
+ Requirements:
537
+ - Files must be in CSV format
538
+ - Must contain numeric data columns
539
+ - Maximum file size: 10MB
540
+ - Standard financial statement format preferred
541
+
542
+ Note: Analysis may take a few minutes to complete.""",
543
+ article="""### Usage Tips:
544
+ 1. Ensure your CSV files have clear column headers
545
+ 2. Verify that numeric data is properly formatted
546
+ 3. Wait for the analysis to complete - it may take several minutes
547
+ 4. The more detailed your financial data, the better the analysis
548
+
549
+ For optimal results, include key financial metrics such as:
550
+ - Revenue
551
+ - Expenses
552
+ - Profits/Losses
553
+ - Assets
554
+ - Liabilities
555
+ - Equity""",
556
+ examples=[
557
+ ["example_income_statement.csv", "example_balance_sheet.csv"]
558
+ ],
559
  flagging_mode="never"
560
  )
561
 
562
+ # Launch the interface with proper error handling
563
  if __name__ == "__main__":
564
  try:
565
+ # Enable queue for better handling of multiple requests
566
  iface.queue()
567
+
568
+ # Launch with specific server configuration
569
  iface.launch(
570
  share=False,
571
  server_name="0.0.0.0",
572
+ server_port=7860,
573
+ show_error=True,
574
+ max_threads=4
575
  )
576
  except Exception as e:
577
  logger.error(f"Launch error: {str(e)}")