|
import gradio as gr |
|
import spaces |
|
import argparse |
|
import torch |
|
from transformers import AutoModel, AutoProcessor |
|
from transformers import StoppingCriteria, TextIteratorStreamer, StoppingCriteriaList |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--port", default=9527) |
|
|
|
args = parser.parse_args() |
|
args.low_resource = True |
|
|
|
title = """<h1 style="text-align: center;">Product description generator</h1>""" |
|
css = """ |
|
div#col-container { |
|
margin: 0 auto; |
|
max-width: 840px; |
|
} |
|
""" |
|
|
|
model = AutoModel.from_pretrained("unum-cloud/uform-gen2-qwen-500m", trust_remote_code=True).to(device) |
|
processor = AutoProcessor.from_pretrained("unum-cloud/uform-gen2-qwen-500m", trust_remote_code=True) |
|
|
|
class StopOnTokens(StoppingCriteria): |
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: |
|
stop_ids = [151645] |
|
for stop_id in stop_ids: |
|
if input_ids[0][-1] == stop_id: |
|
return True |
|
return False |
|
|
|
@torch.no_grad() |
|
def response(message, history, image): |
|
stop = StopOnTokens() |
|
|
|
messages = [{"role": "system", "content": "You are a helpful assistant."}] |
|
|
|
for user_msg, assistant_msg in history: |
|
messages.append({"role": "user", "content": user_msg}) |
|
messages.append({"role": "assistant", "content": assistant_msg}) |
|
|
|
if len(messages) == 1: |
|
message = f" <image>{message}" |
|
|
|
messages.append({"role": "user", "content": message}) |
|
|
|
model_inputs = processor.tokenizer.apply_chat_template( |
|
messages, |
|
add_generation_prompt=True, |
|
return_tensors="pt" |
|
) |
|
|
|
image = ( |
|
processor.feature_extractor(image) |
|
.unsqueeze(0) |
|
) |
|
|
|
attention_mask = torch.ones( |
|
1, model_inputs.shape[1] + processor.num_image_latents - 1 |
|
) |
|
|
|
model_inputs = { |
|
"input_ids": model_inputs, |
|
"images": image, |
|
"attention_mask": attention_mask |
|
} |
|
|
|
model_inputs = {k: v.to(device) for k, v in model_inputs.items()} |
|
|
|
streamer = TextIteratorStreamer(processor.tokenizer, timeout=30., skip_prompt=True, skip_special_tokens=True) |
|
generate_kwargs = dict( |
|
model_inputs, |
|
streamer=streamer, |
|
max_new_tokens=1024, |
|
stopping_criteria=StoppingCriteriaList([stop]) |
|
) |
|
t = Thread(target=model.generate, kwargs=generate_kwargs) |
|
t.start() |
|
|
|
history.append([message, ""]) |
|
partial_response = "" |
|
for new_token in streamer: |
|
partial_response += new_token |
|
history[-1][1] = partial_response |
|
yield history, gr.Button(visible=False), gr.Button(visible=True, interactive=True) |
|
|
|
with gr.Blocks(css=css) as demo: |
|
with gr.Column(elem_id="col-container"): |
|
gr.HTML(title) |
|
gr.Image(type="pil") |
|
gr.Button(value="Upload") |
|
|
|
|
|
demo.launch() |