Spaces:
Runtime error
Runtime error
import argparse | |
import os | |
from queue import SimpleQueue | |
from threading import Thread | |
from typing import Iterator | |
import gradio as gr | |
import spaces | |
import torch | |
from gradio import Chatbot | |
from huggingface_hub import InferenceClient | |
from image_utils import ImageStitcher | |
from StreamDiffusionIO import LatentConsistencyModelStreamIO | |
MAX_MAX_NEW_TOKENS = 2048 | |
DEFAULT_MAX_NEW_TOKENS = 1024 | |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) | |
DESCRIPTION = """\ | |
# Kanji-Streaming Chat | |
π This Space is adapted from [Llama-2-7b-chat](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat) space, demonstrating how to "chat" with LLM with [Kanji-Streaming](https://github.com/AgainstEntropy/kanji). | |
π¨ The technique behind Kanji-Streaming is [StreamDiffusionIO](https://github.com/AgainstEntropy/StreamDiffusionIO), which is based on [StreamDiffusion](https://github.com/cumulo-autumn/StreamDiffusion), *but especially allows to render text streams into image streams*. | |
π For more details about Kanji-Streaming, take a look at the [github repository](https://github.com/AgainstEntropy/kanji). | |
""" | |
LICENSE = """ | |
<p/> | |
--- | |
As a derivate work of [Llama-2-7b-chat](https://huggingface.co/meta-llama/Llama-2-7b-chat) by Meta, | |
this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/USE_POLICY.md). | |
""" | |
parser = argparse.ArgumentParser(description="Gradio launcher for Streaming-Kanji.") | |
parser.add_argument( | |
"--sd_model_id_or_path", | |
type=str, | |
default="runwayml/stable-diffusion-v1-5", | |
required=False, | |
help="Path to downloaded sd-1-5 model or model identifier from huggingface.co/models.", | |
) | |
parser.add_argument( | |
"--lora_path", | |
type=str, | |
default="AgainstEntropy/kanji-lora-sd-v1-5", | |
required=False, | |
help="Path to downloaded LoRA weight or model identifier from huggingface.co/models.", | |
) | |
parser.add_argument( | |
"--lcm_lora_path", | |
type=str, | |
default="AgainstEntropy/kanji-lcm-lora-sd-v1-5", | |
required=False, | |
help="Path to downloaded LCM-LoRA weight or model identifier from huggingface.co/models.", | |
) | |
parser.add_argument( | |
"--img_res", | |
type=int, | |
default=64, | |
required=False, | |
help="Image resolution for displaying Kanji characters in ChatBot.", | |
) | |
parser.add_argument( | |
"--img_per_line", | |
type=int, | |
default=16, | |
required=False, | |
help="Number of Kanji characters to display in a single line.", | |
) | |
parser.add_argument( | |
"--tmp_dir", | |
type=str, | |
default="./tmp", | |
required=False, | |
help="Path to save temporary images generated by StreamDiffusionIO.", | |
) | |
args = parser.parse_args() | |
if torch.cuda.is_available(): | |
device = "cuda" | |
else: | |
device = "cpu" | |
DESCRIPTION += "\n<p>Running on CPU π₯Ά This demo works best on GPU.</p>" | |
client = InferenceClient( | |
model="mistralai/Mixtral-8x7B-Instruct-v0.1", | |
) | |
def format_prompt(message, history, system_prompt=''): | |
prompt = f"<s> {system_prompt}" | |
for user_prompt, bot_response in history: | |
prompt += f"[INST] {user_prompt} [/INST]" | |
if isinstance(bot_response, tuple): | |
bot_response = bot_response[1] | |
if not bot_response.endswith("</s>"): | |
bot_response += "</s>" | |
prompt += f" {bot_response} " | |
prompt += f"[INST] {message} [/INST]" | |
return prompt | |
lcm_stream = LatentConsistencyModelStreamIO( | |
model_id_or_path=args.sd_model_id_or_path, | |
lcm_lora_path=args.lcm_lora_path, | |
lora_dict={args.lora_path: 1}, | |
resolution=128, | |
device=device, | |
use_xformers=True, | |
verbose=True, | |
) | |
tmp_dir_template = f"{args.tmp_dir}/%d" | |
response_num = 0 | |
stitcher = ImageStitcher( | |
tmp_dir=tmp_dir_template % response_num, | |
img_res=args.img_res, | |
img_per_line=args.img_per_line, | |
verbose=True, | |
) | |
def generate( | |
message: str, | |
chat_history: list[tuple[str, str]], | |
show_original_response: bool, | |
seed: int, | |
system_prompt: str = '', | |
max_new_tokens: int = 1024, | |
temperature: float = 0.6, | |
top_p: float = 0.9, | |
top_k: int = 50, | |
repetition_penalty: float = 1.2, | |
) -> Iterator[str]: | |
if temperature < 1e-2: | |
temperature = 1e-2 | |
generate_kwargs = dict( | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
top_p=top_p, | |
top_k=top_k, | |
temperature=temperature, | |
repetition_penalty=repetition_penalty, | |
) | |
formatted_prompt = format_prompt(message, chat_history, system_prompt) | |
print(formatted_prompt) | |
streamer = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) | |
outputs = [] | |
prompt_queue = SimpleQueue() | |
lcm_stream.reset(seed) | |
stitcher.reset() | |
global response_num | |
response_num += 1 | |
stitcher.update_tmp_dir(tmp_dir_template % response_num) | |
def append_to_queue(): | |
for response in streamer: | |
text = response.token.text | |
outputs.append(text) | |
prompt = text.strip() | |
if prompt and prompt not in ['</s>']: | |
if prompt.endswith("."): prompt = prompt[:-1] | |
prompt_queue.put(prompt) | |
prompt_queue.put(None) | |
append_thread = Thread(target=append_to_queue) | |
append_thread.start() | |
def show_image(prompt: str = None): | |
image, text = lcm_stream(prompt) | |
img_path = None | |
if image is not None: | |
img_path = stitcher.add(image, text) | |
return img_path | |
while True: | |
prompt = prompt_queue.get() | |
if prompt is None: | |
break | |
img_path = show_image(prompt) | |
if img_path is not None: | |
yield (img_path, ) | |
# Continue to display the remaining images | |
while True: | |
img_path = show_image() | |
if img_path is not None: | |
yield (img_path, ''.join(outputs)) | |
if lcm_stream.stop(): | |
break | |
print(outputs) | |
if show_original_response: | |
yield ''.join(outputs) | |
chat_interface = gr.ChatInterface( | |
fn=generate, | |
chatbot=Chatbot(height=400), | |
additional_inputs=[ | |
gr.Checkbox( | |
label="Show original response", | |
value=False, | |
), | |
gr.Number( | |
label="Seed", | |
info="Random Seed for Kanji Generation (maybe some kind of accent π€)", | |
step=1, | |
value=1026, | |
), | |
gr.Textbox( | |
label="System prompt", | |
value="", | |
lines=4), | |
gr.Slider( | |
label="Max new tokens", | |
minimum=1, | |
maximum=MAX_MAX_NEW_TOKENS, | |
step=1, | |
value=DEFAULT_MAX_NEW_TOKENS, | |
), | |
gr.Slider( | |
label="Temperature", | |
minimum=0.1, | |
maximum=4.0, | |
step=0.1, | |
value=0.6, | |
), | |
gr.Slider( | |
label="Top-p (nucleus sampling)", | |
minimum=0.05, | |
maximum=1.0, | |
step=0.05, | |
value=0.9, | |
), | |
gr.Slider( | |
label="Top-k", | |
minimum=1, | |
maximum=1000, | |
step=1, | |
value=50, | |
), | |
gr.Slider( | |
label="Repetition penalty", | |
minimum=1.0, | |
maximum=2.0, | |
step=0.05, | |
value=1.2, | |
), | |
], | |
stop_btn=None, | |
examples=[ | |
["Hello there! How are you doing?"], | |
["Can you explain briefly to me what is the Python programming language?"], | |
["Explain the plot of Cinderella in a sentence."], | |
["How many hours does it take a man to eat a Helicopter?"], | |
["Write a 100-word article on 'Benefits of Open-Source in AI research'"], | |
], | |
) | |
with gr.Blocks(css="style.css") as demo: | |
gr.Markdown(DESCRIPTION) | |
gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button") | |
chat_interface.render() | |
gr.Markdown(LICENSE) | |
if __name__ == "__main__": | |
demo.queue(max_size=20).launch(server_name="0.0.0.0", share=False, show_api=False) | |