Spaces:
Build error
Build error
File size: 2,269 Bytes
e030ae6 1f57142 e030ae6 445b401 1f57142 445b401 e030ae6 1f57142 e030ae6 1f57142 e030ae6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import gradio as gr
import os
model_name = 'eliolio/bart-finetuned-yelpreviews'
access_token = os.environ.get('private_token')
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, use_auth_token=access_token)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=access_token)
def create_prompt(stars, useful, funny, cool):
return f"Generate review: stars: {stars}, useful: {useful}, funny: {funny}, cool: {cool}"
def postprocess(review):
dot = review.rfind('.')
return review[:dot]
def generate_reviews(stars, useful, funny, cool):
text = create_prompt(stars, useful, funny, cool)
inputs = tokenizer(text, return_tensors='pt')
out = model.generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
do_sample=True,
num_return_sequences=3,
temperature=1.2,
top_p=0.9
)
reviews = []
for review in out:
reviews.append(postprocess(tokenizer.decode(review, skip_special_tokens=True)))
return reviews[0], reviews[1], reviews[2]
css = """
#ctr {text-align: center;}
#btn {color: white; background: linear-gradient( 90deg, rgba(255,166,0,1) 14.7%, rgba(255,99,97,1) 73% );}
"""
md_text = """## Generating Yelp reviews with BART-base ⭐⭐⭐"""
demo = gr.Blocks(css=css)
with demo:
with gr.Row():
gr.Markdown(md_text, elem_id='ctr')
with gr.Row():
stars = gr.inputs.Slider(minimum=0, maximum=5, step=1, default=0, label="stars")
useful = gr.inputs.Slider(minimum=0, maximum=5, step=1, default=0, label="useful")
funny = gr.inputs.Slider(minimum=0, maximum=5, step=1, default=0, label="funny")
cool = gr.inputs.Slider(minimum=0, maximum=5, step=1, default=0, label="cool")
with gr.Row():
button = gr.Button("Generate reviews !", elem_id='btn')
with gr.Row():
output1 = gr.Textbox(label="Review #1")
output2 = gr.Textbox(label="Review #2")
output3 = gr.Textbox(label="Review #3")
button.click(
fn=generate_reviews,
inputs=[stars, useful, funny, cool],
outputs=[output1, output2, output3]
)
demo.launch() |