Spaces:
Running
on
Zero
Running
on
Zero
update
Browse files- hf_demo.py +4 -3
hf_demo.py
CHANGED
@@ -11,8 +11,9 @@ from PIL import Image
|
|
11 |
|
12 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
13 |
dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16
|
|
|
14 |
pipe = DiffusionPipeline.from_pretrained("rhfeiyang/art-free-diffusion-v1",
|
15 |
-
|
16 |
|
17 |
from inference import get_lora_network, inference, get_validation_dataloader
|
18 |
lora_map = {
|
@@ -42,7 +43,7 @@ def demo_inference_gen(adapter_choice:str, prompt:str, samples:int=1,seed:int=0,
|
|
42 |
adapter_path = f"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt"
|
43 |
|
44 |
prompts = [prompt]*samples
|
45 |
-
infer_loader = get_validation_dataloader(prompts)
|
46 |
network = get_lora_network(pipe.unet, adapter_path)["network"]
|
47 |
pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
|
48 |
height=512, width=512, scales=[1.0],
|
@@ -52,7 +53,7 @@ def demo_inference_gen(adapter_choice:str, prompt:str, samples:int=1,seed:int=0,
|
|
52 |
return pred_images
|
53 |
@spaces.GPU
|
54 |
def demo_inference_stylization(adapter_path:str, prompts:list, image:list, start_noise=800,seed:int=0):
|
55 |
-
infer_loader = get_validation_dataloader(prompts, image)
|
56 |
network = get_lora_network(pipe.unet, adapter_path,"all_up")["network"]
|
57 |
pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
|
58 |
height=512, width=512, scales=[0.,1.],
|
|
|
11 |
|
12 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
13 |
dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16
|
14 |
+
print(f"Using {device} device, dtype={dtype}")
|
15 |
pipe = DiffusionPipeline.from_pretrained("rhfeiyang/art-free-diffusion-v1",
|
16 |
+
torch_dtype=dtype).to(device)
|
17 |
|
18 |
from inference import get_lora_network, inference, get_validation_dataloader
|
19 |
lora_map = {
|
|
|
43 |
adapter_path = f"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt"
|
44 |
|
45 |
prompts = [prompt]*samples
|
46 |
+
infer_loader = get_validation_dataloader(prompts,num_workers=0)
|
47 |
network = get_lora_network(pipe.unet, adapter_path)["network"]
|
48 |
pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
|
49 |
height=512, width=512, scales=[1.0],
|
|
|
53 |
return pred_images
|
54 |
@spaces.GPU
|
55 |
def demo_inference_stylization(adapter_path:str, prompts:list, image:list, start_noise=800,seed:int=0):
|
56 |
+
infer_loader = get_validation_dataloader(prompts, image,num_workers=0)
|
57 |
network = get_lora_network(pipe.unet, adapter_path,"all_up")["network"]
|
58 |
pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
|
59 |
height=512, width=512, scales=[0.,1.],
|