mobenta commited on
Commit
da9ddef
·
verified ·
1 Parent(s): de206ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -19
app.py CHANGED
@@ -6,7 +6,7 @@ from PIL import Image, ImageDraw, ImageFont
6
  import gradio as gr
7
  import datetime
8
  import logging
9
- from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
10
  import tempfile
11
  import os
12
 
@@ -14,8 +14,9 @@ import os
14
  logging.basicConfig(filename='debug.log', level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
15
 
16
  # Load the ChartGemma model and processor
17
- model = PaliGemmaForConditionalGeneration.from_pretrained("ahmed-masry/chartgemma")
18
- processor = AutoProcessor.from_pretrained("ahmed-masry/chartgemma")
 
19
 
20
  def predict(image, input_text):
21
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -153,31 +154,27 @@ def combine_images(image_paths, output_path='combined_chart.png'):
153
 
154
  combined_image.save(output_path)
155
  logging.debug(f"Combined image saved to {output_path}")
 
156
  except Exception as e:
157
  logging.error(f"Error combining images: {e}")
158
  raise
159
 
160
  def gradio_interface(ticker1, ticker2, ticker3, ticker4, start_date, end_date, query, analysis_type, interval, indicators):
161
  try:
162
- logging.debug("Starting Gradio interface analysis")
163
- tickers = [ticker1, ticker2, ticker3, ticker4]
164
- tickers = [ticker for ticker in tickers if ticker] # Remove empty tickers
165
- logging.debug(f"Tickers provided for analysis: {tickers}")
166
 
 
167
  chart_paths = []
168
 
169
- # Create and save charts for each ticker
170
- for ticker in tickers:
171
- try:
172
- data = fetch_stock_data(ticker, start_date, end_date, interval)
173
- chart_file = f"{ticker}_chart.png"
174
- create_stock_chart(data, ticker, chart_file, interval, indicators)
175
- chart_paths.append(chart_file)
176
- except Exception as e:
177
- logging.error(f"Error processing ticker {ticker}: {e}")
178
-
179
- # If comparative analysis is selected, combine charts
180
- if analysis_type == "Comparative Analysis" and len(chart_paths) > 1:
181
  with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as temp_combined_chart:
182
  combined_chart_path = temp_combined_chart.name
183
  combine_images(chart_paths, combined_chart_path)
 
6
  import gradio as gr
7
  import datetime
8
  import logging
9
+ from transformers import AutoProcessor, AutoModelForPreTraining
10
  import tempfile
11
  import os
12
 
 
14
  logging.basicConfig(filename='debug.log', level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
15
 
16
  # Load the ChartGemma model and processor
17
+ processor = AutoProcessor.from_pretrained("mobenta/chart_analysis")
18
+ model = AutoModelForPreTraining.from_pretrained("mobenta/chart_analysis")
19
+
20
 
21
  def predict(image, input_text):
22
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
154
 
155
  combined_image.save(output_path)
156
  logging.debug(f"Combined image saved to {output_path}")
157
+ return output_path
158
  except Exception as e:
159
  logging.error(f"Error combining images: {e}")
160
  raise
161
 
162
  def gradio_interface(ticker1, ticker2, ticker3, ticker4, start_date, end_date, query, analysis_type, interval, indicators):
163
  try:
164
+ logging.debug(f"Starting gradio_interface with tickers: {ticker1}, {ticker2}, {ticker3}, {ticker4}, start_date: {start_date}, end_date: {end_date}, query: {query}, analysis_type: {analysis_type}, interval: {interval}")
 
 
 
165
 
166
+ tickers = [ticker1, ticker2, ticker3, ticker4]
167
  chart_paths = []
168
 
169
+ for i, ticker in enumerate(tickers):
170
+ if ticker:
171
+ data = fetch_stock_data(ticker, start=start_date, end=end_date, interval=interval)
172
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as temp_chart:
173
+ chart_path = temp_chart.name
174
+ create_stock_chart(data, ticker, chart_path, timeframe=interval, indicators=indicators)
175
+ chart_paths.append(chart_path)
176
+
177
+ if analysis_type == 'Comparative Analysis' and len(chart_paths) > 1:
 
 
 
178
  with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as temp_combined_chart:
179
  combined_chart_path = temp_combined_chart.name
180
  combine_images(chart_paths, combined_chart_path)