Sourudra commited on
Commit
ce5bf23
·
verified ·
1 Parent(s): 932864a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -36
app.py CHANGED
@@ -1,46 +1,85 @@
1
  import gradio as gr
 
 
 
 
2
 
3
- # CSS for aligning the UI elements in the middle
4
- css = """
5
- #col-container {
6
- margin: 0 auto;
7
- max-width: 640px;
8
- text-align: center;
9
- }
10
- """
11
-
12
- # Load the ZB-Tech/Text-to-Image model
13
- def infer(prompt):
14
- model = gr.Interface.load("models/ZB-Tech/Text-to-Image")
15
- return model(prompt)
16
-
17
- # Gradio app
18
- with gr.Blocks(css=css) as demo:
19
- with gr.Column(elem_id="col-container"):
20
- # Title and description
21
- gr.Markdown(
22
- """
23
- # AI-Powered Text-to-Image Generator
24
- *Generate stunning images from text prompts using the ZB-Tech/Text-to-Image model.*
25
- """
26
- )
 
 
 
 
 
 
 
 
27
 
28
- # Input: Prompt
29
- prompt = gr.Textbox(
30
- label="Prompt",
31
- placeholder="Enter your prompt here...",
32
- lines=2,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  )
34
 
35
- # Output: Generated Image
36
- result = gr.Image(label="Generated Image", type="pil")
 
 
 
 
 
 
 
 
 
37
 
38
- # Generate Button
39
- generate_button = gr.Button("Generate")
 
 
40
 
41
- # Button action: Call infer function
42
- generate_button.click(infer, inputs=[prompt], outputs=result)
43
 
44
- # Launch the app
45
  if __name__ == "__main__":
46
  demo.launch()
 
1
  import gradio as gr
2
+ import numpy as np
3
+ import random
4
+ import torch
5
+ from diffusers import DiffusionPipeline
6
 
7
+ # Check for GPU availability
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
10
+
11
+ # Load your DiffusionPipeline model
12
+ model_repo_id = "stabilityai/sdxl-turbo"
13
+ pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
14
+ pipe = pipe.to(device)
15
+
16
+ MAX_SEED = np.iinfo(np.int32).max
17
+ MAX_IMAGE_SIZE = 1024
18
+
19
+ # Define the custom model inference function
20
+ def custom_infer(
21
+ prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps
22
+ ):
23
+ if randomize_seed:
24
+ seed = random.randint(0, MAX_SEED)
25
+
26
+ generator = torch.Generator().manual_seed(seed)
27
+
28
+ image = pipe(
29
+ prompt=prompt,
30
+ negative_prompt=negative_prompt,
31
+ guidance_scale=guidance_scale,
32
+ num_inference_steps=num_inference_steps,
33
+ width=width,
34
+ height=height,
35
+ generator=generator,
36
+ ).images[0]
37
+
38
+ return image, seed
39
 
40
+
41
+ # Gradio interface for custom model
42
+ def custom_model_ui():
43
+ with gr.Blocks() as custom_demo:
44
+ gr.Markdown("## Custom Model: Stability AI SDXL")
45
+ with gr.Row():
46
+ prompt = gr.Text(label="Prompt")
47
+ run_button = gr.Button("Generate")
48
+
49
+ result = gr.Image(label="Generated Image")
50
+ negative_prompt = gr.Text(label="Negative Prompt", placeholder="Optional")
51
+ seed = gr.Slider(0, MAX_SEED, label="Seed", step=1, value=0)
52
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
53
+ width = gr.Slider(256, MAX_IMAGE_SIZE, step=32, value=1024, label="Width")
54
+ height = gr.Slider(256, MAX_IMAGE_SIZE, step=32, value=1024, label="Height")
55
+ guidance_scale = gr.Slider(0, 10, step=0.1, value=7.5, label="Guidance Scale")
56
+ num_inference_steps = gr.Slider(1, 50, step=1, value=30, label="Inference Steps")
57
+
58
+ run_button.click(
59
+ custom_infer,
60
+ inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
61
+ outputs=[result, seed],
62
  )
63
 
64
+ return custom_demo
65
+
66
+
67
+ # Preloaded Gradio model
68
+ def preloaded_model_ui():
69
+ with gr.Blocks() as preloaded_demo:
70
+ gr.Markdown("## Preloaded Model: ZB-Tech Text-to-Image")
71
+ preloaded_demo = gr.load("models/ZB-Tech/Text-to-Image")
72
+
73
+ return preloaded_demo
74
+
75
 
76
+ # Combine both interfaces in tabs
77
+ with gr.Blocks() as demo:
78
+ with gr.Tab("Custom Model"):
79
+ custom_ui = custom_model_ui()
80
 
81
+ with gr.Tab("Preloaded Model"):
82
+ preloaded_ui = preloaded_model_ui()
83
 
 
84
  if __name__ == "__main__":
85
  demo.launch()