Tonic commited on
Commit
1223061
Β·
unverified Β·
1 Parent(s): c55665a
Files changed (1) hide show
  1. app.py +37 -21
app.py CHANGED
@@ -3,6 +3,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
  from datetime import datetime
5
 
 
6
  model_id = "BSC-LT/salamandra-2b-instruct"
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
 
@@ -10,9 +11,13 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
10
  model = AutoModelForCausalLM.from_pretrained(
11
  model_id,
12
  device_map="auto",
13
- torch_dtype=torch.bfloat16,
14
  )
15
 
 
 
 
 
16
  description = """
17
  Salamandra-2b-instruct is a Transformer-based decoder-only language model that has been pre-trained on 7.8 trillion tokens of highly curated data.
18
  The pre-training corpus contains text in 35 European languages and code. This instruction-tuned variant can be used as a general-purpose assistant.
@@ -27,36 +32,42 @@ On 🌐Github: [Tonic-AI](https://github.com/tonic-ai) & contribute to🌟 [Buil
27
  πŸ€—Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant πŸ€—
28
  """
29
 
30
- def generate_text(prompt, temperature, max_new_tokens, top_p, repetition_penalty):
31
  date_string = datetime.today().strftime('%Y-%m-%d')
32
- message = [{"role": "user", "content": prompt}]
 
 
 
33
 
34
  chat_prompt = tokenizer.apply_chat_template(
35
- message,
36
  tokenize=False,
37
  add_generation_prompt=True,
38
  date_string=date_string
39
  )
40
 
41
- inputs = tokenizer.encode(chat_prompt, add_special_tokens=False, return_tensors="pt")
 
42
 
43
  outputs = model.generate(
44
- input_ids=inputs.to(model.device),
45
  max_new_tokens=max_new_tokens,
46
  temperature=temperature,
47
  top_p=top_p,
48
  repetition_penalty=repetition_penalty,
49
- do_sample=True
 
 
50
  )
51
 
52
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
53
  return generated_text.split("assistant\n")[-1].strip()
54
 
55
- def update_output(prompt, temperature, max_new_tokens, top_p, repetition_penalty):
56
- return generate_text(prompt, temperature, max_new_tokens, top_p, repetition_penalty)
57
 
58
  with gr.Blocks() as demo:
59
- gr.Markdown("# πŸ™‹πŸ»β€β™‚οΈ Welcome to Tonic's πŸ“²πŸ¦ŽSalamandra-2b-instruct Demo")
60
 
61
  with gr.Row():
62
  with gr.Column(scale=1):
@@ -66,8 +77,13 @@ with gr.Blocks() as demo:
66
 
67
  with gr.Row():
68
  with gr.Column(scale=1):
69
- prompt = gr.Textbox(lines=5, label="πŸ™‹β€β™‚οΈ Input Prompt")
70
- generate_button = gr.Button("Try πŸ“²πŸ¦ŽSalamandra-2b-instruct")
 
 
 
 
 
71
 
72
  with gr.Accordion("πŸ§ͺ Parameters", open=False):
73
  temperature = gr.Slider(0.0, 1.0, value=0.7, label="🌑️ Temperature")
@@ -76,24 +92,24 @@ with gr.Blocks() as demo:
76
  repetition_penalty = gr.Slider(1.0, 2.0, value=1.2, label="πŸ” Repetition Penalty")
77
 
78
  with gr.Column(scale=1):
79
- output = gr.Textbox(lines=10, label="πŸ“²πŸ¦ŽSalamandra")
80
 
81
  generate_button.click(
82
  update_output,
83
- inputs=[prompt, temperature, max_new_tokens, top_p, repetition_penalty],
84
  outputs=output
85
  )
86
 
87
  gr.Examples(
88
  examples=[
89
- ["What are the main advantages of living in a big city like Barcelona?"],
90
- ["Explain the process of photosynthesis in simple terms."],
91
- ["What are some effective strategies for learning a new language?"],
92
- ["Describe the potential impacts of artificial intelligence on the job market in the next decade."],
93
- ["What are the key differences between renewable and non-renewable energy sources?"]
94
  ],
95
- inputs=prompt,
96
- outputs=prompt,
97
  label="Example Prompts"
98
  )
99
 
 
3
  import torch
4
  from datetime import datetime
5
 
6
+ # Model initialization
7
  model_id = "BSC-LT/salamandra-2b-instruct"
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
 
 
11
  model = AutoModelForCausalLM.from_pretrained(
12
  model_id,
13
  device_map="auto",
14
+ torch_dtype=torch.bfloat16
15
  )
16
 
17
+ # Set pad_token_id to eos_token_id if it's not set
18
+ if tokenizer.pad_token_id is None:
19
+ tokenizer.pad_token_id = tokenizer.eos_token_id
20
+
21
  description = """
22
  Salamandra-2b-instruct is a Transformer-based decoder-only language model that has been pre-trained on 7.8 trillion tokens of highly curated data.
23
  The pre-training corpus contains text in 35 European languages and code. This instruction-tuned variant can be used as a general-purpose assistant.
 
32
  πŸ€—Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant πŸ€—
33
  """
34
 
35
+ def generate_text(system_prompt, user_prompt, temperature, max_new_tokens, top_p, repetition_penalty):
36
  date_string = datetime.today().strftime('%Y-%m-%d')
37
+ messages = [
38
+ {"role": "system", "content": system_prompt},
39
+ {"role": "user", "content": user_prompt}
40
+ ]
41
 
42
  chat_prompt = tokenizer.apply_chat_template(
43
+ messages,
44
  tokenize=False,
45
  add_generation_prompt=True,
46
  date_string=date_string
47
  )
48
 
49
+ inputs = tokenizer(chat_prompt, return_tensors="pt", padding=True, truncation=True)
50
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
51
 
52
  outputs = model.generate(
53
+ **inputs,
54
  max_new_tokens=max_new_tokens,
55
  temperature=temperature,
56
  top_p=top_p,
57
  repetition_penalty=repetition_penalty,
58
+ do_sample=True,
59
+ pad_token_id=tokenizer.pad_token_id,
60
+ eos_token_id=tokenizer.eos_token_id,
61
  )
62
 
63
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
64
  return generated_text.split("assistant\n")[-1].strip()
65
 
66
+ def update_output(system_prompt, user_prompt, temperature, max_new_tokens, top_p, repetition_penalty):
67
+ return generate_text(system_prompt, user_prompt, temperature, max_new_tokens, top_p, repetition_penalty)
68
 
69
  with gr.Blocks() as demo:
70
+ gr.Markdown("# 🦎 Welcome to Tonic's Salamandra-2b-instruct Demo")
71
 
72
  with gr.Row():
73
  with gr.Column(scale=1):
 
77
 
78
  with gr.Row():
79
  with gr.Column(scale=1):
80
+ system_prompt = gr.Textbox(
81
+ lines=3,
82
+ label="πŸ–₯️ System Prompt",
83
+ value="You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."
84
+ )
85
+ user_prompt = gr.Textbox(lines=5, label="πŸ™‹β€β™‚οΈ User Prompt")
86
+ generate_button = gr.Button("Generate with 🦎 Salamandra-2b-instruct")
87
 
88
  with gr.Accordion("πŸ§ͺ Parameters", open=False):
89
  temperature = gr.Slider(0.0, 1.0, value=0.7, label="🌑️ Temperature")
 
92
  repetition_penalty = gr.Slider(1.0, 2.0, value=1.2, label="πŸ” Repetition Penalty")
93
 
94
  with gr.Column(scale=1):
95
+ output = gr.Textbox(lines=10, label="🦎 Salamandra-2b-instruct Output")
96
 
97
  generate_button.click(
98
  update_output,
99
+ inputs=[system_prompt, user_prompt, temperature, max_new_tokens, top_p, repetition_penalty],
100
  outputs=output
101
  )
102
 
103
  gr.Examples(
104
  examples=[
105
+ ["You are a helpful assistant.", "What are the main advantages of living in a big city like Barcelona?"],
106
+ ["You are a biology teacher explaining concepts to students.", "Explain the process of photosynthesis in simple terms."],
107
+ ["You are a language learning expert.", "What are some effective strategies for learning a new language?"],
108
+ ["You are an AI and technology expert.", "Describe the potential impacts of artificial intelligence on the job market in the next decade."],
109
+ ["You are an environmental scientist.", "What are the key differences between renewable and non-renewable energy sources?"]
110
  ],
111
+ inputs=[system_prompt, user_prompt],
112
+ outputs=[system_prompt, user_prompt],
113
  label="Example Prompts"
114
  )
115