Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
import os | |
import re | |
import ftfy | |
import torch | |
import html | |
from PIL import Image | |
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer, CLIPTextModel | |
class CLIPEmbedder: | |
""" | |
A class for embedding texts and images using a pretrained CLIP model. | |
""" | |
def __init__(self, device='cuda', model_name='openai/clip-vit-base-patch32', cache_dir='./cache_dir', use_text_preprocessing=True, max_length=77): | |
""" | |
Initializes the CLIPEmbedder with specified model and configurations. | |
""" | |
self.device = torch.device(device) | |
self.model_name = model_name | |
self.cache_dir = cache_dir | |
self.use_text_preprocessing = use_text_preprocessing | |
self.max_length = max_length | |
os.makedirs(self.cache_dir, exist_ok=True) | |
self.processor = CLIPProcessor.from_pretrained(model_name, cache_dir=self.cache_dir) | |
self.model = CLIPModel.from_pretrained(model_name, cache_dir=self.cache_dir).to(self.device).eval() | |
self.tokenizer = CLIPTokenizer.from_pretrained(model_name) | |
self.text_model = CLIPTextModel.from_pretrained(model_name, cache_dir=self.cache_dir).to(self.device).eval() | |
for param in self.text_model.parameters(): | |
param.requires_grad = False | |
def get_text_embeddings(self, texts): | |
""" | |
Generates embeddings for a list of text prompts. | |
""" | |
self._validate_input_list(texts, str) | |
if self.use_text_preprocessing: | |
texts = [self._clean_text(text) for text in texts] | |
inputs = self.processor(text=texts, return_tensors="pt", padding=True, truncation=True, max_length=self.max_length).to(self.device) | |
with torch.no_grad(): | |
embeddings = self.model.get_text_features(**inputs) | |
return embeddings | |
def encode_text(self, texts): | |
""" | |
Encodes texts into embeddings and returns the last hidden state and pooled output. | |
""" | |
self._validate_input_list(texts, str) | |
batch_encoding = self.tokenizer(texts, return_tensors="pt", truncation=True, max_length=self.max_length, padding="max_length").to(self.device) | |
with torch.no_grad(): | |
outputs = self.text_model(**batch_encoding) | |
return outputs.last_hidden_state, outputs.pooler_output | |
def get_image_embeddings(self, image_paths): | |
""" | |
Generates embeddings for a list of image file paths. | |
""" | |
self._validate_input_list(image_paths, str) | |
images = [self._load_image(path) for path in image_paths] | |
inputs = self.processor(images=images, return_tensors="pt").to(self.device) | |
with torch.no_grad(): | |
embeddings = self.model.get_image_features(**inputs) | |
return embeddings | |
def _validate_input_list(self, input_list, expected_type): | |
""" | |
Validates that the input is a list of expected type. | |
""" | |
if not isinstance(input_list, list) or not all(isinstance(item, expected_type) for item in input_list): | |
raise ValueError(f"Input must be a list of {expected_type.__name__}.") | |
def _clean_text(self, text): | |
""" | |
Applies basic cleaning and formatting to a text string. | |
""" | |
text = ftfy.fix_text(text) | |
text = html.unescape(text) | |
return text.strip() | |
def _load_image(self, image_path): | |
""" | |
Loads and preprocesses an image from a file path. | |
""" | |
try: | |
image = Image.open(image_path).convert("RGB") | |
except FileNotFoundError: | |
raise FileNotFoundError(f"Image file not found: {image_path}") | |
except Exception as e: | |
raise Exception(f"Error loading image {image_path}: {e}") | |
return image | |
def clean_caption(self, caption): | |
caption = str(caption) | |
caption = ul.unquote_plus(caption) | |
caption = caption.strip().lower() | |
caption = re.sub('<person>', 'person', caption) | |
# urls: | |
caption = re.sub( | |
r'\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa | |
'', caption) # regex for urls | |
caption = re.sub( | |
r'\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa | |
'', caption) # regex for urls | |
caption = BeautifulSoup(caption, features='html.parser').text | |
caption = re.sub(r'@[\w\d]+\b', '', caption) | |
caption = re.sub(r'[\u31c0-\u31ef]+', '', caption) | |
caption = re.sub(r'[\u31f0-\u31ff]+', '', caption) | |
caption = re.sub(r'[\u3200-\u32ff]+', '', caption) | |
caption = re.sub(r'[\u3300-\u33ff]+', '', caption) | |
caption = re.sub(r'[\u3400-\u4dbf]+', '', caption) | |
caption = re.sub(r'[\u4dc0-\u4dff]+', '', caption) | |
caption = re.sub(r'[\u4e00-\u9fff]+', '', caption) | |
caption = re.sub( | |
r'[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+', # noqa | |
'-', caption) | |
caption = re.sub(r'[`´«»“”¨]', '"', caption) | |
caption = re.sub(r'[‘’]', "'", caption) | |
caption = re.sub(r'"?', '', caption) | |
caption = re.sub(r'&', '', caption) | |
caption = re.sub(r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}', ' ', caption) | |
caption = re.sub(r'\d:\d\d\s+$', '', caption) | |
caption = re.sub(r'\\n', ' ', caption) | |
caption = re.sub(r'#\d{1,3}\b', '', caption) | |
caption = re.sub(r'#\d{5,}\b', '', caption) | |
caption = re.sub(r'\b\d{6,}\b', '', caption) | |
caption = re.sub(r'[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)', '', caption) | |
caption = re.sub(r'[\"\']{2,}', r'"', caption) | |
caption = re.sub(r'[\.]{2,}', r' ', caption) | |
caption = re.sub(self.bad_punct_regex, r' ', caption) | |
caption = re.sub(r'\s+\.\s+', r' ', caption) | |
regex2 = re.compile(r'(?:\-|\_)') | |
if len(re.findall(regex2, caption)) > 3: | |
caption = re.sub(regex2, ' ', caption) | |
caption = self.basic_clean(caption) | |
caption = re.sub(r'\b[a-zA-Z]{1,3}\d{3,15}\b', '', caption) # jc6640 | |
caption = re.sub(r'\b[a-zA-Z]+\d+[a-zA-Z]+\b', '', caption) # jc6640vc | |
caption = re.sub(r'\b\d+[a-zA-Z]+\d+\b', '', caption) # 6640vc231 | |
caption = re.sub(r'(worldwide\s+)?(free\s+)?shipping', '', caption) | |
caption = re.sub(r'(free\s)?download(\sfree)?', '', caption) | |
caption = re.sub(r'\bclick\b\s(?:for|on)\s\w+', '', caption) | |
caption = re.sub(r'\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?', '', caption) | |
caption = re.sub(r'\bpage\s+\d+\b', '', caption) | |
caption = re.sub(r'\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b', r' ', caption) # j2d1a2a... | |
caption = re.sub(r'\b\d+\.?\d*[xх×]\d+\.?\d*\b', '', caption) | |
caption = re.sub(r'\b\s+\:\s+', r': ', caption) | |
caption = re.sub(r'(\D[,\./])\b', r'\1 ', caption) | |
caption = re.sub(r'\s+', ' ', caption) | |
caption.strip() | |
caption = re.sub(r'^[\"\']([\w\W]+)[\"\']$', r'\1', caption) | |
caption = re.sub(r'^[\'\_,\-\:;]', r'', caption) | |
caption = re.sub(r'[\'\_,\-\:\-\+]$', r'', caption) | |
caption = re.sub(r'^\.\S+$', '', caption) | |
return caption.strip() | |
def basic_clean(text): | |
text = ftfy.fix_text(text) | |
text = html.unescape(html.unescape(text)) | |
return text.strip() | |
if __name__ == '__main__': | |
clip_embedder = CLIPEmbedder() | |
# Example | |
text_prompts = [ | |
'A photo of a cute puppy playing with a ball.', | |
'An image of a beautiful sunset over the ocean.', | |
'A scene depicting a busy city street.' | |
] | |
text_embeddings = clip_embedder.get_text_embeddings(text_prompts) | |
print(f"Text embeddings shape: {text_embeddings.shape}") | |
image_paths = ['image1.jpg', 'image2.png'] | |
try: | |
image_embeddings = clip_embedder.get_image_embeddings(image_paths) | |
print(f"Image embeddings shape: {image_embeddings.shape}") | |
except FileNotFoundError as e: | |
print(e) | |
except Exception as e: | |
print(f"An error occurred: {e}") | |