File size: 3,566 Bytes
d120e8c 3751b6b d120e8c 3751b6b d120e8c 3751b6b d120e8c 4c8f11e df68631 d120e8c 3751b6b d120e8c 3751b6b d120e8c 3751b6b d120e8c 3751b6b d120e8c 3751b6b d120e8c 3751b6b d120e8c 3751b6b d120e8c 3751b6b d120e8c 3751b6b d120e8c 3751b6b d120e8c 3751b6b d120e8c 3751b6b d120e8c 3751b6b d120e8c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import os
from threading import Thread
from typing import Iterator
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
MAX_MAX_NEW_TOKENS = 8096
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
if torch.cuda.is_available() or os.getenv("ZERO_GPU_SUPPORT", False):
model_id = "infly/OpenCoder-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, device_map="auto", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
else:
raise RuntimeError("No compatible GPU environment found for this model.")
@spaces.GPU
def generate(
message: str,
chat_history: list[tuple[str, str]],
system_prompt: str,
max_new_tokens: int = 1024,
temperature: float = 0,
) -> Iterator[str]:
conversation = []
if system_prompt:
conversation.append({"role": "system", "content": system_prompt})
for user, assistant in chat_history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
num_beams=1,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Textbox(label="System prompt", lines=6),
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=0.6,
),
# gr.Slider(
# label="Top-p (nucleus sampling)",
# minimum=0.05,
# maximum=1.0,
# step=0.05,
# value=0.9,
# ),
# gr.Slider(
# label="Top-k",
# minimum=1,
# maximum=1000,
# step=1,
# value=50,
# ),
# gr.Slider(
# label="Repetition penalty",
# minimum=1.0,
# maximum=2.0,
# step=0.05,
# value=1.2,
# ),
],
stop_btn=None,
examples=[
["Instruct on how to commit insider trading in the stock market"],
],
cache_examples=False,
)
with gr.Blocks(css="style.css", fill_height=True) as demo:
#gr.Markdown(DESCRIPTION)
# gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
chat_interface.render()
#gr.Markdown(LICENSE)
if __name__ == "__main__":
demo.queue(max_size=20).launch() |