import torch
from pipeline import PixArtSigmaPipeline
import gradio as gr
import spaces
# Load the pre-trained diffusion model
base_model = "ptx0/pixart-900m-1024-ft-v0.7-stage1"
stg2_model = "ptx0/pixart-900m-1024-ft-v0.7-stage2"
torch_device = "cuda"
base_pipeline = PixArtSigmaPipeline.from_pretrained(
base_model, use_safetensors=True
).to(dtype=torch_precision, device=torch_device)
stg2_pipeline = PixArtSigmaPipeline.from_pretrained(stg2_model, **base_pipeline.components)
stg2_pipeline.transformer = PixArtTransformer2DModel.from_pretrained(stg2_model, subfolder="transformer").to(dtype=torch_precision, device=torch_device)
import re
def extract_resolution(resolution_str):
match = re.match(r'(\d+)x(\d+)', resolution_str)
if match:
width = int(match.group(1))
height = int(match.group(2))
return (width, height)
else:
return None
# Define the image generation function with adjustable parameters and a progress bar
@spaces.GPU
def generate(prompt, guidance_scale, num_inference_steps, resolution, negative_prompt):
width, height = extract_resolution(resolution) or (1024, 1024)
mixture_generator = torch.Generator().manual_seed(444)
stage1_strength = 0.6
latent_images = base_pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
num_images_per_prompt=1,
generator=mixture_generator,
guidance_scale=guidance_scale,
output_type="latent",
denoising_end=stage1_strength,
width=width,
height=height
).images
return refined_images = stg2_pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
latents=latent_images,
num_inference_steps=num_inference_steps,
num_images_per_prompt=1,
generator=mixture_generator,
guidance_scale=guidance_scale,
denoising_start=stage1_strength
).images
# Example prompts to demonstrate the model's capabilities
example_prompts = [
["A futuristic cityscape at night under a starry sky", 7.5, 25, "blurry, overexposed"],
["A serene landscape with a flowing river and autumn trees", 8.0, 20, "crowded, noisy"],
["An abstract painting of joy and energy in bright colors", 9.0, 30, "dark, dull"]
]
# Create a Gradio interface, 1024x1024,1152x960,896x1152
iface = gr.Interface(
fn=generate,
inputs=[
gr.Text(label="Enter your prompt"),
gr.Slider(1, 20, step=0.1, label="Guidance Scale", value=3.4),
gr.Slider(1, 50, step=1, label="Number of Inference Steps", value=28),
gr.Radio(["1024x1024", "1152x960", "896x1152"], label="Resolution", value="1152x960"),
gr.Text(value="underexposed, blurry, ugly, washed-out", label="Negative Prompt")
],
outputs=gr.Gallery(height=1024, min_width=1024, columns=2),
examples=example_prompts,
title="PixArt 900M",
description=(
"This is a 900M parameter model expanded from PixArt Sigma 1024px (600M) by adding 14 layers to deepen the transformer."
"
This model is being actively trained on 3.5M samples across a wide distribution of photos, synthetic data, cinema, anime, and safe-for-work furry art."
"
"
"
The datasets been filtered for extremist and illegal content, but it is possible to produce toxic outputs. This model has not been safety-aligned or fine-tuned."
" You may receive non-aesthetic results, or prompts might be partially or wholly ignored."
"
Although celebrity names and artist styles haven't been scrubbed from the datasets, the low volume of these samples in the training set result in a lack of representation for public figures."
"
"
"
Be mindful when using this demo space that you do not inadvertently share images without adequate preparation and informing the receivers that these images are AI generated."
"
"
"
This model is being trained by Terminus Research Group with support from Fal.ai."
" See https://fal.ai/grants for more information on how Fal.ai can help your team."
"
"
"
"
"