joaco7172 commited on
Commit
baa795e
·
verified ·
1 Parent(s): 3af524a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -342
app.py CHANGED
@@ -3,7 +3,6 @@ import re
3
  import time
4
  import json
5
  import random
6
- import requests # <-- For Polygon API calls
7
  import finnhub
8
  import torch
9
  import gradio as gr
@@ -15,18 +14,10 @@ from collections import defaultdict
15
  from datetime import date, datetime, timedelta
16
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
17
 
18
- ################################################################################
19
- # Set up environment variables, tokens, model, tokenizer, and other base config
20
- ################################################################################
21
 
22
- # Make sure these environment variables or your chosen tokens are set properly.
23
- access_token = os.environ.get("HF_TOKEN", "YOUR_HF_ACCESS_TOKEN")
24
- finnhub_api_key = os.environ.get("FINNHUB_API_KEY", "YOUR_FINNHUB_API_KEY")
25
- polygon_api_key = "OeD88ExoL458WlvNpCyiZXnHI0s_h05t" # <--- Replace "X" with your actual Polygon key
26
 
27
- finnhub_client = finnhub.Client(api_key=finnhub_api_key)
28
-
29
- # Load base model & LoRA
30
  base_model = AutoModelForCausalLM.from_pretrained(
31
  'meta-llama/Llama-2-7b-chat-hf',
32
  token=access_token,
@@ -49,25 +40,15 @@ tokenizer = AutoTokenizer.from_pretrained(
49
 
50
  streamer = TextStreamer(tokenizer)
51
 
52
- # Special Llama-format prompt tokens
53
  B_INST, E_INST = "[INST]", "[/INST]"
54
  B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
55
 
56
- SYSTEM_PROMPT = (
57
- "You are a seasoned stock market analyst. Your task is to list the positive developments and "
58
- "potential concerns for companies based on relevant news and basic financials from the past weeks, "
59
- "then provide an analysis and prediction for the companies' stock price movement for the upcoming week. "
60
- "Your answer format should be as follows:\n\n[Positive Developments]:\n1. ...\n\n[Potential Concerns]:\n1. ...\n\n"
61
- "[Prediction & Analysis]\nPrediction: ...\nAnalysis: ..."
62
- )
63
-
64
 
65
- ###############################################################################
66
- # Utility functions
67
- ###############################################################################
68
 
69
  def print_gpu_utilization():
70
- """Helper to print GPU utilization (MB) using NVML."""
71
  nvmlInit()
72
  handle = nvmlDeviceGetHandleByIndex(0)
73
  info = nvmlDeviceGetMemoryInfo(handle)
@@ -75,41 +56,37 @@ def print_gpu_utilization():
75
 
76
 
77
  def get_curday():
78
- """Returns today's date in YYYY-MM-DD."""
79
  return date.today().strftime("%Y-%m-%d")
80
 
81
 
82
  def n_weeks_before(date_string, n):
83
- """Given 'date_string' in YYYY-MM-DD and an integer n, returns the date that is n weeks prior."""
84
- d = datetime.strptime(date_string, "%Y-%m-%d") - timedelta(days=7*n)
85
- return d.strftime("%Y-%m-%d")
 
86
 
87
 
88
  def get_stock_data(stock_symbol, steps):
89
- """
90
- Downloads stock price data using yfinance for the given date steps.
91
- Returns a DataFrame containing Start Date, End Date, Start Price, End Price
92
- for each of the time intervals.
93
- """
94
  stock_data = yf.download(stock_symbol, steps[0], steps[-1])
95
  if len(stock_data) == 0:
96
  raise gr.Error(f"Failed to download stock price data for symbol {stock_symbol} from yfinance!")
 
 
97
 
98
- print(stock_data) # For debugging
99
-
100
  dates, prices = [], []
101
  available_dates = stock_data.index.format()
102
-
103
- for date_ in steps[:-1]:
104
  for i in range(len(stock_data)):
105
- if available_dates[i] >= date_:
106
- prices.append(stock_data['Close'].iloc[i])
107
  dates.append(datetime.strptime(available_dates[i], "%Y-%m-%d"))
108
  break
109
 
110
- # Append last date & price
111
  dates.append(datetime.strptime(available_dates[-1], "%Y-%m-%d"))
112
- prices.append(stock_data['Close'].iloc[-1])
113
 
114
  return pd.DataFrame({
115
  "Start Date": dates[:-1],
@@ -119,385 +96,195 @@ def get_stock_data(stock_symbol, steps):
119
  })
120
 
121
 
122
- ###############################################################################
123
- # News retrieval
124
- ###############################################################################
125
-
126
- def parse_polygon_news_item(item):
127
- """
128
- Convert a Polygon news item into a {date, headline, summary} dict similar to the finnhub format used.
129
- Published_utc is in ISO8601, e.g. '2021-04-23T12:47:00Z'.
130
- """
131
- published_str = item.get('published_utc', '')
132
- try:
133
- # Convert e.g. "2021-04-23T12:47:00Z" to a datetime, then to YYYYmmddHHMMSS string
134
- dt = datetime.strptime(published_str, "%Y-%m-%dT%H:%M:%SZ")
135
- date_fmt = dt.strftime("%Y%m%d%H%M%S")
136
- except:
137
- # In case of parsing error, just keep it raw
138
- date_fmt = published_str
139
-
140
- headline = item.get("title", "")
141
- summary = item.get("description", "")
142
- return {
143
- "date": date_fmt,
144
- "headline": headline,
145
- "summary": summary
146
- }
147
-
148
-
149
  def get_news(symbol, data):
150
- """
151
- For each row in data (weekly intervals), fetch both Finnhub and Polygon news.
152
- Combine them, sort by date, store the combined list in data['News'].
153
- """
154
-
155
- combined_news_list = []
156
 
157
- for idx, row in data.iterrows():
 
 
158
  start_date = row['Start Date'].strftime('%Y-%m-%d')
159
  end_date = row['End Date'].strftime('%Y-%m-%d')
160
-
161
- # Sleep for QPM control if needed
162
- time.sleep(1)
163
-
164
- #######################################################################
165
- # 1) Finnhub News for the weekly period
166
- #######################################################################
167
- finnhub_weekly_news = []
168
- try:
169
- finnhub_news = finnhub_client.company_news(symbol, _from=start_date, to=end_date)
170
- for n in finnhub_news:
171
- dt_str = datetime.fromtimestamp(n['datetime']).strftime('%Y%m%d%H%M%S')
172
- finnhub_weekly_news.append({
173
- "date": dt_str,
174
- "headline": n['headline'],
175
- "summary": n['summary']
176
- })
177
- except Exception as e:
178
- # If Finnhub returns an error or no results, we can choose to ignore or handle differently
179
- finnhub_weekly_news = []
180
-
181
- #######################################################################
182
- # 2) Polygon News
183
- # There's no date range param in the sample snippet. We'll fetch
184
- # up to 30 news items, then filter them by start_date <= published <= end_date
185
- #######################################################################
186
- polygon_weekly_news = []
187
- polygon_url = (
188
- f"https://api.polygon.io/v2/reference/news"
189
- f"?ticker={symbol}"
190
- f"&order=asc"
191
- f"&limit=30"
192
- f"&sort=published_utc"
193
- f"&apiKey={polygon_api_key}"
194
- )
195
- try:
196
- resp = requests.get(polygon_url)
197
- if resp.status_code == 200:
198
- polygon_data = resp.json()
199
- results = polygon_data.get('results', [])
200
- for item in results:
201
- news_item = parse_polygon_news_item(item)
202
- # Filter by date range
203
- # news_item['date'] is "YYYYmmddHHMMSS"
204
- # We compare it with start_date/end_date as YYYY-MM-DD (the day-based boundary).
205
- # So let's parse it properly to do a valid comparison:
206
- try:
207
- dt_item = datetime.strptime(news_item['date'], "%Y%m%d%H%M%S")
208
- dt_start = datetime.strptime(start_date, "%Y-%m-%d")
209
- dt_end = datetime.strptime(end_date, "%Y-%m-%d") + timedelta(days=1) - timedelta(seconds=1)
210
- if dt_start <= dt_item <= dt_end:
211
- polygon_weekly_news.append(news_item)
212
- except:
213
- # If for any reason the date parse fails, skip
214
- pass
215
- else:
216
- print(f"Polygon news request returned status code {resp.status_code}")
217
- except Exception as e:
218
- print(f"Polygon news request failed with exception: {e}")
219
- polygon_weekly_news = []
220
-
221
- # Merge the two news lists
222
- weekly_news_combined = finnhub_weekly_news + polygon_weekly_news
223
- # Sort by date
224
- weekly_news_combined.sort(key=lambda x: x['date'])
225
-
226
- # If no news from either source, you can decide to raise an error or keep empty
227
- # Here we won't raise an error if both are empty, so it continues gracefully:
228
- # if len(weekly_news_combined) == 0:
229
- # raise gr.Error(f"No company news found for symbol {symbol} from both Finnhub & Polygon in {start_date}-{end_date}!")
230
-
231
- combined_news_list.append(json.dumps(weekly_news_combined))
232
-
233
- data['News'] = combined_news_list
234
  return data
235
 
236
 
237
- ###############################################################################
238
- # Polygon Financials
239
- ###############################################################################
240
-
241
- def get_current_basics(symbol, curday):
242
- """
243
- Fetch the latest (limit=1, order=desc) financials from Polygon and parse out
244
- a condensed dictionary of relevant metrics. Replaces the old Finnhub approach.
245
- """
246
- url = (
247
- f"https://api.polygon.io/vX/reference/financials"
248
- f"?ticker={symbol}"
249
- f"&order=desc"
250
- f"&limit=1"
251
- f"&sort=filing_date"
252
- f"&apiKey={polygon_api_key}"
253
- )
254
- resp = requests.get(url)
255
- if resp.status_code != 200:
256
- raise gr.Error(f"Failed to retrieve financial data from Polygon! status={resp.status_code}")
257
-
258
- data = resp.json()
259
- if 'results' not in data or len(data['results']) == 0:
260
- raise gr.Error(f"No financial results found from Polygon for ticker {symbol}!")
261
-
262
- result = data['results'][0]
263
-
264
- # We can store a small set of relevant metrics
265
- filing_date = result.get('filing_date', '')
266
- fiscal_year = result.get('fiscal_year', 'N/A')
267
- fiscal_period = result.get('fiscal_period', 'N/A')
268
- fin = result.get('financials', {})
269
-
270
- # Helper to safely retrieve a nested value
271
- def get_nested_value(d, path):
272
- # path is a list of keys, e.g. ['income_statement','revenues']
273
- temp = d
274
- for p in path:
275
- if p not in temp:
276
- return None
277
- temp = temp[p]
278
- # Return .get("value") if present
279
- return temp.get("value", None)
280
-
281
- # Construct a dictionary for the prompt
282
- # We'll reference these in the final user prompt
283
- basics = {
284
- "period": f"{fiscal_year} {fiscal_period}",
285
- "filing_date": filing_date,
286
- "Revenue": get_nested_value(fin, ["income_statement", "revenues"]),
287
- "NetIncomeLoss": get_nested_value(fin, ["income_statement", "net_income_loss_attributable_to_parent"]),
288
- "DilutedEPS": get_nested_value(fin, ["income_statement", "diluted_earnings_per_share"]),
289
- "BasicEPS": get_nested_value(fin, ["income_statement", "basic_earnings_per_share"]),
290
- }
291
-
292
- return basics
293
-
294
-
295
- ###############################################################################
296
- # Prompt engineering
297
- ###############################################################################
298
-
299
  def get_company_prompt(symbol):
300
- """
301
- Pull a minimal Finnhub company profile to produce a short introduction.
302
- (We keep this from Finnhub if desired, as that info isn't replaced by Polygon.)
303
- """
304
  profile = finnhub_client.company_profile2(symbol=symbol)
305
  if not profile:
306
  raise gr.Error(f"Failed to find company profile for symbol {symbol} from finnhub!")
307
 
308
- company_template = (
309
- "[Company Introduction]:\n\n{name} is a leading entity in the {finnhubIndustry} sector. "
310
- "Incorporated and publicly traded since {ipo}, the company has established its reputation "
311
- "as one of the key players in the market. As of today, {name} has a market capitalization "
312
- "of {marketCapitalization:.2f} in {currency}, with {shareOutstanding:.2f} shares outstanding.\n\n"
313
- "{name} operates primarily in the {country}, trading under the ticker {ticker} on the {exchange}. "
314
- "As a dominant force in the {finnhubIndustry} space, the company continues to innovate and drive "
315
- "progress within the industry."
316
- )
317
 
318
  formatted_str = company_template.format(**profile)
 
319
  return formatted_str
320
 
321
 
322
  def get_prompt_by_row(symbol, row):
323
- """
324
- Given a single row with Start Price, End Price, News, return a textual summary
325
- plus the news items.
326
- """
327
  end_price = float(row['End Price'])
328
  start_price = float(row['Start Price'])
329
  term = 'increased' if end_price > start_price else 'decreased'
330
 
331
- start_date = row['Start Date'].strftime('%Y-%m-%d')
332
- end_date = row['End Date'].strftime('%Y-%m-%d')
333
  head = f"From {start_date} to {end_date}, {symbol}'s stock price {term} from {start_price:.2f} to {end_price:.2f}. Company news during this period are listed below:\n\n"
334
 
335
- # row["News"] is a JSON string, parse it
336
  news = row["News"] if isinstance(row["News"], list) else json.loads(row["News"])
337
- # We do not need 'Basics' from each row, because we fetch the latest basics only at the end
338
- return head, news, None
 
 
339
 
340
 
341
  def sample_news(news, k=5):
342
- """Randomly sample up to k news items from a list."""
343
- if len(news) <= k:
344
- return news
345
  return [news[i] for i in sorted(random.sample(range(len(news)), k))]
346
-
347
-
348
  def latest_news(news, k=5):
349
- """Get up to k of the most recently dated news from 'news', sorted descending."""
350
  sorted_news = sorted(news, key=lambda x: x['date'], reverse=True)
351
  return sorted_news[:k]
352
 
353
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
354
  def get_all_prompts_online(symbol, data, curday, with_basics=True):
355
- """
356
- Build a final prompt string from:
357
- - The company prompt (finnhub profile)
358
- - The historical intervals from 'data'
359
- - The current basics from Polygon
360
- """
361
  company_prompt = get_company_prompt(symbol)
362
 
363
  prev_rows = []
364
- for _, row in data.iterrows():
365
- head, news, _ = get_prompt_by_row(symbol, row)
366
- prev_rows.append((head, news))
367
 
368
- # Build up the text describing previous intervals
369
- prompt_body = ""
 
 
 
370
  for i in range(-len(prev_rows), 0):
371
- head, news_list = prev_rows[i]
372
- prompt_body += "\n" + head
373
- # Show up to 5 latest news for each interval
374
- chosen_news = latest_news(news_list, min(5, len(news_list)))
375
- if chosen_news:
376
- # Format each news item
377
- formatted_news_items = [
378
- "[Headline]: {}\n[Summary]: {}\n".format(n['headline'], n['summary'])
379
- for n in chosen_news
380
- ]
381
- prompt_body += "\n".join(formatted_news_items)
382
  else:
383
- prompt_body += "No relative news reported."
384
-
385
- period = f"{curday} to {n_weeks_before(curday, -1)}"
386
-
387
- # If user wants the latest polygon financials
388
  if with_basics:
389
  basics = get_current_basics(symbol, curday)
390
- basics_text = (
391
- f"Some recent basic financials of {symbol}, reported at {basics['filing_date']} "
392
- f"(period: {basics['period']}), are presented below:\n\n[Basic Financials]:\n\n"
393
- )
394
- # Append each metric
395
- for k, v in basics.items():
396
- if k not in ["period", "filing_date"]:
397
- basics_text += f"{k}: {v}\n"
398
  else:
399
- basics_text = "[Basic Financials]:\n\nNo basic financial reported."
400
 
401
- info_block = company_prompt + '\n' + prompt_body + '\n' + basics_text
402
-
403
- # This is the final request prompt for the language model
404
- final_prompt = (
405
- info_block
406
- + f"\n\nBased on all the information before {curday}, let's first analyze the positive developments "
407
- f"and potential concerns for {symbol}. Come up with 2-4 most important factors respectively and "
408
- f"keep them concise. Most factors should be inferred from company-related news. Then make your "
409
- f"prediction of the {symbol} stock price movement for next week ({period}). Provide a summary analysis "
410
- f"to support your prediction."
411
- )
412
-
413
- return info_block, final_prompt
414
 
415
 
416
- ###############################################################################
417
- # Gradio pipeline
418
- ###############################################################################
419
 
420
  def construct_prompt(ticker, curday, n_weeks, use_basics):
421
- """
422
- Build the final prompt to feed into the model by collecting:
423
- 1) Stock data from yfinance (past n_weeks intervals)
424
- 2) News from Finnhub & Polygon
425
- 3) Current basics from Polygon (if requested)
426
- """
427
  try:
428
  steps = [n_weeks_before(curday, n) for n in range(n_weeks + 1)][::-1]
429
  except Exception:
430
  raise gr.Error(f"Invalid date {curday}!")
431
-
432
- # 1) Stock data
433
  data = get_stock_data(ticker, steps)
434
- # 2) News data (from Finnhub & Polygon)
435
  data = get_news(ticker, data)
436
- # We don't store the basics in each row anymore, so just place empty dict
437
  data['Basics'] = [json.dumps({})] * len(data)
438
-
439
- # 3) Construct final prompt
440
  info, prompt = get_all_prompts_online(ticker, data, curday, use_basics)
441
 
442
- # Format with system instructions for Llama
443
  prompt = B_INST + B_SYS + SYSTEM_PROMPT + E_SYS + prompt + E_INST
 
444
 
445
  return info, prompt
446
 
447
 
448
- def predict(ticker, date_, n_weeks, use_basics):
449
- """
450
- Main function triggered by Gradio.
451
- 1) Builds the prompt,
452
- 2) Generates the response via the LLaMA-2 + LoRA model,
453
- 3) Returns the full info block + final answer.
454
- """
455
  print_gpu_utilization()
456
-
457
- # Construct the full prompt
458
- info, prompt = construct_prompt(ticker, date_, n_weeks, use_basics)
459
-
460
- # Tokenize
461
- inputs = tokenizer(prompt, return_tensors='pt', padding=False)
462
- inputs = {key: value.to(model.device) for key, value in inputs.items()}
463
 
464
- print("Inputs loaded onto device...")
 
 
 
 
 
465
 
466
- # Generate
 
467
  res = model.generate(
468
- **inputs,
469
- max_length=4096,
470
- do_sample=False,
471
  eos_token_id=tokenizer.eos_token_id,
472
- use_cache=True,
473
- streamer=streamer
474
  )
475
  output = tokenizer.decode(res[0], skip_special_tokens=True)
476
- # Remove everything up to the closing [/INST]
477
  answer = re.sub(r'.*\[/INST\]\s*', '', output, flags=re.DOTALL)
478
 
479
- # Clean up GPU memory
480
  torch.cuda.empty_cache()
481
 
482
  return info, answer
483
 
484
 
485
- ################################################################################
486
- # Gradio Interface
487
- ################################################################################
488
-
489
  demo = gr.Interface(
490
- fn=predict,
491
  inputs=[
492
  gr.Textbox(
493
  label="Ticker",
494
  value="AAPL",
495
- info="Companies from Dow-30 are recommended"
496
  ),
497
  gr.Textbox(
498
  label="Date",
499
  value=get_curday,
500
- info="Date from which the prediction is made, format yyyy-mm-dd"
501
  ),
502
  gr.Slider(
503
  minimum=1,
@@ -505,22 +292,25 @@ demo = gr.Interface(
505
  value=3,
506
  step=1,
507
  label="n_weeks",
508
- info="Information of the past n weeks will be utilized. Choose between 1 and 4."
509
  ),
510
  gr.Checkbox(
511
  label="Use Latest Basic Financials",
512
  value=False,
513
- info="If checked, the latest quarterly reported basic financials from Polygon are taken into account."
514
  )
515
  ],
516
  outputs=[
517
- gr.Textbox(label="Information"),
518
- gr.Textbox(label="Response")
 
 
 
 
519
  ],
520
  title="Pro Capital",
521
- description="Pro Capital implementation using yfinance for historical prices, "
522
- "Finnhub + Polygon for news, and Polygon for financials."
523
  )
524
 
525
- if __name__ == "__main__":
526
- demo.launch()
 
3
  import time
4
  import json
5
  import random
 
6
  import finnhub
7
  import torch
8
  import gradio as gr
 
14
  from datetime import date, datetime, timedelta
15
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
16
 
 
 
 
17
 
18
+ access_token = os.environ["HF_TOKEN"]
19
+ finnhub_client = finnhub.Client(api_key=os.environ["FINNHUB_API_KEY"])
 
 
20
 
 
 
 
21
  base_model = AutoModelForCausalLM.from_pretrained(
22
  'meta-llama/Llama-2-7b-chat-hf',
23
  token=access_token,
 
40
 
41
  streamer = TextStreamer(tokenizer)
42
 
 
43
  B_INST, E_INST = "[INST]", "[/INST]"
44
  B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
45
 
46
+ SYSTEM_PROMPT = "You are a seasoned stock market analyst. Your task is to list the positive developments and potential concerns for companies based on relevant news and basic financials from the past weeks, then provide an analysis and prediction for the companies' stock price movement for the upcoming week. " \
47
+ "Your answer format should be as follows:\n\n[Positive Developments]:\n1. ...\n\n[Potential Concerns]:\n1. ...\n\n[Prediction & Analysis]\nPrediction: ...\nAnalysis: ..."
 
 
 
 
 
 
48
 
 
 
 
49
 
50
  def print_gpu_utilization():
51
+
52
  nvmlInit()
53
  handle = nvmlDeviceGetHandleByIndex(0)
54
  info = nvmlDeviceGetMemoryInfo(handle)
 
56
 
57
 
58
  def get_curday():
59
+
60
  return date.today().strftime("%Y-%m-%d")
61
 
62
 
63
  def n_weeks_before(date_string, n):
64
+
65
+ date = datetime.strptime(date_string, "%Y-%m-%d") - timedelta(days=7*n)
66
+
67
+ return date.strftime("%Y-%m-%d")
68
 
69
 
70
  def get_stock_data(stock_symbol, steps):
 
 
 
 
 
71
  stock_data = yf.download(stock_symbol, steps[0], steps[-1])
72
  if len(stock_data) == 0:
73
  raise gr.Error(f"Failed to download stock price data for symbol {stock_symbol} from yfinance!")
74
+
75
+ print(stock_data)
76
 
 
 
77
  dates, prices = [], []
78
  available_dates = stock_data.index.format()
79
+
80
+ for date in steps[:-1]:
81
  for i in range(len(stock_data)):
82
+ if available_dates[i] >= date:
83
+ prices.append(stock_data['Close'].iloc[i]) # Use .iloc here
84
  dates.append(datetime.strptime(available_dates[i], "%Y-%m-%d"))
85
  break
86
 
87
+ # Append the last date and price
88
  dates.append(datetime.strptime(available_dates[-1], "%Y-%m-%d"))
89
+ prices.append(stock_data['Close'].iloc[-1]) # Use .iloc here as well
90
 
91
  return pd.DataFrame({
92
  "Start Date": dates[:-1],
 
96
  })
97
 
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  def get_news(symbol, data):
 
 
 
 
 
 
100
 
101
+ news_list = []
102
+
103
+ for end_date, row in data.iterrows():
104
  start_date = row['Start Date'].strftime('%Y-%m-%d')
105
  end_date = row['End Date'].strftime('%Y-%m-%d')
106
+ # print(symbol, ': ', start_date, ' - ', end_date)
107
+ time.sleep(1) # control qpm
108
+ weekly_news = finnhub_client.company_news(symbol, _from=start_date, to=end_date)
109
+ if len(weekly_news) == 0:
110
+ raise gr.Error(f"No company news found for symbol {symbol} from finnhub!")
111
+ weekly_news = [
112
+ {
113
+ "date": datetime.fromtimestamp(n['datetime']).strftime('%Y%m%d%H%M%S'),
114
+ "headline": n['headline'],
115
+ "summary": n['summary'],
116
+ } for n in weekly_news
117
+ ]
118
+ weekly_news.sort(key=lambda x: x['date'])
119
+ news_list.append(json.dumps(weekly_news))
120
+
121
+ data['News'] = news_list
122
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  return data
124
 
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  def get_company_prompt(symbol):
127
+
 
 
 
128
  profile = finnhub_client.company_profile2(symbol=symbol)
129
  if not profile:
130
  raise gr.Error(f"Failed to find company profile for symbol {symbol} from finnhub!")
131
 
132
+ company_template = "[Company Introduction]:\n\n{name} is a leading entity in the {finnhubIndustry} sector. Incorporated and publicly traded since {ipo}, the company has established its reputation as one of the key players in the market. As of today, {name} has a market capitalization of {marketCapitalization:.2f} in {currency}, with {shareOutstanding:.2f} shares outstanding." \
133
+ "\n\n{name} operates primarily in the {country}, trading under the ticker {ticker} on the {exchange}. As a dominant force in the {finnhubIndustry} space, the company continues to innovate and drive progress within the industry."
 
 
 
 
 
 
 
134
 
135
  formatted_str = company_template.format(**profile)
136
+
137
  return formatted_str
138
 
139
 
140
  def get_prompt_by_row(symbol, row):
 
 
 
 
141
  end_price = float(row['End Price'])
142
  start_price = float(row['Start Price'])
143
  term = 'increased' if end_price > start_price else 'decreased'
144
 
145
+ start_date = row['Start Date'] if isinstance(row['Start Date'], str) else row['Start Date'].strftime('%Y-%m-%d')
146
+ end_date = row['End Date'] if isinstance(row['End Date'], str) else row['End Date'].strftime('%Y-%m-%d')
147
  head = f"From {start_date} to {end_date}, {symbol}'s stock price {term} from {start_price:.2f} to {end_price:.2f}. Company news during this period are listed below:\n\n"
148
 
 
149
  news = row["News"] if isinstance(row["News"], list) else json.loads(row["News"])
150
+ basics = json.loads(row['Basics'])
151
+
152
+ return head, news, basics
153
+
154
 
155
 
156
  def sample_news(news, k=5):
157
+
 
 
158
  return [news[i] for i in sorted(random.sample(range(len(news)), k))]
159
+
 
160
  def latest_news(news, k=5):
 
161
  sorted_news = sorted(news, key=lambda x: x['date'], reverse=True)
162
  return sorted_news[:k]
163
 
164
 
165
+ def get_current_basics(symbol, curday):
166
+
167
+ basic_financials = finnhub_client.company_basic_financials(symbol, 'all')
168
+ if not basic_financials['series']:
169
+ raise gr.Error(f"Failed to find basic financials for symbol {symbol} from finnhub!")
170
+
171
+ final_basics, basic_list, basic_dict = [], [], defaultdict(dict)
172
+
173
+ for metric, value_list in basic_financials['series']['quarterly'].items():
174
+ for value in value_list:
175
+ basic_dict[value['period']].update({metric: value['v']})
176
+
177
+ for k, v in basic_dict.items():
178
+ v.update({'period': k})
179
+ basic_list.append(v)
180
+
181
+ basic_list.sort(key=lambda x: x['period'])
182
+
183
+ for basic in basic_list[::-1]:
184
+ if basic['period'] <= curday:
185
+ break
186
+
187
+ return basic
188
+
189
+
190
  def get_all_prompts_online(symbol, data, curday, with_basics=True):
 
 
 
 
 
 
191
  company_prompt = get_company_prompt(symbol)
192
 
193
  prev_rows = []
 
 
 
194
 
195
+ for row_idx, row in data.iterrows():
196
+ head, news, _ = get_prompt_by_row(symbol, row)
197
+ prev_rows.append((head, news, None))
198
+
199
+ prompt = ""
200
  for i in range(-len(prev_rows), 0):
201
+ prompt += "\n" + prev_rows[i][0]
202
+ latest_news_items = latest_news(
203
+ prev_rows[i][1],
204
+ min(5, len(prev_rows[i][1]))
205
+ )
206
+ if latest_news_items:
207
+ # Ensure each news item is formatted as a string
208
+ formatted_news_items = ["[Headline]: {}\n[Summary]: {}\n".format(n['headline'], n['summary']) for n in latest_news_items]
209
+ prompt += "\n".join(formatted_news_items)
 
 
210
  else:
211
+ prompt += "No relative news reported."
212
+
213
+ period = "{} to {}".format(curday, n_weeks_before(curday, -1))
214
+
 
215
  if with_basics:
216
  basics = get_current_basics(symbol, curday)
217
+ basics = "Some recent basic financials of {}, reported at {}, are presented below:\n\n[Basic Financials]:\n\n".format(
218
+ symbol, basics['period']) + "\n".join(f"{k}: {v}" for k, v in basics.items() if k != 'period')
 
 
 
 
 
 
219
  else:
220
+ basics = "[Basic Financials]:\n\nNo basic financial reported."
221
 
222
+ info = company_prompt + '\n' + prompt + '\n' + basics
223
+ prompt = info + f"\n\nBased on all the information before {curday}, let's first analyze the positive developments and potential concerns for {symbol}. Come up with 2-4 most important factors respectively and keep them concise. Most factors should be inferred from company related news. " \
224
+ f"Then make your prediction of the {symbol} stock price movement for next week ({period}). Provide a summary analysis to support your prediction."
225
+
226
+ return info, prompt
 
 
 
 
 
 
 
 
227
 
228
 
 
 
 
229
 
230
  def construct_prompt(ticker, curday, n_weeks, use_basics):
231
+
 
 
 
 
 
232
  try:
233
  steps = [n_weeks_before(curday, n) for n in range(n_weeks + 1)][::-1]
234
  except Exception:
235
  raise gr.Error(f"Invalid date {curday}!")
236
+
 
237
  data = get_stock_data(ticker, steps)
 
238
  data = get_news(ticker, data)
 
239
  data['Basics'] = [json.dumps({})] * len(data)
240
+ # print(data)
241
+
242
  info, prompt = get_all_prompts_online(ticker, data, curday, use_basics)
243
 
 
244
  prompt = B_INST + B_SYS + SYSTEM_PROMPT + E_SYS + prompt + E_INST
245
+ # print(prompt)
246
 
247
  return info, prompt
248
 
249
 
250
+ def predict(ticker, date, n_weeks, use_basics):
251
+
 
 
 
 
 
252
  print_gpu_utilization()
 
 
 
 
 
 
 
253
 
254
+ info, prompt = construct_prompt(ticker, date, n_weeks, use_basics)
255
+
256
+ inputs = tokenizer(
257
+ prompt, return_tensors='pt', padding=False
258
+ )
259
+ inputs = {key: value.to(model.device) for key, value in inputs.items()}
260
 
261
+ print("Inputs loaded onto devices.")
262
+
263
  res = model.generate(
264
+ **inputs, max_length=4096, do_sample=False,
 
 
265
  eos_token_id=tokenizer.eos_token_id,
266
+ use_cache=True, streamer=streamer
 
267
  )
268
  output = tokenizer.decode(res[0], skip_special_tokens=True)
 
269
  answer = re.sub(r'.*\[/INST\]\s*', '', output, flags=re.DOTALL)
270
 
 
271
  torch.cuda.empty_cache()
272
 
273
  return info, answer
274
 
275
 
 
 
 
 
276
  demo = gr.Interface(
277
+ predict,
278
  inputs=[
279
  gr.Textbox(
280
  label="Ticker",
281
  value="AAPL",
282
+ info="Companys from Dow-30 are recommended"
283
  ),
284
  gr.Textbox(
285
  label="Date",
286
  value=get_curday,
287
+ info="Date from which the prediction is made, use format yyyy-mm-dd"
288
  ),
289
  gr.Slider(
290
  minimum=1,
 
292
  value=3,
293
  step=1,
294
  label="n_weeks",
295
+ info="Information of the past n weeks will be utilized, choose between 1 and 4"
296
  ),
297
  gr.Checkbox(
298
  label="Use Latest Basic Financials",
299
  value=False,
300
+ info="If checked, the latest quarterly reported basic financials of the company is taken into account."
301
  )
302
  ],
303
  outputs=[
304
+ gr.Textbox(
305
+ label="Information"
306
+ ),
307
+ gr.Textbox(
308
+ label="Response"
309
+ )
310
  ],
311
  title="Pro Capital",
312
+ description="""Pro Capital implementation.**
313
+ """
314
  )
315
 
316
+ demo.launch()