testing new latent space visualization code
Browse files
ContraCLIP/traverse_latent_space.py
CHANGED
|
@@ -10,6 +10,10 @@ from torchvision.transforms import ToPILImage
|
|
| 10 |
from lib import SupportSets, GENFORCE_MODELS, update_progress, update_stdout, STYLEGAN_LAYERS
|
| 11 |
from models.load_generator import load_generator
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
class DataParallelPassthrough(nn.DataParallel):
|
| 15 |
def __getattr__(self, name):
|
|
@@ -97,6 +101,63 @@ def create_gif(image_list, gif_height=256):
|
|
| 97 |
|
| 98 |
return transformed_images_gif_frames
|
| 99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
def main():
|
| 102 |
"""ContraCLIP -- Latent space traversal script.
|
|
@@ -210,6 +271,8 @@ def main():
|
|
| 210 |
# -- Get prompt corpus list
|
| 211 |
with open(osp.join(models_dir, 'semantic_dipoles.json'), 'r') as f:
|
| 212 |
semantic_dipoles = json.load(f)
|
|
|
|
|
|
|
| 213 |
|
| 214 |
# Check given pool directory
|
| 215 |
pool = osp.join('experiments', 'latent_codes', gan, args.pool)
|
|
@@ -321,6 +384,9 @@ def main():
|
|
| 321 |
print(" \\__Shift steps : {}".format(2 * args.shift_steps))
|
| 322 |
print(" \\__Traversal length : {}".format(round(2 * args.shift_steps * args.eps, 3)))
|
| 323 |
|
|
|
|
|
|
|
|
|
|
| 324 |
# Iterate over given latent codes
|
| 325 |
for i in range(num_of_latent_codes):
|
| 326 |
# Get latent code
|
|
@@ -333,6 +399,9 @@ def main():
|
|
| 333 |
num_of_latent_codes),
|
| 334 |
num_of_latent_codes, i)
|
| 335 |
|
|
|
|
|
|
|
|
|
|
| 336 |
# Create directory for current latent code
|
| 337 |
latent_code_dir = osp.join(out_dir, '{}'.format(latent_code_hash))
|
| 338 |
os.makedirs(latent_code_dir, exist_ok=True)
|
|
@@ -386,7 +455,7 @@ def main():
|
|
| 386 |
latent_code = latent_code[:, 0, :]
|
| 387 |
|
| 388 |
cnt = 0
|
| 389 |
-
for
|
| 390 |
cnt += 1
|
| 391 |
|
| 392 |
# Calculate shift vector based on current z
|
|
@@ -410,6 +479,10 @@ def main():
|
|
| 410 |
latent_code = latent_code + shift
|
| 411 |
current_path_latent_code = latent_code
|
| 412 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 413 |
# Store latent codes and shifts
|
| 414 |
if cnt == args.shift_leap:
|
| 415 |
if ('stylegan' in gan) and (stylegan_space == 'W+'):
|
|
@@ -421,6 +494,8 @@ def main():
|
|
| 421 |
current_path_latent_codes.append(current_path_latent_code)
|
| 422 |
cnt = 0
|
| 423 |
positive_endpoint = latent_code.clone().reshape(1, -1)
|
|
|
|
|
|
|
| 424 |
# ========================
|
| 425 |
|
| 426 |
# == Negative direction ==
|
|
@@ -430,7 +505,7 @@ def main():
|
|
| 430 |
if stylegan_space == 'W':
|
| 431 |
latent_code = latent_code[:, 0, :]
|
| 432 |
cnt = 0
|
| 433 |
-
for
|
| 434 |
cnt += 1
|
| 435 |
# Calculate shift vector based on current z
|
| 436 |
support_sets_mask = torch.zeros(1, LSS.num_support_sets)
|
|
@@ -453,6 +528,10 @@ def main():
|
|
| 453 |
latent_code = latent_code + shift
|
| 454 |
current_path_latent_code = latent_code
|
| 455 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 456 |
# Store latent codes and shifts
|
| 457 |
if cnt == args.shift_leap:
|
| 458 |
if ('stylegan' in gan) and (stylegan_space == 'W+'):
|
|
@@ -464,6 +543,8 @@ def main():
|
|
| 464 |
current_path_latent_codes = [current_path_latent_code] + current_path_latent_codes
|
| 465 |
cnt = 0
|
| 466 |
negative_endpoint = latent_code.clone().reshape(1, -1)
|
|
|
|
|
|
|
| 467 |
# ========================
|
| 468 |
|
| 469 |
# Calculate latent path phi coefficient (end-to-end distance / latent path length)
|
|
@@ -531,13 +612,69 @@ def main():
|
|
| 531 |
|
| 532 |
# Save all latent paths and shifts for the current latent code (sample) in a tensor of size:
|
| 533 |
# paths_latent_codes : torch.Size([num_gen_paths, 2 * args.shift_steps + 1, G.dim_z])
|
| 534 |
-
torch.
|
|
|
|
|
|
|
| 535 |
|
| 536 |
if args.verbose:
|
| 537 |
update_stdout(1)
|
| 538 |
print()
|
| 539 |
print()
|
| 540 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 541 |
# Create summarizing MD files
|
| 542 |
if args.gif or args.strip:
|
| 543 |
# For each interpretable path (warping function), collect the generated image sequences for each original latent
|
|
|
|
| 10 |
from lib import SupportSets, GENFORCE_MODELS, update_progress, update_stdout, STYLEGAN_LAYERS
|
| 11 |
from models.load_generator import load_generator
|
| 12 |
|
| 13 |
+
import numpy as np
|
| 14 |
+
import matplotlib.pyplot as plt
|
| 15 |
+
from mpl_toolkits.mplot3d import Axes3D
|
| 16 |
+
from sklearn.manifold import TSNE
|
| 17 |
|
| 18 |
class DataParallelPassthrough(nn.DataParallel):
|
| 19 |
def __getattr__(self, name):
|
|
|
|
| 101 |
|
| 102 |
return transformed_images_gif_frames
|
| 103 |
|
| 104 |
+
def visualize_latent_space(tsne_latent_codes, semantic_dipoles, output_dir, save_filename="latent_space_tsne.png", shift_steps=16):
|
| 105 |
+
"""
|
| 106 |
+
Visualize the t-SNE reduced latent space with minimal annotations.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
tsne_latent_codes (np.ndarray): The 3D latent codes after t-SNE transformation.
|
| 110 |
+
semantic_dipoles (list): List of semantic directions (labels) for paths.
|
| 111 |
+
shift_steps (int): Number of positive/negative steps along each path.
|
| 112 |
+
output_dir (str): Directory to save the generated plot.
|
| 113 |
+
save_filename (str): Name of the file to save the plot.
|
| 114 |
+
"""
|
| 115 |
+
fig = plt.figure(figsize=(16, 12)) # Larger figure for clarity
|
| 116 |
+
ax = fig.add_subplot(111, projection='3d')
|
| 117 |
+
|
| 118 |
+
num_paths = len(semantic_dipoles) # Each dipole represents one path
|
| 119 |
+
cmap = plt.cm.get_cmap('tab10', num_paths)
|
| 120 |
+
|
| 121 |
+
for i in range(num_paths):
|
| 122 |
+
# Indices for the path in tsne_latent_codes
|
| 123 |
+
start_idx = i * (2 * shift_steps + 1)
|
| 124 |
+
pos_idx = start_idx + shift_steps # Positive endpoint
|
| 125 |
+
neg_idx = start_idx + 2 * shift_steps # Negative endpoint
|
| 126 |
+
|
| 127 |
+
# Extract path points
|
| 128 |
+
path_indices = list(range(start_idx, neg_idx + 1))
|
| 129 |
+
path_coords = tsne_latent_codes[path_indices]
|
| 130 |
+
|
| 131 |
+
# Plot the entire path (all intermediate points in a single color)
|
| 132 |
+
ax.plot(
|
| 133 |
+
path_coords[:, 0], path_coords[:, 1], path_coords[:, 2],
|
| 134 |
+
color=cmap(i),
|
| 135 |
+
linewidth=2
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# Extract positive and negative endpoint coordinates
|
| 139 |
+
pos_coords = tsne_latent_codes[pos_idx]
|
| 140 |
+
neg_coords = tsne_latent_codes[neg_idx]
|
| 141 |
+
|
| 142 |
+
# Plot positive and negative endpoints
|
| 143 |
+
ax.scatter(*pos_coords, color=cmap(i), s=100, label=f"{semantic_dipoles[i][0]} → {semantic_dipoles[i][1]}")
|
| 144 |
+
ax.scatter(*neg_coords, color=cmap(i), s=100)
|
| 145 |
+
|
| 146 |
+
# Add legend
|
| 147 |
+
ax.legend(loc='best', fontsize=10)
|
| 148 |
+
|
| 149 |
+
# Set titles and labels
|
| 150 |
+
ax.set_title("t-SNE Latent Space Visualization")
|
| 151 |
+
ax.set_xlabel("t-SNE Dimension 1")
|
| 152 |
+
ax.set_ylabel("t-SNE Dimension 2")
|
| 153 |
+
ax.set_zlabel("t-SNE Dimension 3")
|
| 154 |
+
|
| 155 |
+
# Save the plot
|
| 156 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 157 |
+
save_path = osp.join(output_dir, save_filename)
|
| 158 |
+
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
| 159 |
+
print(f"Visualization saved to {save_path}")
|
| 160 |
+
|
| 161 |
|
| 162 |
def main():
|
| 163 |
"""ContraCLIP -- Latent space traversal script.
|
|
|
|
| 271 |
# -- Get prompt corpus list
|
| 272 |
with open(osp.join(models_dir, 'semantic_dipoles.json'), 'r') as f:
|
| 273 |
semantic_dipoles = json.load(f)
|
| 274 |
+
|
| 275 |
+
# semantic_directions = [f"{dipole[0]} → {dipole[1]}" for dipole in semantic_dipoles]
|
| 276 |
|
| 277 |
# Check given pool directory
|
| 278 |
pool = osp.join('experiments', 'latent_codes', gan, args.pool)
|
|
|
|
| 384 |
print(" \\__Shift steps : {}".format(2 * args.shift_steps))
|
| 385 |
print(" \\__Traversal length : {}".format(round(2 * args.shift_steps * args.eps, 3)))
|
| 386 |
|
| 387 |
+
# Store latent codes for T-SNE visualization (for all paths across each latent code)
|
| 388 |
+
all_paths_latent_codes = []
|
| 389 |
+
|
| 390 |
# Iterate over given latent codes
|
| 391 |
for i in range(num_of_latent_codes):
|
| 392 |
# Get latent code
|
|
|
|
| 399 |
num_of_latent_codes),
|
| 400 |
num_of_latent_codes, i)
|
| 401 |
|
| 402 |
+
# Append the starting latent code to tsne_latent_codes
|
| 403 |
+
# tsne_latent_codes.append(x_.clone().cpu().numpy().flatten())
|
| 404 |
+
|
| 405 |
# Create directory for current latent code
|
| 406 |
latent_code_dir = osp.join(out_dir, '{}'.format(latent_code_hash))
|
| 407 |
os.makedirs(latent_code_dir, exist_ok=True)
|
|
|
|
| 455 |
latent_code = latent_code[:, 0, :]
|
| 456 |
|
| 457 |
cnt = 0
|
| 458 |
+
for k in range(args.shift_steps):
|
| 459 |
cnt += 1
|
| 460 |
|
| 461 |
# Calculate shift vector based on current z
|
|
|
|
| 479 |
latent_code = latent_code + shift
|
| 480 |
current_path_latent_code = latent_code
|
| 481 |
|
| 482 |
+
# Append intermediate latent code
|
| 483 |
+
# if k != args.shift_steps - 1:
|
| 484 |
+
# tsne_latent_codes.append(latent_code.clone().cpu().numpy().flatten())
|
| 485 |
+
|
| 486 |
# Store latent codes and shifts
|
| 487 |
if cnt == args.shift_leap:
|
| 488 |
if ('stylegan' in gan) and (stylegan_space == 'W+'):
|
|
|
|
| 494 |
current_path_latent_codes.append(current_path_latent_code)
|
| 495 |
cnt = 0
|
| 496 |
positive_endpoint = latent_code.clone().reshape(1, -1)
|
| 497 |
+
|
| 498 |
+
# tsne_latent_codes.append(positive_endpoint.clone().cpu().numpy().flatten())
|
| 499 |
# ========================
|
| 500 |
|
| 501 |
# == Negative direction ==
|
|
|
|
| 505 |
if stylegan_space == 'W':
|
| 506 |
latent_code = latent_code[:, 0, :]
|
| 507 |
cnt = 0
|
| 508 |
+
for k in range(args.shift_steps):
|
| 509 |
cnt += 1
|
| 510 |
# Calculate shift vector based on current z
|
| 511 |
support_sets_mask = torch.zeros(1, LSS.num_support_sets)
|
|
|
|
| 528 |
latent_code = latent_code + shift
|
| 529 |
current_path_latent_code = latent_code
|
| 530 |
|
| 531 |
+
# Append intermediate latent code
|
| 532 |
+
# if k != args.shift_steps - 1:
|
| 533 |
+
# tsne_latent_codes.append(latent_code.clone().cpu().numpy().flatten())
|
| 534 |
+
|
| 535 |
# Store latent codes and shifts
|
| 536 |
if cnt == args.shift_leap:
|
| 537 |
if ('stylegan' in gan) and (stylegan_space == 'W+'):
|
|
|
|
| 543 |
current_path_latent_codes = [current_path_latent_code] + current_path_latent_codes
|
| 544 |
cnt = 0
|
| 545 |
negative_endpoint = latent_code.clone().reshape(1, -1)
|
| 546 |
+
|
| 547 |
+
# tsne_latent_codes.append(latent_code.clone().cpu().numpy().flatten())
|
| 548 |
# ========================
|
| 549 |
|
| 550 |
# Calculate latent path phi coefficient (end-to-end distance / latent path length)
|
|
|
|
| 612 |
|
| 613 |
# Save all latent paths and shifts for the current latent code (sample) in a tensor of size:
|
| 614 |
# paths_latent_codes : torch.Size([num_gen_paths, 2 * args.shift_steps + 1, G.dim_z])
|
| 615 |
+
paths_latent_codes_tensor = torch.cat(paths_latent_codes)
|
| 616 |
+
torch.save(paths_latent_codes_tensor, osp.join(latent_code_dir, 'paths_latent_codes.pt'))
|
| 617 |
+
all_paths_latent_codes.append(paths_latent_codes_tensor.cpu().numpy())
|
| 618 |
|
| 619 |
if args.verbose:
|
| 620 |
update_stdout(1)
|
| 621 |
print()
|
| 622 |
print()
|
| 623 |
|
| 624 |
+
# After processing all latent codes and paths
|
| 625 |
+
if args.verbose:
|
| 626 |
+
print("Performing t-SNE on latent codes for visualization...")
|
| 627 |
+
|
| 628 |
+
# # Consolidate all paths for T-SNE visualization (total_paths = num_of_latent_codes * num_gen_paths)
|
| 629 |
+
# all_paths_np = np.concatenate(all_paths_latent_codes, axis=0) # Shape: [total_paths, steps_per_path, latent_dim]
|
| 630 |
+
# all_paths_flattened = all_paths_np.reshape(-1, all_paths_np.shape[-1]) # Flatten paths into 2D array for T-SNE
|
| 631 |
+
|
| 632 |
+
# # Apply 3D T-SNE
|
| 633 |
+
# tsne_model = TSNE(n_components=3, perplexity=30, learning_rate=200, random_state=42)
|
| 634 |
+
# tsne_transformed = tsne_model.fit_transform(all_paths_flattened) # Shape: [total_points, 3]
|
| 635 |
+
|
| 636 |
+
# path_indices = [] # List to store indices for each path
|
| 637 |
+
# start_idx = 0 # Starting index for the current path in all_paths_np
|
| 638 |
+
|
| 639 |
+
# steps_per_path = 2 * args.shift_steps + 1 # Number of points in each path
|
| 640 |
+
|
| 641 |
+
# # Iterate over each latent code and its paths
|
| 642 |
+
# for i in range(num_of_latent_codes): # Loop through latent codes
|
| 643 |
+
# for dim in range(num_gen_paths): # Loop through directions (paths)
|
| 644 |
+
# # Generate the indices for this path
|
| 645 |
+
# indices = list(range(start_idx, start_idx + steps_per_path))
|
| 646 |
+
# path_indices.append(indices)
|
| 647 |
+
|
| 648 |
+
# # Update the starting index for the next path
|
| 649 |
+
# start_idx += steps_per_path
|
| 650 |
+
|
| 651 |
+
|
| 652 |
+
all_paths_latent_code_0 = all_paths_latent_codes[0]
|
| 653 |
+
num_paths, num_steps, _ = all_paths_latent_code_0.shape
|
| 654 |
+
tsne_latent_codes = all_paths_latent_code_0.reshape(-1, all_paths_latent_code_0.shape[-1])
|
| 655 |
+
|
| 656 |
+
# Apply 3D T-SNE
|
| 657 |
+
tsne_model = TSNE(n_components=3, perplexity=30, learning_rate=200, random_state=42)
|
| 658 |
+
tsne_transformed = tsne_model.fit_transform(tsne_latent_codes) # Shape: [total_points = num_paths * num_steps, 3]
|
| 659 |
+
|
| 660 |
+
# For this specific latent code, generate indices for each of its paths
|
| 661 |
+
path_indices = []
|
| 662 |
+
start_idx = 0
|
| 663 |
+
for _ in range(num_paths):
|
| 664 |
+
indices = list(range(start_idx, start_idx + num_steps))
|
| 665 |
+
path_indices.append(indices)
|
| 666 |
+
start_idx += num_steps
|
| 667 |
+
|
| 668 |
+
|
| 669 |
+
tsne_vis_dir = osp.join(out_dir, 'tsne_visualizations')
|
| 670 |
+
visualize_latent_space(
|
| 671 |
+
tsne_latent_codes=tsne_transformed, # T-SNE-reduced latent codes
|
| 672 |
+
semantic_dipoles=semantic_dipoles, # Semantic labels for paths
|
| 673 |
+
paths=path_indices, # Indices of paths (for a single latent code)
|
| 674 |
+
output_dir=tsne_vis_dir,
|
| 675 |
+
save_filename="latent_space_tsne.png"
|
| 676 |
+
)
|
| 677 |
+
|
| 678 |
# Create summarizing MD files
|
| 679 |
if args.gif or args.strip:
|
| 680 |
# For each interpretable path (warping function), collect the generated image sequences for each original latent
|