Spaces:
Runtime error
Runtime error
File size: 3,972 Bytes
a403fa8 99fbd97 945e7fe 9d71ab1 945e7fe 3581479 b014abf 3581479 99fbd97 3581479 2672584 f900d24 3581479 f900d24 3581479 00e41f6 3581479 00e41f6 3581479 00e41f6 3581479 00e41f6 3581479 00e41f6 3581479 00e41f6 909e3bf 00e41f6 909e3bf 3614153 99fbd97 3614153 909e3bf f868de6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import subprocess
subprocess.run(["pip", "install","gradio","torch","transformers"])
import re
import gradio as gr
import torch
import transformers
import json
from transformers import GPT2LMHeadModel, GPT2Tokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define a function for generating text based on a prompt using the fine-tuned GPT-2 model and the tokenizer
def generate_text(prompt, length=100, theme=None, **kwargs):
model_url = "https://huggingface.co/spaces/sailormars18/Yelp-reviews-usingGPT2/blob/main/pytorch_model.bin"
config_url = "https://huggingface.co/spaces/sailormars18/Yelp-reviews-usingGPT2/blob/main/config.json"
generation_config_url = "https://huggingface.co/spaces/sailormars18/Yelp-reviews-usingGPT2/blob/main/generation_config.json"
# Load the model from the Hugging Face space
model = transformers.GPT2LMHeadModel.from_pretrained(model_url).to(device)
# Load the tokenizer from the Hugging Face space
tokenizer = transformers.GPT2Tokenizer.from_pretrained(config_url)
# If a theme is specified, add it to the prompt as a prefix for a special token
if theme:
prompt = f"<{theme.strip()}> {prompt.strip()}"
# Encode the input prompt
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=device)
pad_token_id = tokenizer.eos_token_id
# Set the max length of the generated text based on the input parameter
max_length = length if length > 0 else 100
# Generate the text using the model
sample_outputs = model.generate(
input_ids,
attention_mask=attention_mask,
pad_token_id=pad_token_id,
do_sample=True,
max_length=max_length,
top_k=50,
top_p=0.95,
temperature=0.8,
num_return_sequences=1,
no_repeat_ngram_size=2,
repetition_penalty=1.5,
)
# Decode the generated text
generated_text = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
# Postprocessing of the generated text
generated_text = generated_text.strip().strip('"') # Remove leading and trailing whitespace, remove any leading and trailing quotation marks
generated_text = re.sub(r'<([^>]+)>', '', generated_text) # Find the special token in the generated text and remove it
generated_text = re.sub(r'^\d+|^"', '', generated_text) # Remove any leading numeric characters and quotation marks
generated_text = generated_text.replace('\n', '') # Remove any newline characters from the generated text
generated_text = re.sub(r'[^\w\s]+', '', generated_text) # Remove any other unwanted special characters
generated_text = generated_text.capitalize()
return generated_text
# Define a Gradio interface for the generate_text function
iface = gr.Interface(
fn=generate_text,
inputs=[
"text",
gr.inputs.Slider(minimum=10, maximum=100, default=50, label='Length of text'),
gr.inputs.Textbox(default='Food', label='Theme')
],
outputs=[gr.outputs.Textbox(label='Generated Text')],
title='Yelp Review Generator',
description='Generate a Yelp review based on a prompt, length of text, and theme.',
examples=[
['I had a great experience at this restaurant.', 50, 'Service'],
['The service was terrible and the food was cold.', 50, 'Atmosphere'],
['The food was delicious but the service was slow.', 50, 'Food'],
['The ambiance was amazing and the staff was friendly.', 75, 'Service'],
['The waitstaff was knowledgeable and attentive, but the noise level was a bit high.', 75, 'Atmosphere'],
['The menu had a good variety of options, but the portion sizes were a bit small for the price.', 75, 'Food']
],
allow_flagging="manual",
flagging_options=[("π", "positive"), ("π", "negative")],
)
iface.launch(debug=False, share=True) |