Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -3,6 +3,7 @@ import re
|
|
3 |
import time
|
4 |
import json
|
5 |
import random
|
|
|
6 |
import finnhub
|
7 |
import torch
|
8 |
import gradio as gr
|
@@ -14,10 +15,18 @@ from collections import defaultdict
|
|
14 |
from datetime import date, datetime, timedelta
|
15 |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
|
16 |
|
|
|
|
|
|
|
17 |
|
18 |
-
|
19 |
-
|
|
|
|
|
20 |
|
|
|
|
|
|
|
21 |
base_model = AutoModelForCausalLM.from_pretrained(
|
22 |
'meta-llama/Llama-2-7b-chat-hf',
|
23 |
token=access_token,
|
@@ -40,15 +49,25 @@ tokenizer = AutoTokenizer.from_pretrained(
|
|
40 |
|
41 |
streamer = TextStreamer(tokenizer)
|
42 |
|
|
|
43 |
B_INST, E_INST = "[INST]", "[/INST]"
|
44 |
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
45 |
|
46 |
-
SYSTEM_PROMPT =
|
47 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
|
|
|
|
|
|
49 |
|
50 |
def print_gpu_utilization():
|
51 |
-
|
52 |
nvmlInit()
|
53 |
handle = nvmlDeviceGetHandleByIndex(0)
|
54 |
info = nvmlDeviceGetMemoryInfo(handle)
|
@@ -56,37 +75,41 @@ def print_gpu_utilization():
|
|
56 |
|
57 |
|
58 |
def get_curday():
|
59 |
-
|
60 |
return date.today().strftime("%Y-%m-%d")
|
61 |
|
62 |
|
63 |
def n_weeks_before(date_string, n):
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
return date.strftime("%Y-%m-%d")
|
68 |
|
69 |
|
70 |
def get_stock_data(stock_symbol, steps):
|
|
|
|
|
|
|
|
|
|
|
71 |
stock_data = yf.download(stock_symbol, steps[0], steps[-1])
|
72 |
if len(stock_data) == 0:
|
73 |
raise gr.Error(f"Failed to download stock price data for symbol {stock_symbol} from yfinance!")
|
74 |
-
|
75 |
-
print(stock_data)
|
76 |
|
|
|
|
|
77 |
dates, prices = [], []
|
78 |
available_dates = stock_data.index.format()
|
79 |
-
|
80 |
-
for
|
81 |
for i in range(len(stock_data)):
|
82 |
-
if available_dates[i] >=
|
83 |
-
prices.append(stock_data['Close'].iloc[i])
|
84 |
dates.append(datetime.strptime(available_dates[i], "%Y-%m-%d"))
|
85 |
break
|
86 |
|
87 |
-
# Append
|
88 |
dates.append(datetime.strptime(available_dates[-1], "%Y-%m-%d"))
|
89 |
-
prices.append(stock_data['Close'].iloc[-1])
|
90 |
|
91 |
return pd.DataFrame({
|
92 |
"Start Date": dates[:-1],
|
@@ -96,195 +119,385 @@ def get_stock_data(stock_symbol, steps):
|
|
96 |
})
|
97 |
|
98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
def get_news(symbol, data):
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
-
|
102 |
-
|
103 |
-
for end_date, row in data.iterrows():
|
104 |
start_date = row['Start Date'].strftime('%Y-%m-%d')
|
105 |
end_date = row['End Date'].strftime('%Y-%m-%d')
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
return data
|
124 |
|
125 |
|
126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
|
|
|
|
|
|
|
|
|
|
|
128 |
profile = finnhub_client.company_profile2(symbol=symbol)
|
129 |
if not profile:
|
130 |
raise gr.Error(f"Failed to find company profile for symbol {symbol} from finnhub!")
|
131 |
|
132 |
-
company_template =
|
133 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
|
135 |
formatted_str = company_template.format(**profile)
|
136 |
-
|
137 |
return formatted_str
|
138 |
|
139 |
|
140 |
def get_prompt_by_row(symbol, row):
|
|
|
|
|
|
|
|
|
141 |
end_price = float(row['End Price'])
|
142 |
start_price = float(row['Start Price'])
|
143 |
term = 'increased' if end_price > start_price else 'decreased'
|
144 |
|
145 |
-
start_date = row['Start Date']
|
146 |
-
end_date = row['End Date']
|
147 |
head = f"From {start_date} to {end_date}, {symbol}'s stock price {term} from {start_price:.2f} to {end_price:.2f}. Company news during this period are listed below:\n\n"
|
148 |
|
|
|
149 |
news = row["News"] if isinstance(row["News"], list) else json.loads(row["News"])
|
150 |
-
|
151 |
-
|
152 |
-
return head, news, basics
|
153 |
-
|
154 |
|
155 |
|
156 |
def sample_news(news, k=5):
|
157 |
-
|
|
|
|
|
158 |
return [news[i] for i in sorted(random.sample(range(len(news)), k))]
|
159 |
-
|
|
|
160 |
def latest_news(news, k=5):
|
|
|
161 |
sorted_news = sorted(news, key=lambda x: x['date'], reverse=True)
|
162 |
return sorted_news[:k]
|
163 |
|
164 |
|
165 |
-
def get_current_basics(symbol, curday):
|
166 |
-
|
167 |
-
basic_financials = finnhub_client.company_basic_financials(symbol, 'all')
|
168 |
-
if not basic_financials['series']:
|
169 |
-
raise gr.Error(f"Failed to find basic financials for symbol {symbol} from finnhub!")
|
170 |
-
|
171 |
-
final_basics, basic_list, basic_dict = [], [], defaultdict(dict)
|
172 |
-
|
173 |
-
for metric, value_list in basic_financials['series']['quarterly'].items():
|
174 |
-
for value in value_list:
|
175 |
-
basic_dict[value['period']].update({metric: value['v']})
|
176 |
-
|
177 |
-
for k, v in basic_dict.items():
|
178 |
-
v.update({'period': k})
|
179 |
-
basic_list.append(v)
|
180 |
-
|
181 |
-
basic_list.sort(key=lambda x: x['period'])
|
182 |
-
|
183 |
-
for basic in basic_list[::-1]:
|
184 |
-
if basic['period'] <= curday:
|
185 |
-
break
|
186 |
-
|
187 |
-
return basic
|
188 |
-
|
189 |
-
|
190 |
def get_all_prompts_online(symbol, data, curday, with_basics=True):
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
company_prompt = get_company_prompt(symbol)
|
192 |
|
193 |
prev_rows = []
|
194 |
-
|
195 |
-
for row_idx, row in data.iterrows():
|
196 |
head, news, _ = get_prompt_by_row(symbol, row)
|
197 |
-
prev_rows.append((head, news
|
198 |
-
|
199 |
-
|
|
|
200 |
for i in range(-len(prev_rows), 0):
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
|
|
|
|
210 |
else:
|
211 |
-
|
212 |
-
|
213 |
-
period = "{} to {
|
214 |
-
|
|
|
215 |
if with_basics:
|
216 |
basics = get_current_basics(symbol, curday)
|
217 |
-
|
218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
else:
|
220 |
-
|
221 |
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
227 |
|
228 |
|
|
|
|
|
|
|
229 |
|
230 |
def construct_prompt(ticker, curday, n_weeks, use_basics):
|
231 |
-
|
|
|
|
|
|
|
|
|
|
|
232 |
try:
|
233 |
steps = [n_weeks_before(curday, n) for n in range(n_weeks + 1)][::-1]
|
234 |
except Exception:
|
235 |
raise gr.Error(f"Invalid date {curday}!")
|
236 |
-
|
|
|
237 |
data = get_stock_data(ticker, steps)
|
|
|
238 |
data = get_news(ticker, data)
|
|
|
239 |
data['Basics'] = [json.dumps({})] * len(data)
|
240 |
-
|
241 |
-
|
242 |
-
info, prompt = get_all_prompts_online(ticker, data, curday,
|
243 |
|
|
|
244 |
prompt = B_INST + B_SYS + SYSTEM_PROMPT + E_SYS + prompt + E_INST
|
245 |
-
# print(prompt)
|
246 |
|
247 |
return info, prompt
|
248 |
|
249 |
|
250 |
-
def predict(ticker,
|
251 |
-
|
|
|
|
|
|
|
|
|
|
|
252 |
print_gpu_utilization()
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
)
|
259 |
inputs = {key: value.to(model.device) for key, value in inputs.items()}
|
260 |
|
261 |
-
print("Inputs loaded onto
|
262 |
-
|
|
|
263 |
res = model.generate(
|
264 |
-
**inputs,
|
|
|
|
|
265 |
eos_token_id=tokenizer.eos_token_id,
|
266 |
-
use_cache=True,
|
|
|
267 |
)
|
268 |
output = tokenizer.decode(res[0], skip_special_tokens=True)
|
|
|
269 |
answer = re.sub(r'.*\[/INST\]\s*', '', output, flags=re.DOTALL)
|
270 |
|
|
|
271 |
torch.cuda.empty_cache()
|
272 |
|
273 |
return info, answer
|
274 |
|
275 |
|
|
|
|
|
|
|
|
|
276 |
demo = gr.Interface(
|
277 |
-
predict,
|
278 |
inputs=[
|
279 |
gr.Textbox(
|
280 |
label="Ticker",
|
281 |
value="AAPL",
|
282 |
-
info="
|
283 |
),
|
284 |
gr.Textbox(
|
285 |
label="Date",
|
286 |
value=get_curday,
|
287 |
-
info="Date from which the prediction is made,
|
288 |
),
|
289 |
gr.Slider(
|
290 |
minimum=1,
|
@@ -292,25 +505,22 @@ demo = gr.Interface(
|
|
292 |
value=3,
|
293 |
step=1,
|
294 |
label="n_weeks",
|
295 |
-
info="Information of the past n weeks will be utilized
|
296 |
),
|
297 |
gr.Checkbox(
|
298 |
label="Use Latest Basic Financials",
|
299 |
value=False,
|
300 |
-
info="If checked, the latest quarterly reported basic financials
|
301 |
)
|
302 |
],
|
303 |
outputs=[
|
304 |
-
gr.Textbox(
|
305 |
-
|
306 |
-
),
|
307 |
-
gr.Textbox(
|
308 |
-
label="Response"
|
309 |
-
)
|
310 |
],
|
311 |
title="Pro Capital",
|
312 |
-
description="
|
313 |
-
""
|
314 |
)
|
315 |
|
316 |
-
|
|
|
|
3 |
import time
|
4 |
import json
|
5 |
import random
|
6 |
+
import requests # <-- For Polygon API calls
|
7 |
import finnhub
|
8 |
import torch
|
9 |
import gradio as gr
|
|
|
15 |
from datetime import date, datetime, timedelta
|
16 |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
|
17 |
|
18 |
+
################################################################################
|
19 |
+
# Set up environment variables, tokens, model, tokenizer, and other base config
|
20 |
+
################################################################################
|
21 |
|
22 |
+
# Make sure these environment variables or your chosen tokens are set properly.
|
23 |
+
access_token = os.environ.get("HF_TOKEN", "YOUR_HF_ACCESS_TOKEN")
|
24 |
+
finnhub_api_key = os.environ.get("FINNHUB_API_KEY", "YOUR_FINNHUB_API_KEY")
|
25 |
+
polygon_api_key = "OeD88ExoL458WlvNpCyiZXnHI0s_h05t" # <--- Replace "X" with your actual Polygon key
|
26 |
|
27 |
+
finnhub_client = finnhub.Client(api_key=finnhub_api_key)
|
28 |
+
|
29 |
+
# Load base model & LoRA
|
30 |
base_model = AutoModelForCausalLM.from_pretrained(
|
31 |
'meta-llama/Llama-2-7b-chat-hf',
|
32 |
token=access_token,
|
|
|
49 |
|
50 |
streamer = TextStreamer(tokenizer)
|
51 |
|
52 |
+
# Special Llama-format prompt tokens
|
53 |
B_INST, E_INST = "[INST]", "[/INST]"
|
54 |
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
55 |
|
56 |
+
SYSTEM_PROMPT = (
|
57 |
+
"You are a seasoned stock market analyst. Your task is to list the positive developments and "
|
58 |
+
"potential concerns for companies based on relevant news and basic financials from the past weeks, "
|
59 |
+
"then provide an analysis and prediction for the companies' stock price movement for the upcoming week. "
|
60 |
+
"Your answer format should be as follows:\n\n[Positive Developments]:\n1. ...\n\n[Potential Concerns]:\n1. ...\n\n"
|
61 |
+
"[Prediction & Analysis]\nPrediction: ...\nAnalysis: ..."
|
62 |
+
)
|
63 |
+
|
64 |
|
65 |
+
###############################################################################
|
66 |
+
# Utility functions
|
67 |
+
###############################################################################
|
68 |
|
69 |
def print_gpu_utilization():
|
70 |
+
"""Helper to print GPU utilization (MB) using NVML."""
|
71 |
nvmlInit()
|
72 |
handle = nvmlDeviceGetHandleByIndex(0)
|
73 |
info = nvmlDeviceGetMemoryInfo(handle)
|
|
|
75 |
|
76 |
|
77 |
def get_curday():
|
78 |
+
"""Returns today's date in YYYY-MM-DD."""
|
79 |
return date.today().strftime("%Y-%m-%d")
|
80 |
|
81 |
|
82 |
def n_weeks_before(date_string, n):
|
83 |
+
"""Given 'date_string' in YYYY-MM-DD and an integer n, returns the date that is n weeks prior."""
|
84 |
+
d = datetime.strptime(date_string, "%Y-%m-%d") - timedelta(days=7*n)
|
85 |
+
return d.strftime("%Y-%m-%d")
|
|
|
86 |
|
87 |
|
88 |
def get_stock_data(stock_symbol, steps):
|
89 |
+
"""
|
90 |
+
Downloads stock price data using yfinance for the given date steps.
|
91 |
+
Returns a DataFrame containing Start Date, End Date, Start Price, End Price
|
92 |
+
for each of the time intervals.
|
93 |
+
"""
|
94 |
stock_data = yf.download(stock_symbol, steps[0], steps[-1])
|
95 |
if len(stock_data) == 0:
|
96 |
raise gr.Error(f"Failed to download stock price data for symbol {stock_symbol} from yfinance!")
|
|
|
|
|
97 |
|
98 |
+
print(stock_data) # For debugging
|
99 |
+
|
100 |
dates, prices = [], []
|
101 |
available_dates = stock_data.index.format()
|
102 |
+
|
103 |
+
for date_ in steps[:-1]:
|
104 |
for i in range(len(stock_data)):
|
105 |
+
if available_dates[i] >= date_:
|
106 |
+
prices.append(stock_data['Close'].iloc[i])
|
107 |
dates.append(datetime.strptime(available_dates[i], "%Y-%m-%d"))
|
108 |
break
|
109 |
|
110 |
+
# Append last date & price
|
111 |
dates.append(datetime.strptime(available_dates[-1], "%Y-%m-%d"))
|
112 |
+
prices.append(stock_data['Close'].iloc[-1])
|
113 |
|
114 |
return pd.DataFrame({
|
115 |
"Start Date": dates[:-1],
|
|
|
119 |
})
|
120 |
|
121 |
|
122 |
+
###############################################################################
|
123 |
+
# News retrieval
|
124 |
+
###############################################################################
|
125 |
+
|
126 |
+
def parse_polygon_news_item(item):
|
127 |
+
"""
|
128 |
+
Convert a Polygon news item into a {date, headline, summary} dict similar to the finnhub format used.
|
129 |
+
Published_utc is in ISO8601, e.g. '2021-04-23T12:47:00Z'.
|
130 |
+
"""
|
131 |
+
published_str = item.get('published_utc', '')
|
132 |
+
try:
|
133 |
+
# Convert e.g. "2021-04-23T12:47:00Z" to a datetime, then to YYYYmmddHHMMSS string
|
134 |
+
dt = datetime.strptime(published_str, "%Y-%m-%dT%H:%M:%SZ")
|
135 |
+
date_fmt = dt.strftime("%Y%m%d%H%M%S")
|
136 |
+
except:
|
137 |
+
# In case of parsing error, just keep it raw
|
138 |
+
date_fmt = published_str
|
139 |
+
|
140 |
+
headline = item.get("title", "")
|
141 |
+
summary = item.get("description", "")
|
142 |
+
return {
|
143 |
+
"date": date_fmt,
|
144 |
+
"headline": headline,
|
145 |
+
"summary": summary
|
146 |
+
}
|
147 |
+
|
148 |
+
|
149 |
def get_news(symbol, data):
|
150 |
+
"""
|
151 |
+
For each row in data (weekly intervals), fetch both Finnhub and Polygon news.
|
152 |
+
Combine them, sort by date, store the combined list in data['News'].
|
153 |
+
"""
|
154 |
+
|
155 |
+
combined_news_list = []
|
156 |
|
157 |
+
for idx, row in data.iterrows():
|
|
|
|
|
158 |
start_date = row['Start Date'].strftime('%Y-%m-%d')
|
159 |
end_date = row['End Date'].strftime('%Y-%m-%d')
|
160 |
+
|
161 |
+
# Sleep for QPM control if needed
|
162 |
+
time.sleep(1)
|
163 |
+
|
164 |
+
#######################################################################
|
165 |
+
# 1) Finnhub News for the weekly period
|
166 |
+
#######################################################################
|
167 |
+
finnhub_weekly_news = []
|
168 |
+
try:
|
169 |
+
finnhub_news = finnhub_client.company_news(symbol, _from=start_date, to=end_date)
|
170 |
+
for n in finnhub_news:
|
171 |
+
dt_str = datetime.fromtimestamp(n['datetime']).strftime('%Y%m%d%H%M%S')
|
172 |
+
finnhub_weekly_news.append({
|
173 |
+
"date": dt_str,
|
174 |
+
"headline": n['headline'],
|
175 |
+
"summary": n['summary']
|
176 |
+
})
|
177 |
+
except Exception as e:
|
178 |
+
# If Finnhub returns an error or no results, we can choose to ignore or handle differently
|
179 |
+
finnhub_weekly_news = []
|
180 |
+
|
181 |
+
#######################################################################
|
182 |
+
# 2) Polygon News
|
183 |
+
# There's no date range param in the sample snippet. We'll fetch
|
184 |
+
# up to 30 news items, then filter them by start_date <= published <= end_date
|
185 |
+
#######################################################################
|
186 |
+
polygon_weekly_news = []
|
187 |
+
polygon_url = (
|
188 |
+
f"https://api.polygon.io/v2/reference/news"
|
189 |
+
f"?ticker={symbol}"
|
190 |
+
f"&order=asc"
|
191 |
+
f"&limit=30"
|
192 |
+
f"&sort=published_utc"
|
193 |
+
f"&apiKey={polygon_api_key}"
|
194 |
+
)
|
195 |
+
try:
|
196 |
+
resp = requests.get(polygon_url)
|
197 |
+
if resp.status_code == 200:
|
198 |
+
polygon_data = resp.json()
|
199 |
+
results = polygon_data.get('results', [])
|
200 |
+
for item in results:
|
201 |
+
news_item = parse_polygon_news_item(item)
|
202 |
+
# Filter by date range
|
203 |
+
# news_item['date'] is "YYYYmmddHHMMSS"
|
204 |
+
# We compare it with start_date/end_date as YYYY-MM-DD (the day-based boundary).
|
205 |
+
# So let's parse it properly to do a valid comparison:
|
206 |
+
try:
|
207 |
+
dt_item = datetime.strptime(news_item['date'], "%Y%m%d%H%M%S")
|
208 |
+
dt_start = datetime.strptime(start_date, "%Y-%m-%d")
|
209 |
+
dt_end = datetime.strptime(end_date, "%Y-%m-%d") + timedelta(days=1) - timedelta(seconds=1)
|
210 |
+
if dt_start <= dt_item <= dt_end:
|
211 |
+
polygon_weekly_news.append(news_item)
|
212 |
+
except:
|
213 |
+
# If for any reason the date parse fails, skip
|
214 |
+
pass
|
215 |
+
else:
|
216 |
+
print(f"Polygon news request returned status code {resp.status_code}")
|
217 |
+
except Exception as e:
|
218 |
+
print(f"Polygon news request failed with exception: {e}")
|
219 |
+
polygon_weekly_news = []
|
220 |
+
|
221 |
+
# Merge the two news lists
|
222 |
+
weekly_news_combined = finnhub_weekly_news + polygon_weekly_news
|
223 |
+
# Sort by date
|
224 |
+
weekly_news_combined.sort(key=lambda x: x['date'])
|
225 |
+
|
226 |
+
# If no news from either source, you can decide to raise an error or keep empty
|
227 |
+
# Here we won't raise an error if both are empty, so it continues gracefully:
|
228 |
+
# if len(weekly_news_combined) == 0:
|
229 |
+
# raise gr.Error(f"No company news found for symbol {symbol} from both Finnhub & Polygon in {start_date}-{end_date}!")
|
230 |
+
|
231 |
+
combined_news_list.append(json.dumps(weekly_news_combined))
|
232 |
+
|
233 |
+
data['News'] = combined_news_list
|
234 |
return data
|
235 |
|
236 |
|
237 |
+
###############################################################################
|
238 |
+
# Polygon Financials
|
239 |
+
###############################################################################
|
240 |
+
|
241 |
+
def get_current_basics(symbol, curday):
|
242 |
+
"""
|
243 |
+
Fetch the latest (limit=1, order=desc) financials from Polygon and parse out
|
244 |
+
a condensed dictionary of relevant metrics. Replaces the old Finnhub approach.
|
245 |
+
"""
|
246 |
+
url = (
|
247 |
+
f"https://api.polygon.io/vX/reference/financials"
|
248 |
+
f"?ticker={symbol}"
|
249 |
+
f"&order=desc"
|
250 |
+
f"&limit=1"
|
251 |
+
f"&sort=filing_date"
|
252 |
+
f"&apiKey={polygon_api_key}"
|
253 |
+
)
|
254 |
+
resp = requests.get(url)
|
255 |
+
if resp.status_code != 200:
|
256 |
+
raise gr.Error(f"Failed to retrieve financial data from Polygon! status={resp.status_code}")
|
257 |
+
|
258 |
+
data = resp.json()
|
259 |
+
if 'results' not in data or len(data['results']) == 0:
|
260 |
+
raise gr.Error(f"No financial results found from Polygon for ticker {symbol}!")
|
261 |
+
|
262 |
+
result = data['results'][0]
|
263 |
+
|
264 |
+
# We can store a small set of relevant metrics
|
265 |
+
filing_date = result.get('filing_date', '')
|
266 |
+
fiscal_year = result.get('fiscal_year', 'N/A')
|
267 |
+
fiscal_period = result.get('fiscal_period', 'N/A')
|
268 |
+
fin = result.get('financials', {})
|
269 |
+
|
270 |
+
# Helper to safely retrieve a nested value
|
271 |
+
def get_nested_value(d, path):
|
272 |
+
# path is a list of keys, e.g. ['income_statement','revenues']
|
273 |
+
temp = d
|
274 |
+
for p in path:
|
275 |
+
if p not in temp:
|
276 |
+
return None
|
277 |
+
temp = temp[p]
|
278 |
+
# Return .get("value") if present
|
279 |
+
return temp.get("value", None)
|
280 |
+
|
281 |
+
# Construct a dictionary for the prompt
|
282 |
+
# We'll reference these in the final user prompt
|
283 |
+
basics = {
|
284 |
+
"period": f"{fiscal_year} {fiscal_period}",
|
285 |
+
"filing_date": filing_date,
|
286 |
+
"Revenue": get_nested_value(fin, ["income_statement", "revenues"]),
|
287 |
+
"NetIncomeLoss": get_nested_value(fin, ["income_statement", "net_income_loss_attributable_to_parent"]),
|
288 |
+
"DilutedEPS": get_nested_value(fin, ["income_statement", "diluted_earnings_per_share"]),
|
289 |
+
"BasicEPS": get_nested_value(fin, ["income_statement", "basic_earnings_per_share"]),
|
290 |
+
}
|
291 |
+
|
292 |
+
return basics
|
293 |
+
|
294 |
+
|
295 |
+
###############################################################################
|
296 |
+
# Prompt engineering
|
297 |
+
###############################################################################
|
298 |
|
299 |
+
def get_company_prompt(symbol):
|
300 |
+
"""
|
301 |
+
Pull a minimal Finnhub company profile to produce a short introduction.
|
302 |
+
(We keep this from Finnhub if desired, as that info isn't replaced by Polygon.)
|
303 |
+
"""
|
304 |
profile = finnhub_client.company_profile2(symbol=symbol)
|
305 |
if not profile:
|
306 |
raise gr.Error(f"Failed to find company profile for symbol {symbol} from finnhub!")
|
307 |
|
308 |
+
company_template = (
|
309 |
+
"[Company Introduction]:\n\n{name} is a leading entity in the {finnhubIndustry} sector. "
|
310 |
+
"Incorporated and publicly traded since {ipo}, the company has established its reputation "
|
311 |
+
"as one of the key players in the market. As of today, {name} has a market capitalization "
|
312 |
+
"of {marketCapitalization:.2f} in {currency}, with {shareOutstanding:.2f} shares outstanding.\n\n"
|
313 |
+
"{name} operates primarily in the {country}, trading under the ticker {ticker} on the {exchange}. "
|
314 |
+
"As a dominant force in the {finnhubIndustry} space, the company continues to innovate and drive "
|
315 |
+
"progress within the industry."
|
316 |
+
)
|
317 |
|
318 |
formatted_str = company_template.format(**profile)
|
|
|
319 |
return formatted_str
|
320 |
|
321 |
|
322 |
def get_prompt_by_row(symbol, row):
|
323 |
+
"""
|
324 |
+
Given a single row with Start Price, End Price, News, return a textual summary
|
325 |
+
plus the news items.
|
326 |
+
"""
|
327 |
end_price = float(row['End Price'])
|
328 |
start_price = float(row['Start Price'])
|
329 |
term = 'increased' if end_price > start_price else 'decreased'
|
330 |
|
331 |
+
start_date = row['Start Date'].strftime('%Y-%m-%d')
|
332 |
+
end_date = row['End Date'].strftime('%Y-%m-%d')
|
333 |
head = f"From {start_date} to {end_date}, {symbol}'s stock price {term} from {start_price:.2f} to {end_price:.2f}. Company news during this period are listed below:\n\n"
|
334 |
|
335 |
+
# row["News"] is a JSON string, parse it
|
336 |
news = row["News"] if isinstance(row["News"], list) else json.loads(row["News"])
|
337 |
+
# We do not need 'Basics' from each row, because we fetch the latest basics only at the end
|
338 |
+
return head, news, None
|
|
|
|
|
339 |
|
340 |
|
341 |
def sample_news(news, k=5):
|
342 |
+
"""Randomly sample up to k news items from a list."""
|
343 |
+
if len(news) <= k:
|
344 |
+
return news
|
345 |
return [news[i] for i in sorted(random.sample(range(len(news)), k))]
|
346 |
+
|
347 |
+
|
348 |
def latest_news(news, k=5):
|
349 |
+
"""Get up to k of the most recently dated news from 'news', sorted descending."""
|
350 |
sorted_news = sorted(news, key=lambda x: x['date'], reverse=True)
|
351 |
return sorted_news[:k]
|
352 |
|
353 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
354 |
def get_all_prompts_online(symbol, data, curday, with_basics=True):
|
355 |
+
"""
|
356 |
+
Build a final prompt string from:
|
357 |
+
- The company prompt (finnhub profile)
|
358 |
+
- The historical intervals from 'data'
|
359 |
+
- The current basics from Polygon
|
360 |
+
"""
|
361 |
company_prompt = get_company_prompt(symbol)
|
362 |
|
363 |
prev_rows = []
|
364 |
+
for _, row in data.iterrows():
|
|
|
365 |
head, news, _ = get_prompt_by_row(symbol, row)
|
366 |
+
prev_rows.append((head, news))
|
367 |
+
|
368 |
+
# Build up the text describing previous intervals
|
369 |
+
prompt_body = ""
|
370 |
for i in range(-len(prev_rows), 0):
|
371 |
+
head, news_list = prev_rows[i]
|
372 |
+
prompt_body += "\n" + head
|
373 |
+
# Show up to 5 latest news for each interval
|
374 |
+
chosen_news = latest_news(news_list, min(5, len(news_list)))
|
375 |
+
if chosen_news:
|
376 |
+
# Format each news item
|
377 |
+
formatted_news_items = [
|
378 |
+
"[Headline]: {}\n[Summary]: {}\n".format(n['headline'], n['summary'])
|
379 |
+
for n in chosen_news
|
380 |
+
]
|
381 |
+
prompt_body += "\n".join(formatted_news_items)
|
382 |
else:
|
383 |
+
prompt_body += "No relative news reported."
|
384 |
+
|
385 |
+
period = f"{curday} to {n_weeks_before(curday, -1)}"
|
386 |
+
|
387 |
+
# If user wants the latest polygon financials
|
388 |
if with_basics:
|
389 |
basics = get_current_basics(symbol, curday)
|
390 |
+
basics_text = (
|
391 |
+
f"Some recent basic financials of {symbol}, reported at {basics['filing_date']} "
|
392 |
+
f"(period: {basics['period']}), are presented below:\n\n[Basic Financials]:\n\n"
|
393 |
+
)
|
394 |
+
# Append each metric
|
395 |
+
for k, v in basics.items():
|
396 |
+
if k not in ["period", "filing_date"]:
|
397 |
+
basics_text += f"{k}: {v}\n"
|
398 |
else:
|
399 |
+
basics_text = "[Basic Financials]:\n\nNo basic financial reported."
|
400 |
|
401 |
+
info_block = company_prompt + '\n' + prompt_body + '\n' + basics_text
|
402 |
+
|
403 |
+
# This is the final request prompt for the language model
|
404 |
+
final_prompt = (
|
405 |
+
info_block
|
406 |
+
+ f"\n\nBased on all the information before {curday}, let's first analyze the positive developments "
|
407 |
+
f"and potential concerns for {symbol}. Come up with 2-4 most important factors respectively and "
|
408 |
+
f"keep them concise. Most factors should be inferred from company-related news. Then make your "
|
409 |
+
f"prediction of the {symbol} stock price movement for next week ({period}). Provide a summary analysis "
|
410 |
+
f"to support your prediction."
|
411 |
+
)
|
412 |
+
|
413 |
+
return info_block, final_prompt
|
414 |
|
415 |
|
416 |
+
###############################################################################
|
417 |
+
# Gradio pipeline
|
418 |
+
###############################################################################
|
419 |
|
420 |
def construct_prompt(ticker, curday, n_weeks, use_basics):
|
421 |
+
"""
|
422 |
+
Build the final prompt to feed into the model by collecting:
|
423 |
+
1) Stock data from yfinance (past n_weeks intervals)
|
424 |
+
2) News from Finnhub & Polygon
|
425 |
+
3) Current basics from Polygon (if requested)
|
426 |
+
"""
|
427 |
try:
|
428 |
steps = [n_weeks_before(curday, n) for n in range(n_weeks + 1)][::-1]
|
429 |
except Exception:
|
430 |
raise gr.Error(f"Invalid date {curday}!")
|
431 |
+
|
432 |
+
# 1) Stock data
|
433 |
data = get_stock_data(ticker, steps)
|
434 |
+
# 2) News data (from Finnhub & Polygon)
|
435 |
data = get_news(ticker, data)
|
436 |
+
# We don't store the basics in each row anymore, so just place empty dict
|
437 |
data['Basics'] = [json.dumps({})] * len(data)
|
438 |
+
|
439 |
+
# 3) Construct final prompt
|
440 |
+
info, prompt = get_all_prompts_online(ticker, data, curday, with_basics)
|
441 |
|
442 |
+
# Format with system instructions for Llama
|
443 |
prompt = B_INST + B_SYS + SYSTEM_PROMPT + E_SYS + prompt + E_INST
|
|
|
444 |
|
445 |
return info, prompt
|
446 |
|
447 |
|
448 |
+
def predict(ticker, date_, n_weeks, use_basics):
|
449 |
+
"""
|
450 |
+
Main function triggered by Gradio.
|
451 |
+
1) Builds the prompt,
|
452 |
+
2) Generates the response via the LLaMA-2 + LoRA model,
|
453 |
+
3) Returns the full info block + final answer.
|
454 |
+
"""
|
455 |
print_gpu_utilization()
|
456 |
+
|
457 |
+
# Construct the full prompt
|
458 |
+
info, prompt = construct_prompt(ticker, date_, n_weeks, use_basics)
|
459 |
+
|
460 |
+
# Tokenize
|
461 |
+
inputs = tokenizer(prompt, return_tensors='pt', padding=False)
|
462 |
inputs = {key: value.to(model.device) for key, value in inputs.items()}
|
463 |
|
464 |
+
print("Inputs loaded onto device...")
|
465 |
+
|
466 |
+
# Generate
|
467 |
res = model.generate(
|
468 |
+
**inputs,
|
469 |
+
max_length=4096,
|
470 |
+
do_sample=False,
|
471 |
eos_token_id=tokenizer.eos_token_id,
|
472 |
+
use_cache=True,
|
473 |
+
streamer=streamer
|
474 |
)
|
475 |
output = tokenizer.decode(res[0], skip_special_tokens=True)
|
476 |
+
# Remove everything up to the closing [/INST]
|
477 |
answer = re.sub(r'.*\[/INST\]\s*', '', output, flags=re.DOTALL)
|
478 |
|
479 |
+
# Clean up GPU memory
|
480 |
torch.cuda.empty_cache()
|
481 |
|
482 |
return info, answer
|
483 |
|
484 |
|
485 |
+
################################################################################
|
486 |
+
# Gradio Interface
|
487 |
+
################################################################################
|
488 |
+
|
489 |
demo = gr.Interface(
|
490 |
+
fn=predict,
|
491 |
inputs=[
|
492 |
gr.Textbox(
|
493 |
label="Ticker",
|
494 |
value="AAPL",
|
495 |
+
info="Companies from Dow-30 are recommended"
|
496 |
),
|
497 |
gr.Textbox(
|
498 |
label="Date",
|
499 |
value=get_curday,
|
500 |
+
info="Date from which the prediction is made, format yyyy-mm-dd"
|
501 |
),
|
502 |
gr.Slider(
|
503 |
minimum=1,
|
|
|
505 |
value=3,
|
506 |
step=1,
|
507 |
label="n_weeks",
|
508 |
+
info="Information of the past n weeks will be utilized. Choose between 1 and 4."
|
509 |
),
|
510 |
gr.Checkbox(
|
511 |
label="Use Latest Basic Financials",
|
512 |
value=False,
|
513 |
+
info="If checked, the latest quarterly reported basic financials from Polygon are taken into account."
|
514 |
)
|
515 |
],
|
516 |
outputs=[
|
517 |
+
gr.Textbox(label="Information"),
|
518 |
+
gr.Textbox(label="Response")
|
|
|
|
|
|
|
|
|
519 |
],
|
520 |
title="Pro Capital",
|
521 |
+
description="Pro Capital implementation using yfinance for historical prices, "
|
522 |
+
"Finnhub + Polygon for news, and Polygon for financials."
|
523 |
)
|
524 |
|
525 |
+
if __name__ == "__main__":
|
526 |
+
demo.launch()
|