yelp-reviews / app.py
Eliott
sampling strategy
445b401
raw
history blame
2.27 kB
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import gradio as gr
import os
model_name = 'eliolio/bart-finetuned-yelpreviews'
access_token = os.environ.get('private_token')
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, use_auth_token=access_token)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=access_token)
def create_prompt(stars, useful, funny, cool):
return f"Generate review: stars: {stars}, useful: {useful}, funny: {funny}, cool: {cool}"
def postprocess(review):
dot = review.rfind('.')
return review[:dot]
def generate_reviews(stars, useful, funny, cool):
text = create_prompt(stars, useful, funny, cool)
inputs = tokenizer(text, return_tensors='pt')
out = model.generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
do_sample=True,
num_return_sequences=3,
temperature=1.2,
top_p=0.9
)
reviews = []
for review in out:
reviews.append(postprocess(tokenizer.decode(review, skip_special_tokens=True)))
return reviews[0], reviews[1], reviews[2]
css = """
#ctr {text-align: center;}
#btn {color: white; background: linear-gradient( 90deg, rgba(255,166,0,1) 14.7%, rgba(255,99,97,1) 73% );}
"""
md_text = """## Generating Yelp reviews with BART-base ⭐⭐⭐"""
demo = gr.Blocks(css=css)
with demo:
with gr.Row():
gr.Markdown(md_text, elem_id='ctr')
with gr.Row():
stars = gr.inputs.Slider(minimum=0, maximum=5, step=1, default=0, label="stars")
useful = gr.inputs.Slider(minimum=0, maximum=5, step=1, default=0, label="useful")
funny = gr.inputs.Slider(minimum=0, maximum=5, step=1, default=0, label="funny")
cool = gr.inputs.Slider(minimum=0, maximum=5, step=1, default=0, label="cool")
with gr.Row():
button = gr.Button("Generate reviews !", elem_id='btn')
with gr.Row():
output1 = gr.Textbox(label="Review #1")
output2 = gr.Textbox(label="Review #2")
output3 = gr.Textbox(label="Review #3")
button.click(
fn=generate_reviews,
inputs=[stars, useful, funny, cool],
outputs=[output1, output2, output3]
)
demo.launch()