simple-chatbot / app.py
YoheiHayamizu
update
81c8f5c
raw
history blame
2.41 kB
import re
import unicodedata
import requests
import streamlit as st
from bs4 import BeautifulSoup
from transformers import (
AutoModelForQuestionAnswering,
AutoTokenizer,
QuestionAnsweringPipeline,
)
model_name = "KoichiYasuoka/bert-base-japanese-wikipedia-ud-head"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForQuestionAnswering.from_pretrained(model_name)
qa_pipeline = QuestionAnsweringPipeline(model=model, tokenizer=tokenizer)
st.title("株価お知らせBot")
stock_code = st.text_input(
"株価を知りたい企業の証券コードを入力してください",
placeholder="証券コード",
max_chars=4,
help="4桁の数字",
)
if "content" not in st.session_state:
st.write("株価を知りたい企業の証券コードを入力してください")
if st.button("株価を知りたい"):
url = f"https://www.nikkei.com/nkd/company/?scode={stock_code}"
res = requests.get(url)
soup = BeautifulSoup(res.text, "html.parser")
print(soup)
_text = soup.find("div", attrs={"class": "m-stockInfo_top_left"})
print(_text)
_text = _text.text
print(_text)
content = unicodedata.normalize("NFKD", _text)
st.session_state.content = re.sub("[\r\t\n]+", " ", content)
# Transformersで回答を作成
def generate_response(prompt, max_length=50):
# input_ids = tokenizer.encode(prompt, return_tensors="pt")
# Generate response
answer = qa_pipeline(context=st.session_state.content[:100], question=prompt)
return answer["answer"]
# メッセージがない時
if "messages" not in st.session_state.keys():
st.session_state.messages = [{"role": "assistant", "content": "何か御用ですか?"}]
# チャット内容の表示
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.write(message["content"])
# ユーザーの質問
if prompt := st.chat_input():
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.write(prompt)
# AIによる回答
if st.session_state.messages[-1]["role"] != "assistant":
with st.chat_message("assistant"):
with st.spinner("考え中..."):
response = generate_response(prompt)
st.write(response)
message = {"role": "assistant", "content": response}
st.session_state.messages.append(message)