test_call2vec / app.py
FreshP's picture
Create app.py
4cd8155
raw
history blame
4.12 kB
import gradio as gr
import numpy as np
import pandas as pd
from huggingface_hub import hf_hub_url, cached_download
from gensim.models.fasttext import load_facebook_model
# download model from huggingface hub
url = hf_hub_url(repo_id="simonschoe/call2vec", filename="model.bin")
cached_download(url)
# load model via gensim
model = load_facebook_model(cached_download(url))
def process(_input, topn, similar):
# convert input to lower, replace whitespaces by underscores
_input = _input.lower().replace(' ', '_')
_input = _input.split('\n')
# apply model
if len(_input)>1:
# compute average seed embedding
avg_input = np.stack([model.wv[w] for w in _input], axis=0).mean(axis=0)
# find (dis)similarities
if similar=='Dissimilar':
nearest_neighbors = model.wv.most_similar(negative=avg_input, topn=topn)
else:
nearest_neighbors = model.wv.most_similar(positive=avg_input, topn=topn)
frequencies = [model.wv.get_vecattr(nn[0], 'count') for nn in nearest_neighbors]
else:
# find (dis)similarities
if similar=='Dissimilar':
nearest_neighbors = model.wv.most_similar(negative=_input[0], topn=topn)
else:
nearest_neighbors = model.wv.most_similar(positive=_input[0], topn=topn)
frequencies = [model.wv.get_vecattr(nn[0], 'count') for nn in nearest_neighbors]
result = pd.DataFrame([(a[0],a[1],b) for a,b in zip(nearest_neighbors, frequencies)], columns=['Token', 'Cosine Similarity', 'Frequency'])
return result
def save(df):
df.to_csv('result.csv')
return 'result.csv'
demo = gr.Blocks(theme="dark")
with demo:
gr.Markdown("# Title")
gr.Markdown("## Subtitle")
with gr.Row():
with gr.Column():
similar_radio = gr.Radio(choices=["Similar", "Dissimilar"])
n_output = gr.Slider(minimum=5, maximum=50, step=1)
gr.Markdown(
"""### Example prompts:
- Example 1
- Example 2
"""
)
with gr.Column():
with gr.Tabs():
with gr.TabItem("Single"):
with gr.Column():
text_input = gr.Textbox(lines=1)
df_output = gr.Dataframe(interactive=False)
with gr.Row():
compute_button_s = gr.Button("Compute")
export_button_s = gr.Button("Export as CSV")
file_out_s = gr.File(interactive=False)
with gr.TabItem("Multiple"):
with gr.Column():
text_input_multiple = gr.Textbox(lines=3)
df_output_multiple = gr.Dataframe(interactive=False)
with gr.Row():
compute_button_m = gr.Button("Compute")
export_button_m = gr.Button("Export as CSV")
file_out_m = gr.File(interactive=False)
with gr.Column():
gr.Markdown("""
### Project Description
Lorem ipsum dolor sit amet, consetetur sadipscing elitr, sed diam nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. Stet clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet. Lorem ipsum dolor sit amet, consetetur sadipscing elitr, sed diam nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. Stet clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet.""")
compute_button_s.click(process, inputs=[text_input, n_output, similar_radio], outputs=df_output)
compute_button_m.click(process, inputs=[text_input_multiple, n_output, similar_radio], outputs=df_output_multiple)
export_button_s.click(save, inputs=[df_output], outputs=file_out_s)
export_button_s.click(save, inputs=[df_output_multiple], outputs=file_out_s)
demo.launch()