Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import pandas as pd | |
import torch | |
import logging | |
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer | |
import gc | |
# Setup logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
# Device configuration | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"Using device: {DEVICE}") | |
def clear_gpu_memory(): | |
"""Utility function to clear GPU memory""" | |
if DEVICE == "cuda": | |
torch.cuda.empty_cache() | |
gc.collect() | |
class ModelManager: | |
"""Handles model loading and inference""" | |
def __init__(self): | |
self.device = DEVICE | |
self.models = {} | |
self.tokenizers = {} | |
def load_model(self, model_name, model_type="sentiment"): | |
"""Load model and tokenizer""" | |
try: | |
if model_name not in self.models: | |
if model_type == "sentiment": | |
self.tokenizers[model_name] = AutoTokenizer.from_pretrained(model_name) | |
self.models[model_name] = AutoModelForSequenceClassification.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 | |
).to(self.device) | |
else: | |
self.models[model_name] = pipeline( | |
"text-generation", | |
model=model_name, | |
device_map="auto" if self.device == "cuda" else None, | |
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 | |
) | |
logger.info(f"Loaded model: {model_name}") | |
except Exception as e: | |
logger.error(f"Error loading model {model_name}: {str(e)}") | |
raise | |
def unload_model(self, model_name): | |
"""Unload model and tokenizer""" | |
try: | |
if model_name in self.models: | |
del self.models[model_name] | |
if model_name in self.tokenizers: | |
del self.tokenizers[model_name] | |
clear_gpu_memory() | |
logger.info(f"Unloaded model: {model_name}") | |
except Exception as e: | |
logger.error(f"Error unloading model {model_name}: {str(e)}") | |
def get_model(self, model_name): | |
"""Get loaded model""" | |
return self.models.get(model_name) | |
def get_tokenizer(self, model_name): | |
"""Get loaded tokenizer""" | |
return self.tokenizers.get(model_name) | |
class FinancialAnalyzer: | |
"""Main analyzer class for financial statements""" | |
def __init__(self): | |
self.model_manager = ModelManager() | |
self.models = { | |
"sentiment": "ProsusAI/finbert", | |
"analysis": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
"recommendation": "tiiuae/falcon-rw-1b" | |
} | |
# Load sentiment model at initialization | |
try: | |
self.model_manager.load_model(self.models["sentiment"], "sentiment") | |
except Exception as e: | |
logger.error(f"Failed to initialize sentiment model: {str(e)}") | |
raise | |
def read_csv(self, file_obj): | |
"""Read and validate CSV file""" | |
try: | |
if file_obj is None: | |
raise ValueError("No file provided") | |
df = pd.read_csv(file_obj) | |
if df.empty: | |
raise ValueError("Empty CSV file") | |
return df.describe() | |
except Exception as e: | |
logger.error(f"Error reading CSV: {str(e)}") | |
raise | |
def analyze_sentiment(self, text): | |
"""Analyze sentiment using FinBERT""" | |
try: | |
model_name = self.models["sentiment"] | |
model = self.model_manager.get_model(model_name) | |
tokenizer = self.model_manager.get_tokenizer(model_name) | |
inputs = tokenizer( | |
text, | |
return_tensors="pt", | |
truncation=True, | |
max_length=512, | |
padding=True | |
).to(DEVICE) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
probabilities = torch.nn.functional.softmax(outputs.logits, dim=1) | |
labels = ['negative', 'neutral', 'positive'] | |
scores = probabilities[0].cpu().tolist() | |
results = [ | |
{'label': label, 'score': score} | |
for label, score in zip(labels, scores) | |
] | |
return [results] | |
except Exception as e: | |
logger.error(f"Sentiment analysis error: {str(e)}") | |
return [{"label": "error", "score": 1.0}] | |
def generate_analysis(self, financial_data): | |
"""Generate strategic analysis""" | |
try: | |
model_name = self.models["analysis"] | |
self.model_manager.load_model(model_name, "generation") | |
prompt = f"""[INST] Analyze these financial statements: | |
{financial_data} | |
Provide: | |
1. Business Health Assessment | |
2. Key Strategic Insights | |
3. Market Position | |
4. Growth Opportunities | |
5. Risk Factors [/INST]""" | |
response = self.model_manager.get_model(model_name)( | |
prompt, | |
max_length=1000, | |
temperature=0.7, | |
do_sample=True, | |
num_return_sequences=1, | |
truncation=True | |
) | |
return response[0]['generated_text'] | |
except Exception as e: | |
logger.error(f"Analysis generation error: {str(e)}") | |
return "Error in analysis generation" | |
finally: | |
self.model_manager.unload_model(model_name) | |
def generate_recommendations(self, analysis): | |
"""Generate recommendations""" | |
try: | |
model_name = self.models["recommendation"] | |
self.model_manager.load_model(model_name, "generation") | |
prompt = f"""Based on this analysis: | |
{analysis} | |
Provide actionable recommendations for: | |
1. Strategic Initiatives | |
2. Operational Improvements | |
3. Financial Management | |
4. Risk Mitigation | |
5. Growth Strategy""" | |
response = self.model_manager.get_model(model_name)( | |
prompt, | |
max_length=1000, | |
temperature=0.6, | |
do_sample=True, | |
num_return_sequences=1, | |
truncation=True | |
) | |
return response[0]['generated_text'] | |
except Exception as e: | |
logger.error(f"Recommendations generation error: {str(e)}") | |
return "Error generating recommendations" | |
finally: | |
self.model_manager.unload_model(model_name) | |
def analyze_financial_statements(income_statement, balance_sheet): | |
"""Main analysis function""" | |
try: | |
analyzer = FinancialAnalyzer() | |
# Validate inputs | |
if not income_statement or not balance_sheet: | |
return "Error: Please provide both income statement and balance sheet files" | |
# Process financial statements | |
logger.info("Processing financial statements...") | |
income_summary = analyzer.read_csv(income_statement) | |
balance_summary = analyzer.read_csv(balance_sheet) | |
financial_data = f""" | |
Income Statement Summary: | |
{income_summary.to_string()} | |
Balance Sheet Summary: | |
{balance_summary.to_string()} | |
""" | |
# Generate analysis | |
logger.info("Generating analysis...") | |
analysis = analyzer.generate_analysis(financial_data) | |
# Analyze sentiment | |
logger.info("Analyzing sentiment...") | |
sentiment = analyzer.analyze_sentiment(analysis) | |
# Generate recommendations | |
logger.info("Generating recommendations...") | |
recommendations = analyzer.generate_recommendations(analysis) | |
# Format results | |
return format_results(analysis, sentiment, recommendations) | |
except Exception as e: | |
logger.error(f"Analysis error: {str(e)}") | |
return f"""Analysis Error: | |
{str(e)} | |
Please verify: | |
1. Files are valid CSV format | |
2. Files contain required financial data | |
3. File size is within limits""" | |
def format_results(analysis, sentiment, recommendations): | |
"""Format analysis results""" | |
try: | |
if not isinstance(analysis, str) or not isinstance(recommendations, str): | |
raise ValueError("Invalid input types") | |
output = [ | |
"# Financial Analysis Report\n\n", | |
"## Strategic Analysis\n\n", | |
f"{analysis.strip()}\n\n", | |
"## Market Sentiment\n\n" | |
] | |
if isinstance(sentiment, list) and sentiment: | |
for score in sentiment[0]: | |
if isinstance(score, dict) and 'label' in score and 'score' in score: | |
output.append(f"- {score['label']}: {score['score']:.2%}\n") | |
output.append("\n") | |
output.append("## Strategic Recommendations\n\n") | |
output.append(f"{recommendations.strip()}") | |
return "".join(output) | |
except Exception as e: | |
logger.error(f"Formatting error: {str(e)}") | |
return "Error formatting results" | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=analyze_financial_statements, | |
inputs=[ | |
gr.File(label="Income Statement (CSV)"), | |
gr.File(label="Balance Sheet (CSV)") | |
], | |
outputs=gr.Markdown(), | |
title="Financial Statement Analyzer", | |
description="""Upload financial statements for AI-powered analysis: | |
- Strategic Analysis (TinyLlama) | |
- Sentiment Analysis (FinBERT) | |
- Strategic Recommendations (Falcon) | |
Note: Please ensure files are in CSV format.""", | |
flagging_mode="never" | |
) | |
if __name__ == "__main__": | |
try: | |
iface.queue() | |
iface.launch( | |
share=False, | |
server_name="0.0.0.0", | |
server_port=7860 | |
) | |
except Exception as e: | |
logger.error(f"Launch error: {str(e)}") | |
sys.exit(1) | |