mobenta commited on
Commit
4165151
·
verified ·
1 Parent(s): 31a0c50

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +241 -0
app.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import yfinance as yf
3
+ import matplotlib.pyplot as plt
4
+ import mplfinance as mpf
5
+ from PIL import Image
6
+ import gradio as gr
7
+ import logging
8
+
9
+
10
+ import torch
11
+ from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
12
+ from PIL import Image
13
+ import gradio as gr
14
+
15
+ # Load the ChartGemma model and processor
16
+ model = PaliGemmaForConditionalGeneration.from_pretrained("ahmed-masry/chartgemma", torch_dtype=torch.float16)
17
+ processor = AutoProcessor.from_pretrained("ahmed-masry/chartgemma")
18
+
19
+
20
+ # Configure logging to write to a file
21
+ logging.basicConfig(filename='debug.log', level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
22
+
23
+ # Use GPU if available
24
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+ # model = YourModel.from_pretrained('your-model') # Uncomment and update this line with the actual model loading
26
+ # processor = YourProcessor.from_pretrained('your-processor') # Uncomment and update this line with the actual processor loading
27
+ # model = model.to(device)
28
+
29
+ # Function to fetch stock data with different intervals
30
+ def fetch_stock_data(ticker='TSLA', start='2023-01-01', end='2024-01-01', interval='1d'):
31
+ try:
32
+ logging.debug(f"Fetching data for {ticker} from {start} to {end} with interval {interval}")
33
+ stock = yf.Ticker(ticker)
34
+ data = stock.history(start=start, end=end, interval=interval)
35
+ logging.debug(f"Fetched data with {len(data)} rows")
36
+ return data
37
+ except Exception as e:
38
+ logging.error(f"Error fetching data: {e}")
39
+ raise
40
+
41
+ # Function to create a candlestick chart with increased size and add timeframe and ticker information
42
+ def create_stock_chart(data, ticker, filename='chart.png', timeframe='1d'):
43
+ try:
44
+ logging.debug(f"Creating chart for {ticker} with timeframe {timeframe} and saving to {filename}")
45
+ title = f"{ticker.upper()} Price Data (Timeframe: {timeframe})"
46
+
47
+ # Set the title font size using rcParams
48
+ plt.rcParams["axes.titlesize"] = 10
49
+
50
+ # Define the style for the chart
51
+ my_style = mpf.make_mpf_style(base_mpf_style='charles')
52
+
53
+ fig, axlist = mpf.plot(data, type='candle', style=my_style, volume=True, returnfig=True)
54
+
55
+ # Set the title for the figure with padding
56
+ fig.suptitle(title, y=0.98) # Adjust the y parameter to move the title down
57
+
58
+ fig.savefig(filename, dpi=300) # Increased DPI
59
+ plt.close(fig)
60
+
61
+ # Resize image to 3 times its original size
62
+ image = Image.open(filename)
63
+ new_size = (image.width * 3, image.height * 3)
64
+ resized_image = image.resize(new_size, Image.LANCZOS)
65
+ resized_image.save(filename)
66
+ logging.debug(f"Resized image with timeframe {timeframe} and ticker {ticker} saved to {filename}")
67
+ except Exception as e:
68
+ logging.error(f"Error creating or resizing chart: {e}")
69
+ raise
70
+
71
+ # Function to combine two images side by side with increased size
72
+ def combine_images(image1_path, image2_path, output_path='combined_chart.png'):
73
+ try:
74
+ logging.debug(f"Combining images {image1_path} and {image2_path} into {output_path}")
75
+ image1 = Image.open(image1_path)
76
+ image2 = Image.open(image2_path)
77
+
78
+ total_width = image1.width + image2.width
79
+ max_height = max(image1.height, image2.height)
80
+
81
+ combined_image = Image.new('RGB', (total_width, max_height))
82
+ combined_image.paste(image1, (0, 0))
83
+ combined_image.paste(image2, (image1.width, 0))
84
+
85
+ combined_image.save(output_path)
86
+ logging.debug(f"Combined image saved to {output_path}")
87
+ return output_path
88
+ except Exception as e:
89
+ logging.error(f"Error combining images: {e}")
90
+ raise
91
+
92
+ # Function to generate insights
93
+ def generate_insights(image, query, ticker1=None, ticker2=None):
94
+ try:
95
+ logging.debug(f"Generating insights for query: {query}")
96
+
97
+ # Open and process the image
98
+ image = Image.open(image).convert('RGB')
99
+ inputs = processor(text=query, images=image, return_tensors="pt")
100
+ logging.debug(f"Inputs prepared with shapes {inputs['input_ids'].shape} and {inputs['pixel_values'].shape}")
101
+
102
+ prompt_length = inputs['input_ids'].shape[1]
103
+ inputs = {k: v.to(device) for k, v in inputs.items()}
104
+
105
+ # Generate insights using the model
106
+ generate_ids = model.generate(**inputs, num_beams=4, max_new_tokens=512)
107
+ output_text = processor.batch_decode(generate_ids[:, prompt_length:], skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
108
+
109
+ # Replace placeholders with actual ticker names in the insights
110
+ if ticker1:
111
+ output_text = output_text.replace("[First Ticker]", ticker1)
112
+ if ticker2:
113
+ output_text = output_text.replace("[Second Ticker]", ticker2)
114
+
115
+ logging.debug(f"Generated insights: {output_text}")
116
+
117
+ return output_text
118
+ except Exception as e:
119
+ logging.error(f"Error generating insights: {e}")
120
+ return f"Error generating insights: {e}"
121
+
122
+ # Function to handle the Gradio interface
123
+ def gradio_interface(ticker1, start_date, end_date, ticker2, query, analysis_type, interval):
124
+ try:
125
+ logging.debug(f"Starting gradio_interface with ticker1: {ticker1}, start_date: {start_date}, end_date: {end_date}, ticker2: {ticker2}, query: {query}, analysis_type: {analysis_type}, interval: {interval}")
126
+
127
+ # Fetch data and create charts
128
+ data1 = fetch_stock_data(ticker1, start=start_date, end=end_date, interval=interval)
129
+ chart_path1 = '/tmp/chart1.png'
130
+ create_stock_chart(data1, ticker1, chart_path1, timeframe=interval)
131
+
132
+ if analysis_type == 'Comparative Analysis' and ticker2:
133
+ data2 = fetch_stock_data(ticker2, start=start_date, end=end_date, interval=interval)
134
+ chart_path2 = '/tmp/chart2.png'
135
+ create_stock_chart(data2, ticker2, chart_path2, timeframe=interval)
136
+
137
+ # Combine the two charts into one image
138
+ combined_chart_path = combine_images(chart_path1, chart_path2)
139
+ insights = generate_insights(combined_chart_path, query, ticker1, ticker2)
140
+ return insights, combined_chart_path
141
+
142
+ insights = generate_insights(chart_path1, query, ticker1)
143
+ return insights, chart_path1
144
+ except Exception as e:
145
+ logging.error(f"Error processing image or query: {e}")
146
+ return f"Error processing image or query: {e}", None
147
+
148
+ # Button callback functions
149
+ def set_query_trend():
150
+ return "What are the key trends shown in this chart?"
151
+
152
+ def set_query_comparative():
153
+ return "How does [First Ticker] compare to [Second Ticker]?"
154
+
155
+ def set_query_forecasting():
156
+ return "Based on the current data, what are the projected trends?"
157
+
158
+ # Create the Gradio interface
159
+ with gr.Blocks() as interface:
160
+ gr.Markdown(
161
+ """
162
+ # 📈 Price Market Analysis Tool
163
+ Welcome to the Price Market Analysis Tool! This interface helps you generate insightful analyses of market data. Choose between trend analysis, comparative analysis, and forecasting based on your needs.
164
+ """
165
+ )
166
+
167
+ with gr.Row():
168
+ # Input box for first ticker
169
+ ticker1_input = gr.Textbox(
170
+ lines=1,
171
+ placeholder="Enter first ticker (e.g., TSLA)",
172
+ label="First Ticker",
173
+ )
174
+
175
+ # Input box for second ticker
176
+ ticker2_input = gr.Textbox(
177
+ lines=1,
178
+ placeholder="Enter second ticker for comparative analysis (optional)",
179
+ label="Second Ticker (Optional)",
180
+ )
181
+
182
+ # Input box for start date
183
+ start_date_input = gr.Textbox(
184
+ lines=1,
185
+ placeholder="Enter start date (e.g., 2023-01-01)",
186
+ label="Start Date",
187
+ )
188
+
189
+ # Input box for end date
190
+ end_date_input = gr.Textbox(
191
+ lines=1,
192
+ placeholder="Enter end date (e.g., 2024-01-01)",
193
+ label="End Date",
194
+ )
195
+
196
+ # Input box for text query
197
+ query_input = gr.Textbox(
198
+ lines=2,
199
+ placeholder="Enter your question here...",
200
+ label="Input Text",
201
+ )
202
+
203
+ # Hidden input for analysis type
204
+ analysis_type_input = gr.Textbox(
205
+ lines=1,
206
+ visible=False,
207
+ label="Analysis Type"
208
+ )
209
+
210
+ # Dropdown for selecting time frame
211
+ interval_input = gr.Dropdown(
212
+ choices=['1m', '5m', '15m', '30m', '60m', '1d', '1wk', '1mo', '3mo'],
213
+ value='1d',
214
+ label="Select Time Frame",
215
+ )
216
+
217
+ with gr.Row():
218
+ trend_button = gr.Button("Trend Analysis")
219
+ comparative_button = gr.Button("Comparative Analysis")
220
+ forecasting_button = gr.Button("Forecasting")
221
+
222
+ # Output areas for insights and chart
223
+ output_text = gr.Textbox(
224
+ lines=10,
225
+ label="Generated Insights"
226
+ )
227
+ output_image = gr.Image(label="Price Chart")
228
+
229
+ # Button actions to set query text and analysis type
230
+ trend_button.click(lambda: ("Trend Analysis", set_query_trend()), outputs=[analysis_type_input, query_input])
231
+ comparative_button.click(lambda: ("Comparative Analysis", set_query_comparative()), outputs=[analysis_type_input, query_input])
232
+ forecasting_button.click(lambda: ("Forecasting", set_query_forecasting()), outputs=[analysis_type_input, query_input])
233
+
234
+ # Process inputs and generate insights, display chart(s)
235
+ gr.Interface(gradio_interface,
236
+ inputs=[ticker1_input, start_date_input, end_date_input, ticker2_input, query_input, analysis_type_input, interval_input],
237
+ outputs=[output_text, output_image]).launch()
238
+
239
+ # Launch Gradio interface
240
+ if __name__ == "__main__":
241
+ interface.launch()