mins
initial commit
b443c25
import re
from PIL import Image
import torch
import torch.nn as nn
from transformers import AutoModel, CLIPImageProcessor
from PIL import Image
import requests
import torch.nn.functional as F
from transformers import AutoProcessor, Pix2StructVisionModel, Pix2StructProcessor, Pix2StructForConditionalGeneration
cfg={
"crop_size": 256,
"do_center_crop": True,
"do_normalize": True,
"do_resize": True,
"feature_extractor_type": "CLIPFeatureExtractor",
"image_mean": [
0.48145466,
0.4578275,
0.40821073
],
"image_std": [
0.26862954,
0.26130258,
0.27577711
],
"resample": 3,
"size": 256
}
'''
Pixel2Struct-Large Model (pretrained version)
'''
class Pix2StructLargeVisionTower(nn.Module):
def __init__(self, vision_tower, args, delay_load=False):
super().__init__()
self.is_loaded = False
self.vision_tower_name = vision_tower
self.do_resize = args.do_resize
self.de_normalize = args.de_normalize # de-normalize the input image and perform preprocessing with pix2struct processor
self.select_layer = args.mm_vision_select_layer # NOTE: not implemented yet, this parameter has no effect
self.input_image_size = args.input_image_size
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
self.freeze_vision = args.freeze_vision
self.args = args
if not self.is_loaded:
self.load_model()
def load_model(self):
if self.is_loaded:
return
whole_model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-large")
self.vision_tower = whole_model.encoder
self.pix2struct_processor = AutoProcessor.from_pretrained("google/pix2struct-large")
self.pix2struct_processor.image_processor.is_vqa = False
self.image_processor = CLIPImageProcessor(**cfg)
if self.input_image_size is not None:
self.image_processor.size=self.input_image_size
self.image_processor.crop_size={
'height':self.input_image_size,
'width': self.input_image_size
}
if self.freeze_vision:
self.vision_tower.requires_grad_(False)
self.image_mean = torch.tensor(self.image_processor.image_mean).view(1, 3, 1, 1)
self.image_std = torch.tensor(self.image_processor.image_std).view(1, 3, 1, 1)
self.is_loaded = True
def feature_select(self, image_forward_outs):
image_features = image_forward_outs.hidden_states[self.select_layer] # [bs, n, c], cls at idx=0
if self.select_feature == 'patch':
image_features = image_features[:, 1:]
elif self.select_feature == 'cls_patch':
image_features = image_features
else:
raise ValueError(f'Unexpected select feature: {self.select_feature}')
return image_features
# @torch.no_grad()
def forward(self, images):
if self.de_normalize:
mean = self.image_mean.clone().view(1, 3, 1, 1).to(dtype=images.dtype, device=images.device)
std = self.image_std.clone().view(1, 3, 1, 1).to(dtype=images.dtype, device=images.device)
x = (images * std + mean) * 255.0
x = self.pix2struct_processor(images=x.float(), return_tensors="pt")
image_features = self.vision_tower(**(x.to(device=self.device, dtype=self.dtype))).last_hidden_state
bs, n, c = image_features.shape
image_features = image_features[:, :2025, :] # HARD CODE
if self.do_resize:
image_features = image_features.transpose(1,2).reshape(bs, c, 45, 45) # HARD CODE
image_features = F.interpolate(image_features.float(), size=(32, 32), mode='bilinear', align_corners=True).to(dtype=image_features.dtype) # HARD CODE
return image_features
else:
return image_features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return next(self.vision_tower.parameters()).dtype
@property
def device(self):
return next(self.vision_tower.parameters()).device
@property
def config(self):
return self.vision_tower.config
@property
def hidden_size(self):
#return self.config.hidden_size
hidden_dim = 1536
return hidden_dim
@property
def num_patches(self):
# return (self.config.image_size // self.config.patch_size) ** 2
return self.config['num_patches']
#main
if __name__ == "__main__":
'''
print('hello')
from PIL import Image
import requests
from transformers import AutoProcessor, Pix2StructVisionModel
model = Pix2StructVisionModel.from_pretrained("google/pix2struct-textcaps-base")
processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base")
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open("/lustre/fsw/portfolios/llmservice/users/fuxiaol/me.jpg")
for name, param in model.named_parameters():
param.requires_grad = False
#inputs = processor(images=image, return_tensors="pt")
image_processor = CLIPImageProcessor.from_pretrained('OpenGVLab/InternViT-6B-448px-V1-5')
pixel_values = image_processor(images=image, return_tensors='pt').pixel_values
pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
#inputs = pixel_values.to(torch.bfloat16)
print('pixel_values:', pixel_values.size())
inputs = processor(images=pixel_values, max_patches=1024, return_tensors='pt')['flattened_patches']
print(inputs.size())
print(inputs.size())
outputs = model(inputs)
print(outputs.last_hidden_state.size())
'''
cfg={
"crop_size": 1024,
"do_center_crop": True,
"do_normalize": True,
"do_resize": True,
"feature_extractor_type": "CLIPFeatureExtractor",
"image_mean": [
0.48145466,
0.4578275,
0.40821073
],
"image_std": [
0.26862954,
0.26130258,
0.27577711
],
"resample": 3,
"size": 1024
}
from PIL import Image
import requests
from transformers import AutoProcessor, Pix2StructForConditionalGeneration
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
import torchvision.transforms as T
processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-large")
model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-large")
#url = "https://www.ilankelman.org/stopsigns/australia.jpg"
#image = Image.open(requests.get(url, stream=True).raw)
image = Image.open("/lustre/fsw/portfolios/llmservice/users/fuxiaol/sample2.jpg")
image_processor= CLIPImageProcessor(**cfg)
pixel_values = image_processor(images=image, return_tensors='pt').pixel_values
print(pixel_values.size())
mean = [0.48145466, 0.4578275, 0.40821073]
std = [0.26862954, 0.26130258, 0.27577711]
mean = torch.tensor(mean).view(1, 3, 1, 1)
std = torch.tensor(std).view(1, 3, 1, 1)
pixel_values = pixel_values * std + mean
print(pixel_values.size())
#pixel_values.save('pix2image.jpg')
transform = T.ToPILImage()
img = transform(pixel_values.squeeze(0))
img.save('pix2image.jpg')
inputs = processor(images=pixel_values, max_patches=1024,return_tensors="pt")['flattened_patches']
# autoregressive generation
generated_ids = model.generate(inputs, max_new_tokens=50)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(generated_text)
#A stop sign is on a street corner.
#A stop sign is on a street corner.
'''
from PIL import Image
import requests
from transformers import AutoProcessor, CLIPModel
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14-336")
model = CLIPVisionModel.from_pretrained('openai/clip-vit-large-patch14-336')
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
print(image)
inputs = processor(images=image, return_tensors="pt")
#image_features = model.get_image_features(**inputs)
outputs = model(**inputs,output_hidden_states=True)
print(outputs.hidden_states[-1].size())
print(outputs.hidden_states[-2].size())
print(outputs.hidden_states[-3].size())
'''
#sequence = processor.batch_decode(outputs, skip_special_tokens=True)[0]
#sequence = processor.post_process_generation(sequence, fix_markdown=False)
# note: we're using repr here such for the sake of printing the \n characters, feel free to just print the sequence
#print(repr(sequence))