Spaces:
Running
Running
fixed cpu error
Browse files
app.py
CHANGED
@@ -80,8 +80,11 @@ def prepare_pipeline(model_name):
|
|
80 |
if 'dpo' in OUTPUT_DIR:
|
81 |
args.unet_path = "mhdang/dpo-sd1.5-text2image-v1"
|
82 |
|
83 |
-
|
84 |
-
|
|
|
|
|
|
|
85 |
|
86 |
pipe.verbose = True
|
87 |
pipe.v = 're'
|
@@ -116,7 +119,7 @@ def prepare_pipeline(model_name):
|
|
116 |
ID2NAME = open('data/dogs/class_names.txt').readlines()
|
117 |
ID2NAME = [line.strip() for line in ID2NAME]
|
118 |
|
119 |
-
return pipe, MAPPING, ID2NAME
|
120 |
|
121 |
|
122 |
def download_file(url, local_path):
|
@@ -159,11 +162,11 @@ def process_text(text, MAPPING, ID2NAME):
|
|
159 |
|
160 |
|
161 |
def generate_images(model_name, prompt, negative_prompt, num_inference_steps, guidance_scale, num_images, seed):
|
162 |
-
generator = torch.Generator(device='cuda')
|
163 |
-
generator = generator.manual_seed(int(seed))
|
164 |
-
|
165 |
try:
|
166 |
-
pipe, MAPPING, ID2NAME = prepare_pipeline(model_name)
|
|
|
|
|
|
|
167 |
|
168 |
prompt, part2id = process_text(prompt, MAPPING, ID2NAME)
|
169 |
negative_prompt, _ = process_text(negative_prompt, MAPPING, ID2NAME)
|
@@ -179,7 +182,8 @@ def generate_images(model_name, prompt, negative_prompt, num_inference_steps, gu
|
|
179 |
f"The error message: {e}")
|
180 |
finally:
|
181 |
gc.collect()
|
182 |
-
torch.cuda.
|
|
|
183 |
|
184 |
return images, '; '.join(part2id)
|
185 |
|
|
|
80 |
if 'dpo' in OUTPUT_DIR:
|
81 |
args.unet_path = "mhdang/dpo-sd1.5-text2image-v1"
|
82 |
|
83 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
84 |
+
weight_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
85 |
+
|
86 |
+
pipe = load_pipeline(args, weight_dtype, device)
|
87 |
+
pipe = pipe.to(weight_dtype)
|
88 |
|
89 |
pipe.verbose = True
|
90 |
pipe.v = 're'
|
|
|
119 |
ID2NAME = open('data/dogs/class_names.txt').readlines()
|
120 |
ID2NAME = [line.strip() for line in ID2NAME]
|
121 |
|
122 |
+
return pipe, MAPPING, ID2NAME, device
|
123 |
|
124 |
|
125 |
def download_file(url, local_path):
|
|
|
162 |
|
163 |
|
164 |
def generate_images(model_name, prompt, negative_prompt, num_inference_steps, guidance_scale, num_images, seed):
|
|
|
|
|
|
|
165 |
try:
|
166 |
+
pipe, MAPPING, ID2NAME, device = prepare_pipeline(model_name)
|
167 |
+
|
168 |
+
generator = torch.Generator(device=device)
|
169 |
+
generator = generator.manual_seed(int(seed))
|
170 |
|
171 |
prompt, part2id = process_text(prompt, MAPPING, ID2NAME)
|
172 |
negative_prompt, _ = process_text(negative_prompt, MAPPING, ID2NAME)
|
|
|
182 |
f"The error message: {e}")
|
183 |
finally:
|
184 |
gc.collect()
|
185 |
+
if torch.cuda.is_available():
|
186 |
+
torch.cuda.empty_cache()
|
187 |
|
188 |
return images, '; '.join(part2id)
|
189 |
|