File size: 4,247 Bytes
4cd8155
 
 
5661523
7e21fe0
4cd8155
 
 
 
7e21fe0
 
4cd8155
 
 
 
 
 
 
5661523
4cd8155
 
5661523
4cd8155
 
5661523
 
7e21fe0
5661523
 
 
4cd8155
 
 
 
5661523
 
4cd8155
 
5661523
 
4cd8155
 
 
7e21fe0
5661523
 
 
 
5bf70e2
 
4cd8155
 
 
 
 
d0631f5
4cd8155
 
5661523
 
d0631f5
 
5661523
d0631f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4cd8155
d0631f5
 
 
 
 
 
 
4cd8155
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import gradio as gr
import numpy as np
import pandas as pd
from datetime import datetime
import os

from huggingface_hub import hf_hub_url, cached_download
from gensim.models.fasttext import load_facebook_model

ACCESS_KEY = os.environ.get('ACCESS_KEY')

# download model from huggingface hub
url = hf_hub_url(repo_id="simonschoe/call2vec", filename="model.bin")
cached_download(url)

# load model via gensim
model = load_facebook_model(cached_download(url))

def process(_input, topn):

    # convert input to lower, replace whitespaces by underscores
    _input = _input.strip().lower().replace(' ', '_')
    _input = _input.split('\n')

    _input = [s for s in _input if s]

    if _input[0] != ACCESS_KEY:
        with open('log.txt', 'a') as f:
            f.write(str(datetime.now()) + '+++' + '___'.join(_input) + '\n')

    # apply model
    if len(_input)>1:
        # compute average seed embedding
        avg_input = np.stack([model.wv[w] for w in _input], axis=0).mean(axis=0)
        # find similarities
        nearest_neighbors = model.wv.most_similar(positive=avg_input, topn=topn)
        frequencies = [model.wv.get_vecattr(nn[0], 'count') for nn in nearest_neighbors]
    else:
        # find similarities
        nearest_neighbors = model.wv.most_similar(positive=_input[0], topn=topn)
        frequencies = [model.wv.get_vecattr(nn[0], 'count') for nn in nearest_neighbors]

    result = pd.DataFrame([(a[0],a[1],b) for a,b in zip(nearest_neighbors, frequencies)], columns=['Token', 'Cosine Similarity', 'Frequency'])
    if _input[0] == ACCESS_KEY:
        with open('log.txt', 'r') as f:
            prompts = f.readlines()
            prompts = [p.strip().split('+++') for p in prompts]
        result = pd.DataFrame(prompts, columns=['Time', 'Prompt'])
    result.to_csv('result.csv')
    return result, 'result.csv'

def save(df):
    df.to_csv('result.csv')
    return 'result.csv'

demo = gr.Blocks()

with demo:
    gr.Markdown("# Call2Vec")
    gr.Markdown("## Earnings call transformation project")
    with gr.Tabs():
        with gr.TabItem(label='Block Interface'):
            with gr.Row():
                with gr.Column():
                    similar_radio = gr.Radio(label="Single or multiple input prompts", value="Single", choices=["Single", "Multiple"])
                    n_output = gr.Slider(minimum=5, maximum=50, step=1)
                    gr.Markdown(
                        """### Example prompts:
                        - Example 1
                        - Example 2
                        """
                    )
                with gr.Column():
                    text_input = gr.Textbox(lines=1)
                    with gr.Row():
                        compute_button = gr.Button("Compute")
                    df_output = gr.Dataframe(interactive=False)
                    file_out = gr.File(interactive=False)
                with gr.Column():
                    gr.Markdown("""
                        ### Project Description
                        Lorem ipsum dolor sit amet, consetetur sadipscing elitr, sed diam nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. Stet clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet. Lorem ipsum dolor sit amet, consetetur sadipscing elitr, sed diam nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. Stet clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet.""")

            compute_button.click(process, inputs=[text_input, n_output], outputs=[df_output, file_out])
            similar_radio.change(lambda x: "\n\n\n\n\n\n\n" if x=='Multiple' else "", inputs=[similar_radio], outputs=[text_input])
        with gr.TabItem('Traditional Interface'):
            gr.Interface(process, inputs=[gr.Textbox(lines=3), gr.Slider(minimum=5, maximum=50, step=1)],
                         outputs=[gr.Dataframe(interactive=False), gr.File(interactive=False)],
                         examples=[["Test example", 5],
                                   ["Multiple prompts\nexample", 7]])

demo.launch()