File size: 5,633 Bytes
8c97efe 0bf9df3 8c97efe 0bf9df3 8c97efe 0bf9df3 b3f598f 0bf9df3 9f5fd0a 0bf9df3 20734fc 9213e75 20734fc 8c97efe 06ad8da 0bf9df3 6d7bada 20734fc 6d7bada 18b3695 0bf9df3 6d7bada e2e00ab 9213e75 8c97efe a6ae1be 8c97efe 6d7bada e2e00ab 42b1d93 6d7bada |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import os
import streamlit as st
import torch
from huggingface_hub import login
from peft import PeftModel
from transformers import AutoModelForCausalLM, LlamaTokenizer
login(token=os.getenv("HUGGINGFACE_API_KEY"))
device = "cuda" if torch.cuda.is_available() else "cpu"
@st.cache_resource
def load():
base_model = AutoModelForCausalLM.from_pretrained(
"stabilityai/japanese-stablelm-instruct-alpha-7b",
device_map="auto",
low_cpu_mem_usage=True,
variant="int8",
load_in_8bit=True,
trust_remote_code=True,
)
model = PeftModel.from_pretrained(
base_model,
"lora_adapter",
device_map="auto",
)
tokenizer = LlamaTokenizer.from_pretrained(
"novelai/nerdstash-tokenizer-v1",
additional_special_tokens=['▁▁']
)
return model, tokenizer
def get_prompt(user_query, system_prompt, messages="", sep="\n\n### "):
prompt = system_prompt + "\n以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。"
roles = ["指示", "応答"]
msgs = [": \n" + user_query, ": "]
if messages:
roles.insert(1, "入力")
msgs.insert(1, ": \n" + "\n\n".join(message["content"] for message in messages))
for role, msg in zip(roles, msgs):
prompt += sep + role + msg
return prompt
def get_input_token_length(user_query, system_prompt, messages=""):
prompt = get_prompt(user_query, system_prompt, messages)
input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids']
return input_ids.shape[-1]
def generate_response(user_query: str, system_prompt: str, messages: str="", temperature: float=0, top_k: int=50, top_p: float=0.95, repetition_penalty: float=1.1):
prompt = get_prompt(user_query, system_prompt, messages)
inputs = tokenizer(
prompt,
add_special_tokens=False,
return_tensors="pt"
).to(device)
max_new_tokens = 2048 - get_input_token_length(user_query, system_prompt, messages)
model.eval()
with torch.no_grad():
tokens = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
response = tokenizer.decode(tokens[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True).strip()
return response
st.header(":dna: 遺伝カウンセリング対話AI")
# 初期化
model, tokenizer = load()
if "messages" not in st.session_state:
st.session_state["messages"] = []
if "options" not in st.session_state:
st.session_state["options"] = {
"temperature": 0.0,
"top_k": 50,
"top_p": 0.95,
"repetition_penalty": 1.1,
"system_prompt": """あなたは誠実かつ優秀な遺伝子カウンセリングのカウンセラーです。
常に安全を考慮し、できる限り有益な回答を心がけてください。
あなたの回答には、有害、非倫理的、人種差別的、性差別的、有害、危険、違法な内容が含まれてはいけません。
社会的に偏りのない、前向きな回答を心がけてください。
質問が意味をなさない場合、または事実に一貫性がない場合は、正しくないことを答えるのではなく、その理由を説明してください。
質問の答えを知らない場合は、誤った情報を共有しないでください。"""}
# サイドバー
clear_chat = st.sidebar.button(":sparkles: 新しくチャットを始める", key="clear_chat")
st.sidebar.header("Options")
st.session_state["options"]["temperature"] = st.sidebar.slider("temperature", min_value=0.0, max_value=2.0, step=0.1, value=st.session_state["options"]["temperature"])
st.session_state["options"]["top_k"] = st.sidebar.slider("top_k", min_value=0, max_value=100, step=1, value=st.session_state["options"]["top_k"])
st.session_state["options"]["top_p"] = st.sidebar.slider("top_p", min_value=0.0, max_value=1.0, step=0.1, value=st.session_state["options"]["top_p"])
st.session_state["options"]["repetition_penalty"] = st.sidebar.slider("repetition_penalty", min_value=1.0, max_value=2.0, step=0.01, value=st.session_state["options"]["repetition_penalty"])
st.session_state["options"]["system_prompt"] = st.sidebar.text_area("System Prompt", value=st.session_state["options"]["system_prompt"])
# リセット
if clear_chat:
st.session_state["messages"] = []
# チャット履歴の表示
for message in st.session_state["messages"]:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# 現在のチャット
if user_prompt := st.chat_input("質問を送信してください"):
with st.chat_message("user"):
st.markdown(user_prompt)
st.session_state["messages"].append({"role": "user", "content": user_prompt})
response = generate_response(
user_query=user_prompt,
system_prompt=st.session_state["options"]["system_prompt"],
messages=st.session_state["messages"],
temperature=st.session_state["options"]["temperature"],
top_k=st.session_state["options"]["top_k"],
top_p=st.session_state["options"]["top_p"],
repetition_penalty=st.session_state["options"]["repetition_penalty"],
)
with st.chat_message("assistant"):
st.markdown(response)
st.session_state["messages"].append({"role": "assistant", "content": response})
|