Spaces:
Runtime error
Runtime error
from typing import Optional, List | |
import os | |
import math | |
from PIL import Image | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import requests | |
from tokenizers import Tokenizer | |
import matplotlib.pyplot as plt | |
from torchvision.transforms.functional import center_crop | |
from fourm.models.fm import FM | |
from fourm.vq.vqvae import VQVAE, DiVAE | |
from fourm.models.generate import GenerationSampler, build_chained_generation_schedules, init_empty_target_modality, init_full_input_modality, custom_text | |
from fourm.utils.plotting_utils import decode_dict | |
from fourm.data.modality_info import MODALITY_INFO | |
from fourm.data.modality_transforms import RGBTransform | |
from fourm.utils import load_safetensors | |
from fourm.utils.plotting_utils import decode_dict, visualize_bboxes, plot_text_in_square, text_to_pil_image | |
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False in PyTorch 1.12 and later. | |
torch.backends.cuda.matmul.allow_tf32 = True | |
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. | |
torch.backends.cudnn.allow_tf32 = True | |
# Default chained generation order | |
DEFAULT_ORDER = [ | |
'tok_clip@224', 'tok_dinov2@224', 'tok_imagebind@224', 'tok_depth@224', 'tok_normal@224', | |
'tok_semseg@224', 'tok_canny_edge@224', 'tok_sam_edge@224', 'tok_rgb@224', | |
'caption', 'det', 'human_poses', 'sam_instance', 'color_palette', 'metadata', | |
] | |
# Default super-resolution chained generation order | |
DEFAULT_ORDER_SR = [ | |
'tok_clip@448', 'tok_depth@448', 'tok_normal@448', | |
'tok_semseg@448', 'tok_rgb@448', | |
] | |
# Default generation parameters for the case where the input contains RGB | |
DEFAULTS_RGB2X = { | |
'tok_clip@224/tok_depth@224/tok_normal@224/tok_semseg@224/tok_canny_edge@224/tok_sam_edge@224': { | |
'tokens_per_target': 196, 'autoregression_scheme': 'roar', 'decoding_steps': 1, | |
'token_decoding_schedule': 'linear', 'temp': 0.01, 'temp_schedule': 'constant', | |
'cfg_scale': 2.0, 'cfg_schedule': 'constant', | |
}, | |
'tok_dinov2@224/tok_imagebind@224': { | |
'tokens_per_target': 256, 'autoregression_scheme': 'roar', 'decoding_steps': 1, | |
'token_decoding_schedule': 'linear', 'temp': 0.01, 'temp_schedule': 'constant', | |
'cfg_scale': 2.0, 'cfg_schedule': 'constant', | |
}, | |
'caption/det': { | |
'tokens_per_target': 256, 'autoregression_scheme': 'autoregressive', 'decoding_steps': None, | |
'token_decoding_schedule': None, 'temp': 0.3, 'temp_schedule': 'constant', | |
'cfg_scale': 1.0, 'cfg_schedule': 'constant', | |
}, | |
'human_poses': { | |
'tokens_per_target': 275, 'autoregression_scheme': 'autoregressive', 'decoding_steps': None, | |
'token_decoding_schedule': None, 'temp': 0.1, 'temp_schedule': 'constant', | |
'cfg_scale': 1.0, 'cfg_schedule': 'constant', | |
}, | |
'sam_instance': { | |
'tokens_per_target': 256, 'autoregression_scheme': 'autoregressive', 'decoding_steps': None, | |
'token_decoding_schedule': None, 'temp': 0.01, 'temp_schedule': 'constant', | |
'cfg_scale': 1.0, 'cfg_schedule': 'constant', | |
}, | |
'color_palette': { | |
'tokens_per_target': 23, 'autoregression_scheme': 'autoregressive', 'decoding_steps': None, | |
'token_decoding_schedule': None, 'temp': 0.1, 'temp_schedule': 'constant', | |
'cfg_scale': 1.0, 'cfg_schedule': 'constant', | |
}, | |
'metadata': { | |
'tokens_per_target': 40, 'autoregression_scheme': 'autoregressive', 'decoding_steps': None, | |
'token_decoding_schedule': None, 'temp': 0.1, 'temp_schedule': 'constant', | |
'cfg_scale': 1.0, 'cfg_schedule': 'constant', | |
}, | |
} | |
# Default generation parameters for the case where the target is RGB | |
DEFAULTS_X2RGB = { | |
'tok_clip@224': { | |
'tokens_per_target': 196, 'autoregression_scheme': 'roar', 'decoding_steps': 50, | |
'token_decoding_schedule': 'linear', 'temp': 5.0, 'temp_schedule': 'onex:0.5:0.5', | |
'cfg_scale': 3.0, 'cfg_schedule': 'constant', | |
}, | |
'tok_dinov2@224/tok_imagebind@224': { | |
'tokens_per_target': 256, 'autoregression_scheme': 'roar', 'decoding_steps': 8, | |
'token_decoding_schedule': 'linear', 'temp': 0.01, 'temp_schedule': 'constant', | |
'cfg_scale': 2.0, 'cfg_schedule': 'constant', | |
}, | |
'tok_depth@224/tok_normal@224/tok_semseg@224/tok_canny_edge@224/tok_sam_edge@224': { | |
'tokens_per_target': 196, 'autoregression_scheme': 'roar', 'decoding_steps': 8, | |
'token_decoding_schedule': 'linear', 'temp': 3.0, 'temp_schedule': 'onex:0.5:0.5', | |
'cfg_scale': 2.0, 'cfg_schedule': 'constant', | |
}, | |
'tok_rgb@224': { | |
'tokens_per_target': 196, 'autoregression_scheme': 'roar', 'decoding_steps': 25, | |
'token_decoding_schedule': 'linear', 'temp': 3.0, 'temp_schedule': 'onex:0.5:0.5', | |
'cfg_scale': 2.0, 'cfg_schedule': 'constant', | |
}, | |
'caption/det': { | |
'tokens_per_target': 256, 'autoregression_scheme': 'autoregressive', 'decoding_steps': None, | |
'token_decoding_schedule': None, 'temp': 0.3, 'temp_schedule': 'constant', | |
'cfg_scale': 1.0, 'cfg_schedule': 'constant', | |
}, | |
'human_poses': { | |
'tokens_per_target': 275, 'autoregression_scheme': 'autoregressive', 'decoding_steps': None, | |
'token_decoding_schedule': None, 'temp': 0.1, 'temp_schedule': 'constant', | |
'cfg_scale': 1.0, 'cfg_schedule': 'constant', | |
}, | |
'sam_instance': { | |
'tokens_per_target': 256, 'autoregression_scheme': 'autoregressive', 'decoding_steps': None, | |
'token_decoding_schedule': None, 'temp': 0.01, 'temp_schedule': 'constant', | |
'cfg_scale': 1.0, 'cfg_schedule': 'constant', | |
}, | |
'color_palette': { | |
'tokens_per_target': 23, 'autoregression_scheme': 'autoregressive', 'decoding_steps': None, | |
'token_decoding_schedule': None, 'temp': 0.1, 'temp_schedule': 'constant', | |
'cfg_scale': 1.0, 'cfg_schedule': 'constant', | |
}, | |
'metadata': { | |
'tokens_per_target': 40, 'autoregression_scheme': 'autoregressive', 'decoding_steps': None, | |
'token_decoding_schedule': None, 'temp': 0.1, 'temp_schedule': 'constant', | |
'cfg_scale': 1.0, 'cfg_schedule': 'constant', | |
}, | |
} | |
# Default generation parameters for super-resolution | |
DEFAULTS_SR = { | |
'tok_clip@448/tok_depth@448/tok_normal@448/tok_semseg@448/tok_rgb@448': { | |
'tokens_per_target': 784, 'autoregression_scheme': 'maskgit', 'decoding_steps': 8, | |
'token_decoding_schedule': 'cosine', 'temp': 1.0, 'temp_schedule': 'constant', | |
'cfg_scale': 2.0, 'cfg_schedule': 'constant', | |
}, | |
} | |
# Plotting names for each modality | |
MODALITY_PLOTTING_NAME_MAP = { | |
'caption': 'Caption', | |
'det': 'Bounding boxes', | |
'human_poses': 'Human poses', | |
'sam_instance': 'SAM instances (single pass)', | |
'color_palette': 'Color palette', | |
'metadata': 'Metadata', | |
'rgb@224': 'RGB (224x224)', | |
'rgb@448': 'RGB (448x448)', | |
'tok_rgb@224': 'RGB (tokenized, 224x224)', | |
'tok_rgb@448': 'RGB (tokenized, 448x448)', | |
'tok_clip@224': 'CLIP-B/16 (224x224)', | |
'tok_clip@448': 'CLIP-B/16 (448x448)', | |
'tok_depth@224': 'Depth (224x224)', | |
'tok_depth@448': 'Depth (448x448)', | |
'tok_normal@224': 'Normals (224x224)', | |
'tok_normal@448': 'Normals (448x448)', | |
'tok_semseg@224': 'Semantic segmentation (224x224)', | |
'tok_semseg@448': 'Semantic segmentation (448x448)', | |
'tok_canny_edge@224': 'Canny edges (224x224)', | |
'tok_sam_edge@224': 'SAM edges (224x224)', | |
'tok_dinov2@224': 'DINOv2-B/14 (224x224)', | |
'tok_imagebind@224': 'ImageBind-H/14 (224x224)', | |
} | |
# Optional fixed plotting order (by default, plotting order is determined by generation order) | |
MODALITY_PLOTTING_ORDER = [ | |
'rgb@224', 'rgb@448', 'tok_rgb@224', 'tok_rgb@448', | |
'tok_depth@224', 'tok_depth@448', 'tok_normal@224', 'tok_normal@448', | |
'tok_semseg@224', 'tok_semseg@448', 'tok_canny_edge@224', 'tok_sam_edge@224', | |
'sam_instance', 'human_poses', 'det', 'caption', 'metadata', 'color_palette', | |
'tok_clip@224', 'tok_clip@448', 'tok_dinov2@224', 'tok_imagebind@224', | |
] | |
def get_value(defaults_dict, domain, key): | |
"""Look up a default value belonging to a given domain and key.""" | |
for domains, defaults in defaults_dict.items(): | |
if domain in domains: | |
return defaults[key] | |
def load_model(model_id, model_class): | |
"""Load a model from HuggingFace hub or a given .safetensors checkpoint path.""" | |
if model_id.endswith('.safetensors'): | |
ckpt, config = load_safetensors(model_id) | |
model = model_class(config=config) | |
model.load_state_dict(ckpt) | |
else: | |
model = model_class.from_pretrained(model_id) | |
return model | |
def img_from_url(url: str): | |
rgb_transform = RGBTransform(imagenet_default_mean_and_std=True) | |
img_data = requests.get(url).content | |
with open('demo.png', 'wb') as handler: | |
handler.write(img_data) | |
img_pil = rgb_transform.load('./demo.png') | |
img_pil = rgb_transform.preprocess(img_pil) | |
img_pil = center_crop(img_pil, (min(img_pil.size), min(img_pil.size))).resize((224,224)) | |
img = rgb_transform.postprocess(img_pil).unsqueeze(0) | |
return img | |
class Demo4MSampler(nn.Module): | |
"""Convenience wrapper for easy 4M loading and generation. Users can specify HuggingFace Hub | |
model URLs, or downloaded safetensors checkpoints paths, and the models will be automatically | |
loaded. The `forward` function can be used for RGB-2-all and {caption,det}-2-all generation. | |
This wrapper is only intended for quickly trying out 4M models. For more advanced usecases we | |
recommend looking at the generation notebooks in `./notebooks/`, and `./run_generation.py`. | |
Args: | |
fm: Hub or safetensors path of 4M base model | |
fm_sr: Hub or safetensors path of 4M super-resolution model | |
tok_rgb: Hub or safetensors path of RGB tokenizer | |
tok_depth: Hub or safetensors path of depth tokenizer | |
tok_normal: Hub or safetensors path of surface normal tokenizer | |
tok_edge: Hub or safetensors path of canny edge tokenizer (for SAM and RGB edges) | |
tok_semseg: Hub or safetensors path of COCO semantic segmentation tokenizer | |
tok_clip: Hub or safetensors path of CLIP-B/16 tokenizer | |
tok_dinov2: Hub or safetensors path of DINOv2-B/14 tokenizer | |
tok_imagebind: Hub or safetensors path of ImageBind-H/14 tokenizer | |
tok_sam_instance: Hub or safetensors path of SAM instance tokenizer | |
tok_human_poses: Hub or safetensors path of human poses tokenizer | |
tok_text: Path to text tokenizer JSON file | |
mods: Optional list of modalities to override default behavior of generating everything | |
mods_sr: Optional list of super-res modalities to override default behavior of generating everything | |
""" | |
def __init__(self, | |
fm: str = 'EPFL-VILAB/4M-21_XL_CC12M', | |
fm_sr: Optional[str] = 'EPFL-VILAB/4M-7-SR_L_CC12M', | |
tok_rgb: Optional[str] = 'EPFL-VILAB/4M_tokenizers_rgb_16k_224-448', | |
tok_depth: Optional[str] = 'EPFL-VILAB/4M_tokenizers_depth_8k_224-448', | |
tok_normal: Optional[str] = 'EPFL-VILAB/4M_tokenizers_normal_8k_224-448', | |
tok_edge: Optional[str] = 'EPFL-VILAB/4M_tokenizers_edge_8k_224-512', | |
tok_semseg: Optional[str] = 'EPFL-VILAB/4M_tokenizers_semseg_4k_224-448', | |
tok_clip: Optional[str] = 'EPFL-VILAB/4M_tokenizers_CLIP-B16_8k_224-448', | |
tok_dinov2: Optional[str] = 'EPFL-VILAB/4M_tokenizers_DINOv2-B14_8k_224-448', | |
tok_imagebind: Optional[str] = 'EPFL-VILAB/4M_tokenizers_ImageBind-H14_8k_224-448', | |
tok_sam_instance: Optional[str] = 'EPFL-VILAB/4M_tokenizers_sam-instance_1k_64', | |
tok_human_poses: Optional[str] = 'EPFL-VILAB/4M_tokenizers_human-poses_1k_8', | |
tok_text: str = './fourm/utils/tokenizer/trained/text_tokenizer_4m_wordpiece_30k.json', | |
mods: Optional[List[str]] = None, | |
mods_sr: Optional[List[str]] = None, | |
verbose: bool = True): | |
super().__init__() | |
self.verbose = verbose | |
if self.verbose: | |
print('Loading 4M models and tokenizers...', end='') | |
# Load 4M model and initialize sampler | |
fm = load_model(fm, FM) | |
self.sampler_fm = GenerationSampler(fm) | |
self.mods = mods or list(set(fm.encoder_modalities) | set(fm.decoder_modalities)) | |
# Load optional 4M super-res model and initialize sampler | |
if fm_sr is not None: | |
fm_sr = load_model(fm_sr, FM) | |
self.sampler_fm_sr = GenerationSampler(fm_sr) | |
self.mods_sr = mods_sr or list(set(fm_sr.encoder_modalities) | set(fm_sr.decoder_modalities)) | |
else: | |
self.sampler_fm_sr = None | |
# Load tokenizers | |
self.toks = {} | |
if ('tok_rgb@224' in self.mods or 'tok_rgb@448' in self.mods_sr) and tok_rgb is not None: | |
self.toks['tok_rgb'] = load_model(tok_rgb, DiVAE) | |
if ('tok_depth@224' in self.mods or 'tok_depth@448' in self.mods_sr) and tok_depth is not None: | |
self.toks['tok_depth'] = load_model(tok_depth, DiVAE) | |
if ('tok_normal@224' in self.mods or 'tok_normal@448' in self.mods_sr) and tok_normal is not None: | |
self.toks['tok_normal'] = load_model(tok_normal, DiVAE) | |
if ('tok_canny_edge@224' in self.mods or 'tok_sam_edge@224' in self.mods) and tok_edge is not None: | |
self.toks['tok_canny_edge'] = load_model(tok_edge, DiVAE) | |
self.toks['tok_sam_edge'] = self.toks['tok_canny_edge'] # Shared tokenizer | |
if ('tok_semseg@224' in self.mods or 'tok_semseg@448' in self.mods_sr) and tok_semseg is not None: | |
self.toks['tok_semseg'] = load_model(tok_semseg, VQVAE) | |
if ('tok_clip@224' in self.mods or 'tok_clip@448' in self.mods_sr) and tok_clip is not None: | |
self.toks['tok_clip'] = load_model(tok_clip, VQVAE) | |
if 'tok_dinov2@224' in self.mods and tok_dinov2 is not None: | |
self.toks['tok_dinov2'] = load_model(tok_dinov2, VQVAE) | |
if 'tok_imagebind@224' in self.mods and tok_imagebind is not None: | |
self.toks['tok_imagebind'] = load_model(tok_imagebind, VQVAE) | |
if 'sam_instance' in self.mods and tok_sam_instance is not None: | |
self.toks['sam_instance'] = load_model(tok_sam_instance, VQVAE) | |
if 'human_poses' in self.mods and tok_human_poses is not None: | |
self.toks['human_poses'] = load_model(tok_human_poses, VQVAE) | |
self.toks = nn.ModuleDict(self.toks) | |
self.tok_text = Tokenizer.from_file(tok_text) | |
if self.verbose: | |
print(' done!') | |
def device(self): | |
return next(self.parameters()).device | |
def __setup_conds_and_targets(self, sample): | |
# Input and output modalities | |
cond_domains = [domain for domain in list(sample.keys()) if domain in self.mods] | |
target_domains = [domain for domain in DEFAULT_ORDER if (domain not in cond_domains and domain in self.mods)] | |
if 'rgb@224' in cond_domains: | |
# Do not generate tokenized RGB if pixel RGB is given as input | |
target_domains.remove('tok_rgb@224') | |
return cond_domains, target_domains | |
def __setup_sr_conds_and_targets(self, sample): | |
cond_domains_sr = [domain for domain in list(sample.keys()) if domain in self.mods_sr] | |
target_domains_sr = [domain for domain in DEFAULT_ORDER_SR if (domain.replace('448', '224') in cond_domains_sr and domain in self.mods_sr)] | |
return cond_domains_sr, target_domains_sr | |
def __setup_sample_and_schedule(self, sample, cond_domains, target_domains, cfg_grow_conditioning=True): | |
# 1 - Setup generation schedule | |
defaults = DEFAULTS_RGB2X if ('rgb@224' in cond_domains or 'tok_rgb@224' in cond_domains) else DEFAULTS_X2RGB | |
tokens_per_target = [get_value(defaults, domain, 'tokens_per_target') for domain in target_domains] | |
autoregression_schemes = [get_value(defaults, domain, 'autoregression_scheme') for domain in target_domains] | |
decoding_steps = [get_value(defaults, domain, 'decoding_steps') for domain in target_domains] | |
token_decoding_schedules = [get_value(defaults, domain, 'token_decoding_schedule') for domain in target_domains] | |
temps = [get_value(defaults, domain, 'temp') for domain in target_domains] | |
temp_schedules = [get_value(defaults, domain, 'temp_schedule') for domain in target_domains] | |
cfg_scales = [get_value(defaults, domain, 'cfg_scale') for domain in target_domains] | |
cfg_schedules = [get_value(defaults, domain, 'cfg_schedule') for domain in target_domains] | |
schedule = build_chained_generation_schedules( | |
cond_domains=cond_domains, target_domains=target_domains, tokens_per_target=tokens_per_target, | |
autoregression_schemes=autoregression_schemes, decoding_steps=decoding_steps, | |
token_decoding_schedules=token_decoding_schedules, temps=temps, temp_schedules=temp_schedules, | |
cfg_scales=cfg_scales, cfg_schedules=cfg_schedules, cfg_grow_conditioning=cfg_grow_conditioning, | |
) | |
# 2 - Setup sample | |
sample_dict = {} | |
# Handle special cases | |
if 'caption' in sample: | |
caption = sample.pop('caption') | |
sample_dict = custom_text( | |
sample_dict, input_text=caption, eos_token='[EOS]', | |
key='caption', device=self.device, text_tokenizer=self.tok_text | |
) | |
if 'det' in sample: | |
caption = sample.pop('det') | |
sample_dict = custom_text( | |
sample_dict, input_text=caption, eos_token='[EOS]', | |
key='det', device=self.device, text_tokenizer=self.tok_text | |
) | |
# Add remaining modalities | |
sample_dict.update({domain: {'tensor': tensor} for domain, tensor in sample.items()}) | |
# Initialize these remaining input modalities (caption and det are already initialized by custom_text) | |
for cond_mod in sample.keys(): | |
sample_dict = init_full_input_modality(sample_dict, MODALITY_INFO, cond_mod, self.device, eos_id=self.tok_text.token_to_id("[EOS]")) | |
# Initialize target modalities | |
for target_mod, ntoks in zip(target_domains, tokens_per_target): | |
sample_dict = init_empty_target_modality(sample_dict, MODALITY_INFO, target_mod, 1, ntoks, self.device) | |
return sample_dict, schedule | |
def __setup_sr_sample_and_schedule(self, out_dict, cond_domains_sr, target_domains_sr, cfg_grow_conditioning_sr=True): | |
# 1 - Setup generation schedule | |
tokens_per_target_sr = [get_value(DEFAULTS_SR, domain, 'tokens_per_target') for domain in target_domains_sr] | |
autoregression_schemes_sr = [get_value(DEFAULTS_SR, domain, 'autoregression_scheme') for domain in target_domains_sr] | |
decoding_steps_sr = [get_value(DEFAULTS_SR, domain, 'decoding_steps') for domain in target_domains_sr] | |
token_decoding_schedules_sr = [get_value(DEFAULTS_SR, domain, 'token_decoding_schedule') for domain in target_domains_sr] | |
temps_sr = [get_value(DEFAULTS_SR, domain, 'temp') for domain in target_domains_sr] | |
temp_schedules_sr = [get_value(DEFAULTS_SR, domain, 'temp_schedule') for domain in target_domains_sr] | |
cfg_scales_sr = [get_value(DEFAULTS_SR, domain, 'cfg_scale') for domain in target_domains_sr] | |
cfg_schedules_sr = [get_value(DEFAULTS_SR, domain, 'cfg_schedule') for domain in target_domains_sr] | |
schedule_sr = build_chained_generation_schedules( | |
cond_domains=cond_domains_sr, target_domains=target_domains_sr, tokens_per_target=tokens_per_target_sr, | |
autoregression_schemes=autoregression_schemes_sr, decoding_steps=decoding_steps_sr, | |
token_decoding_schedules=token_decoding_schedules_sr, temps=temps_sr, temp_schedules=temp_schedules_sr, | |
cfg_scales=cfg_scales_sr, cfg_schedules=cfg_schedules_sr, cfg_grow_conditioning=cfg_grow_conditioning_sr, | |
) | |
# 2 - Setup sample | |
sample_sr = out_dict | |
# Handle case where generated caption or bounding boxes is just [EOS] | |
if 'caption' in sample_sr and sample_sr['caption']['tensor'].shape[1] <= 1 and 'caption' in cond_domains_sr: | |
sample_sr = custom_text( | |
sample_sr, input_text='[S_1]', eos_token='[EOS]', | |
key='caption', device=self.device, text_tokenizer=self.tok_text | |
) | |
if 'det' in sample_sr and sample_sr['det']['tensor'].shape[1] <= 1 and 'det' in cond_domains_sr: | |
sample_sr = custom_text( | |
sample_sr, input_text='[S_1]', eos_token='[EOS]', | |
key='det', device=self.device, text_tokenizer=self.tok_text | |
) | |
# Initialize input modalities | |
for cond_mod in cond_domains_sr: | |
sample_sr = init_full_input_modality(sample_sr, MODALITY_INFO, cond_mod, self.device, eos_id=self.tok_text.token_to_id("[EOS]")) | |
# Initialize target modalities | |
for target_mod, ntoks in zip(target_domains_sr, tokens_per_target_sr): | |
sample_sr = init_empty_target_modality(sample_sr, MODALITY_INFO, target_mod, 1, ntoks, self.device) | |
return sample_sr, schedule_sr | |
def forward(self, sample, seed: Optional[int] = None, top_p: float = 0.8, top_k: float = 0.0, target_modalities: Optional[List[str]] = None, perform_sr: bool = True): | |
seed = seed or np.random.randint(np.iinfo(np.int64).max) | |
# Prepare the generation parameters and sample | |
cond_domains, target_domains = self.__setup_conds_and_targets(sample) | |
target_domains = target_modalities or target_domains | |
sample, generation_schedule = self.__setup_sample_and_schedule(sample, cond_domains, target_domains) | |
# Generation and decoding at the base resolution 224x224 | |
if self.verbose: | |
print(f'Generating {cond_domains} -> {target_domains} ...') | |
out_dict = self.sampler_fm.generate( | |
sample, generation_schedule, text_tokenizer=self.tok_text, | |
verbose=self.verbose, seed=seed, top_p=top_p, top_k=top_k, | |
) | |
dec_dict = decode_dict( | |
out_dict, self.toks, self.tok_text, image_size=224, | |
patch_size=16, decoding_steps=50 | |
) | |
# Optional upsampling to 448x448 | |
if self.sampler_fm_sr is not None and perform_sr: | |
cond_domains_sr, target_domains_sr = self.__setup_sr_conds_and_targets(out_dict) | |
sample_sr, generation_schedule_sr = self.__setup_sr_sample_and_schedule(out_dict, cond_domains_sr, target_domains_sr) | |
if self.verbose: | |
print(f'Super-resolving {target_domains_sr} ...') | |
out_dict_sr = self.sampler_fm_sr.generate( | |
sample_sr, generation_schedule_sr, text_tokenizer=self.tok_text, | |
verbose=self.verbose, seed=seed+1, top_p=top_p, top_k=top_k, | |
) | |
dec_dict = decode_dict( | |
out_dict_sr, self.toks, self.tok_text, image_size=448, | |
patch_size=16, decoding_steps=50 | |
) | |
# Remove padding tokens | |
if 'caption' in dec_dict: | |
dec_dict['caption'][0].replace('[PAD]', '').strip() | |
if 'det' in dec_dict: | |
dec_dict['det'][0].replace('[PAD]', '').strip() | |
return dec_dict | |
def plot_modalities(self, mod_dict, ncols_max=5, figscale=4.0, save_path=None, use_fixed_plotting_order=False): | |
nmods = len(mod_dict) | |
ncols = min(nmods, ncols_max) | |
nrows = math.ceil(nmods / ncols) | |
fig, ax = plt.subplots( | |
nrows=nrows, ncols=ncols, | |
figsize=(ncols*figscale, nrows*figscale), | |
facecolor=(1, 1, 1) | |
) | |
if use_fixed_plotting_order: | |
mod_dict = { | |
k: mod_dict[k] for k in MODALITY_PLOTTING_ORDER | |
if k in mod_dict | |
} | |
for i, (mod_name, mod) in enumerate(mod_dict.items()): | |
if nrows == 1: | |
ax_i = ax[i] | |
else: | |
row, col = i // ncols, i % ncols | |
ax_i = ax[row,col] | |
if mod_name == 'det': | |
# Attempt to get the first available value from mod_dict according to the priority | |
keys_in_order = ['rgb@448', 'rgb@224', 'tok_rgb@448', 'tok_rgb@224'] | |
rgb_background = next((mod_dict[key] for key in keys_in_order if key in mod_dict), np.ones((224, 224, 3))) | |
rgb_background = (255 * rgb_background).astype(np.uint8) | |
ax_i.imshow(visualize_bboxes(rgb_background, mod[0],).astype(np.uint8)) | |
elif mod_name == 'caption': | |
plot_text_in_square(ax_i, mod[0], wrap_width=16, fontsize=14) | |
elif mod_name == 'metadata': | |
metadata_pred = ',\n'.join([f'{k}: {v:.2f}' if isinstance(v, float) else f'{k}: {v}' for k, v in mod.items()]) | |
plot_text_in_square(ax_i, metadata_pred, wrap_width=36, fontsize=13) | |
else: | |
ax_i.imshow(mod) | |
ax_i.set_title(MODALITY_PLOTTING_NAME_MAP.get(mod_name, mod_name), fontsize=18) | |
for i, axis in enumerate(ax.flatten()): | |
axis.set_xticks([]) | |
axis.set_yticks([]) | |
if i >= len(mod_dict): | |
axis.spines['top'].set_visible(False) | |
axis.spines['right'].set_visible(False) | |
axis.spines['bottom'].set_visible(False) | |
axis.spines['left'].set_visible(False) | |
plt.tight_layout() | |
if save_path is not None: | |
os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
plt.savefig(save_path, bbox_inches='tight', dpi=300) | |
plt.close() | |
else: | |
plt.show() | |
def modalities_to_pil(self, mod_dict, use_fixed_plotting_order=False, resize=None): | |
if use_fixed_plotting_order: | |
mod_dict = { | |
k: mod_dict[k] for k in MODALITY_PLOTTING_ORDER | |
if k in mod_dict | |
} | |
plotted_modalities = [] | |
for i, (mod_name, mod) in enumerate(mod_dict.items()): | |
if mod_name == 'det': | |
# Attempt to get the first available value from mod_dict according to the priority | |
keys_in_order = ['rgb@448', 'rgb@224', 'tok_rgb@448', 'tok_rgb@224'] | |
rgb_background = next((mod_dict[key] for key in keys_in_order if key in mod_dict), np.ones((224, 224, 3))) | |
rgb_background = (255 * rgb_background).astype(np.uint8) | |
img_pil = Image.fromarray(visualize_bboxes(rgb_background, mod[0],).astype(np.uint8)) | |
elif mod_name == 'caption': | |
img_pil = text_to_pil_image(mod[0][:512], wrap_width=40, fontsize=14) | |
elif mod_name == 'metadata': | |
metadata_pred = ',\n'.join([f'{k}: {v:.2f}' if isinstance(v, float) else f'{k}: {v}' for k, v in mod.items()]) | |
img_pil = text_to_pil_image(metadata_pred, wrap_width=36, fontsize=13) | |
else: | |
img_pil = Image.fromarray((255*mod).astype(np.uint8)) | |
if resize is not None: | |
if mod_name in ['tok_clip@224', 'tok_dinov2@224', 'tok_imagebind@224', 'tok_clip@448']: | |
resample_mode = Image.Resampling.NEAREST | |
else: | |
resample_mode = Image.Resampling.BILINEAR | |
img_pil = img_pil.resize((resize, resize), resample=resample_mode) | |
plot_name = MODALITY_PLOTTING_NAME_MAP.get(mod_name, mod_name) | |
plotted_modalities.append((img_pil, plot_name)) | |
return plotted_modalities |