Spaces:
Running
on
Zero
Running
on
Zero
# Authors: Hui Ren (rhfeiyang.github.io) | |
from transformers import CLIPProcessor, CLIPModel | |
import torch | |
import numpy as np | |
import os | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
from tqdm import tqdm | |
class Caption_filter: | |
def __init__(self, filter_prompts=["painting", "paintings", "art", "artwork", "drawings", "sketch", "sketches", "illustration", "illustrations", | |
"sculpture","sculptures", "installation", "printmaking", "digital art", "conceptual art", "mosaic", "tapestry", | |
"abstract", "realism", "surrealism", "impressionism", "expressionism", "cubism", "minimalism", "baroque", "rococo", | |
"pop art", "art nouveau", "art deco", "futurism", "dadaism", | |
"stamp", "stamps", "advertisement", "advertisements","logo", "logos" | |
],): | |
self.filter_prompts = filter_prompts | |
self.total_count=0 | |
self.filter_count=[0]*len(filter_prompts) | |
def reset(self): | |
self.total_count=0 | |
self.filter_count=[0]*len(self.filter_prompts) | |
def filter(self, captions): | |
filter_result = [] | |
for caption in captions: | |
words = caption[0] | |
if words == None: | |
filter_result.append((True, "None")) | |
continue | |
words = words.lower() | |
words = words.split() | |
filt = False | |
reason=None | |
for i, filter_keyword in enumerate(self.filter_prompts): | |
key_len = len(filter_keyword.split()) | |
for j in range(len(words)-key_len+1): | |
if " ".join(words[j:j+key_len]) == filter_keyword: | |
self.filter_count[i] += 1 | |
filt = True | |
reason = filter_keyword | |
break | |
if filt: | |
break | |
filter_result.append((filt, reason)) | |
self.total_count += 1 | |
return filter_result | |
class Clip_filter: | |
prompt_threshold = { | |
"painting": 17, | |
"art": 17.5, | |
"artwork": 19, | |
"drawing": 15.8, | |
"sketch": 17, | |
"illustration": 15, | |
"sculpture": 19.2, | |
"installation art": 20, | |
"printmaking art": 16.3, | |
"digital art": 15, | |
"conceptual art": 18, | |
"mosaic art": 19, | |
"tapestry": 16, | |
"abstract art":16.5, | |
"realism art": 16, | |
"surrealism art": 15, | |
"impressionism art": 17, | |
"expressionism art": 17, | |
"cubism art": 15, | |
"minimalism art": 16, | |
"baroque art": 17.5, | |
"rococo art": 17, | |
"pop art": 16, | |
"art nouveau": 19, | |
"art deco": 19, | |
"futurism art": 16.5, | |
"dadaism art": 16.5, | |
"stamp": 18, | |
"advertisement": 16.5, | |
"logo": 15.5, | |
} | |
def __init__(self, positive_prompt=["painting", "art", "artwork", "drawing", "sketch", "illustration", | |
"sculpture", "installation art", "printmaking art", "digital art", "conceptual art", "mosaic art", "tapestry", | |
"abstract art", "realism art", "surrealism art", "impressionism art", "expressionism art", "cubism art", | |
"minimalism art", "baroque art", "rococo art", | |
"pop art", "art nouveau", "art deco", "futurism art", "dadaism art", | |
"stamp", "advertisement", | |
"logo" | |
], | |
device="cuda"): | |
self.device = device | |
self.model = (CLIPModel.from_pretrained("openai/clip-vit-large-patch14")).to(device) | |
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") | |
self.positive_prompt = positive_prompt | |
self.text = self.positive_prompt | |
self.tokenizer = self.processor.tokenizer | |
self.image_processor = self.processor.image_processor | |
self.text_encoding = self.tokenizer(self.text, return_tensors="pt", padding=True).to(device) | |
self.text_features = self.model.get_text_features(**self.text_encoding) | |
self.text_features = self.text_features / self.text_features.norm(p=2, dim=-1, keepdim=True) | |
def similarity(self, image): | |
# inputs = self.processor(text=self.text, images=image, return_tensors="pt", padding=True) | |
image_processed = self.image_processor(image, return_tensors="pt", padding=True).to(self.device, non_blocking=True) | |
inputs = {**self.text_encoding, **image_processed} | |
outputs = self.model(**inputs) | |
logits_per_image = outputs.logits_per_image | |
return logits_per_image | |
def get_logits(self, image): | |
logits_per_image = self.similarity(image) | |
return logits_per_image.cpu() | |
def get_image_features(self, image): | |
image_processed = self.image_processor(image, return_tensors="pt", padding=True).to(self.device, non_blocking=True) | |
image_features = self.model.get_image_features(**image_processed) | |
return image_features | |
class Art_filter: | |
def __init__(self): | |
self.caption_filter = Caption_filter() | |
self.clip_filter = Clip_filter() | |
def caption_filt(self, dataloader): | |
self.caption_filter.reset() | |
dataloader.dataset.get_img = False | |
dataloader.dataset.get_cap = True | |
remain_ids = [] | |
filtered_ids = [] | |
for i, batch in tqdm(enumerate(dataloader)): | |
captions = batch["text"] | |
filter_result = self.caption_filter.filter(captions) | |
for j, (filt, reason) in enumerate(filter_result): | |
if filt: | |
filtered_ids.append((batch["ids"][j], reason)) | |
if i%10==0: | |
print(f"Filtered caption: {captions[j]}, reason: {reason}") | |
else: | |
remain_ids.append(batch["ids"][j]) | |
return {"remain_ids":remain_ids, "filtered_ids":filtered_ids, "total_count":self.caption_filter.total_count, "filter_count":self.caption_filter.filter_count, "filter_prompts":self.caption_filter.filter_prompts} | |
def clip_filt(self, clip_logits_ckpt:dict): | |
logits = clip_logits_ckpt["clip_logits"] | |
ids = clip_logits_ckpt["ids"] | |
text = clip_logits_ckpt["text"] | |
filt_mask = torch.zeros(logits.shape[0], dtype=torch.bool) | |
for i, prompt in enumerate(text): | |
threshold = Clip_filter.prompt_threshold[prompt] | |
filt_mask = filt_mask | (logits[:,i] >= threshold) | |
filt_ids = [] | |
remain_ids = [] | |
for i, id in enumerate(ids): | |
if filt_mask[i]: | |
filt_ids.append(id) | |
else: | |
remain_ids.append(id) | |
return {"remain_ids":remain_ids, "filtered_ids":filt_ids} | |
def clip_feature(self, dataloader): | |
dataloader.dataset.get_img = True | |
dataloader.dataset.get_cap = False | |
clip_features = [] | |
ids = [] | |
for i, batch in enumerate(dataloader): | |
images = batch["images"] | |
features = self.clip_filter.get_image_features(images).cpu() | |
clip_features.append(features) | |
ids.extend(batch["ids"]) | |
clip_features = torch.cat(clip_features) | |
return {"clip_features":clip_features, "ids":ids} | |
def clip_logit(self, dataloader): | |
dataloader.dataset.get_img = True | |
dataloader.dataset.get_cap = False | |
clip_features = [] | |
clip_logits = [] | |
ids = [] | |
for i, batch in enumerate(dataloader): | |
images = batch["images"] | |
# logits = self.clip_filter.get_logits(images) | |
feature = self.clip_filter.get_image_features(images) | |
logits = self.clip_logit_by_feat(feature)["clip_logits"] | |
clip_features.append(feature) | |
clip_logits.append(logits) | |
ids.extend(batch["ids"]) | |
clip_features = torch.cat(clip_features) | |
clip_logits = torch.cat(clip_logits) | |
return {"clip_features":clip_features, "clip_logits":clip_logits, "ids":ids, "text": self.clip_filter.text} | |
def clip_logit_by_feat(self, feature): | |
feature = feature.clone().to(self.clip_filter.device) | |
feature = feature / feature.norm(p=2, dim=-1, keepdim=True) | |
logit_scale = self.clip_filter.model.logit_scale.exp() | |
logits = ((feature @ self.clip_filter.text_features.T) * logit_scale).cpu() | |
return {"clip_logits":logits, "text": self.clip_filter.text} | |
if __name__ == "__main__": | |
import pickle | |
with open("/vision-nfs/torralba/scratch/jomat/sam_dataset/filt_result/sa_000000/clip_logits_result.pickle","rb") as f: | |
result=pickle.load(f) | |
feat = result['clip_features'] | |
logits =Art_filter().clip_logit_by_feat(feat) | |
print(logits) | |