Spaces:
Runtime error
Runtime error
from diffusers import DDPMPipeline | |
import torch | |
import numpy as np | |
import gradio as gr | |
from torchvision.utils import make_grid | |
import torchvision.transforms as transforms | |
from PIL import Image | |
from transformers import logging | |
from transformers.utils import move_cache | |
# Optional: handle potential cache migration for Transformers | |
try: | |
move_cache() | |
except Exception as e: | |
logging.error(f"Error migrating cache: {str(e)}") | |
# Check for 'accelerate' library and suggest installation if not found | |
try: | |
from accelerate import Accelerator | |
except ImportError: | |
logging.warning("Accelerate library not found. It's recommended to install it for efficient model loading.") | |
# Load the pre-trained pipeline from Hugging Face | |
MODEL_ID = "ahmetyaylalioglu/textile_diffusion_ddpm" | |
pipeline = DDPMPipeline.from_pretrained(MODEL_ID) | |
pipeline.to("cuda") | |
pipeline.unet.eval() | |
def generate_images(seed, num_images): | |
# Set the seed | |
seed = int(seed) # Ensure the seed is an integer | |
torch.manual_seed(seed) | |
# Generate images | |
num_images = int(num_images) # Ensure num_images is an integer | |
generated_images = pipeline(batch_size=num_images).images | |
# Convert the list of generated images to a grid | |
image_grid = make_grid([transforms.ToTensor()(img) for img in generated_images], nrow=4) | |
# Convert the grid to a PIL image | |
image_grid_pil = transforms.ToPILImage()(image_grid) | |
return image_grid_pil | |
# Set up the Gradio interface | |
interface = gr.Interface( | |
fn=generate_images, | |
inputs=[ | |
gr.Textbox(label="Random Seed", placeholder="Enter a seed number", value=str(np.random.randint(0, 1000))), | |
gr.Textbox(label="Number of Images", placeholder="Enter number of images to generate", value="8") | |
], | |
outputs="image", | |
title="Textile Diffusion DDPM", | |
description="Generate textile images using a trained DDPM model from Hugging Face." | |
) | |
# Launch the interface | |
interface.launch() | |