File size: 4,542 Bytes
1393f77
59e3e19
9d65aa7
26139de
d5e5583
1393f77
59e3e19
 
 
 
 
 
 
 
de64520
 
 
 
 
 
 
 
 
 
7338c67
1393f77
26139de
38d6f1a
de64520
59e3e19
 
 
 
70187df
59e3e19
 
 
70187df
59e3e19
 
 
 
 
 
 
 
 
de64520
59e3e19
 
 
 
70187df
7338c67
1393f77
 
 
 
 
 
 
de64520
1393f77
7338c67
1393f77
 
de47ce6
 
de64520
1393f77
 
4064994
1393f77
12d2706
fd948c1
 
 
 
 
 
 
 
 
 
 
 
ecd948a
 
 
 
 
 
 
fd948c1
7338c67
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
from pipeline import PixArtSigmaPipeline
import gradio as gr
import spaces

# Load the pre-trained diffusion model
base_model = "ptx0/pixart-900m-1024-ft-v0.7-stage1"
stg2_model = "ptx0/pixart-900m-1024-ft-v0.7-stage2"
torch_device = "cuda"
base_pipeline = PixArtSigmaPipeline.from_pretrained(
	base_model, use_safetensors=True
).to(dtype=torch_precision, device=torch_device)
stg2_pipeline = PixArtSigmaPipeline.from_pretrained(stg2_model, **base_pipeline.components)
stg2_pipeline.transformer = PixArtTransformer2DModel.from_pretrained(stg2_model, subfolder="transformer").to(dtype=torch_precision, device=torch_device)
import re

def extract_resolution(resolution_str):
    match = re.match(r'(\d+)x(\d+)', resolution_str)
    if match:
        width = int(match.group(1))
        height = int(match.group(2))
        return (width, height)
    else:
        return None

# Define the image generation function with adjustable parameters and a progress bar
@spaces.GPU
def generate(prompt, guidance_scale, num_inference_steps, resolution, negative_prompt):
    width, height = extract_resolution(resolution) or (1024, 1024)
    mixture_generator = torch.Generator().manual_seed(444)
    stage1_strength = 0.6
    latent_images = base_pipeline(
        prompt=prompt,
        negative_prompt=negative_prompt,
        num_inference_steps=num_inference_steps,
        num_images_per_prompt=1,
        generator=mixture_generator,
        guidance_scale=guidance_scale,
        output_type="latent",
        denoising_end=stage1_strength,
        width=width,
        height=height
    ).images
    return refined_images = stg2_pipeline(
        prompt=prompt,
        negative_prompt=negative_prompt,
        latents=latent_images,
        num_inference_steps=num_inference_steps,
        num_images_per_prompt=1,
        generator=mixture_generator,
        guidance_scale=guidance_scale,
        denoising_start=stage1_strength
    ).images

# Example prompts to demonstrate the model's capabilities
example_prompts = [
    ["A futuristic cityscape at night under a starry sky", 7.5, 25, "blurry, overexposed"],
    ["A serene landscape with a flowing river and autumn trees", 8.0, 20, "crowded, noisy"],
    ["An abstract painting of joy and energy in bright colors", 9.0, 30, "dark, dull"]
]

# Create a Gradio interface, 1024x1024,1152x960,896x1152
iface = gr.Interface(
    fn=generate,
    inputs=[
        gr.Text(label="Enter your prompt"),
        gr.Slider(1, 20, step=0.1, label="Guidance Scale", value=3.4),
        gr.Slider(1, 50, step=1, label="Number of Inference Steps", value=28),
        gr.Radio(["1024x1024", "1152x960", "896x1152"], label="Resolution", value="1152x960"),
        gr.Text(value="underexposed, blurry, ugly, washed-out", label="Negative Prompt")
    ],
    outputs=gr.Gallery(height=1024, min_width=1024, columns=2),
    examples=example_prompts,
    title="PixArt 900M",
    description=(
        "This is a 900M parameter model expanded from PixArt Sigma 1024px (600M) by adding 14 layers to deepen the transformer."
        "<br />This model is being <strong>actively trained</strong> on 3.5M samples across a wide distribution of photos, synthetic data, cinema, anime, and safe-for-work furry art."
        "<br />"
        "<br />&nbsp;The datasets been filtered for extremist and illegal content, but it is possible to produce toxic outputs. <strong>This model has not been safety-aligned or fine-tuned</strong>."
        " You may receive non-aesthetic results, or prompts might be partially or wholly ignored."
        "<br />Although celebrity names and artist styles haven't been scrubbed from the datasets, the low volume of these samples in the training set result in a lack of representation for public figures."
        "<br />"
        "<br />Be mindful when using this demo space that you do not inadvertently share images without adequate preparation and informing the receivers that these images are AI generated."
        "<br />"
        "<br />This model is being trained by <strong>Terminus Research Group</strong> with support from <a href='https://fal.ai'>Fal.ai</a>."
        " See https://fal.ai/grants for more information on how Fal.ai can help your team."
        "<br />"
        "<br />"
        "<ul>"
        "<li>Lead trainer: @pseudoterminalx (bghira@GitHub)</li>"
        "<li>Architecture: @jimmycarter (AmericanPresidentJimmyCarter@GitHub)</li>"
        "<li>Datasets: @ProGamerGov, @jimmycarter, @pseudoterminalx</li>"
        "</ul>"
    )
).launch()