Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -41,68 +41,137 @@ The user provided the additional info about how they would like you to respond:
|
|
41 |
total_tokens_used = 0
|
42 |
|
43 |
def format_prompt(message, history):
|
44 |
-
prompt =
|
45 |
for user_prompt, bot_response in history:
|
46 |
prompt += f"[INST] {user_prompt} [/INST]{bot_response}</s> "
|
47 |
prompt += f"[INST] {message} [/INST]"
|
48 |
return prompt
|
49 |
|
50 |
-
stock_info = {
|
51 |
-
"AAPL": {'name': '애플', 'description': '아이폰을 주력으로 생산하는'},
|
52 |
-
"MSFT": {'name': '마이크로소프트', 'description': '윈도우 운영체제와 오피스 소프트웨어를'},
|
53 |
-
"AMZN": {'name': '아마존', 'description': '전자상거래 및 클라우드 서비스를'},
|
54 |
-
"GOOGL": {'name': '알파벳 (구글)', 'description': '검색 엔진 및 온라인 광고를'},
|
55 |
-
"TSLA": {'name': '테슬라', 'description': '전기자동차와 에너지 저장장치를'}
|
56 |
-
}
|
57 |
-
|
58 |
def get_stock_data(ticker):
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
return hist
|
63 |
-
except Exception as e:
|
64 |
-
return f"데이터를 불러오는 중 오류가 발생했습니다: {e}"
|
65 |
|
66 |
def generate(prompt, history=[], temperature=0.1, max_new_tokens=10000, top_p=0.95, repetition_penalty=1.0):
|
67 |
global total_tokens_used
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
try:
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
|
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
else:
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
except Exception as e:
|
85 |
-
|
86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
mychatbot = gr.Chatbot(
|
88 |
avatar_images=["./user.png", "./botm.png"],
|
89 |
bubble_full_width=False,
|
90 |
show_label=False,
|
91 |
show_copy_button=True,
|
92 |
-
likeable=True
|
93 |
)
|
94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
demo = gr.ChatInterface(
|
96 |
fn=generate,
|
97 |
chatbot=mychatbot,
|
98 |
title="글로벌 자산 분석 및 예측 LLM: BloombAI",
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
["요약 결론을 제시해", []],
|
104 |
-
["포트폴리오 분석해줘", []]
|
105 |
-
]
|
106 |
)
|
107 |
|
108 |
-
demo.launch(show_api=False)
|
|
|
41 |
total_tokens_used = 0
|
42 |
|
43 |
def format_prompt(message, history):
|
44 |
+
prompt = "<s>[SYSTEM] {} [/SYSTEM]".format(system_instruction)
|
45 |
for user_prompt, bot_response in history:
|
46 |
prompt += f"[INST] {user_prompt} [/INST]{bot_response}</s> "
|
47 |
prompt += f"[INST] {message} [/INST]"
|
48 |
return prompt
|
49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
def get_stock_data(ticker):
|
51 |
+
stock = yf.Ticker(ticker)
|
52 |
+
hist = stock.history(period="5d") # 지난 5일간의 주식 데이터를 가져옵니다.
|
53 |
+
return hist
|
|
|
|
|
|
|
54 |
|
55 |
def generate(prompt, history=[], temperature=0.1, max_new_tokens=10000, top_p=0.95, repetition_penalty=1.0):
|
56 |
global total_tokens_used
|
57 |
+
input_tokens = len(tokenizer.encode(prompt))
|
58 |
+
total_tokens_used += input_tokens
|
59 |
+
available_tokens = 32768 - total_tokens_used
|
60 |
+
if available_tokens <= 0:
|
61 |
+
yield f"Error: 입력이 최대 허용 토큰 수를 초과합니다. Total tokens used: {total_tokens_used}"
|
62 |
+
return
|
63 |
+
|
64 |
+
formatted_prompt = format_prompt(prompt, history)
|
65 |
+
output_accumulated = ""
|
66 |
try:
|
67 |
+
# 티커 확인 및 데이터 수집
|
68 |
+
stock_info = get_stock_info(prompt) # 종목명을 토대로 티커 정보와 기업 설명을 가져옵니다.
|
69 |
+
if stock_info['ticker']:
|
70 |
+
response_msg = f"{stock_info['name']}은(는) {stock_info['description']} 주력으로 생산하는 기업입니다. {stock_info['name']}의 티커는 {stock_info['ticker']}입니다. 원하시는 종목이 맞는가요?"
|
71 |
+
output_accumulated += response_msg
|
72 |
+
yield output_accumulated
|
73 |
|
74 |
+
# 추가적인 분석 요청이 있다면, yfinance로 데이터 수집 및 분석
|
75 |
+
stock_data = get_stock_data(stock_info['ticker']) # 티커를 이용해 주식 데이터를 가져옵니다.
|
76 |
+
stream = client.text_generation(
|
77 |
+
formatted_prompt,
|
78 |
+
temperature=temperature,
|
79 |
+
max_new_tokens=min(max_new_tokens, available_tokens),
|
80 |
+
top_p=top_p,
|
81 |
+
repetition_penalty=repetition_penalty,
|
82 |
+
do_sample=True,
|
83 |
+
seed=42,
|
84 |
+
stream=True
|
85 |
+
)
|
86 |
+
for response in stream:
|
87 |
+
output_part = response['generated_text'] if 'generated_text' in response else str(response)
|
88 |
+
output_accumulated += output_part
|
89 |
+
yield output_accumulated + f"\n\n---\nTotal tokens used: {total_tokens_used}\nStock Data: {stock_data}"
|
90 |
else:
|
91 |
+
# 입력이 티커인 경우 처리
|
92 |
+
ticker = prompt.upper()
|
93 |
+
if ticker in ['AAPL', 'MSFT', 'AMZN', 'GOOGL', 'TSLA']:
|
94 |
+
stock_info = get_stock_info_by_ticker(ticker)
|
95 |
+
response_msg = f"{stock_info['name']}은(는) {stock_info['description']} 주력으로 생산하는 기업입니다. {stock_info['name']}의 티커는 {stock_info['ticker']}입니다. 원하시는 종목이 맞는가요?"
|
96 |
+
output_accumulated += response_msg
|
97 |
+
yield output_accumulated
|
98 |
+
|
99 |
+
# 추가적인 분석 요청이 있다면, yfinance로 데이터 수집 및 분석
|
100 |
+
stock_data = get_stock_data(stock_info['ticker']) # 티커를 ��용해 주식 데이터를 가져옵니다.
|
101 |
+
stream = client.text_generation(
|
102 |
+
formatted_prompt,
|
103 |
+
temperature=temperature,
|
104 |
+
max_new_tokens=min(max_new_tokens, available_tokens),
|
105 |
+
top_p=top_p,
|
106 |
+
repetition_penalty=repetition_penalty,
|
107 |
+
do_sample=True,
|
108 |
+
seed=42,
|
109 |
+
stream=True
|
110 |
+
)
|
111 |
+
for response in stream:
|
112 |
+
output_part = response['generated_text'] if 'generated_text' in response else str(response)
|
113 |
+
output_accumulated += output_part
|
114 |
+
yield output_accumulated + f"\n\n---\nTotal tokens used: {total_tokens_used}\nStock Data: {stock_data}"
|
115 |
+
else:
|
116 |
+
yield f"입력하신 '{prompt}'은(는) 지원되는 종목명 또는 티커가 아닙니다. 현재 지원되는 종목은 애플(AAPL), 마이크로소프트(MSFT), 아마존(AMZN), 알파벳(GOOGL), 테슬라(TSLA) 등입니다. 정확한 종목명 또는 티커를 입력해주세요."
|
117 |
except Exception as e:
|
118 |
+
yield f"Error: {str(e)}\nTotal tokens used: {total_tokens_used}"
|
119 |
|
120 |
+
# 티커를 토대로 종목 정보를 제공하는 함수
|
121 |
+
def get_stock_info_by_ticker(ticker):
|
122 |
+
stock_info = {
|
123 |
+
"AAPL": {'ticker': 'AAPL', 'name': '애플', 'description': '아이폰을'},
|
124 |
+
"MSFT": {'ticker': 'MSFT', 'name': '마이크로소프트', 'description': '윈도우 운영체제와 오피스 소프트웨어를'},
|
125 |
+
"AMZN": {'ticker': 'AMZN', 'name': '아마존', 'description': '전자상거래 및 클라우드 서비스를'},
|
126 |
+
"GOOGL": {'ticker': 'GOOGL', 'name': '알파벳', 'description': '검색 엔진 및 온라인 광고를'},
|
127 |
+
"TSLA": {'ticker': 'TSLA', 'name': '테슬라', 'description': '전기자동차와 에너지 저장장치를'},
|
128 |
+
}
|
129 |
+
return stock_info.get(ticker, {'ticker': None, 'name': None, 'description': ''})
|
130 |
+
|
131 |
+
# 종목명을 토대로 티커와 기업 정보를 제공하는 함수
|
132 |
+
def get_stock_info(name):
|
133 |
+
stock_info = {
|
134 |
+
"apple": {'ticker': 'AAPL', 'name': '애플', 'description': '아이폰을'},
|
135 |
+
"microsoft": {'ticker': 'MSFT', 'name': '마이크로소프트', 'description': '윈도우 운영체제와 오피스 소프트웨어를'},
|
136 |
+
"amazon": {'ticker': 'AMZN', 'name': '아마존', 'description': '전자상거래 및 클라우드 서비스를'},
|
137 |
+
"google": {'ticker': 'GOOGL', 'name': '알파벳 (구글)', 'description': '검색 엔진 및 온라인 광고를'},
|
138 |
+
"tesla": {'ticker': 'TSLA', 'name': '테슬라', 'description': '전기자동차와 에너지 저장장치를'},
|
139 |
+
# 추가적인 종목에 대한 정보를 이곳에 구현할 수 있습니다.
|
140 |
+
}
|
141 |
+
return stock_info.get(name.lower(), {'ticker': None, 'name': name, 'description': ''})
|
142 |
+
|
143 |
mychatbot = gr.Chatbot(
|
144 |
avatar_images=["./user.png", "./botm.png"],
|
145 |
bubble_full_width=False,
|
146 |
show_label=False,
|
147 |
show_copy_button=True,
|
148 |
+
likeable=True,
|
149 |
)
|
150 |
|
151 |
+
examples = [
|
152 |
+
["반드시 한글로 답변할것.", []],
|
153 |
+
["좋은 종목(티커) 추천해줘", []],
|
154 |
+
["요약 결론을 제시해", []],
|
155 |
+
["포트폴리오 분석해줘", []]
|
156 |
+
]
|
157 |
+
|
158 |
+
css = """
|
159 |
+
h1 {
|
160 |
+
font-size: 14px;
|
161 |
+
}
|
162 |
+
footer {
|
163 |
+
visibility: hidden;
|
164 |
+
}
|
165 |
+
"""
|
166 |
+
|
167 |
demo = gr.ChatInterface(
|
168 |
fn=generate,
|
169 |
chatbot=mychatbot,
|
170 |
title="글로벌 자산 분석 및 예측 LLM: BloombAI",
|
171 |
+
retry_btn=None,
|
172 |
+
undo_btn=None,
|
173 |
+
css=css,
|
174 |
+
examples=examples
|
|
|
|
|
|
|
175 |
)
|
176 |
|
177 |
+
demo.queue().launch(show_api=False)
|