import torch import torch.nn as nn import torch.nn.functional as F from torchvision.transforms import v2 from transformers import CLIPTextModel, CLIPTokenizer, \ CLIPProcessor, CLIPVisionModelWithProjection, CLIPTextModelWithProjection import os # from image_generator import get_output_embeds, position_embeddings # Set device torch_device = "cuda" if torch.cuda.is_available() else "mps" \ if torch.backends.mps.is_available() else "cpu" if "mps" == torch_device: os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1" # Load the tokenizer and text encoder to tokenize and encode the text. clip_model_name = "openai/clip-vit-large-patch14" tokenizer = CLIPTokenizer.from_pretrained(clip_model_name) text_encoder = CLIPTextModel.from_pretrained(clip_model_name).to(torch_device); vision_encoder = CLIPVisionModelWithProjection.from_pretrained(clip_model_name).to(torch_device); processor = CLIPProcessor.from_pretrained(clip_model_name) # # additional textual prompt def get_text_embed(prompt = "on a mountain"): inputs = processor(text=prompt, return_tensors="pt", padding=True) with torch.no_grad(): text_embed = CLIPTextModelWithProjection.from_pretrained( clip_model_name)(**inputs).text_embeds.to(torch_device) return text_embed # def get_text_embed(prompt = "on a mountain"): # text_input = tokenizer([prompt], # padding="max_length", # max_length=tokenizer.model_max_length, # truncation=True, # return_tensors="pt") # with torch.no_grad(): # text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0] # input_embeddings = text_embeddings + position_embeddings.to(torch_device) # modified_output_embeddings = get_output_embeds(input_embeddings) # return modified_output_embeddings class cosine_loss(nn.Module): def __init__(self, prompt) -> None: self.text_embed = get_text_embed(prompt) super().__init__() def forward(self, gen_image): gen_image_clamped = gen_image.clamp(0, 1).mul(255) resized_image = v2.Resize(224)(gen_image_clamped) image_embed = vision_encoder(resized_image).image_embeds similarity = F.cosine_similarity(self.text_embed, image_embed, dim=1) loss = 1 - similarity.mean() return loss def blue_loss(images): # How far are the blue channel values to 0.9: error = torch.abs(images[:,2] - 0.9).mean() # [:,2] -> all images in batch, only the blue channel return error