Spaces:
Runtime error
Runtime error
import sys | |
import os | |
root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) | |
sys.path.append(root) | |
import torch | |
import random | |
import subprocess | |
import numpy as np | |
import torch.distributed as dist | |
import pandas as pd | |
import argparse | |
import torch | |
import os | |
from PIL import Image | |
from tqdm import tqdm | |
import torch.distributed as dist | |
from qwen_vl_utils import process_vision_info | |
from torchvision import transforms | |
from transformers import AutoProcessor | |
from transformers import SiglipImageProcessor, SiglipVisionModel | |
from univa.utils.flux_pipeline import FluxPipeline | |
from univa.eval.configuration_eval import EvalConfig | |
from univa.utils.get_ocr import get_ocr_result | |
from univa.utils.denoiser_prompt_embedding_flux import encode_prompt | |
from univa.models.qwen2p5vl.modeling_univa_qwen2p5vl import UnivaQwen2p5VLForConditionalGeneration | |
import pandas as pd | |
from copy import deepcopy | |
import json | |
def get_meta(prompt_path): | |
''' | |
[ | |
{ | |
"Prompt": "a photo of a cat", | |
"Category": "", | |
"id": "", | |
}, | |
... | |
] | |
''' | |
with open(prompt_path, 'r') as f: | |
meta_info = json.load(f) | |
ret_meta_info = [] | |
for v in meta_info.values(): | |
if 'models' in v: del v['models'] | |
if 'prompt in Chinese' in v: del v['prompt in Chinese'] | |
v['Prompts'] = deepcopy(v['prompt']) | |
if 'prompt' in v: del v['prompt'] | |
v['Category'] = 'No Category' | |
v['id'] = f"{int(v['id']):09d}" | |
ret_meta_info.append(v) | |
return ret_meta_info | |
# adapted from https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/random.py#L31 | |
def set_seed(seed, rank, device_specific=True): | |
if device_specific: | |
seed += rank | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
def initialize_models(args, device): | |
# Load main model and task head | |
model = UnivaQwen2p5VLForConditionalGeneration.from_pretrained( | |
args.pretrained_lvlm_name_or_path, | |
torch_dtype=torch.bfloat16, | |
attn_implementation="flash_attention_2", | |
).to(device) | |
processor = AutoProcessor.from_pretrained( | |
args.pretrained_lvlm_name_or_path, | |
min_pixels=args.min_pixels, | |
max_pixels=args.max_pixels, | |
) | |
# Load FLUX pipeline | |
pipe = FluxPipeline.from_pretrained( | |
args.pretrained_denoiser_name_or_path, | |
transformer=model.denoise_tower.denoiser, | |
torch_dtype=torch.bfloat16, | |
).to(device) | |
tokenizers = [pipe.tokenizer, pipe.tokenizer_2] | |
text_encoders = [pipe.text_encoder, pipe.text_encoder_2] | |
siglip_processor = SiglipImageProcessor.from_pretrained(args.pretrained_siglip_name_or_path) | |
siglip_model = SiglipVisionModel.from_pretrained( | |
args.pretrained_siglip_name_or_path, | |
torch_dtype=torch.bfloat16, | |
).to(device) | |
return { | |
'model': model, | |
'processor': processor, | |
'pipe': pipe, | |
'tokenizers': tokenizers, | |
'text_encoders': text_encoders, | |
'device': device, | |
'siglip_model': siglip_model, | |
'siglip_processor': siglip_processor, | |
} | |
def init_gpu_env(args): | |
local_rank = int(os.getenv('RANK', 0)) | |
world_size = int(os.getenv('WORLD_SIZE', 1)) | |
args.local_rank = local_rank | |
args.world_size = world_size | |
torch.cuda.set_device(local_rank) | |
dist.init_process_group( | |
backend='nccl', init_method='env://', | |
world_size=world_size, rank=local_rank | |
) | |
return args | |
def run_model_and_return_samples(args, state, text, image1=None, image2=None): | |
# Build content | |
convo = [] | |
image_paths = [] | |
content = [] | |
for img in (image1, image2): | |
if img: | |
content.append({'type':'image','image':img,'min_pixels':args.min_pixels,'max_pixels':args.max_pixels}) | |
image_paths.append(img) | |
if text: | |
ocr_text = '' | |
if args.ocr_enhancer and content: | |
ocr_texts = [] | |
for img in (image1, image2): | |
if img: | |
ocr_texts.append(get_ocr_result(img, cur_ocr_i)) | |
cur_ocr_i += 1 | |
ocr_text = '\n'.join(ocr_texts) | |
content.append({'type':'text','text': text + ocr_text}) | |
if not args.only_use_t5: | |
convo.append({'role':'user','content':content}) | |
# Prepare inputs | |
chat_text = state['processor'].apply_chat_template( | |
convo, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
chat_text = '<|im_end|>\n'.join(chat_text.split('<|im_end|>\n')[1:]) | |
image_inputs, video_inputs = process_vision_info(convo) | |
inputs = state['processor']( | |
text=[chat_text], images=image_inputs, videos=video_inputs, | |
padding=True, return_tensors='pt' | |
).to(state['device']) | |
# Generate | |
# image generation pipeline | |
siglip_hs = None | |
if state['siglip_processor'] and image_paths: | |
vals = [state['siglip_processor'].preprocess( | |
images=Image.open(p).convert('RGB'), do_resize=True, | |
return_tensors='pt', do_convert_rgb=True | |
).pixel_values.to(state['device']) | |
for p in image_paths] | |
siglip_hs = state['siglip_model'](torch.concat(vals)).last_hidden_state | |
with torch.no_grad(): | |
lvlm = state['model']( | |
inputs.input_ids, pixel_values=getattr(inputs,'pixel_values',None), | |
attention_mask=inputs.attention_mask, | |
image_grid_thw=getattr(inputs,'image_grid_thw',None), | |
siglip_hidden_states=siglip_hs, | |
output_type='denoise_embeds' | |
) | |
prm_embeds, pooled = encode_prompt( | |
state['text_encoders'], state['tokenizers'], | |
text if args.joint_with_t5 else '', 256, state['device'], 1 | |
) | |
emb = torch.concat([lvlm, prm_embeds], dim=1) if args.joint_with_t5 else lvlm | |
else: | |
prm_embeds, pooled = encode_prompt( | |
state['text_encoders'], state['tokenizers'], | |
text, 256, state['device'], 1 | |
) | |
emb = prm_embeds | |
with torch.no_grad(): | |
img = state['pipe']( | |
prompt_embeds=emb, | |
pooled_prompt_embeds=pooled, | |
height=args.height, | |
width=args.width, | |
num_inference_steps=args.num_inference_steps, | |
guidance_scale=args.guidance_scale, | |
num_images_per_prompt=args.num_images_per_prompt, | |
).images | |
return img | |
def main(args): | |
args = init_gpu_env(args) | |
torch.backends.cuda.matmul.allow_tf32 = False | |
torch.backends.cudnn.allow_tf32 = False | |
if args.allow_tf32: | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
set_seed(args.seed, rank=args.local_rank, device_specific=True) | |
device = torch.cuda.current_device() | |
state = initialize_models(args, device) | |
meta_info = get_meta(args.genai_prompt_path) | |
print(f'origin meta_info ({len(meta_info)})') | |
text_and_savepath = [ | |
[ | |
meta_info[i]['Prompts'], os.path.join(args.output_dir, f"{meta_info[i]['id']}.jpg") | |
] for i in range(len(meta_info)) | |
] | |
text_and_savepath_ = [ | |
[text_prompt, save_path] for text_prompt, save_path in text_and_savepath if not os.path.exists(save_path) | |
] | |
print(f'need to process ({len(text_and_savepath_)})') | |
if len(text_and_savepath_) == 0: | |
import sys;sys.exit(0) | |
text_and_savepath = text_and_savepath[args.local_rank::args.world_size] | |
os.makedirs(args.output_dir, exist_ok=True) | |
print(f'args: {args}') | |
cnt = 0 | |
for text_prompt, save_path in tqdm(text_and_savepath): | |
# print(text_prompt, save_path) | |
if os.path.exists(save_path): | |
continue | |
set_seed(args.seed + cnt * 50, rank=args.local_rank, device_specific=True) | |
image = run_model_and_return_samples(args, state, text_prompt, image1=None, image2=None) | |
image = image[0] | |
image.save(save_path) | |
# import ipdb;ipdb.set_trace() | |
assert args.num_samples_per_prompt == 1 | |
cnt += 1 | |
if __name__ == "__main__": | |
import argparse | |
from omegaconf import OmegaConf | |
parser = argparse.ArgumentParser() | |
parser.add_argument("config", type=str) | |
parser.add_argument("--pretrained_lvlm_name_or_path", type=str, default=None, required=False) | |
parser.add_argument("--output_dir", type=str, default=None, required=False) | |
args = parser.parse_args() | |
config = OmegaConf.load(args.config) | |
schema = OmegaConf.structured(EvalConfig) | |
conf = OmegaConf.merge(schema, config) | |
if args.pretrained_lvlm_name_or_path is not None: | |
assert args.output_dir is not None | |
conf.pretrained_lvlm_name_or_path = args.pretrained_lvlm_name_or_path | |
conf.output_dir = args.output_dir | |
main(conf) |