Spaces:
Paused
Paused
File size: 5,766 Bytes
2bc99a0 927b5de 2bc99a0 a3c3064 dcb9c75 36072c8 63a0917 b6dc5a5 2bc99a0 b6dc5a5 a3c3064 2b24dd8 aa91997 b6dc5a5 2bc99a0 9bc49ef 0d5c130 9bc49ef 6a39dc1 0d5c130 63a0917 b9eff4b 9bc49ef aa91997 2bc99a0 b6dc5a5 a3c3064 2bc99a0 b6dc5a5 2bc99a0 fd37061 5ab0bbc a3c3064 2bc99a0 52c03eb 6a39dc1 2bc99a0 a3c3064 b9eff4b b6dc5a5 fd37061 93d63cd dcb9c75 8503364 b9eff4b 5ab0bbc fd37061 fe35217 b6dc5a5 8de5029 fe35217 fd37061 36072c8 670dcbd 8503364 fd37061 2b24dd8 fd37061 20c9b6d 36072c8 fd37061 dcb9c75 fd37061 b9eff4b fd37061 ee86b18 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import os
import math
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import gradio as gr
import sentencepiece
title = "# Welcome to 🙋🏻♂️Tonic's🌷Tulu Chat!"
description = """[allenai/tulu-2-dpo-7b](https://huggingface.co/allenai/tulu-2-dpo-7b) and larger Tulu-2 models are Instruct Llama Finetunes using the [mistralai/Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-v0.1) recipe. You can use [allenai/tulu-2-13b](https://huggingface.co/allenai/tulu-2-13b) here via API using Gradio by scrolling down and clicking Use 'Via API' or privately by [cloning this space on huggingface](https://huggingface.co/spaces/Tonic1/TuluDemo?duplicate=true) See also the large model here : [allenai/tulu-2-dpo-70b](https://huggingface.co/allenai/tulu-2-dpo-70b) . [Join my active builders' server on discord](https://discord.gg/VqTxc76K3u). Let's build together!. [Add this Space as a discord bot to your server by clicking this link](https://discord.com/oauth2/authorize?client_id=1176628808212828231&scope=bot+applications.commands&permissions=326417525824). Big thanks to 🤗Huggingface Organisation for the🫂Community Grant"""
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:50'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_name = "allenai/tulu-2-dpo-13b"
tokenizer = AutoTokenizer.from_pretrained("allenai/tulu-2-dpo-13b")
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
# bos_token_id = 1
# eos_token_id = 2
# tokenizer.bos_token_id = bos_token_id
# tokenizer.eos_token_id = eos_token_id
# model.config.bos_token_id = bos_token_id
# model.config.eos_token_id = eos_token_id
# if tokenizer.pad_token is None:
# tokenizer.pad_token = tokenizer.eos_token
# model.config.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
class TuluChatBot:
def __init__(self, model, tokenizer, system_message="You are 🌷Tulu, an AI language model created by Tonic-AI. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior."):
self.model = model
self.tokenizer = tokenizer
self.system_message = system_message
def set_system_message(self, new_system_message):
self.system_message = new_system_message
def format_prompt(self, user_message):
prompt = f"<|assistant|>\n{self.system_message}\n<|user|>{user_message}\n<|assistant|>\n"
return prompt
def Tulu(self, user_message, temperature, max_new_tokens, top_p, repetition_penalty, do_sample):
prompt = self.format_prompt(user_message)
inputs = self.tokenizer(prompt, return_tensors='pt', add_special_tokens=True)
input_ids = inputs["input_ids"].to(self.model.device)
attention_mask = inputs["attention_mask"].to(self.model.device)
output_ids = self.model.generate(
input_ids,
attention_mask=attention_mask,
max_length=input_ids.shape[1] + max_new_tokens,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=do_sample
)
response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
response = response.strip()
response = response.split("<|assistant|>\n")[-1]
return response
def gradio_Tulu(user_message, system_message, max_new_tokens, temperature, top_p, repetition_penalty, do_sample):
Tulu_bot.set_system_message(system_message)
if not do_sample:
max_length = 780
temperature = 0.9
top_p = 0.9
repetition_penalty = 0.9
response = Tulu_bot.Tulu(user_message, temperature, max_new_tokens, top_p, repetition_penalty, do_sample)
return response
# Initialize TuluChatBot
Tulu_bot = TuluChatBot(model, tokenizer)
# Gradio interface function
def gradio_Tulu(user_message, system_message, max_new_tokens, temperature, top_p, repetition_penalty, do_sample):
Tulu_bot.set_system_message(system_message)
response = Tulu_bot.Tulu(user_message, temperature, max_new_tokens, top_p, repetition_penalty, do_sample)
return response
with gr.Blocks(theme = "ParityError/Anime") as demo:
gr.Markdown(title)
gr.Markdown(description)
with gr.Row():
system_message = gr.Textbox(label="Optional 🌷Tulu Assistant Message", lines=2)
user_message = gr.Textbox(label="Your Message", lines=3)
with gr.Row():
do_sample = gr.Checkbox(label="Advanced", value=True)
with gr.Accordion("Advanced Settings", open=lambda do_sample: do_sample):
with gr.Row():
max_new_tokens = gr.Slider(label="Max new tokens", value=250, minimum=20, maximum=450, step=1)
temperature = gr.Slider(label="Temperature", value=0.3, minimum=0.1, maximum=1.0, step=0.1)
top_p = gr.Slider(label="Top-p (nucleus sampling)", value=0.90, minimum=0.01, maximum=0.99, step=0.05)
repetition_penalty = gr.Slider(label="Repetition penalty", value=0.9, minimum=0.05, maximum=1.0, step=0.05)
submit_button = gr.Button("Submit")
output_text = gr.Textbox(label="🌷Tulu Response")
def process(user_message, system_message, max_new_tokens, temperature, top_p, repetition_penalty, do_sample):
return gradio_Tulu(user_message, system_message, max_new_tokens, temperature, top_p, repetition_penalty, do_sample)
submit_button.click(
process,
inputs=[user_message, system_message, max_new_tokens, temperature, top_p, repetition_penalty, do_sample],
outputs=output_text
)
demo.launch() |