PseudoTerminal X commited on
Commit
59e3e19
1 Parent(s): c7b113a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -6
app.py CHANGED
@@ -1,11 +1,17 @@
1
  import torch
2
- from diffusers import PixArtSigmaPipeline
3
  import gradio as gr
4
  import spaces
5
 
6
  # Load the pre-trained diffusion model
7
- pipe = PixArtSigmaPipeline.from_pretrained('ptx0/pixart-900m-1024-ft', torch_dtype=torch.bfloat16)
8
- pipe.to('cuda')
 
 
 
 
 
 
9
  import re
10
 
11
  def extract_resolution(resolution_str):
@@ -21,12 +27,29 @@ def extract_resolution(resolution_str):
21
  @spaces.GPU
22
  def generate(prompt, guidance_scale, num_inference_steps, resolution, negative_prompt):
23
  width, height = extract_resolution(resolution) or (1024, 1024)
24
- return pipe(
25
- prompt,
 
 
26
  negative_prompt=negative_prompt,
 
 
 
27
  guidance_scale=guidance_scale,
 
 
 
 
 
 
 
 
 
28
  num_inference_steps=num_inference_steps,
29
- width=width, height=height
 
 
 
30
  ).images
31
 
32
  # Example prompts to demonstrate the model's capabilities
 
1
  import torch
2
+ from pipeline import PixArtSigmaPipeline
3
  import gradio as gr
4
  import spaces
5
 
6
  # Load the pre-trained diffusion model
7
+ base_model = "ptx0/pixart-900m-1024-ft-v0.7-stage1"
8
+ stg2_model = "ptx0/pixart-900m-1024-ft-v0.7-stage2"
9
+ torch_device = "cuda"
10
+ base_pipeline = PixArtSigmaPipeline.from_pretrained(
11
+ base_model, use_safetensors=True
12
+ ).to(dtype=torch_precision, device=torch_device)
13
+ stg2_pipeline = PixArtSigmaPipeline.from_pretrained(stg2_model, **base_pipeline.components)
14
+ stg2_pipeline.transformer = PixArtTransformer2DModel.from_pretrained(stg2_model, subfolder="transformer").to(dtype=torch_precision, device=torch_device)
15
  import re
16
 
17
  def extract_resolution(resolution_str):
 
27
  @spaces.GPU
28
  def generate(prompt, guidance_scale, num_inference_steps, resolution, negative_prompt):
29
  width, height = extract_resolution(resolution) or (1024, 1024)
30
+ mixture_generator = torch.Generator().manual_seed(444)
31
+ stage1_strength = 0.6
32
+ latent_images = base_pipeline(
33
+ prompt=prompt,
34
  negative_prompt=negative_prompt,
35
+ num_inference_steps=num_inference_steps,
36
+ num_images_per_prompt=1,
37
+ generator=mixture_generator,
38
  guidance_scale=guidance_scale,
39
+ output_type="latent",
40
+ denoising_end=stage1_strength,
41
+ width=width,
42
+ height=height
43
+ ).images
44
+ return refined_images = stg2_pipeline(
45
+ prompt=prompt,
46
+ negative_prompt=negative_prompt,
47
+ latents=latent_images,
48
  num_inference_steps=num_inference_steps,
49
+ num_images_per_prompt=1,
50
+ generator=mixture_generator,
51
+ guidance_scale=guidance_scale,
52
+ denoising_start=stage1_strength
53
  ).images
54
 
55
  # Example prompts to demonstrate the model's capabilities