File size: 3,017 Bytes
70ff35f
 
 
 
 
 
 
 
 
 
 
 
7185c8b
70ff35f
8913269
 
 
70ff35f
8913269
 
 
 
 
9f3d881
8913269
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ca7690
 
 
 
74df120
7ca7690
8913269
 
74df120
 
 
 
40f7024
 
 
8913269
ef070d7
 
8913269
 
 
 
 
 
 
ef070d7
 
 
8913269
 
 
43f386d
 
 
8913269
40f7024
 
 
8913269
 
74df120
8913269
40f7024
 
 
 
 
8913269
 
 
7185c8b
40f7024
8913269
 
 
 
ef070d7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from __future__ import annotations

import gc
import pathlib
import sys

import gradio as gr
import PIL.Image
import numpy as np

import torch
from diffusers import StableDiffusionPipeline
# sys.path.insert(0, './ReVersion')

# below are original
import os
# import argparse

# import torch
from PIL import Image

# from diffusers import StableDiffusionPipeline
# sys.path.insert(0, './ReVersion')
# from templates.templates import inference_templates

import math

"""
Inference script for generating batch results
"""

def make_image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size

    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid


def inference_fn(
        model_id: str,
        prompt: str,
        num_samples: int,
        guidance_scale: float,
        ddim_steps: int,
    ) -> PIL.Image.Image:

    # create inference pipeline
    if torch.cuda.is_available():
        pipe = StableDiffusionPipeline.from_pretrained(os.path.join('experiments', model_id),torch_dtype=torch.float16).to('cuda')
    else:
        pipe = StableDiffusionPipeline.from_pretrained(os.path.join('experiments', model_id)).to('cpu')
    # # make directory to save images
    # image_root_folder = os.path.join('experiments', model_id, 'inference')
    # os.makedirs(image_root_folder, exist_ok = True)

    # if prompt is None and args.template_name is None:
    #     raise ValueError("please input a single prompt through'--prompt' or select a batch of prompts using '--template_name'.")

    # single text prompt
    if prompt is not None:
        prompt_list = [prompt]
    else:
        prompt_list = []

    # if args.template_name is not None:
    #     # read the selected text prompts for generation
    #     prompt_list.extend(inference_templates[args.template_name])

    for prompt in prompt_list:
        # insert relation prompt <R>
        # prompt = prompt.lower().replace("<r>", "<R>").format(placeholder_string)
        prompt = prompt.lower().replace("<r>", "<R>").format("<R>")


        # # make sub-folder
        # image_folder = os.path.join(image_root_folder, prompt, 'samples')
        # os.makedirs(image_folder, exist_ok = True)

        # batch generation
        images = pipe(prompt, num_inference_steps=ddim_steps, guidance_scale=guidance_scale, num_images_per_prompt=num_samples).images

        # # save generated images
        # for idx, image in enumerate(images):
        #     image_name = f"{str(idx).zfill(4)}.png"
        #     image_path = os.path.join(image_folder, image_name)
        #     image.save(image_path)

        # save a grid of images
        image_grid = make_image_grid(images, rows=2, cols=math.ceil(num_samples/2))
        print(image_grid)
        # image_grid_path = os.path.join(image_root_folder, prompt, f'{prompt}.png')

        return image_grid

if __name__ == "__main__":
    inference_fn()