seawolf2357 commited on
Commit
2217397
·
verified ·
1 Parent(s): 467cd56

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -38
app.py CHANGED
@@ -41,68 +41,137 @@ The user provided the additional info about how they would like you to respond:
41
  total_tokens_used = 0
42
 
43
  def format_prompt(message, history):
44
- prompt = f"<s>[SYSTEM] {system_instruction} [/SYSTEM]"
45
  for user_prompt, bot_response in history:
46
  prompt += f"[INST] {user_prompt} [/INST]{bot_response}</s> "
47
  prompt += f"[INST] {message} [/INST]"
48
  return prompt
49
 
50
- stock_info = {
51
- "AAPL": {'name': '애플', 'description': '아이폰을 주력으로 생산하는'},
52
- "MSFT": {'name': '마이크로소프트', 'description': '윈도우 운영체제와 오피스 소프트웨어를'},
53
- "AMZN": {'name': '아마존', 'description': '전자상거래 및 클라우드 서비스를'},
54
- "GOOGL": {'name': '알파벳 (구글)', 'description': '검색 엔진 및 온라인 광고를'},
55
- "TSLA": {'name': '테슬라', 'description': '전기자동차와 에너지 저장장치를'}
56
- }
57
-
58
  def get_stock_data(ticker):
59
- try:
60
- stock = yf.Ticker(ticker)
61
- hist = stock.history(period="5d")
62
- return hist
63
- except Exception as e:
64
- return f"데이터를 불러오는 중 오류가 발생했습니다: {e}"
65
 
66
  def generate(prompt, history=[], temperature=0.1, max_new_tokens=10000, top_p=0.95, repetition_penalty=1.0):
67
  global total_tokens_used
 
 
 
 
 
 
 
 
 
68
  try:
69
- input_tokens = len(tokenizer.encode(prompt))
70
- total_tokens_used += input_tokens
71
- available_tokens = 32768 - total_tokens_used
72
- if available_tokens <= 0:
73
- return "Error: 입력이 최대 허용 토큰 수를 초과합니다."
 
74
 
75
- formatted_prompt = format_prompt(prompt, history)
76
- ticker = prompt.upper()
77
- stock_info_detail = stock_info.get(ticker, None)
78
- if stock_info_detail:
79
- response_msg = f"{stock_info_detail['name']}은(는) {stock_info_detail['description']} 주력으로 생산하는 기업입니다. 티커는 {ticker}입니다. 원하시는 종목이 맞는가요?"
80
- stock_data = get_stock_data(ticker)
81
- return f"{response_msg}\n\n---\nTotal tokens used: {total_tokens_used}\nStock Data: {stock_data}"
 
 
 
 
 
 
 
 
 
82
  else:
83
- return f"입력하신 '{prompt}'은(는) 지원되는 종목명 또는 티커가 아닙니다. 지원되는 티커: {', '.join(stock_info.keys())}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  except Exception as e:
85
- return f"Error: {str(e)}\nTotal tokens used: {total_tokens_used}"
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  mychatbot = gr.Chatbot(
88
  avatar_images=["./user.png", "./botm.png"],
89
  bubble_full_width=False,
90
  show_label=False,
91
  show_copy_button=True,
92
- likeable=True
93
  )
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  demo = gr.ChatInterface(
96
  fn=generate,
97
  chatbot=mychatbot,
98
  title="글로벌 자산 분석 및 예측 LLM: BloombAI",
99
- css="h1 { font-size: 14px; } footer { visibility: hidden; }",
100
- examples=[
101
- ["반드시 한글로 답변할것.", []],
102
- ["좋은 종목(티커) 추천해줘", []],
103
- ["요약 결론을 제시해", []],
104
- ["포트폴리오 분석해줘", []]
105
- ]
106
  )
107
 
108
- demo.launch(show_api=False)
 
41
  total_tokens_used = 0
42
 
43
  def format_prompt(message, history):
44
+ prompt = "<s>[SYSTEM] {} [/SYSTEM]".format(system_instruction)
45
  for user_prompt, bot_response in history:
46
  prompt += f"[INST] {user_prompt} [/INST]{bot_response}</s> "
47
  prompt += f"[INST] {message} [/INST]"
48
  return prompt
49
 
 
 
 
 
 
 
 
 
50
  def get_stock_data(ticker):
51
+ stock = yf.Ticker(ticker)
52
+ hist = stock.history(period="5d") # 지난 5일간의 주식 데이터를 가져옵니다.
53
+ return hist
 
 
 
54
 
55
  def generate(prompt, history=[], temperature=0.1, max_new_tokens=10000, top_p=0.95, repetition_penalty=1.0):
56
  global total_tokens_used
57
+ input_tokens = len(tokenizer.encode(prompt))
58
+ total_tokens_used += input_tokens
59
+ available_tokens = 32768 - total_tokens_used
60
+ if available_tokens <= 0:
61
+ yield f"Error: 입력이 최대 허용 토큰 수를 초과합니다. Total tokens used: {total_tokens_used}"
62
+ return
63
+
64
+ formatted_prompt = format_prompt(prompt, history)
65
+ output_accumulated = ""
66
  try:
67
+ # 티커 확인 및 데이터 수집
68
+ stock_info = get_stock_info(prompt) # 종목명을 토대로 티커 정보와 기업 설명을 가져옵니다.
69
+ if stock_info['ticker']:
70
+ response_msg = f"{stock_info['name']}은(는) {stock_info['description']} 주력으로 생산하는 기업입니다. {stock_info['name']}의 티커는 {stock_info['ticker']}입니다. 원하시는 종목이 맞는가요?"
71
+ output_accumulated += response_msg
72
+ yield output_accumulated
73
 
74
+ # 추가적인 분석 요청이 있다면, yfinance로 데이터 수집 및 분석
75
+ stock_data = get_stock_data(stock_info['ticker']) # 티커를 이용해 주식 데이터를 가져옵니다.
76
+ stream = client.text_generation(
77
+ formatted_prompt,
78
+ temperature=temperature,
79
+ max_new_tokens=min(max_new_tokens, available_tokens),
80
+ top_p=top_p,
81
+ repetition_penalty=repetition_penalty,
82
+ do_sample=True,
83
+ seed=42,
84
+ stream=True
85
+ )
86
+ for response in stream:
87
+ output_part = response['generated_text'] if 'generated_text' in response else str(response)
88
+ output_accumulated += output_part
89
+ yield output_accumulated + f"\n\n---\nTotal tokens used: {total_tokens_used}\nStock Data: {stock_data}"
90
  else:
91
+ # 입력이 티커인 경우 처리
92
+ ticker = prompt.upper()
93
+ if ticker in ['AAPL', 'MSFT', 'AMZN', 'GOOGL', 'TSLA']:
94
+ stock_info = get_stock_info_by_ticker(ticker)
95
+ response_msg = f"{stock_info['name']}은(는) {stock_info['description']} 주력으로 생산하는 기업입니다. {stock_info['name']}의 티커는 {stock_info['ticker']}입니다. 원하시는 종목이 맞는가요?"
96
+ output_accumulated += response_msg
97
+ yield output_accumulated
98
+
99
+ # 추가적인 분석 요청이 있다면, yfinance로 데이터 수집 및 분석
100
+ stock_data = get_stock_data(stock_info['ticker']) # 티커를 ��용해 주식 데이터를 가져옵니다.
101
+ stream = client.text_generation(
102
+ formatted_prompt,
103
+ temperature=temperature,
104
+ max_new_tokens=min(max_new_tokens, available_tokens),
105
+ top_p=top_p,
106
+ repetition_penalty=repetition_penalty,
107
+ do_sample=True,
108
+ seed=42,
109
+ stream=True
110
+ )
111
+ for response in stream:
112
+ output_part = response['generated_text'] if 'generated_text' in response else str(response)
113
+ output_accumulated += output_part
114
+ yield output_accumulated + f"\n\n---\nTotal tokens used: {total_tokens_used}\nStock Data: {stock_data}"
115
+ else:
116
+ yield f"입력하신 '{prompt}'은(는) 지원되는 종목명 또는 티커가 아닙니다. 현재 지원되는 종목은 애플(AAPL), 마이크로소프트(MSFT), 아마존(AMZN), 알파벳(GOOGL), 테슬라(TSLA) 등입니다. 정확한 종목명 또는 티커를 입력해주세요."
117
  except Exception as e:
118
+ yield f"Error: {str(e)}\nTotal tokens used: {total_tokens_used}"
119
 
120
+ # 티커를 토대로 종목 정보를 제공하는 함수
121
+ def get_stock_info_by_ticker(ticker):
122
+ stock_info = {
123
+ "AAPL": {'ticker': 'AAPL', 'name': '애플', 'description': '아이폰을'},
124
+ "MSFT": {'ticker': 'MSFT', 'name': '마이크로소프트', 'description': '윈도우 운영체제와 오피스 소프트웨어를'},
125
+ "AMZN": {'ticker': 'AMZN', 'name': '아마존', 'description': '전자상거래 및 클라우드 서비스를'},
126
+ "GOOGL": {'ticker': 'GOOGL', 'name': '알파벳', 'description': '검색 엔진 및 온라인 광고를'},
127
+ "TSLA": {'ticker': 'TSLA', 'name': '테슬라', 'description': '전기자동차와 에너지 저장장치를'},
128
+ }
129
+ return stock_info.get(ticker, {'ticker': None, 'name': None, 'description': ''})
130
+
131
+ # 종목명을 토대로 티커와 기업 정보를 제공하는 함수
132
+ def get_stock_info(name):
133
+ stock_info = {
134
+ "apple": {'ticker': 'AAPL', 'name': '애플', 'description': '아이폰을'},
135
+ "microsoft": {'ticker': 'MSFT', 'name': '마이크로소프트', 'description': '윈도우 운영체제와 오피스 소프트웨어를'},
136
+ "amazon": {'ticker': 'AMZN', 'name': '아마존', 'description': '전자상거래 및 클라우드 서비스를'},
137
+ "google": {'ticker': 'GOOGL', 'name': '알파벳 (구글)', 'description': '검색 엔진 및 온라인 광고를'},
138
+ "tesla": {'ticker': 'TSLA', 'name': '테슬라', 'description': '전기자동차와 에너지 저장장치를'},
139
+ # 추가적인 종목에 대한 정보를 이곳에 구현할 수 있습니다.
140
+ }
141
+ return stock_info.get(name.lower(), {'ticker': None, 'name': name, 'description': ''})
142
+
143
  mychatbot = gr.Chatbot(
144
  avatar_images=["./user.png", "./botm.png"],
145
  bubble_full_width=False,
146
  show_label=False,
147
  show_copy_button=True,
148
+ likeable=True,
149
  )
150
 
151
+ examples = [
152
+ ["반드시 한글로 답변할것.", []],
153
+ ["좋은 종목(티커) 추천해줘", []],
154
+ ["요약 결론을 제시해", []],
155
+ ["포트폴리오 분석해줘", []]
156
+ ]
157
+
158
+ css = """
159
+ h1 {
160
+ font-size: 14px;
161
+ }
162
+ footer {
163
+ visibility: hidden;
164
+ }
165
+ """
166
+
167
  demo = gr.ChatInterface(
168
  fn=generate,
169
  chatbot=mychatbot,
170
  title="글로벌 자산 분석 및 예측 LLM: BloombAI",
171
+ retry_btn=None,
172
+ undo_btn=None,
173
+ css=css,
174
+ examples=examples
 
 
 
175
  )
176
 
177
+ demo.queue().launch(show_api=False)