willsh1997 commited on
Commit
075f0c6
·
verified ·
1 Parent(s): 22b2550

:robot: fix tensor device issue with claude

Browse files
Files changed (1) hide show
  1. app.py +60 -48
app.py CHANGED
@@ -106,7 +106,6 @@ class customUnClipPipeline(UnCLIPImageVariationPipeline):
106
  ):
107
  """
108
  The call function to the pipeline for generation.
109
-
110
  Args:
111
  image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.Tensor`):
112
  `Image` or tensor representing an image batch to be used as the starting point. If you provide a
@@ -138,7 +137,6 @@ class customUnClipPipeline(UnCLIPImageVariationPipeline):
138
  The output format of the generated image. Choose between `PIL.Image` or `np.array`.
139
  return_dict (`bool`, *optional*, defaults to `True`):
140
  Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
141
-
142
  Returns:
143
  [`~pipelines.ImagePipelineOutput`] or `tuple`:
144
  If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
@@ -314,14 +312,11 @@ class customUnClipPipeline(UnCLIPImageVariationPipeline):
314
 
315
 
316
  ### ADDITIONAL PIPELINE CODE FOR KARLO
317
- torch_device = torch.device('cpu')
318
- pipe = customUnClipPipeline.from_pretrained("kakaobrain/karlo-v1-alpha-image-variations", torch_dtype=torch.float32, trust_remote_code=True,
319
- # device=torch_device,
320
- # device_map='cpu'
321
- )
322
- pipe.to(torch.device("cuda"))
323
- # pipe.enable_model_cpu_offload()
324
 
 
 
325
 
326
  # func for getting tensor embeddings from cand image
327
 
@@ -334,27 +329,27 @@ def load_img_from_URL(URL):
334
  init_image = Image.open(BytesIO(response.content)).convert("RGB")
335
  return init_image
336
 
337
- def embed_img(input_image):
338
- tokens = pipe.feature_extractor(input_image).to(torch_device)
339
- img_model = pipe.image_encoder.to(torch_device)
340
  with torch.no_grad():
341
- embeds = img_model(torch.tensor(tokens.pixel_values[0]).unsqueeze(0).to(torch_device))
342
 
343
- return embeds.image_embeds.to(torch_device)
344
 
345
- def localimg_2_embed(image_dir):
346
- embeds = embed_img(load_image(image_dir))
347
  return embeds
348
 
349
- def URLimg_2_embed(URL):
350
- embeds = embed_img(load_img_from_URL(URL))
351
  return embeds
352
 
353
 
354
  # random generator for softmaxxed outputs
355
 
356
- def random_probdist(num_cands):
357
- random_numbers = torch.randn(num_cands)
358
  softmax_output = torch.nn.functional.softmax(random_numbers, dim=0).reshape((num_cands,1))
359
  return softmax_output
360
 
@@ -366,13 +361,15 @@ def scalesum_candtensors(list_scale, cand_tensors):
366
  assert sum(list_scale) == 1, f"you didn't input a valid probability distribution - make sure your scales add up to 1, currently it adds up to {sum(list_scale)}"
367
  assert len(list_scale) == len(cand_tensors), f"your scale list is not the same length as your list of candidate tensors. len list = {len(list_scale)}, len cand tensors = {len(cand_tensors)}"
368
 
369
- scaled = torch.tensor(list_scale), cand_tensors
 
370
  output = scaled.sum(dim=0)
371
  return output
372
 
373
 
374
  def random_candtensor(cand_tensors):
375
- scaled = random_probdist(len(cand_tensors)) * cand_tensors
 
376
  output = scaled.sum(dim=0)
377
  return output
378
 
@@ -390,37 +387,52 @@ def image_grid(imgs, rows, cols):
390
  return grid
391
 
392
 
393
- chaosclicker_willtensor = localimg_2_embed('willpaint-imgs/chaosclicker-willpaint.png').to(torch_device)
394
- contentcnsr_willtensor = localimg_2_embed('willpaint-imgs/contentconnoisseur-willpaint.png').to(torch_device)
395
- digdaydrmr_willtensor = localimg_2_embed('willpaint-imgs/digitaldaydreamer-willpaint.png').to(torch_device)
396
- ecoexplr_willtensor = localimg_2_embed('willpaint-imgs/ecoexplorer-willpaint.png').to(torch_device)
397
- fandomfox_willtensor = localimg_2_embed('willpaint-imgs/fandomfox-willpaint.png').to(torch_device)
398
- mememaven_willtensor = localimg_2_embed('willpaint-imgs/mememaven-willpaint.png').to(torch_device)
399
- newsnerd_willtensor = localimg_2_embed('willpaint-imgs/newnerd-willpaint.png').to(torch_device)
400
- nostalgicnvgtr_willtensor = localimg_2_embed('willpaint-imgs/nostalgicnavigator-willpaint.png').to(torch_device)
401
- scrollseeker_willtensor = localimg_2_embed('willpaint-imgs/scrollseeker-willpaint.png').to(torch_device)
402
- trendtracker_willtensor = localimg_2_embed('willpaint-imgs/trendtracker-willpaint.png').to(torch_device)
403
-
404
-
405
- will_cand_tensors = torch.cat([chaosclicker_willtensor,
406
- contentcnsr_willtensor ,
407
- digdaydrmr_willtensor,
408
- ecoexplr_willtensor,
409
- fandomfox_willtensor,
410
- mememaven_willtensor,
411
- newsnerd_willtensor,
412
- nostalgicnvgtr_willtensor,
413
- scrollseeker_willtensor,
414
- trendtracker_willtensor,], dim=0)
415
-
 
 
 
 
 
 
416
 
417
 
418
  ### FUNCTION FOR EXECUTION
419
  @spaces.GPU
420
  def generate_freak():
421
- will_randomised_input = random_candtensor(will_cand_tensors).unsqueeze(0)
422
- #will_randomised_input
423
- output = pipe(image_embeddings=will_randomised_input.to("cuda"), num_images_per_prompt=1, decoder_num_inference_steps = 15, super_res_num_inference_steps = 4)
 
 
 
 
 
 
 
 
 
424
  return output.images[0]
425
 
426
  ### GRADIO BACKEND
 
106
  ):
107
  """
108
  The call function to the pipeline for generation.
 
109
  Args:
110
  image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.Tensor`):
111
  `Image` or tensor representing an image batch to be used as the starting point. If you provide a
 
137
  The output format of the generated image. Choose between `PIL.Image` or `np.array`.
138
  return_dict (`bool`, *optional*, defaults to `True`):
139
  Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
 
140
  Returns:
141
  [`~pipelines.ImagePipelineOutput`] or `tuple`:
142
  If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
 
312
 
313
 
314
  ### ADDITIONAL PIPELINE CODE FOR KARLO
315
+ # Initialize pipeline on CPU first
316
+ pipe = customUnClipPipeline.from_pretrained("kakaobrain/karlo-v1-alpha-image-variations", torch_dtype=torch.float32, trust_remote_code=True)
 
 
 
 
 
317
 
318
+ # Global variable to store embeddings - will be loaded on GPU when needed
319
+ will_cand_tensors = None
320
 
321
  # func for getting tensor embeddings from cand image
322
 
 
329
  init_image = Image.open(BytesIO(response.content)).convert("RGB")
330
  return init_image
331
 
332
+ def embed_img(input_image, device):
333
+ tokens = pipe.feature_extractor(input_image)
334
+ img_model = pipe.image_encoder.to(device)
335
  with torch.no_grad():
336
+ embeds = img_model(torch.tensor(tokens.pixel_values[0]).unsqueeze(0).to(device))
337
 
338
+ return embeds.image_embeds
339
 
340
+ def localimg_2_embed(image_dir, device):
341
+ embeds = embed_img(load_image(image_dir), device)
342
  return embeds
343
 
344
+ def URLimg_2_embed(URL, device):
345
+ embeds = embed_img(load_img_from_URL(URL), device)
346
  return embeds
347
 
348
 
349
  # random generator for softmaxxed outputs
350
 
351
+ def random_probdist(num_cands, device):
352
+ random_numbers = torch.randn(num_cands, device=device)
353
  softmax_output = torch.nn.functional.softmax(random_numbers, dim=0).reshape((num_cands,1))
354
  return softmax_output
355
 
 
361
  assert sum(list_scale) == 1, f"you didn't input a valid probability distribution - make sure your scales add up to 1, currently it adds up to {sum(list_scale)}"
362
  assert len(list_scale) == len(cand_tensors), f"your scale list is not the same length as your list of candidate tensors. len list = {len(list_scale)}, len cand tensors = {len(cand_tensors)}"
363
 
364
+ device = cand_tensors.device
365
+ scaled = torch.tensor(list_scale, device=device).unsqueeze(1) * cand_tensors
366
  output = scaled.sum(dim=0)
367
  return output
368
 
369
 
370
  def random_candtensor(cand_tensors):
371
+ device = cand_tensors.device
372
+ scaled = random_probdist(len(cand_tensors), device) * cand_tensors
373
  output = scaled.sum(dim=0)
374
  return output
375
 
 
387
  return grid
388
 
389
 
390
+ def initialize_embeddings(device):
391
+ """Initialize embeddings on the correct device"""
392
+ global will_cand_tensors
393
+
394
+ if will_cand_tensors is None:
395
+ # Create embeddings on the specified device
396
+ chaosclicker_willtensor = localimg_2_embed('willpaint-imgs/chaosclicker-willpaint.png', device)
397
+ contentcnsr_willtensor = localimg_2_embed('willpaint-imgs/contentconnoisseur-willpaint.png', device)
398
+ digdaydrmr_willtensor = localimg_2_embed('willpaint-imgs/digitaldaydreamer-willpaint.png', device)
399
+ ecoexplr_willtensor = localimg_2_embed('willpaint-imgs/ecoexplorer-willpaint.png', device)
400
+ fandomfox_willtensor = localimg_2_embed('willpaint-imgs/fandomfox-willpaint.png', device)
401
+ mememaven_willtensor = localimg_2_embed('willpaint-imgs/mememaven-willpaint.png', device)
402
+ newsnerd_willtensor = localimg_2_embed('willpaint-imgs/newnerd-willpaint.png', device)
403
+ nostalgicnvgtr_willtensor = localimg_2_embed('willpaint-imgs/nostalgicnavigator-willpaint.png', device)
404
+ scrollseeker_willtensor = localimg_2_embed('willpaint-imgs/scrollseeker-willpaint.png', device)
405
+ trendtracker_willtensor = localimg_2_embed('willpaint-imgs/trendtracker-willpaint.png', device)
406
+
407
+ will_cand_tensors = torch.cat([chaosclicker_willtensor,
408
+ contentcnsr_willtensor ,
409
+ digdaydrmr_willtensor,
410
+ ecoexplr_willtensor,
411
+ fandomfox_willtensor,
412
+ mememaven_willtensor,
413
+ newsnerd_willtensor,
414
+ nostalgicnvgtr_willtensor,
415
+ scrollseeker_willtensor,
416
+ trendtracker_willtensor,], dim=0)
417
+
418
+ return will_cand_tensors
419
 
420
 
421
  ### FUNCTION FOR EXECUTION
422
  @spaces.GPU
423
  def generate_freak():
424
+ # Move pipeline to GPU
425
+ device = torch.device("cuda")
426
+ pipe.to(device)
427
+
428
+ # Initialize embeddings on GPU
429
+ cand_tensors = initialize_embeddings(device)
430
+
431
+ # Generate random input on GPU
432
+ will_randomised_input = random_candtensor(cand_tensors).unsqueeze(0)
433
+
434
+ # Generate image
435
+ output = pipe(image_embeddings=will_randomised_input, num_images_per_prompt=1, decoder_num_inference_steps = 15, super_res_num_inference_steps = 4)
436
  return output.images[0]
437
 
438
  ### GRADIO BACKEND