Spaces:
Build error
Build error
Update inference.py
Browse files- inference.py +10 -29
inference.py
CHANGED
|
@@ -10,7 +10,7 @@ import torch
|
|
| 10 |
from diffusers import StableDiffusionPipeline
|
| 11 |
|
| 12 |
sys.path.insert(0, 'lora')
|
| 13 |
-
from
|
| 14 |
|
| 15 |
|
| 16 |
class InferencePipeline:
|
|
@@ -28,24 +28,16 @@ class InferencePipeline:
|
|
| 28 |
gc.collect()
|
| 29 |
|
| 30 |
@staticmethod
|
| 31 |
-
def
|
| 32 |
curr_dir = pathlib.Path(__file__).parent
|
| 33 |
return curr_dir / name
|
| 34 |
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
parent_dir = path.parent
|
| 38 |
-
stem = path.stem
|
| 39 |
-
text_encoder_filename = f'{stem}.text_encoder.pt'
|
| 40 |
-
path = parent_dir / text_encoder_filename
|
| 41 |
-
return path.as_posix() if path.exists() else ''
|
| 42 |
-
|
| 43 |
-
def load_pipe(self, model_id: str, lora_filename: str) -> None:
|
| 44 |
-
weight_path = self.get_lora_weight_path(lora_filename)
|
| 45 |
if weight_path == self.weight_path:
|
| 46 |
return
|
| 47 |
self.weight_path = weight_path
|
| 48 |
-
|
| 49 |
|
| 50 |
if self.device.type == 'cpu':
|
| 51 |
pipe = StableDiffusionPipeline.from_pretrained(model_id)
|
|
@@ -54,40 +46,29 @@ class InferencePipeline:
|
|
| 54 |
model_id, torch_dtype=torch.float16)
|
| 55 |
pipe = pipe.to(self.device)
|
| 56 |
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
lora_text_encoder_weight_path = self.get_lora_text_encoder_weight_path(
|
| 60 |
-
weight_path)
|
| 61 |
-
if lora_text_encoder_weight_path:
|
| 62 |
-
lora_text_encoder_weight = torch.load(
|
| 63 |
-
lora_text_encoder_weight_path, map_location=self.device)
|
| 64 |
-
monkeypatch_lora(pipe.text_encoder,
|
| 65 |
-
lora_text_encoder_weight,
|
| 66 |
-
target_replace_module=['CLIPAttention'])
|
| 67 |
|
| 68 |
self.pipe = pipe
|
| 69 |
|
| 70 |
def run(
|
| 71 |
self,
|
| 72 |
base_model: str,
|
| 73 |
-
|
| 74 |
prompt: str,
|
| 75 |
-
alpha: float,
|
| 76 |
-
alpha_for_text: float,
|
| 77 |
seed: int,
|
| 78 |
n_steps: int,
|
| 79 |
guidance_scale: float,
|
|
|
|
| 80 |
) -> PIL.Image.Image:
|
| 81 |
if not torch.cuda.is_available():
|
| 82 |
raise gr.Error('CUDA is not available.')
|
| 83 |
|
| 84 |
-
self.load_pipe(base_model,
|
| 85 |
|
| 86 |
generator = torch.Generator(device=self.device).manual_seed(seed)
|
| 87 |
-
tune_lora_scale(self.pipe.unet, alpha) # type: ignore
|
| 88 |
-
tune_lora_scale(self.pipe.text_encoder, alpha_for_text) # type: ignore
|
| 89 |
out = self.pipe(prompt,
|
| 90 |
num_inference_steps=n_steps,
|
| 91 |
guidance_scale=guidance_scale,
|
|
|
|
| 92 |
generator=generator) # type: ignore
|
| 93 |
return out.images[0]
|
|
|
|
| 10 |
from diffusers import StableDiffusionPipeline
|
| 11 |
|
| 12 |
sys.path.insert(0, 'lora')
|
| 13 |
+
from src import sample_diffuser, diffuser_training
|
| 14 |
|
| 15 |
|
| 16 |
class InferencePipeline:
|
|
|
|
| 28 |
gc.collect()
|
| 29 |
|
| 30 |
@staticmethod
|
| 31 |
+
def get_weight_path(name: str) -> pathlib.Path:
|
| 32 |
curr_dir = pathlib.Path(__file__).parent
|
| 33 |
return curr_dir / name
|
| 34 |
|
| 35 |
+
def load_pipe(self, model_id: str, filename: str) -> None:
|
| 36 |
+
weight_path = self.get_weight_path(filename)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
if weight_path == self.weight_path:
|
| 38 |
return
|
| 39 |
self.weight_path = weight_path
|
| 40 |
+
weight = torch.load(self.weight_path, map_location=self.device)
|
| 41 |
|
| 42 |
if self.device.type == 'cpu':
|
| 43 |
pipe = StableDiffusionPipeline.from_pretrained(model_id)
|
|
|
|
| 46 |
model_id, torch_dtype=torch.float16)
|
| 47 |
pipe = pipe.to(self.device)
|
| 48 |
|
| 49 |
+
diffuser_training.load_model(pipe.text_encoder, pipe.tokenizer, pipe.unet, weight_path, '<new1>')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
self.pipe = pipe
|
| 52 |
|
| 53 |
def run(
|
| 54 |
self,
|
| 55 |
base_model: str,
|
| 56 |
+
weight_name: str,
|
| 57 |
prompt: str,
|
|
|
|
|
|
|
| 58 |
seed: int,
|
| 59 |
n_steps: int,
|
| 60 |
guidance_scale: float,
|
| 61 |
+
eta: float,
|
| 62 |
) -> PIL.Image.Image:
|
| 63 |
if not torch.cuda.is_available():
|
| 64 |
raise gr.Error('CUDA is not available.')
|
| 65 |
|
| 66 |
+
self.load_pipe(base_model, weight_name)
|
| 67 |
|
| 68 |
generator = torch.Generator(device=self.device).manual_seed(seed)
|
|
|
|
|
|
|
| 69 |
out = self.pipe(prompt,
|
| 70 |
num_inference_steps=n_steps,
|
| 71 |
guidance_scale=guidance_scale,
|
| 72 |
+
eta = eta,
|
| 73 |
generator=generator) # type: ignore
|
| 74 |
return out.images[0]
|