File size: 1,822 Bytes
171f55b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
from diffusers import UnCLIPScheduler, DDPMScheduler, StableUnCLIPPipeline
from diffusers.models import PriorTransformer
from transformers import CLIPTokenizer, CLIPTextModelWithProjection


def init_text2img_pipe():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    data_type = torch.float16 if torch.cuda.is_available() else torch.float32

    prior_model_id = "kakaobrain/karlo-v1-alpha"
    prior = PriorTransformer.from_pretrained(prior_model_id, subfolder="prior", torch_dtype=data_type)

    prior_text_model_id = "openai/clip-vit-large-patch14"
    prior_tokenizer = CLIPTokenizer.from_pretrained(prior_text_model_id)
    prior_text_model = CLIPTextModelWithProjection.from_pretrained(prior_text_model_id, torch_dtype=data_type)
    prior_scheduler = UnCLIPScheduler.from_pretrained(prior_model_id, subfolder="prior_scheduler")
    prior_scheduler = DDPMScheduler.from_config(prior_scheduler.config)

    stable_unclip_model_id = "stabilityai/stable-diffusion-2-1-unclip-small"

    pipe = StableUnCLIPPipeline.from_pretrained(
        stable_unclip_model_id,
        torch_dtype=data_type,
        variant="fp16",
        prior_tokenizer=prior_tokenizer,
        prior_text_encoder=prior_text_model,
        prior=prior,
        prior_scheduler=prior_scheduler,
    )
    return pipe.to(device)


def predict(prompt: str, negative_prompt: str, pipeline):
    return pipeline(prompt=prompt,
                    negative_prompt=negative_prompt,
                    height=600,
                    width=400,
                    num_inference_steps=60).images


if __name__ == "__main__":
    text2img_pipeline = init_text2img_pipe()
    images = predict("a dog", "a cat", text2img_pipeline)
    for idx, image in enumerate(images):
        image.save(f"/root/autodl-tmp/image_{idx}.png")