Spaces:
Running
Running
import gradio as gr | |
import torch | |
from diffusers import DiffusionPipeline | |
from PIL import Image | |
import numpy as np | |
from torchvision import transforms | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Initialize the inpainting model | |
try: | |
inpaint_model = DiffusionPipeline.from_pretrained("diffusers/stable-diffusion-xl-1.0-inpainting-0.1").to(device) | |
except Exception as e: | |
print(f"Error initializing model: {e}") | |
inpaint_model = None | |
# Load a test image | |
def load_test_image(): | |
try: | |
# Provide the absolute path to a test image | |
return Image.open("/absolute/path/to/your/test_image.png") | |
except Exception as e: | |
print(f"Error loading test image: {e}") | |
return None | |
def process_image(prompt, image, style, upscale_factor, inpaint): | |
if image is None: | |
image = load_test_image() | |
if image is None: | |
return None, "No image received and failed to load test image." | |
try: | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
elif isinstance(image, torch.Tensor): | |
image = transforms.ToPILImage()(image) | |
elif not isinstance(image, Image.Image): | |
return None, f"Unsupported image format: {type(image)}." | |
print(f"Received image: {image.size}") | |
# Placeholder for processing logic | |
return image, None | |
except Exception as e: | |
error_message = f"Error in process_image function: {e}" | |
print(error_message) | |
return None, error_message | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
prompt_input = gr.Textbox(label="Enter your prompt") | |
image_input = gr.Image(label="Image (for inpainting)", type="pil") | |
style_input = gr.Dropdown(choices=["Fooocus Style", "SAI Anime"], label="Select Style") | |
upscale_input = gr.Slider(minimum=1, maximum=4, step=1, label="Upscale Factor") | |
inpaint_input = gr.Checkbox(label="Enable Inpainting") | |
output_image = gr.Image(label="Generated Image", type="pil") | |
error_output = gr.Textbox(label="Error Details", lines=4, placeholder="Error details will appear here") | |
generate_button = gr.Button("Generate Image") | |
generate_button.click( | |
process_image, | |
inputs=[prompt_input, image_input, style_input, upscale_input, inpaint_input], | |
outputs=[output_image, error_output] | |
) | |
demo.launch() | |