toshi456 commited on
Commit
744af63
·
verified ·
1 Parent(s): 62a1e5c

Upload 6 files

Browse files
Files changed (7) hide show
  1. .gitattributes +5 -0
  2. app.py +182 -0
  3. imgs/sample1.jpg +3 -0
  4. imgs/sample2.jpg +3 -0
  5. imgs/sample3.jpg +3 -0
  6. imgs/sample4.jpg +3 -0
  7. imgs/sample5.jpg +3 -0
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ imgs/sample1.jpg filter=lfs diff=lfs merge=lfs -text
37
+ imgs/sample2.jpg filter=lfs diff=lfs merge=lfs -text
38
+ imgs/sample3.jpg filter=lfs diff=lfs merge=lfs -text
39
+ imgs/sample4.jpg filter=lfs diff=lfs merge=lfs -text
40
+ imgs/sample5.jpg filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import transformers
4
+
5
+ from llava.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
6
+ from llava.conversation import conv_templates
7
+ from llava.model.llava_gpt2 import LlavaGpt2ForCausalLM
8
+ from llava.train.arguments_dataclass import ModelArguments, DataArguments, TrainingArguments
9
+ from llava.train.dataset import tokenizer_image_token
10
+
11
+
12
+ # load model
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ torch_dtype = torch.bfloat16 if device=="cuda" else torch.float32
15
+ model_path = 'toshi456/llava-jp-1.3b-v1.1'
16
+
17
+ model = LlavaGpt2ForCausalLM.from_pretrained(
18
+ model_path,
19
+ low_cpu_mem_usage=True,
20
+ use_safetensors=True,
21
+ torch_dtype=torch_dtype,
22
+ device_map=device,
23
+ )
24
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
25
+ model_path,
26
+ model_max_length=1024,
27
+ padding_side="right",
28
+ use_fast=False,
29
+ )
30
+ model.eval()
31
+ conv_mode = "v1"
32
+
33
+
34
+ @torch.inference_mode()
35
+ def inference_fn(
36
+ image,
37
+ prompt,
38
+ max_len,
39
+ temperature,
40
+ top_p,
41
+ ):
42
+ # prepare inputs
43
+ # image pre-process
44
+ image_size = model.get_model().vision_tower.image_processor.size["height"]
45
+ if model.get_model().vision_tower.scales is not None:
46
+ image_size = model.get_model().vision_tower.image_processor.size["height"] * len(model.get_model().vision_tower.scales)
47
+
48
+ if device == "cuda":
49
+ image_tensor = model.get_model().vision_tower.image_processor(
50
+ image,
51
+ return_tensors='pt',
52
+ size={"height": image_size, "width": image_size}
53
+ )['pixel_values'].half().cuda().to(torch_dtype)
54
+ else:
55
+ image_tensor = model.get_model().vision_tower.image_processor(
56
+ image,
57
+ return_tensors='pt',
58
+ size={"height": image_size, "width": image_size}
59
+ )['pixel_values'].to(torch_dtype)
60
+
61
+ # create prompt
62
+ inp = DEFAULT_IMAGE_TOKEN + '\n' + prompt
63
+ conv = conv_templates[conv_mode].copy()
64
+ conv.append_message(conv.roles[0], inp)
65
+ conv.append_message(conv.roles[1], None)
66
+ prompt = conv.get_prompt()
67
+
68
+ input_ids = tokenizer_image_token(
69
+ prompt,
70
+ tokenizer,
71
+ IMAGE_TOKEN_INDEX,
72
+ return_tensors='pt'
73
+ ).unsqueeze(0)
74
+ if device == "cuda":
75
+ input_ids = input_ids.to(device)
76
+
77
+ input_ids = input_ids[:, :-1] # </sep>がinputの最後に入るので削除する
78
+
79
+ # generate
80
+ output_ids = model.generate(
81
+ inputs=input_ids,
82
+ images=image_tensor,
83
+ do_sample= temperature != 0.0,
84
+ temperature=temperature,
85
+ top_p=top_p,
86
+ max_new_tokens=max_len,
87
+ use_cache=True,
88
+ )
89
+ output_ids = [token_id for token_id in output_ids.tolist()[0] if token_id != IMAGE_TOKEN_INDEX]
90
+ output = tokenizer.decode(output_ids, skip_special_tokens=True)
91
+
92
+ target = "システム: "
93
+ idx = output.find(target)
94
+ output = output[idx+len(target):]
95
+
96
+ return output
97
+
98
+ with gr.Blocks() as demo:
99
+ gr.Markdown(f"# LLaVA-JP Demo")
100
+
101
+ with gr.Row():
102
+ with gr.Column():
103
+ # input_instruction = gr.TextArea(label="instruction", value=DEFAULT_INSTRUCTION)
104
+ input_image = gr.Image(type="pil", label="image")
105
+ prompt = gr.Textbox(label="prompt (optional)", value="")
106
+ with gr.Accordion(label="Configs", open=False):
107
+ max_len = gr.Slider(
108
+ minimum=10,
109
+ maximum=256,
110
+ value=128,
111
+ step=5,
112
+ interactive=True,
113
+ label="Max New Tokens",
114
+ )
115
+
116
+ temperature = gr.Slider(
117
+ minimum=0.0,
118
+ maximum=1.0,
119
+ value=0.1,
120
+ step=0.1,
121
+ interactive=True,
122
+ label="Temperature",
123
+ )
124
+
125
+ top_p = gr.Slider(
126
+ minimum=0.5,
127
+ maximum=1.0,
128
+ value=0.9,
129
+ step=0.1,
130
+ interactive=True,
131
+ label="Top p",
132
+ )
133
+
134
+ # button
135
+ input_button = gr.Button(value="Submit")
136
+ with gr.Column():
137
+ output = gr.Textbox(label="Output")
138
+
139
+ inputs = [input_image, prompt, max_len, temperature, top_p]
140
+ input_button.click(inference_fn, inputs=inputs, outputs=[output])
141
+ prompt.submit(inference_fn, inputs=inputs, outputs=[output])
142
+ img2txt_examples = gr.Examples(examples=[
143
+ [
144
+ "./imgs/sample1.jpg",
145
+ "猫は何をしていますか?",
146
+ 32,
147
+ 0.0,
148
+ 0.9,
149
+ ],
150
+ [
151
+ "./imgs/sample2.jpg",
152
+ "この自動販売機にはどのブランドの飲料が含まれていますか?",
153
+ 256,
154
+ 0.0,
155
+ 0.9,
156
+ ],
157
+ [
158
+ "./imgs/sample3.jpg",
159
+ "この料理の作り方を教えてください。",
160
+ 256,
161
+ 0.0,
162
+ 0.9,
163
+ ],
164
+ [
165
+ "./imgs/sample4.jpg",
166
+ "このコンピュータの名前を教えてください。",
167
+ 256,
168
+ 0.0,
169
+ 0.9,
170
+ ],
171
+ [
172
+ "./imgs/sample5.jpg",
173
+ "これらを使って作ることができる料理を教えてください。",
174
+ 256,
175
+ 0.0,
176
+ 0.9,
177
+ ],
178
+ ], inputs=inputs)
179
+
180
+
181
+ if __name__ == "__main__":
182
+ demo.queue().launch(share=True)
imgs/sample1.jpg ADDED

Git LFS Details

  • SHA256: 66b0df6b0906b24089bf24b9ad0a82efb75d410d9f59a96df8e1a5f7f4a36fdd
  • Pointer size: 132 Bytes
  • Size of remote file: 2.6 MB
imgs/sample2.jpg ADDED

Git LFS Details

  • SHA256: 2318c8fb84a2fd6e077d3b02a8898c87220182ca35d755c74264d6807ddc6563
  • Pointer size: 132 Bytes
  • Size of remote file: 2.55 MB
imgs/sample3.jpg ADDED

Git LFS Details

  • SHA256: 289defa9edf31eddfcac6a0d396326acf213da2445a059e11e3f8c74f45388a6
  • Pointer size: 132 Bytes
  • Size of remote file: 2.43 MB
imgs/sample4.jpg ADDED

Git LFS Details

  • SHA256: 77dde212742b8b6b7f4898759db189b896fdc19a0abf01717caecdb4927b2672
  • Pointer size: 132 Bytes
  • Size of remote file: 1.63 MB
imgs/sample5.jpg ADDED

Git LFS Details

  • SHA256: ca02d3b4e57dadc0931c7b84768fee14d3273410e6b5ced80835588c46e0855f
  • Pointer size: 132 Bytes
  • Size of remote file: 2.05 MB