File size: 5,945 Bytes
a0bd6fb
4f199bf
 
 
 
a0bd6fb
4f199bf
3858798
 
 
 
4f199bf
a0bd6fb
4f199bf
 
c46a8ee
 
 
4f199bf
 
 
e25f9d4
c46a8ee
 
 
4f199bf
 
 
e25f9d4
4f199bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e25f9d4
3858798
4f199bf
 
 
 
e25f9d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f199bf
 
 
 
 
 
 
 
 
 
 
 
 
e25f9d4
4f199bf
 
 
e25f9d4
4f199bf
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import gradio as gr
from transformers import AutoModelForVision2Seq, AutoTokenizer, AutoImageProcessor, StoppingCriteria
import spaces
import torch
from PIL import Image

models = {
    "Salesforce/xgen-mm-phi3-mini-instruct-r-v1": AutoModelForVision2Seq.from_pretrained("Salesforce/xgen-mm-phi3-mini-instruct-r-v1", trust_remote_code=True),
    "Salesforce/xgen-mm-phi3-mini-instruct-interleave-r-v1.5": AutoModelForVision2Seq.from_pretrained("Salesforce/xgen-mm-phi3-mini-instruct-interleave-r-v1.5", trust_remote_code=True),
    "Salesforce/xgen-mm-phi3-mini-instruct-singleimg-r-v1.5": AutoModelForVision2Seq.from_pretrained("Salesforce/xgen-mm-phi3-mini-instruct-singleimg-r-v1.5", trust_remote_code=True),
    "Salesforce/xgen-mm-phi3-mini-instruct-dpo-r-v1.5": AutoModelForVision2Seq.from_pretrained("Salesforce/xgen-mm-phi3-mini-instruct-dpo-r-v1.5", trust_remote_code=True)
}

processors = {
    "Salesforce/xgen-mm-phi3-mini-instruct-r-v1": AutoImageProcessor.from_pretrained("Salesforce/xgen-mm-phi3-mini-instruct-r-v1", trust_remote_code=True),
    "Salesforce/xgen-mm-phi3-mini-instruct-interleave-r-v1.5": AutoImageProcessor.from_pretrained("Salesforce/xgen-mm-phi3-mini-instruct-interleave-r-v1.5", trust_remote_code=True),
    "Salesforce/xgen-mm-phi3-mini-instruct-singleimg-r-v1.5": AutoImageProcessor.from_pretrained("Salesforce/xgen-mm-phi3-mini-instruct-singleimg-r-v1.5", trust_remote_code=True),
    "Salesforce/xgen-mm-phi3-mini-instruct-dpo-r-v1.5": AutoImageProcessor.from_pretrained("Salesforce/xgen-mm-phi3-mini-instruct-dpo-r-v1.5", trust_remote_code=True)
}

tokenizers = {
    "Salesforce/xgen-mm-phi3-mini-instruct-r-v1": AutoTokenizer.from_pretrained("Salesforce/xgen-mm-phi3-mini-instruct-r-v1", trust_remote_code=True, use_fast=False, legacy=False),
    "Salesforce/xgen-mm-phi3-mini-instruct-interleave-r-v1.5": AutoTokenizer.from_pretrained("Salesforce/xgen-mm-phi3-mini-instruct-interleave-r-v1.5", trust_remote_code=True, use_fast=False, legacy=False),
    "Salesforce/xgen-mm-phi3-mini-instruct-singleimg-r-v1.5": AutoTokenizer.from_pretrained("Salesforce/xgen-mm-phi3-mini-instruct-singleimg-r-v1.5", trust_remote_code=True, use_fast=False, legacy=False),
    "Salesforce/xgen-mm-phi3-mini-instruct-dpo-r-v1.5": AutoTokenizer.from_pretrained("Salesforce/xgen-mm-phi3-mini-instruct-dpo-r-v1.5", trust_remote_code=True, use_fast=False, legacy=False)
}


DESCRIPTION = "# [xGen-MM Demo](https://huggingface.co/collections/Salesforce/xgen-mm-1-models-662971d6cecbf3a7f80ecc2e)"


def apply_prompt_template(prompt):
    s = (
        '<|system|>\nA chat between a curious user and an artificial intelligence assistant. '
        "The assistant gives helpful, detailed, and polite answers to the user's questions.<|end|>\n"
        f'<|user|>\n<image>\n{prompt}<|end|>\n<|assistant|>\n'
    )
    return s


class EosListStoppingCriteria(StoppingCriteria):
    def __init__(self, eos_sequence = [32007]):
        self.eos_sequence = eos_sequence

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        last_ids = input_ids[:,-len(self.eos_sequence):].tolist()
        return self.eos_sequence in last_ids    


@spaces.GPU
def run_example(image, text_input=None, model_id="Salesforce/xgen-mm-phi3-mini-instruct-interleave-r-v1.5"):
    model = models[model_id].to("cuda").eval()
    processor = processors[model_id]
    tokenizer = tokenizers[model_id]
    tokenizer = model.update_special_tokens(tokenizer)

    if model_id == "Salesforce/xgen-mm-phi3-mini-instruct-r-v1":
        image = Image.fromarray(image).convert("RGB")
        prompt = apply_prompt_template(text_input)
        language_inputs = tokenizer([prompt], return_tensors="pt")
        
        inputs = processor([image], return_tensors="pt", image_aspect_ratio='anyres')
        inputs.update(language_inputs)
        inputs = {name: tensor.cuda() for name, tensor in inputs.items()}

        generated_text = model.generate(**inputs, image_size=[image.size],
            pad_token_id=tokenizer.pad_token_id,
            do_sample=False, max_new_tokens=768, top_p=None, num_beams=1,
            stopping_criteria = [EosListStoppingCriteria()],
        )
    else:
        image_list = []
        image_sizes = []

        img = Image.fromarray(image).convert("RGB")
        image_list.append(processor([img], image_aspect_ratio='anyres')["pixel_values"].cuda())
        image_sizes.append(img.size)

        inputs = {
            "pixel_values": [image_list]
        }
        prompt = apply_prompt_template(text_input)
        language_inputs = tokenizer([prompt], return_tensors="pt")
        inputs.update(language_inputs)

        for name, value in inputs.items():
            if isinstance(value, torch.Tensor):
                inputs[name] = value.cuda()
        generated_text = model.generate(**inputs, image_size=[image_sizes],
            pad_token_id=tokenizer.pad_token_id,
            do_sample=False, max_new_tokens=1024, top_p=None, num_beams=1,
        )

    prediction = tokenizer.decode(generated_text[0], skip_special_tokens=True).split("<|end|>")[0]
    return prediction
css = """
  #output {
    height: 500px; 
    overflow: auto; 
    border: 1px solid #ccc; 
  }
"""

with gr.Blocks(css=css) as demo:
    gr.Markdown(DESCRIPTION)
    with gr.Tab(label="xGen-MM Input"):
        with gr.Row():
            with gr.Column():
                input_img = gr.Image(label="Input Picture")
                model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value="Salesforce/xgen-mm-phi3-mini-instruct-interleave-r-v1.5")
                text_input = gr.Textbox(label="Question")
                submit_btn = gr.Button(value="Submit")
            with gr.Column():
                output_text = gr.Textbox(label="Output Text")

        submit_btn.click(run_example, [input_img, text_input, model_selector], [output_text])

demo.launch(debug=True)