mobenta commited on
Commit
7496630
·
verified ·
1 Parent(s): 8bf799e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -76
app.py CHANGED
@@ -6,7 +6,7 @@ from PIL import Image
6
  import gradio as gr
7
  import datetime
8
  import logging
9
- from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
10
  import spaces
11
  import pandas as pd
12
 
@@ -14,8 +14,8 @@ import pandas as pd
14
  logging.basicConfig(filename='debug.log', level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
15
 
16
  # Load the ChartGemma model and processor
17
- model = PaliGemmaForConditionalGeneration.from_pretrained("ahmed-masry/chartgemma")
18
- processor = AutoProcessor.from_pretrained("ahmed-masry/chartgemma")
19
 
20
  @spaces.GPU
21
  def predict(image, input_text):
@@ -59,11 +59,6 @@ def create_stock_chart(data, ticker, filename='chart.png', timeframe='1d', indic
59
  # Calculate indicators if selected
60
  addplot = []
61
  if indicators:
62
- if 'VWAP' in indicators:
63
- vwap = (data['Close'] * data['Volume']).cumsum() / data['Volume'].cumsum()
64
- addplot.append(mpf.make_addplot(vwap, color='purple'))
65
- if 'Volume' in indicators:
66
- addplot.append(mpf.make_addplot(data['Volume'], panel=1, type='bar', color='g', ylabel='Volume'))
67
  if 'RSI' in indicators:
68
  delta = data['Close'].diff(1)
69
  gain = delta.where(delta > 0, 0)
@@ -73,53 +68,6 @@ def create_stock_chart(data, ticker, filename='chart.png', timeframe='1d', indic
73
  rs = avg_gain / avg_loss
74
  rsi = 100 - (100 / (1 + rs))
75
  addplot.append(mpf.make_addplot(rsi, panel=2, color='orange', ylabel='RSI'))
76
- if 'Ichimoku Cloud' in indicators:
77
- logging.debug("Calculating Ichimoku Cloud")
78
- nine_period_high = data['High'].rolling(window=9).max()
79
- nine_period_low = data['Low'].rolling(window=9).min()
80
- data['tenkan_sen'] = (nine_period_high + nine_period_low) / 2
81
-
82
- period26_high = data['High'].rolling(window=26).max()
83
- period26_low = data['Low'].rolling(window=26).min()
84
- data['kijun_sen'] = (period26_high + period26_low) / 2
85
-
86
- data['senkou_span_a'] = ((data['tenkan_sen'] + data['kijun_sen']) / 2).shift(26)
87
-
88
- period52_high = data['High'].rolling(window=52).max()
89
- period52_low = data['Low'].rolling(window=52).min()
90
- data['senkou_span_b'] = ((period52_high + period52_low) / 2).shift(26)
91
-
92
- data['chikou_span'] = data['Close'].shift(-26)
93
-
94
- addplot.append(mpf.make_addplot(data['tenkan_sen'], color='red'))
95
- addplot.append(mpf.make_addplot(data['kijun_sen'], color='blue'))
96
- addplot.append(mpf.make_addplot(data['senkou_span_a'], color='green'))
97
- addplot.append(mpf.make_addplot(data['senkou_span_b'], color='brown'))
98
- addplot.append(mpf.make_addplot(data['chikou_span'], color='purple'))
99
- if 'Bollinger Bands' in indicators:
100
- logging.debug("Calculating Bollinger Bands")
101
- rolling_mean = data['Close'].rolling(window=20).mean()
102
- rolling_std = data['Close'].rolling(window=20).std()
103
- data['upper_band'] = rolling_mean + (rolling_std * 2)
104
- data['lower_band'] = rolling_mean - (rolling_std * 2)
105
- addplot.append(mpf.make_addplot(data['upper_band'], color='blue'))
106
- addplot.append(mpf.make_addplot(data['lower_band'], color='blue'))
107
- if 'Pivot Levels' in indicators:
108
- logging.debug("Calculating Pivot Levels")
109
- data['pivot'] = (data['High'] + data['Low'] + data['Close']) / 3
110
- data['r1'] = (2 * data['pivot']) - data['Low']
111
- data['s1'] = (2 * data['pivot']) - data['High']
112
- data['r2'] = data['pivot'] + (data['High'] - data['Low'])
113
- data['s2'] = data['pivot'] - (data['High'] - data['Low'])
114
- data['r3'] = data['High'] + 2 * (data['pivot'] - data['Low'])
115
- data['s3'] = data['Low'] - 2 * (data['High'] - data['pivot'])
116
- addplot.append(mpf.make_addplot(data['pivot'], color='black'))
117
- addplot.append(mpf.make_addplot(data['r1'], color='green'))
118
- addplot.append(mpf.make_addplot(data['s1'], color='red'))
119
- addplot.append(mpf.make_addplot(data['r2'], color='green', linestyle='dashed'))
120
- addplot.append(mpf.make_addplot(data['s2'], color='red', linestyle='dashed'))
121
- addplot.append(mpf.make_addplot(data['r3'], color='green', linestyle='dotted'))
122
- addplot.append(mpf.make_addplot(data['s3'], color='red', linestyle='dotted'))
123
  if 'SMA21' in indicators:
124
  logging.debug("Calculating SMA 21")
125
  sma_21 = data['Close'].rolling(window=21).mean()
@@ -132,6 +80,18 @@ def create_stock_chart(data, ticker, filename='chart.png', timeframe='1d', indic
132
  logging.debug("Calculating SMA 200")
133
  sma_200 = data['Close'].rolling(window=200).mean()
134
  addplot.append(mpf.make_addplot(sma_200, color='brown', linestyle='dashed'))
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
  fig, axlist = mpf.plot(data, type='candle', style=my_style, volume=True, addplot=addplot, returnfig=True)
137
  fig.suptitle(title, y=0.98)
@@ -224,6 +184,7 @@ with gr.Blocks() as interface:
224
  label="Second Ticker (Optional)",
225
  )
226
 
 
227
  start_date_input = gr.Textbox(
228
  lines=1,
229
  placeholder="Enter start date (e.g., 2010-01-01)",
@@ -233,53 +194,107 @@ with gr.Blocks() as interface:
233
 
234
  end_date_input = gr.Textbox(
235
  lines=1,
236
- placeholder=f"Enter end date (e.g., {default_end_date})",
237
  value=default_end_date,
238
  label="End Date",
239
  )
240
 
241
- query_input = gr.Textbox(
242
- lines=2,
243
- placeholder="Enter your question here...",
 
244
  label="Input Text",
245
  )
246
 
247
  analysis_type_input = gr.Textbox(
248
  lines=1,
 
 
 
 
 
 
 
 
 
249
  visible=False,
250
- label="Analysis Type"
251
  )
252
 
 
253
  interval_input = gr.Dropdown(
254
- choices=['1m', '5m', '15m', '30m', '60m', '1d', '1wk', '1mo', '3mo'],
255
- value='1d',
256
  label="Select Time Frame",
257
  )
258
 
 
259
  indicator_input = gr.CheckboxGroup(
260
- choices=['VWAP', 'Volume', 'RSI', 'Ichimoku Cloud', 'Bollinger Bands', 'Pivot Levels', 'SMA21', 'SMA50', 'SMA200'],
261
  label="Select Indicators",
262
- value=['Volume'] # Default to show Volume
263
  )
264
 
265
  with gr.Row():
266
  trend_button = gr.Button("Trend Analysis")
267
  comparative_button = gr.Button("Comparative Analysis")
268
  forecasting_button = gr.Button("Forecasting")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
 
270
- output_text = gr.Textbox(
271
- lines=10,
272
- label="Generated Insights"
 
 
 
 
 
 
 
 
 
273
  )
274
- output_image = gr.Image(label="Price Chart")
275
 
276
- trend_button.click(lambda: ("Trend Analysis", set_query_trend()), outputs=[analysis_type_input, query_input])
277
- comparative_button.click(lambda: ("Comparative Analysis", set_query_comparative()), outputs=[analysis_type_input, query_input])
278
- forecasting_button.click(lambda: ("Forecasting", set_query_forecasting()), outputs=[analysis_type_input, query_input])
 
 
279
 
280
- gr.Interface(gradio_interface,
281
- inputs=[ticker1_input, start_date_input, end_date_input, ticker2_input, query_input, analysis_type_input, interval_input, indicator_input],
282
- outputs=[output_text, output_image])
 
 
283
 
284
- if __name__ == "__main__":
285
- interface.launch()
 
6
  import gradio as gr
7
  import datetime
8
  import logging
9
+ from transformers import AutoProcessor, AutoModelForPreTraining
10
  import spaces
11
  import pandas as pd
12
 
 
14
  logging.basicConfig(filename='debug.log', level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
15
 
16
  # Load the ChartGemma model and processor
17
+ processor = AutoProcessor.from_pretrained("mobenta/chart_analysis")
18
+ model = AutoModelForPreTraining.from_pretrained("mobenta/chart_analysis")
19
 
20
  @spaces.GPU
21
  def predict(image, input_text):
 
59
  # Calculate indicators if selected
60
  addplot = []
61
  if indicators:
 
 
 
 
 
62
  if 'RSI' in indicators:
63
  delta = data['Close'].diff(1)
64
  gain = delta.where(delta > 0, 0)
 
68
  rs = avg_gain / avg_loss
69
  rsi = 100 - (100 / (1 + rs))
70
  addplot.append(mpf.make_addplot(rsi, panel=2, color='orange', ylabel='RSI'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  if 'SMA21' in indicators:
72
  logging.debug("Calculating SMA 21")
73
  sma_21 = data['Close'].rolling(window=21).mean()
 
80
  logging.debug("Calculating SMA 200")
81
  sma_200 = data['Close'].rolling(window=200).mean()
82
  addplot.append(mpf.make_addplot(sma_200, color='brown', linestyle='dashed'))
83
+ if 'VWAP' in indicators:
84
+ logging.debug("Calculating VWAP")
85
+ vwap = (data['Volume'] * (data['High'] + data['Low'] + data['Close']) / 3).cumsum() / data['Volume'].cumsum()
86
+ addplot.append(mpf.make_addplot(vwap, color='blue', linestyle='dashed'))
87
+ if 'Bollinger Bands' in indicators:
88
+ logging.debug("Calculating Bollinger Bands")
89
+ sma = data['Close'].rolling(window=20).mean()
90
+ std = data['Close'].rolling(window=20).std()
91
+ upper_band = sma + (std * 2)
92
+ lower_band = sma - (std * 2)
93
+ addplot.append(mpf.make_addplot(upper_band, color='green', linestyle='dashed'))
94
+ addplot.append(mpf.make_addplot(lower_band, color='green', linestyle='dashed'))
95
 
96
  fig, axlist = mpf.plot(data, type='candle', style=my_style, volume=True, addplot=addplot, returnfig=True)
97
  fig.suptitle(title, y=0.98)
 
184
  label="Second Ticker (Optional)",
185
  )
186
 
187
+ with gr.Row():
188
  start_date_input = gr.Textbox(
189
  lines=1,
190
  placeholder="Enter start date (e.g., 2010-01-01)",
 
194
 
195
  end_date_input = gr.Textbox(
196
  lines=1,
197
+ placeholder=f"Enter end date (default: {default_end_date})",
198
  value=default_end_date,
199
  label="End Date",
200
  )
201
 
202
+ with gr.Row():
203
+ input_text = gr.Textbox(
204
+ lines=3,
205
+ placeholder="Enter your input text",
206
  label="Input Text",
207
  )
208
 
209
  analysis_type_input = gr.Textbox(
210
  lines=1,
211
+ placeholder="Analysis Type",
212
+ label="",
213
+ visible=False,
214
+ )
215
+
216
+ query_input = gr.Textbox(
217
+ lines=3,
218
+ placeholder="Query",
219
+ label="",
220
  visible=False,
 
221
  )
222
 
223
+ with gr.Row():
224
  interval_input = gr.Dropdown(
225
+ choices=["1d", "1wk", "1mo"],
226
+ value="1d",
227
  label="Select Time Frame",
228
  )
229
 
230
+ with gr.Row():
231
  indicator_input = gr.CheckboxGroup(
232
+ choices=["VWAP", "Volume", "RSI", "Ichimoku Cloud", "Bollinger Bands", "Pivot Levels", "SMA21", "SMA50", "SMA200"],
233
  label="Select Indicators",
 
234
  )
235
 
236
  with gr.Row():
237
  trend_button = gr.Button("Trend Analysis")
238
  comparative_button = gr.Button("Comparative Analysis")
239
  forecasting_button = gr.Button("Forecasting")
240
+ submit_button = gr.Button("Submit")
241
+ clear_button = gr.Button("Clear")
242
+
243
+ output_text = gr.Textbox(lines=5, label="Generated Insights")
244
+ output_image = gr.Image(type="filepath", label="Price Chart")
245
+
246
+ trend_button.click(
247
+ fn=lambda: "Trend Analysis",
248
+ inputs=[],
249
+ outputs=[analysis_type_input],
250
+ ).then(
251
+ set_query_trend,
252
+ inputs=[],
253
+ outputs=[query_input],
254
+ ).then(
255
+ gradio_interface,
256
+ inputs=[ticker1_input, start_date_input, end_date_input, ticker2_input, query_input, analysis_type_input, interval_input, indicator_input],
257
+ outputs=[output_text, output_image],
258
+ )
259
+
260
+ comparative_button.click(
261
+ fn=lambda: "Comparative Analysis",
262
+ inputs=[],
263
+ outputs=[analysis_type_input],
264
+ ).then(
265
+ set_query_comparative,
266
+ inputs=[],
267
+ outputs=[query_input],
268
+ ).then(
269
+ gradio_interface,
270
+ inputs=[ticker1_input, start_date_input, end_date_input, ticker2_input, query_input, analysis_type_input, interval_input, indicator_input],
271
+ outputs=[output_text, output_image],
272
+ )
273
 
274
+ forecasting_button.click(
275
+ fn=lambda: "Forecasting",
276
+ inputs=[],
277
+ outputs=[analysis_type_input],
278
+ ).then(
279
+ set_query_forecasting,
280
+ inputs=[],
281
+ outputs=[query_input],
282
+ ).then(
283
+ gradio_interface,
284
+ inputs=[ticker1_input, start_date_input, end_date_input, ticker2_input, query_input, analysis_type_input, interval_input, indicator_input],
285
+ outputs=[output_text, output_image],
286
  )
 
287
 
288
+ submit_button.click(
289
+ gradio_interface,
290
+ inputs=[ticker1_input, start_date_input, end_date_input, ticker2_input, input_text, analysis_type_input, interval_input, indicator_input],
291
+ outputs=[output_text, output_image],
292
+ )
293
 
294
+ clear_button.click(
295
+ fn=lambda: ("", "", "", "", "", "", "", []),
296
+ inputs=[],
297
+ outputs=[ticker1_input, start_date_input, end_date_input, ticker2_input, input_text, analysis_type_input, interval_input, indicator_input],
298
+ )
299
 
300
+ interface.launch(debug=True)