Spaces:
Runtime error
Runtime error
File size: 3,396 Bytes
829da7c 31a1ff8 e2b5fc2 829da7c 54fe16b 829da7c 54fe16b 829da7c 54fe16b 31a1ff8 54fe16b 829da7c 54fe16b 829da7c 53b40bf 54fe16b 829da7c 4856892 829da7c 54fe16b 4856892 54fe16b 829da7c 54fe16b 829da7c 54fe16b 994685c 54fe16b 53b40bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import argparse
import os
import spaces
import gradio as gr
import json
from threading import Thread
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
MAX_LENGTH = 4096
DEFAULT_MAX_NEW_TOKENS = 1024
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--base_model", type=str) # model path
parser.add_argument("--n_gpus", type=int, default=1) # n_gpu
return parser.parse_args()
@spaces.GPU()
def predict(message, history, system_prompt, temperature, max_tokens):
global model, tokenizer, device
messages = [{'role': 'system', 'content': system_prompt}]
for human, assistant in history:
messages.append({'role': 'user', 'content': human})
messages.append({'role': 'assistant', 'content': assistant})
messages.append({'role': 'user', 'content': message})
problem = [tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)]
stop_tokens = ["<|endoftext|>", "<|im_end|>"]
streamer = TextIteratorStreamer(tokenizer, timeout=100.0, skip_prompt=True, skip_special_tokens=True)
enc = tokenizer(problem, return_tensors="pt", padding=True, truncation=True)
input_ids = enc.input_ids
attention_mask = enc.attention_mask
if input_ids.shape[1] > MAX_LENGTH:
input_ids = input_ids[:, -MAX_LENGTH:]
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
generate_kwargs = dict(
{"input_ids": input_ids, "attention_mask": attention_mask},
streamer=streamer,
do_sample=True,
top_p=0.95,
temperature=temperature,
max_new_tokens=DEFAULT_MAX_NEW_TOKENS,
use_cache=True,
eos_token_id=100278 # <|im_end|>
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
if __name__ == "__main__":
args = parse_args()
tokenizer = AutoTokenizer.from_pretrained("stabilityai/stablelm-2-12b-chat", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("stabilityai/stablelm-2-12b-chat", trust_remote_code=True, torch_dtype=torch.bfloat16)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
gr.ChatInterface(
predict,
title="StableLM 2 12B Chat - Demo",
description="StableLM 2 12B Chat - StabilityAI",
theme="soft",
chatbot=gr.Chatbot(label="Chat History",),
textbox=gr.Textbox(placeholder="input", container=False, scale=7),
retry_btn=None,
undo_btn="Delete Previous",
clear_btn="Clear",
additional_inputs=[
gr.Textbox("You are a helpful assistant.", label="System Prompt"),
gr.Slider(0, 1, 0.5, label="Temperature"),
gr.Slider(100, 2048, 1024, label="Max Tokens"),
],
examples=[
["Implement snake game using pygame"],
["Escreva um poema sobre a saudade."],
["Ecris une prose a propos de la mer du Nord"],
["Schreibe ein Haiku ueber die Alpen"],
["What's been the role of music in human societies?"],
["How to become a good programmer?"]
],
additional_inputs_accordion_name="Parameters",
).queue().launch() |