import os import streamlit as st import torch from huggingface_hub import login from peft import PeftModel from transformers import AutoModelForCausalLM, LlamaTokenizer login(token=os.getenv("HUGGINGFACE_API_KEY")) device = "cuda" if torch.cuda.is_available() else "cpu" @st.cache_resource def load(): base_model = AutoModelForCausalLM.from_pretrained( "stabilityai/japanese-stablelm-instruct-alpha-7b", device_map="auto", low_cpu_mem_usage=True, variant="int8", load_in_8bit=True, trust_remote_code=True, ) model = PeftModel.from_pretrained( base_model, "lora_adapter", device_map="auto", ) tokenizer = LlamaTokenizer.from_pretrained( "novelai/nerdstash-tokenizer-v1", additional_special_tokens=['▁▁'] ) return model, tokenizer def get_prompt(user_query, system_prompt, messages="", sep="\n\n### "): prompt = system_prompt + "\n以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。" roles = ["指示", "応答"] msgs = [": \n" + user_query, ": "] if messages: roles.insert(1, "入力") msgs.insert(1, ": \n" + "\n\n".join(message["content"] for message in messages)) for role, msg in zip(roles, msgs): prompt += sep + role + msg return prompt def get_input_token_length(user_query, system_prompt, messages=""): prompt = get_prompt(user_query, system_prompt, messages) input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids'] return input_ids.shape[-1] def generate_response(user_query: str, system_prompt: str, messages: str="", temperature: float=0, top_k: int=50, top_p: float=0.95, repetition_penalty: float=1.1): prompt = get_prompt(user_query, system_prompt, messages) inputs = tokenizer( prompt, add_special_tokens=False, return_tensors="pt" ).to(device) max_new_tokens = 2048 - get_input_token_length(user_query, system_prompt, messages) model.eval() with torch.no_grad(): tokens = model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, ) response = tokenizer.decode(tokens[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True).strip() return response st.header(":dna: 遺伝カウンセリング対話AI") # 初期化 model, tokenizer = load() if "messages" not in st.session_state: st.session_state["messages"] = [] if "options" not in st.session_state: st.session_state["options"] = { "temperature": 0.0, "top_k": 50, "top_p": 0.95, "repetition_penalty": 1.1, "system_prompt": """あなたは誠実かつ優秀な遺伝子カウンセリングのカウンセラーです。 常に安全を考慮し、できる限り有益な回答を心がけてください。 あなたの回答には、有害、非倫理的、人種差別的、性差別的、有害、危険、違法な内容が含まれてはいけません。 社会的に偏りのない、前向きな回答を心がけてください。 質問が意味をなさない場合、または事実に一貫性がない場合は、正しくないことを答えるのではなく、その理由を説明してください。 質問の答えを知らない場合は、誤った情報を共有しないでください。"""} # サイドバー clear_chat = st.sidebar.button(":sparkles: 新しくチャットを始める", key="clear_chat") st.sidebar.header("Options") st.session_state["options"]["temperature"] = st.sidebar.slider("temperature", min_value=0.0, max_value=2.0, step=0.1, value=st.session_state["options"]["temperature"]) st.session_state["options"]["top_k"] = st.sidebar.slider("top_k", min_value=0, max_value=100, step=1, value=st.session_state["options"]["top_k"]) st.session_state["options"]["top_p"] = st.sidebar.slider("top_p", min_value=0.0, max_value=1.0, step=0.1, value=st.session_state["options"]["top_p"]) st.session_state["options"]["repetition_penalty"] = st.sidebar.slider("repetition_penalty", min_value=1.0, max_value=2.0, step=0.01, value=st.session_state["options"]["repetition_penalty"]) st.session_state["options"]["system_prompt"] = st.sidebar.text_area("System Prompt", value=st.session_state["options"]["system_prompt"]) # リセット if clear_chat: st.session_state["messages"] = [] # チャット履歴の表示 for message in st.session_state["messages"]: with st.chat_message(message["role"]): st.markdown(message["content"]) # 現在のチャット if user_prompt := st.chat_input("質問を送信してください"): with st.chat_message("user"): st.text(user_prompt) st.session_state["messages"].append({"role": "user", "content": user_prompt}) response = generate_response( user_query=user_prompt, system_prompt=st.session_state["options"]["system_prompt"], messages=st.session_state["messages"], temperature=st.session_state["options"]["temperature"], top_k=st.session_state["options"]["top_k"], top_p=st.session_state["options"]["top_p"], repetition_penalty=st.session_state["options"]["repetition_penalty"], ) with st.chat_message("assistant"): st.text(response) st.session_state["messages"].append({"role": "assistant", "content": response})