HarryLee commited on
Commit
fa57f25
·
1 Parent(s): 8135811

initial commit

Browse files
Files changed (1) hide show
  1. app.py +123 -0
app.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.system('cd fairseq;'
4
+ 'pip install --use-feature=in-tree-build ./; cd ..')
5
+ os.system('ls -l')
6
+
7
+ import torch
8
+ import numpy as np
9
+ from fairseq import utils, tasks
10
+ from fairseq import checkpoint_utils
11
+ from utils.eval_utils import eval_step
12
+ from tasks.mm_tasks.caption import CaptionTask
13
+ from models.ofa import OFAModel
14
+ from PIL import Image
15
+ from torchvision import transforms
16
+ import gradio as gr
17
+
18
+ # Register caption task
19
+ tasks.register_task('caption', CaptionTask)
20
+ # turn on cuda if GPU is available
21
+ use_cuda = torch.cuda.is_available()
22
+ # use fp16 only when GPU is available
23
+ use_fp16 = False
24
+
25
+ os.system('wget https://ofa-silicon.oss-us-west-1.aliyuncs.com/checkpoints/caption_large_best_clean.pt; '
26
+ 'mkdir -p checkpoints; mv caption_large_best_clean.pt checkpoints/caption.pt')
27
+
28
+ # Load pretrained ckpt & config
29
+ overrides = {"bpe_dir": "utils/BPE", "eval_cider": False, "beam": 5,
30
+ "max_len_b": 16, "no_repeat_ngram_size": 3, "seed": 7}
31
+ models, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
32
+ utils.split_paths('checkpoints/caption.pt'),
33
+ arg_overrides=overrides
34
+ )
35
+
36
+ # Move models to GPU
37
+ for model in models:
38
+ model.eval()
39
+ if use_fp16:
40
+ model.half()
41
+ if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
42
+ model.cuda()
43
+ model.prepare_for_inference_(cfg)
44
+
45
+ # Initialize generator
46
+ generator = task.build_generator(models, cfg.generation)
47
+
48
+ mean = [0.5, 0.5, 0.5]
49
+ std = [0.5, 0.5, 0.5]
50
+
51
+ patch_resize_transform = transforms.Compose([
52
+ lambda image: image.convert("RGB"),
53
+ transforms.Resize((cfg.task.patch_image_size, cfg.task.patch_image_size), interpolation=Image.BICUBIC),
54
+ transforms.ToTensor(),
55
+ transforms.Normalize(mean=mean, std=std),
56
+ ])
57
+
58
+ # Text preprocess
59
+ bos_item = torch.LongTensor([task.src_dict.bos()])
60
+ eos_item = torch.LongTensor([task.src_dict.eos()])
61
+ pad_idx = task.src_dict.pad()
62
+
63
+
64
+ def encode_text(text, length=None, append_bos=False, append_eos=False):
65
+ s = task.tgt_dict.encode_line(
66
+ line=task.bpe.encode(text),
67
+ add_if_not_exist=False,
68
+ append_eos=False
69
+ ).long()
70
+ if length is not None:
71
+ s = s[:length]
72
+ if append_bos:
73
+ s = torch.cat([bos_item, s])
74
+ if append_eos:
75
+ s = torch.cat([s, eos_item])
76
+ return s
77
+
78
+
79
+ # Construct input for caption task
80
+ def construct_sample(image: Image):
81
+ patch_image = patch_resize_transform(image).unsqueeze(0)
82
+ patch_mask = torch.tensor([True])
83
+ src_text = encode_text(" what does the image describe?", append_bos=True, append_eos=True).unsqueeze(0)
84
+ src_length = torch.LongTensor([s.ne(pad_idx).long().sum() for s in src_text])
85
+ sample = {
86
+ "id": np.array(['42']),
87
+ "net_input": {
88
+ "src_tokens": src_text,
89
+ "src_lengths": src_length,
90
+ "patch_images": patch_image,
91
+ "patch_masks": patch_mask
92
+ }
93
+ }
94
+ return sample
95
+
96
+
97
+ # Function to turn FP32 to FP16
98
+ def apply_half(t):
99
+ if t.dtype is torch.float32:
100
+ return t.to(dtype=torch.half)
101
+ return t
102
+
103
+
104
+ # Function for image captioning
105
+ def image_caption(Image):
106
+ sample = construct_sample(Image)
107
+ sample = utils.move_to_cuda(sample) if use_cuda else sample
108
+ sample = utils.apply_to_sample(apply_half, sample) if use_fp16 else sample
109
+ with torch.no_grad():
110
+ result, scores = eval_step(task, generator, models, sample)
111
+ return result[0]['caption']
112
+
113
+
114
+ title = "eRupt e-Commerce Image Captioning"
115
+ description = "Online Demo for e-Commerce Image Captioning. Upload your own image or click any one of the examples, and click " \
116
+ "\"Submit\" and then wait for the generated caption. "
117
+ #article = "<p style='text-align: center'><a href='https://github.com/OFA-Sys/OFA' target='_blank'>OFA Github " \
118
+ # "Repo</a></p> "
119
+ examples = [['0001.jpg'], ['0002.jpg'], ['0003.jpg'], ['0004.jpg'], ['0005.jpg']]
120
+ io = gr.Interface(fn=image_caption, inputs=gr.inputs.Image(type='pil'), outputs=gr.outputs.Textbox(label="Caption"),
121
+ title=title, description=description, article=article, examples=examples,
122
+ allow_flagging=False, allow_screenshot=False)
123
+ io.launch(cache_examples=True)