vivekvar commited on
Commit
7d605fc
·
verified ·
1 Parent(s): 79c9372

Delete main.py

Browse files
Files changed (1) hide show
  1. main.py +0 -58
main.py DELETED
@@ -1,58 +0,0 @@
1
- import os
2
- import torch
3
- from PIL import Image
4
- from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
5
- import gradio as gr
6
-
7
- # Disable oneDNN custom operations
8
- os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
9
-
10
- # Clear PyTorch cache
11
- torch.cuda.empty_cache()
12
-
13
- # Check if CUDA is available
14
- device = "cuda" if torch.cuda.is_available() else "cpu"
15
- if device == "cuda":
16
- print("CUDA is available. Device count:", torch.cuda.device_count())
17
- print("Current device:", torch.cuda.current_device())
18
- print("Device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
19
- else:
20
- print("CUDA is not available. Using CPU.")
21
-
22
- # Load ControlNet model with OpenPose pre-trained weights from Hugging Face
23
- controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_openpose", torch_dtype=torch.float16)
24
-
25
- # Load the Stable Diffusion model
26
- pipe = StableDiffusionControlNetPipeline.from_pretrained(
27
- "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
28
- ).to(device)
29
-
30
- # Function for inference
31
- def generate_image(prompt, target_image, pose_image):
32
- try:
33
- # Resize images
34
- target_image = target_image.resize((512, 512))
35
- pose_image = pose_image.resize((512, 512))
36
-
37
- # Generate image with ControlNet
38
- output = pipe(prompt=prompt, image=target_image, control_image=pose_image, num_inference_steps=50)
39
-
40
- # Return the result
41
- return output["sample"][0]
42
- except Exception as e:
43
- print(f"Error during image generation: {e}")
44
- return None
45
-
46
- # Setup Gradio Interface
47
- interface = gr.Interface(
48
- fn=generate_image,
49
- inputs=[
50
- gr.Textbox(label="Prompt"),
51
- gr.Image(label="Target Image", type="pil"),
52
- gr.Image(label="Pose Image (Reference)", type="pil")
53
- ],
54
- outputs=gr.Image(label="Generated Image")
55
- )
56
-
57
- # Launch the interface
58
- interface.launch()