File size: 3,307 Bytes
563f98d
bc3802f
 
7d58261
 
 
563f98d
 
bc3802f
 
 
 
 
 
 
 
563f98d
bc3802f
 
 
 
 
 
 
 
 
 
 
7d58261
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc3802f
 
 
 
14901fa
bc3802f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import gradio as gr
import spaces
import argparse
import torch
from transformers import AutoModel, AutoProcessor
from transformers import StoppingCriteria, TextIteratorStreamer, StoppingCriteriaList


parser = argparse.ArgumentParser()
# parser.add_argument("--device", type=str, default="cuda:0")
# parser.add_argument("--ckpt_path", type=str, default="./salmonn_7b_v0.pth")
# parser.add_argument("--whisper_path", type=str, default="./whisper_large_v2")
# parser.add_argument("--beats_path", type=str, default="./beats/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt")
# parser.add_argument("--vicuna_path", type=str, default="./vicuna-7b-v1.5")
# parser.add_argument("--low_resource", action='store_true', default=False)
parser.add_argument("--port", default=9527)

args = parser.parse_args()
args.low_resource = True

title = """<h1 style="text-align: center;">Product description generator</h1>"""
css = """
div#col-container {
    margin: 0 auto;
    max-width: 840px;
}
"""

model = AutoModel.from_pretrained("unum-cloud/uform-gen2-qwen-500m", trust_remote_code=True).to(device)
processor = AutoProcessor.from_pretrained("unum-cloud/uform-gen2-qwen-500m", trust_remote_code=True)

class StopOnTokens(StoppingCriteria):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        stop_ids = [151645]
        for stop_id in stop_ids:
            if input_ids[0][-1] == stop_id:
                return True
        return False

@torch.no_grad()
def response(message, history, image):
    stop = StopOnTokens()

    messages = [{"role": "system", "content": "You are a helpful assistant."}]

    for user_msg, assistant_msg in history:
        messages.append({"role": "user", "content": user_msg})
        messages.append({"role": "assistant", "content": assistant_msg})

    if len(messages) == 1:
        message = f" <image>{message}"

    messages.append({"role": "user", "content": message})

    model_inputs = processor.tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt"
    )

    image = (
            processor.feature_extractor(image)
            .unsqueeze(0)
    )

    attention_mask = torch.ones(
        1, model_inputs.shape[1] + processor.num_image_latents - 1
    )
    
    model_inputs = {
        "input_ids": model_inputs,
        "images": image,
        "attention_mask": attention_mask
    }

    model_inputs = {k: v.to(device) for k, v in model_inputs.items()}

    streamer = TextIteratorStreamer(processor.tokenizer, timeout=30., skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        model_inputs,
        streamer=streamer,
        max_new_tokens=1024,
        stopping_criteria=StoppingCriteriaList([stop])
        )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    history.append([message, ""])
    partial_response = ""
    for new_token in streamer:
        partial_response += new_token
        history[-1][1] = partial_response
        yield history, gr.Button(visible=False), gr.Button(visible=True, interactive=True)

with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.HTML(title)
        gr.Image(type="pil")
        gr.Button(value="Upload")


demo.launch()