humanist96 commited on
Commit
e248cd9
1 Parent(s): e0ae9ac

Upload 12 files

Browse files
Files changed (12) hide show
  1. README.md +110 -12
  2. app.py +320 -0
  3. config.json +33 -0
  4. demo.ipynb +0 -0
  5. figs/interface.png +0 -0
  6. figs/response.png +0 -0
  7. figs/title.png +0 -0
  8. prepare_data.ipynb +1545 -0
  9. requirements.txt +7 -0
  10. train.sh +21 -0
  11. train_lora.py +221 -0
  12. utils.py +162 -0
README.md CHANGED
@@ -1,12 +1,110 @@
1
- ---
2
- title: FinGPT Forecaster
3
- emoji: 🦀
4
- colorFrom: green
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 4.12.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ![title](figs/title.png)
2
+
3
+ ## What is FinGPT-Forecaster?
4
+ - FinGPT-Forecaster takes market news and optional basic financials related to the specified company from the past few weeks as input and responds with the company's **positive developments** and **potential concerns**. Then it gives out a **prediction** of stock price movement for the coming week and its **analysis** summary.
5
+ - FinGPT-Forecaster is finetuned on Llama-2-7b-chat-hf with LoRA on the past year's DOW30 market data. But also has shown great generalization ability on other ticker symbols.
6
+ - FinGPT-Forecaster is an easy-to-deploy junior robo-advisor, a milestone towards our goal.
7
+
8
+ ## Try out the demo!
9
+
10
+ Try our demo at <https://huggingface.co/spaces/FinGPT/FinGPT-Forecaster>
11
+
12
+ ![demo_interface](figs/interface.png)
13
+
14
+ Enter the following inputs:
15
+
16
+ 1) ticker symbol (e.g. AAPL, MSFT, NVDA)
17
+ 2) the day from which you want the prediction to happen (yyyy-mm-dd)
18
+ 3) the number of past weeks where market news are retrieved
19
+ 4) whether to add latest basic financials as additional information
20
+
21
+ Then, click Submit!You'll get a response like this
22
+
23
+ ![demo_response](figs/response.png)
24
+
25
+ This is just a demo showing what this model is capable of. Results inferred from randomly chosen news can be strongly biased.
26
+ For more detailed and customized usage, scroll down and continue your reading.
27
+
28
+ ## Deploy FinGPT-Forecaster
29
+
30
+ We have released our FinGPT-Forecaster trained on DOW30 market data from 2022-12-30 to 2023-9-1 on HuggingFace: [fingpt-forecaster_dow30_llama2-7b_lora](https://huggingface.co/FinGPT/fingpt-forecaster_dow30_llama2-7b_lora)
31
+
32
+ We have most of the key requirements in `requirements.txt`. Before you start, do `pip install -r requirements.txt`. Then you can refer to `demo.ipynb` for our deployment and evaluation script.
33
+
34
+ First let's load the model:
35
+
36
+ ```
37
+ from datasets import load_dataset
38
+ from transformers import AutoTokenizer, AutoModelForCausalLM
39
+ from peft import PeftModel
40
+
41
+
42
+ base_model = AutoModelForCausalLM.from_pretrained(
43
+ 'meta-llama/Llama-2-7b-chat-hf',
44
+ trust_remote_code=True,
45
+ device_map="auto",
46
+ torch_dtype=torch.float16, # optional if you have enough VRAM
47
+ )
48
+ tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-chat-hf')
49
+
50
+ model = PeftModel.from_pretrained(base_model, 'FinGPT/fingpt-forecaster_dow30_llama2-7b_lora')
51
+ model = model.eval()
52
+ ```
53
+
54
+ Then you are ready to go, prepare your prompt with news & stock price movements in llama format (which we'll mention in the next section), and generate your own forecasting results!
55
+ ```
56
+ B_INST, E_INST = "[INST]", "[/INST]"
57
+ B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
58
+
59
+ prompt = B_INST + B_SYS + {SYSTEM_PROMPT} + E_SYS + {YOUR_PROMPT} + E_INST
60
+ inputs = tokenizer(
61
+ prompt, return_tensors='pt'
62
+ )
63
+ inputs = {key: value.to(model.device) for key, value in inputs.items()}
64
+
65
+ res = model.generate(
66
+ **inputs, max_length=4096, do_sample=True,
67
+ eos_token_id=tokenizer.eos_token_id,
68
+ use_cache=True
69
+ )
70
+ output = tokenizer.decode(res[0], skip_special_tokens=True)
71
+ answer = re.sub(r'.*\[/INST\]\s*', '', output, flags=re.DOTALL) # don't forget to import re
72
+ ```
73
+
74
+ ## Data Preparation
75
+ Company profile & Market news & Basic financials & Stock prices are retrieved using **yfinance & finnhub**.
76
+
77
+ Prompts used are organized as below:
78
+
79
+ ```
80
+ SYSTEM_PROMPT = "You are a seasoned stock market analyst. Your task is to list the positive developments and potential concerns for companies based on relevant news and basic financials from the past weeks, then provide an analysis and prediction for the companies' stock price movement for the upcoming week. Your answer format should be as follows:\n\n[Positive Developments]:\n1. ...\n\n[Potential Concerns]:\n1. ...\n\n[Prediction & Analysis]:\n...\n"
81
+
82
+ prompt = """
83
+ [Company Introduction]:
84
+
85
+ {name} is a leading entity in the {finnhubIndustry} sector. Incorporated and publicly traded since {ipo}, the company has established its reputation as one of the key players in the market. As of today, {name} has a market capitalization of {marketCapitalization:.2f} in {currency}, with {shareOutstanding:.2f} shares outstanding. {name} operates primarily in the {country}, trading under the ticker {ticker} on the {exchange}. As a dominant force in the {finnhubIndustry} space, the company continues to innovate and drive progress within the industry.
86
+
87
+ From {startDate} to {endDate}, {name}'s stock price {increase/decrease} from {startPrice} to {endPrice}. Company news during this period are listed below:
88
+
89
+ [Headline]: ...
90
+ [Summary]: ...
91
+
92
+ [Headline]: ...
93
+ [Summary]: ...
94
+
95
+ Some recent basic financials of {name}, reported at {date}, are presented below:
96
+
97
+ [Basic Financials]:
98
+ {attr1}: {value1}
99
+ {attr2}: {value2}
100
+ ...
101
+
102
+ Based on all the information before {curday}, let's first analyze the positive developments and potential concerns for {symbol}. Come up with 2-4 most important factors respectively and keep them concise. Most factors should be inferred from company-related news. Then make your prediction of the {symbol} stock price movement for next week ({period}). Provide a summary analysis to support your prediction.
103
+
104
+ """
105
+ ```
106
+ ## Train your own FinGPT-Forecaster
107
+
108
+
109
+
110
+ **Disclaimer: Nothing herein is financial advice, and NOT a recommendation to trade real money. Please use common sense and always first consult a professional before trading or investing.**
app.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import time
4
+ import json
5
+ import random
6
+ import finnhub
7
+ import torch
8
+ import gradio as gr
9
+ import pandas as pd
10
+ import yfinance as yf
11
+ from pynvml import *
12
+ from peft import PeftModel
13
+ from collections import defaultdict
14
+ from datetime import date, datetime, timedelta
15
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
16
+
17
+
18
+ access_token = os.environ["HF_TOKEN"]
19
+ finnhub_client = finnhub.Client(api_key=os.environ["FINNHUB_API_KEY"])
20
+
21
+ base_model = AutoModelForCausalLM.from_pretrained(
22
+ 'meta-llama/Llama-2-7b-chat-hf',
23
+ token=access_token,
24
+ trust_remote_code=True,
25
+ device_map="auto",
26
+ torch_dtype=torch.float16,
27
+ offload_folder="offload/"
28
+ )
29
+ model = PeftModel.from_pretrained(
30
+ base_model,
31
+ 'FinGPT/fingpt-forecaster_dow30_llama2-7b_lora',
32
+ offload_folder="offload/"
33
+ )
34
+ model = model.eval()
35
+
36
+ tokenizer = AutoTokenizer.from_pretrained(
37
+ 'meta-llama/Llama-2-7b-chat-hf',
38
+ token=access_token
39
+ )
40
+
41
+ streamer = TextStreamer(tokenizer)
42
+
43
+ B_INST, E_INST = "[INST]", "[/INST]"
44
+ B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
45
+
46
+ SYSTEM_PROMPT = "You are a seasoned stock market analyst. Your task is to list the positive developments and potential concerns for companies based on relevant news and basic financials from the past weeks, then provide an analysis and prediction for the companies' stock price movement for the upcoming week. " \
47
+ "Your answer format should be as follows:\n\n[Positive Developments]:\n1. ...\n\n[Potential Concerns]:\n1. ...\n\n[Prediction & Analysis]\nPrediction: ...\nAnalysis: ..."
48
+
49
+
50
+ def print_gpu_utilization():
51
+
52
+ nvmlInit()
53
+ handle = nvmlDeviceGetHandleByIndex(0)
54
+ info = nvmlDeviceGetMemoryInfo(handle)
55
+ print(f"GPU memory occupied: {info.used//1024**2} MB.")
56
+
57
+
58
+ def get_curday():
59
+
60
+ return date.today().strftime("%Y-%m-%d")
61
+
62
+
63
+ def n_weeks_before(date_string, n):
64
+
65
+ date = datetime.strptime(date_string, "%Y-%m-%d") - timedelta(days=7*n)
66
+
67
+ return date.strftime("%Y-%m-%d")
68
+
69
+
70
+ def get_stock_data(stock_symbol, steps):
71
+
72
+ stock_data = yf.download(stock_symbol, steps[0], steps[-1])
73
+ if len(stock_data) == 0:
74
+ raise gr.Error(f"Failed to download stock price data for symbol {stock_symbol} from yfinance!")
75
+
76
+ # print(stock_data)
77
+
78
+ dates, prices = [], []
79
+ available_dates = stock_data.index.format()
80
+
81
+ for date in steps[:-1]:
82
+ for i in range(len(stock_data)):
83
+ if available_dates[i] >= date:
84
+ prices.append(stock_data['Close'][i])
85
+ dates.append(datetime.strptime(available_dates[i], "%Y-%m-%d"))
86
+ break
87
+
88
+ dates.append(datetime.strptime(available_dates[-1], "%Y-%m-%d"))
89
+ prices.append(stock_data['Close'][-1])
90
+
91
+ return pd.DataFrame({
92
+ "Start Date": dates[:-1], "End Date": dates[1:],
93
+ "Start Price": prices[:-1], "End Price": prices[1:]
94
+ })
95
+
96
+
97
+ def get_news(symbol, data):
98
+
99
+ news_list = []
100
+
101
+ for end_date, row in data.iterrows():
102
+ start_date = row['Start Date'].strftime('%Y-%m-%d')
103
+ end_date = row['End Date'].strftime('%Y-%m-%d')
104
+ # print(symbol, ': ', start_date, ' - ', end_date)
105
+ time.sleep(1) # control qpm
106
+ weekly_news = finnhub_client.company_news(symbol, _from=start_date, to=end_date)
107
+ if len(weekly_news) == 0:
108
+ raise gr.Error(f"No company news found for symbol {symbol} from finnhub!")
109
+ weekly_news = [
110
+ {
111
+ "date": datetime.fromtimestamp(n['datetime']).strftime('%Y%m%d%H%M%S'),
112
+ "headline": n['headline'],
113
+ "summary": n['summary'],
114
+ } for n in weekly_news
115
+ ]
116
+ weekly_news.sort(key=lambda x: x['date'])
117
+ news_list.append(json.dumps(weekly_news))
118
+
119
+ data['News'] = news_list
120
+
121
+ return data
122
+
123
+
124
+ def get_company_prompt(symbol):
125
+
126
+ profile = finnhub_client.company_profile2(symbol=symbol)
127
+ if not profile:
128
+ raise gr.Error(f"Failed to find company profile for symbol {symbol} from finnhub!")
129
+
130
+ company_template = "[Company Introduction]:\n\n{name} is a leading entity in the {finnhubIndustry} sector. Incorporated and publicly traded since {ipo}, the company has established its reputation as one of the key players in the market. As of today, {name} has a market capitalization of {marketCapitalization:.2f} in {currency}, with {shareOutstanding:.2f} shares outstanding." \
131
+ "\n\n{name} operates primarily in the {country}, trading under the ticker {ticker} on the {exchange}. As a dominant force in the {finnhubIndustry} space, the company continues to innovate and drive progress within the industry."
132
+
133
+ formatted_str = company_template.format(**profile)
134
+
135
+ return formatted_str
136
+
137
+
138
+ def get_prompt_by_row(symbol, row):
139
+
140
+ start_date = row['Start Date'] if isinstance(row['Start Date'], str) else row['Start Date'].strftime('%Y-%m-%d')
141
+ end_date = row['End Date'] if isinstance(row['End Date'], str) else row['End Date'].strftime('%Y-%m-%d')
142
+ term = 'increased' if row['End Price'] > row['Start Price'] else 'decreased'
143
+ head = "From {} to {}, {}'s stock price {} from {:.2f} to {:.2f}. Company news during this period are listed below:\n\n".format(
144
+ start_date, end_date, symbol, term, row['Start Price'], row['End Price'])
145
+
146
+ news = json.loads(row["News"])
147
+ news = ["[Headline]: {}\n[Summary]: {}\n".format(
148
+ n['headline'], n['summary']) for n in news if n['date'][:8] <= end_date.replace('-', '') and \
149
+ not n['summary'].startswith("Looking for stock market analysis and research with proves results?")]
150
+
151
+ basics = json.loads(row['Basics'])
152
+ if basics:
153
+ basics = "Some recent basic financials of {}, reported at {}, are presented below:\n\n[Basic Financials]:\n\n".format(
154
+ symbol, basics['period']) + "\n".join(f"{k}: {v}" for k, v in basics.items() if k != 'period')
155
+ else:
156
+ basics = "[Basic Financials]:\n\nNo basic financial reported."
157
+
158
+ return head, news, basics
159
+
160
+
161
+ def sample_news(news, k=5):
162
+
163
+ return [news[i] for i in sorted(random.sample(range(len(news)), k))]
164
+
165
+
166
+ def get_current_basics(symbol, curday):
167
+
168
+ basic_financials = finnhub_client.company_basic_financials(symbol, 'all')
169
+ if not basic_financials['series']:
170
+ raise gr.Error(f"Failed to find basic financials for symbol {symbol} from finnhub!")
171
+
172
+ final_basics, basic_list, basic_dict = [], [], defaultdict(dict)
173
+
174
+ for metric, value_list in basic_financials['series']['quarterly'].items():
175
+ for value in value_list:
176
+ basic_dict[value['period']].update({metric: value['v']})
177
+
178
+ for k, v in basic_dict.items():
179
+ v.update({'period': k})
180
+ basic_list.append(v)
181
+
182
+ basic_list.sort(key=lambda x: x['period'])
183
+
184
+ for basic in basic_list[::-1]:
185
+ if basic['period'] <= curday:
186
+ break
187
+
188
+ return basic
189
+
190
+
191
+ def get_all_prompts_online(symbol, data, curday, with_basics=True):
192
+
193
+ company_prompt = get_company_prompt(symbol)
194
+
195
+ prev_rows = []
196
+
197
+ for row_idx, row in data.iterrows():
198
+ head, news, _ = get_prompt_by_row(symbol, row)
199
+ prev_rows.append((head, news, None))
200
+
201
+ prompt = ""
202
+ for i in range(-len(prev_rows), 0):
203
+ prompt += "\n" + prev_rows[i][0]
204
+ sampled_news = sample_news(
205
+ prev_rows[i][1],
206
+ min(5, len(prev_rows[i][1]))
207
+ )
208
+ if sampled_news:
209
+ prompt += "\n".join(sampled_news)
210
+ else:
211
+ prompt += "No relative news reported."
212
+
213
+ period = "{} to {}".format(curday, n_weeks_before(curday, -1))
214
+
215
+ if with_basics:
216
+ basics = get_current_basics(symbol, curday)
217
+ basics = "Some recent basic financials of {}, reported at {}, are presented below:\n\n[Basic Financials]:\n\n".format(
218
+ symbol, basics['period']) + "\n".join(f"{k}: {v}" for k, v in basics.items() if k != 'period')
219
+ else:
220
+ basics = "[Basic Financials]:\n\nNo basic financial reported."
221
+
222
+ info = company_prompt + '\n' + prompt + '\n' + basics
223
+ prompt = info + f"\n\nBased on all the information before {curday}, let's first analyze the positive developments and potential concerns for {symbol}. Come up with 2-4 most important factors respectively and keep them concise. Most factors should be inferred from company related news. " \
224
+ f"Then make your prediction of the {symbol} stock price movement for next week ({period}). Provide a summary analysis to support your prediction."
225
+
226
+ return info, prompt
227
+
228
+
229
+ def construct_prompt(ticker, curday, n_weeks, use_basics):
230
+
231
+ try:
232
+ steps = [n_weeks_before(curday, n) for n in range(n_weeks + 1)][::-1]
233
+ except Exception:
234
+ raise gr.Error(f"Invalid date {curday}!")
235
+
236
+ data = get_stock_data(ticker, steps)
237
+ data = get_news(ticker, data)
238
+ data['Basics'] = [json.dumps({})] * len(data)
239
+ # print(data)
240
+
241
+ info, prompt = get_all_prompts_online(ticker, data, curday, use_basics)
242
+
243
+ prompt = B_INST + B_SYS + SYSTEM_PROMPT + E_SYS + prompt + E_INST
244
+ # print(prompt)
245
+
246
+ return info, prompt
247
+
248
+
249
+ def predict(ticker, date, n_weeks, use_basics):
250
+
251
+ print_gpu_utilization()
252
+
253
+ info, prompt = construct_prompt(ticker, date, n_weeks, use_basics)
254
+
255
+ inputs = tokenizer(
256
+ prompt, return_tensors='pt', padding=False
257
+ )
258
+ inputs = {key: value.to(model.device) for key, value in inputs.items()}
259
+
260
+ print("Inputs loaded onto devices.")
261
+
262
+ res = model.generate(
263
+ **inputs, max_length=4096, do_sample=True,
264
+ eos_token_id=tokenizer.eos_token_id,
265
+ use_cache=True, streamer=streamer
266
+ )
267
+ output = tokenizer.decode(res[0], skip_special_tokens=True)
268
+ answer = re.sub(r'.*\[/INST\]\s*', '', output, flags=re.DOTALL)
269
+
270
+ torch.cuda.empty_cache()
271
+
272
+ return info, answer
273
+
274
+
275
+ demo = gr.Interface(
276
+ predict,
277
+ inputs=[
278
+ gr.Textbox(
279
+ label="Ticker",
280
+ value="AAPL",
281
+ info="Companys from Dow-30 are recommended"
282
+ ),
283
+ gr.Textbox(
284
+ label="Date",
285
+ value=get_curday,
286
+ info="Date from which the prediction is made, use format yyyy-mm-dd"
287
+ ),
288
+ gr.Slider(
289
+ minimum=1,
290
+ maximum=4,
291
+ value=3,
292
+ step=1,
293
+ label="n_weeks",
294
+ info="Information of the past n weeks will be utilized, choose between 1 and 4"
295
+ ),
296
+ gr.Checkbox(
297
+ label="Use Latest Basic Financials",
298
+ value=False,
299
+ info="If checked, the latest quarterly reported basic financials of the company is taken into account."
300
+ )
301
+ ],
302
+ outputs=[
303
+ gr.Textbox(
304
+ label="Information"
305
+ ),
306
+ gr.Textbox(
307
+ label="Response"
308
+ )
309
+ ],
310
+ title="FinGPT-Forecaster",
311
+ description="""FinGPT-Forecaster takes random market news and optional basic financials related to the specified company from the past few weeks as input and responds with the company's **positive developments** and **potential concerns**. Then it gives out a **prediction** of stock price movement for the coming week and its **analysis** summary.
312
+ This model is finetuned on Llama2-7b-chat-hf with LoRA on the past year's DOW30 market data. Inference in this demo uses fp16 and **welcomes any ticker symbol**.
313
+ Company profile & Market news & Basic financials & Stock prices are retrieved using **yfinance & finnhub**.
314
+ This is just a demo showing what this model is capable of. Results inferred from randomly chosen news can be strongly biased.
315
+ For more detailed and customized implementation, refer to our FinGPT project: <https://github.com/AI4Finance-Foundation/FinGPT>
316
+ **Disclaimer: Nothing herein is financial advice, and NOT a recommendation to trade real money. Please use common sense and always first consult a professional before trading or investing.**
317
+ """
318
+ )
319
+
320
+ demo.launch()
config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train_micro_batch_size_per_gpu": "auto",
3
+ "train_batch_size": "auto",
4
+ "gradient_accumulation_steps": "auto",
5
+ "optimizer": {
6
+ "type": "ZeroOneAdam",
7
+ "params": {
8
+ "lr": "auto",
9
+ "weight_decay": "auto",
10
+ "bias_correction": false,
11
+ "var_freeze_step": 1000,
12
+ "var_update_scaler": 16,
13
+ "local_step_scaler": 1000,
14
+ "local_step_clipper": 16,
15
+ "cuda_aware": true,
16
+ "comm_backend_name": "nccl"
17
+ }
18
+ },
19
+ "scheduler": {
20
+ "type": "WarmupLR",
21
+ "params": {
22
+ "warmup_min_lr": 0,
23
+ "warmup_max_lr": "auto",
24
+ "warmup_num_steps": "auto"
25
+ }
26
+ },
27
+ "fp16": {
28
+ "enabled": true
29
+ },
30
+ "zero_optimization": {
31
+ "stage": 0
32
+ }
33
+ }
demo.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
figs/interface.png ADDED
figs/response.png ADDED
figs/title.png ADDED
prepare_data.ipynb ADDED
@@ -0,0 +1,1545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 30,
6
+ "id": "3c4d096e",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import os\n",
11
+ "import re\n",
12
+ "import csv\n",
13
+ "import math\n",
14
+ "import time\n",
15
+ "import json\n",
16
+ "import random\n",
17
+ "import finnhub\n",
18
+ "import datasets\n",
19
+ "import pandas as pd\n",
20
+ "import yfinance as yf\n",
21
+ "from datetime import datetime\n",
22
+ "from collections import defaultdict\n",
23
+ "from datasets import Dataset\n",
24
+ "from openai import OpenAI"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": 31,
30
+ "id": "ace9fdb4",
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "START_DATE = \"2022-12-31\"\n",
35
+ "END_DATE = \"2023-05-31\"\n",
36
+ "\n",
37
+ "DATA_DIR = f\"./{START_DATE}_{END_DATE}\"\n",
38
+ "os.makedirs(DATA_DIR, exist_ok=True)\n",
39
+ "\n",
40
+ "finnhub_client = finnhub.Client(api_key=\"your finnhub key\")\n",
41
+ "\n",
42
+ "client = OpenAI(api_key = 'your openai key')"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "markdown",
47
+ "id": "2fce2503",
48
+ "metadata": {},
49
+ "source": [
50
+ "# Raw Financial Data Acquisition"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "code",
55
+ "execution_count": 43,
56
+ "id": "c6564114",
57
+ "metadata": {},
58
+ "outputs": [],
59
+ "source": [
60
+ "def bin_mapping(ret):\n",
61
+ " \n",
62
+ " up_down = 'U' if ret >= 0 else 'D'\n",
63
+ " integer = math.ceil(abs(100 * ret))\n",
64
+ " \n",
65
+ " return up_down + (str(integer) if integer <= 5 else '5+')\n",
66
+ "\n",
67
+ "\n",
68
+ "def get_returns(stock_symbol):\n",
69
+ " \n",
70
+ " # Download historical stock data\n",
71
+ " stock_data = yf.download(stock_symbol, start=START_DATE, end=END_DATE)\n",
72
+ " \n",
73
+ " weekly_data = stock_data['Adj Close'].resample('W').ffill()\n",
74
+ " weekly_returns = weekly_data.pct_change()[1:]\n",
75
+ " weekly_start_prices = weekly_data[:-1]\n",
76
+ " weekly_end_prices = weekly_data[1:]\n",
77
+ "\n",
78
+ " weekly_data = pd.DataFrame({\n",
79
+ " 'Start Date': weekly_start_prices.index,\n",
80
+ " 'Start Price': weekly_start_prices.values,\n",
81
+ " 'End Date': weekly_end_prices.index,\n",
82
+ " 'End Price': weekly_end_prices.values,\n",
83
+ " 'Weekly Returns': weekly_returns.values\n",
84
+ " })\n",
85
+ " \n",
86
+ " weekly_data['Bin Label'] = weekly_data['Weekly Returns'].map(bin_mapping)\n",
87
+ "\n",
88
+ " return weekly_data\n",
89
+ "\n",
90
+ "\n",
91
+ "def get_news(symbol, data):\n",
92
+ " \n",
93
+ " news_list = []\n",
94
+ " \n",
95
+ " for end_date, row in data.iterrows():\n",
96
+ " start_date = row['Start Date'].strftime('%Y-%m-%d')\n",
97
+ " end_date = row['End Date'].strftime('%Y-%m-%d')\n",
98
+ " print(symbol, ': ', start_date, ' - ', end_date)\n",
99
+ " time.sleep(1) # control qpm\n",
100
+ " weekly_news = finnhub_client.company_news(symbol, _from=start_date, to=end_date)\n",
101
+ " weekly_news = [\n",
102
+ " {\n",
103
+ " \"date\": datetime.fromtimestamp(n['datetime']).strftime('%Y%m%d%H%M%S'),\n",
104
+ " \"headline\": n['headline'],\n",
105
+ " \"summary\": n['summary'],\n",
106
+ " } for n in weekly_news\n",
107
+ " ]\n",
108
+ " weekly_news.sort(key=lambda x: x['date'])\n",
109
+ " news_list.append(json.dumps(weekly_news))\n",
110
+ " \n",
111
+ " data['News'] = news_list\n",
112
+ " \n",
113
+ " return data\n",
114
+ "\n",
115
+ "\n",
116
+ "def get_basics(symbol, data, always=False):\n",
117
+ " \n",
118
+ " basic_financials = finnhub_client.company_basic_financials(symbol, 'all')\n",
119
+ " \n",
120
+ " final_basics, basic_list, basic_dict = [], [], defaultdict(dict)\n",
121
+ " \n",
122
+ " for metric, value_list in basic_financials['series']['quarterly'].items():\n",
123
+ " for value in value_list:\n",
124
+ " basic_dict[value['period']].update({metric: value['v']})\n",
125
+ "\n",
126
+ " for k, v in basic_dict.items():\n",
127
+ " v.update({'period': k})\n",
128
+ " basic_list.append(v)\n",
129
+ " \n",
130
+ " basic_list.sort(key=lambda x: x['period'])\n",
131
+ " \n",
132
+ " for i, row in data.iterrows():\n",
133
+ " \n",
134
+ " start_date = row['End Date'].strftime('%Y-%m-%d')\n",
135
+ " last_start_date = START_DATE if i < 2 else data.loc[i-2, 'Start Date'].strftime('%Y-%m-%d')\n",
136
+ " \n",
137
+ " used_basic = {}\n",
138
+ " for basic in basic_list[::-1]:\n",
139
+ " if (always and basic['period'] < start_date) or (last_start_date <= basic['period'] < start_date):\n",
140
+ " used_basic = basic\n",
141
+ " break\n",
142
+ " final_basics.append(json.dumps(used_basic))\n",
143
+ " \n",
144
+ " data['Basics'] = final_basics\n",
145
+ " \n",
146
+ " return data\n",
147
+ " \n",
148
+ "\n",
149
+ "def prepare_data_for_company(symbol, with_basics=True):\n",
150
+ " \n",
151
+ " data = get_returns(symbol)\n",
152
+ " data = get_news(symbol, data)\n",
153
+ " \n",
154
+ " if with_basics:\n",
155
+ " data = get_basics(symbol, data)\n",
156
+ " data.to_csv(f\"{DATA_DIR}/{symbol}_{START_DATE}_{END_DATE}.csv\")\n",
157
+ " else:\n",
158
+ " data['Basics'] = [json.dumps({})] * len(data)\n",
159
+ " data.to_csv(f\"{DATA_DIR}/{symbol}_{START_DATE}_{END_DATE}_nobasics.csv\")\n",
160
+ " \n",
161
+ " return data\n"
162
+ ]
163
+ },
164
+ {
165
+ "cell_type": "code",
166
+ "execution_count": 59,
167
+ "id": "caf02ab7",
168
+ "metadata": {},
169
+ "outputs": [],
170
+ "source": [
171
+ "DOW_30 = [\n",
172
+ " \"AXP\", \"AMGN\", \"AAPL\", \"BA\", \"CAT\", \"CSCO\", \"CVX\", \"GS\", \"HD\", \"HON\",\n",
173
+ " \"IBM\", \"INTC\", \"JNJ\", \"KO\", \"JPM\", \"MCD\", \"MMM\", \"MRK\", \"MSFT\", \"NKE\",\n",
174
+ " \"PG\", \"TRV\", \"UNH\", \"CRM\", \"VZ\", \"V\", \"WBA\", \"WMT\", \"DIS\", \"DOW\"\n",
175
+ "]\n",
176
+ "\n",
177
+ "# prepare_data_for_company(\"DOW\", False)"
178
+ ]
179
+ },
180
+ {
181
+ "cell_type": "code",
182
+ "execution_count": 81,
183
+ "id": "43d65960",
184
+ "metadata": {
185
+ "scrolled": true
186
+ },
187
+ "outputs": [
188
+ {
189
+ "name": "stdout",
190
+ "output_type": "stream",
191
+ "text": [
192
+ "[*********************100%%**********************] 1 of 1 completed\n",
193
+ "AXP : 2023-01-08 - 2023-01-15\n",
194
+ "AXP : 2023-01-15 - 2023-01-22\n",
195
+ "AXP : 2023-01-22 - 2023-01-29\n",
196
+ "AXP : 2023-01-29 - 2023-02-05\n",
197
+ "AXP : 2023-02-05 - 2023-02-12\n",
198
+ "AXP : 2023-02-12 - 2023-02-19\n",
199
+ "AXP : 2023-02-19 - 2023-02-26\n",
200
+ "AXP : 2023-02-26 - 2023-03-05\n",
201
+ "AXP : 2023-03-05 - 2023-03-12\n",
202
+ "AXP : 2023-03-12 - 2023-03-19\n",
203
+ "AXP : 2023-03-19 - 2023-03-26\n",
204
+ "AXP : 2023-03-26 - 2023-04-02\n",
205
+ "AXP : 2023-04-02 - 2023-04-09\n",
206
+ "AXP : 2023-04-09 - 2023-04-16\n",
207
+ "AXP : 2023-04-16 - 2023-04-23\n",
208
+ "AXP : 2023-04-23 - 2023-04-30\n",
209
+ "AXP : 2023-04-30 - 2023-05-07\n",
210
+ "AXP : 2023-05-07 - 2023-05-14\n",
211
+ "AXP : 2023-05-14 - 2023-05-21\n",
212
+ "AXP : 2023-05-21 - 2023-05-28\n",
213
+ "AXP : 2023-05-28 - 2023-06-04\n",
214
+ "[*********************100%%**********************] 1 of 1 completed\n",
215
+ "AMGN : 2023-01-08 - 2023-01-15\n",
216
+ "AMGN : 2023-01-15 - 2023-01-22\n",
217
+ "AMGN : 2023-01-22 - 2023-01-29\n",
218
+ "AMGN : 2023-01-29 - 2023-02-05\n",
219
+ "AMGN : 2023-02-05 - 2023-02-12\n",
220
+ "AMGN : 2023-02-12 - 2023-02-19\n",
221
+ "AMGN : 2023-02-19 - 2023-02-26\n",
222
+ "AMGN : 2023-02-26 - 2023-03-05\n",
223
+ "AMGN : 2023-03-05 - 2023-03-12\n",
224
+ "AMGN : 2023-03-12 - 2023-03-19\n",
225
+ "AMGN : 2023-03-19 - 2023-03-26\n",
226
+ "AMGN : 2023-03-26 - 2023-04-02\n",
227
+ "AMGN : 2023-04-02 - 2023-04-09\n",
228
+ "AMGN : 2023-04-09 - 2023-04-16\n",
229
+ "AMGN : 2023-04-16 - 2023-04-23\n",
230
+ "AMGN : 2023-04-23 - 2023-04-30\n",
231
+ "AMGN : 2023-04-30 - 2023-05-07\n",
232
+ "AMGN : 2023-05-07 - 2023-05-14\n",
233
+ "AMGN : 2023-05-14 - 2023-05-21\n",
234
+ "AMGN : 2023-05-21 - 2023-05-28\n",
235
+ "AMGN : 2023-05-28 - 2023-06-04\n",
236
+ "[*********************100%%**********************] 1 of 1 completed\n",
237
+ "AAPL : 2023-01-08 - 2023-01-15\n",
238
+ "AAPL : 2023-01-15 - 2023-01-22\n",
239
+ "AAPL : 2023-01-22 - 2023-01-29\n",
240
+ "AAPL : 2023-01-29 - 2023-02-05\n",
241
+ "AAPL : 2023-02-05 - 2023-02-12\n",
242
+ "AAPL : 2023-02-12 - 2023-02-19\n",
243
+ "AAPL : 2023-02-19 - 2023-02-26\n",
244
+ "AAPL : 2023-02-26 - 2023-03-05\n",
245
+ "AAPL : 2023-03-05 - 2023-03-12\n",
246
+ "AAPL : 2023-03-12 - 2023-03-19\n",
247
+ "AAPL : 2023-03-19 - 2023-03-26\n",
248
+ "AAPL : 2023-03-26 - 2023-04-02\n",
249
+ "AAPL : 2023-04-02 - 2023-04-09\n",
250
+ "AAPL : 2023-04-09 - 2023-04-16\n",
251
+ "AAPL : 2023-04-16 - 2023-04-23\n",
252
+ "AAPL : 2023-04-23 - 2023-04-30\n",
253
+ "AAPL : 2023-04-30 - 2023-05-07\n",
254
+ "AAPL : 2023-05-07 - 2023-05-14\n",
255
+ "AAPL : 2023-05-14 - 2023-05-21\n",
256
+ "AAPL : 2023-05-21 - 2023-05-28\n",
257
+ "AAPL : 2023-05-28 - 2023-06-04\n",
258
+ "[*********************100%%**********************] 1 of 1 completed\n",
259
+ "BA : 2023-01-08 - 2023-01-15\n",
260
+ "BA : 2023-01-15 - 2023-01-22\n",
261
+ "BA : 2023-01-22 - 2023-01-29\n",
262
+ "BA : 2023-01-29 - 2023-02-05\n",
263
+ "BA : 2023-02-05 - 2023-02-12\n",
264
+ "BA : 2023-02-12 - 2023-02-19\n",
265
+ "BA : 2023-02-19 - 2023-02-26\n",
266
+ "BA : 2023-02-26 - 2023-03-05\n",
267
+ "BA : 2023-03-05 - 2023-03-12\n",
268
+ "BA : 2023-03-12 - 2023-03-19\n",
269
+ "BA : 2023-03-19 - 2023-03-26\n",
270
+ "BA : 2023-03-26 - 2023-04-02\n",
271
+ "BA : 2023-04-02 - 2023-04-09\n",
272
+ "BA : 2023-04-09 - 2023-04-16\n",
273
+ "BA : 2023-04-16 - 2023-04-23\n",
274
+ "BA : 2023-04-23 - 2023-04-30\n",
275
+ "BA : 2023-04-30 - 2023-05-07\n",
276
+ "BA : 2023-05-07 - 2023-05-14\n",
277
+ "BA : 2023-05-14 - 2023-05-21\n",
278
+ "BA : 2023-05-21 - 2023-05-28\n",
279
+ "BA : 2023-05-28 - 2023-06-04\n",
280
+ "[*********************100%%**********************] 1 of 1 completed\n",
281
+ "CAT : 2023-01-08 - 2023-01-15\n",
282
+ "CAT : 2023-01-15 - 2023-01-22\n",
283
+ "CAT : 2023-01-22 - 2023-01-29\n",
284
+ "CAT : 2023-01-29 - 2023-02-05\n",
285
+ "CAT : 2023-02-05 - 2023-02-12\n",
286
+ "CAT : 2023-02-12 - 2023-02-19\n",
287
+ "CAT : 2023-02-19 - 2023-02-26\n",
288
+ "CAT : 2023-02-26 - 2023-03-05\n",
289
+ "CAT : 2023-03-05 - 2023-03-12\n",
290
+ "CAT : 2023-03-12 - 2023-03-19\n",
291
+ "CAT : 2023-03-19 - 2023-03-26\n",
292
+ "CAT : 2023-03-26 - 2023-04-02\n",
293
+ "CAT : 2023-04-02 - 2023-04-09\n",
294
+ "CAT : 2023-04-09 - 2023-04-16\n",
295
+ "CAT : 2023-04-16 - 2023-04-23\n",
296
+ "CAT : 2023-04-23 - 2023-04-30\n",
297
+ "CAT : 2023-04-30 - 2023-05-07\n",
298
+ "CAT : 2023-05-07 - 2023-05-14\n",
299
+ "CAT : 2023-05-14 - 2023-05-21\n",
300
+ "CAT : 2023-05-21 - 2023-05-28\n",
301
+ "CAT : 2023-05-28 - 2023-06-04\n",
302
+ "[*********************100%%**********************] 1 of 1 completed\n",
303
+ "CSCO : 2023-01-08 - 2023-01-15\n",
304
+ "CSCO : 2023-01-15 - 2023-01-22\n",
305
+ "CSCO : 2023-01-22 - 2023-01-29\n",
306
+ "CSCO : 2023-01-29 - 2023-02-05\n",
307
+ "CSCO : 2023-02-05 - 2023-02-12\n",
308
+ "CSCO : 2023-02-12 - 2023-02-19\n",
309
+ "CSCO : 2023-02-19 - 2023-02-26\n",
310
+ "CSCO : 2023-02-26 - 2023-03-05\n",
311
+ "CSCO : 2023-03-05 - 2023-03-12\n",
312
+ "CSCO : 2023-03-12 - 2023-03-19\n",
313
+ "CSCO : 2023-03-19 - 2023-03-26\n",
314
+ "CSCO : 2023-03-26 - 2023-04-02\n",
315
+ "CSCO : 2023-04-02 - 2023-04-09\n",
316
+ "CSCO : 2023-04-09 - 2023-04-16\n",
317
+ "CSCO : 2023-04-16 - 2023-04-23\n",
318
+ "CSCO : 2023-04-23 - 2023-04-30\n",
319
+ "CSCO : 2023-04-30 - 2023-05-07\n",
320
+ "CSCO : 2023-05-07 - 2023-05-14\n",
321
+ "CSCO : 2023-05-14 - 2023-05-21\n",
322
+ "CSCO : 2023-05-21 - 2023-05-28\n",
323
+ "CSCO : 2023-05-28 - 2023-06-04\n",
324
+ "[*********************100%%**********************] 1 of 1 completed\n",
325
+ "CVX : 2023-01-08 - 2023-01-15\n",
326
+ "CVX : 2023-01-15 - 2023-01-22\n",
327
+ "CVX : 2023-01-22 - 2023-01-29\n",
328
+ "CVX : 2023-01-29 - 2023-02-05\n",
329
+ "CVX : 2023-02-05 - 2023-02-12\n",
330
+ "CVX : 2023-02-12 - 2023-02-19\n",
331
+ "CVX : 2023-02-19 - 2023-02-26\n",
332
+ "CVX : 2023-02-26 - 2023-03-05\n",
333
+ "CVX : 2023-03-05 - 2023-03-12\n",
334
+ "CVX : 2023-03-12 - 2023-03-19\n",
335
+ "CVX : 2023-03-19 - 2023-03-26\n",
336
+ "CVX : 2023-03-26 - 2023-04-02\n",
337
+ "CVX : 2023-04-02 - 2023-04-09\n",
338
+ "CVX : 2023-04-09 - 2023-04-16\n",
339
+ "CVX : 2023-04-16 - 2023-04-23\n",
340
+ "CVX : 2023-04-23 - 2023-04-30\n",
341
+ "CVX : 2023-04-30 - 2023-05-07\n",
342
+ "CVX : 2023-05-07 - 2023-05-14\n",
343
+ "CVX : 2023-05-14 - 2023-05-21\n",
344
+ "CVX : 2023-05-21 - 2023-05-28\n",
345
+ "CVX : 2023-05-28 - 2023-06-04\n",
346
+ "[*********************100%%**********************] 1 of 1 completed\n",
347
+ "GS : 2023-01-08 - 2023-01-15\n",
348
+ "GS : 2023-01-15 - 2023-01-22\n",
349
+ "GS : 2023-01-22 - 2023-01-29\n",
350
+ "GS : 2023-01-29 - 2023-02-05\n",
351
+ "GS : 2023-02-05 - 2023-02-12\n",
352
+ "GS : 2023-02-12 - 2023-02-19\n",
353
+ "GS : 2023-02-19 - 2023-02-26\n",
354
+ "GS : 2023-02-26 - 2023-03-05\n",
355
+ "GS : 2023-03-05 - 2023-03-12\n",
356
+ "GS : 2023-03-12 - 2023-03-19\n",
357
+ "GS : 2023-03-19 - 2023-03-26\n",
358
+ "GS : 2023-03-26 - 2023-04-02\n",
359
+ "GS : 2023-04-02 - 2023-04-09\n",
360
+ "GS : 2023-04-09 - 2023-04-16\n",
361
+ "GS : 2023-04-16 - 2023-04-23\n",
362
+ "GS : 2023-04-23 - 2023-04-30\n",
363
+ "GS : 2023-04-30 - 2023-05-07\n",
364
+ "GS : 2023-05-07 - 2023-05-14\n",
365
+ "GS : 2023-05-14 - 2023-05-21\n",
366
+ "GS : 2023-05-21 - 2023-05-28\n",
367
+ "GS : 2023-05-28 - 2023-06-04\n",
368
+ "[*********************100%%**********************] 1 of 1 completed\n",
369
+ "HD : 2023-01-08 - 2023-01-15\n",
370
+ "HD : 2023-01-15 - 2023-01-22\n",
371
+ "HD : 2023-01-22 - 2023-01-29\n",
372
+ "HD : 2023-01-29 - 2023-02-05\n",
373
+ "HD : 2023-02-05 - 2023-02-12\n",
374
+ "HD : 2023-02-12 - 2023-02-19\n",
375
+ "HD : 2023-02-19 - 2023-02-26\n",
376
+ "HD : 2023-02-26 - 2023-03-05\n",
377
+ "HD : 2023-03-05 - 2023-03-12\n",
378
+ "HD : 2023-03-12 - 2023-03-19\n",
379
+ "HD : 2023-03-19 - 2023-03-26\n",
380
+ "HD : 2023-03-26 - 2023-04-02\n",
381
+ "HD : 2023-04-02 - 2023-04-09\n",
382
+ "HD : 2023-04-09 - 2023-04-16\n",
383
+ "HD : 2023-04-16 - 2023-04-23\n",
384
+ "HD : 2023-04-23 - 2023-04-30\n",
385
+ "HD : 2023-04-30 - 2023-05-07\n",
386
+ "HD : 2023-05-07 - 2023-05-14\n",
387
+ "HD : 2023-05-14 - 2023-05-21\n",
388
+ "HD : 2023-05-21 - 2023-05-28\n",
389
+ "HD : 2023-05-28 - 2023-06-04\n",
390
+ "[*********************100%%**********************] 1 of 1 completed\n",
391
+ "HON : 2023-01-08 - 2023-01-15\n",
392
+ "HON : 2023-01-15 - 2023-01-22\n",
393
+ "HON : 2023-01-22 - 2023-01-29\n",
394
+ "HON : 2023-01-29 - 2023-02-05\n",
395
+ "HON : 2023-02-05 - 2023-02-12\n",
396
+ "HON : 2023-02-12 - 2023-02-19\n",
397
+ "HON : 2023-02-19 - 2023-02-26\n",
398
+ "HON : 2023-02-26 - 2023-03-05\n",
399
+ "HON : 2023-03-05 - 2023-03-12\n",
400
+ "HON : 2023-03-12 - 2023-03-19\n",
401
+ "HON : 2023-03-19 - 2023-03-26\n",
402
+ "HON : 2023-03-26 - 2023-04-02\n",
403
+ "HON : 2023-04-02 - 2023-04-09\n",
404
+ "HON : 2023-04-09 - 2023-04-16\n",
405
+ "HON : 2023-04-16 - 2023-04-23\n",
406
+ "HON : 2023-04-23 - 2023-04-30\n",
407
+ "HON : 2023-04-30 - 2023-05-07\n",
408
+ "HON : 2023-05-07 - 2023-05-14\n",
409
+ "HON : 2023-05-14 - 2023-05-21\n",
410
+ "HON : 2023-05-21 - 2023-05-28\n",
411
+ "HON : 2023-05-28 - 2023-06-04\n",
412
+ "[*********************100%%**********************] 1 of 1 completed\n",
413
+ "IBM : 2023-01-08 - 2023-01-15\n",
414
+ "IBM : 2023-01-15 - 2023-01-22\n",
415
+ "IBM : 2023-01-22 - 2023-01-29\n",
416
+ "IBM : 2023-01-29 - 2023-02-05\n",
417
+ "IBM : 2023-02-05 - 2023-02-12\n",
418
+ "IBM : 2023-02-12 - 2023-02-19\n",
419
+ "IBM : 2023-02-19 - 2023-02-26\n",
420
+ "IBM : 2023-02-26 - 2023-03-05\n",
421
+ "IBM : 2023-03-05 - 2023-03-12\n",
422
+ "IBM : 2023-03-12 - 2023-03-19\n",
423
+ "IBM : 2023-03-19 - 2023-03-26\n",
424
+ "IBM : 2023-03-26 - 2023-04-02\n",
425
+ "IBM : 2023-04-02 - 2023-04-09\n",
426
+ "IBM : 2023-04-09 - 2023-04-16\n",
427
+ "IBM : 2023-04-16 - 2023-04-23\n",
428
+ "IBM : 2023-04-23 - 2023-04-30\n"
429
+ ]
430
+ },
431
+ {
432
+ "name": "stdout",
433
+ "output_type": "stream",
434
+ "text": [
435
+ "IBM : 2023-04-30 - 2023-05-07\n",
436
+ "IBM : 2023-05-07 - 2023-05-14\n",
437
+ "IBM : 2023-05-14 - 2023-05-21\n",
438
+ "IBM : 2023-05-21 - 2023-05-28\n",
439
+ "IBM : 2023-05-28 - 2023-06-04\n",
440
+ "[*********************100%%**********************] 1 of 1 completed\n",
441
+ "INTC : 2023-01-08 - 2023-01-15\n",
442
+ "INTC : 2023-01-15 - 2023-01-22\n",
443
+ "INTC : 2023-01-22 - 2023-01-29\n",
444
+ "INTC : 2023-01-29 - 2023-02-05\n",
445
+ "INTC : 2023-02-05 - 2023-02-12\n",
446
+ "INTC : 2023-02-12 - 2023-02-19\n",
447
+ "INTC : 2023-02-19 - 2023-02-26\n",
448
+ "INTC : 2023-02-26 - 2023-03-05\n",
449
+ "INTC : 2023-03-05 - 2023-03-12\n",
450
+ "INTC : 2023-03-12 - 2023-03-19\n",
451
+ "INTC : 2023-03-19 - 2023-03-26\n",
452
+ "INTC : 2023-03-26 - 2023-04-02\n",
453
+ "INTC : 2023-04-02 - 2023-04-09\n",
454
+ "INTC : 2023-04-09 - 2023-04-16\n",
455
+ "INTC : 2023-04-16 - 2023-04-23\n",
456
+ "INTC : 2023-04-23 - 2023-04-30\n",
457
+ "INTC : 2023-04-30 - 2023-05-07\n",
458
+ "INTC : 2023-05-07 - 2023-05-14\n",
459
+ "INTC : 2023-05-14 - 2023-05-21\n",
460
+ "INTC : 2023-05-21 - 2023-05-28\n",
461
+ "INTC : 2023-05-28 - 2023-06-04\n",
462
+ "[*********************100%%**********************] 1 of 1 completed\n",
463
+ "JNJ : 2023-01-08 - 2023-01-15\n",
464
+ "JNJ : 2023-01-15 - 2023-01-22\n",
465
+ "JNJ : 2023-01-22 - 2023-01-29\n",
466
+ "JNJ : 2023-01-29 - 2023-02-05\n",
467
+ "JNJ : 2023-02-05 - 2023-02-12\n",
468
+ "JNJ : 2023-02-12 - 2023-02-19\n",
469
+ "JNJ : 2023-02-19 - 2023-02-26\n",
470
+ "JNJ : 2023-02-26 - 2023-03-05\n",
471
+ "JNJ : 2023-03-05 - 2023-03-12\n",
472
+ "JNJ : 2023-03-12 - 2023-03-19\n",
473
+ "JNJ : 2023-03-19 - 2023-03-26\n",
474
+ "JNJ : 2023-03-26 - 2023-04-02\n",
475
+ "JNJ : 2023-04-02 - 2023-04-09\n",
476
+ "JNJ : 2023-04-09 - 2023-04-16\n",
477
+ "JNJ : 2023-04-16 - 2023-04-23\n",
478
+ "JNJ : 2023-04-23 - 2023-04-30\n",
479
+ "JNJ : 2023-04-30 - 2023-05-07\n",
480
+ "JNJ : 2023-05-07 - 2023-05-14\n",
481
+ "JNJ : 2023-05-14 - 2023-05-21\n",
482
+ "JNJ : 2023-05-21 - 2023-05-28\n",
483
+ "JNJ : 2023-05-28 - 2023-06-04\n",
484
+ "[*********************100%%**********************] 1 of 1 completed\n",
485
+ "KO : 2023-01-08 - 2023-01-15\n",
486
+ "KO : 2023-01-15 - 2023-01-22\n",
487
+ "KO : 2023-01-22 - 2023-01-29\n",
488
+ "KO : 2023-01-29 - 2023-02-05\n",
489
+ "KO : 2023-02-05 - 2023-02-12\n",
490
+ "KO : 2023-02-12 - 2023-02-19\n",
491
+ "KO : 2023-02-19 - 2023-02-26\n",
492
+ "KO : 2023-02-26 - 2023-03-05\n",
493
+ "KO : 2023-03-05 - 2023-03-12\n",
494
+ "KO : 2023-03-12 - 2023-03-19\n",
495
+ "KO : 2023-03-19 - 2023-03-26\n",
496
+ "KO : 2023-03-26 - 2023-04-02\n",
497
+ "KO : 2023-04-02 - 2023-04-09\n",
498
+ "KO : 2023-04-09 - 2023-04-16\n",
499
+ "KO : 2023-04-16 - 2023-04-23\n",
500
+ "KO : 2023-04-23 - 2023-04-30\n",
501
+ "KO : 2023-04-30 - 2023-05-07\n",
502
+ "KO : 2023-05-07 - 2023-05-14\n",
503
+ "KO : 2023-05-14 - 2023-05-21\n",
504
+ "KO : 2023-05-21 - 2023-05-28\n",
505
+ "KO : 2023-05-28 - 2023-06-04\n",
506
+ "[*********************100%%**********************] 1 of 1 completed\n",
507
+ "JPM : 2023-01-08 - 2023-01-15\n",
508
+ "JPM : 2023-01-15 - 2023-01-22\n",
509
+ "JPM : 2023-01-22 - 2023-01-29\n",
510
+ "JPM : 2023-01-29 - 2023-02-05\n",
511
+ "JPM : 2023-02-05 - 2023-02-12\n",
512
+ "JPM : 2023-02-12 - 2023-02-19\n",
513
+ "JPM : 2023-02-19 - 2023-02-26\n",
514
+ "JPM : 2023-02-26 - 2023-03-05\n",
515
+ "JPM : 2023-03-05 - 2023-03-12\n",
516
+ "JPM : 2023-03-12 - 2023-03-19\n",
517
+ "JPM : 2023-03-19 - 2023-03-26\n",
518
+ "JPM : 2023-03-26 - 2023-04-02\n",
519
+ "JPM : 2023-04-02 - 2023-04-09\n",
520
+ "JPM : 2023-04-09 - 2023-04-16\n",
521
+ "JPM : 2023-04-16 - 2023-04-23\n",
522
+ "JPM : 2023-04-23 - 2023-04-30\n",
523
+ "JPM : 2023-04-30 - 2023-05-07\n",
524
+ "JPM : 2023-05-07 - 2023-05-14\n",
525
+ "JPM : 2023-05-14 - 2023-05-21\n",
526
+ "JPM : 2023-05-21 - 2023-05-28\n",
527
+ "JPM : 2023-05-28 - 2023-06-04\n",
528
+ "[*********************100%%**********************] 1 of 1 completed\n",
529
+ "MCD : 2023-01-08 - 2023-01-15\n",
530
+ "MCD : 2023-01-15 - 2023-01-22\n",
531
+ "MCD : 2023-01-22 - 2023-01-29\n",
532
+ "MCD : 2023-01-29 - 2023-02-05\n",
533
+ "MCD : 2023-02-05 - 2023-02-12\n",
534
+ "MCD : 2023-02-12 - 2023-02-19\n",
535
+ "MCD : 2023-02-19 - 2023-02-26\n",
536
+ "MCD : 2023-02-26 - 2023-03-05\n",
537
+ "MCD : 2023-03-05 - 2023-03-12\n",
538
+ "MCD : 2023-03-12 - 2023-03-19\n",
539
+ "MCD : 2023-03-19 - 2023-03-26\n",
540
+ "MCD : 2023-03-26 - 2023-04-02\n",
541
+ "MCD : 2023-04-02 - 2023-04-09\n",
542
+ "MCD : 2023-04-09 - 2023-04-16\n",
543
+ "MCD : 2023-04-16 - 2023-04-23\n",
544
+ "MCD : 2023-04-23 - 2023-04-30\n",
545
+ "MCD : 2023-04-30 - 2023-05-07\n",
546
+ "MCD : 2023-05-07 - 2023-05-14\n",
547
+ "MCD : 2023-05-14 - 2023-05-21\n",
548
+ "MCD : 2023-05-21 - 2023-05-28\n",
549
+ "MCD : 2023-05-28 - 2023-06-04\n",
550
+ "[*********************100%%**********************] 1 of 1 completed\n",
551
+ "MMM : 2023-01-08 - 2023-01-15\n",
552
+ "MMM : 2023-01-15 - 2023-01-22\n",
553
+ "MMM : 2023-01-22 - 2023-01-29\n",
554
+ "MMM : 2023-01-29 - 2023-02-05\n",
555
+ "MMM : 2023-02-05 - 2023-02-12\n",
556
+ "MMM : 2023-02-12 - 2023-02-19\n",
557
+ "MMM : 2023-02-19 - 2023-02-26\n",
558
+ "MMM : 2023-02-26 - 2023-03-05\n",
559
+ "MMM : 2023-03-05 - 2023-03-12\n",
560
+ "MMM : 2023-03-12 - 2023-03-19\n",
561
+ "MMM : 2023-03-19 - 2023-03-26\n",
562
+ "MMM : 2023-03-26 - 2023-04-02\n",
563
+ "MMM : 2023-04-02 - 2023-04-09\n",
564
+ "MMM : 2023-04-09 - 2023-04-16\n",
565
+ "MMM : 2023-04-16 - 2023-04-23\n",
566
+ "MMM : 2023-04-23 - 2023-04-30\n",
567
+ "MMM : 2023-04-30 - 2023-05-07\n",
568
+ "MMM : 2023-05-07 - 2023-05-14\n",
569
+ "MMM : 2023-05-14 - 2023-05-21\n",
570
+ "MMM : 2023-05-21 - 2023-05-28\n",
571
+ "MMM : 2023-05-28 - 2023-06-04\n",
572
+ "[*********************100%%**********************] 1 of 1 completed\n",
573
+ "MRK : 2023-01-08 - 2023-01-15\n",
574
+ "MRK : 2023-01-15 - 2023-01-22\n",
575
+ "MRK : 2023-01-22 - 2023-01-29\n",
576
+ "MRK : 2023-01-29 - 2023-02-05\n",
577
+ "MRK : 2023-02-05 - 2023-02-12\n",
578
+ "MRK : 2023-02-12 - 2023-02-19\n",
579
+ "MRK : 2023-02-19 - 2023-02-26\n",
580
+ "MRK : 2023-02-26 - 2023-03-05\n",
581
+ "MRK : 2023-03-05 - 2023-03-12\n",
582
+ "MRK : 2023-03-12 - 2023-03-19\n",
583
+ "MRK : 2023-03-19 - 2023-03-26\n",
584
+ "MRK : 2023-03-26 - 2023-04-02\n",
585
+ "MRK : 2023-04-02 - 2023-04-09\n",
586
+ "MRK : 2023-04-09 - 2023-04-16\n",
587
+ "MRK : 2023-04-16 - 2023-04-23\n",
588
+ "MRK : 2023-04-23 - 2023-04-30\n",
589
+ "MRK : 2023-04-30 - 2023-05-07\n",
590
+ "MRK : 2023-05-07 - 2023-05-14\n",
591
+ "MRK : 2023-05-14 - 2023-05-21\n",
592
+ "MRK : 2023-05-21 - 2023-05-28\n",
593
+ "MRK : 2023-05-28 - 2023-06-04\n",
594
+ "[*********************100%%**********************] 1 of 1 completed\n",
595
+ "MSFT : 2023-01-08 - 2023-01-15\n",
596
+ "MSFT : 2023-01-15 - 2023-01-22\n",
597
+ "MSFT : 2023-01-22 - 2023-01-29\n",
598
+ "MSFT : 2023-01-29 - 2023-02-05\n",
599
+ "MSFT : 2023-02-05 - 2023-02-12\n",
600
+ "MSFT : 2023-02-12 - 2023-02-19\n",
601
+ "MSFT : 2023-02-19 - 2023-02-26\n",
602
+ "MSFT : 2023-02-26 - 2023-03-05\n",
603
+ "MSFT : 2023-03-05 - 2023-03-12\n",
604
+ "MSFT : 2023-03-12 - 2023-03-19\n",
605
+ "MSFT : 2023-03-19 - 2023-03-26\n",
606
+ "MSFT : 2023-03-26 - 2023-04-02\n",
607
+ "MSFT : 2023-04-02 - 2023-04-09\n",
608
+ "MSFT : 2023-04-09 - 2023-04-16\n",
609
+ "MSFT : 2023-04-16 - 2023-04-23\n",
610
+ "MSFT : 2023-04-23 - 2023-04-30\n",
611
+ "MSFT : 2023-04-30 - 2023-05-07\n",
612
+ "MSFT : 2023-05-07 - 2023-05-14\n",
613
+ "MSFT : 2023-05-14 - 2023-05-21\n",
614
+ "MSFT : 2023-05-21 - 2023-05-28\n",
615
+ "MSFT : 2023-05-28 - 2023-06-04\n",
616
+ "[*********************100%%**********************] 1 of 1 completed\n",
617
+ "NKE : 2023-01-08 - 2023-01-15\n",
618
+ "NKE : 2023-01-15 - 2023-01-22\n",
619
+ "NKE : 2023-01-22 - 2023-01-29\n",
620
+ "NKE : 2023-01-29 - 2023-02-05\n",
621
+ "NKE : 2023-02-05 - 2023-02-12\n",
622
+ "NKE : 2023-02-12 - 2023-02-19\n",
623
+ "NKE : 2023-02-19 - 2023-02-26\n",
624
+ "NKE : 2023-02-26 - 2023-03-05\n",
625
+ "NKE : 2023-03-05 - 2023-03-12\n",
626
+ "NKE : 2023-03-12 - 2023-03-19\n",
627
+ "NKE : 2023-03-19 - 2023-03-26\n",
628
+ "NKE : 2023-03-26 - 2023-04-02\n",
629
+ "NKE : 2023-04-02 - 2023-04-09\n",
630
+ "NKE : 2023-04-09 - 2023-04-16\n",
631
+ "NKE : 2023-04-16 - 2023-04-23\n",
632
+ "NKE : 2023-04-23 - 2023-04-30\n",
633
+ "NKE : 2023-04-30 - 2023-05-07\n",
634
+ "NKE : 2023-05-07 - 2023-05-14\n",
635
+ "NKE : 2023-05-14 - 2023-05-21\n",
636
+ "NKE : 2023-05-21 - 2023-05-28\n",
637
+ "NKE : 2023-05-28 - 2023-06-04\n",
638
+ "[*********************100%%**********************] 1 of 1 completed\n",
639
+ "PG : 2023-01-08 - 2023-01-15\n",
640
+ "PG : 2023-01-15 - 2023-01-22\n",
641
+ "PG : 2023-01-22 - 2023-01-29\n",
642
+ "PG : 2023-01-29 - 2023-02-05\n",
643
+ "PG : 2023-02-05 - 2023-02-12\n",
644
+ "PG : 2023-02-12 - 2023-02-19\n",
645
+ "PG : 2023-02-19 - 2023-02-26\n",
646
+ "PG : 2023-02-26 - 2023-03-05\n",
647
+ "PG : 2023-03-05 - 2023-03-12\n",
648
+ "PG : 2023-03-12 - 2023-03-19\n",
649
+ "PG : 2023-03-19 - 2023-03-26\n",
650
+ "PG : 2023-03-26 - 2023-04-02\n",
651
+ "PG : 2023-04-02 - 2023-04-09\n",
652
+ "PG : 2023-04-09 - 2023-04-16\n",
653
+ "PG : 2023-04-16 - 2023-04-23\n",
654
+ "PG : 2023-04-23 - 2023-04-30\n",
655
+ "PG : 2023-04-30 - 2023-05-07\n",
656
+ "PG : 2023-05-07 - 2023-05-14\n",
657
+ "PG : 2023-05-14 - 2023-05-21\n",
658
+ "PG : 2023-05-21 - 2023-05-28\n",
659
+ "PG : 2023-05-28 - 2023-06-04\n",
660
+ "[*********************100%%**********************] 1 of 1 completed\n",
661
+ "TRV : 2023-01-08 - 2023-01-15\n",
662
+ "TRV : 2023-01-15 - 2023-01-22\n",
663
+ "TRV : 2023-01-22 - 2023-01-29\n",
664
+ "TRV : 2023-01-29 - 2023-02-05\n",
665
+ "TRV : 2023-02-05 - 2023-02-12\n",
666
+ "TRV : 2023-02-12 - 2023-02-19\n",
667
+ "TRV : 2023-02-19 - 2023-02-26\n",
668
+ "TRV : 2023-02-26 - 2023-03-05\n",
669
+ "TRV : 2023-03-05 - 2023-03-12\n",
670
+ "TRV : 2023-03-12 - 2023-03-19\n",
671
+ "TRV : 2023-03-19 - 2023-03-26\n"
672
+ ]
673
+ },
674
+ {
675
+ "name": "stdout",
676
+ "output_type": "stream",
677
+ "text": [
678
+ "TRV : 2023-03-26 - 2023-04-02\n",
679
+ "TRV : 2023-04-02 - 2023-04-09\n",
680
+ "TRV : 2023-04-09 - 2023-04-16\n",
681
+ "TRV : 2023-04-16 - 2023-04-23\n",
682
+ "TRV : 2023-04-23 - 2023-04-30\n",
683
+ "TRV : 2023-04-30 - 2023-05-07\n",
684
+ "TRV : 2023-05-07 - 2023-05-14\n",
685
+ "TRV : 2023-05-14 - 2023-05-21\n",
686
+ "TRV : 2023-05-21 - 2023-05-28\n",
687
+ "TRV : 2023-05-28 - 2023-06-04\n",
688
+ "[*********************100%%**********************] 1 of 1 completed\n",
689
+ "UNH : 2023-01-08 - 2023-01-15\n",
690
+ "UNH : 2023-01-15 - 2023-01-22\n",
691
+ "UNH : 2023-01-22 - 2023-01-29\n",
692
+ "UNH : 2023-01-29 - 2023-02-05\n",
693
+ "UNH : 2023-02-05 - 2023-02-12\n",
694
+ "UNH : 2023-02-12 - 2023-02-19\n",
695
+ "UNH : 2023-02-19 - 2023-02-26\n",
696
+ "UNH : 2023-02-26 - 2023-03-05\n",
697
+ "UNH : 2023-03-05 - 2023-03-12\n",
698
+ "UNH : 2023-03-12 - 2023-03-19\n",
699
+ "UNH : 2023-03-19 - 2023-03-26\n",
700
+ "UNH : 2023-03-26 - 2023-04-02\n",
701
+ "UNH : 2023-04-02 - 2023-04-09\n",
702
+ "UNH : 2023-04-09 - 2023-04-16\n",
703
+ "UNH : 2023-04-16 - 2023-04-23\n",
704
+ "UNH : 2023-04-23 - 2023-04-30\n",
705
+ "UNH : 2023-04-30 - 2023-05-07\n",
706
+ "UNH : 2023-05-07 - 2023-05-14\n",
707
+ "UNH : 2023-05-14 - 2023-05-21\n",
708
+ "UNH : 2023-05-21 - 2023-05-28\n",
709
+ "UNH : 2023-05-28 - 2023-06-04\n",
710
+ "[*********************100%%**********************] 1 of 1 completed\n",
711
+ "CRM : 2023-01-08 - 2023-01-15\n",
712
+ "CRM : 2023-01-15 - 2023-01-22\n",
713
+ "CRM : 2023-01-22 - 2023-01-29\n",
714
+ "CRM : 2023-01-29 - 2023-02-05\n",
715
+ "CRM : 2023-02-05 - 2023-02-12\n",
716
+ "CRM : 2023-02-12 - 2023-02-19\n",
717
+ "CRM : 2023-02-19 - 2023-02-26\n",
718
+ "CRM : 2023-02-26 - 2023-03-05\n",
719
+ "CRM : 2023-03-05 - 2023-03-12\n",
720
+ "CRM : 2023-03-12 - 2023-03-19\n",
721
+ "CRM : 2023-03-19 - 2023-03-26\n",
722
+ "CRM : 2023-03-26 - 2023-04-02\n",
723
+ "CRM : 2023-04-02 - 2023-04-09\n",
724
+ "CRM : 2023-04-09 - 2023-04-16\n",
725
+ "CRM : 2023-04-16 - 2023-04-23\n",
726
+ "CRM : 2023-04-23 - 2023-04-30\n",
727
+ "CRM : 2023-04-30 - 2023-05-07\n",
728
+ "CRM : 2023-05-07 - 2023-05-14\n",
729
+ "CRM : 2023-05-14 - 2023-05-21\n",
730
+ "CRM : 2023-05-21 - 2023-05-28\n",
731
+ "CRM : 2023-05-28 - 2023-06-04\n",
732
+ "[*********************100%%**********************] 1 of 1 completed\n",
733
+ "VZ : 2023-01-08 - 2023-01-15\n",
734
+ "VZ : 2023-01-15 - 2023-01-22\n",
735
+ "VZ : 2023-01-22 - 2023-01-29\n",
736
+ "VZ : 2023-01-29 - 2023-02-05\n",
737
+ "VZ : 2023-02-05 - 2023-02-12\n",
738
+ "VZ : 2023-02-12 - 2023-02-19\n",
739
+ "VZ : 2023-02-19 - 2023-02-26\n",
740
+ "VZ : 2023-02-26 - 2023-03-05\n",
741
+ "VZ : 2023-03-05 - 2023-03-12\n",
742
+ "VZ : 2023-03-12 - 2023-03-19\n",
743
+ "VZ : 2023-03-19 - 2023-03-26\n",
744
+ "VZ : 2023-03-26 - 2023-04-02\n",
745
+ "VZ : 2023-04-02 - 2023-04-09\n",
746
+ "VZ : 2023-04-09 - 2023-04-16\n",
747
+ "VZ : 2023-04-16 - 2023-04-23\n",
748
+ "VZ : 2023-04-23 - 2023-04-30\n",
749
+ "VZ : 2023-04-30 - 2023-05-07\n",
750
+ "VZ : 2023-05-07 - 2023-05-14\n",
751
+ "VZ : 2023-05-14 - 2023-05-21\n",
752
+ "VZ : 2023-05-21 - 2023-05-28\n",
753
+ "VZ : 2023-05-28 - 2023-06-04\n",
754
+ "[*********************100%%**********************] 1 of 1 completed\n",
755
+ "V : 2023-01-08 - 2023-01-15\n",
756
+ "V : 2023-01-15 - 2023-01-22\n",
757
+ "V : 2023-01-22 - 2023-01-29\n",
758
+ "V : 2023-01-29 - 2023-02-05\n",
759
+ "V : 2023-02-05 - 2023-02-12\n",
760
+ "V : 2023-02-12 - 2023-02-19\n",
761
+ "V : 2023-02-19 - 2023-02-26\n",
762
+ "V : 2023-02-26 - 2023-03-05\n",
763
+ "V : 2023-03-05 - 2023-03-12\n",
764
+ "V : 2023-03-12 - 2023-03-19\n",
765
+ "V : 2023-03-19 - 2023-03-26\n",
766
+ "V : 2023-03-26 - 2023-04-02\n",
767
+ "V : 2023-04-02 - 2023-04-09\n",
768
+ "V : 2023-04-09 - 2023-04-16\n",
769
+ "V : 2023-04-16 - 2023-04-23\n",
770
+ "V : 2023-04-23 - 2023-04-30\n",
771
+ "V : 2023-04-30 - 2023-05-07\n",
772
+ "V : 2023-05-07 - 2023-05-14\n",
773
+ "V : 2023-05-14 - 2023-05-21\n",
774
+ "V : 2023-05-21 - 2023-05-28\n",
775
+ "V : 2023-05-28 - 2023-06-04\n",
776
+ "[*********************100%%**********************] 1 of 1 completed\n",
777
+ "WBA : 2023-01-08 - 2023-01-15\n",
778
+ "WBA : 2023-01-15 - 2023-01-22\n",
779
+ "WBA : 2023-01-22 - 2023-01-29\n",
780
+ "WBA : 2023-01-29 - 2023-02-05\n",
781
+ "WBA : 2023-02-05 - 2023-02-12\n",
782
+ "WBA : 2023-02-12 - 2023-02-19\n",
783
+ "WBA : 2023-02-19 - 2023-02-26\n",
784
+ "WBA : 2023-02-26 - 2023-03-05\n",
785
+ "WBA : 2023-03-05 - 2023-03-12\n",
786
+ "WBA : 2023-03-12 - 2023-03-19\n",
787
+ "WBA : 2023-03-19 - 2023-03-26\n",
788
+ "WBA : 2023-03-26 - 2023-04-02\n",
789
+ "WBA : 2023-04-02 - 2023-04-09\n",
790
+ "WBA : 2023-04-09 - 2023-04-16\n",
791
+ "WBA : 2023-04-16 - 2023-04-23\n",
792
+ "WBA : 2023-04-23 - 2023-04-30\n",
793
+ "WBA : 2023-04-30 - 2023-05-07\n",
794
+ "WBA : 2023-05-07 - 2023-05-14\n",
795
+ "WBA : 2023-05-14 - 2023-05-21\n",
796
+ "WBA : 2023-05-21 - 2023-05-28\n",
797
+ "WBA : 2023-05-28 - 2023-06-04\n",
798
+ "[*********************100%%**********************] 1 of 1 completed\n",
799
+ "WMT : 2023-01-08 - 2023-01-15\n",
800
+ "WMT : 2023-01-15 - 2023-01-22\n",
801
+ "WMT : 2023-01-22 - 2023-01-29\n",
802
+ "WMT : 2023-01-29 - 2023-02-05\n",
803
+ "WMT : 2023-02-05 - 2023-02-12\n",
804
+ "WMT : 2023-02-12 - 2023-02-19\n",
805
+ "WMT : 2023-02-19 - 2023-02-26\n",
806
+ "WMT : 2023-02-26 - 2023-03-05\n",
807
+ "WMT : 2023-03-05 - 2023-03-12\n",
808
+ "WMT : 2023-03-12 - 2023-03-19\n",
809
+ "WMT : 2023-03-19 - 2023-03-26\n",
810
+ "WMT : 2023-03-26 - 2023-04-02\n",
811
+ "WMT : 2023-04-02 - 2023-04-09\n",
812
+ "WMT : 2023-04-09 - 2023-04-16\n",
813
+ "WMT : 2023-04-16 - 2023-04-23\n",
814
+ "WMT : 2023-04-23 - 2023-04-30\n",
815
+ "WMT : 2023-04-30 - 2023-05-07\n",
816
+ "WMT : 2023-05-07 - 2023-05-14\n",
817
+ "WMT : 2023-05-14 - 2023-05-21\n",
818
+ "WMT : 2023-05-21 - 2023-05-28\n",
819
+ "WMT : 2023-05-28 - 2023-06-04\n",
820
+ "[*********************100%%**********************] 1 of 1 completed\n",
821
+ "DIS : 2023-01-08 - 2023-01-15\n",
822
+ "DIS : 2023-01-15 - 2023-01-22\n",
823
+ "DIS : 2023-01-22 - 2023-01-29\n",
824
+ "DIS : 2023-01-29 - 2023-02-05\n",
825
+ "DIS : 2023-02-05 - 2023-02-12\n",
826
+ "DIS : 2023-02-12 - 2023-02-19\n",
827
+ "DIS : 2023-02-19 - 2023-02-26\n",
828
+ "DIS : 2023-02-26 - 2023-03-05\n",
829
+ "DIS : 2023-03-05 - 2023-03-12\n",
830
+ "DIS : 2023-03-12 - 2023-03-19\n",
831
+ "DIS : 2023-03-19 - 2023-03-26\n",
832
+ "DIS : 2023-03-26 - 2023-04-02\n",
833
+ "DIS : 2023-04-02 - 2023-04-09\n",
834
+ "DIS : 2023-04-09 - 2023-04-16\n",
835
+ "DIS : 2023-04-16 - 2023-04-23\n",
836
+ "DIS : 2023-04-23 - 2023-04-30\n",
837
+ "DIS : 2023-04-30 - 2023-05-07\n",
838
+ "DIS : 2023-05-07 - 2023-05-14\n",
839
+ "DIS : 2023-05-14 - 2023-05-21\n",
840
+ "DIS : 2023-05-21 - 2023-05-28\n",
841
+ "DIS : 2023-05-28 - 2023-06-04\n",
842
+ "[*********************100%%**********************] 1 of 1 completed\n",
843
+ "DOW : 2023-01-08 - 2023-01-15\n",
844
+ "DOW : 2023-01-15 - 2023-01-22\n",
845
+ "DOW : 2023-01-22 - 2023-01-29\n",
846
+ "DOW : 2023-01-29 - 2023-02-05\n",
847
+ "DOW : 2023-02-05 - 2023-02-12\n",
848
+ "DOW : 2023-02-12 - 2023-02-19\n",
849
+ "DOW : 2023-02-19 - 2023-02-26\n",
850
+ "DOW : 2023-02-26 - 2023-03-05\n",
851
+ "DOW : 2023-03-05 - 2023-03-12\n",
852
+ "DOW : 2023-03-12 - 2023-03-19\n",
853
+ "DOW : 2023-03-19 - 2023-03-26\n",
854
+ "DOW : 2023-03-26 - 2023-04-02\n",
855
+ "DOW : 2023-04-02 - 2023-04-09\n",
856
+ "DOW : 2023-04-09 - 2023-04-16\n",
857
+ "DOW : 2023-04-16 - 2023-04-23\n",
858
+ "DOW : 2023-04-23 - 2023-04-30\n",
859
+ "DOW : 2023-04-30 - 2023-05-07\n",
860
+ "DOW : 2023-05-07 - 2023-05-14\n",
861
+ "DOW : 2023-05-14 - 2023-05-21\n",
862
+ "DOW : 2023-05-21 - 2023-05-28\n",
863
+ "DOW : 2023-05-28 - 2023-06-04\n"
864
+ ]
865
+ }
866
+ ],
867
+ "source": [
868
+ "for symbol in DOW_30:\n",
869
+ " prepare_data_for_company(symbol)\n",
870
+ "# prepare_data_for_company(symbol, False)"
871
+ ]
872
+ },
873
+ {
874
+ "cell_type": "markdown",
875
+ "id": "af655d8b",
876
+ "metadata": {},
877
+ "source": [
878
+ "# Generate Prompt from Financial Data"
879
+ ]
880
+ },
881
+ {
882
+ "cell_type": "code",
883
+ "execution_count": 65,
884
+ "id": "5a53c0ae",
885
+ "metadata": {
886
+ "scrolled": true
887
+ },
888
+ "outputs": [],
889
+ "source": [
890
+ "def get_company_prompt(symbol):\n",
891
+ " \n",
892
+ " profile = finnhub_client.company_profile2(symbol=symbol)\n",
893
+ "\n",
894
+ " company_template = \"[Company Introduction]:\\n\\n{name} is a leading entity in the {finnhubIndustry} sector. Incorporated and publicly traded since {ipo}, the company has established its reputation as one of the key players in the market. As of today, {name} has a market capitalization of {marketCapitalization:.2f} in {currency}, with {shareOutstanding:.2f} shares outstanding.\" \\\n",
895
+ " \"\\n\\n{name} operates primarily in the {country}, trading under the ticker {ticker} on the {exchange}. As a dominant force in the {finnhubIndustry} space, the company continues to innovate and drive progress within the industry.\"\n",
896
+ "\n",
897
+ " formatted_str = company_template.format(**profile)\n",
898
+ " \n",
899
+ " return formatted_str\n",
900
+ "\n",
901
+ "\n",
902
+ "def get_prompt_by_row(symbol, row):\n",
903
+ "\n",
904
+ " start_date = row['Start Date'] if isinstance(row['Start Date'], str) else row['Start Date'].strftime('%Y-%m-%d')\n",
905
+ " end_date = row['End Date'] if isinstance(row['End Date'], str) else row['End Date'].strftime('%Y-%m-%d')\n",
906
+ " term = 'increased' if row['End Price'] > row['Start Price'] else 'decreased'\n",
907
+ " head = \"From {} to {}, {}'s stock price {} from {:.2f} to {:.2f}. Company news during this period are listed below:\\n\\n\".format(\n",
908
+ " start_date, end_date, symbol, term, row['Start Price'], row['End Price'])\n",
909
+ " \n",
910
+ " news = json.loads(row[\"News\"])\n",
911
+ " news = [\"[Headline]: {}\\n[Summary]: {}\\n\".format(\n",
912
+ " n['headline'], n['summary']) for n in news if n['date'][:8] <= end_date.replace('-', '') and \\\n",
913
+ " not n['summary'].startswith(\"Looking for stock market analysis and research with proves results?\")]\n",
914
+ "\n",
915
+ " basics = json.loads(row['Basics'])\n",
916
+ " if basics:\n",
917
+ " basics = \"Some recent basic financials of {}, reported at {}, are presented below:\\n\\n[Basic Financials]:\\n\\n\".format(\n",
918
+ " symbol, basics['period']) + \"\\n\".join(f\"{k}: {v}\" for k, v in basics.items() if k != 'period')\n",
919
+ " else:\n",
920
+ " basics = \"[Basic Financials]:\\n\\nNo basic financial reported.\"\n",
921
+ " \n",
922
+ " return head, news, basics\n",
923
+ "\n",
924
+ "\n",
925
+ "def sample_news(news, k=5):\n",
926
+ " \n",
927
+ " return [news[i] for i in sorted(random.sample(range(len(news)), k))]\n",
928
+ "\n",
929
+ "\n",
930
+ "def map_bin_label(bin_lb):\n",
931
+ " \n",
932
+ " lb = bin_lb.replace('U', 'up by ')\n",
933
+ " lb = lb.replace('D', 'down by ')\n",
934
+ " lb = lb.replace('1', '0-1%')\n",
935
+ " lb = lb.replace('2', '1-2%')\n",
936
+ " lb = lb.replace('3', '2-3%')\n",
937
+ " lb = lb.replace('4', '3-4%')\n",
938
+ " if lb.endswith('+'):\n",
939
+ " lb = lb.replace('5+', 'more than 5%')\n",
940
+ "# lb = lb.replace('5+', '5+%')\n",
941
+ " else:\n",
942
+ " lb = lb.replace('5', '4-5%')\n",
943
+ " \n",
944
+ " return lb\n",
945
+ "\n",
946
+ "\n",
947
+ "def get_all_prompts(symbol, min_past_weeks=1, max_past_weeks=3, with_basics=True):\n",
948
+ "\n",
949
+ " \n",
950
+ " if with_basics:\n",
951
+ " df = pd.read_csv(f'{DATA_DIR}/{symbol}_{START_DATE}_{END_DATE}.csv')\n",
952
+ " else:\n",
953
+ " df = pd.read_csv(f'{DATA_DIR}/{symbol}_{START_DATE}_{END_DATE}_nobasics.csv')\n",
954
+ " \n",
955
+ " company_prompt = get_company_prompt(symbol)\n",
956
+ "\n",
957
+ " prev_rows = []\n",
958
+ " all_prompts = []\n",
959
+ "\n",
960
+ " for row_idx, row in df.iterrows():\n",
961
+ "\n",
962
+ " prompt = \"\"\n",
963
+ " if len(prev_rows) >= min_past_weeks:\n",
964
+ " idx = min(random.choice(range(min_past_weeks, max_past_weeks+1)), len(prev_rows))\n",
965
+ " for i in range(-idx, 0):\n",
966
+ " # Add Price Movement (Head)\n",
967
+ " prompt += \"\\n\" + prev_rows[i][0]\n",
968
+ " # Add News of previous weeks\n",
969
+ " sampled_news = sample_news(\n",
970
+ " prev_rows[i][1],\n",
971
+ " min(5, len(prev_rows[i][1]))\n",
972
+ " )\n",
973
+ " if sampled_news:\n",
974
+ " prompt += \"\\n\".join(sampled_news)\n",
975
+ " else:\n",
976
+ " prompt += \"No relative news reported.\"\n",
977
+ "\n",
978
+ " head, news, basics = get_prompt_by_row(symbol, row)\n",
979
+ "\n",
980
+ " prev_rows.append((head, news, basics))\n",
981
+ " if len(prev_rows) > max_past_weeks:\n",
982
+ " prev_rows.pop(0) \n",
983
+ "\n",
984
+ " if not prompt:\n",
985
+ " continue\n",
986
+ "\n",
987
+ " prediction = map_bin_label(row['Bin Label'])\n",
988
+ " \n",
989
+ " prompt = company_prompt + '\\n' + prompt + '\\n' + basics\n",
990
+ " prompt += f\"\\n\\nBased on all the information before {row['Start Date']}, let's first analyze the positive developments and potential concerns for {symbol}. Come up with 2-4 most important factors respectively and keep them concise. Most factors should be inferred from company related news. \" \\\n",
991
+ " f\"Then let's assume your prediction for next week ({row['Start Date']} to {row['End Date']}) is {prediction}. Provide a summary analysis to support your prediction. The prediction result need to be inferred from your analysis at the end, and thus not appearing as a foundational factor of your analysis.\"\n",
992
+ "\n",
993
+ " all_prompts.append(prompt.strip())\n",
994
+ " \n",
995
+ " return all_prompts"
996
+ ]
997
+ },
998
+ {
999
+ "cell_type": "code",
1000
+ "execution_count": null,
1001
+ "id": "92208b72",
1002
+ "metadata": {},
1003
+ "outputs": [],
1004
+ "source": [
1005
+ "B_INST, E_INST = \"[INST]\", \"[/INST]\"\n",
1006
+ "B_SYS, E_SYS = \"<<SYS>>\\n\", \"\\n<</SYS>>\\n\\n\"\n",
1007
+ "\n",
1008
+ "\n",
1009
+ "SYSTEM_PROMPT = \"You are a seasoned stock market analyst. Your task is to list the positive developments and potential concerns for companies based on relevant news and basic financials from the past weeks, then provide an analysis and prediction for the companies' stock price movement for the upcoming week. \" \\\n",
1010
+ " \"Your answer format should be as follows:\\n\\n[Positive Developments]:\\n1. ...\\n\\n[Potential Concerns]:\\n1. ...\\n\\n[Prediction & Analysis]:\\n...\\n\"\n",
1011
+ "\n",
1012
+ "print(SYSTEM_PROMPT)\n",
1013
+ "\n",
1014
+ "# prompts = get_all_prompts(\"AAPL\", 1, 3)\n",
1015
+ "# prompts = get_all_prompts(\"MSFT\", 1, 3, False)\n",
1016
+ "prompts = get_all_prompts(\"TRV\", 1, 4)\n",
1017
+ "\n",
1018
+ "print(prompts[0])\n"
1019
+ ]
1020
+ },
1021
+ {
1022
+ "cell_type": "markdown",
1023
+ "id": "2b010a45",
1024
+ "metadata": {},
1025
+ "source": [
1026
+ "# Request to GPT-4 for Financial Analysis"
1027
+ ]
1028
+ },
1029
+ {
1030
+ "cell_type": "code",
1031
+ "execution_count": 86,
1032
+ "id": "3e355117",
1033
+ "metadata": {},
1034
+ "outputs": [],
1035
+ "source": [
1036
+ "def append_to_csv(filename, input_data, output_data):\n",
1037
+ " \n",
1038
+ " with open(filename, mode='a', newline='') as file:\n",
1039
+ " writer = csv.writer(file)\n",
1040
+ " writer.writerow([input_data, output_data])\n",
1041
+ "\n",
1042
+ " \n",
1043
+ "def initialize_csv(filename):\n",
1044
+ " \n",
1045
+ " with open(filename, mode='w', newline='') as file:\n",
1046
+ " writer = csv.writer(file)\n",
1047
+ " writer.writerow([\"prompt\", \"answer\"])\n",
1048
+ "\n",
1049
+ "\n",
1050
+ "def query_gpt4(symbol_list, min_past_weeks=1, max_past_weeks=3, with_basics=True):\n",
1051
+ "\n",
1052
+ " for symbol in symbol_list:\n",
1053
+ " \n",
1054
+ " csv_file = f'{DATA_DIR}/{symbol}_{START_DATE}_{END_DATE}_gpt-4.csv' if with_basics else \\\n",
1055
+ " f'{DATA_DIR}/{symbol}_{START_DATE}_{END_DATE}_nobasics_gpt-4.csv'\n",
1056
+ " \n",
1057
+ " if not os.path.exists(csv_file):\n",
1058
+ " initialize_csv(csv_file)\n",
1059
+ " pre_done = 0\n",
1060
+ " else:\n",
1061
+ " df = pd.read_csv(csv_file)\n",
1062
+ " pre_done = len(df)\n",
1063
+ "\n",
1064
+ " prompts = get_all_prompts(symbol, min_past_weeks, max_past_weeks, with_basics)\n",
1065
+ "\n",
1066
+ " for i, prompt in enumerate(prompts):\n",
1067
+ " \n",
1068
+ " if i < pre_done:\n",
1069
+ " continue\n",
1070
+ "\n",
1071
+ " print(f\"{symbol} - {i}\")\n",
1072
+ " \n",
1073
+ " cnt = 0\n",
1074
+ " while cnt < 5:\n",
1075
+ " try:\n",
1076
+ " completion = client.chat.completions.create(\n",
1077
+ " model=\"gpt-4\",\n",
1078
+ " messages=[\n",
1079
+ " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
1080
+ " {\"role\": \"user\", \"content\": prompt}\n",
1081
+ " ]\n",
1082
+ " )\n",
1083
+ " break \n",
1084
+ " except Exception:\n",
1085
+ " cnt += 1\n",
1086
+ " print(f'retry cnt {cnt}')\n",
1087
+ " \n",
1088
+ " answer = completion.choices[0].message.content if cnt < 5 else \"\"\n",
1089
+ " append_to_csv(csv_file, prompt, answer)\n",
1090
+ " "
1091
+ ]
1092
+ },
1093
+ {
1094
+ "cell_type": "code",
1095
+ "execution_count": 121,
1096
+ "id": "a9ff6ff3",
1097
+ "metadata": {
1098
+ "scrolled": true
1099
+ },
1100
+ "outputs": [
1101
+ {
1102
+ "name": "stdout",
1103
+ "output_type": "stream",
1104
+ "text": [
1105
+ "WBA - 12\n",
1106
+ "WBA - 13\n",
1107
+ "WBA - 14\n",
1108
+ "WBA - 15\n",
1109
+ "WBA - 16\n",
1110
+ "WBA - 17\n",
1111
+ "WBA - 18\n",
1112
+ "WBA - 19\n"
1113
+ ]
1114
+ }
1115
+ ],
1116
+ "source": [
1117
+ "# query_gpt4(DOW_30, 1, 3)\n",
1118
+ "query_gpt4(DOW_30, 1, 4)\n",
1119
+ "# query_gpt4(['WBA'], 1, 4)"
1120
+ ]
1121
+ },
1122
+ {
1123
+ "cell_type": "markdown",
1124
+ "id": "238ba9f0",
1125
+ "metadata": {},
1126
+ "source": [
1127
+ "# Transform into Llama2 Training Format"
1128
+ ]
1129
+ },
1130
+ {
1131
+ "cell_type": "code",
1132
+ "execution_count": 93,
1133
+ "id": "d2627f5a",
1134
+ "metadata": {},
1135
+ "outputs": [],
1136
+ "source": [
1137
+ "def gpt4_to_llama(symbol, with_basics=True):\n",
1138
+ " \n",
1139
+ " csv_file = f'{DATA_DIR}/{symbol}_{START_DATE}_{END_DATE}_gpt-4.csv' if with_basics else \\\n",
1140
+ " f'{DATA_DIR}/{symbol}_{START_DATE}_{END_DATE}_nobasics_gpt-4.csv'\n",
1141
+ " \n",
1142
+ " df = pd.read_csv(csv_file)\n",
1143
+ " \n",
1144
+ " prompts, answers, periods, labels = [], [], [], []\n",
1145
+ " \n",
1146
+ " for i, row in df.iterrows():\n",
1147
+ " \n",
1148
+ " prompt, answer = row['prompt'], row['answer']\n",
1149
+ " \n",
1150
+ " res = re.search(r\"Then let's assume your prediction for next week \\((.*)\\) is ((:?up|down) by .*%).\", prompt)\n",
1151
+ " \n",
1152
+ " period, label = res.group(1), res.group(2)\n",
1153
+ "# label = label.replace('more than 5', '5+')\n",
1154
+ " \n",
1155
+ " prompt = re.sub(\n",
1156
+ " r\"Then let's assume your prediction for next week \\((.*)\\) is (up|down) by ((:?.*)%). Provide a summary analysis to support your prediction. The prediction result need to be inferred from your analysis at the end, and thus not appearing as a foundational factor of your analysis.\", \n",
1157
+ " f\"Then make your prediction of the {symbol} stock price movement for next week ({period}). Provide a summary analysis to support your prediction.\",\n",
1158
+ " prompt\n",
1159
+ " )\n",
1160
+ " try:\n",
1161
+ " answer = re.sub(\n",
1162
+ " r\"\\[Prediction & Analysis\\]:\\s*\",\n",
1163
+ " f\"[Prediction & Analysis]:\\nPrediction: {label.capitalize()}\\nAnalysis: \",\n",
1164
+ " answer\n",
1165
+ " )\n",
1166
+ " except Exception:\n",
1167
+ " print(symbol, i)\n",
1168
+ " print(label)\n",
1169
+ " print(answer)\n",
1170
+ " continue\n",
1171
+ " \n",
1172
+ " new_system_prompt = SYSTEM_PROMPT.replace(':\\n...', '\\nPrediction: ...\\nAnalysis: ...')\n",
1173
+ "# new_system_prompt = SYSTEM_PROMPT.replace(':\\n...', '\\nPrediction: {Up|Down} by {1-2|2-3|3-4|4-5|5+}%\\nAnalysis: ...')\n",
1174
+ " \n",
1175
+ " prompt = B_INST + B_SYS + new_system_prompt + E_SYS + prompt + E_INST\n",
1176
+ " \n",
1177
+ " prompts.append(prompt)\n",
1178
+ " answers.append(answer)\n",
1179
+ " periods.append(period)\n",
1180
+ " labels.append(label)\n",
1181
+ " \n",
1182
+ " return {\n",
1183
+ " \"prompt\": prompts,\n",
1184
+ " \"answer\": answers,\n",
1185
+ " \"period\": periods,\n",
1186
+ " \"label\": labels,\n",
1187
+ " }\n",
1188
+ "\n",
1189
+ "\n",
1190
+ "def create_dataset(symbol_list, train_ratio=0.8, with_basics=True):\n",
1191
+ "\n",
1192
+ " train_dataset_list = []\n",
1193
+ " test_dataset_list = []\n",
1194
+ "\n",
1195
+ " for symbol in symbol_list:\n",
1196
+ "\n",
1197
+ " data_dict = gpt4_to_llama(symbol, with_basics)\n",
1198
+ "# print(data_dict['prompt'][-1])\n",
1199
+ "# print(data_dict['answer'][-1])\n",
1200
+ " symbols = [symbol] * len(data_dict['label'])\n",
1201
+ " data_dict.update({\"symbol\": symbols})\n",
1202
+ "\n",
1203
+ " dataset = Dataset.from_dict(data_dict)\n",
1204
+ " train_size = round(train_ratio * len(dataset))\n",
1205
+ "\n",
1206
+ " train_dataset_list.append(dataset.select(range(train_size)))\n",
1207
+ " test_dataset_list.append(dataset.select(range(train_size, len(dataset))))\n",
1208
+ "\n",
1209
+ " train_dataset = datasets.concatenate_datasets(train_dataset_list)\n",
1210
+ " test_dataset = datasets.concatenate_datasets(test_dataset_list)\n",
1211
+ "\n",
1212
+ " dataset = datasets.DatasetDict({\n",
1213
+ " 'train': train_dataset,\n",
1214
+ " 'test': test_dataset\n",
1215
+ " })\n",
1216
+ " \n",
1217
+ " return dataset\n",
1218
+ " "
1219
+ ]
1220
+ },
1221
+ {
1222
+ "cell_type": "code",
1223
+ "execution_count": 129,
1224
+ "id": "e089b1bf",
1225
+ "metadata": {
1226
+ "scrolled": true
1227
+ },
1228
+ "outputs": [],
1229
+ "source": [
1230
+ "# v1\n",
1231
+ "# dow30_dataset = create_dataset(DOW30, True)\n",
1232
+ "# v2\n",
1233
+ "# dow30_nobasic_dataset = create_dataset(DOW_30, 0.8, False)\n",
1234
+ "# v3\n",
1235
+ "dow30_v3_dataset = create_dataset(DOW_30, 0.9)"
1236
+ ]
1237
+ },
1238
+ {
1239
+ "cell_type": "code",
1240
+ "execution_count": 130,
1241
+ "id": "123f2db9",
1242
+ "metadata": {},
1243
+ "outputs": [
1244
+ {
1245
+ "data": {
1246
+ "application/vnd.jupyter.widget-view+json": {
1247
+ "model_id": "439535ce3e804a3d847f1e03df02283d",
1248
+ "version_major": 2,
1249
+ "version_minor": 0
1250
+ },
1251
+ "text/plain": [
1252
+ "Saving the dataset (0/1 shards): 0%| | 0/540 [00:00<?, ? examples/s]"
1253
+ ]
1254
+ },
1255
+ "metadata": {},
1256
+ "output_type": "display_data"
1257
+ },
1258
+ {
1259
+ "data": {
1260
+ "application/vnd.jupyter.widget-view+json": {
1261
+ "model_id": "373274c4f9b547fabe0027bc696912c3",
1262
+ "version_major": 2,
1263
+ "version_minor": 0
1264
+ },
1265
+ "text/plain": [
1266
+ "Saving the dataset (0/1 shards): 0%| | 0/60 [00:00<?, ? examples/s]"
1267
+ ]
1268
+ },
1269
+ "metadata": {},
1270
+ "output_type": "display_data"
1271
+ }
1272
+ ],
1273
+ "source": [
1274
+ "# dow30_dataset.save_to_disk('fingpt-forecaster-dow30-20230601-20230930-llama')\n",
1275
+ "# dow30_nobasics_dataset.save_to_disk('fingpt-forecaster-dow30nobasics-20230601-20230930-llama')\n",
1276
+ "dow30_v3_dataset.save_to_disk('fingpt-forecaster-dow30v3-20221231-20230531-llama')"
1277
+ ]
1278
+ },
1279
+ {
1280
+ "cell_type": "code",
1281
+ "execution_count": 131,
1282
+ "id": "9ed5cf5f",
1283
+ "metadata": {},
1284
+ "outputs": [
1285
+ {
1286
+ "data": {
1287
+ "text/plain": [
1288
+ "DatasetDict({\n",
1289
+ " train: Dataset({\n",
1290
+ " features: ['prompt', 'answer', 'period', 'label', 'symbol'],\n",
1291
+ " num_rows: 540\n",
1292
+ " })\n",
1293
+ " test: Dataset({\n",
1294
+ " features: ['prompt', 'answer', 'period', 'label', 'symbol'],\n",
1295
+ " num_rows: 60\n",
1296
+ " })\n",
1297
+ "})"
1298
+ ]
1299
+ },
1300
+ "execution_count": 131,
1301
+ "metadata": {},
1302
+ "output_type": "execute_result"
1303
+ }
1304
+ ],
1305
+ "source": [
1306
+ "dow30_v3_dataset"
1307
+ ]
1308
+ },
1309
+ {
1310
+ "cell_type": "markdown",
1311
+ "id": "d62e711f",
1312
+ "metadata": {},
1313
+ "source": [
1314
+ "# Test-time Information Fetching"
1315
+ ]
1316
+ },
1317
+ {
1318
+ "cell_type": "code",
1319
+ "execution_count": 90,
1320
+ "id": "292268bc",
1321
+ "metadata": {},
1322
+ "outputs": [],
1323
+ "source": [
1324
+ "import yfinance as yf\n",
1325
+ "import pandas as pd\n",
1326
+ "from datetime import date, datetime, timedelta\n",
1327
+ "\n",
1328
+ "\n",
1329
+ "def get_curday():\n",
1330
+ " \n",
1331
+ " return date.today().strftime(\"%Y-%m-%d\")\n",
1332
+ "\n",
1333
+ "\n",
1334
+ "def n_weeks_before(date_string, n):\n",
1335
+ " \n",
1336
+ " date = datetime.strptime(date_string, \"%Y-%m-%d\") - timedelta(days=7*n)\n",
1337
+ " \n",
1338
+ " return date.strftime(\"%Y-%m-%d\")\n",
1339
+ "\n",
1340
+ "\n",
1341
+ "def get_stock_data(stock_symbol, steps):\n",
1342
+ "\n",
1343
+ " stock_data = yf.download(stock_symbol, steps[0], steps[-1])\n",
1344
+ " \n",
1345
+ "# print(stock_data)\n",
1346
+ " \n",
1347
+ " dates, prices = [], []\n",
1348
+ " available_dates = stock_data.index.format()\n",
1349
+ " \n",
1350
+ " for date in steps[:-1]:\n",
1351
+ " for i in range(len(stock_data)):\n",
1352
+ " if available_dates[i] >= date:\n",
1353
+ " prices.append(stock_data['Close'][i])\n",
1354
+ " dates.append(datetime.strptime(available_dates[i], \"%Y-%m-%d\"))\n",
1355
+ " break\n",
1356
+ "\n",
1357
+ " dates.append(datetime.strptime(available_dates[-1], \"%Y-%m-%d\"))\n",
1358
+ " prices.append(stock_data['Close'][-1])\n",
1359
+ " \n",
1360
+ " return pd.DataFrame({\n",
1361
+ " \"Start Date\": dates[:-1], \"End Date\": dates[1:],\n",
1362
+ " \"Start Price\": prices[:-1], \"End Price\": prices[1:]\n",
1363
+ " })\n",
1364
+ "\n",
1365
+ "\n",
1366
+ "def get_current_basics(symbol, curday):\n",
1367
+ "\n",
1368
+ " basic_financials = finnhub_client.company_basic_financials(symbol, 'all')\n",
1369
+ " \n",
1370
+ " final_basics, basic_list, basic_dict = [], [], defaultdict(dict)\n",
1371
+ " \n",
1372
+ " for metric, value_list in basic_financials['series']['quarterly'].items():\n",
1373
+ " for value in value_list:\n",
1374
+ " basic_dict[value['period']].update({metric: value['v']})\n",
1375
+ "\n",
1376
+ " for k, v in basic_dict.items():\n",
1377
+ " v.update({'period': k})\n",
1378
+ " basic_list.append(v)\n",
1379
+ " \n",
1380
+ " basic_list.sort(key=lambda x: x['period'])\n",
1381
+ " \n",
1382
+ " for basic in basic_list[::-1]:\n",
1383
+ " if basic['period'] <= curday:\n",
1384
+ " break\n",
1385
+ " \n",
1386
+ " return basic\n",
1387
+ " \n",
1388
+ "\n",
1389
+ "def get_all_prompts_online(symbol, data, curday, with_basics=True):\n",
1390
+ "\n",
1391
+ " company_prompt = get_company_prompt(symbol)\n",
1392
+ "\n",
1393
+ " prev_rows = []\n",
1394
+ "\n",
1395
+ " for row_idx, row in data.iterrows():\n",
1396
+ " head, news, _ = get_prompt_by_row(symbol, row)\n",
1397
+ " prev_rows.append((head, news, None))\n",
1398
+ " \n",
1399
+ " prompt = \"\"\n",
1400
+ " for i in range(-len(prev_rows), 0):\n",
1401
+ " prompt += \"\\n\" + prev_rows[i][0]\n",
1402
+ " sampled_news = sample_news(\n",
1403
+ " prev_rows[i][1],\n",
1404
+ " min(5, len(prev_rows[i][1]))\n",
1405
+ " )\n",
1406
+ " if sampled_news:\n",
1407
+ " prompt += \"\\n\".join(sampled_news)\n",
1408
+ " else:\n",
1409
+ " prompt += \"No relative news reported.\"\n",
1410
+ " \n",
1411
+ " period = \"{} to {}\".format(curday, n_weeks_before(curday, -1))\n",
1412
+ " \n",
1413
+ " if with_basics:\n",
1414
+ " basics = get_current_basics(symbol, curday)\n",
1415
+ " basics = \"Some recent basic financials of {}, reported at {}, are presented below:\\n\\n[Basic Financials]:\\n\\n\".format(\n",
1416
+ " symbol, basics['period']) + \"\\n\".join(f\"{k}: {v}\" for k, v in basics.items() if k != 'period')\n",
1417
+ " else:\n",
1418
+ " basics = \"[Basic Financials]:\\n\\nNo basic financial reported.\"\n",
1419
+ "\n",
1420
+ " info = company_prompt + '\\n' + prompt + '\\n' + basics\n",
1421
+ " prompt = info + f\"\\n\\nBased on all the information before {curday}, let's first analyze the positive developments and potential concerns for {symbol}. Come up with 2-4 most important factors respectively and keep them concise. Most factors should be inferred from company related news. \" \\\n",
1422
+ " f\"Then make your prediction of the {symbol} stock price movement for next week ({period}). Provide a summary analysis to support your prediction.\"\n",
1423
+ " \n",
1424
+ " return info, prompt"
1425
+ ]
1426
+ },
1427
+ {
1428
+ "cell_type": "code",
1429
+ "execution_count": 76,
1430
+ "id": "8f48aab1",
1431
+ "metadata": {
1432
+ "scrolled": false
1433
+ },
1434
+ "outputs": [
1435
+ {
1436
+ "name": "stdout",
1437
+ "output_type": "stream",
1438
+ "text": [
1439
+ "[*********************100%%**********************] 1 of 1 completed\n",
1440
+ "AAPL : 2023-10-25 - 2023-11-01\n",
1441
+ "AAPL : 2023-11-01 - 2023-11-07\n"
1442
+ ]
1443
+ }
1444
+ ],
1445
+ "source": [
1446
+ "ticker = \"AAPL\"\n",
1447
+ "n_weeks = 2\n",
1448
+ "curday = get_curday()\n",
1449
+ "steps = [n_weeks_before(curday, n) for n in range(n_weeks + 1)][::-1]\n",
1450
+ "\n",
1451
+ "data = get_stock_data(ticker, steps)\n",
1452
+ "\n",
1453
+ "data = get_news(ticker, data)\n",
1454
+ "\n",
1455
+ "data['Basics'] = [json.dumps({})] * len(data)\n",
1456
+ "# data = get_basics(ticker, data, always=True)\n"
1457
+ ]
1458
+ },
1459
+ {
1460
+ "cell_type": "code",
1461
+ "execution_count": 91,
1462
+ "id": "84bb302a",
1463
+ "metadata": {},
1464
+ "outputs": [
1465
+ {
1466
+ "name": "stdout",
1467
+ "output_type": "stream",
1468
+ "text": [
1469
+ "[Company Introduction]:\n",
1470
+ "\n",
1471
+ "Apple Inc is a leading entity in the Technology sector. Incorporated and publicly traded since 1980-12-12, the company has established its reputation as one of the key players in the market. As of today, Apple Inc has a market capitalization of 2809837.86 in USD, with 15634.23 shares outstanding.\n",
1472
+ "\n",
1473
+ "Apple Inc operates primarily in the US, trading under the ticker AAPL on the NASDAQ NMS - GLOBAL MARKET. As a dominant force in the Technology space, the company continues to innovate and drive progress within the industry.\n",
1474
+ "\n",
1475
+ "From 2023-10-25 to 2023-11-01, AAPL's stock price increased from 171.10 to 173.97. Company news during this period are listed below:\n",
1476
+ "\n",
1477
+ "[Headline]: 25 Largest Economies in the World by 2075\n",
1478
+ "[Summary]: In this article, we will be taking a look at the 25 largest economies in the world by 2075. To skip our detailed analysis, you can go directly to see the 5 largest economies in the world by 2075. In both 2022 and 2023, the global economy has struggled significantly after record inflation enveloped most countries across […]\n",
1479
+ "\n",
1480
+ "[Headline]: India opposition accuses govt of trying to hack lawmakers' iPhones\n",
1481
+ "[Summary]: Indian opposition leader Rahul Gandhi on Tuesday accused Prime Minister Narendra Modi's government of trying to hack into senior opposition politicians' mobile phones, after they reported receiving warning messages from Apple. Some of the lawmakers shared screenshots on social media of a notification quoting the iPhone manufacturer as saying: \"Apple believes you are being targeted by state-sponsored attackers who are trying to remotely compromise the iPhone associated with your Apple ID\". \"Hack us all you want,\" Gandhi told a news conference in New Delhi, in reference to Modi.\n",
1482
+ "\n",
1483
+ "[Headline]: 39% Of This Apple Insider's Holdings Were Sold\n",
1484
+ "[Summary]: Looking at Apple Inc.'s ( NASDAQ:AAPL ) insider transactions over the last year, we can see that insiders were net...\n",
1485
+ "\n",
1486
+ "[Headline]: Indian opposition MPs accuse government of trying to hack their iPhones\n",
1487
+ "[Summary]: Ruling BJP rejects claims of involvement following Apple notifications of possible ‘state-sponsored’ attacks\n",
1488
+ "\n",
1489
+ "[Headline]: Should You Buy These 2 ‘Magnificent Seven’ Stocks Ahead of Earnings? Apple and Nvidia in Focus\n",
1490
+ "[Summary]: What should investors make of this year’s third-quarter earnings? The Q3 results have been pretty good, with 78% of companies reporting so far beating the forecasts, but stocks are still feeling pressure. One obvious sign of that pressure: the S&P 500 this week hit its lowest point since last May, and is just shy of correction territory. The effect is most clearly seen in the ‘Magnificent Seven,’ a group of Big Tech giants whose gains earlier in the year carried the markets generally – but which\n",
1491
+ "\n",
1492
+ "From 2023-11-01 to 2023-11-07, AAPL's stock price increased from 173.97 to 181.25. Company news during this period are listed below:\n",
1493
+ "\n",
1494
+ "[Headline]: Apple Earnings: Why Guidance Will Be Key\n",
1495
+ "[Summary]: Tech giant Apple (NASDAQ: AAPL) is scheduled to report its fiscal fourth-quarter results on Thursday. After all, the company's approximately $2.7 trillion market cap is big enough to influence major market indexes like the S&P 500; Apple represents about 7% of the index. While the company's fiscal fourth-quarter financial performance will definitely be important, investors may pay even closer attention to another metric: management's guidance for its fiscal first-quarter revenue.\n",
1496
+ "\n",
1497
+ "[Headline]: Analysts offer hot takes on Q4 2023 Apple results\n",
1498
+ "[Summary]: Analysts have weighed in on Apple's Q4 2023 financial results, with most taking the view that the quarter is decent-performing, but with caution about a shorter Q1 2024.\n",
1499
+ "\n",
1500
+ "[Headline]: How to run new macOS versions on older Macs with OpenCore\n",
1501
+ "[Summary]: Apple removes support for old Mac hardware in new macOS releases. Here's how to run modern macOS on older Macs using OpenCore.\n",
1502
+ "\n",
1503
+ "[Headline]: Apple Watch import ban: what you need to know\n",
1504
+ "[Summary]: There is a possibility of an import ban in the U.S. on the Apple Watch. Here's what you need to know before it potentially goes into effect on Christmas Day, 2023.\n",
1505
+ "\n",
1506
+ "[Headline]: ChatGPT: Everything you need to know about the AI-powered chatbot\n",
1507
+ "[Summary]: ChatGPT, OpenAI’s text-generating AI chatbot, has taken the world by storm. What started as a tool to hyper-charge productivity through writing essays and code with short text prompts has evolved into a behemoth used by more than 92% of Fortune 500 companies for more wide-ranging needs. While there is a more…nefarious side to ChatGPT, it’s clear that AI tools are not going away anytime soon. Since its initial launch nearly a year ago, ChatGPT has hit 100 million weekly active users, and OpenAI i\n",
1508
+ "\n",
1509
+ "[Basic Financials]:\n",
1510
+ "\n",
1511
+ "No basic financial reported.\n",
1512
+ "\n",
1513
+ "Based on all the information before 2023-11-08, let's first analyze the positive developments and potential concerns for AAPL. Come up with 2-4 most important factors respectively and keep them concise. Most factors should be inferred from company related news. Then make your prediction of the AAPL stock price movement for next week (2023-11-08 to 2023-11-15). Provide a summary analysis to support your prediction.\n"
1514
+ ]
1515
+ }
1516
+ ],
1517
+ "source": [
1518
+ "info, prompt = get_all_prompts_online(ticker, data, curday, False)\n",
1519
+ "\n",
1520
+ "print(prompt)"
1521
+ ]
1522
+ }
1523
+ ],
1524
+ "metadata": {
1525
+ "kernelspec": {
1526
+ "display_name": "Python 3 (ipykernel)",
1527
+ "language": "python",
1528
+ "name": "python3"
1529
+ },
1530
+ "language_info": {
1531
+ "codemirror_mode": {
1532
+ "name": "ipython",
1533
+ "version": 3
1534
+ },
1535
+ "file_extension": ".py",
1536
+ "mimetype": "text/x-python",
1537
+ "name": "python",
1538
+ "nbconvert_exporter": "python",
1539
+ "pygments_lexer": "ipython3",
1540
+ "version": "3.10.9"
1541
+ }
1542
+ },
1543
+ "nbformat": 4,
1544
+ "nbformat_minor": 5
1545
+ }
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch==2.0.1
2
+ transformers==4.32.0
3
+ peft==0.5.0
4
+ pandas
5
+ yfinance
6
+ finnhub-python
7
+ nvidia-ml-py3
train.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export NCCL_IGNORE_DISABLED_P2P=1
2
+ export TRANSFORMERS_NO_ADVISORY_WARNINGS=1
3
+ export TOKENIZERS_PARALLELISM=0
4
+
5
+
6
+ deepspeed \
7
+ --include localhost:2,3 \
8
+ train_lora.py \
9
+ --run_name dow30v3-llama2-5e-5lr-qkvogud \
10
+ --base_model llama2 \
11
+ --dataset dow30-20230601-20230930-llama,dow30nobasics-20230601-20230930-llama,dow30v3-20221231-20230531-llama*2 \
12
+ --max_length 4096 \
13
+ --batch_size 1 \
14
+ --gradient_accumulation_steps 16 \
15
+ --learning_rate 5e-5 \
16
+ --num_epochs 5 \
17
+ --log_interval 10 \
18
+ --warmup_ratio 0.03 \
19
+ --scheduler constant \
20
+ --evaluation_strategy steps \
21
+ --ds_config config.json
train_lora.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.integrations import TensorBoardCallback
2
+ from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
3
+ from transformers import TrainingArguments, Trainer, DataCollatorForSeq2Seq
4
+ from transformers import TrainerCallback, TrainerState, TrainerControl
5
+ from transformers.trainer import TRAINING_ARGS_NAME
6
+ from torch.utils.tensorboard import SummaryWriter
7
+ import datasets
8
+ import torch
9
+ import os
10
+ import re
11
+ import sys
12
+ import wandb
13
+ import argparse
14
+ from datetime import datetime
15
+ from functools import partial
16
+ from tqdm import tqdm
17
+ from utils import *
18
+
19
+ # LoRA
20
+ from peft import (
21
+ TaskType,
22
+ LoraConfig,
23
+ get_peft_model,
24
+ get_peft_model_state_dict,
25
+ prepare_model_for_int8_training,
26
+ set_peft_model_state_dict,
27
+ )
28
+
29
+ # Replace with your own api_key and project name
30
+ os.environ['WANDB_API_KEY'] = 'ecf1e5e4f47441d46822d38a3249d62e8fc94db4'
31
+ os.environ['WANDB_PROJECT'] = 'fingpt-forecaster'
32
+
33
+
34
+ class GenerationEvalCallback(TrainerCallback):
35
+
36
+ def __init__(self, eval_dataset, ignore_until_epoch=0):
37
+ self.eval_dataset = eval_dataset
38
+ self.ignore_until_epoch = ignore_until_epoch
39
+
40
+ def on_evaluate(self, args, state: TrainerState, control: TrainerControl, **kwargs):
41
+
42
+ if state.epoch is None or state.epoch + 1 < self.ignore_until_epoch:
43
+ return
44
+
45
+ if state.is_local_process_zero:
46
+ model = kwargs['model']
47
+ tokenizer = kwargs['tokenizer']
48
+ generated_texts, reference_texts = [], []
49
+
50
+ for feature in tqdm(self.eval_dataset):
51
+ prompt = feature['prompt']
52
+ gt = feature['answer']
53
+ inputs = tokenizer(
54
+ prompt, return_tensors='pt',
55
+ padding=False, max_length=4096
56
+ )
57
+ inputs = {key: value.to(model.device) for key, value in inputs.items()}
58
+
59
+ res = model.generate(
60
+ **inputs,
61
+ use_cache=True
62
+ )
63
+ output = tokenizer.decode(res[0], skip_special_tokens=True)
64
+ answer = re.sub(r'.*\[/INST\]\s*', '', output, flags=re.DOTALL)
65
+
66
+ generated_texts.append(answer)
67
+ reference_texts.append(gt)
68
+
69
+ # print("GENERATED: ", answer)
70
+ # print("REFERENCE: ", gt)
71
+
72
+ metrics = calc_metrics(reference_texts, generated_texts)
73
+
74
+ # Ensure wandb is initialized
75
+ if wandb.run is None:
76
+ wandb.init()
77
+
78
+ wandb.log(metrics, step=state.global_step)
79
+ torch.cuda.empty_cache()
80
+
81
+
82
+ def main(args):
83
+
84
+ model_name = parse_model_name(args.base_model, args.from_remote)
85
+
86
+ # load model
87
+ model = AutoModelForCausalLM.from_pretrained(
88
+ model_name,
89
+ # load_in_8bit=True,
90
+ trust_remote_code=True
91
+ )
92
+ if args.local_rank == 0:
93
+ print(model)
94
+
95
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
96
+ tokenizer.pad_token = tokenizer.eos_token
97
+ tokenizer.padding_side = "right"
98
+
99
+ # load data
100
+ dataset_list = load_dataset(args.dataset, args.from_remote)
101
+
102
+ dataset_train = datasets.concatenate_datasets([d['train'] for d in dataset_list]).shuffle(seed=42)
103
+
104
+ if args.test_dataset:
105
+ dataset_list = load_dataset(args.test_dataset, args.from_remote)
106
+
107
+ dataset_test = datasets.concatenate_datasets([d['test'] for d in dataset_list])
108
+
109
+ original_dataset = datasets.DatasetDict({'train': dataset_train, 'test': dataset_test})
110
+
111
+ eval_dataset = original_dataset['test'].shuffle(seed=42).select(range(50))
112
+
113
+ dataset = original_dataset.map(partial(tokenize, args, tokenizer))
114
+ print('original dataset length: ', len(dataset['train']))
115
+ dataset = dataset.filter(lambda x: not x['exceed_max_length'])
116
+ print('filtered dataset length: ', len(dataset['train']))
117
+ dataset = dataset.remove_columns(
118
+ ['prompt', 'answer', 'label', 'symbol', 'period', 'exceed_max_length']
119
+ )
120
+
121
+ current_time = datetime.now()
122
+ formatted_time = current_time.strftime('%Y%m%d%H%M')
123
+
124
+ training_args = TrainingArguments(
125
+ output_dir=f'finetuned_models/{args.run_name}_{formatted_time}', # 保存位置
126
+ logging_steps=args.log_interval,
127
+ num_train_epochs=args.num_epochs,
128
+ per_device_train_batch_size=args.batch_size,
129
+ per_device_eval_batch_size=args.batch_size,
130
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
131
+ dataloader_num_workers=args.num_workers,
132
+ learning_rate=args.learning_rate,
133
+ weight_decay=args.weight_decay,
134
+ warmup_ratio=args.warmup_ratio,
135
+ lr_scheduler_type=args.scheduler,
136
+ save_steps=args.eval_steps,
137
+ eval_steps=args.eval_steps,
138
+ fp16=True,
139
+ deepspeed=args.ds_config,
140
+ evaluation_strategy=args.evaluation_strategy,
141
+ remove_unused_columns=False,
142
+ report_to='wandb',
143
+ run_name=args.run_name
144
+ )
145
+
146
+ model.gradient_checkpointing_enable()
147
+ model.enable_input_require_grads()
148
+ model.is_parallelizable = True
149
+ model.model_parallel = True
150
+ model.model.config.use_cache = False
151
+
152
+ # model = prepare_model_for_int8_training(model)
153
+
154
+ # setup peft
155
+ peft_config = LoraConfig(
156
+ task_type=TaskType.CAUSAL_LM,
157
+ inference_mode=False,
158
+ r=8,
159
+ lora_alpha=16,
160
+ lora_dropout=0.1,
161
+ target_modules=lora_module_dict[args.base_model],
162
+ bias='none',
163
+ )
164
+ model = get_peft_model(model, peft_config)
165
+
166
+ # Train
167
+ trainer = Trainer(
168
+ model=model,
169
+ args=training_args,
170
+ train_dataset=dataset['train'],
171
+ eval_dataset=dataset['test'],
172
+ tokenizer=tokenizer,
173
+ data_collator=DataCollatorForSeq2Seq(
174
+ tokenizer, padding=True,
175
+ return_tensors="pt"
176
+ ),
177
+ callbacks=[
178
+ GenerationEvalCallback(
179
+ eval_dataset=eval_dataset,
180
+ ignore_until_epoch=round(0.3 * args.num_epochs)
181
+ )
182
+ ]
183
+ )
184
+
185
+ if torch.__version__ >= "2" and sys.platform != "win32":
186
+ model = torch.compile(model)
187
+
188
+ torch.cuda.empty_cache()
189
+ trainer.train()
190
+
191
+ # save model
192
+ model.save_pretrained(training_args.output_dir)
193
+
194
+
195
+ if __name__ == "__main__":
196
+
197
+ parser = argparse.ArgumentParser()
198
+ parser.add_argument("--local_rank", default=0, type=int)
199
+ parser.add_argument("--run_name", default='local-test', type=str)
200
+ parser.add_argument("--dataset", required=True, type=str)
201
+ parser.add_argument("--test_dataset", type=str)
202
+ parser.add_argument("--base_model", required=True, type=str, choices=['chatglm2', 'llama2'])
203
+ parser.add_argument("--max_length", default=512, type=int)
204
+ parser.add_argument("--batch_size", default=4, type=int, help="The train batch size per device")
205
+ parser.add_argument("--learning_rate", default=1e-4, type=float, help="The learning rate")
206
+ parser.add_argument("--weight_decay", default=0.01, type=float, help="weight decay")
207
+ parser.add_argument("--num_epochs", default=8, type=float, help="The training epochs")
208
+ parser.add_argument("--num_workers", default=8, type=int, help="dataloader workers")
209
+ parser.add_argument("--log_interval", default=20, type=int)
210
+ parser.add_argument("--gradient_accumulation_steps", default=8, type=int)
211
+ parser.add_argument("--warmup_ratio", default=0.05, type=float)
212
+ parser.add_argument("--ds_config", default='./config_new.json', type=str)
213
+ parser.add_argument("--scheduler", default='linear', type=str)
214
+ parser.add_argument("--instruct_template", default='default')
215
+ parser.add_argument("--evaluation_strategy", default='steps', type=str)
216
+ parser.add_argument("--eval_steps", default=0.1, type=float)
217
+ parser.add_argument("--from_remote", default=False, type=bool)
218
+ args = parser.parse_args()
219
+
220
+ wandb.login()
221
+ main(args)
utils.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+ import datasets
4
+ from sklearn.metrics import accuracy_score, mean_squared_error
5
+ from collections import defaultdict
6
+ from rouge_score import rouge_scorer
7
+
8
+
9
+ lora_module_dict = {
10
+ 'chatglm2': ['query_key_value'],
11
+ 'llama2': [
12
+ 'q_proj', 'k_proj', 'v_proj',
13
+ 'o_proj', 'gate_proj', 'up_proj', 'down_proj',
14
+ # 'embed_tokens', 'lm_head',
15
+ ],
16
+ }
17
+
18
+
19
+ def tokenize(args, tokenizer, feature):
20
+
21
+ prompt_ids = tokenizer.encode(
22
+ feature['prompt'].strip(), padding=False,
23
+ max_length=args.max_length, truncation=True
24
+ )
25
+
26
+ target_ids = tokenizer.encode(
27
+ feature['answer'].strip(), padding=False,
28
+ max_length=args.max_length, truncation=True, add_special_tokens=False
29
+ )
30
+
31
+ input_ids = prompt_ids + target_ids
32
+ exceed_max_length = len(input_ids) >= args.max_length
33
+
34
+ # Add EOS Token
35
+ if input_ids[-1] != tokenizer.eos_token_id and not exceed_max_length:
36
+ input_ids.append(tokenizer.eos_token_id)
37
+
38
+ label_ids = [tokenizer.pad_token_id] * len(prompt_ids) + input_ids[len(prompt_ids):]
39
+
40
+ return {
41
+ "input_ids": input_ids,
42
+ "labels": label_ids,
43
+ "exceed_max_length": exceed_max_length
44
+ }
45
+
46
+
47
+ def parse_model_name(name, from_remote=False):
48
+
49
+ if name == 'chatglm2':
50
+ return 'THUDM/chatglm2-6b' if from_remote else 'base_models/chatglm2-6b'
51
+ elif name == 'llama2':
52
+ return 'meta-llama/Llama-2-7b-chat-hf' if from_remote else 'base_models/Llama-2-7b-chat-hf'
53
+ else:
54
+ raise ValueError(f"Undefined base model {name}")
55
+
56
+
57
+ def load_dataset(names, from_remote=False):
58
+
59
+ dataset_names = [d for d in names.split(',')]
60
+ dataset_list = []
61
+
62
+ for name in dataset_names:
63
+ rep = 1
64
+ if not os.path.exists(name):
65
+ rep = int(name.split('*')[1]) if '*' in name else 1
66
+ name = ('FinGPT/fingpt-forecaster-' if from_remote else 'data/fingpt-forecaster-') + name.split('*')[0]
67
+ tmp_dataset = datasets.load_dataset(name) if from_remote else datasets.load_from_disk(name)
68
+
69
+ if 'test' not in tmp_dataset:
70
+ tmp_dataset = tmp_dataset.train_test_split(0.2, shuffle=True, seed=42)
71
+ dataset_list.extend([tmp_dataset] * rep)
72
+
73
+ return dataset_list
74
+
75
+
76
+ def parse_answer(answer):
77
+
78
+ match_res = re.match(r"^\s*\[Positive Developments\]:\s*(.*)\s*\[Potential Concerns\]:\s*(.*)\s*\[Prediction & Analysis\]:\s*(.*)\s*$", answer, flags=re.DOTALL)
79
+ if not match_res:
80
+ return None
81
+
82
+ pros, cons, pna = match_res.group(1), match_res.group(2), match_res.group(3)
83
+
84
+ match_res = re.match(r'^Prediction:\s*(.*)\s*Analysis:\s*(.*)\s*$', pna, flags=re.DOTALL)
85
+ if not match_res:
86
+ return None
87
+
88
+ pred, anal = match_res.group(1), match_res.group(2)
89
+
90
+ if re.search(r'up|increase', pred.lower()):
91
+ pred_bin = 1
92
+ elif re.search(r'down|decrease|decline', pred.lower()):
93
+ pred_bin = -1
94
+ else:
95
+ pred_bin = 0
96
+
97
+ match_res = re.search(r'(\d)-(\d)%', pred)
98
+ if not match_res:
99
+ match_res = re.search(r'(?:more than )?(\d)+?%', pred)
100
+
101
+ pred_margin = pred_bin * (int(match_res.group(1)) + 0.5) if match_res else 0.
102
+
103
+ return {
104
+ "positive developments": pros,
105
+ "potential concerns": cons,
106
+ "prediction": pred_margin,
107
+ "prediction_binary": pred_bin,
108
+ "analysis": anal
109
+ }
110
+
111
+
112
+ def calc_rouge_score(references, answers):
113
+
114
+ scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
115
+
116
+ scores_per_pair = [scorer.score(ref, ans) for ref, ans in zip(references, answers)]
117
+
118
+ rouge1 = sum(score['rouge1'].fmeasure for score in scores_per_pair) / len(scores_per_pair)
119
+ rouge2 = sum(score['rouge2'].fmeasure for score in scores_per_pair) / len(scores_per_pair)
120
+ rougeL = sum(score['rougeL'].fmeasure for score in scores_per_pair) / len(scores_per_pair)
121
+
122
+ return {'rouge1': rouge1, 'rouge2': rouge2, 'rougeL': rougeL}
123
+
124
+
125
+ def calc_metrics(answers, gts):
126
+
127
+ answers_dict = defaultdict(list)
128
+ gts_dict = defaultdict(list)
129
+
130
+ for answer, gt in zip(answers, gts):
131
+ answer_dict = parse_answer(answer)
132
+ gt_dict = parse_answer(gt)
133
+
134
+ if answer_dict and gt_dict:
135
+ for k in answer_dict.keys():
136
+ answers_dict[k].append(answer_dict[k])
137
+ gts_dict[k].append(gt_dict[k])
138
+
139
+ if not answers_dict['prediction']:
140
+ return {}
141
+
142
+ bin_acc = accuracy_score(gts_dict['prediction_binary'], answers_dict['prediction_binary'])
143
+ mse = mean_squared_error(gts_dict['prediction'], answers_dict['prediction'])
144
+
145
+ pros_rouge_scores = calc_rouge_score(gts_dict['positive developments'], answers_dict['positive developments'])
146
+ cons_rouge_scores = calc_rouge_score(gts_dict['potential concerns'], answers_dict['potential concerns'])
147
+ anal_rouge_scores = calc_rouge_score(gts_dict['analysis'], answers_dict['analysis'])
148
+
149
+ print(f"\nBinary Accuracy: {bin_acc:.2f} | Mean Square Error: {mse:.2f}")
150
+ print(f"\nRouge Score of Positive Developments: {pros_rouge_scores}")
151
+ print(f"\nRouge Score of Potential Concerns: {cons_rouge_scores}")
152
+ print(f"\nRouge Score of Summary Analysis: {anal_rouge_scores}")
153
+
154
+ return {
155
+ "valid_count": len(answers_dict['prediction']),
156
+ "bin_acc": bin_acc,
157
+ "mse": mse,
158
+ "pros_rouge_scores": pros_rouge_scores,
159
+ "cons_rouge_scores": cons_rouge_scores,
160
+ "anal_rouge_scores": anal_rouge_scores
161
+ }
162
+