Spaces:
Sleeping
Sleeping
Upload back_task.py
Browse files- back_task.py +55 -0
back_task.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torchvision
|
5 |
+
from datasets import load_dataset
|
6 |
+
from diffusers import DDIMScheduler,DDPMPipeline
|
7 |
+
from matplotlib import pyplot as plt
|
8 |
+
from PIL import Image
|
9 |
+
from torchvision import transforms
|
10 |
+
from tqdm.auto import tqdm
|
11 |
+
import open_clip
|
12 |
+
|
13 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
14 |
+
print(f"Using device: {device}")
|
15 |
+
|
16 |
+
# Load the pretrained pipeline
|
17 |
+
pipeline_name = "johnowhitaker/sd-class-wikiart-from-bedrooms"
|
18 |
+
image_pipe = DDPMPipeline.from_pretrained(pipeline_name).to(device)
|
19 |
+
|
20 |
+
clip_model,_,preprocess = open_clip.create_model_and_transforms("ViT-B-32",pretrained="openai")
|
21 |
+
clip_model.to(device)
|
22 |
+
|
23 |
+
# Sample some images with a DDIM Scheduler over 40 steps
|
24 |
+
scheduler = DDIMScheduler.from_pretrained(pipeline_name)
|
25 |
+
scheduler.set_timesteps(num_inference_steps=40)
|
26 |
+
|
27 |
+
|
28 |
+
def color_loss(images,target_color=(0.1,0.9,0.5)):
|
29 |
+
|
30 |
+
target=torch.tensor(target_color).to(images.device)*2-1
|
31 |
+
target=target[None, :, None,None]
|
32 |
+
error=torch.abs(images-target).mean()
|
33 |
+
return error
|
34 |
+
|
35 |
+
|
36 |
+
tfms = torchvision.transforms.Compose(
|
37 |
+
[
|
38 |
+
torchvision.transforms.RandomResizedCrop(224),
|
39 |
+
torchvision.transforms.RandomAffine(5),
|
40 |
+
torchvision.transforms.RandomHorizontalFlip(),
|
41 |
+
torchvision.transforms.Normalize(
|
42 |
+
mean=(0.48145466, 0.4578275, 0.40821073),
|
43 |
+
std=(0.26862954, 0.26130258, 0.27577711),
|
44 |
+
)
|
45 |
+
]
|
46 |
+
)
|
47 |
+
|
48 |
+
def clip_loss(image, text_features):
|
49 |
+
image_features = clip_model.encode_image(tfms(image)) # Note: applies the above transforms
|
50 |
+
input_normed = torch.nn.functional.normalize(image_features.unsqueeze(1), dim=2)
|
51 |
+
embed_normed = torch.nn.functional.normalize(text_features.unsqueeze(0), dim=2)
|
52 |
+
dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2) # Squared Great Circle Distance
|
53 |
+
return dists.mean()
|
54 |
+
|
55 |
+
n_cuts = 4 # @param
|