|
import os |
|
|
|
os.system('cd TimeSformer;' |
|
'pip install .; cd ..') |
|
|
|
os.system('ls -l') |
|
os.system('pwd') |
|
|
|
import os, sys |
|
sys.path.append("/home/user/app/TimeSformer/") |
|
|
|
import timesformer |
|
|
|
|
|
import torch |
|
from torchvision import transforms |
|
|
|
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
from PIL import Image |
|
import json |
|
import os |
|
|
|
from torchvision import transforms |
|
|
|
from models.epalm import ePALM |
|
|
|
import os |
|
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
from ruamel.yaml import YAML |
|
|
|
import torch |
|
import gradio as gr |
|
|
|
import torchaudio |
|
|
|
yaml=YAML(typ='safe') |
|
|
|
|
|
|
|
use_cuda = torch.cuda.is_available() |
|
device = torch.device('cuda') if use_cuda else torch.device('cpu') |
|
device_type = 'cuda' if use_cuda else 'cpu' |
|
|
|
|
|
|
|
|
|
config = 'configs/image/ePALM_caption.yaml' |
|
|
|
config = yaml.load(open(config, 'r')) |
|
|
|
text_model = 'facebook/opt-2.7b' |
|
vision_model_name = 'vit_base_patch16_224' |
|
|
|
|
|
|
|
|
|
start_layer_idx = 19 |
|
end_layer_idx = 31 |
|
low_cpu = True |
|
model_caption = ePALM(opt_model_name=text_model, |
|
vision_model_name=vision_model_name, |
|
use_vis_prefix=True, |
|
start_layer_idx=start_layer_idx, |
|
end_layer_idx=end_layer_idx, |
|
return_hidden_state_vision=True, |
|
config=config, |
|
low_cpu=low_cpu |
|
) |
|
print("Model Built") |
|
model_caption.to(device) |
|
|
|
checkpoint_path = 'checkpoints/float32/ePALM_caption/checkpoint_best.pth' |
|
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu') |
|
state_dict = checkpoint['model'] |
|
msg = model_caption.load_state_dict(state_dict,strict=False) |
|
|
|
model_caption.bfloat16() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
checkpoint_path = 'checkpoints/float32/ePALM_vqa/checkpoint_best.pth' |
|
checkpoint = torch.load(checkpoint_path, map_location='cpu') |
|
state_dict_vqa = checkpoint['model'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
checkpoint_path = 'checkpoints/float32/ePALM_video_caption_msrvtt/checkpoint_best.pth' |
|
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu') |
|
state_dict_video_caption = checkpoint['model'] |
|
|
|
|
|
checkpoint_path = 'checkpoints/float32/ePALM_video_qa_msrvtt/checkpoint_best.pth' |
|
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu') |
|
state_dict_video_qa = checkpoint['model'] |
|
|
|
|
|
|
|
checkpoint_path = 'checkpoints/float32/ePALM_audio_caption/checkpoint_best.pth' |
|
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu') |
|
state_dict_audio_caption = checkpoint['model'] |
|
|
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(text_model, use_fast=False) |
|
eos_token = tokenizer.eos_token |
|
pad_token = tokenizer.pad_token |
|
|
|
special_answer_token = '</a>' |
|
|
|
special_tokens_dict = {'additional_special_tokens': [special_answer_token]} |
|
tokenizer.add_special_tokens(special_tokens_dict) |
|
|
|
|
|
image_size = 224 |
|
normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) |
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((image_size,image_size),interpolation=Image.BICUBIC), |
|
transforms.ToTensor(), |
|
normalize, |
|
]) |
|
|
|
type_transform = transforms.Lambda(lambda x: x.float().div(255.0)) |
|
test_transform = transforms.Compose([ |
|
transforms.Resize((image_size,image_size),interpolation=Image.BICUBIC), |
|
type_transform, |
|
normalize, |
|
]) |
|
from dataset.video_utils import VIDEO_READER_FUNCS |
|
video_reader = VIDEO_READER_FUNCS['decord'] |
|
|
|
def read_video(path, num_frames=16): |
|
|
|
|
|
frames, frame_indices, video_duration = video_reader( |
|
path, num_frames, 'rand', max_num_frames=-1 |
|
) |
|
video = test_transform(frames) |
|
|
|
return video |
|
|
|
def read_audio(path): |
|
|
|
melbins = 128 |
|
target_length = 1024 |
|
skip_norm = False |
|
norm_mean = -4.2677393 |
|
norm_std = 4.5689974 |
|
|
|
waveform, sr = torchaudio.load(path) |
|
waveform = waveform - waveform.mean() |
|
|
|
|
|
fbank = torchaudio.compliance.kaldi.fbank(waveform, htk_compat=True, sample_frequency=sr, use_energy=False, |
|
window_type='hanning', num_mel_bins=melbins, dither=0.0, |
|
frame_shift=10) |
|
|
|
n_frames = fbank.shape[0] |
|
|
|
p = target_length - n_frames |
|
|
|
|
|
if p > 0: |
|
m = torch.nn.ZeroPad2d((0, 0, 0, p)) |
|
fbank = m(fbank) |
|
elif p < 0: |
|
fbank = fbank[0:target_length, :] |
|
|
|
|
|
|
|
|
|
|
|
|
|
fbank = torch.transpose(fbank, 0, 1) |
|
|
|
fbank = fbank.unsqueeze(0) |
|
|
|
|
|
|
|
|
|
fbank = fbank.squeeze(0) |
|
fbank = torch.transpose(fbank, 0, 1) |
|
|
|
|
|
|
|
if not skip_norm: |
|
fbank = (fbank - norm_mean) / (norm_std * 2) |
|
|
|
else: |
|
pass |
|
|
|
|
|
audio = fbank |
|
|
|
return audio |
|
|
|
do_sample=False |
|
num_beams=3 |
|
max_length=30 |
|
|
|
|
|
|
|
|
|
|
|
def inference(image, audio, video, task_type, instruction): |
|
|
|
if task_type == 'Image Captioning': |
|
text = [''] |
|
text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device) |
|
model = model_caption |
|
elif task_type == 'Video Captioning': |
|
text = [''] |
|
text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device) |
|
model_caption = model_caption.load_state_dict(state_dict_video_caption,strict=False) |
|
model = model_caption |
|
elif task_type == 'Audio Captioning': |
|
text = [''] |
|
text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device) |
|
model_caption = model_caption.load_state_dict(state_dict_audio_caption,strict=False) |
|
model = model_caption |
|
elif task_type == 'Visual Question Answering': |
|
question = instruction+'?'+special_answer_token |
|
text_input = tokenizer(question, padding='longest', return_tensors="pt").to(device) |
|
model_caption = model_caption.load_state_dict(state_dict_vqa,strict=False) |
|
model = model_caption |
|
elif task_type == 'Visual Question Answering': |
|
question = instruction+'?'+special_answer_token |
|
text_input = tokenizer(question, padding='longest', return_tensors="pt").to(device) |
|
model_caption = model_caption.load_state_dict(state_dict_video_qa,strict=False) |
|
model = model_caption |
|
else: |
|
raise NotImplemented |
|
|
|
if "Video" in task_type: |
|
image = read_video(image) |
|
elif "Audio" in task_type: |
|
image = read_audio(image) |
|
else: |
|
image = transform(image) |
|
image = image.to(device,non_blocking=True).unsqueeze(0) |
|
|
|
|
|
|
|
|
|
with torch.autocast(device_type=device_type, dtype=torch.bfloat16, enabled=True): |
|
|
|
out = model(image=image, text=text_input, mode='generate', return_dict=True, max_length=max_length, |
|
do_sample=do_sample, num_beams=num_beams) |
|
|
|
|
|
if 'Captioning' in task_type: |
|
for i, o in enumerate(out): |
|
res = tokenizer.decode(o) |
|
response = res.split('</s>')[1].replace(pad_token, '').replace('</s>', '').replace(eos_token, '') |
|
else: |
|
for o in out: |
|
o_list = o.tolist() |
|
response = tokenizer.decode(o_list).split(special_answer_token)[1].replace(pad_token, '').replace('</s>', '').replace(eos_token, '') |
|
|
|
return response |
|
|
|
|
|
inputs = [gr.inputs.Image(type='pil'), gr.Audio(source="upload", type="filepath"), gr.Video(source="upload", type="filepath"), gr.inputs.Radio(choices=['Image Captioning', 'Video Captioning', 'Audio Captioning', "Visual Question Answering", "Visual Grounding", "General", "General Video"], type="value", default="Image Captioning", label="Task"), gr.inputs.Textbox(lines=1, label="Instruction")] |
|
outputs = ['text'] |
|
examples = [ |
|
['examples/images/soccer.jpg', None, None, 'Image Captioning', None], |
|
['examples/images/ski.jpg', None, None, 'Visual Question Answering', 'what does the woman wearing black do?'], |
|
['examples/images/banana.jpg', None, None, 'Image Captioning', None], |
|
['examples/images/skateboard.jpg', None, None, 'Visual Question Answering', 'what is on top of the skateboard?'], |
|
['examples/images/baseball.jpg', None, None, 'Image Captioning', None], |
|
[None, None, 'examples/videos/video7014.mp4', 'Video Captioning', None], |
|
[None, None, 'examples/videos/video7017.mp4', 'Video Captioning', None], |
|
[None, None, 'examples/videos/video7019.mp4', 'Video Captioning', None], |
|
[None, None, 'examples/videos/video7021.mp4', 'Video Captioning', None], |
|
[None, None, 'examples/videos/video7021.mp4', 'Video Captioning', None], |
|
[None, 'examples/audios/6cS0FsUM-cQ.wav', None, 'Audio Captioning', None], |
|
[None, 'examples/audios/AJtNitYMa1I.wav', None, 'Audio Captioning', None], |
|
] |
|
|
|
title = "eP-ALM" |
|
description = "Gradio Demo for eP-ALM: " |
|
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2303.11403' target='_blank'>Paper</a> | <a href='https://github.com/mshukor/eP-ALM' target='_blank'>Github Repo</a></p>" |
|
|
|
io = gr.Interface(fn=inference, inputs=inputs, outputs=outputs, |
|
title=title, description=description, article=article, examples=examples, cache_examples=False) |
|
io.launch() |