File size: 2,269 Bytes
e030ae6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f57142
 
 
 
e030ae6
 
 
 
 
 
445b401
1f57142
445b401
 
e030ae6
 
 
1f57142
e030ae6
 
 
 
 
1f57142
e030ae6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import gradio as gr
import os

model_name = 'eliolio/bart-finetuned-yelpreviews'

access_token = os.environ.get('private_token')

model = AutoModelForSeq2SeqLM.from_pretrained(model_name, use_auth_token=access_token)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=access_token)

def create_prompt(stars, useful, funny, cool):
    return f"Generate review: stars: {stars}, useful: {useful}, funny: {funny}, cool: {cool}"

def postprocess(review):
    dot = review.rfind('.')
    return review[:dot]

def generate_reviews(stars, useful, funny, cool):
    text = create_prompt(stars, useful, funny, cool)
    inputs = tokenizer(text, return_tensors='pt')
    out = model.generate(
        input_ids=inputs.input_ids,
        attention_mask=inputs.attention_mask,
        do_sample=True,
        num_return_sequences=3,
        temperature=1.2,
        top_p=0.9
    )
    reviews = []
    for review in out:
        reviews.append(postprocess(tokenizer.decode(review, skip_special_tokens=True)))

    return reviews[0], reviews[1], reviews[2]

css = """
    #ctr {text-align: center;}
    #btn {color: white; background: linear-gradient( 90deg,  rgba(255,166,0,1) 14.7%, rgba(255,99,97,1) 73% );}
"""


md_text = """## Generating Yelp reviews with BART-base ⭐⭐⭐"""
demo = gr.Blocks(css=css)
with demo:
    with gr.Row():
        gr.Markdown(md_text, elem_id='ctr')
    
    with gr.Row():
        stars = gr.inputs.Slider(minimum=0, maximum=5, step=1, default=0, label="stars")
        useful = gr.inputs.Slider(minimum=0, maximum=5, step=1, default=0, label="useful")
        funny = gr.inputs.Slider(minimum=0, maximum=5, step=1, default=0, label="funny")
        cool = gr.inputs.Slider(minimum=0, maximum=5, step=1, default=0, label="cool")
    with gr.Row():
        button = gr.Button("Generate reviews !", elem_id='btn')

    with gr.Row():
        output1 = gr.Textbox(label="Review #1")
        output2 = gr.Textbox(label="Review #2")
        output3 = gr.Textbox(label="Review #3")

    button.click(
        fn=generate_reviews,
        inputs=[stars, useful, funny, cool],
        outputs=[output1, output2, output3]
    )

demo.launch()