Prathm commited on
Commit
8a5a3ed
·
1 Parent(s): 6b34730

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -4
app.py CHANGED
@@ -50,8 +50,8 @@ def clip_optimized_latent(text, seed, iterations=25, lr=1e-2):
50
  params = [torch.nn.Parameter(latent_vector[i], requires_grad=True) for i in range(len(latent_vector))]
51
  optimizer = Adam(params, lr=lr, betas=(0.9, 0.999))
52
 
53
- with torch.no_grad():
54
- text_features = clip_model.encode_text(text_input)
55
 
56
  #pbar = tqdm(range(iterations), dynamic_ncols=True)
57
 
@@ -67,10 +67,11 @@ def clip_optimized_latent(text, seed, iterations=25, lr=1e-2):
67
  #image = clip_preprocess(Image.fromarray((image_np * 255).astype(np.uint8))).unsqueeze(0).to(device)
68
 
69
  # Extract features from the image
70
- image_features = clip_model.encode_image(image)
71
 
72
  # Calculate the loss and backpropagate
73
- loss = -torch.cosine_similarity(text_features, image_features).mean()
 
74
  loss.backward()
75
  optimizer.step()
76
 
@@ -189,6 +190,8 @@ def load_model():
189
  clip_model, inst = load_model()
190
  model = inst.model
191
 
 
 
192
 
193
  path_to_components = get_or_compute(config, inst)
194
  comps = np.load(path_to_components)
@@ -227,6 +230,8 @@ if 'seed1' not in st.session_state and 'seed2' not in st.session_state:
227
  st.session_state['seed2'] = random.randint(1, 1000)
228
  seed1 = st.sidebar.number_input("Seed 1", value= st.session_state['seed1'])
229
  seed2 = st.sidebar.number_input("Seed 2", value= st.session_state['seed2'])
 
 
230
  iters = st.sidebar.number_input("Iterations for CLIP Optimization", value = 50)
231
  submit_button = st.sidebar.button("Submit")
232
  content = st.sidebar.slider("Structural Composition", min_value=0.0, max_value=1.0, value=0.5)
 
50
  params = [torch.nn.Parameter(latent_vector[i], requires_grad=True) for i in range(len(latent_vector))]
51
  optimizer = Adam(params, lr=lr, betas=(0.9, 0.999))
52
 
53
+ #with torch.no_grad():
54
+ # text_features = clip_model.encode_text(text_input)
55
 
56
  #pbar = tqdm(range(iterations), dynamic_ncols=True)
57
 
 
67
  #image = clip_preprocess(Image.fromarray((image_np * 255).astype(np.uint8))).unsqueeze(0).to(device)
68
 
69
  # Extract features from the image
70
+ #image_features = clip_model.encode_image(image)
71
 
72
  # Calculate the loss and backpropagate
73
+ loss = 1 - clip_model(image, text_input)[0] / 100
74
+ #loss = -torch.cosine_similarity(text_features, image_features).mean()
75
  loss.backward()
76
  optimizer.step()
77
 
 
190
  clip_model, inst = load_model()
191
  model = inst.model
192
 
193
+ similarity = 1 - clip_model(image, text)[0] / 100
194
+
195
 
196
  path_to_components = get_or_compute(config, inst)
197
  comps = np.load(path_to_components)
 
230
  st.session_state['seed2'] = random.randint(1, 1000)
231
  seed1 = st.sidebar.number_input("Seed 1", value= st.session_state['seed1'])
232
  seed2 = st.sidebar.number_input("Seed 2", value= st.session_state['seed2'])
233
+ st.session_state['seed1'] = seed1
234
+ st.session_state['seed2'] = seed1
235
  iters = st.sidebar.number_input("Iterations for CLIP Optimization", value = 50)
236
  submit_button = st.sidebar.button("Submit")
237
  content = st.sidebar.slider("Structural Composition", min_value=0.0, max_value=1.0, value=0.5)