clockclock commited on
Commit
55d13a0
·
verified ·
1 Parent(s): b01886f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -0
app.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+
3
+ import gradio as gr
4
+ import torch
5
+ from diffusers import AutoPipelineForInpainting
6
+ from PIL import Image
7
+
8
+ # --- Model Loading ---
9
+ # Load the model only once at the start of the application
10
+ # We use float16 for memory efficiency and speed on GPUs
11
+ # If no GPU is available, this will run on CPU (but it will be very slow)
12
+ try:
13
+ pipe = AutoPipelineForInpainting.from_pretrained(
14
+ "stabilityai/stable-diffusion-2-inpainting",
15
+ torch_dtype=torch.float16,
16
+ variant="fp16"
17
+ ).to("cuda")
18
+ except Exception as e:
19
+ print(f"Could not load model on GPU: {e}. Falling back to CPU.")
20
+ pipe = AutoPipelineForInpainting.from_pretrained(
21
+ "stabilityai/stable-diffusion-2-inpainting"
22
+ )
23
+
24
+ # --- The Inpainting Function ---
25
+ # This is the core function that takes user inputs and generates the image
26
+ def inpaint_image(input_dict, prompt, negative_prompt, guidance_scale, num_steps):
27
+ """
28
+ Performs inpainting on an image based on a mask and a prompt.
29
+
30
+ Args:
31
+ input_dict (dict): A dictionary from Gradio's Image component containing 'image' and 'mask'.
32
+ prompt (str): The text prompt describing what to generate in the masked area.
33
+ negative_prompt (str): The text prompt describing what to avoid.
34
+ guidance_scale (float): A value to control how much the generation follows the prompt.
35
+ num_steps (int): The number of inference steps.
36
+
37
+ Returns:
38
+ PIL.Image: The resulting image after inpainting.
39
+ """
40
+ # Separate the image and the mask from the input dictionary
41
+ image = input_dict["image"].convert("RGB")
42
+ mask_image = input_dict["mask"].convert("RGB")
43
+
44
+ # The model works best with images of a specific size (e.g., 512x512)
45
+ # We can resize for consistency, but for user-friendliness, we'll let the pipeline handle it.
46
+ # However, it's good practice to inform the user that square images work best.
47
+
48
+ print(f"Starting inpainting with prompt: '{prompt}'")
49
+
50
+ # Run the inpainting pipeline
51
+ result_image = pipe(
52
+ prompt=prompt,
53
+ image=image,
54
+ mask_image=mask_image,
55
+ negative_prompt=negative_prompt,
56
+ guidance_scale=guidance_scale,
57
+ num_inference_steps=int(num_steps),
58
+ ).images[0]
59
+
60
+ return result_image
61
+
62
+ # --- Gradio User Interface ---
63
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
64
+ gr.Markdown(
65
+ """
66
+ # 🎨 AI Image Fixer (Inpainting)
67
+
68
+ Have an AI-generated image with weird hands, faces, or artifacts? Fix it here!
69
+
70
+ **How to use:**
71
+ 1. Upload your image.
72
+ 2. Use the brush tool to "paint" over the parts you want to replace. This is your mask.
73
+ 3. Write a prompt describing what you want in the painted-over area.
74
+ 4. Adjust the advanced settings if you want more control.
75
+ 5. Click "Fix It!" and see the magic happen.
76
+ """
77
+ )
78
+
79
+ with gr.Row():
80
+ # Input column
81
+ with gr.Column():
82
+ gr.Markdown("### 1. Upload & Mask Your Image")
83
+ # The Image component with a drawing tool for masking
84
+ input_image = gr.Image(
85
+ label="Upload Image & Draw Mask",
86
+ source="upload",
87
+ tool="brush",
88
+ type="pil" # We want to work with PIL images in our function
89
+ )
90
+
91
+ gr.Markdown("### 2. Describe Your Fix")
92
+ prompt = gr.Textbox(label="Prompt", placeholder="e.g., 'A beautiful, realistic human hand, detailed fingers'")
93
+
94
+ # Accordion for advanced settings to keep the UI clean
95
+ with gr.Accordion("Advanced Settings", open=False):
96
+ negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="e.g., 'blurry, distorted, extra fingers, cartoon'")
97
+ guidance_scale = gr.Slider(minimum=0, maximum=20, value=8.0, label="Guidance Scale")
98
+ num_steps = gr.Slider(minimum=10, maximum=100, step=1, value=40, label="Inference Steps")
99
+
100
+ # Output column
101
+ with gr.Column():
102
+ gr.Markdown("### 3. Get Your Result")
103
+ output_image = gr.Image(
104
+ label="Resulting Image",
105
+ type="pil"
106
+ )
107
+
108
+ # The button to trigger the process
109
+ submit_button = gr.Button("Fix It!", variant="primary")
110
+
111
+ # Connect the button to the function
112
+ submit_button.click(
113
+ fn=inpaint_image,
114
+ inputs=[input_image, prompt, negative_prompt, guidance_scale, num_steps],
115
+ outputs=output_image
116
+ )
117
+
118
+ # Launch the Gradio app
119
+ if __name__ == "__main__":
120
+ demo.launch()