File size: 5,252 Bytes
1393f77
59e3e19
ecbb9e3
9d65aa7
26139de
d5e5583
1393f77
59e3e19
 
 
ec7f291
59e3e19
 
 
 
 
de64520
 
 
 
 
 
 
 
 
 
7338c67
1393f77
26139de
e31a916
de64520
59e3e19
 
 
 
70187df
59e3e19
 
 
e31a916
59e3e19
 
 
 
 
ae90f5c
59e3e19
 
 
de64520
59e3e19
 
e31a916
59e3e19
70187df
7338c67
1393f77
 
fbe4ed2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1393f77
de64520
1393f77
7338c67
1393f77
 
e31a916
 
 
62dff5f
1393f77
 
4064994
1393f77
12d2706
fd948c1
983dbf6
 
 
fd948c1
983dbf6
fd948c1
983dbf6
 
fd948c1
983dbf6
 
ecd948a
 
 
 
 
fd948c1
7338c67
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import torch
from pipeline import PixArtSigmaPipeline
from diffusers.models import PixArtTransformer2DModel
import gradio as gr
import spaces

# Load the pre-trained diffusion model
base_model = "ptx0/pixart-900m-1024-ft-v0.7-stage1"
stg2_model = "ptx0/pixart-900m-1024-ft-v0.7-stage2"
torch_device = "cuda"
torch_precision = torch.bfloat16
base_pipeline = PixArtSigmaPipeline.from_pretrained(
	base_model, use_safetensors=True
).to(dtype=torch_precision, device=torch_device)
stg2_pipeline = PixArtSigmaPipeline.from_pretrained(stg2_model, **base_pipeline.components)
stg2_pipeline.transformer = PixArtTransformer2DModel.from_pretrained(stg2_model, subfolder="transformer").to(dtype=torch_precision, device=torch_device)
import re

def extract_resolution(resolution_str):
    match = re.match(r'(\d+)x(\d+)', resolution_str)
    if match:
        width = int(match.group(1))
        height = int(match.group(2))
        return (width, height)
    else:
        return None

# Define the image generation function with adjustable parameters and a progress bar
@spaces.GPU
def generate(prompt, stage1_guidance_scale, stage2_guidance_scale, num_inference_steps, resolution, negative_prompt):
    width, height = extract_resolution(resolution) or (1024, 1024)
    mixture_generator = torch.Generator().manual_seed(444)
    stage1_strength = 0.6
    latent_images = base_pipeline(
        prompt=prompt,
        negative_prompt=negative_prompt,
        num_inference_steps=num_inference_steps,
        num_images_per_prompt=1,
        generator=mixture_generator,
        guidance_scale=stage1_guidance_scale,
        output_type="latent",
        denoising_end=stage1_strength,
        width=width,
        height=height
    ).images
    return stg2_pipeline(
        prompt=prompt,
        negative_prompt=negative_prompt,
        latents=latent_images,
        num_inference_steps=num_inference_steps,
        num_images_per_prompt=1,
        generator=mixture_generator,
        guidance_scale=stage2_guidance_scale,
        denoising_start=stage1_strength
    ).images

# Example prompts to demonstrate the model's capabilities
example_prompts = [
    [
        "A futuristic cityscape at night under a starry sky",
        3.5,
        4.5,
        25,
        "1152x960",
        "blurry, overexposed"
    ],
    [
        "A serene landscape with a flowing river and autumn trees",
        3.0,
        4.0,
        20,
        "1152x960",
        "crowded, noisy"
    ],
    [
        "An abstract painting of joy and energy in bright colors",
        3.0,
        4.5,
        30,
        "896x1152",
        "dark, dull"
    ],
    [
        "a stunning portrait of a hamster with an eye patch, piloting a miniature cessna on a wooden desk in an office, depth of field, bokeh, sharp, f1.4",
        3.2,
        4.6,
        40,
        "1024x1024",
        "this is an ugly photograph that no one liked"
    ],
    [
        "Check out my cousin larry in his dirty room, he is such a damn mess",
        3.2,
        4.6,
        40,
        "1152x960",
        "the photograph is blurry and unremarkable"
    ]
]
# Create a Gradio interface, 1024x1024,1152x960,896x1152
iface = gr.Interface(
    fn=generate,
    inputs=[
        gr.Text(label="Enter your prompt"),
        gr.Slider(1, 20, step=0.1, label="Guidance Scale (Stage I)", value=3.4),
        gr.Slider(1, 20, step=0.1, label="Guidance Scale (Stage II)", value=4.2),
        gr.Slider(1, 50, step=1, label="Number of Inference Steps", value=35),
        gr.Radio(["1024x1024", "1152x960", "896x1152"], label="Resolution", value="1024x1024"),
        gr.Text(value="underexposed, blurry, ugly, washed-out", label="Negative Prompt")
    ],
    outputs=gr.Gallery(height=1024, min_width=1024, columns=2),
    examples=example_prompts,
    title="PixArt 900M",
    description=(
        "This is a two-stage mixture-of-experts model implemented in the spirit of NVIDIA's E-Diffi model."
        "<br />The weights were initialised from <strong>terminusresearch/pixart-900m-1024-ft-v0.6</strong> and trained separately on timestep ranges <strong>999-400</strong> and <strong>400-0</strong>."
        "<br />This results in two models where the first stage is responsible for most of the image's composition and colour, and the second stage handles minor-to-fine details."
        "<br />"
        "<br />In comparison to SDXL's refiner, the second stage here handles twice as many timesteps, which allows it to make more use of the text-conditional guidance, improving its capabilities."
        "<br />"
        "<br />Despite being trained with 40% of the schedule, you will discover that using stage 2 stand-alone as a refiner (img2img) will need half the strength - about 20%."
        "<br />When being used in the two-stage pipeline, it should be configured to handle all of its 40% range."
        "<br />"
        "<br />This model is funded and trained by <strong>Terminus Research Group</strong>."
        " If you would like to collaborate or provide compute, please see the organisation page for how to locate us on Discord."
        "<br />"
        "<br />"
        "<ul>"
        "<li>Lead trainer: @pseudoterminalx (bghira@GitHub)</li>"
        "</ul>"
    )
).launch()