File size: 2,887 Bytes
8c497f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import numpy as np
from sklearn.decomposition import PCA
import gensim.downloader as api
import gradio as gr
import plotly.graph_objects as go

# Load the Word2Vec model
model = api.load("word2vec-google-news-300")


def gensim_analogy(model, word1, word2, word3):
    try:
        result = model.most_similar(positive=[word2, word3], negative=[word1], topn=1)
        return result[0][0]  # Return the word
    except KeyError as e:
        return str(e)


def plot_words_plotly(model, words):
    vectors = np.array([model[word] for word in words if word in model.key_to_index])

    # Reduce dimensions to 2D for plotting
    pca = PCA(n_components=2)
    vectors_2d = pca.fit_transform(vectors)

    # Create a scatter plot
    fig = go.Figure()

    # Add scatter points for each word vector
    for word, vec in zip(words, vectors_2d):
        fig.add_trace(go.Scatter(x=[vec[0]], y=[vec[1]],
                                 text=[word], mode='markers+text',
                                 textposition="bottom center",
                                 name=word))

    fig.update_layout(title="Word Vectors Visualization",
                      xaxis_title="PCA 1",
                      yaxis_title="PCA 2",
                      showlegend=True)

    return fig


def gradio_interface(choice, custom_input=None):
    if choice == "Custom":
        if not custom_input or len(custom_input.split(", ")) != 3:
            return "Invalid input. Please enter exactly three words, separated by commas.", None, {
                "error": "Invalid input"}
        words = custom_input.split(", ")
    else:
        words = choice.split(", ")

    word1, word2, word3 = words
    word4 = gensim_analogy(model, word1, word2, word3)
    plot_fig = plot_words_plotly(model, [word1, word2, word3, word4])

    if word4 in model.key_to_index:
        vector = model[word4]
        vector_display = {word4: [round(num, 2) for num in vector.tolist()]}
    else:
        vector_display = {"error": "Vector not available for the resulting word"}

    return word4, plot_fig, vector_display


choices = [
    "man, king, woman",
    "Paris, France, London",
    "strong, stronger, weak",
    "pork, pig, beef",
    "Custom"
]

iface = gr.Interface(
    fn=gradio_interface,
    inputs=[
        gr.Dropdown(choices=choices, label="Choose predefined words or enter custom words"),
        gr.Textbox(label="Custom words (comma-separated, required for custom choice; use only if 'Custom' is selected)",
                   placeholder="Enter 3 words separated by commas")
    ],
    outputs=["text", "plot", "json"],
    title="Word Analogy and Vector Visualization with Plotly",
    description="Select a predefined triplet of words or choose 'Custom' and enter your own (comma-separated) to find a fourth word by analogy, and see their vectors plotted with Plotly."
)

iface.launch(share=True)