import gradio as gr import numpy as np import matplotlib.pyplot as plt from sklearn.decomposition import PCA from matplotlib.colors import LinearSegmentedColormap from sklearn.metrics.pairwise import cosine_similarity from openai import OpenAI, AuthenticationError, RateLimitError from dotenv import load_dotenv import os load_dotenv() openai_api_key = os.getenv("OPENAI_API_KEY") oai_client = OpenAI(api_key=openai_api_key) def calculate_embeddings(words): # Get word embeddings response = oai_client.embeddings.create(input=words, model="text-embedding-3-small") embeddings = [e.embedding for e in response.data] return embeddings def process_array(arr): # Ensure the input is a square array if arr.shape[0] != arr.shape[1]: raise ValueError("Input must be a square array") n = arr.shape[0] # Step 1: Keep only the upper triangle (excluding diagonal) upper_triangle = np.triu(arr, k=1) # Step 2: Reverse horizontally reversed_upper_triangle = np.fliplr(upper_triangle) # Step 3: Drop the final row and column result = reversed_upper_triangle[:-1, :-1] # Step 4: Mask the zeros masked_result = np.ma.masked_where(result == 0, result) return masked_result def plot_heatmap(masked_result, l1: list[str]): n, _ = masked_result.shape # Create the heatmap fig, ax = plt.subplots( figsize=(12, 10) ) # Increased figure size for better visibility # Create a custom colormap colors = ["darkred", "lightgray", "dodgerblue"] n_bins = 100 cmap = LinearSegmentedColormap.from_list("custom", colors, N=n_bins) cmap.set_bad("white") # Set color for masked values (zeros) to white # Plot the heatmap im = ax.imshow(masked_result, cmap=cmap, vmin=-1, vmax=1) # Add text annotations for i in range(n): for j in range(n): if not np.ma.is_masked(masked_result[i, j]): text = ax.text( j, i, f"{masked_result[i, j]:.2f}", ha="center", va="center", color="black", ) # Set y and x axis labels ax.set_yticks(range(n)) ax.set_yticklabels(l1[:-1]) ax.set_xticks(range(n)) ax.set_xticklabels(reversed(l1[1:])) # Move x-axis to the top ax.xaxis.tick_top() ax.xaxis.set_label_position("top") # Rotate x-axis labels for better readability plt.setp(ax.get_xticklabels(), rotation=45, ha="left", rotation_mode="anchor") # Add colorbar cbar = plt.colorbar(im) cbar.set_ticks([-1, 0, 1]) cbar.set_ticklabels(["-1", "0", "1"]) # Add title plt.title("Correlation Heatmap", pad=20) # Adjust layout and display the plot plt.tight_layout() return fig def plot_pca(embeddings, words): fig, ax = plt.subplots(figsize=(12, 10)) pca = PCA(n_components=2) embeddings_2d = pca.fit_transform(embeddings) fig, ax = plt.subplots(figsize=(10, 8)) ax.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1]) for i, word in enumerate(words): ax.annotate(word, (embeddings_2d[i, 0], embeddings_2d[i, 1])) ax.set_title("PCA of Word Embeddings") ax.set_xlabel("First Principal Component") ax.set_ylabel("Second Principal Component") plt.tight_layout() return fig def word_similarity_heatmap(input_text): words = [word.strip() for word in input_text.split(",")] if len(words) < 2: return "Please enter at least two words." try: embeddings = calculate_embeddings(words) similarities = cosine_similarity(embeddings) new_array = process_array(similarities) heatmap = plot_heatmap(new_array, words) pca_plot = plot_pca(embeddings, words) return heatmap, pca_plot # return heatmap except AuthenticationError as e: print("OpenAI API key is invalid. Please check your API key.") raise e except RateLimitError as e: print("OpenAI API rate limit exceeded. Please try again later.") raise e except Exception as e: print(f"An error occurred: {str(e)}") raise e iface = gr.Interface( fn=word_similarity_heatmap, # _and_pca, inputs=gr.Textbox(lines=2, placeholder="Enter words separated by commas"), outputs=[gr.Plot(label="Similarity Heatmap"), gr.Plot(label="PCA Plot")], title="Word Similarity Heatmap and PCA Plot using OpenAI Embeddings", description="Enter a list of words separated by commas. The app will calculate the cosine similarity between their OpenAI embeddings, display a compact heatmap of the upper triangle similarities, and show a PCA plot of the embeddings.", ) # Launch the app iface.launch(share=True)