edmundhui commited on
Commit
612ae5b
·
verified ·
1 Parent(s): 78845e7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -0
app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import BertTokenizer
3
+ from regression_models import BERTRegression
4
+
5
+ max_len = 80
6
+
7
+ # Load tokenizer
8
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
9
+
10
+ # Load model architecture
11
+ bertregressor = BERTRegression()
12
+ bertregressor.load_state_dict(torch.load('bert_regression_model.pth', map_location=torch.device('cpu')))
13
+ bertregressor.eval()
14
+
15
+ def predict_price(name, item_condition, category, brand_name, shipping_included, item_description):
16
+ print((name, item_condition, category, brand_name, shipping_included, item_description))
17
+ # Preprocess Input
18
+ if shipping_included:
19
+ shipping_str = "Includes Shipping"
20
+ else:
21
+ shipping_str = "No Shipping"
22
+
23
+ combined = "Item Name: " + name + \
24
+ " Description: " + item_description + \
25
+ " Condition: " + item_condition + \
26
+ " Category: " + category + \
27
+ " Brand " + brand_name + \
28
+ " Shipping: " + shipping_str
29
+
30
+ inputs = tokenizer.encode_plus(
31
+ combined,
32
+ None,
33
+ add_special_tokens=True,
34
+ max_length=max_len,
35
+ padding="max_length",
36
+ truncation=True,
37
+ return_tensors="pt"
38
+ )
39
+
40
+ input_ids = inputs["input_ids"]
41
+ attention_mask = inputs["attention_mask"]
42
+
43
+ with torch.no_grad():
44
+ output = bertregressor(input_ids, attention_mask)
45
+
46
+ return output.item()
47
+
48
+
49
+ demo = gr.Interface(
50
+
51
+ fn = predict_price,
52
+
53
+ inputs = [gr.Textbox(label="Item Name"),
54
+ gr.Dropdown(['Poor', 'Okay', 'Good', 'Excellent', 'Like New'], label="Item Condition", info="What condition is the item in?"),
55
+ gr.Textbox(label="Category on Mercari"),
56
+ gr.Textbox(label="Brand"),
57
+ gr.Checkbox(label="Shipping Included"),
58
+ gr.Textbox(label="Description")
59
+ ],
60
+
61
+ #outputs = gr.Textbox()
62
+ outputs= gr.Number()
63
+ )
64
+
65
+
66
+ demo.launch()