File size: 3,992 Bytes
a403fa8
945e7fe
 
a403fa8
945e7fe
3581479
 
b014abf
 
 
 
 
3581479
 
 
 
0af08df
 
3581479
 
4cb5bf4
3581479
 
c0357a8
3581479
 
 
00e41f6
3581479
00e41f6
3581479
 
 
 
 
 
 
00e41f6
3581479
 
 
 
 
 
 
 
 
 
 
 
 
 
00e41f6
 
3581479
00e41f6
 
 
 
 
 
 
3581479
00e41f6
909e3bf
00e41f6
909e3bf
 
00e41f6
 
 
 
 
909e3bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da320a8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import subprocess
subprocess.run(["pip", "install","gradio","torch","transformers"])
import re
import gradio as gr
import torch
import transformers

import json
from transformers import GPT2LMHeadModel, GPT2Tokenizer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define a function for generating text based on a prompt using the fine-tuned GPT-2 model and the tokenizer
def generate_text(prompt, length=100, theme=None, **kwargs):

    model_url = "https://huggingface.co/spaces/sailormars18/Yelp-reviews-usingGPT2/blob/main/pytorch_model.bin"
    config_url = "https://huggingface.co/spaces/sailormars18/Yelp-reviews-usingGPT2/blob/main/config.json"
    generation_config_url = "https://huggingface.co/spaces/sailormars18/Yelp-reviews-usingGPT2/blob/main/generation_config.json"
    
    # Load the model from the Hugging Face space
    model = transformers.GPT2LMHeadModel.from_pretrained("./pytorch_model.bin", config="./config.json").to(device)
    
    # Load the tokenizer from the Hugging Face space
    tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2')

    # If a theme is specified, add it to the prompt as a prefix for a special token
    if theme:
        prompt = f"<{theme.strip()}> {prompt.strip()}"

    # Encode the input prompt
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
    attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=device)
    pad_token_id = tokenizer.eos_token_id

    # Set the max length of the generated text based on the input parameter
    max_length = length if length > 0 else 100

    # Generate the text using the model
    sample_outputs = model.generate(
        input_ids,
        attention_mask=attention_mask,
        pad_token_id=pad_token_id,
        do_sample=True,
        max_length=max_length,
        top_k=50,
        top_p=0.95,
        temperature=0.8,
        num_return_sequences=1,
        no_repeat_ngram_size=2,
        repetition_penalty=1.5,
    )

    # Decode the generated text
    generated_text = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)

    # Postprocessing of the generated text
    generated_text = generated_text.strip().strip('"') # Remove leading and trailing whitespace, remove any leading and trailing quotation marks
    generated_text = re.sub(r'<([^>]+)>', '', generated_text) # Find the special token in the generated text and remove it
    generated_text = re.sub(r'^\d+|^"', '', generated_text) # Remove any leading numeric characters and quotation marks
    generated_text = generated_text.replace('\n', '') # Remove any newline characters from the generated text
    generated_text = re.sub(r'[^\w\s]+', '', generated_text) # Remove any other unwanted special characters
    generated_text = generated_text.capitalize()

    return generated_text

# Define a Gradio interface for the generate_text function
iface = gr.Interface(
    fn=generate_text,
    inputs=[
        "text", 
        gr.inputs.Slider(minimum=10, maximum=100, default=50, label='Length of text'),
        gr.inputs.Textbox(default='Food', label='Theme')
    ],
    outputs=[gr.outputs.Textbox(label='Generated Text')],
    title='Yelp Review Generator',
    description='Generate a Yelp review based on a prompt, length of text, and theme.',
    examples=[
        ['I had a great experience at this restaurant.', 50, 'Service'],
        ['The service was terrible and the food was cold.', 50, 'Atmosphere'],
        ['The food was delicious but the service was slow.', 50, 'Food'],
        ['The ambiance was amazing and the staff was friendly.', 75, 'Service'],
        ['The waitstaff was knowledgeable and attentive, but the noise level was a bit high.', 75, 'Atmosphere'],
        ['The menu had a good variety of options, but the portion sizes were a bit small for the price.', 75, 'Food']
    ],
    allow_flagging="manual",
    flagging_options=[("πŸ™Œ", "positive"), ("😞", "negative")],
)

iface.launch(debug=False)