File size: 4,209 Bytes
35f1339
 
 
 
 
 
 
 
 
 
0338b46
1083b87
dd15f2c
 
 
35f1339
 
 
 
b44c523
35f1339
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9aaf598
 
35f1339
 
50c0f11
35f1339
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9aaf598
35f1339
 
 
c9e74a0
35f1339
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9aaf598
35f1339
 
6fe0865
35f1339
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50c0f11
35f1339
 
 
 
 
 
 
50c0f11
35f1339
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
from threading import Thread
from typing import Iterator

import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

DESCRIPTION = """\
# ⛰️⛰️ JAIS Initiative: Atlas-Chat-9B ⛰️⛰️

Disclaimer: This research demonstration of Atlas-Chat-9B is not intended for end-user applications. The model may generate biased, offensive, or inaccurate content as it is trained on diverse internet data. The developers do not endorse any views expressed by the model and assume no responsibility for the consequences of its use. Users should critically evaluate the generated responses and use the tool at their own risk.

Note: The model is expected to take input and generate output in Moroccan Darija with Arabic script. 
"""

MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "2024"))

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model_id = "MBZUAI-Paris/Atlas-Chat-9B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)
model.eval()


@spaces.GPU(duration=90)
def generate(
    message: str,
    chat_history: list[dict],
    max_new_tokens: int = 1024,
    do_sample: bool = False,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.0,
) -> Iterator[str]:
    conversation = chat_history.copy()
    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    input_ids = input_ids.to(model.device)

    streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=do_sample,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        repetition_penalty=repetition_penalty,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)


chat_interface = gr.ChatInterface(
    fn=generate,
    additional_inputs=[
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        gr.Checkbox(label="Do Sample"),
        gr.Slider(
            label="Temperature",
            minimum=0.0,
            maximum=4.0,
            step=0.1,
            value=0.6,
        ),
        gr.Slider(
            label="Top-p (nucleus sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=0.9,
        ),
        gr.Slider(
            label="Top-k",
            minimum=1,
            maximum=1000,
            step=1,
            value=50,
        ),
        gr.Slider(
            label="Repetition penalty",
            minimum=1.0,
            maximum=2.0,
            step=0.05,
            value=1.0,
        ),
    ],
    stop_btn=None,
    examples=[
        ['شكون لي صنعك؟'],
        ["شنو كيتسمى المنتخب المغربي ؟"],
        ["أشنو كايمييز المملكة المغربية."],
        ["ترجم للدارجة:\nAtlas-Chat is the first open source large language model that talks in Darija."],
    ],
    cache_examples=False,
    type="messages",
)

with gr.Blocks(css_paths="style.css", fill_height=True) as demo:
    gr.Markdown(DESCRIPTION)
    gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
    chat_interface.render()

if __name__ == "__main__":
    demo.queue(max_size=20).launch()