File size: 3,785 Bytes
51a7d9e 7cf1091 51a7d9e 7cf1091 51a7d9e 289c0ee 51a7d9e fd6304d 51a7d9e fd6304d 3b9cb87 3f4dd7a 3b9cb87 639e063 51a7d9e 997f90e 51a7d9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import torch
from PIL import Image
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import os
import time
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL_ID = "CohereForAI/aya-23-8B"
MODEL_NAME = MODEL_ID.split("/")[-1]
TITLE = "<h1><center>Aya-23-Chatbox</center></h1>"
DESCRIPTION = f'<h3><center>MODEL: <a href="https://hf.co/{MODEL_ID}">{MODEL_NAME}</a></center></h3>'
CSS = """
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
"""
#QUANTIZE
QUANTIZE_4BIT = True
USE_GRAD_CHECKPOINTING = True
TRAIN_BATCH_SIZE = 2
TRAIN_MAX_SEQ_LENGTH = 512
USE_FLASH_ATTENTION = False
GRAD_ACC_STEPS = 16
quantization_config = None
if QUANTIZE_4BIT:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16,
)
attn_implementation = None
if USE_FLASH_ATTENTION:
attn_implementation="flash_attention_2"
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
quantization_config=quantization_config,
attn_implementation=attn_implementation,
torch_dtype=torch.bfloat16,
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
@spaces.GPU
def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int):
print(f'message is - {message}')
print(f'history is - {history}')
conversation = []
for prompt, answer in history:
conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
conversation.append({"role": "user", "content": message})
print(f"Conversation is -\n{conversation}")
input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(model.device)
prompt_padded_len = len(input_ids[0])
gen_tokens= model.generate(
input_ids,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
)
gen_tokens = [
gt[prompt_padded_len:] for gt in gen_tokens
]
gen_text = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)
return gen_text
chatbot = gr.Chatbot(height=450)
with gr.Blocks(css=CSS) as demo:
gr.HTML(TITLE)
gr.HTML(DESCRIPTION)
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
gr.ChatInterface(
fn=stream_chat,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.8,
label="Temperature",
render=False,
),
gr.Slider(
minimum=128,
maximum=4096,
step=1,
value=1024,
label="Max new tokens",
render=False,
),
],
examples=[
["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],
["What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter."],
["Tell me a random fun fact about the Roman Empire."],
["Show me a code snippet of a website's sticky header in CSS and JavaScript."],
],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()
|