DigiP-AI commited on
Commit
99fef99
·
verified ·
1 Parent(s): 260b768

Delete predict.py

Browse files
Files changed (1) hide show
  1. predict.py +0 -134
predict.py DELETED
@@ -1,134 +0,0 @@
1
- # Prediction interface for Cog ⚙️
2
- # https://cog.run/python
3
-
4
- from cog import BasePredictor, Input, Path
5
- import os
6
- import time
7
- import torch
8
- import subprocess
9
- from PIL import Image
10
- from typing import List
11
- from image_datasets.canny_dataset import canny_processor, c_crop
12
- from src.flux.util import load_ae, load_clip, load_t5, load_flow_model, load_controlnet, load_safetensors
13
-
14
- OUTPUT_DIR = "controlnet_results"
15
- MODEL_CACHE = "checkpoints"
16
- CONTROLNET_URL = "https://huggingface.co/XLabs-AI/flux-controlnet-canny/resolve/main/controlnet.safetensors"
17
- T5_URL = "https://weights.replicate.delivery/default/black-forest-labs/FLUX.1-dev/t5-cache.tar"
18
- CLIP_URL = "https://weights.replicate.delivery/default/black-forest-labs/FLUX.1-dev/clip-cache.tar"
19
- HF_TOKEN = "hf_..." # Your HuggingFace token
20
-
21
- def download_weights(url, dest):
22
- start = time.time()
23
- print("downloading url: ", url)
24
- print("downloading to: ", dest)
25
- subprocess.check_call(["pget", "-xf", url, dest], close_fds=False)
26
- print("downloading took: ", time.time() - start)
27
-
28
- def get_models(name: str, device: torch.device, offload: bool, is_schnell: bool):
29
- t5 = load_t5(device, max_length=256 if is_schnell else 512)
30
- clip = load_clip(device)
31
- model = load_flow_model(name, device="cpu" if offload else device)
32
- ae = load_ae(name, device="cpu" if offload else device)
33
- controlnet = load_controlnet(name, device).to(torch.bfloat16)
34
- return model, ae, t5, clip, controlnet
35
-
36
- class Predictor(BasePredictor):
37
- def setup(self) -> None:
38
- """Load the model into memory to make running multiple predictions efficient"""
39
- t1 = time.time()
40
- os.system(f"huggingface-cli login --token {HF_TOKEN}")
41
- name = "flux-dev"
42
- self.offload = False
43
- checkpoint = "controlnet.safetensors"
44
-
45
- print("Checking ControlNet weights")
46
- checkpoint = "controlnet.safetensors"
47
- if not os.path.exists(checkpoint):
48
- os.system(f"wget {CONTROLNET_URL}")
49
- print("Checking T5 weights")
50
- if not os.path.exists(MODEL_CACHE+"/models--google--t5-v1_1-xxl"):
51
- download_weights(T5_URL, MODEL_CACHE)
52
- print("Checking CLIP weights")
53
- if not os.path.exists(MODEL_CACHE+"/models--openai--clip-vit-large-patch14"):
54
- download_weights(CLIP_URL, MODEL_CACHE)
55
-
56
- self.is_schnell = False
57
- device = "cuda"
58
- self.torch_device = torch.device(device)
59
- model, ae, t5, clip, controlnet = get_models(
60
- name,
61
- device=self.torch_device,
62
- offload=self.offload,
63
- is_schnell=self.is_schnell,
64
- )
65
- self.ae = ae
66
- self.t5 = t5
67
- self.clip = clip
68
- self.controlnet = controlnet
69
- self.model = model.to(self.torch_device)
70
- if '.safetensors' in checkpoint:
71
- checkpoint1 = load_safetensors(checkpoint)
72
- else:
73
- checkpoint1 = torch.load(checkpoint, map_location='cpu')
74
-
75
- controlnet.load_state_dict(checkpoint1, strict=False)
76
- t2 = time.time()
77
- print(f"Setup time: {t2 - t1}")
78
-
79
- def preprocess_canny_image(self, image_path: str, width: int = 512, height: int = 512):
80
- image = Image.open(image_path)
81
- image = c_crop(image)
82
- image = image.resize((width, height))
83
- image = canny_processor(image)
84
- return image
85
-
86
- def predict(
87
- self,
88
- prompt: str = Input(description="Input prompt", default="a handsome viking man with white hair, cinematic, MM full HD"),
89
- image: Path = Input(description="Input image", default=None),
90
- num_inference_steps: int = Input(description="Number of inference steps", ge=1, le=64, default=28),
91
- cfg: float = Input(description="CFG", ge=0, le=10, default=3.5),
92
- seed: int = Input(description="Random seed", default=None)
93
- ) -> List[Path]:
94
- """Run a single prediction on the model"""
95
- if seed is None:
96
- seed = int.from_bytes(os.urandom(2), "big")
97
- print(f"Using seed: {seed}")
98
-
99
- # clean output dir
100
- output_dir = "controlnet_results"
101
- os.system(f"rm -rf {output_dir}")
102
-
103
- input_image = str(image)
104
- img = Image.open(input_image)
105
- width, height = img.size
106
- # Resize input image if it's too large
107
- max_image_size = 1536
108
- scale = min(max_image_size / width, max_image_size / height, 1)
109
- if scale < 1:
110
- width = int(width * scale)
111
- height = int(height * scale)
112
- print(f"Scaling image down to {width}x{height}")
113
- img = img.resize((width, height), resample=Image.Resampling.LANCZOS)
114
- input_image = "/tmp/resized_image.png"
115
- img.save(input_image)
116
-
117
- subprocess.check_call(
118
- ["python3", "main.py",
119
- "--local_path", "controlnet.safetensors",
120
- "--image", input_image,
121
- "--use_controlnet",
122
- "--control_type", "canny",
123
- "--prompt", prompt,
124
- "--width", str(width),
125
- "--height", str(height),
126
- "--num_steps", str(num_inference_steps),
127
- "--guidance", str(cfg),
128
- "--seed", str(seed)
129
- ], close_fds=False)
130
-
131
- # Find the first file that begins with "controlnet_result_"
132
- for file in os.listdir(output_dir):
133
- if file.startswith("controlnet_result_"):
134
- return [Path(os.path.join(output_dir, file))]