Faizan Azizahmed Shaikh commited on
Commit
e7bd0c2
·
1 Parent(s): 94f83b0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -2,7 +2,7 @@
2
  # coding: utf-8
3
 
4
  # importing required libraries
5
- from transformers import pipeline
6
  from torch import bfloat16
7
  import gradio as gr
8
  WARNING = """Whoooa there, partner! Before you dive in, let's establish some ground rules:\nBy using this application, you are stating that you are the 'Big Cheese', the 'Head Honcho', the 'Master of Your Domain', in short, the sole user of this app. Now, don't go blaming us or any other parties if the results are not to your liking, or lead to any unforeseen circumstances.\nIn the simplest terms, the moment you input any data on this page you accept full responsibility for any and all usage of this application. Just like when you eat that extra slice of pizza at midnight, you're the one who's responsible for the extra workout the next day, not the pizza guy!"""
@@ -15,7 +15,7 @@ def story(prompt="When I was young", model_name = "coffeeee/nsfw-story-generator
15
  story_length: number of maximum tokens to generate, function_default: 50, modified_default: 300;
16
  """
17
  # create a pipeline for the model
18
- create = pipeline(model=model_name, torch_dtype=bfloat16, device_map="auto")
19
  # return the output from the model
20
  return create(prompt, max_new_tokens=story_length)[0]['generated_text']
21
 
@@ -28,5 +28,5 @@ with gr.Blocks() as demo:
28
  story_len = gr.Slider(100,500, label="Arc length")
29
  gen_story = gr.Textbox(label="Story", lines=15, max_lines=20)
30
  greet_btn = gr.Button("Entertain")
31
- greet_btn.click(fn=story, inputs=[story_start, selected_model, story_len], outputs=gen_story, api_name="story")
32
  demo.launch(inline=False, share=False)
 
2
  # coding: utf-8
3
 
4
  # importing required libraries
5
+ from transformers import pipeline, GPT2TokenizerFast
6
  from torch import bfloat16
7
  import gradio as gr
8
  WARNING = """Whoooa there, partner! Before you dive in, let's establish some ground rules:\nBy using this application, you are stating that you are the 'Big Cheese', the 'Head Honcho', the 'Master of Your Domain', in short, the sole user of this app. Now, don't go blaming us or any other parties if the results are not to your liking, or lead to any unforeseen circumstances.\nIn the simplest terms, the moment you input any data on this page you accept full responsibility for any and all usage of this application. Just like when you eat that extra slice of pizza at midnight, you're the one who's responsible for the extra workout the next day, not the pizza guy!"""
 
15
  story_length: number of maximum tokens to generate, function_default: 50, modified_default: 300;
16
  """
17
  # create a pipeline for the model
18
+ create = pipeline(model=model_name, torch_dtype=bfloat16, device_map="auto", pad_token_id=GPT2TokenizerFast.from_pretrained("gpt2").eos_token_id)
19
  # return the output from the model
20
  return create(prompt, max_new_tokens=story_length)[0]['generated_text']
21
 
 
28
  story_len = gr.Slider(100,500, label="Arc length")
29
  gen_story = gr.Textbox(label="Story", lines=15, max_lines=20)
30
  greet_btn = gr.Button("Entertain")
31
+ greet_btn.click(fn=story, inputs=[story_start, selected_model, story_len], outputs=gen_story)
32
  demo.launch(inline=False, share=False)