File size: 4,778 Bytes
aa4694e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92802a1
 
aa4694e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from matplotlib.colors import LinearSegmentedColormap
from sklearn.metrics.pairwise import cosine_similarity
from openai import OpenAI, AuthenticationError, RateLimitError
from dotenv import load_dotenv
import os


load_dotenv()
openai_api_key = os.getenv("OPENAI_API_KEY")
oai_client = OpenAI(api_key=openai_api_key)


def calculate_embeddings(words):
    # Get word embeddings
    response = oai_client.embeddings.create(input=words, model="text-embedding-3-small")
    embeddings = [e.embedding for e in response.data]
    return embeddings


def process_array(arr):
    # Ensure the input is a square array
    if arr.shape[0] != arr.shape[1]:
        raise ValueError("Input must be a square array")

    n = arr.shape[0]

    # Step 1: Keep only the upper triangle (excluding diagonal)
    upper_triangle = np.triu(arr, k=1)

    # Step 2: Reverse horizontally
    reversed_upper_triangle = np.fliplr(upper_triangle)

    # Step 3: Drop the final row and column
    result = reversed_upper_triangle[:-1, :-1]

    # Step 4: Mask the zeros
    masked_result = np.ma.masked_where(result == 0, result)

    return masked_result


def plot_heatmap(masked_result, l1: list[str]):
    n, _ = masked_result.shape

    # Create the heatmap
    fig, ax = plt.subplots(
        figsize=(12, 10)
    )  # Increased figure size for better visibility

    # Create a custom colormap
    colors = ["darkred", "lightgray", "dodgerblue"]
    n_bins = 100
    cmap = LinearSegmentedColormap.from_list("custom", colors, N=n_bins)
    cmap.set_bad("white")  # Set color for masked values (zeros) to white

    # Plot the heatmap
    im = ax.imshow(masked_result, cmap=cmap, vmin=-1, vmax=1)

    # Add text annotations
    for i in range(n):
        for j in range(n):
            if not np.ma.is_masked(masked_result[i, j]):
                text = ax.text(
                    j,
                    i,
                    f"{masked_result[i, j]:.2f}",
                    ha="center",
                    va="center",
                    color="black",
                )

    # Set y and x axis labels
    ax.set_yticks(range(n))
    ax.set_yticklabels(l1[:-1])
    ax.set_xticks(range(n))
    ax.set_xticklabels(reversed(l1[1:]))

    # Move x-axis to the top
    ax.xaxis.tick_top()
    ax.xaxis.set_label_position("top")

    # Rotate x-axis labels for better readability
    plt.setp(ax.get_xticklabels(), rotation=45, ha="left", rotation_mode="anchor")

    # Add colorbar
    cbar = plt.colorbar(im)
    cbar.set_ticks([-1, 0, 1])
    cbar.set_ticklabels(["-1", "0", "1"])

    # Add title
    plt.title("Correlation Heatmap", pad=20)

    # Adjust layout and display the plot
    plt.tight_layout()
    return fig


def plot_pca(embeddings, words):
    fig, ax = plt.subplots(figsize=(12, 10))
    pca = PCA(n_components=2)
    embeddings_2d = pca.fit_transform(embeddings)

    fig, ax = plt.subplots(figsize=(10, 8))
    ax.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1])

    for i, word in enumerate(words):
        ax.annotate(word, (embeddings_2d[i, 0], embeddings_2d[i, 1]))

    ax.set_title("PCA of Word Embeddings")
    ax.set_xlabel("First Principal Component")
    ax.set_ylabel("Second Principal Component")
    plt.tight_layout()
    return fig


def word_similarity_heatmap(input_text):
    words = [word.strip() for word in input_text.split(",")]

    if len(words) < 2:
        return "Please enter at least two words."

    try:
        embeddings = calculate_embeddings(words)
        similarities = cosine_similarity(embeddings)
        new_array = process_array(similarities)
        heatmap = plot_heatmap(new_array, words)
        pca_plot = plot_pca(embeddings, words)
        return heatmap, pca_plot
        # return heatmap
    except AuthenticationError as e:
        print("OpenAI API key is invalid. Please check your API key.")
        raise e
    except RateLimitError as e:
        print("OpenAI API rate limit exceeded. Please try again later.")
        raise e
    except Exception as e:
        print(f"An error occurred: {str(e)}")
        raise e


iface = gr.Interface(
    fn=word_similarity_heatmap,  # _and_pca,
    inputs=gr.Textbox(lines=2, placeholder="Enter words separated by commas"),
    outputs=[gr.Plot(label="Similarity Heatmap"), gr.Plot(label="PCA Plot")],
    title="Word Similarity Heatmap and PCA Plot using OpenAI Embeddings",
    description="Enter a list of words separated by commas. The app will calculate the cosine similarity between their OpenAI embeddings, display a compact heatmap of the upper triangle similarities, and show a PCA plot of the embeddings.",
)

# Launch the app
iface.launch(share=True)