File size: 1,965 Bytes
612ae5b
 
 
7cec3d5
612ae5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
from transformers import BertTokenizer
from regression_models import BERTRegression
import gradio as gr

max_len = 80

# Load tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# Load model architecture
bertregressor = BERTRegression()
bertregressor.load_state_dict(torch.load('bert_regression_model.pth', map_location=torch.device('cpu')))
bertregressor.eval()

def predict_price(name, item_condition, category, brand_name, shipping_included, item_description):
    print((name, item_condition, category, brand_name, shipping_included, item_description))
    # Preprocess Input
    if shipping_included:
        shipping_str = "Includes Shipping"
    else:
        shipping_str = "No Shipping"
        
    combined = "Item Name: " + name + \
            " Description: " + item_description + \
            " Condition: " + item_condition + \
            " Category: " + category + \
            " Brand " + brand_name + \
            " Shipping: " + shipping_str
    
    inputs = tokenizer.encode_plus(
        combined,
        None,
        add_special_tokens=True,
        max_length=max_len,
        padding="max_length",
        truncation=True,
        return_tensors="pt"
    )
    
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]
    
    with torch.no_grad():
        output = bertregressor(input_ids, attention_mask)
    
    return output.item()

    
demo = gr.Interface(
    
    fn = predict_price,
    
    inputs = [gr.Textbox(label="Item Name"), 
              gr.Dropdown(['Poor', 'Okay', 'Good', 'Excellent', 'Like New'], label="Item Condition", info="What condition is the item in?"),
              gr.Textbox(label="Category on Mercari"),
              gr.Textbox(label="Brand"),
              gr.Checkbox(label="Shipping Included"),
              gr.Textbox(label="Description")
             ],
    
    #outputs = gr.Textbox()
    outputs= gr.Number()
)


demo.launch()