Spaces:
Sleeping
Sleeping
[UPD] Upd. app.py
Browse files
app.py
CHANGED
@@ -2,10 +2,11 @@ import gradio as gr
|
|
2 |
import numpy as np
|
3 |
import random
|
4 |
|
5 |
-
# import spaces #[uncomment to use ZeroGPU]
|
6 |
from diffusers import DiffusionPipeline
|
|
|
7 |
import torch
|
8 |
|
|
|
9 |
|
10 |
# Model list including your LoRA model
|
11 |
MODEL_LIST = [
|
@@ -16,14 +17,14 @@ MODEL_LIST = [
|
|
16 |
"YaArtemNosenko/dino_stickers",
|
17 |
]
|
18 |
|
19 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
20 |
-
|
21 |
if torch.cuda.is_available():
|
22 |
torch_dtype = torch.float16
|
23 |
else:
|
24 |
torch_dtype = torch.float32
|
25 |
|
26 |
-
|
|
|
|
|
27 |
def load_pipeline(model_id: str):
|
28 |
"""
|
29 |
Loads or retrieves a cached DiffusionPipeline.
|
@@ -56,20 +57,13 @@ def load_pipeline(model_id: str):
|
|
56 |
|
57 |
pipe.to(device)
|
58 |
model_cache[model_id] = pipe
|
59 |
-
|
60 |
return pipe
|
61 |
|
62 |
-
# model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
|
63 |
-
# pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
|
64 |
-
# pipe = pipe.to(device)
|
65 |
-
|
66 |
MAX_SEED = np.iinfo(np.int32).max
|
67 |
MAX_IMAGE_SIZE = 1024
|
68 |
|
69 |
-
|
70 |
-
# @spaces.GPU #[uncomment to use ZeroGPU]
|
71 |
def infer(
|
72 |
-
|
73 |
prompt,
|
74 |
negative_prompt,
|
75 |
seed,
|
@@ -78,15 +72,20 @@ def infer(
|
|
78 |
height,
|
79 |
guidance_scale,
|
80 |
num_inference_steps,
|
|
|
81 |
progress=gr.Progress(track_tqdm=True),
|
82 |
):
|
|
|
|
|
|
|
83 |
if randomize_seed:
|
84 |
seed = random.randint(0, MAX_SEED)
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
if model_id == "YaArtemNosenko/dino_stickers":
|
|
|
90 |
if hasattr(pipe.unet, "set_lora_scale"):
|
91 |
pipe.unet.set_lora_scale(lora_scale)
|
92 |
else:
|
@@ -104,7 +103,6 @@ def infer(
|
|
104 |
|
105 |
return image, seed
|
106 |
|
107 |
-
|
108 |
examples = [
|
109 |
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
110 |
"An astronaut riding a green horse",
|
@@ -224,4 +222,4 @@ with gr.Blocks(css=css) as demo:
|
|
224 |
)
|
225 |
|
226 |
if __name__ == "__main__":
|
227 |
-
demo.launch()
|
|
|
2 |
import numpy as np
|
3 |
import random
|
4 |
|
|
|
5 |
from diffusers import DiffusionPipeline
|
6 |
+
from peft import PeftModel, PeftConfig
|
7 |
import torch
|
8 |
|
9 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
10 |
|
11 |
# Model list including your LoRA model
|
12 |
MODEL_LIST = [
|
|
|
17 |
"YaArtemNosenko/dino_stickers",
|
18 |
]
|
19 |
|
|
|
|
|
20 |
if torch.cuda.is_available():
|
21 |
torch_dtype = torch.float16
|
22 |
else:
|
23 |
torch_dtype = torch.float32
|
24 |
|
25 |
+
# Cache to avoid re-initializing pipelines repeatedly
|
26 |
+
model_cache = {}
|
27 |
+
|
28 |
def load_pipeline(model_id: str):
|
29 |
"""
|
30 |
Loads or retrieves a cached DiffusionPipeline.
|
|
|
57 |
|
58 |
pipe.to(device)
|
59 |
model_cache[model_id] = pipe
|
|
|
60 |
return pipe
|
61 |
|
|
|
|
|
|
|
|
|
62 |
MAX_SEED = np.iinfo(np.int32).max
|
63 |
MAX_IMAGE_SIZE = 1024
|
64 |
|
|
|
|
|
65 |
def infer(
|
66 |
+
model_id,
|
67 |
prompt,
|
68 |
negative_prompt,
|
69 |
seed,
|
|
|
72 |
height,
|
73 |
guidance_scale,
|
74 |
num_inference_steps,
|
75 |
+
lora_scale, # New parameter for adjusting LoRA scale
|
76 |
progress=gr.Progress(track_tqdm=True),
|
77 |
):
|
78 |
+
# Load the pipeline for the chosen model
|
79 |
+
pipe = load_pipeline(model_id)
|
80 |
+
|
81 |
if randomize_seed:
|
82 |
seed = random.randint(0, MAX_SEED)
|
83 |
+
|
84 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
85 |
+
|
86 |
+
# If using the LoRA model, update the LoRA scale if supported.
|
87 |
if model_id == "YaArtemNosenko/dino_stickers":
|
88 |
+
# This assumes your pipeline's unet has a method to update the LoRA scale.
|
89 |
if hasattr(pipe.unet, "set_lora_scale"):
|
90 |
pipe.unet.set_lora_scale(lora_scale)
|
91 |
else:
|
|
|
103 |
|
104 |
return image, seed
|
105 |
|
|
|
106 |
examples = [
|
107 |
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
108 |
"An astronaut riding a green horse",
|
|
|
222 |
)
|
223 |
|
224 |
if __name__ == "__main__":
|
225 |
+
demo.launch()
|