Fta98's picture
fix
b3f598f
raw
history blame
4.45 kB
import streamlit as st
from transformers import AutoModelForCausalLM, LlamaTokenizer
@st.cache_resource
def load():
"""
base_model = AutoModelForCausalLM.from_pretrained(
"stabilityai/japanese-stablelm-instruct-alpha-7b",
device_map="auto",
low_cpu_mem_usage=True,
variant="int8",
load_in_8bit=True,
trust_remote_code=True,
)
model = PeftModel.from_pretrained(
base_model,
"lora_adapter",
device_map="auto",
)
"""
tokenizer = LlamaTokenizer.from_pretrained(
"lora_adapter",
)
return model, tokenizer
def get_prompt(user_query, system_prompt, messages="", sep="\n\n### "):
prompt = system_prompt + "\n以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。"
roles = ["指示", "応答"]
msgs = [": \n" + user_query, ": "]
if messages:
roles.insert(1, "入力")
msgs.insert(1, ": \n" + "\n".join(message["content"] for message in messages))
for role, msg in zip(roles, msgs):
prompt += sep + role + msg
return prompt
def get_input_token_length(user_query, system_prompt, messages=""):
prompt = get_prompt(user_query, system_prompt, messages)
input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids']
return input_ids.shape[-1]
def generate():
pass
st.header(":dna: 遺伝カウンセリング対話AI")
# 初期化
model, tokenizer = load()
if "messages" not in st.session_state:
st.session_state["messages"] = []
if "options" not in st.session_state:
st.session_state["options"] = {
"temperature": 0.0,
"top_k": 50,
"top_p": 0.95,
"repetition_penalty": 1.1,
"system_prompt": """あなたは誠実かつ優秀な遺伝子カウンセリングのカウンセラーです。
常に安全を考慮し、できる限り有益な回答を心がけてください。
あなたの回答には、有害、非倫理的、人種差別的、性差別的、有害、危険、違法な内容が含まれてはいけません。
社会的に偏りのない、前向きな回答を心がけてください。
質問が意味をなさない場合、または事実に一貫性がない場合は、正しくないことを答えるのではなく、その理由を説明してください。
質問の答えを知らない場合は、誤った情報を共有しないでください。"""}
# サイドバー
clear_chat = st.sidebar.button(":sparkles: 新しくチャットを始める", key="clear_chat")
st.sidebar.header("Options")
st.session_state["options"]["temperature"] = st.sidebar.slider("temperature", min_value=0.0, max_value=2.0, step=0.1, value=st.session_state["options"]["temperature"])
st.session_state["options"]["top_k"] = st.sidebar.slider("top_k", min_value=0, max_value=100, step=1, value=st.session_state["options"]["top_k"])
st.session_state["options"]["top_p"] = st.sidebar.slider("top_p", min_value=0.0, max_value=1.0, step=0.1, value=st.session_state["options"]["top_p"])
st.session_state["options"]["repetition_penalty"] = st.sidebar.slider("repetition_penalty", min_value=1.0, max_value=2.0, step=0.01, value=st.session_state["options"]["repetition_penalty"])
st.session_state["options"]["system_prompt"] = st.sidebar.text_area("System Prompt", value=st.session_state["options"]["system_prompt"])
# リセット
if clear_chat:
st.session_state["messages"] = []
# チャット履歴の表示
for message in st.session_state["messages"]:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# 現在のチャット
if user_prompt := st.chat_input("質問を送信してください"):
with st.chat_message("user"):
st.text(user_prompt)
st.session_state["messages"].append({"role": "user", "content": user_prompt})
token_kength = get_input_token_length(user_query=user_prompt, system_prompt=st.session_state["options"]["system_prompt"], messages=st.session_state["messages"])
response = f"{token_kength}: " + get_prompt(user_query=user_prompt, system_prompt=st.session_state["options"]["system_prompt"], messages=st.session_state["messages"])
with st.chat_message("assistant"):
st.text(response)
st.session_state["messages"].append({"role": "assistant", "content": user_prompt})