File size: 2,648 Bytes
c2a8649
6031dc7
 
 
 
1888310
6031dc7
 
9af81fd
c2a8649
6031dc7
9af81fd
1888310
 
 
6031dc7
1888310
 
 
 
 
 
6031dc7
 
234658e
6031dc7
1888310
 
6031dc7
 
c2a8649
9af81fd
 
 
 
 
 
 
 
6031dc7
 
1888310
9af81fd
 
c2a8649
 
1888310
 
6031dc7
c2a8649
 
 
6031dc7
1888310
6031dc7
 
1888310
c2a8649
6031dc7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import gradio as gr
from diffusers import StableDiffusionPipeline
import torch

# Function to automatically switch between GPU and CPU
def load_model(base_model_id, adapter_model_id=None):
    if torch.cuda.is_available():
        device = "cuda"
        info = "Running on GPU (CUDA) 🔥"
    else:
        device = "cpu"
        info = "Running on CPU 🥶"

    # Load the base model dynamically on the correct device
    pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16 if device == "cuda" else torch.float32)
    pipe = pipe.to(device)

    # If an adapter model is provided, load and merge the adapter model
    if adapter_model_id:
        adapter_model = StableDiffusionPipeline.from_pretrained(adapter_model_id, torch_dtype=torch.float16 if device == "cuda" else torch.float32)
        pipe.unet.load_attn_procs(adapter_model_id)  # This applies the adapter like LoRA to the model's UNet
        info += f" with Adapter Model: {adapter_model_id}"
    
    return pipe, info

# Function for text-to-image generation with dynamic model ID and device info
def generate_image(base_model_id, adapter_model_id, prompt):
    pipe, info = load_model(base_model_id, adapter_model_id)
    image = pipe(prompt).images[0]
    return image, info

# Check device (GPU/CPU) once at the start and show it in the UI
if torch.cuda.is_available():
    device = "cuda"
    info = "Running on GPU (CUDA) 🔥"
else:
    device = "cpu"
    info = "Running on CPU 🥶"

# Create the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("## Custom Text-to-Image Generator with Adapter Support")
    gr.Markdown(f"**{info}**")  # Display GPU/CPU information in the UI

    with gr.Row():
        with gr.Column():
            base_model_id = gr.Textbox(label="Enter Base Model ID (e.g., CompVis/stable-diffusion-v1-4)", placeholder="Base Model ID")
            adapter_model_id = gr.Textbox(label="Enter Adapter Model ID (optional, e.g., nevreal/vMurderDrones-Lora)", placeholder="Adapter Model ID (optional)", value="")
            prompt = gr.Textbox(label="Enter your prompt", placeholder="Describe the image you want to generate")
            generate_btn = gr.Button("Generate Image")
        
        with gr.Column():
            output_image = gr.Image(label="Generated Image")
            device_info = gr.Markdown()  # To display if GPU or CPU is used and whether an adapter is applied
    
    # Link the button to the image generation function
    generate_btn.click(fn=generate_image, inputs=[base_model_id, adapter_model_id, prompt], outputs=[output_image, device_info])

# Launch the app
demo.launch()