File size: 809 Bytes
c09bcc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import os
from diffusers import StableDiffusionPipeline
import torch

def generate_images(prompt, model_path="output/checkpoint-500/", num_images=1):
    required_files = ['pytorch_model.bin', 'model.safetensors', 'tf_model.h5', 'model.ckpt.index', 'flax_model.msgpack']
    
    if not any(os.path.exists(os.path.join(model_path, file)) for file in required_files):
        raise EnvironmentError(
            f"Error no file named {', '.join(required_files)} found in directory {model_path}. "
            "Ensure your model is correctly saved."
        )

    pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
    pipe = pipe.to("cuda" if torch.cuda.is_available() else "cpu")

    images = pipe(prompt, num_images=num_images).images
    return images