seawolf2357 commited on
Commit
d55c709
ยท
verified ยท
1 Parent(s): dd00036

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -7
app.py CHANGED
@@ -66,16 +66,33 @@ def generate(prompt, history=[], temperature=0.1, max_new_tokens=10000, top_p=0.
66
  formatted_prompt = format_prompt(prompt, history)
67
  output_accumulated = ""
68
  try:
69
- stock_data = get_stock_data("AAPL") # ์˜ˆ์‹œ๋กœ 'AAPL' ํ‹ฐ์ปค ๋ฐ์ดํ„ฐ๋ฅผ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.
70
- stream = client.text_generation(formatted_prompt, temperature=temperature, max_new_tokens=min(max_new_tokens, available_tokens),
71
- top_p=top_p, repetition_penalty=repetition_penalty, do_sample=True, seed=42, stream=True)
72
- for response in stream:
73
- output_part = response['generated_text'] if 'generated_text' in response else str(response)
74
- output_accumulated += output_part
75
- yield output_accumulated + f"\n\n---\nTotal tokens used: {total_tokens_used}\nStock Data: {stock_data}"
 
 
 
 
 
 
 
 
 
76
  except Exception as e:
77
  yield f"Error: {str(e)}\nTotal tokens used: {total_tokens_used}"
78
 
 
 
 
 
 
 
 
 
79
  mychatbot = gr.Chatbot(
80
  avatar_images=["./user.png", "./botm.png"],
81
  bubble_full_width=False,
 
66
  formatted_prompt = format_prompt(prompt, history)
67
  output_accumulated = ""
68
  try:
69
+ # ํ‹ฐ์ปค ํ™•์ธ ๋ฐ ๋ฐ์ดํ„ฐ ์ˆ˜์ง‘
70
+ stock_info = get_stock_info(prompt) # ์ข…๋ชฉ๋ช…์„ ํ† ๋Œ€๋กœ ํ‹ฐ์ปค ์ •๋ณด์™€ ๊ธฐ์—… ์„ค๋ช…์„ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.
71
+ if stock_info['ticker']:
72
+ response_msg = f"{stock_info['name']}์€(๋Š”) {stock_info['description']} ์ฃผ๋ ฅ์œผ๋กœ ์ƒ์‚ฐํ•˜๋Š” ๊ธฐ์—…์ž…๋‹ˆ๋‹ค. {stock_info['name']}์˜ ํ‹ฐ์ปค๋Š” {stock_info['ticker']}์ž…๋‹ˆ๋‹ค. ์›ํ•˜์‹œ๋Š” ์ข…๋ชฉ์ด ๋งž๋Š”๊ฐ€์š”?"
73
+ output_accumulated += response_msg
74
+ yield output_accumulated
75
+ # ์ถ”๊ฐ€์ ์ธ ๋ถ„์„ ์š”์ฒญ์ด ์žˆ๋‹ค๋ฉด, yfinance๋กœ ๋ฐ์ดํ„ฐ ์ˆ˜์ง‘ ๋ฐ ๋ถ„์„
76
+ stock_data = get_stock_data(stock_info['ticker']) # ํ‹ฐ์ปค๋ฅผ ์ด์šฉํ•ด ์ฃผ์‹ ๋ฐ์ดํ„ฐ๋ฅผ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.
77
+ stream = client.text_generation(formatted_prompt, temperature=temperature, max_new_tokens=min(max_new_tokens, available_tokens),
78
+ top_p=top_p, repetition_penalty=repetition_penalty, do_sample=True, seed=42, stream=True)
79
+ for response in stream:
80
+ output_part = response['generated_text'] if 'generated_text' in response else str(response)
81
+ output_accumulated += output_part
82
+ yield output_accumulated + f"\n\n---\nTotal tokens used: {total_tokens_used}\nStock Data: {stock_data}"
83
+ else:
84
+ yield "์ž…๋ ฅํ•˜์‹  ์ข…๋ชฉ๋ช…์„ ํ™•์ธํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ์ •ํ™•ํ•œ ์ข…๋ชฉ๋ช…์„ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”."
85
  except Exception as e:
86
  yield f"Error: {str(e)}\nTotal tokens used: {total_tokens_used}"
87
 
88
+ # ์ข…๋ชฉ๋ช…์„ ํ† ๋Œ€๋กœ ํ‹ฐ์ปค์™€ ๊ธฐ์—… ์ •๋ณด๋ฅผ ์ œ๊ณตํ•˜๋Š” ํ•จ์ˆ˜
89
+ def get_stock_info(name):
90
+ if name.lower() == "apple":
91
+ return {'ticker': 'AAPL', 'name': '์• ํ”Œ', 'description': '์•„์ดํฐ์„'}
92
+ # ์ถ”๊ฐ€์ ์ธ ์ข…๋ชฉ์— ๋Œ€ํ•œ ์ •๋ณด๋ฅผ ์ด๊ณณ์— ๊ตฌํ˜„ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
93
+ return {'ticker': None, 'name': name, 'description': ''}
94
+
95
+
96
  mychatbot = gr.Chatbot(
97
  avatar_images=["./user.png", "./botm.png"],
98
  bubble_full_width=False,