try_WikiArt / back_task.py
spdraptor's picture
add all code
a66b96a
import numpy as np
import torch
import torch.nn.functional as F
import torchvision
from datasets import load_dataset
from diffusers import DDIMScheduler,DDPMPipeline
from matplotlib import pyplot as plt
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
import open_clip
#backend code
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load the pretrained pipeline
pipeline_name = "johnowhitaker/sd-class-wikiart-from-bedrooms"
image_pipe = DDPMPipeline.from_pretrained(pipeline_name).to(device)
clip_model,_,preprocess = open_clip.create_model_and_transforms("ViT-B-32",pretrained="openai")
clip_model.to(device)
# Sample some images with a DDIM Scheduler over 40 steps
scheduler = DDIMScheduler.from_pretrained(pipeline_name)
scheduler.set_timesteps(num_inference_steps=40)
def color_loss(images,target_color=(0.1,0.9,0.5)):
target=torch.tensor(target_color).to(images.device)*2-1
target=target[None, :, None,None]
error=torch.abs(images-target).mean()
return error
tfms = torchvision.transforms.Compose(
[
torchvision.transforms.RandomResizedCrop(224),
torchvision.transforms.RandomAffine(5),
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.Normalize(
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
)
]
)
def clip_loss(image, text_features):
image_features = clip_model.encode_image(tfms(image)) # Note: applies the above transforms
input_normed = torch.nn.functional.normalize(image_features.unsqueeze(1), dim=2)
embed_normed = torch.nn.functional.normalize(text_features.unsqueeze(0), dim=2)
dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2) # Squared Great Circle Distance
return dists.mean()
n_cuts = 4 # @param