garment_designer / api /model /src /utils /image_from_pipe.py
giorgio-caparvi
saving image in memory and sending via api
9abe70d
import os
from tqdm import tqdm
import torch
import torchvision.transforms as T
from diffusers.pipeline_utils import DiffusionPipeline
from torch.utils.data import DataLoader
from src.utils.image_composition import compose_img, compose_img_dresscode
from PIL import Image
@torch.inference_mode()
def generate_images_from_mgd_pipe(
test_order: bool,
pipe: DiffusionPipeline,
test_dataloader: DataLoader,
save_name: str,
dataset: str,
output_dir: str,
guidance_scale: float = 7.5,
guidance_scale_pose: float = 7.5,
guidance_scale_sketch: float = 7.5,
sketch_cond_rate: float = 1.0,
start_cond_rate: float = 0.0,
no_pose: bool = False,
disentagle: bool = False,
seed: int = 1234,
) -> None:
#This function generates images from the given test dataloader and saves them to the output directory.
"""
Args:
test_order: The order of the test dataset.
pipe: The diffusion pipeline.
test_dataloader: The test dataloader.
save_name: The name of the saved images.
dataset: The name of the dataset.
output_dir: The output directory.
guidance_scale: The guidance scale.
guidance_scale_pose: The guidance scale for the pose.
guidance_scale_sketch: The guidance scale for the sketch.
sketch_cond_rate: The sketch condition rate.
start_cond_rate: The start condition rate.
no_pose: Whether to use the pose.
disentagle: Whether to use disentagle.
seed: The seed.
Returns:
None
"""
#assert(save_name != ""), "save_name must be specified"
#assert(output_dir != ""), "output_dir must be specified"
#path = os.path.join(output_dir, f"{save_name}_{test_order}", "images")
#os.makedirs(path, exist_ok=True)
generator = torch.Generator("cuda").manual_seed(seed)
for batch in tqdm(test_dataloader):
model_img = batch["image"]
mask_img = batch["inpaint_mask"]
mask_img = mask_img.type(torch.float32)
prompts = batch["original_captions"] # prompts is a list of length N, where N=batch size.
pose_map = batch["pose_map"]
sketch = batch["im_sketch"]
ext = ".jpg"
if disentagle:
guidance_scale = guidance_scale
num_samples = 1
guidance_scale_pose = guidance_scale_pose
guidance_scale_sketch = guidance_scale_sketch
generated_images = pipe(
prompt=prompts,
image=model_img,
mask_image=mask_img,
pose_map=pose_map,
sketch=sketch,
height=512,
width=384,
guidance_scale=guidance_scale,
num_images_per_prompt=num_samples,
generator=generator,
sketch_cond_rate=sketch_cond_rate,
guidance_scale_pose=guidance_scale_pose,
guidance_scale_sketch=guidance_scale_sketch,
start_cond_rate=start_cond_rate,
no_pose=no_pose,
).images
else:
guidance_scale = 7.5
num_samples = 1
generated_images = pipe(
prompt=prompts,
image=model_img,
mask_image=mask_img,
pose_map=pose_map,
sketch=sketch,
height=512,
width=384,
guidance_scale=guidance_scale,
num_images_per_prompt=num_samples,
generator=generator,
sketch_cond_rate=sketch_cond_rate,
start_cond_rate=start_cond_rate,
no_pose=no_pose,
).images
for i in range(len(generated_images)):
model_i = model_img[i] * 0.5 + 0.5
if dataset == "vitonhd":
final_img = compose_img(model_i, generated_images[i], batch['im_parse'][i])
else: # dataset == Dresscode
face = batch["stitch_label"][i].to(model_img.device)
face = T.functional.resize(face,
size=(512,384),
interpolation=T.InterpolationMode.BILINEAR,
antialias = True
)
final_img = compose_img_dresscode(
gt_img = model_i,
fake_img = T.functional.to_tensor(generated_images[i]).to(model_img.device),
im_head = face
)
return T.functional.to_pil_image(final_img)
#final_img = T.functional.to_pil_image(final_img)
#final_img.save(
# os.path.join(path, batch["im_name"][i].replace(".jpg", ext)))