File size: 3,926 Bytes
8c497f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7702cd1
8c497f4
 
740a21c
 
 
8c497f4
 
 
 
740a21c
8c497f4
 
 
 
 
 
740a21c
 
 
8c497f4
 
 
 
 
 
 
 
740a21c
8c497f4
740a21c
8c497f4
 
 
 
 
 
 
 
 
 
 
 
740a21c
 
 
 
 
 
 
 
 
7702cd1
740a21c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7702cd1
740a21c
 
 
 
 
8c497f4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import numpy as np
from sklearn.decomposition import PCA
import gensim.downloader as api
import gradio as gr
import plotly.graph_objects as go

# Load the Word2Vec model
model = api.load("word2vec-google-news-300")


def gensim_analogy(model, word1, word2, word3):
    try:
        result = model.most_similar(positive=[word2, word3], negative=[word1], topn=1)
        return result[0][0]  # Return the word
    except KeyError as e:
        return str(e)


def plot_words_plotly(model, words):
    vectors = np.array([model[word] for word in words if word in model.key_to_index])

    # Reduce dimensions to 2D for plotting
    pca = PCA(n_components=2)
    vectors_2d = pca.fit_transform(vectors)

    # Create a scatter plot
    fig = go.Figure()

    # Add scatter points for each word vector
    for word, vec in zip(words, vectors_2d):
        fig.add_trace(go.Scatter(x=[vec[0]], y=[vec[1]],
                                 text=[word], mode='markers+text',
                                 textposition="bottom center",
                                 name=word))

    fig.update_layout(title="Visualization of Word Vectors",
                      xaxis_title="PCA 1",
                      yaxis_title="PCA 2",
                      showlegend=True,
                      width=600,  # Adjust width as needed
                      height=400)  # Adjust height as needed

    return fig


def gradio_interface(choice, custom_input):
    if choice == "Custom":
        if not custom_input or len(custom_input.split(", ")) != 3:
            return "Invalid input. Please enter exactly three words, separated by commas.", None, {
                "error": "Invalid input"}
        words = custom_input.split(", ")
    else:
        if not choice:
            return "Invalid input. Please select or enter words.", None, {
                "error": "Invalid input"}
        words = choice.split(", ")

    word1, word2, word3 = words
    word4 = gensim_analogy(model, word1, word2, word3)
    plot_fig = plot_words_plotly(model, [word1, word2, word3, word4])

    if word4 in model.key_to_index:
        vector = model[word4]
        vector_display = f"{word4}: {np.round(vector, 2).tolist()}"
    else:
        vector_display = "Vector not available for the resulting word"

    return word4, plot_fig, vector_display


choices = [
    "man, king, woman",
    "Paris, France, London",
    "strong, stronger, weak",
    "pork, pig, beef",
    "Custom"
]


def clear_inputs():
    return "", "", "", "", None


# Define the layout using Rows and Columns
with gr.Blocks() as iface:
    with gr.Row():
        with gr.Column():
            gr.Markdown("# Word Analogy and Vector Visualization")
            gr.Markdown(
                "Select a predefined triplet of words or choose 'Custom' and enter your own (comma-separated) to find a fourth word by analogy, and see their vectors plotted with Plotly.")

            radio = gr.Radio(choices=choices, label="Choose predefined words or enter custom words")

            custom_words = gr.Textbox(
                label="Custom words (comma-separated, required for custom choice; use only if 'Custom' is selected)",
                placeholder="Enter 3 words separated by commas")

            with gr.Row():
                clear_btn = gr.Button("Clear")
                submit_btn = gr.Button("Submit")

            output_word = gr.Textbox(label="Output Word")

        word_plot = gr.Plot(label="Word Vectors Visualization")

    with gr.Row():
        word_vectorization = gr.Textbox(label="Vectorization of the Output Word", lines=4, max_lines=4)

    clear_btn.click(fn=clear_inputs, inputs=None,
                    outputs=[radio, custom_words, output_word, word_vectorization, word_plot])
    submit_btn.click(fn=gradio_interface, inputs=[radio, custom_words],
                     outputs=[output_word, word_plot, word_vectorization])

iface.launch(share=True)