amos1088 commited on
Commit
68e88ea
·
1 Parent(s): 91a655a

test gradio

Browse files
Files changed (1) hide show
  1. app.py +37 -31
app.py CHANGED
@@ -1,20 +1,19 @@
1
  import gradio as gr
2
  import torch
3
- from PIL import Image
4
  from models.transformer_sd3 import SD3Transformer2DModel
5
  from pipeline_stable_diffusion_3_ipa import StableDiffusion3Pipeline
6
  import os
 
7
  from huggingface_hub import login
8
-
9
  token = os.getenv("HF_TOKEN")
10
  login(token=token)
11
 
12
- # Model and paths
13
  model_path = 'stabilityai/stable-diffusion-3.5-large'
14
  ip_adapter_path = './ip-adapter.bin'
15
  image_encoder_path = "google/siglip-so400m-patch14-384"
16
 
17
- # Load SD3.5 pipeline and components
18
  transformer = SD3Transformer2DModel.from_pretrained(
19
  model_path, subfolder="transformer", torch_dtype=torch.bfloat16
20
  )
@@ -22,6 +21,7 @@ pipe = StableDiffusion3Pipeline.from_pretrained(
22
  model_path, transformer=transformer, torch_dtype=torch.bfloat16
23
  ).to("cuda")
24
 
 
25
  pipe.init_ipadapter(
26
  ip_adapter_path=ip_adapter_path,
27
  image_encoder_path=image_encoder_path,
@@ -29,38 +29,44 @@ pipe.init_ipadapter(
29
  )
30
 
31
 
32
- @gr.Interface()
33
- def gui_generation(image: Image, style_image: Image):
34
  """
35
- Generate an image based on input and style images.
36
  """
37
- generator = torch.Generator("cuda").manual_seed(42) # Reproducibility
38
-
39
- output = pipe(
40
- width=1024,
41
- height=1024,
42
- prompt="",
43
- negative_prompt="",
44
  num_inference_steps=24,
45
  guidance_scale=5.0,
46
- generator=generator,
47
- clip_image=style_image,
48
- ipadapter_scale=0.5,
49
- ).images[0]
50
- return output
51
 
 
52
 
53
- # Gradio UI elements
54
- image_input = gr.Image(type="pil", label="Input Image")
55
- style_image_input = gr.Image(type="pil", label="Style Image")
56
- output_image = gr.Image(label="Generated Image")
57
 
58
- interface = gr.Interface(
59
- gui_generation,
60
- inputs=[image_input, style_image_input],
61
- outputs=output_image,
62
- title="Image Generation with Style Image",
63
- description="Upload an input image and a style image to generate a new image based on the style."
64
- )
 
 
 
 
 
 
 
 
65
 
66
- interface.launch()
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
 
3
  from models.transformer_sd3 import SD3Transformer2DModel
4
  from pipeline_stable_diffusion_3_ipa import StableDiffusion3Pipeline
5
  import os
6
+ import spaces
7
  from huggingface_hub import login
 
8
  token = os.getenv("HF_TOKEN")
9
  login(token=token)
10
 
11
+ # Model and Pipeline Setup
12
  model_path = 'stabilityai/stable-diffusion-3.5-large'
13
  ip_adapter_path = './ip-adapter.bin'
14
  image_encoder_path = "google/siglip-so400m-patch14-384"
15
 
16
+ # Load transformer and pipeline
17
  transformer = SD3Transformer2DModel.from_pretrained(
18
  model_path, subfolder="transformer", torch_dtype=torch.bfloat16
19
  )
 
21
  model_path, transformer=transformer, torch_dtype=torch.bfloat16
22
  ).to("cuda")
23
 
24
+ # Initialize IP Adapter
25
  pipe.init_ipadapter(
26
  ip_adapter_path=ip_adapter_path,
27
  image_encoder_path=image_encoder_path,
 
29
  )
30
 
31
 
32
+ @spaces.GPU
33
+ def gui_generation(text, num_imgs, width, height):
34
  """
35
+ Generate images using Stable Diffusion 3.5
36
  """
37
+ images = pipe(
38
+ prompt=text,
39
+ width=width,
40
+ height=height,
41
+ num_images_per_prompt=num_imgs,
42
+ negative_prompt="lowres, low quality, worst quality",
 
43
  num_inference_steps=24,
44
  guidance_scale=5.0,
45
+ generator=torch.Generator("cuda").manual_seed(42),
46
+ ).images
 
 
 
47
 
48
+ return images
49
 
 
 
 
 
50
 
51
+ # Create Gradio interface
52
+ with gr.Blocks() as demo:
53
+ gr.Markdown("# Stable Diffusion 3.5 Image Generation")
54
+
55
+ with gr.Row():
56
+ prompt_box = gr.Textbox(label="Prompt", placeholder="Enter your image generation prompt")
57
+ number_slider = gr.Slider(1, 30, value=2, step=1, label="Batch size")
58
+
59
+ with gr.Row():
60
+ width_slider = gr.Slider(256, 1536, value=1024, step=64, label="Width")
61
+ height_slider = gr.Slider(256, 1536, value=1024, step=64, label="Height")
62
+
63
+ gallery = gr.Gallery(columns=[3], rows=[1], object_fit="contain", height="auto")
64
+
65
+ generate_btn = gr.Button("Generate")
66
 
67
+ generate_btn.click(
68
+ fn=gui_generation,
69
+ inputs=[prompt_box, number_slider, width_slider, height_slider],
70
+ outputs=gallery
71
+ )
72
+ demo.launch()