walaa2022 commited on
Commit
2ac4fcc
·
verified ·
1 Parent(s): 37d8fd3

Create app1.py

Browse files
Files changed (1) hide show
  1. app1.py +311 -0
app1.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import pandas as pd
4
+ import torch
5
+ import logging
6
+ from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
7
+ import gc
8
+
9
+ # Setup logging
10
+ logging.basicConfig(
11
+ level=logging.INFO,
12
+ format='%(asctime)s - %(levelname)s - %(message)s'
13
+ )
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # Device configuration
17
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
+ logger.info(f"Using device: {DEVICE}")
19
+
20
+ def clear_gpu_memory():
21
+ """Utility function to clear GPU memory"""
22
+ if DEVICE == "cuda":
23
+ torch.cuda.empty_cache()
24
+ gc.collect()
25
+
26
+ class ModelManager:
27
+ """Handles model loading and inference"""
28
+
29
+ def __init__(self):
30
+ self.device = DEVICE
31
+ self.models = {}
32
+ self.tokenizers = {}
33
+
34
+ def load_model(self, model_name, model_type="sentiment"):
35
+ """Load model and tokenizer"""
36
+ try:
37
+ if model_name not in self.models:
38
+ if model_type == "sentiment":
39
+ self.tokenizers[model_name] = AutoTokenizer.from_pretrained(model_name)
40
+ self.models[model_name] = AutoModelForSequenceClassification.from_pretrained(
41
+ model_name,
42
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
43
+ ).to(self.device)
44
+ else:
45
+ self.models[model_name] = pipeline(
46
+ "text-generation",
47
+ model=model_name,
48
+ device_map="auto" if self.device == "cuda" else None,
49
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
50
+ )
51
+ logger.info(f"Loaded model: {model_name}")
52
+ except Exception as e:
53
+ logger.error(f"Error loading model {model_name}: {str(e)}")
54
+ raise
55
+
56
+ def unload_model(self, model_name):
57
+ """Unload model and tokenizer"""
58
+ try:
59
+ if model_name in self.models:
60
+ del self.models[model_name]
61
+ if model_name in self.tokenizers:
62
+ del self.tokenizers[model_name]
63
+ clear_gpu_memory()
64
+ logger.info(f"Unloaded model: {model_name}")
65
+ except Exception as e:
66
+ logger.error(f"Error unloading model {model_name}: {str(e)}")
67
+
68
+ def get_model(self, model_name):
69
+ """Get loaded model"""
70
+ return self.models.get(model_name)
71
+
72
+ def get_tokenizer(self, model_name):
73
+ """Get loaded tokenizer"""
74
+ return self.tokenizers.get(model_name)
75
+
76
+ class FinancialAnalyzer:
77
+ """Main analyzer class for financial statements"""
78
+
79
+ def __init__(self):
80
+ self.model_manager = ModelManager()
81
+ self.models = {
82
+ "sentiment": "ProsusAI/finbert",
83
+ "analysis": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
84
+ "recommendation": "tiiuae/falcon-rw-1b"
85
+ }
86
+
87
+ # Load sentiment model at initialization
88
+ try:
89
+ self.model_manager.load_model(self.models["sentiment"], "sentiment")
90
+ except Exception as e:
91
+ logger.error(f"Failed to initialize sentiment model: {str(e)}")
92
+ raise
93
+
94
+ def read_csv(self, file_obj):
95
+ """Read and validate CSV file"""
96
+ try:
97
+ if file_obj is None:
98
+ raise ValueError("No file provided")
99
+
100
+ df = pd.read_csv(file_obj)
101
+
102
+ if df.empty:
103
+ raise ValueError("Empty CSV file")
104
+
105
+ return df.describe()
106
+ except Exception as e:
107
+ logger.error(f"Error reading CSV: {str(e)}")
108
+ raise
109
+
110
+
111
+ def analyze_sentiment(self, text):
112
+ """Analyze sentiment using FinBERT"""
113
+ try:
114
+ model_name = self.models["sentiment"]
115
+ model = self.model_manager.get_model(model_name)
116
+ tokenizer = self.model_manager.get_tokenizer(model_name)
117
+
118
+ inputs = tokenizer(
119
+ text,
120
+ return_tensors="pt",
121
+ truncation=True,
122
+ max_length=512,
123
+ padding=True
124
+ ).to(DEVICE)
125
+
126
+ with torch.no_grad():
127
+ outputs = model(**inputs)
128
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
129
+
130
+ labels = ['negative', 'neutral', 'positive']
131
+ scores = probabilities[0].cpu().tolist()
132
+
133
+ results = [
134
+ {'label': label, 'score': score}
135
+ for label, score in zip(labels, scores)
136
+ ]
137
+
138
+ return [results]
139
+ except Exception as e:
140
+ logger.error(f"Sentiment analysis error: {str(e)}")
141
+ return [{"label": "error", "score": 1.0}]
142
+
143
+ def generate_analysis(self, financial_data):
144
+ """Generate strategic analysis"""
145
+ try:
146
+ model_name = self.models["analysis"]
147
+ self.model_manager.load_model(model_name, "generation")
148
+
149
+ prompt = f"""[INST] Analyze these financial statements:
150
+ {financial_data}
151
+ Provide:
152
+ 1. Business Health Assessment
153
+ 2. Key Strategic Insights
154
+ 3. Market Position
155
+ 4. Growth Opportunities
156
+ 5. Risk Factors [/INST]"""
157
+
158
+ response = self.model_manager.get_model(model_name)(
159
+ prompt,
160
+ max_length=1000,
161
+ temperature=0.7,
162
+ do_sample=True,
163
+ num_return_sequences=1,
164
+ truncation=True
165
+ )
166
+
167
+ return response[0]['generated_text']
168
+ except Exception as e:
169
+ logger.error(f"Analysis generation error: {str(e)}")
170
+ return "Error in analysis generation"
171
+ finally:
172
+ self.model_manager.unload_model(model_name)
173
+
174
+ def generate_recommendations(self, analysis):
175
+ """Generate recommendations"""
176
+ try:
177
+ model_name = self.models["recommendation"]
178
+ self.model_manager.load_model(model_name, "generation")
179
+
180
+ prompt = f"""Based on this analysis:
181
+ {analysis}
182
+
183
+ Provide actionable recommendations for:
184
+ 1. Strategic Initiatives
185
+ 2. Operational Improvements
186
+ 3. Financial Management
187
+ 4. Risk Mitigation
188
+ 5. Growth Strategy"""
189
+
190
+ response = self.model_manager.get_model(model_name)(
191
+ prompt,
192
+ max_length=1000,
193
+ temperature=0.6,
194
+ do_sample=True,
195
+ num_return_sequences=1,
196
+ truncation=True
197
+ )
198
+
199
+ return response[0]['generated_text']
200
+ except Exception as e:
201
+ logger.error(f"Recommendations generation error: {str(e)}")
202
+ return "Error generating recommendations"
203
+ finally:
204
+ self.model_manager.unload_model(model_name)
205
+
206
+
207
+
208
+ def analyze_financial_statements(income_statement, balance_sheet):
209
+ """Main analysis function"""
210
+ try:
211
+ analyzer = FinancialAnalyzer()
212
+
213
+ # Validate inputs
214
+ if not income_statement or not balance_sheet:
215
+ return "Error: Please provide both income statement and balance sheet files"
216
+
217
+ # Process financial statements
218
+ logger.info("Processing financial statements...")
219
+ income_summary = analyzer.read_csv(income_statement)
220
+ balance_summary = analyzer.read_csv(balance_sheet)
221
+
222
+ financial_data = f"""
223
+ Income Statement Summary:
224
+ {income_summary.to_string()}
225
+
226
+ Balance Sheet Summary:
227
+ {balance_summary.to_string()}
228
+ """
229
+
230
+ # Generate analysis
231
+ logger.info("Generating analysis...")
232
+ analysis = analyzer.generate_analysis(financial_data)
233
+
234
+ # Analyze sentiment
235
+ logger.info("Analyzing sentiment...")
236
+ sentiment = analyzer.analyze_sentiment(analysis)
237
+
238
+ # Generate recommendations
239
+ logger.info("Generating recommendations...")
240
+ recommendations = analyzer.generate_recommendations(analysis)
241
+
242
+ # Format results
243
+ return format_results(analysis, sentiment, recommendations)
244
+
245
+ except Exception as e:
246
+ logger.error(f"Analysis error: {str(e)}")
247
+ return f"""Analysis Error:
248
+
249
+ {str(e)}
250
+
251
+ Please verify:
252
+ 1. Files are valid CSV format
253
+ 2. Files contain required financial data
254
+ 3. File size is within limits"""
255
+
256
+ def format_results(analysis, sentiment, recommendations):
257
+ """Format analysis results"""
258
+ try:
259
+ if not isinstance(analysis, str) or not isinstance(recommendations, str):
260
+ raise ValueError("Invalid input types")
261
+
262
+ output = [
263
+ "# Financial Analysis Report\n\n",
264
+ "## Strategic Analysis\n\n",
265
+ f"{analysis.strip()}\n\n",
266
+ "## Market Sentiment\n\n"
267
+ ]
268
+
269
+ if isinstance(sentiment, list) and sentiment:
270
+ for score in sentiment[0]:
271
+ if isinstance(score, dict) and 'label' in score and 'score' in score:
272
+ output.append(f"- {score['label']}: {score['score']:.2%}\n")
273
+ output.append("\n")
274
+
275
+ output.append("## Strategic Recommendations\n\n")
276
+ output.append(f"{recommendations.strip()}")
277
+
278
+ return "".join(output)
279
+ except Exception as e:
280
+ logger.error(f"Formatting error: {str(e)}")
281
+ return "Error formatting results"
282
+
283
+ # Create Gradio interface
284
+ iface = gr.Interface(
285
+ fn=analyze_financial_statements,
286
+ inputs=[
287
+ gr.File(label="Income Statement (CSV)"),
288
+ gr.File(label="Balance Sheet (CSV)")
289
+ ],
290
+ outputs=gr.Markdown(),
291
+ title="Financial Statement Analyzer",
292
+ description="""Upload financial statements for AI-powered analysis:
293
+ - Strategic Analysis (TinyLlama)
294
+ - Sentiment Analysis (FinBERT)
295
+ - Strategic Recommendations (Falcon)
296
+
297
+ Note: Please ensure files are in CSV format.""",
298
+ flagging_mode="never"
299
+ )
300
+
301
+ if __name__ == "__main__":
302
+ try:
303
+ iface.queue()
304
+ iface.launch(
305
+ share=False,
306
+ server_name="0.0.0.0",
307
+ server_port=7860
308
+ )
309
+ except Exception as e:
310
+ logger.error(f"Launch error: {str(e)}")
311
+ sys.exit(1)