File size: 2,688 Bytes
4c53d64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# 'Deforum' plugin for Automatic1111's Stable Diffusion WebUI.
# Copyright (C) 2023 Artem Khrapov (kabachuha) and Deforum team listed in AUTHORS.md
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, version 3 of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

# Contact the dev team: https://discord.gg/deforum

import os
import torch
from PIL import Image
from torchvision import transforms
from torch.nn.functional import interpolate
import cv2

preclipseg_transform = transforms.Compose([
      transforms.ToTensor(),
      transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
      transforms.Resize((512, 512)), #TODO: check if the size is hardcoded
])

def find_clipseg():
    basedirs = [os.getcwd()]
    src_basedirs = []
    for basedir in basedirs:
        src_basedirs.append(os.path.join(os.path.sep.join(os.path.abspath(__file__).split(os.path.sep)[:-2]), 'deforum_helpers', 'src'))

    for basedir in src_basedirs:
        pth = os.path.join(basedir, './clipseg/weights/rd64-uni.pth')
        if os.path.exists(pth):
            return pth
    raise Exception('CLIPseg weights not found!')

def setup_clipseg(root):
    from clipseg.models.clipseg import CLIPDensePredT
    model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64)
    model.eval()
    model.load_state_dict(torch.load(find_clipseg(), map_location=root.device), strict=False)

    model.to(root.device)
    root.clipseg_model = model

def get_word_mask(root, frame, word_mask):
    if root.clipseg_model is None:
        setup_clipseg(root)
    img = preclipseg_transform(frame).to(root.device, dtype=torch.float32)
    word_masks = [word_mask]
    with torch.no_grad():
        preds = root.clipseg_model(img.repeat(len(word_masks),1,1,1), word_masks)[0]

    mask = torch.sigmoid(preds[0][0]).unsqueeze(0).unsqueeze(0) # add batch, channels dims
    resized_mask = interpolate(mask, size=(frame.size[1], frame.size[0]), mode='bicubic').squeeze() # rescale mask back to the target resolution
    numpy_array = resized_mask.multiply(255).to(dtype=torch.uint8,device='cpu').numpy()
    return Image.fromarray(cv2.threshold(numpy_array, 32, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1])