ciover2024 commited on
Commit
9244e51
·
verified ·
1 Parent(s): 8de7092

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -0
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import torch
4
+ from diffusers import StableDiffusionInpaintPipeline
5
+ import numpy as np
6
+ import cv2
7
+ import os
8
+ import shutil
9
+ from gradio_client import Client, handle_file
10
+
11
+ # Load the model once globally to avoid repeated loading
12
+ def load_inpainting_model():
13
+ model_path = "uberRealisticPornMerge_v23Inpainting.safetensors"
14
+ device = "cpu" # Explicitly use CPU
15
+ pipe = StableDiffusionInpaintPipeline.from_single_file(
16
+ model_path,
17
+ torch_dtype=torch.float32, # Use float32 for CPU
18
+ safety_checker=None
19
+ ).to(device)
20
+ return pipe
21
+
22
+ # Preload the model once
23
+ inpaint_pipeline = load_inpainting_model()
24
+
25
+ # Function to resize image (simpler interpolation method for speed)
26
+ def resize_to_match(input_image, output_image):
27
+ return output_image.resize(input_image.size, Image.BILINEAR) # Use BILINEAR for faster resizing
28
+
29
+ # Function to generate the mask using Florence SAM Masking API (Replicate)
30
+ def generate_mask(image_path, text_prompt="clothing"):
31
+ client_sam = Client("SkalskiP/florence-sam-masking")
32
+ mask_result = client_sam.predict(
33
+ image_input=handle_file(image_path), # Provide your image path here
34
+ text_input=text_prompt, # Use "clothing" as the prompt
35
+ api_name="/process_image"
36
+ )
37
+ return mask_result # This is the local path to the generated mask
38
+
39
+ # Save the generated mask
40
+ def save_mask(mask_local_path, save_path="generated_mask.png"):
41
+ try:
42
+ shutil.copy(mask_local_path, save_path)
43
+ except Exception as e:
44
+ print(f"Failed to save the mask: {e}")
45
+
46
+ # Function to perform inpainting
47
+ def inpaint_image(input_image, mask_image):
48
+ prompt = "undress, naked"
49
+ result = inpaint_pipeline(prompt=prompt, image=input_image, mask_image=mask_image)
50
+ inpainted_image = result.images[0]
51
+ inpainted_image = resize_to_match(input_image, inpainted_image)
52
+ return inpainted_image
53
+
54
+ # Function to process input image and mask
55
+ def process_image(input_image):
56
+ # Save the input image temporarily to process with Replicate
57
+ input_image_path = "temp_input_image.png"
58
+ input_image.save(input_image_path)
59
+
60
+ # Generate the mask using Florence SAM API
61
+ mask_local_path = generate_mask(image_path=input_image_path)
62
+
63
+ # Save the generated mask
64
+ mask_image_path = "generated_mask.png"
65
+ save_mask(mask_local_path, save_path=mask_image_path)
66
+
67
+ # Open the mask image and perform inpainting
68
+ mask_image = Image.open(mask_image_path)
69
+ result_image = inpaint_image(input_image, mask_image)
70
+
71
+ # Clean up temporary files
72
+ os.remove(input_image_path)
73
+ os.remove(mask_image_path)
74
+
75
+ return result_image
76
+
77
+ # Define Gradio interface using Blocks API
78
+ with gr.Blocks() as demo:
79
+ with gr.Row():
80
+ input_image = gr.Image(label="Upload Input Image", type="pil")
81
+ output_image = gr.Image(type="pil", label="Output Image")
82
+
83
+ # Button to trigger the process
84
+ with gr.Row():
85
+ btn = gr.Button("Run Inpainting")
86
+
87
+ # Function to run when button is clicked
88
+ btn.click(fn=process_image, inputs=[input_image], outputs=output_image)
89
+
90
+ # Launch the Gradio app
91
+ demo.launch(share=True)