File size: 1,470 Bytes
739e183
ca559d4
739e183
 
 
 
 
 
 
 
 
 
 
 
1b68d22
739e183
 
 
 
 
 
1b68d22
739e183
 
 
d161d18
739e183
 
 
1b68d22
739e183
 
9c0e205
739e183
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
from PIL import Image
from diffusers import (
    StableDiffusionControlNetPipeline,
    UniPCMultistepScheduler,
    ControlNetModel
)


class GeoPainting:
    DEFAULT_CONTROLNET_MODEL = "lllyasviel/control_v11f1p_sd15_depth"
    DEFAULT_DIFFUSER_MODEL = "geospatial_diffuser"

    def __init__(self, controlnet_model_path=DEFAULT_CONTROLNET_MODEL, diffuser_model=DEFAULT_DIFFUSER_MODEL):
        self.controlnet = ControlNetModel.from_pretrained(controlnet_model_path, torch_dtype=torch.float16)
        self.generator = torch.Generator(device="cpu").manual_seed(2)
        self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
            diffuser_model,
            low_cpu_mem_usage=False,
            device_map=None,
            controlnet=self.controlnet,
            torch_dtype=torch.float16
        )
        self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
        if torch.cuda.is_available():
            self.pipe.enable_model_cpu_offload()
            self.pipe.enable_xformers_memory_efficient_attention()

    def generate_painting(self, input_promp, control_image):
        image = Image.fromarray(control_image.astype('uint8'))
        output = self.pipe(
            input_promp,
            image,
            negative_prompt="ugly, disfigured, low quality, blurry, nsfw",
            generator=self.generator,
            num_inference_steps=20,
        )
        return output.images[0]