Spaces:
Runtime error
Runtime error
import argparse | |
import os | |
from queue import SimpleQueue | |
from threading import Thread | |
from typing import Iterator | |
import gradio as gr | |
import spaces | |
import torch | |
from gradio import Chatbot | |
from image_utils import ImageStitcher | |
from transformers import (AutoModelForCausalLM, AutoTokenizer, | |
TextIteratorStreamer) | |
from StreamDiffusionIO import LatentConsistencyModelStreamIO | |
MAX_MAX_NEW_TOKENS = 2048 | |
DEFAULT_MAX_NEW_TOKENS = 1024 | |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) | |
DESCRIPTION = """\ | |
# Kanji-Streaming Chat | |
π This Space is adapted from [Llama-2-7b-chat](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat) space, demonstrating how to "chat" with LLM with [Kanji-Streaming](https://github.com/AgainstEntropy/kanji). | |
π¨ The technique behind Kanji-Streaming is [StreamDiffusionIO](https://github.com/AgainstEntropy/StreamDiffusionIO), which is based on [StreamDiffusion](https://github.com/cumulo-autumn/StreamDiffusion), *but especially allows to render text streams into image streams*. | |
π For more details about Kanji-Streaming, take a look at the [github repository](https://github.com/AgainstEntropy/kanji). | |
""" | |
LICENSE = """ | |
<p/> | |
--- | |
As a derivate work of [Llama-2-7b-chat](https://huggingface.co/meta-llama/Llama-2-7b-chat) by Meta, | |
this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/USE_POLICY.md). | |
""" | |
parser = argparse.ArgumentParser(description="Gradio launcher for Streaming-Kanji.") | |
parser.add_argument( | |
"--llama_model_id_or_path", | |
type=str, | |
default="meta-llama/Llama-2-7b-chat-hf", | |
required=False, | |
help="Path to downloaded llama-chat-hf model or model identifier from huggingface.co/models.", | |
) | |
parser.add_argument( | |
"--sd_model_id_or_path", | |
type=str, | |
default="runwayml/stable-diffusion-v1-5", | |
required=False, | |
help="Path to downloaded sd-1-5 model or model identifier from huggingface.co/models.", | |
) | |
parser.add_argument( | |
"--lora_path", | |
type=str, | |
default="AgainstEntropy/kanji-lora-sd-v1-5", | |
required=False, | |
help="Path to downloaded LoRA weight or model identifier from huggingface.co/models.", | |
) | |
parser.add_argument( | |
"--lcm_lora_path", | |
type=str, | |
default="AgainstEntropy/kanji-lcm-lora-sd-v1-5", | |
required=False, | |
help="Path to downloaded LCM-LoRA weight or model identifier from huggingface.co/models.", | |
) | |
parser.add_argument( | |
"--img_res", | |
type=int, | |
default=64, | |
required=False, | |
help="Image resolution for displaying Kanji characters in ChatBot.", | |
) | |
parser.add_argument( | |
"--img_per_line", | |
type=int, | |
default=16, | |
required=False, | |
help="Number of Kanji characters to display in a single line.", | |
) | |
parser.add_argument( | |
"--tmp_dir", | |
type=str, | |
default="./tmp", | |
required=False, | |
help="Path to save temporary images generated by StreamDiffusionIO.", | |
) | |
args = parser.parse_args() | |
if torch.cuda.is_available(): | |
device = "cuda" | |
else: | |
device = "cpu" | |
DESCRIPTION += "\n<p>Running on CPU π₯Ά This demo works best on GPU.</p>" | |
DESCRIPTION += "\n<p>This demo will get the best kanji streaming experience in localhost (or SSH forward), instead of shared link generated by Gradio.</p>" | |
model = AutoModelForCausalLM.from_pretrained(args.llama_model_id_or_path, torch_dtype=torch.float16, device_map="auto") | |
tokenizer = AutoTokenizer.from_pretrained(args.llama_model_id_or_path) | |
tokenizer.use_default_system_prompt = False | |
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) | |
lcm_stream = LatentConsistencyModelStreamIO( | |
model_id_or_path=args.sd_model_id_or_path, | |
lcm_lora_path=args.lcm_lora_path, | |
lora_dict={args.lora_path: 1}, | |
resolution=128, | |
device=device, | |
use_xformers=True, | |
verbose=True, | |
) | |
tmp_dir_template = f"{args.tmp_dir}/%d" | |
response_num = 0 | |
stitcher = ImageStitcher( | |
tmp_dir=tmp_dir_template % response_num, | |
img_res=args.img_res, | |
img_per_line=args.img_per_line, | |
verbose=True, | |
) | |
def generate( | |
message: str, | |
chat_history: list[tuple[str, str]], | |
show_original_response: bool, | |
seed: int, | |
system_prompt: str = '', | |
max_new_tokens: int = 1024, | |
temperature: float = 0.6, | |
top_p: float = 0.9, | |
top_k: int = 50, | |
repetition_penalty: float = 1.2, | |
) -> Iterator[str]: | |
conversation = [] | |
if system_prompt: | |
conversation.append({"role": "system", "content": system_prompt}) | |
for user, assistant in chat_history: | |
if isinstance(assistant, tuple): | |
assistant = assistant[1] | |
else: | |
assistant = str(assistant) | |
conversation.extend([ | |
{"role": "user", "content": user}, | |
{"role": "assistant", "content": assistant}, | |
]) | |
conversation.append({"role": "user", "content": message}) | |
print(conversation) | |
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt") | |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: | |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] | |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") | |
input_ids = input_ids.to(model.device) | |
generate_kwargs = dict( | |
{"input_ids": input_ids}, | |
streamer=streamer, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
top_p=top_p, | |
top_k=top_k, | |
temperature=temperature, | |
num_beams=1, | |
repetition_penalty=repetition_penalty, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
outputs = [] | |
prompt_queue = SimpleQueue() | |
lcm_stream.reset(seed) | |
stitcher.reset() | |
global response_num | |
response_num += 1 | |
stitcher.update_tmp_dir(tmp_dir_template % response_num) | |
def append_to_queue(): | |
for text in streamer: | |
outputs.append(text) | |
prompt = text.strip() | |
if prompt: | |
if prompt.endswith("."): prompt = prompt[:-1] | |
prompt_queue.put(prompt) | |
prompt_queue.put(None) | |
append_thread = Thread(target=append_to_queue) | |
append_thread.start() | |
def show_image(prompt: str = None): | |
image, text = lcm_stream(prompt) | |
img_path = None | |
if image is not None: | |
img_path = stitcher.add(image, text) | |
return img_path | |
while True: | |
prompt = prompt_queue.get() | |
if prompt is None: | |
break | |
img_path = show_image(prompt) | |
if img_path is not None: | |
yield (img_path, ) | |
# Continue to display the remaining images | |
while True: | |
img_path = show_image() | |
if img_path is not None: | |
yield (img_path, ''.join(outputs)) | |
if lcm_stream.stop(): | |
break | |
print(outputs) | |
if show_original_response: | |
yield ''.join(outputs) | |
chat_interface = gr.ChatInterface( | |
fn=generate, | |
chatbot=Chatbot(height=400), | |
additional_inputs=[ | |
gr.Checkbox( | |
label="Show original response", | |
value=False, | |
), | |
gr.Number( | |
label="Seed", | |
info="Random Seed for Kanji Generation (maybe some kind of accent π€)", | |
step=1, | |
value=1026, | |
), | |
gr.Textbox(label="System prompt", lines=4), | |
gr.Slider( | |
label="Max new tokens", | |
minimum=1, | |
maximum=MAX_MAX_NEW_TOKENS, | |
step=1, | |
value=DEFAULT_MAX_NEW_TOKENS, | |
), | |
gr.Slider( | |
label="Temperature", | |
minimum=0.1, | |
maximum=4.0, | |
step=0.1, | |
value=0.6, | |
), | |
gr.Slider( | |
label="Top-p (nucleus sampling)", | |
minimum=0.05, | |
maximum=1.0, | |
step=0.05, | |
value=0.9, | |
), | |
gr.Slider( | |
label="Top-k", | |
minimum=1, | |
maximum=1000, | |
step=1, | |
value=50, | |
), | |
gr.Slider( | |
label="Repetition penalty", | |
minimum=1.0, | |
maximum=2.0, | |
step=0.05, | |
value=1.2, | |
), | |
], | |
stop_btn=None, | |
examples=[ | |
["Hello there! How are you doing?"], | |
["Can you explain briefly to me what is the Python programming language?"], | |
["Explain the plot of Cinderella in a sentence."], | |
["How many hours does it take a man to eat a Helicopter?"], | |
["Write a 100-word article on 'Benefits of Open-Source in AI research'"], | |
], | |
) | |
with gr.Blocks(css="style.css") as demo: | |
gr.Markdown(DESCRIPTION) | |
gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button") | |
chat_interface.render() | |
gr.Markdown(LICENSE) | |
if __name__ == "__main__": | |
demo.queue(max_size=20).launch(server_name="0.0.0.0", share=False) | |