Spaces:
Runtime error
Runtime error
File size: 3,992 Bytes
a403fa8 945e7fe a403fa8 945e7fe 3581479 b014abf 3581479 0af08df 3581479 4cb5bf4 3581479 c0357a8 3581479 00e41f6 3581479 00e41f6 3581479 00e41f6 3581479 00e41f6 3581479 00e41f6 3581479 00e41f6 909e3bf 00e41f6 909e3bf 00e41f6 909e3bf da320a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import subprocess
subprocess.run(["pip", "install","gradio","torch","transformers"])
import re
import gradio as gr
import torch
import transformers
import json
from transformers import GPT2LMHeadModel, GPT2Tokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define a function for generating text based on a prompt using the fine-tuned GPT-2 model and the tokenizer
def generate_text(prompt, length=100, theme=None, **kwargs):
model_url = "https://huggingface.co/spaces/sailormars18/Yelp-reviews-usingGPT2/blob/main/pytorch_model.bin"
config_url = "https://huggingface.co/spaces/sailormars18/Yelp-reviews-usingGPT2/blob/main/config.json"
generation_config_url = "https://huggingface.co/spaces/sailormars18/Yelp-reviews-usingGPT2/blob/main/generation_config.json"
# Load the model from the Hugging Face space
model = transformers.GPT2LMHeadModel.from_pretrained("./pytorch_model.bin", config="./config.json").to(device)
# Load the tokenizer from the Hugging Face space
tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2')
# If a theme is specified, add it to the prompt as a prefix for a special token
if theme:
prompt = f"<{theme.strip()}> {prompt.strip()}"
# Encode the input prompt
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=device)
pad_token_id = tokenizer.eos_token_id
# Set the max length of the generated text based on the input parameter
max_length = length if length > 0 else 100
# Generate the text using the model
sample_outputs = model.generate(
input_ids,
attention_mask=attention_mask,
pad_token_id=pad_token_id,
do_sample=True,
max_length=max_length,
top_k=50,
top_p=0.95,
temperature=0.8,
num_return_sequences=1,
no_repeat_ngram_size=2,
repetition_penalty=1.5,
)
# Decode the generated text
generated_text = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
# Postprocessing of the generated text
generated_text = generated_text.strip().strip('"') # Remove leading and trailing whitespace, remove any leading and trailing quotation marks
generated_text = re.sub(r'<([^>]+)>', '', generated_text) # Find the special token in the generated text and remove it
generated_text = re.sub(r'^\d+|^"', '', generated_text) # Remove any leading numeric characters and quotation marks
generated_text = generated_text.replace('\n', '') # Remove any newline characters from the generated text
generated_text = re.sub(r'[^\w\s]+', '', generated_text) # Remove any other unwanted special characters
generated_text = generated_text.capitalize()
return generated_text
# Define a Gradio interface for the generate_text function
iface = gr.Interface(
fn=generate_text,
inputs=[
"text",
gr.inputs.Slider(minimum=10, maximum=100, default=50, label='Length of text'),
gr.inputs.Textbox(default='Food', label='Theme')
],
outputs=[gr.outputs.Textbox(label='Generated Text')],
title='Yelp Review Generator',
description='Generate a Yelp review based on a prompt, length of text, and theme.',
examples=[
['I had a great experience at this restaurant.', 50, 'Service'],
['The service was terrible and the food was cold.', 50, 'Atmosphere'],
['The food was delicious but the service was slow.', 50, 'Food'],
['The ambiance was amazing and the staff was friendly.', 75, 'Service'],
['The waitstaff was knowledgeable and attentive, but the noise level was a bit high.', 75, 'Atmosphere'],
['The menu had a good variety of options, but the portion sizes were a bit small for the price.', 75, 'Food']
],
allow_flagging="manual",
flagging_options=[("π", "positive"), ("π", "negative")],
)
iface.launch(debug=False) |