rhfeiyang commited on
Commit
5d89dc0
·
1 Parent(s): b851c53
Files changed (1) hide show
  1. hf_demo.py +2 -4
hf_demo.py CHANGED
@@ -7,8 +7,6 @@ import matplotlib.pyplot as plt
7
  import torch
8
  from PIL import Image
9
 
10
-
11
-
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
  dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16
14
  print(f"Using {device} device, dtype={dtype}")
@@ -49,7 +47,7 @@ def demo_inference_gen(adapter_choice:str, prompt:str, samples:int=1,seed:int=0,
49
  height=512, width=512, scales=[1.0],
50
  save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,
51
  start_noise=-1, show=False, style_prompt="sks art", no_load=True,
52
- from_scratch=True, device=device)[0][1.0]
53
  return pred_images
54
  @spaces.GPU
55
  def demo_inference_stylization(adapter_path:str, prompts:list, image:list, start_noise=800,seed:int=0):
@@ -59,7 +57,7 @@ def demo_inference_stylization(adapter_path:str, prompts:list, image:list, start
59
  height=512, width=512, scales=[0.,1.],
60
  save_dir=None, seed=seed,steps=20, guidance_scale=7.5,
61
  start_noise=start_noise, show=True, style_prompt="sks art", no_load=True,
62
- from_scratch=False, device=device)
63
  return pred_images
64
 
65
  # def infer(prompt, samples, steps, scale, seed):
 
7
  import torch
8
  from PIL import Image
9
 
 
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
  dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16
12
  print(f"Using {device} device, dtype={dtype}")
 
47
  height=512, width=512, scales=[1.0],
48
  save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,
49
  start_noise=-1, show=False, style_prompt="sks art", no_load=True,
50
+ from_scratch=True, device=device, weight_dtype=dtype)[0][1.0]
51
  return pred_images
52
  @spaces.GPU
53
  def demo_inference_stylization(adapter_path:str, prompts:list, image:list, start_noise=800,seed:int=0):
 
57
  height=512, width=512, scales=[0.,1.],
58
  save_dir=None, seed=seed,steps=20, guidance_scale=7.5,
59
  start_noise=start_noise, show=True, style_prompt="sks art", no_load=True,
60
+ from_scratch=False, device=device, weight_dtype=dtype)[0][1.0]
61
  return pred_images
62
 
63
  # def infer(prompt, samples, steps, scale, seed):