Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -3,7 +3,6 @@ import re
|
|
3 |
import time
|
4 |
import json
|
5 |
import random
|
6 |
-
import requests # <-- For Polygon API calls
|
7 |
import finnhub
|
8 |
import torch
|
9 |
import gradio as gr
|
@@ -15,18 +14,10 @@ from collections import defaultdict
|
|
15 |
from datetime import date, datetime, timedelta
|
16 |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
|
17 |
|
18 |
-
################################################################################
|
19 |
-
# Set up environment variables, tokens, model, tokenizer, and other base config
|
20 |
-
################################################################################
|
21 |
|
22 |
-
|
23 |
-
|
24 |
-
finnhub_api_key = os.environ.get("FINNHUB_API_KEY", "YOUR_FINNHUB_API_KEY")
|
25 |
-
polygon_api_key = "OeD88ExoL458WlvNpCyiZXnHI0s_h05t" # <--- Replace "X" with your actual Polygon key
|
26 |
|
27 |
-
finnhub_client = finnhub.Client(api_key=finnhub_api_key)
|
28 |
-
|
29 |
-
# Load base model & LoRA
|
30 |
base_model = AutoModelForCausalLM.from_pretrained(
|
31 |
'meta-llama/Llama-2-7b-chat-hf',
|
32 |
token=access_token,
|
@@ -49,25 +40,15 @@ tokenizer = AutoTokenizer.from_pretrained(
|
|
49 |
|
50 |
streamer = TextStreamer(tokenizer)
|
51 |
|
52 |
-
# Special Llama-format prompt tokens
|
53 |
B_INST, E_INST = "[INST]", "[/INST]"
|
54 |
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
55 |
|
56 |
-
SYSTEM_PROMPT =
|
57 |
-
"
|
58 |
-
"potential concerns for companies based on relevant news and basic financials from the past weeks, "
|
59 |
-
"then provide an analysis and prediction for the companies' stock price movement for the upcoming week. "
|
60 |
-
"Your answer format should be as follows:\n\n[Positive Developments]:\n1. ...\n\n[Potential Concerns]:\n1. ...\n\n"
|
61 |
-
"[Prediction & Analysis]\nPrediction: ...\nAnalysis: ..."
|
62 |
-
)
|
63 |
-
|
64 |
|
65 |
-
###############################################################################
|
66 |
-
# Utility functions
|
67 |
-
###############################################################################
|
68 |
|
69 |
def print_gpu_utilization():
|
70 |
-
|
71 |
nvmlInit()
|
72 |
handle = nvmlDeviceGetHandleByIndex(0)
|
73 |
info = nvmlDeviceGetMemoryInfo(handle)
|
@@ -75,41 +56,37 @@ def print_gpu_utilization():
|
|
75 |
|
76 |
|
77 |
def get_curday():
|
78 |
-
|
79 |
return date.today().strftime("%Y-%m-%d")
|
80 |
|
81 |
|
82 |
def n_weeks_before(date_string, n):
|
83 |
-
|
84 |
-
|
85 |
-
|
|
|
86 |
|
87 |
|
88 |
def get_stock_data(stock_symbol, steps):
|
89 |
-
"""
|
90 |
-
Downloads stock price data using yfinance for the given date steps.
|
91 |
-
Returns a DataFrame containing Start Date, End Date, Start Price, End Price
|
92 |
-
for each of the time intervals.
|
93 |
-
"""
|
94 |
stock_data = yf.download(stock_symbol, steps[0], steps[-1])
|
95 |
if len(stock_data) == 0:
|
96 |
raise gr.Error(f"Failed to download stock price data for symbol {stock_symbol} from yfinance!")
|
|
|
|
|
97 |
|
98 |
-
print(stock_data) # For debugging
|
99 |
-
|
100 |
dates, prices = [], []
|
101 |
available_dates = stock_data.index.format()
|
102 |
-
|
103 |
-
for
|
104 |
for i in range(len(stock_data)):
|
105 |
-
if available_dates[i] >=
|
106 |
-
prices.append(stock_data['Close'].iloc[i])
|
107 |
dates.append(datetime.strptime(available_dates[i], "%Y-%m-%d"))
|
108 |
break
|
109 |
|
110 |
-
# Append last date
|
111 |
dates.append(datetime.strptime(available_dates[-1], "%Y-%m-%d"))
|
112 |
-
prices.append(stock_data['Close'].iloc[-1])
|
113 |
|
114 |
return pd.DataFrame({
|
115 |
"Start Date": dates[:-1],
|
@@ -119,385 +96,195 @@ def get_stock_data(stock_symbol, steps):
|
|
119 |
})
|
120 |
|
121 |
|
122 |
-
###############################################################################
|
123 |
-
# News retrieval
|
124 |
-
###############################################################################
|
125 |
-
|
126 |
-
def parse_polygon_news_item(item):
|
127 |
-
"""
|
128 |
-
Convert a Polygon news item into a {date, headline, summary} dict similar to the finnhub format used.
|
129 |
-
Published_utc is in ISO8601, e.g. '2021-04-23T12:47:00Z'.
|
130 |
-
"""
|
131 |
-
published_str = item.get('published_utc', '')
|
132 |
-
try:
|
133 |
-
# Convert e.g. "2021-04-23T12:47:00Z" to a datetime, then to YYYYmmddHHMMSS string
|
134 |
-
dt = datetime.strptime(published_str, "%Y-%m-%dT%H:%M:%SZ")
|
135 |
-
date_fmt = dt.strftime("%Y%m%d%H%M%S")
|
136 |
-
except:
|
137 |
-
# In case of parsing error, just keep it raw
|
138 |
-
date_fmt = published_str
|
139 |
-
|
140 |
-
headline = item.get("title", "")
|
141 |
-
summary = item.get("description", "")
|
142 |
-
return {
|
143 |
-
"date": date_fmt,
|
144 |
-
"headline": headline,
|
145 |
-
"summary": summary
|
146 |
-
}
|
147 |
-
|
148 |
-
|
149 |
def get_news(symbol, data):
|
150 |
-
"""
|
151 |
-
For each row in data (weekly intervals), fetch both Finnhub and Polygon news.
|
152 |
-
Combine them, sort by date, store the combined list in data['News'].
|
153 |
-
"""
|
154 |
-
|
155 |
-
combined_news_list = []
|
156 |
|
157 |
-
|
|
|
|
|
158 |
start_date = row['Start Date'].strftime('%Y-%m-%d')
|
159 |
end_date = row['End Date'].strftime('%Y-%m-%d')
|
160 |
-
|
161 |
-
#
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
for n in
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
except Exception as e:
|
178 |
-
# If Finnhub returns an error or no results, we can choose to ignore or handle differently
|
179 |
-
finnhub_weekly_news = []
|
180 |
-
|
181 |
-
#######################################################################
|
182 |
-
# 2) Polygon News
|
183 |
-
# There's no date range param in the sample snippet. We'll fetch
|
184 |
-
# up to 30 news items, then filter them by start_date <= published <= end_date
|
185 |
-
#######################################################################
|
186 |
-
polygon_weekly_news = []
|
187 |
-
polygon_url = (
|
188 |
-
f"https://api.polygon.io/v2/reference/news"
|
189 |
-
f"?ticker={symbol}"
|
190 |
-
f"&order=asc"
|
191 |
-
f"&limit=30"
|
192 |
-
f"&sort=published_utc"
|
193 |
-
f"&apiKey={polygon_api_key}"
|
194 |
-
)
|
195 |
-
try:
|
196 |
-
resp = requests.get(polygon_url)
|
197 |
-
if resp.status_code == 200:
|
198 |
-
polygon_data = resp.json()
|
199 |
-
results = polygon_data.get('results', [])
|
200 |
-
for item in results:
|
201 |
-
news_item = parse_polygon_news_item(item)
|
202 |
-
# Filter by date range
|
203 |
-
# news_item['date'] is "YYYYmmddHHMMSS"
|
204 |
-
# We compare it with start_date/end_date as YYYY-MM-DD (the day-based boundary).
|
205 |
-
# So let's parse it properly to do a valid comparison:
|
206 |
-
try:
|
207 |
-
dt_item = datetime.strptime(news_item['date'], "%Y%m%d%H%M%S")
|
208 |
-
dt_start = datetime.strptime(start_date, "%Y-%m-%d")
|
209 |
-
dt_end = datetime.strptime(end_date, "%Y-%m-%d") + timedelta(days=1) - timedelta(seconds=1)
|
210 |
-
if dt_start <= dt_item <= dt_end:
|
211 |
-
polygon_weekly_news.append(news_item)
|
212 |
-
except:
|
213 |
-
# If for any reason the date parse fails, skip
|
214 |
-
pass
|
215 |
-
else:
|
216 |
-
print(f"Polygon news request returned status code {resp.status_code}")
|
217 |
-
except Exception as e:
|
218 |
-
print(f"Polygon news request failed with exception: {e}")
|
219 |
-
polygon_weekly_news = []
|
220 |
-
|
221 |
-
# Merge the two news lists
|
222 |
-
weekly_news_combined = finnhub_weekly_news + polygon_weekly_news
|
223 |
-
# Sort by date
|
224 |
-
weekly_news_combined.sort(key=lambda x: x['date'])
|
225 |
-
|
226 |
-
# If no news from either source, you can decide to raise an error or keep empty
|
227 |
-
# Here we won't raise an error if both are empty, so it continues gracefully:
|
228 |
-
# if len(weekly_news_combined) == 0:
|
229 |
-
# raise gr.Error(f"No company news found for symbol {symbol} from both Finnhub & Polygon in {start_date}-{end_date}!")
|
230 |
-
|
231 |
-
combined_news_list.append(json.dumps(weekly_news_combined))
|
232 |
-
|
233 |
-
data['News'] = combined_news_list
|
234 |
return data
|
235 |
|
236 |
|
237 |
-
###############################################################################
|
238 |
-
# Polygon Financials
|
239 |
-
###############################################################################
|
240 |
-
|
241 |
-
def get_current_basics(symbol, curday):
|
242 |
-
"""
|
243 |
-
Fetch the latest (limit=1, order=desc) financials from Polygon and parse out
|
244 |
-
a condensed dictionary of relevant metrics. Replaces the old Finnhub approach.
|
245 |
-
"""
|
246 |
-
url = (
|
247 |
-
f"https://api.polygon.io/vX/reference/financials"
|
248 |
-
f"?ticker={symbol}"
|
249 |
-
f"&order=desc"
|
250 |
-
f"&limit=1"
|
251 |
-
f"&sort=filing_date"
|
252 |
-
f"&apiKey={polygon_api_key}"
|
253 |
-
)
|
254 |
-
resp = requests.get(url)
|
255 |
-
if resp.status_code != 200:
|
256 |
-
raise gr.Error(f"Failed to retrieve financial data from Polygon! status={resp.status_code}")
|
257 |
-
|
258 |
-
data = resp.json()
|
259 |
-
if 'results' not in data or len(data['results']) == 0:
|
260 |
-
raise gr.Error(f"No financial results found from Polygon for ticker {symbol}!")
|
261 |
-
|
262 |
-
result = data['results'][0]
|
263 |
-
|
264 |
-
# We can store a small set of relevant metrics
|
265 |
-
filing_date = result.get('filing_date', '')
|
266 |
-
fiscal_year = result.get('fiscal_year', 'N/A')
|
267 |
-
fiscal_period = result.get('fiscal_period', 'N/A')
|
268 |
-
fin = result.get('financials', {})
|
269 |
-
|
270 |
-
# Helper to safely retrieve a nested value
|
271 |
-
def get_nested_value(d, path):
|
272 |
-
# path is a list of keys, e.g. ['income_statement','revenues']
|
273 |
-
temp = d
|
274 |
-
for p in path:
|
275 |
-
if p not in temp:
|
276 |
-
return None
|
277 |
-
temp = temp[p]
|
278 |
-
# Return .get("value") if present
|
279 |
-
return temp.get("value", None)
|
280 |
-
|
281 |
-
# Construct a dictionary for the prompt
|
282 |
-
# We'll reference these in the final user prompt
|
283 |
-
basics = {
|
284 |
-
"period": f"{fiscal_year} {fiscal_period}",
|
285 |
-
"filing_date": filing_date,
|
286 |
-
"Revenue": get_nested_value(fin, ["income_statement", "revenues"]),
|
287 |
-
"NetIncomeLoss": get_nested_value(fin, ["income_statement", "net_income_loss_attributable_to_parent"]),
|
288 |
-
"DilutedEPS": get_nested_value(fin, ["income_statement", "diluted_earnings_per_share"]),
|
289 |
-
"BasicEPS": get_nested_value(fin, ["income_statement", "basic_earnings_per_share"]),
|
290 |
-
}
|
291 |
-
|
292 |
-
return basics
|
293 |
-
|
294 |
-
|
295 |
-
###############################################################################
|
296 |
-
# Prompt engineering
|
297 |
-
###############################################################################
|
298 |
-
|
299 |
def get_company_prompt(symbol):
|
300 |
-
|
301 |
-
Pull a minimal Finnhub company profile to produce a short introduction.
|
302 |
-
(We keep this from Finnhub if desired, as that info isn't replaced by Polygon.)
|
303 |
-
"""
|
304 |
profile = finnhub_client.company_profile2(symbol=symbol)
|
305 |
if not profile:
|
306 |
raise gr.Error(f"Failed to find company profile for symbol {symbol} from finnhub!")
|
307 |
|
308 |
-
company_template =
|
309 |
-
"
|
310 |
-
"Incorporated and publicly traded since {ipo}, the company has established its reputation "
|
311 |
-
"as one of the key players in the market. As of today, {name} has a market capitalization "
|
312 |
-
"of {marketCapitalization:.2f} in {currency}, with {shareOutstanding:.2f} shares outstanding.\n\n"
|
313 |
-
"{name} operates primarily in the {country}, trading under the ticker {ticker} on the {exchange}. "
|
314 |
-
"As a dominant force in the {finnhubIndustry} space, the company continues to innovate and drive "
|
315 |
-
"progress within the industry."
|
316 |
-
)
|
317 |
|
318 |
formatted_str = company_template.format(**profile)
|
|
|
319 |
return formatted_str
|
320 |
|
321 |
|
322 |
def get_prompt_by_row(symbol, row):
|
323 |
-
"""
|
324 |
-
Given a single row with Start Price, End Price, News, return a textual summary
|
325 |
-
plus the news items.
|
326 |
-
"""
|
327 |
end_price = float(row['End Price'])
|
328 |
start_price = float(row['Start Price'])
|
329 |
term = 'increased' if end_price > start_price else 'decreased'
|
330 |
|
331 |
-
start_date = row['Start Date'].strftime('%Y-%m-%d')
|
332 |
-
end_date = row['End Date'].strftime('%Y-%m-%d')
|
333 |
head = f"From {start_date} to {end_date}, {symbol}'s stock price {term} from {start_price:.2f} to {end_price:.2f}. Company news during this period are listed below:\n\n"
|
334 |
|
335 |
-
# row["News"] is a JSON string, parse it
|
336 |
news = row["News"] if isinstance(row["News"], list) else json.loads(row["News"])
|
337 |
-
|
338 |
-
|
|
|
|
|
339 |
|
340 |
|
341 |
def sample_news(news, k=5):
|
342 |
-
|
343 |
-
if len(news) <= k:
|
344 |
-
return news
|
345 |
return [news[i] for i in sorted(random.sample(range(len(news)), k))]
|
346 |
-
|
347 |
-
|
348 |
def latest_news(news, k=5):
|
349 |
-
"""Get up to k of the most recently dated news from 'news', sorted descending."""
|
350 |
sorted_news = sorted(news, key=lambda x: x['date'], reverse=True)
|
351 |
return sorted_news[:k]
|
352 |
|
353 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
354 |
def get_all_prompts_online(symbol, data, curday, with_basics=True):
|
355 |
-
"""
|
356 |
-
Build a final prompt string from:
|
357 |
-
- The company prompt (finnhub profile)
|
358 |
-
- The historical intervals from 'data'
|
359 |
-
- The current basics from Polygon
|
360 |
-
"""
|
361 |
company_prompt = get_company_prompt(symbol)
|
362 |
|
363 |
prev_rows = []
|
364 |
-
for _, row in data.iterrows():
|
365 |
-
head, news, _ = get_prompt_by_row(symbol, row)
|
366 |
-
prev_rows.append((head, news))
|
367 |
|
368 |
-
|
369 |
-
|
|
|
|
|
|
|
370 |
for i in range(-len(prev_rows), 0):
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
]
|
381 |
-
prompt_body += "\n".join(formatted_news_items)
|
382 |
else:
|
383 |
-
|
384 |
-
|
385 |
-
period =
|
386 |
-
|
387 |
-
# If user wants the latest polygon financials
|
388 |
if with_basics:
|
389 |
basics = get_current_basics(symbol, curday)
|
390 |
-
|
391 |
-
|
392 |
-
f"(period: {basics['period']}), are presented below:\n\n[Basic Financials]:\n\n"
|
393 |
-
)
|
394 |
-
# Append each metric
|
395 |
-
for k, v in basics.items():
|
396 |
-
if k not in ["period", "filing_date"]:
|
397 |
-
basics_text += f"{k}: {v}\n"
|
398 |
else:
|
399 |
-
|
400 |
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
+ f"\n\nBased on all the information before {curday}, let's first analyze the positive developments "
|
407 |
-
f"and potential concerns for {symbol}. Come up with 2-4 most important factors respectively and "
|
408 |
-
f"keep them concise. Most factors should be inferred from company-related news. Then make your "
|
409 |
-
f"prediction of the {symbol} stock price movement for next week ({period}). Provide a summary analysis "
|
410 |
-
f"to support your prediction."
|
411 |
-
)
|
412 |
-
|
413 |
-
return info_block, final_prompt
|
414 |
|
415 |
|
416 |
-
###############################################################################
|
417 |
-
# Gradio pipeline
|
418 |
-
###############################################################################
|
419 |
|
420 |
def construct_prompt(ticker, curday, n_weeks, use_basics):
|
421 |
-
|
422 |
-
Build the final prompt to feed into the model by collecting:
|
423 |
-
1) Stock data from yfinance (past n_weeks intervals)
|
424 |
-
2) News from Finnhub & Polygon
|
425 |
-
3) Current basics from Polygon (if requested)
|
426 |
-
"""
|
427 |
try:
|
428 |
steps = [n_weeks_before(curday, n) for n in range(n_weeks + 1)][::-1]
|
429 |
except Exception:
|
430 |
raise gr.Error(f"Invalid date {curday}!")
|
431 |
-
|
432 |
-
# 1) Stock data
|
433 |
data = get_stock_data(ticker, steps)
|
434 |
-
# 2) News data (from Finnhub & Polygon)
|
435 |
data = get_news(ticker, data)
|
436 |
-
# We don't store the basics in each row anymore, so just place empty dict
|
437 |
data['Basics'] = [json.dumps({})] * len(data)
|
438 |
-
|
439 |
-
|
440 |
info, prompt = get_all_prompts_online(ticker, data, curday, use_basics)
|
441 |
|
442 |
-
# Format with system instructions for Llama
|
443 |
prompt = B_INST + B_SYS + SYSTEM_PROMPT + E_SYS + prompt + E_INST
|
|
|
444 |
|
445 |
return info, prompt
|
446 |
|
447 |
|
448 |
-
def predict(ticker,
|
449 |
-
|
450 |
-
Main function triggered by Gradio.
|
451 |
-
1) Builds the prompt,
|
452 |
-
2) Generates the response via the LLaMA-2 + LoRA model,
|
453 |
-
3) Returns the full info block + final answer.
|
454 |
-
"""
|
455 |
print_gpu_utilization()
|
456 |
-
|
457 |
-
# Construct the full prompt
|
458 |
-
info, prompt = construct_prompt(ticker, date_, n_weeks, use_basics)
|
459 |
-
|
460 |
-
# Tokenize
|
461 |
-
inputs = tokenizer(prompt, return_tensors='pt', padding=False)
|
462 |
-
inputs = {key: value.to(model.device) for key, value in inputs.items()}
|
463 |
|
464 |
-
|
|
|
|
|
|
|
|
|
|
|
465 |
|
466 |
-
|
|
|
467 |
res = model.generate(
|
468 |
-
**inputs,
|
469 |
-
max_length=4096,
|
470 |
-
do_sample=False,
|
471 |
eos_token_id=tokenizer.eos_token_id,
|
472 |
-
use_cache=True,
|
473 |
-
streamer=streamer
|
474 |
)
|
475 |
output = tokenizer.decode(res[0], skip_special_tokens=True)
|
476 |
-
# Remove everything up to the closing [/INST]
|
477 |
answer = re.sub(r'.*\[/INST\]\s*', '', output, flags=re.DOTALL)
|
478 |
|
479 |
-
# Clean up GPU memory
|
480 |
torch.cuda.empty_cache()
|
481 |
|
482 |
return info, answer
|
483 |
|
484 |
|
485 |
-
################################################################################
|
486 |
-
# Gradio Interface
|
487 |
-
################################################################################
|
488 |
-
|
489 |
demo = gr.Interface(
|
490 |
-
|
491 |
inputs=[
|
492 |
gr.Textbox(
|
493 |
label="Ticker",
|
494 |
value="AAPL",
|
495 |
-
info="
|
496 |
),
|
497 |
gr.Textbox(
|
498 |
label="Date",
|
499 |
value=get_curday,
|
500 |
-
info="Date from which the prediction is made, format yyyy-mm-dd"
|
501 |
),
|
502 |
gr.Slider(
|
503 |
minimum=1,
|
@@ -505,22 +292,25 @@ demo = gr.Interface(
|
|
505 |
value=3,
|
506 |
step=1,
|
507 |
label="n_weeks",
|
508 |
-
info="Information of the past n weeks will be utilized
|
509 |
),
|
510 |
gr.Checkbox(
|
511 |
label="Use Latest Basic Financials",
|
512 |
value=False,
|
513 |
-
info="If checked, the latest quarterly reported basic financials
|
514 |
)
|
515 |
],
|
516 |
outputs=[
|
517 |
-
gr.Textbox(
|
518 |
-
|
|
|
|
|
|
|
|
|
519 |
],
|
520 |
title="Pro Capital",
|
521 |
-
description="Pro Capital implementation
|
522 |
-
|
523 |
)
|
524 |
|
525 |
-
|
526 |
-
demo.launch()
|
|
|
3 |
import time
|
4 |
import json
|
5 |
import random
|
|
|
6 |
import finnhub
|
7 |
import torch
|
8 |
import gradio as gr
|
|
|
14 |
from datetime import date, datetime, timedelta
|
15 |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
|
16 |
|
|
|
|
|
|
|
17 |
|
18 |
+
access_token = os.environ["HF_TOKEN"]
|
19 |
+
finnhub_client = finnhub.Client(api_key=os.environ["FINNHUB_API_KEY"])
|
|
|
|
|
20 |
|
|
|
|
|
|
|
21 |
base_model = AutoModelForCausalLM.from_pretrained(
|
22 |
'meta-llama/Llama-2-7b-chat-hf',
|
23 |
token=access_token,
|
|
|
40 |
|
41 |
streamer = TextStreamer(tokenizer)
|
42 |
|
|
|
43 |
B_INST, E_INST = "[INST]", "[/INST]"
|
44 |
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
45 |
|
46 |
+
SYSTEM_PROMPT = "You are a seasoned stock market analyst. Your task is to list the positive developments and potential concerns for companies based on relevant news and basic financials from the past weeks, then provide an analysis and prediction for the companies' stock price movement for the upcoming week. " \
|
47 |
+
"Your answer format should be as follows:\n\n[Positive Developments]:\n1. ...\n\n[Potential Concerns]:\n1. ...\n\n[Prediction & Analysis]\nPrediction: ...\nAnalysis: ..."
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
|
|
|
|
|
|
49 |
|
50 |
def print_gpu_utilization():
|
51 |
+
|
52 |
nvmlInit()
|
53 |
handle = nvmlDeviceGetHandleByIndex(0)
|
54 |
info = nvmlDeviceGetMemoryInfo(handle)
|
|
|
56 |
|
57 |
|
58 |
def get_curday():
|
59 |
+
|
60 |
return date.today().strftime("%Y-%m-%d")
|
61 |
|
62 |
|
63 |
def n_weeks_before(date_string, n):
|
64 |
+
|
65 |
+
date = datetime.strptime(date_string, "%Y-%m-%d") - timedelta(days=7*n)
|
66 |
+
|
67 |
+
return date.strftime("%Y-%m-%d")
|
68 |
|
69 |
|
70 |
def get_stock_data(stock_symbol, steps):
|
|
|
|
|
|
|
|
|
|
|
71 |
stock_data = yf.download(stock_symbol, steps[0], steps[-1])
|
72 |
if len(stock_data) == 0:
|
73 |
raise gr.Error(f"Failed to download stock price data for symbol {stock_symbol} from yfinance!")
|
74 |
+
|
75 |
+
print(stock_data)
|
76 |
|
|
|
|
|
77 |
dates, prices = [], []
|
78 |
available_dates = stock_data.index.format()
|
79 |
+
|
80 |
+
for date in steps[:-1]:
|
81 |
for i in range(len(stock_data)):
|
82 |
+
if available_dates[i] >= date:
|
83 |
+
prices.append(stock_data['Close'].iloc[i]) # Use .iloc here
|
84 |
dates.append(datetime.strptime(available_dates[i], "%Y-%m-%d"))
|
85 |
break
|
86 |
|
87 |
+
# Append the last date and price
|
88 |
dates.append(datetime.strptime(available_dates[-1], "%Y-%m-%d"))
|
89 |
+
prices.append(stock_data['Close'].iloc[-1]) # Use .iloc here as well
|
90 |
|
91 |
return pd.DataFrame({
|
92 |
"Start Date": dates[:-1],
|
|
|
96 |
})
|
97 |
|
98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
def get_news(symbol, data):
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
+
news_list = []
|
102 |
+
|
103 |
+
for end_date, row in data.iterrows():
|
104 |
start_date = row['Start Date'].strftime('%Y-%m-%d')
|
105 |
end_date = row['End Date'].strftime('%Y-%m-%d')
|
106 |
+
# print(symbol, ': ', start_date, ' - ', end_date)
|
107 |
+
time.sleep(1) # control qpm
|
108 |
+
weekly_news = finnhub_client.company_news(symbol, _from=start_date, to=end_date)
|
109 |
+
if len(weekly_news) == 0:
|
110 |
+
raise gr.Error(f"No company news found for symbol {symbol} from finnhub!")
|
111 |
+
weekly_news = [
|
112 |
+
{
|
113 |
+
"date": datetime.fromtimestamp(n['datetime']).strftime('%Y%m%d%H%M%S'),
|
114 |
+
"headline": n['headline'],
|
115 |
+
"summary": n['summary'],
|
116 |
+
} for n in weekly_news
|
117 |
+
]
|
118 |
+
weekly_news.sort(key=lambda x: x['date'])
|
119 |
+
news_list.append(json.dumps(weekly_news))
|
120 |
+
|
121 |
+
data['News'] = news_list
|
122 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
return data
|
124 |
|
125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
def get_company_prompt(symbol):
|
127 |
+
|
|
|
|
|
|
|
128 |
profile = finnhub_client.company_profile2(symbol=symbol)
|
129 |
if not profile:
|
130 |
raise gr.Error(f"Failed to find company profile for symbol {symbol} from finnhub!")
|
131 |
|
132 |
+
company_template = "[Company Introduction]:\n\n{name} is a leading entity in the {finnhubIndustry} sector. Incorporated and publicly traded since {ipo}, the company has established its reputation as one of the key players in the market. As of today, {name} has a market capitalization of {marketCapitalization:.2f} in {currency}, with {shareOutstanding:.2f} shares outstanding." \
|
133 |
+
"\n\n{name} operates primarily in the {country}, trading under the ticker {ticker} on the {exchange}. As a dominant force in the {finnhubIndustry} space, the company continues to innovate and drive progress within the industry."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
|
135 |
formatted_str = company_template.format(**profile)
|
136 |
+
|
137 |
return formatted_str
|
138 |
|
139 |
|
140 |
def get_prompt_by_row(symbol, row):
|
|
|
|
|
|
|
|
|
141 |
end_price = float(row['End Price'])
|
142 |
start_price = float(row['Start Price'])
|
143 |
term = 'increased' if end_price > start_price else 'decreased'
|
144 |
|
145 |
+
start_date = row['Start Date'] if isinstance(row['Start Date'], str) else row['Start Date'].strftime('%Y-%m-%d')
|
146 |
+
end_date = row['End Date'] if isinstance(row['End Date'], str) else row['End Date'].strftime('%Y-%m-%d')
|
147 |
head = f"From {start_date} to {end_date}, {symbol}'s stock price {term} from {start_price:.2f} to {end_price:.2f}. Company news during this period are listed below:\n\n"
|
148 |
|
|
|
149 |
news = row["News"] if isinstance(row["News"], list) else json.loads(row["News"])
|
150 |
+
basics = json.loads(row['Basics'])
|
151 |
+
|
152 |
+
return head, news, basics
|
153 |
+
|
154 |
|
155 |
|
156 |
def sample_news(news, k=5):
|
157 |
+
|
|
|
|
|
158 |
return [news[i] for i in sorted(random.sample(range(len(news)), k))]
|
159 |
+
|
|
|
160 |
def latest_news(news, k=5):
|
|
|
161 |
sorted_news = sorted(news, key=lambda x: x['date'], reverse=True)
|
162 |
return sorted_news[:k]
|
163 |
|
164 |
|
165 |
+
def get_current_basics(symbol, curday):
|
166 |
+
|
167 |
+
basic_financials = finnhub_client.company_basic_financials(symbol, 'all')
|
168 |
+
if not basic_financials['series']:
|
169 |
+
raise gr.Error(f"Failed to find basic financials for symbol {symbol} from finnhub!")
|
170 |
+
|
171 |
+
final_basics, basic_list, basic_dict = [], [], defaultdict(dict)
|
172 |
+
|
173 |
+
for metric, value_list in basic_financials['series']['quarterly'].items():
|
174 |
+
for value in value_list:
|
175 |
+
basic_dict[value['period']].update({metric: value['v']})
|
176 |
+
|
177 |
+
for k, v in basic_dict.items():
|
178 |
+
v.update({'period': k})
|
179 |
+
basic_list.append(v)
|
180 |
+
|
181 |
+
basic_list.sort(key=lambda x: x['period'])
|
182 |
+
|
183 |
+
for basic in basic_list[::-1]:
|
184 |
+
if basic['period'] <= curday:
|
185 |
+
break
|
186 |
+
|
187 |
+
return basic
|
188 |
+
|
189 |
+
|
190 |
def get_all_prompts_online(symbol, data, curday, with_basics=True):
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
company_prompt = get_company_prompt(symbol)
|
192 |
|
193 |
prev_rows = []
|
|
|
|
|
|
|
194 |
|
195 |
+
for row_idx, row in data.iterrows():
|
196 |
+
head, news, _ = get_prompt_by_row(symbol, row)
|
197 |
+
prev_rows.append((head, news, None))
|
198 |
+
|
199 |
+
prompt = ""
|
200 |
for i in range(-len(prev_rows), 0):
|
201 |
+
prompt += "\n" + prev_rows[i][0]
|
202 |
+
latest_news_items = latest_news(
|
203 |
+
prev_rows[i][1],
|
204 |
+
min(5, len(prev_rows[i][1]))
|
205 |
+
)
|
206 |
+
if latest_news_items:
|
207 |
+
# Ensure each news item is formatted as a string
|
208 |
+
formatted_news_items = ["[Headline]: {}\n[Summary]: {}\n".format(n['headline'], n['summary']) for n in latest_news_items]
|
209 |
+
prompt += "\n".join(formatted_news_items)
|
|
|
|
|
210 |
else:
|
211 |
+
prompt += "No relative news reported."
|
212 |
+
|
213 |
+
period = "{} to {}".format(curday, n_weeks_before(curday, -1))
|
214 |
+
|
|
|
215 |
if with_basics:
|
216 |
basics = get_current_basics(symbol, curday)
|
217 |
+
basics = "Some recent basic financials of {}, reported at {}, are presented below:\n\n[Basic Financials]:\n\n".format(
|
218 |
+
symbol, basics['period']) + "\n".join(f"{k}: {v}" for k, v in basics.items() if k != 'period')
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
else:
|
220 |
+
basics = "[Basic Financials]:\n\nNo basic financial reported."
|
221 |
|
222 |
+
info = company_prompt + '\n' + prompt + '\n' + basics
|
223 |
+
prompt = info + f"\n\nBased on all the information before {curday}, let's first analyze the positive developments and potential concerns for {symbol}. Come up with 2-4 most important factors respectively and keep them concise. Most factors should be inferred from company related news. " \
|
224 |
+
f"Then make your prediction of the {symbol} stock price movement for next week ({period}). Provide a summary analysis to support your prediction."
|
225 |
+
|
226 |
+
return info, prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
227 |
|
228 |
|
|
|
|
|
|
|
229 |
|
230 |
def construct_prompt(ticker, curday, n_weeks, use_basics):
|
231 |
+
|
|
|
|
|
|
|
|
|
|
|
232 |
try:
|
233 |
steps = [n_weeks_before(curday, n) for n in range(n_weeks + 1)][::-1]
|
234 |
except Exception:
|
235 |
raise gr.Error(f"Invalid date {curday}!")
|
236 |
+
|
|
|
237 |
data = get_stock_data(ticker, steps)
|
|
|
238 |
data = get_news(ticker, data)
|
|
|
239 |
data['Basics'] = [json.dumps({})] * len(data)
|
240 |
+
# print(data)
|
241 |
+
|
242 |
info, prompt = get_all_prompts_online(ticker, data, curday, use_basics)
|
243 |
|
|
|
244 |
prompt = B_INST + B_SYS + SYSTEM_PROMPT + E_SYS + prompt + E_INST
|
245 |
+
# print(prompt)
|
246 |
|
247 |
return info, prompt
|
248 |
|
249 |
|
250 |
+
def predict(ticker, date, n_weeks, use_basics):
|
251 |
+
|
|
|
|
|
|
|
|
|
|
|
252 |
print_gpu_utilization()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
|
254 |
+
info, prompt = construct_prompt(ticker, date, n_weeks, use_basics)
|
255 |
+
|
256 |
+
inputs = tokenizer(
|
257 |
+
prompt, return_tensors='pt', padding=False
|
258 |
+
)
|
259 |
+
inputs = {key: value.to(model.device) for key, value in inputs.items()}
|
260 |
|
261 |
+
print("Inputs loaded onto devices.")
|
262 |
+
|
263 |
res = model.generate(
|
264 |
+
**inputs, max_length=4096, do_sample=False,
|
|
|
|
|
265 |
eos_token_id=tokenizer.eos_token_id,
|
266 |
+
use_cache=True, streamer=streamer
|
|
|
267 |
)
|
268 |
output = tokenizer.decode(res[0], skip_special_tokens=True)
|
|
|
269 |
answer = re.sub(r'.*\[/INST\]\s*', '', output, flags=re.DOTALL)
|
270 |
|
|
|
271 |
torch.cuda.empty_cache()
|
272 |
|
273 |
return info, answer
|
274 |
|
275 |
|
|
|
|
|
|
|
|
|
276 |
demo = gr.Interface(
|
277 |
+
predict,
|
278 |
inputs=[
|
279 |
gr.Textbox(
|
280 |
label="Ticker",
|
281 |
value="AAPL",
|
282 |
+
info="Companys from Dow-30 are recommended"
|
283 |
),
|
284 |
gr.Textbox(
|
285 |
label="Date",
|
286 |
value=get_curday,
|
287 |
+
info="Date from which the prediction is made, use format yyyy-mm-dd"
|
288 |
),
|
289 |
gr.Slider(
|
290 |
minimum=1,
|
|
|
292 |
value=3,
|
293 |
step=1,
|
294 |
label="n_weeks",
|
295 |
+
info="Information of the past n weeks will be utilized, choose between 1 and 4"
|
296 |
),
|
297 |
gr.Checkbox(
|
298 |
label="Use Latest Basic Financials",
|
299 |
value=False,
|
300 |
+
info="If checked, the latest quarterly reported basic financials of the company is taken into account."
|
301 |
)
|
302 |
],
|
303 |
outputs=[
|
304 |
+
gr.Textbox(
|
305 |
+
label="Information"
|
306 |
+
),
|
307 |
+
gr.Textbox(
|
308 |
+
label="Response"
|
309 |
+
)
|
310 |
],
|
311 |
title="Pro Capital",
|
312 |
+
description="""Pro Capital implementation.**
|
313 |
+
"""
|
314 |
)
|
315 |
|
316 |
+
demo.launch()
|
|