Spaces:
Sleeping
Sleeping
File size: 4,778 Bytes
aa4694e 92802a1 aa4694e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from matplotlib.colors import LinearSegmentedColormap
from sklearn.metrics.pairwise import cosine_similarity
from openai import OpenAI, AuthenticationError, RateLimitError
from dotenv import load_dotenv
import os
load_dotenv()
openai_api_key = os.getenv("OPENAI_API_KEY")
oai_client = OpenAI(api_key=openai_api_key)
def calculate_embeddings(words):
# Get word embeddings
response = oai_client.embeddings.create(input=words, model="text-embedding-3-small")
embeddings = [e.embedding for e in response.data]
return embeddings
def process_array(arr):
# Ensure the input is a square array
if arr.shape[0] != arr.shape[1]:
raise ValueError("Input must be a square array")
n = arr.shape[0]
# Step 1: Keep only the upper triangle (excluding diagonal)
upper_triangle = np.triu(arr, k=1)
# Step 2: Reverse horizontally
reversed_upper_triangle = np.fliplr(upper_triangle)
# Step 3: Drop the final row and column
result = reversed_upper_triangle[:-1, :-1]
# Step 4: Mask the zeros
masked_result = np.ma.masked_where(result == 0, result)
return masked_result
def plot_heatmap(masked_result, l1: list[str]):
n, _ = masked_result.shape
# Create the heatmap
fig, ax = plt.subplots(
figsize=(12, 10)
) # Increased figure size for better visibility
# Create a custom colormap
colors = ["darkred", "lightgray", "dodgerblue"]
n_bins = 100
cmap = LinearSegmentedColormap.from_list("custom", colors, N=n_bins)
cmap.set_bad("white") # Set color for masked values (zeros) to white
# Plot the heatmap
im = ax.imshow(masked_result, cmap=cmap, vmin=-1, vmax=1)
# Add text annotations
for i in range(n):
for j in range(n):
if not np.ma.is_masked(masked_result[i, j]):
text = ax.text(
j,
i,
f"{masked_result[i, j]:.2f}",
ha="center",
va="center",
color="black",
)
# Set y and x axis labels
ax.set_yticks(range(n))
ax.set_yticklabels(l1[:-1])
ax.set_xticks(range(n))
ax.set_xticklabels(reversed(l1[1:]))
# Move x-axis to the top
ax.xaxis.tick_top()
ax.xaxis.set_label_position("top")
# Rotate x-axis labels for better readability
plt.setp(ax.get_xticklabels(), rotation=45, ha="left", rotation_mode="anchor")
# Add colorbar
cbar = plt.colorbar(im)
cbar.set_ticks([-1, 0, 1])
cbar.set_ticklabels(["-1", "0", "1"])
# Add title
plt.title("Correlation Heatmap", pad=20)
# Adjust layout and display the plot
plt.tight_layout()
return fig
def plot_pca(embeddings, words):
fig, ax = plt.subplots(figsize=(12, 10))
pca = PCA(n_components=2)
embeddings_2d = pca.fit_transform(embeddings)
fig, ax = plt.subplots(figsize=(10, 8))
ax.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1])
for i, word in enumerate(words):
ax.annotate(word, (embeddings_2d[i, 0], embeddings_2d[i, 1]))
ax.set_title("PCA of Word Embeddings")
ax.set_xlabel("First Principal Component")
ax.set_ylabel("Second Principal Component")
plt.tight_layout()
return fig
def word_similarity_heatmap(input_text):
words = [word.strip() for word in input_text.split(",")]
if len(words) < 2:
return "Please enter at least two words."
try:
embeddings = calculate_embeddings(words)
similarities = cosine_similarity(embeddings)
new_array = process_array(similarities)
heatmap = plot_heatmap(new_array, words)
pca_plot = plot_pca(embeddings, words)
return heatmap, pca_plot
# return heatmap
except AuthenticationError as e:
print("OpenAI API key is invalid. Please check your API key.")
raise e
except RateLimitError as e:
print("OpenAI API rate limit exceeded. Please try again later.")
raise e
except Exception as e:
print(f"An error occurred: {str(e)}")
raise e
iface = gr.Interface(
fn=word_similarity_heatmap, # _and_pca,
inputs=gr.Textbox(lines=2, placeholder="Enter words separated by commas"),
outputs=[gr.Plot(label="Similarity Heatmap"), gr.Plot(label="PCA Plot")],
title="Word Similarity Heatmap and PCA Plot using OpenAI Embeddings",
description="Enter a list of words separated by commas. The app will calculate the cosine similarity between their OpenAI embeddings, display a compact heatmap of the upper triangle similarities, and show a PCA plot of the embeddings.",
)
# Launch the app
iface.launch(share=True)
|