File size: 3,578 Bytes
b5eb658
 
 
b26e890
b5eb658
 
 
 
 
 
 
502da2b
b5eb658
502da2b
b5eb658
 
 
 
 
 
 
 
c836784
b5eb658
 
 
 
4aa0d94
b5eb658
 
 
 
 
 
 
 
 
502da2b
b5eb658
 
 
 
 
 
 
 
 
 
b26e890
b5eb658
 
 
b26e890
b5eb658
 
 
 
 
 
 
 
 
 
 
 
 
b26e890
 
 
 
 
 
 
b5eb658
 
 
 
 
 
b6b4690
b5eb658
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c836784
b5eb658
 
 
 
 
 
c836784
b5eb658
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
451d731
b5eb658
 
 
 
 
b26e890
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
from threading import Thread
from typing import Iterator
import queue

import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

DESCRIPTION = """\
# magro-7b

日本語ai
"""

MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model_id = "Sakalti/magro-7B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.float16,
)
model.config.sliding_window = 4096
model.eval()

@spaces.GPU
def generate(
    message: str,
    chat_history: list[dict],
    max_new_tokens: int = 1024,
    temperature: float = 0.7,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
) -> Iterator[str]:
    conversation = chat_history + [{"role": "user", "content": message}]

    input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    input_ids = input_ids.to(device)

    streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        input_ids=input_ids,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        repetition_penalty=repetition_penalty,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    try:
        for text in streamer:
            outputs.append(text)
            yield "".join(outputs)
    except queue.Empty:
        # キューが空になった場合の処理
        gr.Warning("生成プロセスがタイムアウトしました。")
        yield "".join(outputs)

demo = gr.ChatInterface(
    fn=generate,
    type="messages",
    description=DESCRIPTION,
    css_paths="style.css",
    fill_height=True,
    additional_inputs_accordion=gr.Accordion(label="詳細設定", open=False),
    additional_inputs=[
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        gr.Slider(
            label="Temperature",
            minimum=0.1,
            maximum=4.0,
            step=0.1,
            value=0.7,
        ),
        gr.Slider(
            label="Top-p (nucleus sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=0.95,
        ),
        gr.Slider(
            label="Top-k",
            minimum=1,
            maximum=1000,
            step=1,
            value=50,
        ),
        gr.Slider(
            label="Repetition penalty",
            minimum=1.0,
            maximum=2.0,
            step=0.05,
            value=1.2,
        ),
    ],
    stop_btn=None,
    examples=[
        ["こんにちは、自己紹介をしてください。"],
        ["マシンラーニングについての詩を書いてください。"],
        ["c言語は難しいですか?"],
    ],
    cache_examples=False,
)

if __name__ == "__main__":
    demo.launch()