File size: 4,153 Bytes
c55665a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from datetime import datetime

model_id = "BSC-LT/salamandra-2b-instruct"
device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

description = """
Salamandra-2b-instruct is a Transformer-based decoder-only language model that has been pre-trained on 7.8 trillion tokens of highly curated data. 
The pre-training corpus contains text in 35 European languages and code. This instruction-tuned variant can be used as a general-purpose assistant.
"""

join_us = """
## Join us:
🌟TeamTonic🌟 is always making cool demos! Join our active builder's 🛠️community 👻 
[![Join us on Discord](https://img.shields.io/discord/1109943800132010065?label=Discord&logo=discord&style=flat-square)](https://discord.gg/qdfnvSPcqP) 
On 🤗Huggingface: [MultiTransformer](https://huggingface.co/MultiTransformer) 
On 🌐Github: [Tonic-AI](https://github.com/tonic-ai) & contribute to🌟 [Build Tonic](https://git.tonic-ai.com/contribute)
🤗Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant 🤗
"""

def generate_text(prompt, temperature, max_new_tokens, top_p, repetition_penalty):
    date_string = datetime.today().strftime('%Y-%m-%d')
    message = [{"role": "user", "content": prompt}]
    
    chat_prompt = tokenizer.apply_chat_template(
        message,
        tokenize=False,
        add_generation_prompt=True,
        date_string=date_string
    )
    
    inputs = tokenizer.encode(chat_prompt, add_special_tokens=False, return_tensors="pt")
    
    outputs = model.generate(
        input_ids=inputs.to(model.device),
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True
    )
    
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return generated_text.split("assistant\n")[-1].strip()

def update_output(prompt, temperature, max_new_tokens, top_p, repetition_penalty):
    return generate_text(prompt, temperature, max_new_tokens, top_p, repetition_penalty)

with gr.Blocks() as demo:
    gr.Markdown("# 🙋🏻‍♂️ Welcome to Tonic's 📲🦎Salamandra-2b-instruct Demo")
    
    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown(description)
        with gr.Column(scale=1):
            gr.Markdown(join_us)
    
    with gr.Row():
        with gr.Column(scale=1):
            prompt = gr.Textbox(lines=5, label="🙋‍♂️ Input Prompt")
            generate_button = gr.Button("Try 📲🦎Salamandra-2b-instruct")
            
            with gr.Accordion("🧪 Parameters", open=False):
                temperature = gr.Slider(0.0, 1.0, value=0.7, label="🌡️ Temperature")
                max_new_tokens = gr.Slider(1, 1000, value=200, step=1, label="🔢 Max New Tokens")
                top_p = gr.Slider(0.0, 1.0, value=0.95, label="⚛️ Top P")
                repetition_penalty = gr.Slider(1.0, 2.0, value=1.2, label="🔁 Repetition Penalty")
        
        with gr.Column(scale=1):
            output = gr.Textbox(lines=10, label="📲🦎Salamandra")
    
    generate_button.click(
        update_output,
        inputs=[prompt, temperature, max_new_tokens, top_p, repetition_penalty],
        outputs=output
    )
    
    gr.Examples(
        examples=[
            ["What are the main advantages of living in a big city like Barcelona?"],
            ["Explain the process of photosynthesis in simple terms."],
            ["What are some effective strategies for learning a new language?"],
            ["Describe the potential impacts of artificial intelligence on the job market in the next decade."],
            ["What are the key differences between renewable and non-renewable energy sources?"]
        ],
        inputs=prompt,
        outputs=prompt,
        label="Example Prompts"
    )

if __name__ == "__main__":
    demo.launch()