Spaces:
Running
Running
import cv2 | |
import numpy as np | |
import torch | |
import threading | |
from torchvision import transforms | |
from clip.clipseg import CLIPDensePredT | |
import numpy as np | |
from roop.typing import Frame | |
THREAD_LOCK_CLIP = threading.Lock() | |
class Mask_Clip2Seg(): | |
plugin_options:dict = None | |
model_clip = None | |
processorname = 'clip2seg' | |
type = 'mask' | |
def Initialize(self, plugin_options:dict): | |
if self.plugin_options is not None: | |
if self.plugin_options["devicename"] != plugin_options["devicename"]: | |
self.Release() | |
self.plugin_options = plugin_options | |
if self.model_clip is None: | |
self.model_clip = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, complex_trans_conv=True) | |
self.model_clip.eval(); | |
self.model_clip.load_state_dict(torch.load('models/CLIP/rd64-uni-refined.pth', map_location=torch.device('cpu')), strict=False) | |
device = torch.device(self.plugin_options["devicename"]) | |
self.model_clip.to(device) | |
def Run(self, img1, keywords:str) -> Frame: | |
if keywords is None or len(keywords) < 1 or img1 is None: | |
return img1 | |
source_image_small = cv2.resize(img1, (256,256)) | |
img_mask = np.full((source_image_small.shape[0],source_image_small.shape[1]), 0, dtype=np.float32) | |
mask_border = 1 | |
l = 0 | |
t = 0 | |
r = 1 | |
b = 1 | |
mask_blur = 5 | |
clip_blur = 5 | |
img_mask = cv2.rectangle(img_mask, (mask_border+int(l), mask_border+int(t)), | |
(256 - mask_border-int(r), 256-mask_border-int(b)), (255, 255, 255), -1) | |
img_mask = cv2.GaussianBlur(img_mask, (mask_blur*2+1,mask_blur*2+1), 0) | |
img_mask /= 255 | |
input_image = source_image_small | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
transforms.Resize((256, 256)), | |
]) | |
img = transform(input_image).unsqueeze(0) | |
thresh = 0.5 | |
prompts = keywords.split(',') | |
with THREAD_LOCK_CLIP: | |
with torch.no_grad(): | |
preds = self.model_clip(img.repeat(len(prompts),1,1,1), prompts)[0] | |
clip_mask = torch.sigmoid(preds[0][0]) | |
for i in range(len(prompts)-1): | |
clip_mask += torch.sigmoid(preds[i+1][0]) | |
clip_mask = clip_mask.data.cpu().numpy() | |
np.clip(clip_mask, 0, 1) | |
clip_mask[clip_mask>thresh] = 1.0 | |
clip_mask[clip_mask<=thresh] = 0.0 | |
kernel = np.ones((5, 5), np.float32) | |
clip_mask = cv2.dilate(clip_mask, kernel, iterations=1) | |
clip_mask = cv2.GaussianBlur(clip_mask, (clip_blur*2+1,clip_blur*2+1), 0) | |
img_mask *= clip_mask | |
img_mask[img_mask<0.0] = 0.0 | |
return img_mask | |
def Release(self): | |
self.model_clip = None | |