File size: 4,041 Bytes
ab4cf94
 
6ef117e
 
 
 
 
 
 
 
 
 
 
 
 
 
80b040e
6ef117e
 
 
 
ab4cf94
 
6ef117e
 
 
 
 
 
 
 
 
 
 
ab4cf94
 
6ef117e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab4cf94
6ef117e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab4cf94
6ef117e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab4cf94
6ef117e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import gradio as gr
from torch import Tensor, ones_like
from typing import Optional, Union, List, Tuple
from diffusers.pipelines import FluxPipeline
from PIL import Image, ImageFilter
import numpy as np
import cv2

condition_dict = {
    "depth": 0,
    "canny": 1,
    "subject": 4,
    "coloring": 6,
    "deblurring": 7,
    "fill": 9,
}

class Condition(object):
    def __init__(
        self,
        condition_type: str,
        raw_img: Union[Image.Image, Tensor] = None,
        condition: Union[Image.Image,Tensor] = None,
        mask=None,
    ) -> None:
        self.condition_type = condition_type
        assert raw_img is not None or condition is not None
        if raw_img is not None:
            self.condition = self.get_condition(condition_type, raw_img)
        else:
            self.condition = condition
        # TODO: Add mask support
        assert mask is None, "Mask not supported yet"
    def get_condition(
        self, condition_type: str, raw_img: Union[Image.Image, Tensor]
    ) -> Union[Image.Image, Tensor]:
        """
        Returns the condition image.
        """
        if condition_type == "depth":
            from transformers import pipeline
            depth_pipe = pipeline(
                task="depth-estimation",
                model="LiheYoung/depth-anything-small-hf",
                device="cuda",
            )
            source_image = raw_img.convert("RGB")
            condition_img = depth_pipe(source_image)["depth"].convert("RGB")
            return condition_img
        elif condition_type == "canny":
            img = np.array(raw_img)
            edges = cv2.Canny(img, 100, 200)
            edges = Image.fromarray(edges).convert("RGB")
            return edges
        elif condition_type == "subject":
            return raw_img
        elif condition_type == "coloring":
            return raw_img.convert("L").convert("RGB")
        elif condition_type == "deblurring":
            condition_image = (
                raw_img.convert("RGB")
                .filter(ImageFilter.GaussianBlur(10))
                .convert("RGB")
            )
            return condition_image
        elif condition_type == "fill":
            return raw_img.convert("RGB")
        return self.condition
    @property
    def type_id(self) -> int:
        """
        Returns the type id of the condition.
        """
        return condition_dict[self.condition_type]
    @classmethod
    def get_type_id(cls, condition_type: str) -> int:
        """
        Returns the type id of the condition.
        """
        return condition_dict[condition_type]
    def _encode_image(self, pipe: FluxPipeline, cond_img: Image.Image) -> Tensor:
        """
        Encodes an image condition into tokens using the pipeline.
        """
        cond_img = pipe.image_processor.preprocess(cond_img)
        cond_img = cond_img.to(pipe.device).to(pipe.dtype)
        cond_img = pipe.vae.encode(cond_img).latent_dist.sample()
        cond_img = (
            cond_img - pipe.vae.config.shift_factor
        ) * pipe.vae.config.scaling_factor
        cond_tokens = pipe._pack_latents(cond_img, *cond_img.shape)
        cond_ids = pipe._prepare_latent_image_ids(
            cond_img.shape[0],
            cond_img.shape[2]//2,
            cond_img.shape[3]//2,
            pipe.device,
            pipe.dtype,
        )
        return cond_tokens, cond_ids
    def encode(self, pipe: FluxPipeline) -> Tuple[Tensor, Tensor, int]:
        """
        Encodes the condition into tokens, ids and type_id.
        """
        if self.condition_type in [
            "depth",
            "canny",
            "subject",
            "coloring",
            "deblurring",
            "fill",
        ]:
            tokens, ids = self._encode_image(pipe, self.condition)
        else:
            raise NotImplementedError(
                f"Condition type {self.condition_type} not implemented"
            )
        type_id = ones_like(ids[:, :1]) * self.type_id
        return tokens, ids, type_id