File size: 5,633 Bytes
8c97efe
 
0bf9df3
8c97efe
 
 
0bf9df3
 
8c97efe
 
0bf9df3
 
 
b3f598f
 
 
 
 
 
 
 
 
 
 
 
 
0bf9df3
9f5fd0a
 
0bf9df3
 
 
20734fc
 
 
 
 
 
9213e75
20734fc
 
 
 
 
 
 
 
 
 
8c97efe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
06ad8da
 
0bf9df3
 
 
 
6d7bada
 
20734fc
6d7bada
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18b3695
 
0bf9df3
 
 
 
 
 
 
6d7bada
 
 
 
 
 
 
 
 
 
 
 
e2e00ab
9213e75
8c97efe
a6ae1be
8c97efe
 
 
 
 
 
 
6d7bada
e2e00ab
42b1d93
6d7bada
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os

import streamlit as st
import torch
from huggingface_hub import login
from peft import PeftModel
from transformers import AutoModelForCausalLM, LlamaTokenizer

login(token=os.getenv("HUGGINGFACE_API_KEY"))
device = "cuda" if torch.cuda.is_available() else "cpu"

@st.cache_resource
def load():
    base_model = AutoModelForCausalLM.from_pretrained(
        "stabilityai/japanese-stablelm-instruct-alpha-7b", 
        device_map="auto",
        low_cpu_mem_usage=True,
        variant="int8",
        load_in_8bit=True,
        trust_remote_code=True,
        )
    model = PeftModel.from_pretrained(
        base_model, 
        "lora_adapter",
        device_map="auto",
        )
    tokenizer = LlamaTokenizer.from_pretrained(
        "novelai/nerdstash-tokenizer-v1",
         additional_special_tokens=['▁▁']
        )
    return model, tokenizer

def get_prompt(user_query, system_prompt, messages="", sep="\n\n### "):
    prompt = system_prompt + "\n以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。"
    roles = ["指示", "応答"]
    msgs = [": \n" + user_query, ": "]
    if messages:
        roles.insert(1, "入力")
        msgs.insert(1, ": \n" + "\n\n".join(message["content"] for message in messages))

    for role, msg in zip(roles, msgs):
        prompt += sep + role + msg
    return prompt

def get_input_token_length(user_query, system_prompt, messages=""):
    prompt = get_prompt(user_query, system_prompt, messages)
    input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids']
    return input_ids.shape[-1]

def generate_response(user_query: str, system_prompt: str, messages: str="", temperature: float=0, top_k: int=50, top_p: float=0.95, repetition_penalty: float=1.1):
    prompt = get_prompt(user_query, system_prompt, messages)
    inputs = tokenizer(
        prompt, 
        add_special_tokens=False, 
        return_tensors="pt"
        ).to(device)
    max_new_tokens = 2048 - get_input_token_length(user_query, system_prompt, messages)
    model.eval()
    with torch.no_grad():
        tokens = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens, 
            temperature=temperature, 
            top_k=top_k, 
            top_p=top_p,
            repetition_penalty=repetition_penalty, 
            )
    response = tokenizer.decode(tokens[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True).strip()
    return response


st.header(":dna: 遺伝カウンセリング対話AI")


# 初期化
model, tokenizer = load()
if "messages" not in st.session_state:
    st.session_state["messages"] = []
if "options" not in st.session_state:
    st.session_state["options"] = {
        "temperature": 0.0, 
        "top_k": 50, 
        "top_p": 0.95, 
        "repetition_penalty": 1.1,
        "system_prompt": """あなたは誠実かつ優秀な遺伝子カウンセリングのカウンセラーです。
常に安全を考慮し、できる限り有益な回答を心がけてください。
あなたの回答には、有害、非倫理的、人種差別的、性差別的、有害、危険、違法な内容が含まれてはいけません。
社会的に偏りのない、前向きな回答を心がけてください。
質問が意味をなさない場合、または事実に一貫性がない場合は、正しくないことを答えるのではなく、その理由を説明してください。
質問の答えを知らない場合は、誤った情報を共有しないでください。"""}

# サイドバー
clear_chat = st.sidebar.button(":sparkles: 新しくチャットを始める", key="clear_chat")

st.sidebar.header("Options")
st.session_state["options"]["temperature"] = st.sidebar.slider("temperature", min_value=0.0, max_value=2.0, step=0.1, value=st.session_state["options"]["temperature"])
st.session_state["options"]["top_k"] = st.sidebar.slider("top_k", min_value=0, max_value=100, step=1, value=st.session_state["options"]["top_k"])
st.session_state["options"]["top_p"] = st.sidebar.slider("top_p", min_value=0.0, max_value=1.0, step=0.1, value=st.session_state["options"]["top_p"])
st.session_state["options"]["repetition_penalty"] = st.sidebar.slider("repetition_penalty", min_value=1.0, max_value=2.0, step=0.01, value=st.session_state["options"]["repetition_penalty"])
st.session_state["options"]["system_prompt"] = st.sidebar.text_area("System Prompt", value=st.session_state["options"]["system_prompt"])

# リセット
if clear_chat:
    st.session_state["messages"] = [] 

# チャット履歴の表示
for message in st.session_state["messages"]:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])

# 現在のチャット
if user_prompt := st.chat_input("質問を送信してください"):
    with st.chat_message("user"):
        st.markdown(user_prompt)
    st.session_state["messages"].append({"role": "user", "content": user_prompt})
    response = generate_response(
        user_query=user_prompt, 
        system_prompt=st.session_state["options"]["system_prompt"],
        messages=st.session_state["messages"],
        temperature=st.session_state["options"]["temperature"],
        top_k=st.session_state["options"]["top_k"],
        top_p=st.session_state["options"]["top_p"],
        repetition_penalty=st.session_state["options"]["repetition_penalty"],
        )
    with st.chat_message("assistant"):
        st.markdown(response)
    st.session_state["messages"].append({"role": "assistant", "content": response})