File size: 4,666 Bytes
e030ae6
 
9df4338
e030ae6
 
 
 
9df4338
e030ae6
 
 
9df4338
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c0cf912
9df4338
 
7517145
 
9df4338
e030ae6
 
 
 
9df4338
1f57142
 
bd89db0
1f57142
9df4338
e030ae6
 
 
 
 
 
445b401
1f57142
445b401
 
e030ae6
 
9df4338
e030ae6
c88b60d
2a2e9cb
9df4338
 
 
 
 
e030ae6
 
 
 
1f57142
e030ae6
 
 
5e02842
 
 
85054d0
 
 
861be40
85054d0
 
 
 
 
 
 
 
861be40
85054d0
861be40
85054d0
 
e030ae6
 
 
be193e7
9df4338
e030ae6
9df4338
 
 
 
 
 
 
 
e030ae6
 
 
 
 
 
 
 
9df4338
 
 
 
 
861be40
 
 
e030ae6
 
 
9df4338
e030ae6
 
9df4338
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification
import gradio as gr
import os

model_name = 'eliolio/bart-finetuned-yelpreviews'
bert_model_name = 'eliolio/bert-correlation-yelpreviews'

access_token = os.environ.get('private_token')

model = AutoModelForSeq2SeqLM.from_pretrained(
    model_name, use_auth_token=access_token
)
tokenizer = AutoTokenizer.from_pretrained(
    model_name, use_auth_token=access_token
)

bert_tokenizer = AutoTokenizer.from_pretrained(
    bert_model_name, use_auth_token=access_token
)
bert_model = AutoModelForSequenceClassification.from_pretrained(
    bert_model_name, use_auth_token=access_token
)


def correlation_score(table, review):
    # Compute the correlation score
    args = ((table, review))
    inputs = bert_tokenizer(*args, padding=True, max_length=128, truncation=True, return_tensors="pt")
    logits = bert_model(**inputs).logits
    probs = logits.softmax(dim=-1)
    return {
        "correlated": probs[:, 1].item(),
        "uncorrelated": probs[:, 0].item()
    }

def create_prompt(stars, useful, funny, cool):
    return f"Generate review: stars: {stars}, useful: {useful}, funny: {funny}, cool: {cool}"


def postprocess(review):
    dot = review.rfind('.')
    return review[:dot+1]


def generate_reviews(stars, useful, funny, cool):
    text = create_prompt(stars, useful, funny, cool)
    inputs = tokenizer(text, return_tensors='pt')
    out = model.generate(
        input_ids=inputs.input_ids,
        attention_mask=inputs.attention_mask,
        do_sample=True,
        num_return_sequences=3,
        temperature=1.2,
        top_p=0.9
    )
    reviews = []
    scores = []
    for review in out:
        reviews.append(postprocess(tokenizer.decode(review, skip_special_tokens=True)))
    for review in reviews:
        scores.append(
            correlation_score(text[17:], review)
        )

    return reviews[0], reviews[1], reviews[2], scores[0], scores[1], scores[2]


css = """
    #ctr {text-align: center;}
    #btn {color: white; background: linear-gradient( 90deg,  rgba(255,166,0,1) 14.7%, rgba(255,99,97,1) 73% );}
"""


md_text = """<h1 style='text-align: center; margin-bottom: 1rem'>Generating Yelp reviews with BART-base ⭐⭐⭐</h1>

This space demonstrates how synthetic data generation can be performed on natural language columns, as found in the Yelp reviews dataset.   

| review id | stars | useful | funny | cool | text |
|:---:|:---:|:---:|:---:|:---:|:---:|
| 0 | 5 | 1 | 0 | 1 | "Wow! Yummy, different, delicious. Our favorite is the lamb curry and korma. With 10 different kinds of naan!!!  Don't let the outside deter you (because we almost changed our minds)...go in and try something new! You'll be glad you did!"




The model is a fine-tuned version of [facebook/bart-base](https://huggingface.com/facebook/bart-base) on Yelp reviews with the following input-output pairs:

- **Input**: "Generate review: stars: 5, useful: 1, funny: 0, cool: 1"
- **Output**: "Wow!  Yummy, different,  delicious.   Our favorite is the lamb curry and korma.  With 10 different kinds of naan!!!  Don't let the outside deter you (because we almost changed our minds)...go in and try something new!   You'll be glad you did!"
"""

resources = """## Resources
- The Yelp reviews dataset can be found in json format [here](https://www.yelp.com/dataset)."""

demo = gr.Blocks(css=css)
with demo:
    with gr.Row():
        gr.Markdown(md_text)

    with gr.Row():
        stars = gr.inputs.Slider(minimum=0, maximum=5,
                                 step=1, default=0, label="stars")
        useful = gr.inputs.Slider(
            minimum=0, maximum=5, step=1, default=0, label="useful")
        funny = gr.inputs.Slider(minimum=0, maximum=5,
                                 step=1, default=0, label="funny")
        cool = gr.inputs.Slider(minimum=0, maximum=5,
                                step=1, default=0, label="cool")
    with gr.Row():
        button = gr.Button("Generate reviews !", elem_id='btn')

    with gr.Row():
        output1 = gr.Textbox(label="Review #1")
        output2 = gr.Textbox(label="Review #2")
        output3 = gr.Textbox(label="Review #3")

    with gr.Row():
        score1 = gr.Label(label="Correlation score #1")
        score2 = gr.Label(label="Correlation score #2")
        score3 = gr.Label(label="Correlation score #3")

    with gr.Row():
        gr.Markdown(resources)

    button.click(
        fn=generate_reviews,
        inputs=[stars, useful, funny, cool],
        outputs=[output1, output2, output3, score1, score2, score3]
    )

demo.launch()