mobenta commited on
Commit
454c94e
·
verified ·
1 Parent(s): 5996956

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -40
app.py CHANGED
@@ -4,45 +4,36 @@ import matplotlib.pyplot as plt
4
  import mplfinance as mpf
5
  from PIL import Image
6
  import gradio as gr
 
7
  import logging
8
  from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
9
  import spaces
10
 
11
- # Configure logging to write to a file
12
  logging.basicConfig(filename='debug.log', level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
13
 
14
- # Load the ChartGemma model and processor outside the GPU context
15
  model = PaliGemmaForConditionalGeneration.from_pretrained("ahmed-masry/chartgemma")
16
  processor = AutoProcessor.from_pretrained("ahmed-masry/chartgemma")
17
 
18
  @spaces.GPU
19
- def predict(image, input_text, ticker1=None, ticker2=None):
20
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
  model.to(device)
22
 
23
  image = image.convert("RGB")
24
-
25
  inputs = processor(text=input_text, images=image, return_tensors="pt")
26
  inputs = {k: v.to(device) for k, v in inputs.items()}
27
-
28
  prompt_length = inputs['input_ids'].shape[1]
29
-
30
- # Generate insights using the model
31
- generate_ids = model.generate(**inputs, num_beams=4, max_new_tokens=512)
32
  output_text = processor.batch_decode(generate_ids[:, prompt_length:], skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
33
 
34
- # Replace placeholders with actual ticker names in the insights
35
- if ticker1:
36
- output_text = output_text.replace("[First Ticker]", ticker1)
37
- if ticker2:
38
- output_text = output_text.replace("[Second Ticker]", ticker2)
39
-
40
- logging.debug(f"Generated insights: {output_text}")
41
-
42
  return output_text
43
 
44
- # Function to fetch stock data with different intervals
45
- def fetch_stock_data(ticker='TSLA', start='2023-01-01', end='2024-01-01', interval='1d'):
 
46
  try:
47
  logging.debug(f"Fetching data for {ticker} from {start} to {end} with interval {interval}")
48
  stock = yf.Ticker(ticker)
@@ -53,7 +44,6 @@ def fetch_stock_data(ticker='TSLA', start='2023-01-01', end='2024-01-01', interv
53
  logging.error(f"Error fetching data: {e}")
54
  raise
55
 
56
- # Function to create a candlestick chart with increased size and add timeframe and ticker information
57
  def create_stock_chart(data, ticker, filename='chart.png', timeframe='1d'):
58
  try:
59
  logging.debug(f"Creating chart for {ticker} with timeframe {timeframe} and saving to {filename}")
@@ -66,7 +56,6 @@ def create_stock_chart(data, ticker, filename='chart.png', timeframe='1d'):
66
  fig.savefig(filename, dpi=300)
67
  plt.close(fig)
68
 
69
- # Resize image to 3 times its original size
70
  image = Image.open(filename)
71
  new_size = (image.width * 3, image.height * 3)
72
  resized_image = image.resize(new_size, Image.LANCZOS)
@@ -76,7 +65,6 @@ def create_stock_chart(data, ticker, filename='chart.png', timeframe='1d'):
76
  logging.error(f"Error creating or resizing chart: {e}")
77
  raise
78
 
79
- # Function to combine two images side by side with increased size
80
  def combine_images(image1_path, image2_path, output_path='combined_chart.png'):
81
  try:
82
  logging.debug(f"Combining images {image1_path} and {image2_path} into {output_path}")
@@ -97,7 +85,6 @@ def combine_images(image1_path, image2_path, output_path='combined_chart.png'):
97
  logging.error(f"Error combining images: {e}")
98
  raise
99
 
100
- # Function to handle the Gradio interface
101
  def gradio_interface(ticker1, start_date, end_date, ticker2, query, analysis_type, interval):
102
  try:
103
  logging.debug(f"Starting gradio_interface with ticker1: {ticker1}, start_date: {start_date}, end_date: {end_date}, ticker2: {ticker2}, query: {query}, analysis_type: {analysis_type}, interval: {interval}")
@@ -112,18 +99,16 @@ def gradio_interface(ticker1, start_date, end_date, ticker2, query, analysis_typ
112
  chart_path2 = '/tmp/chart2.png'
113
  create_stock_chart(data2, ticker2, chart_path2, timeframe=interval)
114
 
115
- # Combine the two charts into one image
116
  combined_chart_path = combine_images(chart_path1, chart_path2)
117
- insights = predict(Image.open(combined_chart_path), query, ticker1, ticker2)
118
  return insights, combined_chart_path
119
 
120
- insights = predict(Image.open(chart_path1), query, ticker1)
121
  return insights, chart_path1
122
  except Exception as e:
123
  logging.error(f"Error processing image or query: {e}")
124
  return f"Error processing image or query: {e}", None
125
 
126
- # Button callback functions
127
  def set_query_trend():
128
  return "What are the key trends shown in this chart?"
129
 
@@ -133,7 +118,10 @@ def set_query_comparative():
133
  def set_query_forecasting():
134
  return "Based on the current data, what are the projected trends?"
135
 
136
- # Create the Gradio interface
 
 
 
137
  with gr.Blocks() as interface:
138
  gr.Markdown(
139
  """
@@ -143,49 +131,44 @@ with gr.Blocks() as interface:
143
  )
144
 
145
  with gr.Row():
146
- # Input box for first ticker
147
  ticker1_input = gr.Textbox(
148
  lines=1,
149
  placeholder="Enter first ticker (e.g., TSLA)",
150
  label="First Ticker",
151
  )
152
 
153
- # Input box for second ticker
154
  ticker2_input = gr.Textbox(
155
  lines=1,
156
  placeholder="Enter second ticker for comparative analysis (optional)",
157
  label="Second Ticker (Optional)",
158
  )
159
 
160
- # Input box for start date
161
  start_date_input = gr.Textbox(
162
  lines=1,
163
- placeholder="Enter start date (e.g., 2023-01-01)",
 
164
  label="Start Date",
165
  )
166
 
167
- # Input box for end date
168
  end_date_input = gr.Textbox(
169
  lines=1,
170
- placeholder="Enter end date (e.g., 2024-01-01)",
 
171
  label="End Date",
172
  )
173
 
174
- # Input box for text query
175
  query_input = gr.Textbox(
176
  lines=2,
177
  placeholder="Enter your question here...",
178
  label="Input Text",
179
  )
180
 
181
- # Hidden input for analysis type
182
  analysis_type_input = gr.Textbox(
183
  lines=1,
184
  visible=False,
185
  label="Analysis Type"
186
  )
187
 
188
- # Dropdown for selecting time frame
189
  interval_input = gr.Dropdown(
190
  choices=['1m', '5m', '15m', '30m', '60m', '1d', '1wk', '1mo', '3mo'],
191
  value='1d',
@@ -197,23 +180,19 @@ with gr.Blocks() as interface:
197
  comparative_button = gr.Button("Comparative Analysis")
198
  forecasting_button = gr.Button("Forecasting")
199
 
200
- # Output areas for insights and chart
201
  output_text = gr.Textbox(
202
  lines=10,
203
  label="Generated Insights"
204
  )
205
  output_image = gr.Image(label="Price Chart")
206
 
207
- # Button actions to set query text and analysis type
208
  trend_button.click(lambda: ("Trend Analysis", set_query_trend()), outputs=[analysis_type_input, query_input])
209
  comparative_button.click(lambda: ("Comparative Analysis", set_query_comparative()), outputs=[analysis_type_input, query_input])
210
  forecasting_button.click(lambda: ("Forecasting", set_query_forecasting()), outputs=[analysis_type_input, query_input])
211
 
212
- # Process inputs and generate insights, display chart(s)
213
  gr.Interface(gradio_interface,
214
  inputs=[ticker1_input, start_date_input, end_date_input, ticker2_input, query_input, analysis_type_input, interval_input],
215
  outputs=[output_text, output_image])
216
 
217
- # Launch Gradio interface
218
  if __name__ == "__main__":
219
  interface.launch()
 
4
  import mplfinance as mpf
5
  from PIL import Image
6
  import gradio as gr
7
+ import datetime
8
  import logging
9
  from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
10
  import spaces
11
 
12
+ # Configure logging
13
  logging.basicConfig(filename='debug.log', level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
14
 
15
+ # Load the ChartGemma model and processor
16
  model = PaliGemmaForConditionalGeneration.from_pretrained("ahmed-masry/chartgemma")
17
  processor = AutoProcessor.from_pretrained("ahmed-masry/chartgemma")
18
 
19
  @spaces.GPU
20
+ def predict(image, input_text):
21
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
  model.to(device)
23
 
24
  image = image.convert("RGB")
 
25
  inputs = processor(text=input_text, images=image, return_tensors="pt")
26
  inputs = {k: v.to(device) for k, v in inputs.items()}
27
+
28
  prompt_length = inputs['input_ids'].shape[1]
29
+ generate_ids = model.generate(**inputs, max_new_tokens=512)
 
 
30
  output_text = processor.batch_decode(generate_ids[:, prompt_length:], skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
31
 
 
 
 
 
 
 
 
 
32
  return output_text
33
 
34
+ def fetch_stock_data(ticker='TSLA', start='2010-01-01', end=None, interval='1d'):
35
+ if end is None:
36
+ end = datetime.datetime.now().strftime('%Y-%m-%d')
37
  try:
38
  logging.debug(f"Fetching data for {ticker} from {start} to {end} with interval {interval}")
39
  stock = yf.Ticker(ticker)
 
44
  logging.error(f"Error fetching data: {e}")
45
  raise
46
 
 
47
  def create_stock_chart(data, ticker, filename='chart.png', timeframe='1d'):
48
  try:
49
  logging.debug(f"Creating chart for {ticker} with timeframe {timeframe} and saving to {filename}")
 
56
  fig.savefig(filename, dpi=300)
57
  plt.close(fig)
58
 
 
59
  image = Image.open(filename)
60
  new_size = (image.width * 3, image.height * 3)
61
  resized_image = image.resize(new_size, Image.LANCZOS)
 
65
  logging.error(f"Error creating or resizing chart: {e}")
66
  raise
67
 
 
68
  def combine_images(image1_path, image2_path, output_path='combined_chart.png'):
69
  try:
70
  logging.debug(f"Combining images {image1_path} and {image2_path} into {output_path}")
 
85
  logging.error(f"Error combining images: {e}")
86
  raise
87
 
 
88
  def gradio_interface(ticker1, start_date, end_date, ticker2, query, analysis_type, interval):
89
  try:
90
  logging.debug(f"Starting gradio_interface with ticker1: {ticker1}, start_date: {start_date}, end_date: {end_date}, ticker2: {ticker2}, query: {query}, analysis_type: {analysis_type}, interval: {interval}")
 
99
  chart_path2 = '/tmp/chart2.png'
100
  create_stock_chart(data2, ticker2, chart_path2, timeframe=interval)
101
 
 
102
  combined_chart_path = combine_images(chart_path1, chart_path2)
103
+ insights = predict(Image.open(combined_chart_path), query)
104
  return insights, combined_chart_path
105
 
106
+ insights = predict(Image.open(chart_path1), query)
107
  return insights, chart_path1
108
  except Exception as e:
109
  logging.error(f"Error processing image or query: {e}")
110
  return f"Error processing image or query: {e}", None
111
 
 
112
  def set_query_trend():
113
  return "What are the key trends shown in this chart?"
114
 
 
118
  def set_query_forecasting():
119
  return "Based on the current data, what are the projected trends?"
120
 
121
+ # Default dates
122
+ default_start_date = '2010-01-01'
123
+ default_end_date = datetime.datetime.now().strftime('%Y-%m-%d')
124
+
125
  with gr.Blocks() as interface:
126
  gr.Markdown(
127
  """
 
131
  )
132
 
133
  with gr.Row():
 
134
  ticker1_input = gr.Textbox(
135
  lines=1,
136
  placeholder="Enter first ticker (e.g., TSLA)",
137
  label="First Ticker",
138
  )
139
 
 
140
  ticker2_input = gr.Textbox(
141
  lines=1,
142
  placeholder="Enter second ticker for comparative analysis (optional)",
143
  label="Second Ticker (Optional)",
144
  )
145
 
 
146
  start_date_input = gr.Textbox(
147
  lines=1,
148
+ placeholder="Enter start date (e.g., 2010-01-01)",
149
+ value=default_start_date,
150
  label="Start Date",
151
  )
152
 
 
153
  end_date_input = gr.Textbox(
154
  lines=1,
155
+ placeholder=f"Enter end date (e.g., {default_end_date})",
156
+ value=default_end_date,
157
  label="End Date",
158
  )
159
 
 
160
  query_input = gr.Textbox(
161
  lines=2,
162
  placeholder="Enter your question here...",
163
  label="Input Text",
164
  )
165
 
 
166
  analysis_type_input = gr.Textbox(
167
  lines=1,
168
  visible=False,
169
  label="Analysis Type"
170
  )
171
 
 
172
  interval_input = gr.Dropdown(
173
  choices=['1m', '5m', '15m', '30m', '60m', '1d', '1wk', '1mo', '3mo'],
174
  value='1d',
 
180
  comparative_button = gr.Button("Comparative Analysis")
181
  forecasting_button = gr.Button("Forecasting")
182
 
 
183
  output_text = gr.Textbox(
184
  lines=10,
185
  label="Generated Insights"
186
  )
187
  output_image = gr.Image(label="Price Chart")
188
 
 
189
  trend_button.click(lambda: ("Trend Analysis", set_query_trend()), outputs=[analysis_type_input, query_input])
190
  comparative_button.click(lambda: ("Comparative Analysis", set_query_comparative()), outputs=[analysis_type_input, query_input])
191
  forecasting_button.click(lambda: ("Forecasting", set_query_forecasting()), outputs=[analysis_type_input, query_input])
192
 
 
193
  gr.Interface(gradio_interface,
194
  inputs=[ticker1_input, start_date_input, end_date_input, ticker2_input, query_input, analysis_type_input, interval_input],
195
  outputs=[output_text, output_image])
196
 
 
197
  if __name__ == "__main__":
198
  interface.launch()