File size: 3,282 Bytes
66a73ae
 
bad655a
 
 
66a73ae
 
bad655a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66a73ae
bad655a
 
 
 
 
 
66a73ae
 
 
bad655a
 
 
 
 
 
 
 
 
 
 
 
 
66a73ae
bad655a
66a73ae
 
 
 
 
 
 
 
 
 
 
 
bad655a
 
 
 
 
 
66a73ae
 
 
bad655a
66a73ae
bad655a
66a73ae
bad655a
66a73ae
 
 
bad655a
66a73ae
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import gradio as gr
import torch
from diffusers import StableDiffusionPipeline
from torchvision.models.segmentation import fcn_resnet50
from torchvision.transforms import Compose, ToTensor, Normalize, Resize, ToPILImage
from PIL import Image

# Device configuration
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load Stable Diffusion for text-to-image
text_to_image_pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16 if device == "cuda" else torch.float32
).to(device)

# Load a pre-trained FCN model for image-to-image transformations
unet_model = fcn_resnet50(pretrained=True).eval().to(device)

# Transforms for UNet
preprocess = Compose([
    Resize((512, 512)),
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

postprocess = Compose([
    ToPILImage(),
])


# Function for Text-to-Image
def text_to_image(prompt, negative_prompt, guidance_scale, num_inference_steps):
    image = text_to_image_pipe(
        prompt,
        negative_prompt=negative_prompt,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
    ).images[0]
    return image


# Function for Image-to-Image using Dynamic UNet
def apply_dynamic_unet(init_image, strength):
    with torch.no_grad():
        image_tensor = preprocess(init_image).unsqueeze(0).to(device)
        output = unet_model(image_tensor)["out"][0]
        output = torch.softmax(output, dim=0)  # Normalize predictions
        mask = output.argmax(dim=0).float().cpu()
        blended = (strength * mask.unsqueeze(0) + (1 - strength) * image_tensor[0].cpu()).clamp(0, 1)
        blended_image = postprocess(blended)
    return blended_image


# Gradio Interface
with gr.Blocks(theme='Respair/[email protected]') as demo:
    gr.Markdown("# Text-to-Image and Image-to-Image ")

    with gr.Tab("Text-to-Image"):
        with gr.Row():
            text_prompt = gr.Textbox(label="Prompt", placeholder="Enter your text here...")
            text_negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="Enter what to avoid...")
        with gr.Row():
            guidance_scale = gr.Slider(1, 20, value=7.5, step=0.1, label="Guidance Scale")
            num_inference_steps = gr.Slider(10, 100, value=50, step=1, label="Inference Steps")
        with gr.Row():
            generate_btn = gr.Button("Generate", elem_classes=["primary-button"])
        with gr.Row():
            text_output = gr.Image(label="Generated Image")

        generate_btn.click(
            text_to_image,
            inputs=[text_prompt, text_negative_prompt, guidance_scale, num_inference_steps],
            outputs=text_output,
        )

    with gr.Tab("Image-to-Image"):
        with gr.Row():
            init_image = gr.Image(type="pil", label="Upload Initial Image")
        with gr.Row():
            strength = gr.Slider(0.1, 1.0, value=0.75, step=0.05, label="Blend Strength")
        with gr.Row():
            img_generate_btn = gr.Button("Apply UNet", elem_classes=["primary-button"])
        with gr.Row():
            img_output = gr.Image(label="Modified Image")

        img_generate_btn.click(apply_dynamic_unet, inputs=[init_image, strength], outputs=img_output)

demo.launch(share=True)