Spaces:
Build error
Build error
import io | |
import math | |
import pickle | |
import imageio | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from tqdm import tqdm | |
from moviepy.editor import ImageSequenceClip, VideoFileClip | |
class ClusterSOM: | |
def __init__(self): | |
self.hdbscan_model = None | |
self.som_models = {} | |
self.sigma_values = {} | |
self.mean_values = {} | |
self.cluster_mapping = {} | |
self.embedding = None | |
self.dim_red_op = None | |
def load(self, file_path): | |
""" | |
Load a ClusterSOM model from a file. | |
""" | |
with open(file_path, "rb") as f: | |
model_data = pickle.load(f) | |
self.hdbscan_model, self.som_models, self.mean_values, self.sigma_values, self.cluster_mapping = model_data[:5] | |
if len(model_data) > 5: | |
self.label_centroids, self.label_encodings = model_data[5:] | |
def predict(self, data, sigma_factor=2.5): | |
""" | |
Predict the cluster and BMU SOM coordinate for each sample in the data if it's inside the sigma value. | |
Also, predict the label and distance to the center of the label if labels are trained. | |
""" | |
results = [] | |
for sample in data: | |
min_distance = float('inf') | |
nearest_cluster_idx = None | |
nearest_node = None | |
for i, som in self.som_models.items(): | |
x, y = som.winner(sample) | |
node = som.get_weights()[x, y] | |
distance = np.linalg.norm(sample - node) | |
if distance < min_distance: | |
min_distance = distance | |
nearest_cluster_idx = i | |
nearest_node = (x, y) | |
# Check if the nearest node is within the sigma value | |
if min_distance <= self.mean_values[nearest_cluster_idx][nearest_node] * 1.5: # * self.sigma_values[nearest_cluster_idx][nearest_node] * sigma_factor: | |
if hasattr(self, 'label_centroids'): | |
# Predict the label and distance to the center of the label | |
label_idx = self.label_encodings.inverse_transform([nearest_cluster_idx - 1])[0] | |
label_distance = np.linalg.norm(sample - self.label_centroids[label_idx]) | |
results.append((nearest_cluster_idx, nearest_node, label_idx, label_distance)) | |
else: | |
results.append((nearest_cluster_idx, nearest_node)) | |
else: | |
results.append((-1, None)) # Noise | |
return results | |
def score(self, data, midpoints=None, threshold_radius=4): | |
""" | |
Compute the score for each sample in the data based on the distance of the BMU node to the closest midpoint of the SOM grid. | |
:param data: The input data. | |
:param midpoints: A dictionary with keys as the indices of the SOMs and values as lists of midpoints on the grid for the corresponding SOMs. | |
:param threshold_radius: The threshold radius for score calculation. | |
""" | |
scores = [] | |
for sample in data: | |
# Predict the cluster and BMU SOM coordinate for each sample in the data | |
result = self.predict([sample])[0] | |
# Check if it is not a noise | |
if result[0] != -1: | |
# The activated SOM's index and its corresponding BMU | |
activated_som_index, bmu = result[0], result[1] | |
# Get the corresponding SOM for the data point | |
som = self.som_models[activated_som_index] | |
# If specific midpoints are provided for SOMs, use them; else compute the midpoint of the SOM grid | |
if midpoints is not None and activated_som_index in midpoints: | |
specified_midpoints = midpoints[activated_som_index] | |
else: | |
specified_midpoints = [tuple((dim-1)/2 for dim in som.get_weights().shape[:2])] | |
# Compute the grid distances from the BMU to each midpoint and find the minimum distance | |
min_distance = min(np.sqrt((midpoint[0] - bmu[0])**2 + (midpoint[1] - bmu[1])**2) for midpoint in specified_midpoints) | |
# Compute the score as the minimum grid distance minus the threshold radius | |
score = min_distance - threshold_radius | |
scores.append(score) | |
else: | |
scores.append(None) # Noise | |
return scores | |
# rearranging the subplots in the closest square format | |
def rearrange_subplots(self, num_subplots): | |
# Calculate the number of rows and columns for the subplot grid | |
num_rows = math.isqrt(num_subplots) | |
num_cols = math.ceil(num_subplots / num_rows) | |
# Create the figure and subplots | |
fig, axes = plt.subplots(num_rows, num_cols, sharex=True, sharey=True) | |
# Flatten the axes array if it is multidimensional | |
if isinstance(axes, np.ndarray): | |
axes = axes.flatten() | |
# Hide any empty subplots | |
for i in range(num_subplots, len(axes)): | |
axes[i].axis('off') | |
return fig, axes | |
def plot_activation(self, data, start=None, end=None, times=None): | |
""" | |
Generate a GIF visualization of the prediction output using the activation maps of individual SOMs. | |
""" | |
if len(self.som_models) == 0: | |
raise ValueError("SOM models not trained yet.") | |
if start is None: | |
start = 0 | |
if end is None: | |
end = len(data) | |
images = [] | |
for sample in tqdm(data[start:end], desc="Visualizing prediction output"): | |
prediction = self.predict([sample])[0] | |
fig, axes = self.rearrange_subplots(len(self.som_models)) | |
# fig, axes = plt.subplots(1, len(self.som_models), figsize=(20, 5), sharex=True, sharey=True) | |
fig.suptitle(f"Activation map for A {prediction[0]}, node {prediction[1]}", fontsize=16) | |
for idx, (som_key, som) in enumerate(self.som_models.items()): | |
ax = axes[idx] | |
activation_map = np.zeros(som._weights.shape[:2]) | |
for x in range(som._weights.shape[0]): | |
for y in range(som._weights.shape[1]): | |
activation_map[x, y] = np.linalg.norm(sample - som._weights[x, y]) | |
winner = som.winner(sample) # Find the BMU for this SOM | |
activation_map[winner] = 0 # Set the BMU's value to 0 so it will be red in the colormap | |
if som_key == prediction[0]: # Active SOM | |
im_active = ax.imshow(activation_map, cmap='viridis', origin='lower', interpolation='none') | |
ax.plot(winner[1], winner[0], 'r+') # Mark the BMU with a red plus sign | |
ax.set_title(f"A {som_key}", color='blue', fontweight='bold', fontsize=10) | |
if hasattr(self, 'label_centroids'): | |
label_idx = self.label_encodings.inverse_transform([som_key - 1])[0] | |
ax.set_xlabel(f"Label: {label_idx}", fontsize=12) | |
else: # Inactive SOM | |
im_inactive = ax.imshow(activation_map, cmap='gray', origin='lower', interpolation='none') | |
ax.set_title(f"A {som_key}", fontsize=10) | |
ax.set_xticks([]) | |
ax.set_yticks([]) | |
ax.grid(True, linestyle='-', linewidth=0.5) | |
# Create a colorbar for each frame | |
plt.tight_layout() | |
fig.subplots_adjust(wspace=0, hspace=0) | |
# Save the plot to a buffer | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png') | |
buf.seek(0) | |
img = imageio.imread(buf) | |
images.append(img) | |
plt.close() | |
# Set default frame duration if `times` is not provided. | |
# if times is None: | |
# times = [500 for _ in range(len(images))] | |
# # Make sure `times` has the same length as `images`. | |
# times = times[1:] | |
# times = [int(t) for t in times] | |
# if len(times) != len(images): | |
# raise ValueError("`times` must have the same length as the number of frames.") | |
# # Save the images as a GIF with custom durations. | |
# imageio.mimsave("som_gif.gif", images, duration=[t / 1000 for t in times], loop=1) | |
# # Load the gif | |
# gif_file = "som_gif.gif" | |
# clip = VideoFileClip(gif_file) | |
# # Convert the gif to mp4 | |
# mp4_file = "som_gif.mp4" | |
# clip.write_videofile(mp4_file, codec='libx264') | |
# # Close the clip to release resources | |
# clip.close() | |
# return "som_gif.mp4" | |
# Create the video using moviepy and save it as a mp4 file | |
video = ImageSequenceClip(images, fps=2) | |
return video | |
def plot_activation_v2(self, data, slice_select): | |
""" | |
Generate a GIF visualization of the prediction output using the activation maps of individual SOMs. | |
""" | |
if len(self.som_models) == 0: | |
raise ValueError("SOM models not trained yet.") | |
prediction = self.predict([data[int(slice_select)-1]])[0] | |
fig, axes = plt.subplots(1, len(self.som_models), figsize=(20, 5), sharex=True, sharey=True) | |
fig.suptitle(f"Activation map for A {prediction[0]}, node {prediction[1]}", fontsize=16) | |
for idx, (som_key, som) in enumerate(self.som_models.items()): | |
ax = axes[idx] | |
activation_map = np.zeros(som._weights.shape[:2]) | |
for x in range(som._weights.shape[0]): | |
for y in range(som._weights.shape[1]): | |
activation_map[x, y] = np.linalg.norm(data[int(slice_select)-1] - som._weights[x, y]) | |
winner = som.winner(data[int(slice_select)-1]) # Find the BMU for this SOM | |
activation_map[winner] = 0 # Set the BMU's value to 0 so it will be red in the colormap | |
if som_key == prediction[0]: # Active SOM | |
im_active = ax.imshow(activation_map, cmap='viridis', origin='lower', interpolation='none') | |
ax.plot(winner[1], winner[0], 'r+') # Mark the BMU with a red plus sign | |
ax.set_title(f"A {som_key}", color='blue', fontweight='bold') | |
if hasattr(self, 'label_centroids'): | |
label_idx = self.label_encodings.inverse_transform([som_key - 1])[0] | |
ax.set_xlabel(f"Label: {label_idx}", fontsize=12) | |
else: # Inactive SOM | |
im_inactive = ax.imshow(activation_map, cmap='gray', origin='lower', interpolation='none') | |
ax.set_title(f"A {som_key}") | |
ax.set_xticks(range(activation_map.shape[1])) | |
ax.set_yticks(range(activation_map.shape[0])) | |
ax.grid(True, linestyle='-', linewidth=0.5) | |
plt.tight_layout() | |
return fig | |