Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -6,7 +6,7 @@ from PIL import Image
|
|
6 |
import gradio as gr
|
7 |
import datetime
|
8 |
import logging
|
9 |
-
from transformers import AutoProcessor,
|
10 |
import spaces
|
11 |
import pandas as pd
|
12 |
|
@@ -14,8 +14,8 @@ import pandas as pd
|
|
14 |
logging.basicConfig(filename='debug.log', level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
|
15 |
|
16 |
# Load the ChartGemma model and processor
|
17 |
-
|
18 |
-
|
19 |
|
20 |
@spaces.GPU
|
21 |
def predict(image, input_text):
|
@@ -59,11 +59,6 @@ def create_stock_chart(data, ticker, filename='chart.png', timeframe='1d', indic
|
|
59 |
# Calculate indicators if selected
|
60 |
addplot = []
|
61 |
if indicators:
|
62 |
-
if 'VWAP' in indicators:
|
63 |
-
vwap = (data['Close'] * data['Volume']).cumsum() / data['Volume'].cumsum()
|
64 |
-
addplot.append(mpf.make_addplot(vwap, color='purple'))
|
65 |
-
if 'Volume' in indicators:
|
66 |
-
addplot.append(mpf.make_addplot(data['Volume'], panel=1, type='bar', color='g', ylabel='Volume'))
|
67 |
if 'RSI' in indicators:
|
68 |
delta = data['Close'].diff(1)
|
69 |
gain = delta.where(delta > 0, 0)
|
@@ -73,53 +68,6 @@ def create_stock_chart(data, ticker, filename='chart.png', timeframe='1d', indic
|
|
73 |
rs = avg_gain / avg_loss
|
74 |
rsi = 100 - (100 / (1 + rs))
|
75 |
addplot.append(mpf.make_addplot(rsi, panel=2, color='orange', ylabel='RSI'))
|
76 |
-
if 'Ichimoku Cloud' in indicators:
|
77 |
-
logging.debug("Calculating Ichimoku Cloud")
|
78 |
-
nine_period_high = data['High'].rolling(window=9).max()
|
79 |
-
nine_period_low = data['Low'].rolling(window=9).min()
|
80 |
-
data['tenkan_sen'] = (nine_period_high + nine_period_low) / 2
|
81 |
-
|
82 |
-
period26_high = data['High'].rolling(window=26).max()
|
83 |
-
period26_low = data['Low'].rolling(window=26).min()
|
84 |
-
data['kijun_sen'] = (period26_high + period26_low) / 2
|
85 |
-
|
86 |
-
data['senkou_span_a'] = ((data['tenkan_sen'] + data['kijun_sen']) / 2).shift(26)
|
87 |
-
|
88 |
-
period52_high = data['High'].rolling(window=52).max()
|
89 |
-
period52_low = data['Low'].rolling(window=52).min()
|
90 |
-
data['senkou_span_b'] = ((period52_high + period52_low) / 2).shift(26)
|
91 |
-
|
92 |
-
data['chikou_span'] = data['Close'].shift(-26)
|
93 |
-
|
94 |
-
addplot.append(mpf.make_addplot(data['tenkan_sen'], color='red'))
|
95 |
-
addplot.append(mpf.make_addplot(data['kijun_sen'], color='blue'))
|
96 |
-
addplot.append(mpf.make_addplot(data['senkou_span_a'], color='green'))
|
97 |
-
addplot.append(mpf.make_addplot(data['senkou_span_b'], color='brown'))
|
98 |
-
addplot.append(mpf.make_addplot(data['chikou_span'], color='purple'))
|
99 |
-
if 'Bollinger Bands' in indicators:
|
100 |
-
logging.debug("Calculating Bollinger Bands")
|
101 |
-
rolling_mean = data['Close'].rolling(window=20).mean()
|
102 |
-
rolling_std = data['Close'].rolling(window=20).std()
|
103 |
-
data['upper_band'] = rolling_mean + (rolling_std * 2)
|
104 |
-
data['lower_band'] = rolling_mean - (rolling_std * 2)
|
105 |
-
addplot.append(mpf.make_addplot(data['upper_band'], color='blue'))
|
106 |
-
addplot.append(mpf.make_addplot(data['lower_band'], color='blue'))
|
107 |
-
if 'Pivot Levels' in indicators:
|
108 |
-
logging.debug("Calculating Pivot Levels")
|
109 |
-
data['pivot'] = (data['High'] + data['Low'] + data['Close']) / 3
|
110 |
-
data['r1'] = (2 * data['pivot']) - data['Low']
|
111 |
-
data['s1'] = (2 * data['pivot']) - data['High']
|
112 |
-
data['r2'] = data['pivot'] + (data['High'] - data['Low'])
|
113 |
-
data['s2'] = data['pivot'] - (data['High'] - data['Low'])
|
114 |
-
data['r3'] = data['High'] + 2 * (data['pivot'] - data['Low'])
|
115 |
-
data['s3'] = data['Low'] - 2 * (data['High'] - data['pivot'])
|
116 |
-
addplot.append(mpf.make_addplot(data['pivot'], color='black'))
|
117 |
-
addplot.append(mpf.make_addplot(data['r1'], color='green'))
|
118 |
-
addplot.append(mpf.make_addplot(data['s1'], color='red'))
|
119 |
-
addplot.append(mpf.make_addplot(data['r2'], color='green', linestyle='dashed'))
|
120 |
-
addplot.append(mpf.make_addplot(data['s2'], color='red', linestyle='dashed'))
|
121 |
-
addplot.append(mpf.make_addplot(data['r3'], color='green', linestyle='dotted'))
|
122 |
-
addplot.append(mpf.make_addplot(data['s3'], color='red', linestyle='dotted'))
|
123 |
if 'SMA21' in indicators:
|
124 |
logging.debug("Calculating SMA 21")
|
125 |
sma_21 = data['Close'].rolling(window=21).mean()
|
@@ -132,6 +80,18 @@ def create_stock_chart(data, ticker, filename='chart.png', timeframe='1d', indic
|
|
132 |
logging.debug("Calculating SMA 200")
|
133 |
sma_200 = data['Close'].rolling(window=200).mean()
|
134 |
addplot.append(mpf.make_addplot(sma_200, color='brown', linestyle='dashed'))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
fig, axlist = mpf.plot(data, type='candle', style=my_style, volume=True, addplot=addplot, returnfig=True)
|
137 |
fig.suptitle(title, y=0.98)
|
@@ -224,6 +184,7 @@ with gr.Blocks() as interface:
|
|
224 |
label="Second Ticker (Optional)",
|
225 |
)
|
226 |
|
|
|
227 |
start_date_input = gr.Textbox(
|
228 |
lines=1,
|
229 |
placeholder="Enter start date (e.g., 2010-01-01)",
|
@@ -233,53 +194,107 @@ with gr.Blocks() as interface:
|
|
233 |
|
234 |
end_date_input = gr.Textbox(
|
235 |
lines=1,
|
236 |
-
placeholder=f"Enter end date (
|
237 |
value=default_end_date,
|
238 |
label="End Date",
|
239 |
)
|
240 |
|
241 |
-
|
242 |
-
|
243 |
-
|
|
|
244 |
label="Input Text",
|
245 |
)
|
246 |
|
247 |
analysis_type_input = gr.Textbox(
|
248 |
lines=1,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
249 |
visible=False,
|
250 |
-
label="Analysis Type"
|
251 |
)
|
252 |
|
|
|
253 |
interval_input = gr.Dropdown(
|
254 |
-
choices=[
|
255 |
-
value=
|
256 |
label="Select Time Frame",
|
257 |
)
|
258 |
|
|
|
259 |
indicator_input = gr.CheckboxGroup(
|
260 |
-
choices=[
|
261 |
label="Select Indicators",
|
262 |
-
value=['Volume'] # Default to show Volume
|
263 |
)
|
264 |
|
265 |
with gr.Row():
|
266 |
trend_button = gr.Button("Trend Analysis")
|
267 |
comparative_button = gr.Button("Comparative Analysis")
|
268 |
forecasting_button = gr.Button("Forecasting")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
269 |
|
270 |
-
|
271 |
-
|
272 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
273 |
)
|
274 |
-
output_image = gr.Image(label="Price Chart")
|
275 |
|
276 |
-
|
277 |
-
|
278 |
-
|
|
|
|
|
279 |
|
280 |
-
|
281 |
-
|
282 |
-
|
|
|
|
|
283 |
|
284 |
-
|
285 |
-
interface.launch()
|
|
|
6 |
import gradio as gr
|
7 |
import datetime
|
8 |
import logging
|
9 |
+
from transformers import AutoProcessor, AutoModelForPreTraining
|
10 |
import spaces
|
11 |
import pandas as pd
|
12 |
|
|
|
14 |
logging.basicConfig(filename='debug.log', level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
|
15 |
|
16 |
# Load the ChartGemma model and processor
|
17 |
+
processor = AutoProcessor.from_pretrained("mobenta/chart_analysis")
|
18 |
+
model = AutoModelForPreTraining.from_pretrained("mobenta/chart_analysis")
|
19 |
|
20 |
@spaces.GPU
|
21 |
def predict(image, input_text):
|
|
|
59 |
# Calculate indicators if selected
|
60 |
addplot = []
|
61 |
if indicators:
|
|
|
|
|
|
|
|
|
|
|
62 |
if 'RSI' in indicators:
|
63 |
delta = data['Close'].diff(1)
|
64 |
gain = delta.where(delta > 0, 0)
|
|
|
68 |
rs = avg_gain / avg_loss
|
69 |
rsi = 100 - (100 / (1 + rs))
|
70 |
addplot.append(mpf.make_addplot(rsi, panel=2, color='orange', ylabel='RSI'))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
if 'SMA21' in indicators:
|
72 |
logging.debug("Calculating SMA 21")
|
73 |
sma_21 = data['Close'].rolling(window=21).mean()
|
|
|
80 |
logging.debug("Calculating SMA 200")
|
81 |
sma_200 = data['Close'].rolling(window=200).mean()
|
82 |
addplot.append(mpf.make_addplot(sma_200, color='brown', linestyle='dashed'))
|
83 |
+
if 'VWAP' in indicators:
|
84 |
+
logging.debug("Calculating VWAP")
|
85 |
+
vwap = (data['Volume'] * (data['High'] + data['Low'] + data['Close']) / 3).cumsum() / data['Volume'].cumsum()
|
86 |
+
addplot.append(mpf.make_addplot(vwap, color='blue', linestyle='dashed'))
|
87 |
+
if 'Bollinger Bands' in indicators:
|
88 |
+
logging.debug("Calculating Bollinger Bands")
|
89 |
+
sma = data['Close'].rolling(window=20).mean()
|
90 |
+
std = data['Close'].rolling(window=20).std()
|
91 |
+
upper_band = sma + (std * 2)
|
92 |
+
lower_band = sma - (std * 2)
|
93 |
+
addplot.append(mpf.make_addplot(upper_band, color='green', linestyle='dashed'))
|
94 |
+
addplot.append(mpf.make_addplot(lower_band, color='green', linestyle='dashed'))
|
95 |
|
96 |
fig, axlist = mpf.plot(data, type='candle', style=my_style, volume=True, addplot=addplot, returnfig=True)
|
97 |
fig.suptitle(title, y=0.98)
|
|
|
184 |
label="Second Ticker (Optional)",
|
185 |
)
|
186 |
|
187 |
+
with gr.Row():
|
188 |
start_date_input = gr.Textbox(
|
189 |
lines=1,
|
190 |
placeholder="Enter start date (e.g., 2010-01-01)",
|
|
|
194 |
|
195 |
end_date_input = gr.Textbox(
|
196 |
lines=1,
|
197 |
+
placeholder=f"Enter end date (default: {default_end_date})",
|
198 |
value=default_end_date,
|
199 |
label="End Date",
|
200 |
)
|
201 |
|
202 |
+
with gr.Row():
|
203 |
+
input_text = gr.Textbox(
|
204 |
+
lines=3,
|
205 |
+
placeholder="Enter your input text",
|
206 |
label="Input Text",
|
207 |
)
|
208 |
|
209 |
analysis_type_input = gr.Textbox(
|
210 |
lines=1,
|
211 |
+
placeholder="Analysis Type",
|
212 |
+
label="",
|
213 |
+
visible=False,
|
214 |
+
)
|
215 |
+
|
216 |
+
query_input = gr.Textbox(
|
217 |
+
lines=3,
|
218 |
+
placeholder="Query",
|
219 |
+
label="",
|
220 |
visible=False,
|
|
|
221 |
)
|
222 |
|
223 |
+
with gr.Row():
|
224 |
interval_input = gr.Dropdown(
|
225 |
+
choices=["1d", "1wk", "1mo"],
|
226 |
+
value="1d",
|
227 |
label="Select Time Frame",
|
228 |
)
|
229 |
|
230 |
+
with gr.Row():
|
231 |
indicator_input = gr.CheckboxGroup(
|
232 |
+
choices=["VWAP", "Volume", "RSI", "Ichimoku Cloud", "Bollinger Bands", "Pivot Levels", "SMA21", "SMA50", "SMA200"],
|
233 |
label="Select Indicators",
|
|
|
234 |
)
|
235 |
|
236 |
with gr.Row():
|
237 |
trend_button = gr.Button("Trend Analysis")
|
238 |
comparative_button = gr.Button("Comparative Analysis")
|
239 |
forecasting_button = gr.Button("Forecasting")
|
240 |
+
submit_button = gr.Button("Submit")
|
241 |
+
clear_button = gr.Button("Clear")
|
242 |
+
|
243 |
+
output_text = gr.Textbox(lines=5, label="Generated Insights")
|
244 |
+
output_image = gr.Image(type="filepath", label="Price Chart")
|
245 |
+
|
246 |
+
trend_button.click(
|
247 |
+
fn=lambda: "Trend Analysis",
|
248 |
+
inputs=[],
|
249 |
+
outputs=[analysis_type_input],
|
250 |
+
).then(
|
251 |
+
set_query_trend,
|
252 |
+
inputs=[],
|
253 |
+
outputs=[query_input],
|
254 |
+
).then(
|
255 |
+
gradio_interface,
|
256 |
+
inputs=[ticker1_input, start_date_input, end_date_input, ticker2_input, query_input, analysis_type_input, interval_input, indicator_input],
|
257 |
+
outputs=[output_text, output_image],
|
258 |
+
)
|
259 |
+
|
260 |
+
comparative_button.click(
|
261 |
+
fn=lambda: "Comparative Analysis",
|
262 |
+
inputs=[],
|
263 |
+
outputs=[analysis_type_input],
|
264 |
+
).then(
|
265 |
+
set_query_comparative,
|
266 |
+
inputs=[],
|
267 |
+
outputs=[query_input],
|
268 |
+
).then(
|
269 |
+
gradio_interface,
|
270 |
+
inputs=[ticker1_input, start_date_input, end_date_input, ticker2_input, query_input, analysis_type_input, interval_input, indicator_input],
|
271 |
+
outputs=[output_text, output_image],
|
272 |
+
)
|
273 |
|
274 |
+
forecasting_button.click(
|
275 |
+
fn=lambda: "Forecasting",
|
276 |
+
inputs=[],
|
277 |
+
outputs=[analysis_type_input],
|
278 |
+
).then(
|
279 |
+
set_query_forecasting,
|
280 |
+
inputs=[],
|
281 |
+
outputs=[query_input],
|
282 |
+
).then(
|
283 |
+
gradio_interface,
|
284 |
+
inputs=[ticker1_input, start_date_input, end_date_input, ticker2_input, query_input, analysis_type_input, interval_input, indicator_input],
|
285 |
+
outputs=[output_text, output_image],
|
286 |
)
|
|
|
287 |
|
288 |
+
submit_button.click(
|
289 |
+
gradio_interface,
|
290 |
+
inputs=[ticker1_input, start_date_input, end_date_input, ticker2_input, input_text, analysis_type_input, interval_input, indicator_input],
|
291 |
+
outputs=[output_text, output_image],
|
292 |
+
)
|
293 |
|
294 |
+
clear_button.click(
|
295 |
+
fn=lambda: ("", "", "", "", "", "", "", []),
|
296 |
+
inputs=[],
|
297 |
+
outputs=[ticker1_input, start_date_input, end_date_input, ticker2_input, input_text, analysis_type_input, interval_input, indicator_input],
|
298 |
+
)
|
299 |
|
300 |
+
interface.launch(debug=True)
|
|