rhfeiyang commited on
Commit
d4fbe32
·
1 Parent(s): c2786e2
Files changed (1) hide show
  1. hf_demo.py +4 -3
hf_demo.py CHANGED
@@ -11,8 +11,9 @@ from PIL import Image
11
 
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
  dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16
 
14
  pipe = DiffusionPipeline.from_pretrained("rhfeiyang/art-free-diffusion-v1",
15
- dtype=dtype).to(device)
16
 
17
  from inference import get_lora_network, inference, get_validation_dataloader
18
  lora_map = {
@@ -42,7 +43,7 @@ def demo_inference_gen(adapter_choice:str, prompt:str, samples:int=1,seed:int=0,
42
  adapter_path = f"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt"
43
 
44
  prompts = [prompt]*samples
45
- infer_loader = get_validation_dataloader(prompts)
46
  network = get_lora_network(pipe.unet, adapter_path)["network"]
47
  pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
48
  height=512, width=512, scales=[1.0],
@@ -52,7 +53,7 @@ def demo_inference_gen(adapter_choice:str, prompt:str, samples:int=1,seed:int=0,
52
  return pred_images
53
  @spaces.GPU
54
  def demo_inference_stylization(adapter_path:str, prompts:list, image:list, start_noise=800,seed:int=0):
55
- infer_loader = get_validation_dataloader(prompts, image)
56
  network = get_lora_network(pipe.unet, adapter_path,"all_up")["network"]
57
  pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
58
  height=512, width=512, scales=[0.,1.],
 
11
 
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
  dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16
14
+ print(f"Using {device} device, dtype={dtype}")
15
  pipe = DiffusionPipeline.from_pretrained("rhfeiyang/art-free-diffusion-v1",
16
+ torch_dtype=dtype).to(device)
17
 
18
  from inference import get_lora_network, inference, get_validation_dataloader
19
  lora_map = {
 
43
  adapter_path = f"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt"
44
 
45
  prompts = [prompt]*samples
46
+ infer_loader = get_validation_dataloader(prompts,num_workers=0)
47
  network = get_lora_network(pipe.unet, adapter_path)["network"]
48
  pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
49
  height=512, width=512, scales=[1.0],
 
53
  return pred_images
54
  @spaces.GPU
55
  def demo_inference_stylization(adapter_path:str, prompts:list, image:list, start_noise=800,seed:int=0):
56
+ infer_loader = get_validation_dataloader(prompts, image,num_workers=0)
57
  network = get_lora_network(pipe.unet, adapter_path,"all_up")["network"]
58
  pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
59
  height=512, width=512, scales=[0.,1.],