Spaces:
Runtime error
Runtime error
File size: 4,124 Bytes
8c029ff 7f7ba92 8c029ff 625f637 7f7ba92 66f04b1 4e84396 764fbdd 8bd6e7a 7f7ba92 8c029ff 4bd06f0 625f637 8c029ff 577b78e 8c029ff c2c6f60 8c029ff 2911cde 8c029ff 577b78e 66f04b1 8c029ff c2c6f60 577b78e 8c029ff 4e84396 8c029ff 7f7ba92 8c029ff 7f7ba92 8c029ff 7f7ba92 8bd6e7a 7f7ba92 4e84396 7f7ba92 fcd3d1b 66f04b1 fcd3d1b 66f04b1 fcd3d1b 4e84396 ffd4c08 7f7ba92 ffd4c08 7f7ba92 ffd4c08 7f7ba92 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import os
from threading import Thread
from typing import Iterator
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
DESCRIPTION = """\
# Llama backend
This is a demo of text completion with AI LLM's.
Enter your text in the box below and click "Complete" to have the AI generate a completion for your input. The generated text will be appended to your input. You can stop the generation at any time by clicking the "Stop" button.
"""
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_id = "meta-llama/Llama-3.1-8B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.float16,
load_in_8bit=True,
)
model.eval()
@spaces.GPU
def generate(
message: str,
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.1,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
prompt = f"{message}"
input_ids = tokenizer.encode(prompt, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=input_ids,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
partial_message = message
for text in streamer:
partial_message += text
yield partial_message
with gr.Blocks(css="style.css", fill_height=True) as demo:
gr.Markdown(DESCRIPTION)
gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
with gr.Row():
with gr.Column(scale=4):
text_box = gr.Textbox(
label="Enter your text",
placeholder="Type your message here...",
lines=5
)
with gr.Column(scale=1):
complete_button = gr.Button("Complete")
stop_button = gr.Button("Stop")
max_new_tokens = gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
)
temperature = gr.Slider(
label="Temperature",
minimum=0.1,
maximum=1.0,
step=0.1,
value=0.1,
)
top_p = gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
)
top_k = gr.Slider(
label="Top-k",
minimum=1,
maximum=100, # Changed from 1000 to 100
step=1,
value=50,
)
repetition_penalty = gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.2,
)
# Set up the generation event
generation_event = complete_button.click(
generate,
inputs=[text_box, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
outputs=[text_box],
)
# Set up the stop event
stop_button.click(
None,
None,
None,
cancels=[generation_event]
)
if __name__ == "__main__":
demo.queue(max_size=20).launch()
|