amos1088 commited on
Commit
c1497a6
·
1 Parent(s): 4fbc46c

test gradio

Browse files
Files changed (1) hide show
  1. app.py +13 -7
app.py CHANGED
@@ -2,27 +2,32 @@ import gradio as gr
2
  import torch
3
  from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
4
  import os
5
- from huggingface_hub import HfApi, login
6
 
 
7
  token = os.getenv("HF_TOKEN")
8
- login(token=token) # Logs in with the token in Hugging Face Spaces
9
 
10
- # Load Stable Diffusion model and ControlNet reference-only model
11
  model_id = "stabilityai/stable-diffusion-3.5-large-turbo"
12
- controlnet_id = "lllyasviel/control_v11p_sd15_inpaint" # Use an appropriate ControlNet variant
13
 
 
14
  controlnet = ControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch.float32)
15
  pipeline = StableDiffusionControlNetPipeline.from_pretrained(
16
  model_id,
17
  controlnet=controlnet,
18
  torch_dtype=torch.float32
19
  )
 
 
20
 
21
  # Define the Gradio interface function
22
  def generate_image(prompt, reference_image):
23
- # Process reference image
24
- reference_image = reference_image.resize((512, 512))
25
- # Generate image with reference-only style transfer
 
26
  generated_image = pipeline(
27
  prompt=prompt,
28
  image=reference_image,
@@ -32,6 +37,7 @@ def generate_image(prompt, reference_image):
32
  ).images[0]
33
  return generated_image
34
 
 
35
  # Set up Gradio interface
36
  interface = gr.Interface(
37
  fn=generate_image,
 
2
  import torch
3
  from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
4
  import os
5
+ from huggingface_hub import login
6
 
7
+ # Log in with your Hugging Face token (assumed stored in HF_TOKEN)
8
  token = os.getenv("HF_TOKEN")
9
+ login(token=token)
10
 
11
+ # Model IDs for the base Stable Diffusion model and ControlNet variant
12
  model_id = "stabilityai/stable-diffusion-3.5-large-turbo"
13
+ controlnet_id = "lllyasviel/control_v11p_sd15_inpaint" # Make sure this ControlNet is compatible
14
 
15
+ # Load ControlNet model and pipeline
16
  controlnet = ControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch.float32)
17
  pipeline = StableDiffusionControlNetPipeline.from_pretrained(
18
  model_id,
19
  controlnet=controlnet,
20
  torch_dtype=torch.float32
21
  )
22
+ pipeline = pipeline.to("cuda") if torch.cuda.is_available() else pipeline
23
+
24
 
25
  # Define the Gradio interface function
26
  def generate_image(prompt, reference_image):
27
+ # Ensure the reference image is in the correct format
28
+ reference_image = reference_image.convert("RGB").resize((512, 512))
29
+
30
+ # Generate the image with ControlNet
31
  generated_image = pipeline(
32
  prompt=prompt,
33
  image=reference_image,
 
37
  ).images[0]
38
  return generated_image
39
 
40
+
41
  # Set up Gradio interface
42
  interface = gr.Interface(
43
  fn=generate_image,