File size: 4,492 Bytes
b7e867a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import argparse
import os
import os.path as osp

import torch
import torch.utils.checkpoint
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers import DPMSolverMultistepScheduler
from diffusers.utils import check_min_version
from omegaconf import OmegaConf
from tqdm import tqdm

from mixofshow.data.prompt_dataset import PromptDataset
from mixofshow.pipelines.pipeline_edlora import EDLoRAPipeline, StableDiffusionPipeline
from mixofshow.utils.convert_edlora_to_diffusers import convert_edlora
from mixofshow.utils.util import NEGATIVE_PROMPT, compose_visualize, dict2str, pil_imwrite, set_path_logger

# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version('0.18.2')


def visual_validation(accelerator, pipe, dataloader, current_iter, opt):
    dataset_name = dataloader.dataset.opt['name']
    pipe.unet.eval()
    pipe.text_encoder.eval()

    for idx, val_data in enumerate(tqdm(dataloader)):
        output = pipe(
            prompt=val_data['prompts'],
            latents=val_data['latents'].to(dtype=torch.float16),
            negative_prompt=[NEGATIVE_PROMPT] * len(val_data['prompts']),
            num_inference_steps=opt['val']['sample'].get('num_inference_steps', 50),
            guidance_scale=opt['val']['sample'].get('guidance_scale', 7.5),
        ).images

        for img, prompt, indice in zip(output, val_data['prompts'], val_data['indices']):
            img_name = '{prompt}---G_{guidance_scale}_S_{steps}---{indice}'.format(
                prompt=prompt.replace(' ', '_'),
                guidance_scale=opt['val']['sample'].get('guidance_scale', 7.5),
                steps=opt['val']['sample'].get('num_inference_steps', 50),
                indice=indice)

            save_img_path = osp.join(opt['path']['visualization'], dataset_name, f'{current_iter}', f'{img_name}---{current_iter}.png')

            pil_imwrite(img, save_img_path)
        # tentative for out of GPU memory
        del output
        torch.cuda.empty_cache()

    # Save the lora layers, final eval
    accelerator.wait_for_everyone()

    if opt['val'].get('compose_visualize'):
        if accelerator.is_main_process:
            compose_visualize(os.path.dirname(save_img_path))


def test(root_path, args):

    # load config
    opt = OmegaConf.to_container(OmegaConf.load(args.opt), resolve=True)

    # set accelerator, mix-precision set in the environment by "accelerate config"
    accelerator = Accelerator(mixed_precision=opt['mixed_precision'])

    # set experiment dir
    with accelerator.main_process_first():
        set_path_logger(accelerator, root_path, args.opt, opt, is_train=False)

    # get logger
    logger = get_logger('mixofshow', log_level='INFO')
    logger.info(accelerator.state, main_process_only=True)

    logger.info(dict2str(opt))

    # If passed along, set the training seed now.
    if opt.get('manual_seed') is not None:
        set_seed(opt['manual_seed'])

    # Get the training dataset
    valset_cfg = opt['datasets']['val_vis']
    val_dataset = PromptDataset(valset_cfg)
    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=valset_cfg['batch_size_per_gpu'], shuffle=False)

    enable_edlora = opt['models']['enable_edlora']

    for lora_alpha in opt['val']['alpha_list']:
        pipeclass = EDLoRAPipeline if enable_edlora else StableDiffusionPipeline
        pipe = pipeclass.from_pretrained(opt['models']['pretrained_path'],
            scheduler=DPMSolverMultistepScheduler.from_pretrained(opt['models']['pretrained_path'], subfolder='scheduler'),
            torch_dtype=torch.float16).to('cuda')
        pipe, new_concept_cfg = convert_edlora(pipe, torch.load(opt['path']['lora_path']), enable_edlora=enable_edlora, alpha=lora_alpha)
        pipe.set_new_concept_cfg(new_concept_cfg)
        # visualize embedding + LoRA weight shift
        logger.info(f'Start validation sample lora({lora_alpha}):')

        lora_type = 'edlora' if enable_edlora else 'lora'
        visual_validation(accelerator, pipe, val_dataloader, f'validation_{lora_type}_{lora_alpha}', opt)
        del pipe


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, default='options/test/EDLoRA/EDLoRA_hina_Anyv4_B4_Iter1K.yml')
    args = parser.parse_args()

    root_path = osp.abspath(osp.join(__file__, osp.pardir))
    test(root_path, args)