sailormars18 commited on
Commit
00e41f6
Β·
1 Parent(s): fd94e4d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -28
app.py CHANGED
@@ -25,8 +25,9 @@ def generate_text(prompt, length=100, theme=None, **kwargs):
25
 
26
  # If a theme is specified, add it to the prompt as a prefix for a special token
27
  if theme:
28
- prompt = ' <{}> '.format(theme.strip()) + prompt.strip()
29
 
 
30
  input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
31
  attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=device)
32
  pad_token_id = tokenizer.eos_token_id
@@ -34,6 +35,7 @@ def generate_text(prompt, length=100, theme=None, **kwargs):
34
  # Set the max length of the generated text based on the input parameter
35
  max_length = length if length > 0 else 100
36
 
 
37
  sample_outputs = model.generate(
38
  input_ids,
39
  attention_mask=attention_mask,
@@ -47,38 +49,28 @@ def generate_text(prompt, length=100, theme=None, **kwargs):
47
  no_repeat_ngram_size=2,
48
  repetition_penalty=1.5,
49
  )
50
- generated_text = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
51
-
52
- # Post preprocessing of the generated text
53
-
54
- # Remove any leading and trailing quotation marks
55
- generated_text = generated_text.strip('"')
56
-
57
- # Remove leading and trailing whitespace
58
- generated_text = generated_text.strip()
59
 
60
- # Find the special token in the generated text and remove it
61
- match = re.search(r'<([^>]+)>', generated_text)
62
- if match:
63
- generated_text = generated_text[:match.start()] + generated_text[match.end():]
64
-
65
- # Remove any leading numeric characters and quotation marks
66
- generated_text = re.sub(r'^\d+', '', generated_text)
67
- generated_text = re.sub(r'^"', '', generated_text)
68
 
69
- # Remove any newline characters from the generated text
70
- generated_text = generated_text.replace('\n', '')
 
 
 
 
 
71
 
72
- # Remove any other unwanted special characters
73
- generated_text = re.sub(r'[^\w\s]+', '', generated_text)
74
-
75
- return generated_text.strip().capitalize()
76
 
77
- # Define a Gradio interface for the generate_text function, allowing users to input a prompt and generate text based on it
78
  iface = gr.Interface(
79
  fn=generate_text,
80
- inputs=['text', gr.inputs.Slider(minimum=10, maximum=100, default=50, label='Length of text'),
81
- gr.inputs.Textbox(default='Food', label='Theme')],
 
 
 
82
  outputs=[gr.outputs.Textbox(label='Generated Text')],
83
  title='Yelp Review Generator',
84
  description='Generate a Yelp review based on a prompt, length of text, and theme.',
@@ -94,4 +86,4 @@ iface = gr.Interface(
94
  flagging_options=[("πŸ™Œ", "positive"), ("😞", "negative")],
95
  )
96
 
97
- iface.launch(debug=False)
 
25
 
26
  # If a theme is specified, add it to the prompt as a prefix for a special token
27
  if theme:
28
+ prompt = f"<{theme.strip()}> {prompt.strip()}"
29
 
30
+ # Encode the input prompt
31
  input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
32
  attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=device)
33
  pad_token_id = tokenizer.eos_token_id
 
35
  # Set the max length of the generated text based on the input parameter
36
  max_length = length if length > 0 else 100
37
 
38
+ # Generate the text using the model
39
  sample_outputs = model.generate(
40
  input_ids,
41
  attention_mask=attention_mask,
 
49
  no_repeat_ngram_size=2,
50
  repetition_penalty=1.5,
51
  )
 
 
 
 
 
 
 
 
 
52
 
53
+ # Decode the generated text
54
+ generated_text = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
55
 
56
+ # Postprocessing of the generated text
57
+ generated_text = generated_text.strip().strip('"') # Remove leading and trailing whitespace, remove any leading and trailing quotation marks
58
+ generated_text = re.sub(r'<([^>]+)>', '', generated_text) # Find the special token in the generated text and remove it
59
+ generated_text = re.sub(r'^\d+|^"', '', generated_text) # Remove any leading numeric characters and quotation marks
60
+ generated_text = generated_text.replace('\n', '') # Remove any newline characters from the generated text
61
+ generated_text = re.sub(r'[^\w\s]+', '', generated_text) # Remove any other unwanted special characters
62
+ generated_text = generated_text.capitalize()
63
 
64
+ return generated_text
 
 
 
65
 
66
+ # Define a Gradio interface for the generate_text function
67
  iface = gr.Interface(
68
  fn=generate_text,
69
+ inputs=[
70
+ "text",
71
+ gr.inputs.Slider(minimum=10, maximum=100, default=50, label='Length of text'),
72
+ gr.inputs.Textbox(default='Food', label='Theme')
73
+ ],
74
  outputs=[gr.outputs.Textbox(label='Generated Text')],
75
  title='Yelp Review Generator',
76
  description='Generate a Yelp review based on a prompt, length of text, and theme.',
 
86
  flagging_options=[("πŸ™Œ", "positive"), ("😞", "negative")],
87
  )
88
 
89
+ iface.launch(debug=False, share=True)