File size: 4,273 Bytes
375ee53
 
 
 
 
 
4e92ab0
375ee53
427f665
375ee53
 
4e92ab0
 
 
 
 
375ee53
 
 
 
 
 
 
 
 
 
4e92ab0
375ee53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os

import gradio as gr
import rembg
import spaces
import torch
from diffusers import DiffusionPipeline

from instantMesh.src.utils.infer_util import (remove_background, resize_foreground)


pipe = DiffusionPipeline.from_pretrained(
    "playgroundai/playground-v2.5-1024px-aesthetic",
    torch_dtype=torch.float16,
    variant="fp16"
).to("cuda")


def generate_prompt(subject, style, color_scheme, angle, lighting_type, additional_details):
    return f"A 3D cartoon render of {subject}, featuring the entire body and shape, on a transparent background. The style should be {style}, with {color_scheme} colors, emphasizing the essential features and lines. The pose should clearly showcase the full form of the {subject} from a {angle} perspective. Lighting is {lighting_type}, highlighting the volume and depth of the subject. {additional_details}. Output as a high-resolution PNG with no background."


@spaces.GPU
def generate_image(subject, style, color_scheme, angle, lighting_type, additional_details):
    prompt = generate_prompt(subject, style, color_scheme,
                             angle, lighting_type, additional_details)
    results = pipe(prompt, num_inference_steps=25, guidance_scale=7.5)
    return results.images[0]


def check_input_image(input_image):
    if input_image is None:
        raise gr.Error("No image selected!")


def preprocess(input_image):
    rembg_session = rembg.new_session()

    input_image = remove_background(input_image, rembg_session)
    input_image = resize_foreground(input_image, 0.85)

    return input_image


def image_generation_ui():
    with gr.Row():
        subject = gr.Textbox(label='Subject', scale=2)
        style = gr.Dropdown(
            label='Style',
            choices=['Pixar-like', 'Disney-esque', 'Anime-inspired'],
            value='Pixar-like',
            multiselect=False,
            scale=2
        )
        color_scheme = gr.Dropdown(
            label='Color Scheme',
            choices=['Vibrant', 'Pastel', 'Monochromatic', 'Black and White'],
            value='Vibrant',
            multiselect=False,
            scale=2
        )
        angle = gr.Dropdown(
            label='Angle',
            choices=['Front', 'Side', 'Three-quarter'],
            value='Front',
            multiselect=False,
            scale=2
        )
        lighting_type = gr.Dropdown(
            label='Lighting Type',
            choices=['Bright and Even', 'Dramatic Shadows', 'Soft and Warm'],
            value='Bright and Even',
            multiselect=False,
            scale=2
        )
        additional_details = gr.Textbox(label='Additional Details', scale=2)
        submit_prompt = gr.Button('Generate Image', scale=1, variant='primary')

    with gr.Row(variant="panel"):
        with gr.Column():
            with gr.Row():
                input_image = gr.Image(
                    label="Input Image",
                    image_mode="RGBA",
                    sources="upload",
                    type="pil",
                    elem_id="content_image",
                )
                processed_image = gr.Image(
                    label="Processed Image",
                    image_mode="RGBA",
                    type="pil",
                    interactive=False
                )
            with gr.Row():
                submit_process = gr.Button(
                    "Remove Background", elem_id="process", variant="primary")
            with gr.Row(variant="panel"):
                gr.Examples(
                    examples=[os.path.join("examples", img_name)
                              for img_name in sorted(os.listdir("examples"))],
                    inputs=[input_image],
                    label="Examples",
                    cache_examples=False,
                    examples_per_page=16
                )

    submit_prompt.click(fn=generate_image, inputs=[subject, style, color_scheme, angle, lighting_type, additional_details], outputs=input_image).success(
        fn=preprocess, inputs=[input_image], outputs=[processed_image]
    )
    submit_process.click(fn=check_input_image, inputs=[input_image]).success(
        fn=preprocess, inputs=[input_image], outputs=[processed_image],
    )

    return input_image, processed_image