File size: 5,724 Bytes
091be34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import os
import pickle
import torch
import random
import numpy as np
import pandas as pd
import gradio as gr
from setfit import SetFitModel
from huggingface_hub import login, hf_hub_download
hf_token = os.getenv('hf_token')
login(hf_token)

def prepare_setfit_model(repo_id):
    model = SetFitModel.from_pretrained(repo_id)
    id2cat_path = hf_hub_download(repo_id, filename='id2cat.pkl')
    with open(id2cat_path, "rb") as file:
        id2cat = pickle.load(file)
    cat2id_path = hf_hub_download(repo_id, filename='cat2id.pkl')
    with open(cat2id_path, "rb") as file:
        cat2id = pickle.load(file)
    cat_name_path = hf_hub_download(repo_id, filename='cat_name.csv')
    df_cat = pd.read_csv(cat_name_path)
    return model, id2cat, cat2id, df_cat

cat1_model, cat1_id2cat, cat1_cat2id, df_cat = prepare_setfit_model(os.getenv('cat1_repo_id'))
cat2_model, cat2_id2cat, cat2_cat2id, df_cat = prepare_setfit_model(os.getenv('cat2_repo_id'))
cid_model, cid_id2cat, cid_cat2id, df_cat = prepare_setfit_model(os.getenv('cid_repo_id'))


# -

def model_predict(model, sentence):
    with torch.no_grad():
        predict_result = model.predict_proba(sentence).cpu().detach().numpy()
        sorted_ids = np.argsort(predict_result)[::-1]
        sorted_probs = np.sort(predict_result)[::-1]
    return sorted_ids, sorted_probs


# +
def run_prediction(sentence, state):
    sorted_cat1_ids, sorted_cat1_probs = model_predict(cat1_model, sentence)
    sorted_cat2_ids, sorted_cat2_probs = model_predict(cat2_model, sentence)
    sorted_cid_ids, sorted_cid_probs = model_predict(cid_model, sentence)
    sorted_cat1_ids = [cat1_id2cat[item] for item in sorted_cat1_ids]
    sorted_cat2_ids = [cat2_id2cat[item] for item in sorted_cat2_ids]
    sorted_cid_ids = [cid_id2cat[item] for item in sorted_cid_ids]
    
    cat1_names = ['select'] + list(map(id2catname.get, sorted_cat1_ids))
    cat2_names = ['select'] + list(map(id2catname.get, sorted_cat2_ids))
    cid_names = ['select'] + list(map(id2catname.get, sorted_cid_ids))
    
    state['cat1_names'] = cat1_names
    state['sorted_cat1_probs'] = sorted_cat1_probs
    state['cat2_names'] = cat2_names
    state['sorted_cat2_probs'] = sorted_cat2_probs
    state['cid_names'] = cid_names
    state['sorted_cid_probs'] = sorted_cid_probs
    return gr.Dropdown.update(
        choices = cat1_names, value = cat1_names[0], interactive=True
    ), gr.Dropdown.update(
        choices = cat2_names, value = cat2_names[0], interactive=True
    ), gr.Dropdown.update(
        choices = cid_names, value = cid_names[0], interactive=True
    ), state

def filter_cat2(cat1_name, state):
    cat2_names = []
    cat2_list = parent_cat_map.get(cat1_name)
    
    for item in state['cat2_names']:
        if item in cat2_list and item not in cat2_names:
            cat2_names.append(item)
    cat2_names = ['select'] + cat2_names
    return gr.Dropdown.update(
        choices=cat2_names, value=cat2_names[0], interactive=True
    ), state

def filter_cid(cat2_name, state):
    cid_names = []
    cid_list = parent_cat_map.get(cat2_name)
    if cid_list is None:
        return gr.Dropdown.update(
        choices=['None'], value='None', interactive=False
    )
    for item in state['cid_names']:
        if item in cid_list and item not in cid_names:
            cid_names.append(item)
    cid_names = ['select'] + cid_names
    return gr.Dropdown.update(
        choices=cid_names, value=cid_names[0], interactive=True
    )

# def predict_with_title_and_description(title, description):
#     temp_list = list(locations.keys())
#     random.shuffle(temp_list)
#     countries = ['select'] + temp_list
#     return gr.Dropdown.update(
#                 choices=countries, value=countries[0], interactive=True
#             )

parent_cat = df_cat[['id', 'name']]
parent_cat.columns = ['temp_id', 'parent_name']
df_cat = pd.merge(df_cat, parent_cat, left_on='parent_id', right_on='temp_id', how='left').drop('temp_id', axis=1)

id2catname = {item['id']:item['name'] for item in df_cat[['id', 'name']].to_dict(orient='records')}

parent_cat_map = {}
for item in df_cat[['parent_name', 'name']].to_dict(orient='records'):
    if item['parent_name'] in parent_cat_map:
        parent_cat_map[item['parent_name']].append(item['name'])
    else:
        parent_cat_map[item['parent_name']] = [item['name']]

with gr.Blocks() as demo:
    prediction_results = gr.State({})
    with gr.Tab(label="Predict by title") as t1:
        title = gr.Textbox(label='Service Title', placeholder='Please enter service title')
        d1 = gr.Dropdown(choices = list(), label="Cat 1")
        d2 = gr.Dropdown(choices = list(), label='Cat 2')
        d3 = gr.Dropdown(choices = list(), label="CID")
        
        b1 = gr.Button()
        b1.click(run_prediction, [title, prediction_results], [d1, d2, d3, prediction_results])
        d1.select(filter_cat2, [d1, prediction_results], [d2, prediction_results])
        d2.select(filter_cid, [d2, prediction_results], d3)
#     with gr.Tab(label="Predict by title and description") as t2:
#         title = gr.Textbox(label='Service Title', placeholder='Please enter service title')
#         description = gr.Textbox(label='Service Description', placeholder="Please enter service description")
#         d1 = gr.Dropdown(choices = list(locations.keys()), label="Country")
#         d2 = gr.Dropdown(choices = list(), label='State')
#         d3 = gr.Dropdown(choices = list(), label="City")
        
#         b1 = gr.Button()
#         b1.click(predict_with_title_and_description, [title, description], d1)
#         d1.change(filter_states, d1, d2)
#         d2.change(filter_cities, [d1, d2], d3)
    
demo.queue(max_size=5).launch()