File size: 750 Bytes
778d937
6dbb9c3
24438e9
778d937
 
 
 
 
 
 
 
 
 
 
 
0977330
ab28a5d
6dbb9c3
50bfbb3
778d937
 
50bfbb3
778d937
 
f3ff2c1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import PIL.Image
import torch, gc
from controlnet_aux_local import NormalBaeDetector#, CannyDetector

class Preprocessor:
    MODEL_ID = "lllyasviel/Annotators"

    def __init__(self):
        self.model = None
        self.name = ""

    def load(self, name: str) -> None:
        if name == self.name:
            return
        elif name == "NormalBae":
            print("Loading NormalBae")
            self.model = NormalBaeDetector.from_pretrained(self.MODEL_ID).to("cuda")
            torch.cuda.empty_cache()
            self.name = name
        else:
            raise ValueError
        return

    def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image:
        return self.model(image, **kwargs)