Sourudra commited on
Commit
56971aa
·
verified ·
1 Parent(s): dcf1a60

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -14
app.py CHANGED
@@ -6,7 +6,9 @@ from diffusers import DiffusionPipeline
6
  # Device setup
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
  model_repo_id_turbo = "stabilityai/sdxl-turbo" # Stability AI Model
9
- pipe_turbo = DiffusionPipeline.from_pretrained(model_repo_id_turbo, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32).to(device)
 
 
10
 
11
  # Placeholder for ZB-Tech model
12
  def load_zb_model():
@@ -16,16 +18,17 @@ def load_zb_model():
16
  def custom_infer(
17
  model_choice, prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps
18
  ):
19
- # Load the selected model
20
  if model_choice == "Faster image generation (suitable for CPUs)":
 
21
  model = load_zb_model()
22
  return model(prompt)
23
  else:
 
24
  default_negative_prompt = "no watermark, hezzy, blurry"
25
  combined_negative_prompt = f"{default_negative_prompt}, {negative_prompt}" if negative_prompt else default_negative_prompt
26
 
27
  if randomize_seed:
28
- seed = random.randint(0, np.iinfo(np.int32).max)
29
 
30
  generator = torch.Generator().manual_seed(seed)
31
  image = pipe_turbo(
@@ -39,22 +42,19 @@ def custom_infer(
39
  ).images[0]
40
  return image, seed
41
 
42
- # CSS for centering UI
43
  css = """
44
  #col-container {
45
- display: flex;
46
- flex-direction: column;
47
- align-items: center;
48
- justify-content: center;
49
- text-align: center;
50
  margin: 0 auto;
 
 
51
  }
52
  """
53
 
54
  # Gradio app
55
  with gr.Blocks(css=css) as demo:
56
  with gr.Column(elem_id="col-container"):
57
- # App name and description
58
  gr.Markdown(
59
  """
60
  # AI-Powered Text-to-Image Generator
@@ -72,9 +72,11 @@ with gr.Blocks(css=css) as demo:
72
  value="Faster image generation (suitable for CPUs)",
73
  )
74
 
75
- # Input section
76
  prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...")
77
- with gr.Accordion("Advanced Settings", open=False):
 
 
78
  negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="Enter a negative prompt here...")
79
  seed = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, value=0)
80
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
@@ -84,8 +86,21 @@ with gr.Blocks(css=css) as demo:
84
  num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=25)
85
 
86
  # Output section
87
- result = gr.Image(label="Generated Image", type="pil")
88
- gr.Button("Generate").click(
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  custom_infer,
90
  inputs=[model_choice, prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
91
  outputs=result
 
6
  # Device setup
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
  model_repo_id_turbo = "stabilityai/sdxl-turbo" # Stability AI Model
9
+ pipe_turbo = DiffusionPipeline.from_pretrained(
10
+ model_repo_id_turbo, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
11
+ ).to(device)
12
 
13
  # Placeholder for ZB-Tech model
14
  def load_zb_model():
 
18
  def custom_infer(
19
  model_choice, prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps
20
  ):
 
21
  if model_choice == "Faster image generation (suitable for CPUs)":
22
+ # Call ZB-Tech model for faster generation
23
  model = load_zb_model()
24
  return model(prompt)
25
  else:
26
+ # Use Stability AI's model with customizable options
27
  default_negative_prompt = "no watermark, hezzy, blurry"
28
  combined_negative_prompt = f"{default_negative_prompt}, {negative_prompt}" if negative_prompt else default_negative_prompt
29
 
30
  if randomize_seed:
31
+ seed = random.randint(0, 2147483647)
32
 
33
  generator = torch.Generator().manual_seed(seed)
34
  image = pipe_turbo(
 
42
  ).images[0]
43
  return image, seed
44
 
45
+ # CSS for aligning the UI
46
  css = """
47
  #col-container {
 
 
 
 
 
48
  margin: 0 auto;
49
+ max-width: 640px;
50
+ text-align: center;
51
  }
52
  """
53
 
54
  # Gradio app
55
  with gr.Blocks(css=css) as demo:
56
  with gr.Column(elem_id="col-container"):
57
+ # App title and description
58
  gr.Markdown(
59
  """
60
  # AI-Powered Text-to-Image Generator
 
72
  value="Faster image generation (suitable for CPUs)",
73
  )
74
 
75
+ # Input for the text prompt
76
  prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...")
77
+
78
+ # Advanced options (conditionally displayed)
79
+ with gr.Row(visible=False) as advanced_settings:
80
  negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="Enter a negative prompt here...")
81
  seed = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, value=0)
82
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
 
86
  num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=25)
87
 
88
  # Output section
89
+ result = gr.Image(label="Generated Image", type="pil", elem_id="col-container")
90
+ generate_button = gr.Button("Generate")
91
+
92
+ # Event to toggle advanced options based on model selection
93
+ def toggle_advanced_options(model_choice):
94
+ return model_choice != "Faster image generation (suitable for CPUs)"
95
+
96
+ model_choice.change(
97
+ toggle_advanced_options,
98
+ inputs=[model_choice],
99
+ outputs=[advanced_settings]
100
+ )
101
+
102
+ # Generate button action
103
+ generate_button.click(
104
  custom_infer,
105
  inputs=[model_choice, prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
106
  outputs=result