File size: 5,153 Bytes
59812f5 141ba59 e880564 c86c2f3 d2d3f64 c86c2f3 141ba59 c86c2f3 4522cd0 59812f5 4522cd0 141ba59 4522cd0 2805de5 4522cd0 e6dd388 2805de5 e6dd388 c86c2f3 09b3f75 c86c2f3 1827259 141ba59 f00b880 141ba59 c86c2f3 d2d3f64 4522cd0 c86c2f3 e880564 c86c2f3 141ba59 54995d2 6bc8e25 54995d2 141ba59 54995d2 141ba59 c86c2f3 141ba59 c86c2f3 141ba59 b4ca5ac 141ba59 09b3f75 c86c2f3 141ba59 09b3f75 c86c2f3 4522cd0 c86c2f3 141ba59 09b3f75 c86c2f3 141ba59 09b3f75 c86c2f3 4522cd0 c86c2f3 141ba59 2805de5 141ba59 1827259 141ba59 e6dd388 89f9579 e880564 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import os
from threading import Thread
from typing import Iterator, List, Tuple
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
DESCRIPTION = """\
# Llama-2 7B Chat
This Space demonstrates model [llama-2-7b-bics-multi_woz_v22](https://huggingface.co/stevugnin/llama-2-7b-bics-multi_woz_v22) by University of Luxembourg FSTM, a Llama 2 model with 7B parameters fine-tuned for multi-domain customer support chat instructions. Feel free to play with it, or duplicate to run generations without a queue! If you want to run your own service, you can also [deploy the model on Inference Endpoints](https://huggingface.co/inference-endpoints).
🔎 For more details about the Llama 2 family of models and how to use them with `transformers`, take a look [at our blog post](https://huggingface.co/blog/llama2).
"""
LICENSE = """
<p/>
---
As a derivate work of [llama-2-7b-bics-multi_woz_v22](https://huggingface.co/stevugnin/llama-2-7b-bics-multi_woz_v22) by University of Luxembourg FSTM,
this demo is governed by the original [license](https://huggingface.co/spaces/stevugnin/multi-domain-customer-support-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/stevugnin/multi-domain-customer-support-chat/blob/main/USE_POLICY.md).
"""
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
if torch.cuda.is_available():
model_id = "stevugnin/llama-2-7b-bics-multi_woz_v22"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.use_default_system_prompt = False
@spaces.GPU
def generate(
message: str,
chat_history: List[Tuple[str, str]],
system_prompt: str,
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
conversation = []
if system_prompt:
conversation.append({"role": "system", "content": system_prompt})
for user, assistant in chat_history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Textbox(label="System prompt", lines=6),
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=0.6,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.2,
),
],
stop_btn=None,
examples=[
["Hi there! Can you give me some info on Cityroomz?"],
["I am looking for a restaurant. I would like something cheap that has Chinese food."],
["Please find me a train from Cambridge to Stansted airport."],
["I'd like a train from Leicester to Cambridge, please!"],
["Hi, I'm traveling to Cambridge soon and am looking forward to seeing some local tourist attractions."],
],
)
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
chat_interface.render()
gr.Markdown(LICENSE)
if __name__ == "__main__":
demo.queue(max_size=20).launch(share=True)
|