File size: 3,552 Bytes
80956aa
d851055
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80956aa
d851055
 
 
80956aa
d851055
 
 
 
80956aa
 
d851055
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
from threading import Thread
import spaces

# 言語リスト
languages = [
    "English", "Chinese (Simplified)", "Chinese (Traditional)", "Spanish", "Arabic", "Hindi",
    "Bengali", "Portuguese", "Russian", "Japanese", "German", "French", "Urdu", "Indonesian",
    "Italian", "Turkish", "Korean", "Vietnamese", "Tamil", "Marathi", "Telugu", "Persian",
    "Polish", "Dutch", "Thai", "Gujarati", "Romanian", "Ukrainian", "Malay", "Kannada", "Oriya (Odia)",
    "Burmese (Myanmar)", "Azerbaijani", "Uzbek", "Kurdish (Kurmanji)", "Swedish", "Filipino (Tagalog)",
    "Serbian", "Czech", "Hungarian", "Greek", "Belarusian", "Bulgarian", "Hebrew", "Finnish",
    "Slovak", "Norwegian", "Danish", "Sinhala", "Croatian", "Lithuanian", "Slovenian", "Latvian",
    "Estonian", "Armenian", "Malayalam", "Georgian", "Mongolian", "Afrikaans", "Nepali", "Pashto",
    "Punjabi", "Kurdish", "Kyrgyz", "Somali", "Albanian", "Icelandic", "Basque", "Luxembourgish",
    "Macedonian", "Maltese", "Hawaiian", "Yoruba", "Maori", "Zulu", "Welsh", "Swahili", "Haitian Creole",
    "Lao", "Amharic", "Khmer", "Javanese", "Kazakh", "Malagasy", "Sindhi", "Sundanese", "Tajik", "Xhosa",
    "Yiddish", "Bosnian", "Cebuano", "Chichewa", "Corsican", "Esperanto", "Frisian", "Galician", "Hausa",
    "Hmong", "Igbo", "Irish", "Kinyarwanda", "Latin", "Samoan", "Scots Gaelic", "Sesotho", "Shona",
    "Sotho", "Swedish", "Uyghur"
]

tokenizer = AutoTokenizer.from_pretrained("aixsatoshi/Honyaku-13b")
model = AutoModelForCausalLM.from_pretrained("aixsatoshi/Honyaku-13b", torch_dtype=torch.float16)
#model = model.to('cuda:0')

class StopOnTokens(StoppingCriteria):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        stop_ids = [2]
        for stop_id in stop_ids:
            if input_ids[0][-1] == stop_id:
                return True
        return False
        
@spaces.GPU
def predict(message, history, tokens, temperature, language):
    tag = "<" + language.lower() + ">"
    history_transformer_format = history + [[message, ""]]
    stop = StopOnTokens()

    messages = "".join(["".join(["\n<english>:"+item[0]+"</english>\n", tag+item[1]])
                for item in history_transformer_format])

    model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
    streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        model_inputs,
        streamer=streamer,
        max_new_tokens=int(tokens),
        temperature=float(temperature),
        do_sample=True,
        top_p=0.95,
        top_k=20,
        repetition_penalty=1.15,
        num_beams=1,
        stopping_criteria=StoppingCriteriaList([stop])
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    partial_message = ""
    for new_token in streamer:
        if new_token != '<':
            partial_message += new_token
            yield partial_message

# Gradioインタフェースの設定
demo = gr.ChatInterface(
    fn=predict, 
    title="Honyaku-13b webui",
    description="Translate using Honyaku-7b model",
    additional_inputs=[
        gr.Slider(100, 4096, value=1000, label="Tokens"),
        gr.Slider(0.0, 1.0, value=0.3, label="Temperature"),
        gr.Dropdown(choices=languages, value="Japanese", label="Language")
    ]
)

demo.queue().launch()