spdraptor commited on
Commit
fe074c0
·
verified ·
1 Parent(s): 0a05d1f

Upload back_task.py

Browse files
Files changed (1) hide show
  1. back_task.py +55 -0
back_task.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import torchvision
5
+ from datasets import load_dataset
6
+ from diffusers import DDIMScheduler,DDPMPipeline
7
+ from matplotlib import pyplot as plt
8
+ from PIL import Image
9
+ from torchvision import transforms
10
+ from tqdm.auto import tqdm
11
+ import open_clip
12
+
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ print(f"Using device: {device}")
15
+
16
+ # Load the pretrained pipeline
17
+ pipeline_name = "johnowhitaker/sd-class-wikiart-from-bedrooms"
18
+ image_pipe = DDPMPipeline.from_pretrained(pipeline_name).to(device)
19
+
20
+ clip_model,_,preprocess = open_clip.create_model_and_transforms("ViT-B-32",pretrained="openai")
21
+ clip_model.to(device)
22
+
23
+ # Sample some images with a DDIM Scheduler over 40 steps
24
+ scheduler = DDIMScheduler.from_pretrained(pipeline_name)
25
+ scheduler.set_timesteps(num_inference_steps=40)
26
+
27
+
28
+ def color_loss(images,target_color=(0.1,0.9,0.5)):
29
+
30
+ target=torch.tensor(target_color).to(images.device)*2-1
31
+ target=target[None, :, None,None]
32
+ error=torch.abs(images-target).mean()
33
+ return error
34
+
35
+
36
+ tfms = torchvision.transforms.Compose(
37
+ [
38
+ torchvision.transforms.RandomResizedCrop(224),
39
+ torchvision.transforms.RandomAffine(5),
40
+ torchvision.transforms.RandomHorizontalFlip(),
41
+ torchvision.transforms.Normalize(
42
+ mean=(0.48145466, 0.4578275, 0.40821073),
43
+ std=(0.26862954, 0.26130258, 0.27577711),
44
+ )
45
+ ]
46
+ )
47
+
48
+ def clip_loss(image, text_features):
49
+ image_features = clip_model.encode_image(tfms(image)) # Note: applies the above transforms
50
+ input_normed = torch.nn.functional.normalize(image_features.unsqueeze(1), dim=2)
51
+ embed_normed = torch.nn.functional.normalize(text_features.unsqueeze(0), dim=2)
52
+ dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2) # Squared Great Circle Distance
53
+ return dists.mean()
54
+
55
+ n_cuts = 4 # @param