seawolf2357 commited on
Commit
e33796d
ยท
verified ยท
1 Parent(s): 713547c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -106
app.py CHANGED
@@ -41,137 +41,68 @@ The user provided the additional info about how they would like you to respond:
41
  total_tokens_used = 0
42
 
43
  def format_prompt(message, history):
44
- prompt = "<s>[SYSTEM] {} [/SYSTEM]".format(system_instruction)
45
  for user_prompt, bot_response in history:
46
  prompt += f"[INST] {user_prompt} [/INST]{bot_response}</s> "
47
  prompt += f"[INST] {message} [/INST]"
48
  return prompt
49
 
 
 
 
 
 
 
 
 
50
  def get_stock_data(ticker):
51
- stock = yf.Ticker(ticker)
52
- hist = stock.history(period="5d") # ์ง€๋‚œ 5์ผ๊ฐ„์˜ ์ฃผ์‹ ๋ฐ์ดํ„ฐ๋ฅผ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.
53
- return hist
 
 
 
54
 
55
  def generate(prompt, history=[], temperature=0.1, max_new_tokens=10000, top_p=0.95, repetition_penalty=1.0):
56
  global total_tokens_used
57
- input_tokens = len(tokenizer.encode(prompt))
58
- total_tokens_used += input_tokens
59
- available_tokens = 32768 - total_tokens_used
60
- if available_tokens <= 0:
61
- yield f"Error: ์ž…๋ ฅ์ด ์ตœ๋Œ€ ํ—ˆ์šฉ ํ† ํฐ ์ˆ˜๋ฅผ ์ดˆ๊ณผํ•ฉ๋‹ˆ๋‹ค. Total tokens used: {total_tokens_used}"
62
- return
63
-
64
- formatted_prompt = format_prompt(prompt, history)
65
- output_accumulated = ""
66
  try:
67
- # ํ‹ฐ์ปค ํ™•์ธ ๋ฐ ๋ฐ์ดํ„ฐ ์ˆ˜์ง‘
68
- stock_info = get_stock_info(prompt) # ์ข…๋ชฉ๋ช…์„ ํ† ๋Œ€๋กœ ํ‹ฐ์ปค ์ •๋ณด์™€ ๊ธฐ์—… ์„ค๋ช…์„ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.
69
- if stock_info['ticker']:
70
- response_msg = f"{stock_info['name']}์€(๋Š”) {stock_info['description']} ์ฃผ๋ ฅ์œผ๋กœ ์ƒ์‚ฐํ•˜๋Š” ๊ธฐ์—…์ž…๋‹ˆ๋‹ค. {stock_info['name']}์˜ ํ‹ฐ์ปค๋Š” {stock_info['ticker']}์ž…๋‹ˆ๋‹ค. ์›ํ•˜์‹œ๋Š” ์ข…๋ชฉ์ด ๋งž๋Š”๊ฐ€์š”?"
71
- output_accumulated += response_msg
72
- yield output_accumulated
73
 
74
- # ์ถ”๊ฐ€์ ์ธ ๋ถ„์„ ์š”์ฒญ์ด ์žˆ๋‹ค๋ฉด, yfinance๋กœ ๋ฐ์ดํ„ฐ ์ˆ˜์ง‘ ๋ฐ ๋ถ„์„
75
- stock_data = get_stock_data(stock_info['ticker']) # ํ‹ฐ์ปค๋ฅผ ์ด์šฉํ•ด ์ฃผ์‹ ๋ฐ์ดํ„ฐ๋ฅผ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.
76
- stream = client.text_generation(
77
- formatted_prompt,
78
- temperature=temperature,
79
- max_new_tokens=min(max_new_tokens, available_tokens),
80
- top_p=top_p,
81
- repetition_penalty=repetition_penalty,
82
- do_sample=True,
83
- seed=42,
84
- stream=True
85
- )
86
- for response in stream:
87
- output_part = response['generated_text'] if 'generated_text' in response else str(response)
88
- output_accumulated += output_part
89
- yield output_accumulated + f"\n\n---\nTotal tokens used: {total_tokens_used}\nStock Data: {stock_data}"
90
  else:
91
- # ์ž…๋ ฅ์ด ํ‹ฐ์ปค์ธ ๊ฒฝ์šฐ ์ฒ˜๋ฆฌ
92
- ticker = prompt.upper()
93
- if ticker in ['AAPL', 'MSFT', 'AMZN', 'GOOGL', 'TSLA']:
94
- stock_info = get_stock_info_by_ticker(ticker)
95
- response_msg = f"{stock_info['name']}์€(๋Š”) {stock_info['description']} ์ฃผ๋ ฅ์œผ๋กœ ์ƒ์‚ฐํ•˜๋Š” ๊ธฐ์—…์ž…๋‹ˆ๋‹ค. {stock_info['name']}์˜ ํ‹ฐ์ปค๋Š” {stock_info['ticker']}์ž…๋‹ˆ๋‹ค. ์›ํ•˜์‹œ๋Š” ์ข…๋ชฉ์ด ๋งž๋Š”๊ฐ€์š”?"
96
- output_accumulated += response_msg
97
- yield output_accumulated
98
-
99
- # ์ถ”๊ฐ€์ ์ธ ๋ถ„์„ ์š”์ฒญ์ด ์žˆ๋‹ค๋ฉด, yfinance๋กœ ๋ฐ์ดํ„ฐ ์ˆ˜์ง‘ ๋ฐ ๋ถ„์„
100
- stock_data = get_stock_data(stock_info['ticker']) # ํ‹ฐ์ปค๋ฅผ ์ด์šฉํ•ด ์ฃผ์‹ ๋ฐ์ดํ„ฐ๋ฅผ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.
101
- stream = client.text_generation(
102
- formatted_prompt,
103
- temperature=temperature,
104
- max_new_tokens=min(max_new_tokens, available_tokens),
105
- top_p=top_p,
106
- repetition_penalty=repetition_penalty,
107
- do_sample=True,
108
- seed=42,
109
- stream=True
110
- )
111
- for response in stream:
112
- output_part = response['generated_text'] if 'generated_text' in response else str(response)
113
- output_accumulated += output_part
114
- yield output_accumulated + f"\n\n---\nTotal tokens used: {total_tokens_used}\nStock Data: {stock_data}"
115
- else:
116
- yield f"์ž…๋ ฅํ•˜์‹  '{prompt}'์€(๋Š”) ์ง€์›๋˜๋Š” ์ข…๋ชฉ๋ช… ๋˜๋Š” ํ‹ฐ์ปค๊ฐ€ ์•„๋‹™๋‹ˆ๋‹ค. ํ˜„์žฌ ์ง€์›๋˜๋Š” ์ข…๋ชฉ์€ ์• ํ”Œ(AAPL), ๋งˆ์ดํฌ๋กœ์†Œํ”„ํŠธ(MSFT), ์•„๋งˆ์กด(AMZN), ์•ŒํŒŒ๋ฒณ(GOOGL), ํ…Œ์Šฌ๋ผ(TSLA) ๋“ฑ์ž…๋‹ˆ๋‹ค. ์ •ํ™•ํ•œ ์ข…๋ชฉ๋ช… ๋˜๋Š” ํ‹ฐ์ปค๋ฅผ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”."
117
  except Exception as e:
118
- yield f"Error: {str(e)}\nTotal tokens used: {total_tokens_used}"
119
 
120
- # ํ‹ฐ์ปค๋ฅผ ํ† ๋Œ€๋กœ ์ข…๋ชฉ ์ •๋ณด๋ฅผ ์ œ๊ณตํ•˜๋Š” ํ•จ์ˆ˜
121
- def get_stock_info_by_ticker(ticker):
122
- stock_info = {
123
- "AAPL": {'ticker': 'AAPL', 'name': '์• ํ”Œ', 'description': '์•„์ดํฐ์„'},
124
- "MSFT": {'ticker': 'MSFT', 'name': '๋งˆ์ดํฌ๋กœ์†Œํ”„ํŠธ', 'description': '์œˆ๋„์šฐ ์šด์˜์ฒด์ œ์™€ ์˜คํ”ผ์Šค ์†Œํ”„ํŠธ์›จ์–ด๋ฅผ'},
125
- "AMZN": {'ticker': 'AMZN', 'name': '์•„๋งˆ์กด', 'description': '์ „์ž์ƒ๊ฑฐ๋ž˜ ๋ฐ ํด๋ผ์šฐ๋“œ ์„œ๋น„์Šค๋ฅผ'},
126
- "GOOGL": {'ticker': 'GOOGL', 'name': '์•ŒํŒŒ๋ฒณ', 'description': '๊ฒ€์ƒ‰ ์—”์ง„ ๋ฐ ์˜จ๋ผ์ธ ๊ด‘๊ณ ๋ฅผ'},
127
- "TSLA": {'ticker': 'TSLA', 'name': 'ํ…Œ์Šฌ๋ผ', 'description': '์ „๊ธฐ์ž๋™์ฐจ์™€ ์—๋„ˆ์ง€ ์ €์žฅ์žฅ์น˜๋ฅผ'},
128
- }
129
- return stock_info.get(ticker, {'ticker': None, 'name': None, 'description': ''})
130
-
131
- # ์ข…๋ชฉ๋ช…์„ ํ† ๋Œ€๋กœ ํ‹ฐ์ปค์™€ ๊ธฐ์—… ์ •๋ณด๋ฅผ ์ œ๊ณตํ•˜๋Š” ํ•จ์ˆ˜
132
- def get_stock_info(name):
133
- stock_info = {
134
- "apple": {'ticker': 'AAPL', 'name': '์• ํ”Œ', 'description': '์•„์ดํฐ์„'},
135
- "microsoft": {'ticker': 'MSFT', 'name': '๋งˆ์ดํฌ๋กœ์†Œํ”„ํŠธ', 'description': '์œˆ๋„์šฐ ์šด์˜์ฒด์ œ์™€ ์˜คํ”ผ์Šค ์†Œํ”„ํŠธ์›จ์–ด๋ฅผ'},
136
- "amazon": {'ticker': 'AMZN', 'name': '์•„๋งˆ์กด', 'description': '์ „์ž์ƒ๊ฑฐ๋ž˜ ๋ฐ ํด๋ผ์šฐ๋“œ ์„œ๋น„์Šค๋ฅผ'},
137
- "google": {'ticker': 'GOOGL', 'name': '์•ŒํŒŒ๋ฒณ (๊ตฌ๊ธ€)', 'description': '๊ฒ€์ƒ‰ ์—”์ง„ ๋ฐ ์˜จ๋ผ์ธ ๊ด‘๊ณ ๋ฅผ'},
138
- "tesla": {'ticker': 'TSLA', 'name': 'ํ…Œ์Šฌ๋ผ', 'description': '์ „๊ธฐ์ž๋™์ฐจ์™€ ์—๋„ˆ์ง€ ์ €์žฅ์žฅ์น˜๋ฅผ'},
139
- # ์ถ”๊ฐ€์ ์ธ ์ข…๋ชฉ์— ๋Œ€ํ•œ ์ •๋ณด๋ฅผ ์ด๊ณณ์— ๊ตฌํ˜„ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
140
- }
141
- return stock_info.get(name.lower(), {'ticker': None, 'name': name, 'description': ''})
142
-
143
  mychatbot = gr.Chatbot(
144
  avatar_images=["./user.png", "./botm.png"],
145
  bubble_full_width=False,
146
  show_label=False,
147
  show_copy_button=True,
148
  likeable=True,
 
 
 
 
 
 
149
  )
150
 
151
- examples = [
152
- ["๋ฐ˜๋“œ์‹œ ํ•œ๊ธ€๋กœ ๋‹ต๋ณ€ํ• ๊ฒƒ.", []],
153
- ["์ข‹์€ ์ข…๋ชฉ(ํ‹ฐ์ปค) ์ถ”์ฒœํ•ด์ค˜", []],
154
- ["์š”์•ฝ ๊ฒฐ๋ก ์„ ์ œ์‹œํ•ด", []],
155
- ["ํฌํŠธํด๋ฆฌ์˜ค ๋ถ„์„ํ•ด์ค˜", []]
156
- ]
157
-
158
- css = """
159
- h1 {
160
- font-size: 14px;
161
- }
162
- footer {
163
- visibility: hidden;
164
- }
165
- """
166
-
167
  demo = gr.ChatInterface(
168
  fn=generate,
169
  chatbot=mychatbot,
170
  title="๊ธ€๋กœ๋ฒŒ ์ž์‚ฐ ๋ถ„์„ ๋ฐ ์˜ˆ์ธก LLM: BloombAI",
171
- retry_btn=None,
172
- undo_btn=None,
173
- css=css,
174
- examples=examples
175
  )
176
 
177
- demo.queue().launch(show_api=False)
 
41
  total_tokens_used = 0
42
 
43
  def format_prompt(message, history):
44
+ prompt = f"<s>[SYSTEM] {system_instruction} [/SYSTEM]"
45
  for user_prompt, bot_response in history:
46
  prompt += f"[INST] {user_prompt} [/INST]{bot_response}</s> "
47
  prompt += f"[INST] {message} [/INST]"
48
  return prompt
49
 
50
+ stock_info = {
51
+ "AAPL": {'name': '์• ํ”Œ', 'description': '์•„์ดํฐ์„ ์ฃผ๋ ฅ์œผ๋กœ ์ƒ์‚ฐํ•˜๋Š”'},
52
+ "MSFT": {'name': '๋งˆ์ดํฌ๋กœ์†Œํ”„ํŠธ', 'description': '์œˆ๋„์šฐ ์šด์˜์ฒด์ œ์™€ ์˜คํ”ผ์Šค ์†Œํ”„ํŠธ์›จ์–ด๋ฅผ'},
53
+ "AMZN": {'name': '์•„๋งˆ์กด', 'description': '์ „์ž์ƒ๊ฑฐ๋ž˜ ๋ฐ ํด๋ผ์šฐ๋“œ ์„œ๋น„์Šค๋ฅผ'},
54
+ "GOOGL": {'name': '์•ŒํŒŒ๋ฒณ (๊ตฌ๊ธ€)', 'description': '๊ฒ€์ƒ‰ ์—”์ง„ ๋ฐ ์˜จ๋ผ์ธ ๊ด‘๊ณ ๋ฅผ'},
55
+ "TSLA": {'name': 'ํ…Œ์Šฌ๋ผ', 'description': '์ „๊ธฐ์ž๋™์ฐจ์™€ ์—๋„ˆ์ง€ ์ €์žฅ์žฅ์น˜๋ฅผ'}
56
+ }
57
+
58
  def get_stock_data(ticker):
59
+ try:
60
+ stock = yf.Ticker(ticker)
61
+ hist = stock.history(period="5d")
62
+ return hist
63
+ except Exception as e:
64
+ return f"๋ฐ์ดํ„ฐ๋ฅผ ๋ถˆ๋Ÿฌ์˜ค๋Š” ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {e}"
65
 
66
  def generate(prompt, history=[], temperature=0.1, max_new_tokens=10000, top_p=0.95, repetition_penalty=1.0):
67
  global total_tokens_used
 
 
 
 
 
 
 
 
 
68
  try:
69
+ input_tokens = len(tokenizer.encode(prompt))
70
+ total_tokens_used += input_tokens
71
+ available_tokens = 32768 - total_tokens_used
72
+ if available_tokens <= 0:
73
+ return "Error: ์ž…๋ ฅ์ด ์ตœ๋Œ€ ํ—ˆ์šฉ ํ† ํฐ ์ˆ˜๋ฅผ ์ดˆ๊ณผํ•ฉ๋‹ˆ๋‹ค."
 
74
 
75
+ formatted_prompt = format_prompt(prompt, history)
76
+ ticker = prompt.upper()
77
+ stock_info_detail = stock_info.get(ticker, None)
78
+ if stock_info_detail:
79
+ response_msg = f"{stock_info_detail['name']}์€(๋Š”) {stock_info_detail['description']} ์ฃผ๋ ฅ์œผ๋กœ ์ƒ์‚ฐํ•˜๋Š” ๊ธฐ์—…์ž…๋‹ˆ๋‹ค. ํ‹ฐ์ปค๋Š” {ticker}์ž…๋‹ˆ๋‹ค. ์›ํ•˜์‹œ๋Š” ์ข…๋ชฉ์ด ๋งž๋Š”๊ฐ€์š”?"
80
+ stock_data = get_stock_data(ticker)
81
+ return f"{response_msg}\n\n---\nTotal tokens used: {total_tokens_used}\nStock Data: {stock_data}"
 
 
 
 
 
 
 
 
 
82
  else:
83
+ return f"์ž…๋ ฅํ•˜์‹  '{prompt}'์€(๋Š”) ์ง€์›๋˜๋Š” ์ข…๋ชฉ๋ช… ๋˜๋Š” ํ‹ฐ์ปค๊ฐ€ ์•„๋‹™๋‹ˆ๋‹ค. ์ง€์›๋˜๋Š” ํ‹ฐ์ปค: {', '.join(stock_info.keys())}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  except Exception as e:
85
+ return f"Error: {str(e)}\nTotal tokens used: {total_tokens_used}"
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  mychatbot = gr.Chatbot(
88
  avatar_images=["./user.png", "./botm.png"],
89
  bubble_full_width=False,
90
  show_label=False,
91
  show_copy_button=True,
92
  likeable=True,
93
+ examples=[
94
+ ["๋ฐ˜๋“œ์‹œ ํ•œ๊ธ€๋กœ ๋‹ต๋ณ€ํ• ๊ฒƒ.", []],
95
+ ["์ข‹์€ ์ข…๋ชฉ(ํ‹ฐ์ปค) ์ถ”์ฒœํ•ด์ค˜", []],
96
+ ["์š”์•ฝ ๊ฒฐ๋ก ์„ ์ œ์‹œํ•ด", []],
97
+ ["ํฌํŠธํด๋ฆฌ์˜ค ๋ถ„์„ํ•ด์ค˜", []]
98
+ ]
99
  )
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  demo = gr.ChatInterface(
102
  fn=generate,
103
  chatbot=mychatbot,
104
  title="๊ธ€๋กœ๋ฒŒ ์ž์‚ฐ ๋ถ„์„ ๋ฐ ์˜ˆ์ธก LLM: BloombAI",
105
+ css="h1 { font-size: 14px; } footer { visibility: hidden; }"
 
 
 
106
  )
107
 
108
+ demo.launch(show_api=False)