File size: 3,448 Bytes
4602ab6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d248fe
4602ab6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import pandas as pd
import numpy as np
from PIL import Image
import torch
import torchvision
import clip
import matplotlib.pyplot as plt
import seaborn as sns
import gradio as gr


DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

model_name = 'ViT-B/16' #@param  ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'ViT-B/32', 'ViT-B/16']
model, preprocess = clip.load(model_name)

model.to(DEVICE).eval()
resolution = model.visual.input_resolution
resizer = torchvision.transforms.Resize(size=(resolution, resolution))


def create_rgb_tensor(color):
  """color is e.g. [1,0,0]"""
  return torch.tensor(color, device=DEVICE).reshape((1, 3, 1, 1))

def encode_color(color):
  """color is e.g. [1,0,0]"""
  rgb = create_rgb_tensor(color)
  return model.encode_image( resizer(rgb) )

def encode_text(text):
  tokenized_text = clip.tokenize(text).to(DEVICE)
  return model.encode_text(tokenized_text)

def lerp(x, y, steps=11):
  """Linear interpolation between two tensors """

  weights = torch.tensor(np.linspace(0,1,steps), device=DEVICE).reshape([-1, 1, 1, 1])

  interpolated = x * (1 - weights) + y * weights

  return interpolated

def get_interpolated_scores(x, y, encoded_text, steps=11):
  interpolated = lerp(x, y, steps)
  interpolated_encodings = model.encode_image(resizer(interpolated))

  scores = torch.cosine_similarity(interpolated_encodings, encoded_text).detach().cpu().numpy()

  rgb = interpolated.detach().cpu().numpy().reshape(-1,  3)
  interpolated_hex = [rgb2hex(x) for x in rgb]  

  data = pd.DataFrame({
      'similarity': scores,
      'color': interpolated_hex
  }).reset_index().rename(columns={'index':'step'})

  return data

def rgb2hex(rgb):
    rgb = (rgb * 255).astype(int)
    r,g,b = rgb
    return "#{:02x}{:02x}{:02x}".format(r,g,b)


def similarity_plot(data, text_prompt):
  title = f'CLIP Cosine Similarity Prompt="{text_prompt}"'
  
  fig, ax = plt.subplots()
  plot = data['similarity'].plot(kind='bar',
                                 ax=ax,
                                 stacked=True,
                                 title=title,
                                 color=data['color'],
                                 width=1.0,
                                 xlim=(0, 2),
                                 grid=False)

  
  plot.get_xaxis().set_visible(False) ; 
  return fig



def interpolation_experiment(rgb_start, rgb_end, text_prompt, steps=11):

  start = create_rgb_tensor(rgb_start)
  end = create_rgb_tensor(rgb_end)
  encoded_text = encode_text(text_prompt)

  data = get_interpolated_scores(start, end, encoded_text, steps)
  return similarity_plot(data, text_prompt)




start_input = gr.inputs.Textbox(lines=1, default="1, 0, 0", label="Start RGB")
end_input = gr.inputs.Textbox(lines=1, default="0, 1, 0", label="End RGB")
' (Comma separated numbers between 0 and 1)'

text_input = gr.inputs.Textbox(lines=1, label="Text Prompt", default='A solid red square')

steps_input = gr.inputs.Slider(minimum=1, maximum=100, step=1, default=11, label="Interpolation Steps")

def gradio_fn(rgb_start, rgb_end, text_prompt, steps=11):

  rgb_start = [float(x.strip()) for x in rgb_start.split(',')]
  rgb_end =  [float(x.strip()) for x in rgb_end.split(',')]
  out = interpolation_experiment(rgb_start, rgb_end, text_prompt, steps)

  return out

iface = gr.Interface( fn=gradio_fn, inputs=[start_input, end_input, text_input, steps_input], outputs="plot")
iface.launch(debug=True, share=False)