Spaces:
Sleeping
Sleeping
File size: 5,724 Bytes
091be34 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import os
import pickle
import torch
import random
import numpy as np
import pandas as pd
import gradio as gr
from setfit import SetFitModel
from huggingface_hub import login, hf_hub_download
hf_token = os.getenv('hf_token')
login(hf_token)
def prepare_setfit_model(repo_id):
model = SetFitModel.from_pretrained(repo_id)
id2cat_path = hf_hub_download(repo_id, filename='id2cat.pkl')
with open(id2cat_path, "rb") as file:
id2cat = pickle.load(file)
cat2id_path = hf_hub_download(repo_id, filename='cat2id.pkl')
with open(cat2id_path, "rb") as file:
cat2id = pickle.load(file)
cat_name_path = hf_hub_download(repo_id, filename='cat_name.csv')
df_cat = pd.read_csv(cat_name_path)
return model, id2cat, cat2id, df_cat
cat1_model, cat1_id2cat, cat1_cat2id, df_cat = prepare_setfit_model(os.getenv('cat1_repo_id'))
cat2_model, cat2_id2cat, cat2_cat2id, df_cat = prepare_setfit_model(os.getenv('cat2_repo_id'))
cid_model, cid_id2cat, cid_cat2id, df_cat = prepare_setfit_model(os.getenv('cid_repo_id'))
# -
def model_predict(model, sentence):
with torch.no_grad():
predict_result = model.predict_proba(sentence).cpu().detach().numpy()
sorted_ids = np.argsort(predict_result)[::-1]
sorted_probs = np.sort(predict_result)[::-1]
return sorted_ids, sorted_probs
# +
def run_prediction(sentence, state):
sorted_cat1_ids, sorted_cat1_probs = model_predict(cat1_model, sentence)
sorted_cat2_ids, sorted_cat2_probs = model_predict(cat2_model, sentence)
sorted_cid_ids, sorted_cid_probs = model_predict(cid_model, sentence)
sorted_cat1_ids = [cat1_id2cat[item] for item in sorted_cat1_ids]
sorted_cat2_ids = [cat2_id2cat[item] for item in sorted_cat2_ids]
sorted_cid_ids = [cid_id2cat[item] for item in sorted_cid_ids]
cat1_names = ['select'] + list(map(id2catname.get, sorted_cat1_ids))
cat2_names = ['select'] + list(map(id2catname.get, sorted_cat2_ids))
cid_names = ['select'] + list(map(id2catname.get, sorted_cid_ids))
state['cat1_names'] = cat1_names
state['sorted_cat1_probs'] = sorted_cat1_probs
state['cat2_names'] = cat2_names
state['sorted_cat2_probs'] = sorted_cat2_probs
state['cid_names'] = cid_names
state['sorted_cid_probs'] = sorted_cid_probs
return gr.Dropdown.update(
choices = cat1_names, value = cat1_names[0], interactive=True
), gr.Dropdown.update(
choices = cat2_names, value = cat2_names[0], interactive=True
), gr.Dropdown.update(
choices = cid_names, value = cid_names[0], interactive=True
), state
def filter_cat2(cat1_name, state):
cat2_names = []
cat2_list = parent_cat_map.get(cat1_name)
for item in state['cat2_names']:
if item in cat2_list and item not in cat2_names:
cat2_names.append(item)
cat2_names = ['select'] + cat2_names
return gr.Dropdown.update(
choices=cat2_names, value=cat2_names[0], interactive=True
), state
def filter_cid(cat2_name, state):
cid_names = []
cid_list = parent_cat_map.get(cat2_name)
if cid_list is None:
return gr.Dropdown.update(
choices=['None'], value='None', interactive=False
)
for item in state['cid_names']:
if item in cid_list and item not in cid_names:
cid_names.append(item)
cid_names = ['select'] + cid_names
return gr.Dropdown.update(
choices=cid_names, value=cid_names[0], interactive=True
)
# def predict_with_title_and_description(title, description):
# temp_list = list(locations.keys())
# random.shuffle(temp_list)
# countries = ['select'] + temp_list
# return gr.Dropdown.update(
# choices=countries, value=countries[0], interactive=True
# )
parent_cat = df_cat[['id', 'name']]
parent_cat.columns = ['temp_id', 'parent_name']
df_cat = pd.merge(df_cat, parent_cat, left_on='parent_id', right_on='temp_id', how='left').drop('temp_id', axis=1)
id2catname = {item['id']:item['name'] for item in df_cat[['id', 'name']].to_dict(orient='records')}
parent_cat_map = {}
for item in df_cat[['parent_name', 'name']].to_dict(orient='records'):
if item['parent_name'] in parent_cat_map:
parent_cat_map[item['parent_name']].append(item['name'])
else:
parent_cat_map[item['parent_name']] = [item['name']]
with gr.Blocks() as demo:
prediction_results = gr.State({})
with gr.Tab(label="Predict by title") as t1:
title = gr.Textbox(label='Service Title', placeholder='Please enter service title')
d1 = gr.Dropdown(choices = list(), label="Cat 1")
d2 = gr.Dropdown(choices = list(), label='Cat 2')
d3 = gr.Dropdown(choices = list(), label="CID")
b1 = gr.Button()
b1.click(run_prediction, [title, prediction_results], [d1, d2, d3, prediction_results])
d1.select(filter_cat2, [d1, prediction_results], [d2, prediction_results])
d2.select(filter_cid, [d2, prediction_results], d3)
# with gr.Tab(label="Predict by title and description") as t2:
# title = gr.Textbox(label='Service Title', placeholder='Please enter service title')
# description = gr.Textbox(label='Service Description', placeholder="Please enter service description")
# d1 = gr.Dropdown(choices = list(locations.keys()), label="Country")
# d2 = gr.Dropdown(choices = list(), label='State')
# d3 = gr.Dropdown(choices = list(), label="City")
# b1 = gr.Button()
# b1.click(predict_with_title_and_description, [title, description], d1)
# d1.change(filter_states, d1, d2)
# d2.change(filter_cities, [d1, d2], d3)
demo.queue(max_size=5).launch() |