File size: 2,623 Bytes
99f1917
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import v2
from transformers import CLIPTextModel, CLIPTokenizer, \
 CLIPProcessor, CLIPVisionModelWithProjection, CLIPTextModelWithProjection

import os
# from image_generator import get_output_embeds, position_embeddings


# Set device
torch_device = "cuda" if torch.cuda.is_available() else "mps" \
    if torch.backends.mps.is_available() else "cpu"

if "mps" == torch_device: 
    os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1"

# Load the tokenizer and text encoder to tokenize and encode the text.
clip_model_name = "openai/clip-vit-large-patch14"
tokenizer = CLIPTokenizer.from_pretrained(clip_model_name)
text_encoder = CLIPTextModel.from_pretrained(clip_model_name).to(torch_device);
vision_encoder = CLIPVisionModelWithProjection.from_pretrained(clip_model_name).to(torch_device);
processor = CLIPProcessor.from_pretrained(clip_model_name)

# # additional textual prompt
def get_text_embed(prompt = "on a mountain"):
    inputs = processor(text=prompt,
                       return_tensors="pt",
                       padding=True)
    with torch.no_grad():
        text_embed = CLIPTextModelWithProjection.from_pretrained(
            clip_model_name)(**inputs).text_embeds.to(torch_device)
    return text_embed

# def get_text_embed(prompt = "on a mountain"):
#     text_input = tokenizer([prompt],
#                            padding="max_length",
#                            max_length=tokenizer.model_max_length,
#                            truncation=True,
#                            return_tensors="pt")
#     with torch.no_grad():
#         text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
#         input_embeddings = text_embeddings + position_embeddings.to(torch_device)
#     modified_output_embeddings = get_output_embeds(input_embeddings)
#     return modified_output_embeddings

class cosine_loss(nn.Module):
    def __init__(self, prompt) -> None:
        self.text_embed = get_text_embed(prompt)
        super().__init__()

    def forward(self, gen_image):
        gen_image_clamped = gen_image.clamp(0, 1).mul(255)
        resized_image = v2.Resize(224)(gen_image_clamped)
        image_embed = vision_encoder(resized_image).image_embeds
        similarity = F.cosine_similarity(self.text_embed, image_embed, dim=1)
        loss = 1 - similarity.mean()
        return loss

def blue_loss(images):
    # How far are the blue channel values to 0.9:
    error = torch.abs(images[:,2] - 0.9).mean() # [:,2] -> all images in batch, only the blue channel
    return error