shangrilar commited on
Commit
091be34
·
0 Parent(s):

Duplicate from shangrilar/cat_prediction

Browse files
Files changed (4) hide show
  1. .gitattributes +35 -0
  2. README.md +13 -0
  3. app.py +139 -0
  4. requirements.txt +1 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Cat Prediction
3
+ emoji: 🏆
4
+ colorFrom: red
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 3.39.0
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from: shangrilar/cat_prediction
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import torch
4
+ import random
5
+ import numpy as np
6
+ import pandas as pd
7
+ import gradio as gr
8
+ from setfit import SetFitModel
9
+ from huggingface_hub import login, hf_hub_download
10
+ hf_token = os.getenv('hf_token')
11
+ login(hf_token)
12
+
13
+ def prepare_setfit_model(repo_id):
14
+ model = SetFitModel.from_pretrained(repo_id)
15
+ id2cat_path = hf_hub_download(repo_id, filename='id2cat.pkl')
16
+ with open(id2cat_path, "rb") as file:
17
+ id2cat = pickle.load(file)
18
+ cat2id_path = hf_hub_download(repo_id, filename='cat2id.pkl')
19
+ with open(cat2id_path, "rb") as file:
20
+ cat2id = pickle.load(file)
21
+ cat_name_path = hf_hub_download(repo_id, filename='cat_name.csv')
22
+ df_cat = pd.read_csv(cat_name_path)
23
+ return model, id2cat, cat2id, df_cat
24
+
25
+ cat1_model, cat1_id2cat, cat1_cat2id, df_cat = prepare_setfit_model(os.getenv('cat1_repo_id'))
26
+ cat2_model, cat2_id2cat, cat2_cat2id, df_cat = prepare_setfit_model(os.getenv('cat2_repo_id'))
27
+ cid_model, cid_id2cat, cid_cat2id, df_cat = prepare_setfit_model(os.getenv('cid_repo_id'))
28
+
29
+
30
+ # -
31
+
32
+ def model_predict(model, sentence):
33
+ with torch.no_grad():
34
+ predict_result = model.predict_proba(sentence).cpu().detach().numpy()
35
+ sorted_ids = np.argsort(predict_result)[::-1]
36
+ sorted_probs = np.sort(predict_result)[::-1]
37
+ return sorted_ids, sorted_probs
38
+
39
+
40
+ # +
41
+ def run_prediction(sentence, state):
42
+ sorted_cat1_ids, sorted_cat1_probs = model_predict(cat1_model, sentence)
43
+ sorted_cat2_ids, sorted_cat2_probs = model_predict(cat2_model, sentence)
44
+ sorted_cid_ids, sorted_cid_probs = model_predict(cid_model, sentence)
45
+ sorted_cat1_ids = [cat1_id2cat[item] for item in sorted_cat1_ids]
46
+ sorted_cat2_ids = [cat2_id2cat[item] for item in sorted_cat2_ids]
47
+ sorted_cid_ids = [cid_id2cat[item] for item in sorted_cid_ids]
48
+
49
+ cat1_names = ['select'] + list(map(id2catname.get, sorted_cat1_ids))
50
+ cat2_names = ['select'] + list(map(id2catname.get, sorted_cat2_ids))
51
+ cid_names = ['select'] + list(map(id2catname.get, sorted_cid_ids))
52
+
53
+ state['cat1_names'] = cat1_names
54
+ state['sorted_cat1_probs'] = sorted_cat1_probs
55
+ state['cat2_names'] = cat2_names
56
+ state['sorted_cat2_probs'] = sorted_cat2_probs
57
+ state['cid_names'] = cid_names
58
+ state['sorted_cid_probs'] = sorted_cid_probs
59
+ return gr.Dropdown.update(
60
+ choices = cat1_names, value = cat1_names[0], interactive=True
61
+ ), gr.Dropdown.update(
62
+ choices = cat2_names, value = cat2_names[0], interactive=True
63
+ ), gr.Dropdown.update(
64
+ choices = cid_names, value = cid_names[0], interactive=True
65
+ ), state
66
+
67
+ def filter_cat2(cat1_name, state):
68
+ cat2_names = []
69
+ cat2_list = parent_cat_map.get(cat1_name)
70
+
71
+ for item in state['cat2_names']:
72
+ if item in cat2_list and item not in cat2_names:
73
+ cat2_names.append(item)
74
+ cat2_names = ['select'] + cat2_names
75
+ return gr.Dropdown.update(
76
+ choices=cat2_names, value=cat2_names[0], interactive=True
77
+ ), state
78
+
79
+ def filter_cid(cat2_name, state):
80
+ cid_names = []
81
+ cid_list = parent_cat_map.get(cat2_name)
82
+ if cid_list is None:
83
+ return gr.Dropdown.update(
84
+ choices=['None'], value='None', interactive=False
85
+ )
86
+ for item in state['cid_names']:
87
+ if item in cid_list and item not in cid_names:
88
+ cid_names.append(item)
89
+ cid_names = ['select'] + cid_names
90
+ return gr.Dropdown.update(
91
+ choices=cid_names, value=cid_names[0], interactive=True
92
+ )
93
+
94
+ # def predict_with_title_and_description(title, description):
95
+ # temp_list = list(locations.keys())
96
+ # random.shuffle(temp_list)
97
+ # countries = ['select'] + temp_list
98
+ # return gr.Dropdown.update(
99
+ # choices=countries, value=countries[0], interactive=True
100
+ # )
101
+
102
+ parent_cat = df_cat[['id', 'name']]
103
+ parent_cat.columns = ['temp_id', 'parent_name']
104
+ df_cat = pd.merge(df_cat, parent_cat, left_on='parent_id', right_on='temp_id', how='left').drop('temp_id', axis=1)
105
+
106
+ id2catname = {item['id']:item['name'] for item in df_cat[['id', 'name']].to_dict(orient='records')}
107
+
108
+ parent_cat_map = {}
109
+ for item in df_cat[['parent_name', 'name']].to_dict(orient='records'):
110
+ if item['parent_name'] in parent_cat_map:
111
+ parent_cat_map[item['parent_name']].append(item['name'])
112
+ else:
113
+ parent_cat_map[item['parent_name']] = [item['name']]
114
+
115
+ with gr.Blocks() as demo:
116
+ prediction_results = gr.State({})
117
+ with gr.Tab(label="Predict by title") as t1:
118
+ title = gr.Textbox(label='Service Title', placeholder='Please enter service title')
119
+ d1 = gr.Dropdown(choices = list(), label="Cat 1")
120
+ d2 = gr.Dropdown(choices = list(), label='Cat 2')
121
+ d3 = gr.Dropdown(choices = list(), label="CID")
122
+
123
+ b1 = gr.Button()
124
+ b1.click(run_prediction, [title, prediction_results], [d1, d2, d3, prediction_results])
125
+ d1.select(filter_cat2, [d1, prediction_results], [d2, prediction_results])
126
+ d2.select(filter_cid, [d2, prediction_results], d3)
127
+ # with gr.Tab(label="Predict by title and description") as t2:
128
+ # title = gr.Textbox(label='Service Title', placeholder='Please enter service title')
129
+ # description = gr.Textbox(label='Service Description', placeholder="Please enter service description")
130
+ # d1 = gr.Dropdown(choices = list(locations.keys()), label="Country")
131
+ # d2 = gr.Dropdown(choices = list(), label='State')
132
+ # d3 = gr.Dropdown(choices = list(), label="City")
133
+
134
+ # b1 = gr.Button()
135
+ # b1.click(predict_with_title_and_description, [title, description], d1)
136
+ # d1.change(filter_states, d1, d2)
137
+ # d2.change(filter_cities, [d1, d2], d3)
138
+
139
+ demo.queue(max_size=5).launch()
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ setfit