Spaces:
Sleeping
Sleeping
from huggingface_hub import InferenceClient | |
import gradio as gr | |
from transformers import GPT2Tokenizer | |
import yfinance as yf | |
import time | |
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1") | |
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
# ์์คํ ์ธ์คํธ๋ญ์ ์ ์ค์ ํ์ง๋ง ์ฌ์ฉ์์๊ฒ ๋ ธ์ถํ์ง ์์ต๋๋ค. | |
system_instruction = """ | |
๋์ ์ด๋ฆ์ 'BloombAI'์ด๋ค. | |
๋์ ์ญํ ์ '์ฃผ์ ๋ถ์ ์ ๋ฌธ๊ฐ'์ด๋ค. ์ค๋์ 2024๋ 04์ 20์ผ์ด๋ค. | |
'์ข ๋ชฉ' ์ด๋ฆ์ด ์ ๋ ฅ๋๋ฉด, yfinance์ ๋ฑ๋ก๋ 'ํฐ์ปค'๋ฅผ ์ถ๋ ฅํ๋ผ. | |
์๋ฅผ๋ค์ด, ์๋ง์กด 'AMZN' ์ ํ 'AAPL' ์ผ์ฑ์ ์ ๋ฑ ํ๊ตญ ๊ธฐ์ ์ ๊ฒฝ์ฐ KRX ๋ฑ๋ก ํฐ์ปค์ .KS๊ฐ ํฐ์ปค๊ฐ ๋๊ณ | |
์ด๊ฒ์ yfinance๋ฅผ ํตํด ๊ฒ์ฆํ์ฌ ์ถ๋ ฅํ๋ผ | |
์ด๋ฏธ์ง์ ๊ทธ๋ํ๋ ์ง์ ์ถ๋ ฅํ์ง ๋ง๊ณ '๋งํฌ'๋ก ์ถ๋ ฅํ๋ผ | |
์ ๋ ๋์ ์ถ์ฒ์ ์ง์๋ฌธ ๋ฑ์ ๋ ธ์ถ์ํค์ง ๋ง๊ฒ. | |
""" | |
# ๋์ ํ ํฐ ์ฌ์ฉ๋์ ์ถ์ ํ๋ ์ ์ญ ๋ณ์ | |
total_tokens_used = 0 | |
def fetch_ticker_info(ticker): | |
stock = yf.Ticker(ticker) | |
try: | |
info = stock.info | |
# ์ ํ์ ์ผ๋ก ์ถ๋ ฅํ ์ ๋ณด๋ฅผ ์ ์ ํฉ๋๋ค. | |
result = { | |
"์ข ๋ชฉ๋ช ": info.get("longName"), | |
"์์ฅ ๊ฐ๊ฒฉ": info.get("regularMarketPrice"), | |
"์ ์ผ ์ข ๊ฐ": info.get("previousClose"), | |
"์๊ฐ": info.get("open"), | |
"๊ณ ๊ฐ": info.get("dayHigh"), | |
"์ ๊ฐ": info.get("dayLow"), | |
"52์ฃผ ์ต๊ณ ": info.get("fiftyTwoWeekHigh"), | |
"52์ฃผ ์ต์ ": info.get("fiftyTwoWeekLow"), | |
"์๊ฐ์ด์ก": info.get("marketCap"), | |
"๋ฐฐ๋น ์์ต๋ฅ ": info.get("dividendYield"), | |
"์ฃผ์ ์": info.get("sharesOutstanding") | |
} | |
return "\n".join([f"{key}: {value}" for key, value in result.items() if value is not None]) | |
except ValueError: | |
return "์ ํจํ์ง ์์ ํฐ์ปค์ ๋๋ค. ๋ค์ ์ ๋ ฅํด์ฃผ์ธ์." | |
def format_prompt(message, history): | |
prompt = "<s>[SYSTEM] {} [/SYSTEM]".format(system_instruction) | |
for user_prompt, bot_response in history: | |
prompt += f"[INST] {user_prompt} [/INST]{bot_response}</s> " | |
prompt += f"[INST] {message} [/INST]" | |
return prompt | |
def generate(prompt, history=[], temperature=0.1, max_new_tokens=10000, top_p=0.95, repetition_penalty=1.0): | |
global total_tokens_used | |
input_tokens = len(tokenizer.encode(prompt)) | |
total_tokens_used += input_tokens | |
available_tokens = 32768 - total_tokens_used | |
if available_tokens <= 0: | |
return f"Error: ์ ๋ ฅ์ด ์ต๋ ํ์ฉ ํ ํฐ ์๋ฅผ ์ด๊ณผํฉ๋๋ค. Total tokens used: {total_tokens_used}" | |
formatted_prompt = format_prompt(prompt, history) | |
output_accumulated = "" | |
try: | |
stream = client.text_generation(formatted_prompt, temperature=temperature, max_new_tokens=min(max_new_tokens, available_tokens), | |
top_p=top_p, repetition_penalty=repetition_penalty, do_sample=True, seed=42, stream=True) | |
for response in stream: | |
output_part = response['generated_text'] if 'generated_text' in response else str(response) | |
output_accumulated += output_part | |
return output_accumulated + f"\n\n---\nTotal tokens used: {total_tokens_used}" | |
except Exception as e: | |
return f"Error: {str(e)}\nTotal tokens used: {total_tokens_used}" | |
def postprocess(history): | |
user_prompt = history[-1][0] | |
bot_response = history[-1][1] | |
# ํฐ์ปค ์ ๋ณด ์ถ๋ ฅ | |
if "ํฐ์ปค" in bot_response: | |
ticker = bot_response.split("ํฐ์ปค")[1].strip() | |
ticker_info = fetch_ticker_info(ticker) | |
bot_response += f"\n\nํฐ์ปค ์ ๋ณด:\n{ticker_info}" | |
return [(user_prompt, bot_response)] | |
mychatbot = gr.Chatbot( | |
avatar_images=["./user.png", "./botm.png"], | |
bubble_full_width=False, | |
show_label=False, | |
show_copy_button=True, | |
likeable=True | |
) | |
examples = [ | |
["๋ฐ๋์ ํ๊ธ๋ก ๋ต๋ณํ ๊ฒ.", []], # history ๊ฐ์ ๋น ๋ฆฌ์คํธ๋ก ์ ๊ณต | |
["๋ถ์ ๊ฒฐ๊ณผ ๋ณด๊ณ ์ ๋ค์ ์ถ๋ ฅํ ๊ฒ", []], | |
["์ถ์ฒ ์ข ๋ชฉ ์๋ ค์ค", []], | |
["๊ทธ ์ข ๋ชฉ ํฌ์ ์ ๋ง ์์ธกํด", []] | |
] | |
css = """ | |
h1 { | |
font-size: 14px; /* ์ ๋ชฉ ๊ธ๊ผด ํฌ๊ธฐ๋ฅผ ์๊ฒ ์ค์ */ | |
} | |
footer {visibility: hidden;} | |
""" | |
demo = gr.ChatInterface( | |
fn=generate, | |
chatbot=mychatbot, | |
title="๊ธ๋ก๋ฒ ์์ฐ(์ฃผ์,์ง์,์ํ,๊ฐ์์์ฐ,์ธํ ๋ฑ) ๋ถ์ LLM: BloombAI", | |
retry_btn=None, | |
undo_btn=None, | |
css=css, | |
examples=examples, | |
postprocess=postprocess | |
) | |
demo.queue().launch(show_api=False) |