# 'Deforum' plugin for Automatic1111's Stable Diffusion WebUI. # Copyright (C) 2023 Artem Khrapov (kabachuha) and Deforum team listed in AUTHORS.md # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by # the Free Software Foundation, version 3 of the License. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Affero General Public License for more details. # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . # Contact the dev team: https://discord.gg/deforum import os import torch from PIL import Image from torchvision import transforms from torch.nn.functional import interpolate import cv2 preclipseg_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), transforms.Resize((512, 512)), #TODO: check if the size is hardcoded ]) def find_clipseg(): basedirs = [os.getcwd()] src_basedirs = [] for basedir in basedirs: src_basedirs.append(os.path.join(os.path.sep.join(os.path.abspath(__file__).split(os.path.sep)[:-2]), 'deforum_helpers', 'src')) for basedir in src_basedirs: pth = os.path.join(basedir, './clipseg/weights/rd64-uni.pth') if os.path.exists(pth): return pth raise Exception('CLIPseg weights not found!') def setup_clipseg(root): from clipseg.models.clipseg import CLIPDensePredT model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64) model.eval() model.load_state_dict(torch.load(find_clipseg(), map_location=root.device), strict=False) model.to(root.device) root.clipseg_model = model def get_word_mask(root, frame, word_mask): if root.clipseg_model is None: setup_clipseg(root) img = preclipseg_transform(frame).to(root.device, dtype=torch.float32) word_masks = [word_mask] with torch.no_grad(): preds = root.clipseg_model(img.repeat(len(word_masks),1,1,1), word_masks)[0] mask = torch.sigmoid(preds[0][0]).unsqueeze(0).unsqueeze(0) # add batch, channels dims resized_mask = interpolate(mask, size=(frame.size[1], frame.size[0]), mode='bicubic').squeeze() # rescale mask back to the target resolution numpy_array = resized_mask.multiply(255).to(dtype=torch.uint8,device='cpu').numpy() return Image.fromarray(cv2.threshold(numpy_array, 32, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1])