Fta98's picture
fix
e2e00ab
raw
history blame
5.63 kB
import os
import streamlit as st
import torch
from huggingface_hub import login
from peft import PeftModel
from transformers import AutoModelForCausalLM, LlamaTokenizer
login(token=os.getenv("HUGGINGFACE_API_KEY"))
device = "cuda" if torch.cuda.is_available() else "cpu"
@st.cache_resource
def load():
base_model = AutoModelForCausalLM.from_pretrained(
"stabilityai/japanese-stablelm-instruct-alpha-7b",
device_map="auto",
low_cpu_mem_usage=True,
variant="int8",
load_in_8bit=True,
trust_remote_code=True,
)
model = PeftModel.from_pretrained(
base_model,
"lora_adapter",
device_map="auto",
)
tokenizer = LlamaTokenizer.from_pretrained(
"novelai/nerdstash-tokenizer-v1",
additional_special_tokens=['▁▁']
)
return model, tokenizer
def get_prompt(user_query, system_prompt, messages="", sep="\n\n### "):
prompt = system_prompt + "\n以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。"
roles = ["指示", "応答"]
msgs = [": \n" + user_query, ": "]
if messages:
roles.insert(1, "入力")
msgs.insert(1, ": \n" + "\n\n".join(message["content"] for message in messages))
for role, msg in zip(roles, msgs):
prompt += sep + role + msg
return prompt
def get_input_token_length(user_query, system_prompt, messages=""):
prompt = get_prompt(user_query, system_prompt, messages)
input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids']
return input_ids.shape[-1]
def generate_response(user_query: str, system_prompt: str, messages: str="", temperature: float=0, top_k: int=50, top_p: float=0.95, repetition_penalty: float=1.1):
prompt = get_prompt(user_query, system_prompt, messages)
inputs = tokenizer(
prompt,
add_special_tokens=False,
return_tensors="pt"
).to(device)
max_new_tokens = 2048 - get_input_token_length(user_query, system_prompt, messages)
model.eval()
with torch.no_grad():
tokens = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
response = tokenizer.decode(tokens[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True).strip()
return response
st.header(":dna: 遺伝カウンセリング対話AI")
# 初期化
model, tokenizer = load()
if "messages" not in st.session_state:
st.session_state["messages"] = []
if "options" not in st.session_state:
st.session_state["options"] = {
"temperature": 0.0,
"top_k": 50,
"top_p": 0.95,
"repetition_penalty": 1.1,
"system_prompt": """あなたは誠実かつ優秀な遺伝子カウンセリングのカウンセラーです。
常に安全を考慮し、できる限り有益な回答を心がけてください。
あなたの回答には、有害、非倫理的、人種差別的、性差別的、有害、危険、違法な内容が含まれてはいけません。
社会的に偏りのない、前向きな回答を心がけてください。
質問が意味をなさない場合、または事実に一貫性がない場合は、正しくないことを答えるのではなく、その理由を説明してください。
質問の答えを知らない場合は、誤った情報を共有しないでください。"""}
# サイドバー
clear_chat = st.sidebar.button(":sparkles: 新しくチャットを始める", key="clear_chat")
st.sidebar.header("Options")
st.session_state["options"]["temperature"] = st.sidebar.slider("temperature", min_value=0.0, max_value=2.0, step=0.1, value=st.session_state["options"]["temperature"])
st.session_state["options"]["top_k"] = st.sidebar.slider("top_k", min_value=0, max_value=100, step=1, value=st.session_state["options"]["top_k"])
st.session_state["options"]["top_p"] = st.sidebar.slider("top_p", min_value=0.0, max_value=1.0, step=0.1, value=st.session_state["options"]["top_p"])
st.session_state["options"]["repetition_penalty"] = st.sidebar.slider("repetition_penalty", min_value=1.0, max_value=2.0, step=0.01, value=st.session_state["options"]["repetition_penalty"])
st.session_state["options"]["system_prompt"] = st.sidebar.text_area("System Prompt", value=st.session_state["options"]["system_prompt"])
# リセット
if clear_chat:
st.session_state["messages"] = []
# チャット履歴の表示
for message in st.session_state["messages"]:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# 現在のチャット
if user_prompt := st.chat_input("質問を送信してください"):
with st.chat_message("user"):
st.markdown(user_prompt)
st.session_state["messages"].append({"role": "user", "content": user_prompt})
response = generate_response(
user_query=user_prompt,
system_prompt=st.session_state["options"]["system_prompt"],
messages=st.session_state["messages"],
temperature=st.session_state["options"]["temperature"],
top_k=st.session_state["options"]["top_k"],
top_p=st.session_state["options"]["top_p"],
repetition_penalty=st.session_state["options"]["repetition_penalty"],
)
with st.chat_message("assistant"):
st.markdown(response)
st.session_state["messages"].append({"role": "assistant", "content": response})