amos1088 commited on
Commit
3aadc38
·
1 Parent(s): 4de092c

test gradio

Browse files
Files changed (1) hide show
  1. app.py +65 -24
app.py CHANGED
@@ -1,32 +1,73 @@
1
  import gradio as gr
2
  import torch
3
- from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
4
- import gradio as gr
5
- from huggingface_hub import login
6
  import os
7
- import spaces,tempfile
8
- import torch
9
-
10
 
11
  token = os.getenv("HF_TOKEN")
12
  login(token=token)
13
- model_id = "stabilityai/stable-diffusion-2-base"
14
- scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
15
- pipe = StableDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, torch_dtype=torch.float16)
16
-
17
- lora_path = "Jl-wei/ui-diffuser-v2"
18
- pipe.load_lora_weights(lora_path)
19
- pipe.to("cuda")
20
-
21
- @spaces.GPU
22
- def gui_generation(text, num_imgs):
23
- prompt = f"Mobile app: {text}"
24
- images = pipe(prompt, num_inference_steps=30, guidance_scale=7.5, height=512, width=288, num_images_per_prompt=num_imgs).images
25
- yield images
26
-
27
- gallery = gr.Gallery(columns=[3], rows=[1], object_fit="contain", height="auto")
28
- number_slider = gr.Slider(1, 30, value=2, step=1, label="Batch size")
29
- prompt_box = gr.Textbox(label="Prompt", placeholder="Health monittoring report")
30
- interface = gr.Interface(gui_generation, inputs=[prompt_box, number_slider], outputs=gallery)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  interface.launch()
 
1
  import gradio as gr
2
  import torch
3
+ from PIL import Image
4
+ from models.transformer_sd3 import SD3Transformer2DModel
5
+ from pipeline_stable_diffusion_3_ipa import StableDiffusion3Pipeline
6
  import os
7
+ from huggingface_hub import login
 
 
8
 
9
  token = os.getenv("HF_TOKEN")
10
  login(token=token)
11
+
12
+ # Model and paths
13
+ model_path = 'stabilityai/stable-diffusion-3.5-large'
14
+ ip_adapter_path = './ip-adapter.bin'
15
+ image_encoder_path = "google/siglip-so400m-patch14-384"
16
+ ref_img_path = './assets/1.jpg' # Reference image path
17
+
18
+ # Load SD3.5 pipeline and components
19
+ transformer = SD3Transformer2DModel.from_pretrained(
20
+ model_path, subfolder="transformer", torch_dtype=torch.bfloat16
21
+ )
22
+ pipe = StableDiffusion3Pipeline.from_pretrained(
23
+ model_path, transformer=transformer, torch_dtype=torch.bfloat16
24
+ ).to("cuda")
25
+
26
+ pipe.init_ipadapter(
27
+ ip_adapter_path=ip_adapter_path,
28
+ image_encoder_path=image_encoder_path,
29
+ nb_token=64,
30
+ )
31
+
32
+
33
+ @gr.Interface()
34
+ def gui_generation(prompt: str, negative_prompt: str, ipadapter_scale: float, num_imgs: int):
35
+ """
36
+ Generate images based on prompt, negative prompt, and IP-Adapter scale.
37
+ """
38
+ ref_img = Image.open(ref_img_path).convert('RGB') # Load reference image
39
+ generator = torch.Generator("cuda").manual_seed(42) # Reproducibility
40
+
41
+ images = []
42
+ for _ in range(num_imgs):
43
+ output = pipe(
44
+ width=1024,
45
+ height=1024,
46
+ prompt=prompt,
47
+ negative_prompt=negative_prompt,
48
+ num_inference_steps=24,
49
+ guidance_scale=5.0,
50
+ generator=generator,
51
+ clip_image=ref_img,
52
+ ipadapter_scale=ipadapter_scale,
53
+ ).images[0]
54
+ images.append(output)
55
+ return images
56
+
57
+
58
+ # Gradio UI elements
59
+ prompt_box = gr.Textbox(label="Prompt", placeholder="Enter your generation prompt here")
60
+ negative_prompt_box = gr.Textbox(label="Negative Prompt", placeholder="e.g., lowres, worst quality")
61
+ ipadapter_slider = gr.Slider(0.1, 1.0, value=0.5, step=0.1, label="IP-Adapter Scale")
62
+ number_slider = gr.Slider(1, 5, value=1, step=1, label="Number of Images")
63
+ gallery = gr.Gallery(label="Generated Images", columns=[3], rows=[1], object_fit="contain", height="auto")
64
+
65
+ interface = gr.Interface(
66
+ gui_generation,
67
+ inputs=[prompt_box, negative_prompt_box, ipadapter_slider, number_slider],
68
+ outputs=gallery,
69
+ title="Stable Diffusion 3.5 Image Generation with IP-Adapter",
70
+ description="Generate high-quality images with Stable Diffusion 3.5 Large and IP-Adapter guidance."
71
+ )
72
 
73
  interface.launch()