File size: 3,082 Bytes
42466a5
 
 
 
 
68957a6
42466a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
517cbd2
42466a5
68957a6
 
 
fd637c4
 
 
 
 
 
ae6b2f5
42466a5
517cbd2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import gradio as gr
import torch
import os
from transformers import AutoTokenizer, T5ForConditionalGeneration

model_id = 'ksabeh/qpave'
max_input_length = 512
max_target_length = 20
auth_token = os.environ.get('TOKEN')

model = T5ForConditionalGeneration.from_pretrained(model_id, use_auth_token=auth_token)
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=auth_token)

def predict(cg_attribute, text, fg_attribute, category):
    input = f"{fg_attribute}: {text}"
    model_input = tokenizer(input, max_length=max_input_length, truncation=True,
                           padding="max_length")
    model_input = {k:torch.unsqueeze(torch.tensor(v),dim=0) for k,v in model_input.items()}
    predictions = model.generate(**model_input, num_beams=4, do_sample=True, max_length=10)
    return tokenizer.batch_decode(predictions, skip_special_tokens=True)[0]

# iface = gr.Interface(
#     predict,
#     inputs=["text", "text", "text", "text"],
#     outputs=['text'],
#     title="QPAVE",
#     examples=[["Arriba Salsa Garlic and Cilantro, 16 oz", "Food"], 
#     ["MV Verholen Black GPS Ball Mount for BMW K1200S K1200R K1300S K1300R Black GPS Ball Mount VER-4901-10181", "Toys"],
#     ["Mitsubishi 3000GT License Plate Frame (Zince Metal)", "Automotive"],
#     ["Fun Fire Truck Pinata Personalized", "Toys"],
#     ["White Chocolate Caramel Gourmet Popcorn Kelly", "Food"]
#     ]
# )

# iface.launch()

demo = gr.Interface(
    predict,
    [
        gr.Textbox(
            label = "Coarse-grained Attribute",
            info = "The coarse-grained attribute name",
            lines = 1,
        ),
        gr.Textbox(
            label = "Context",
            info = "The value of the coarse-grained attribute",
            lines = 1,
        ),
        gr.Textbox(
            label = "Fine-grained Attribute",
            info = "The target fine-grained attribute name",
            lines = 1,
        ),
        gr.Textbox(
            label = "Category",
            info = "The product category",
            lines = 1,
        )
    ],
    "textbox",
    title="QPAVE",
    examples=[["Processor", "3ghz intel core i5", "Brand Name", "Computers & Tablets"],
              ["Special Feature", "Electric Razor for Men,Beard Trimmer,Rechargeable,Wet and Dry,Cordles", "Uses", "Electric Shavers"],
              ["Special Feature", "Electric Razor for Men,Beard Trimmer,Rechargeable,Wet and Dry,Cordles", "Skin Type", "Electric Shavers"],
              ["Color", "Black Foil Razor", "Head Type", "Electric Shavers"],
              ["Color", "2 ustzc-541 black with silver wood", "Material Type", "Office Electronics Accessories"],
              ["Brand Name", "beiter gray battery power", "Power Source", "Laptop Accessories"],
              ['Color', '14.5-inch red color spectrum', 'Size', 'Novelty Lighting'],
              ['Fixture Features', 'wattage 21w type 1200 mm input end', 'Wattage', 'Fluorescent Tubes'],
              ['Fixture Features', 'wattage 21w type 1200 mm input end', 'Size', 'Fluorescent Tubes']
              ]
)

demo.launch()