update torch load
Browse files
app.py
CHANGED
@@ -22,6 +22,7 @@ from PIL import Image
|
|
22 |
|
23 |
MY_TOKEN = os.environ.get("MY_TOKEN")
|
24 |
|
|
|
25 |
|
26 |
def image_to_base64(image: Image.Image) -> str:
|
27 |
buffered = BytesIO()
|
@@ -77,7 +78,9 @@ for item in sdxl_loras_raw:
|
|
77 |
saved_name = hf_hub_download(item["repo"], item["weights"])
|
78 |
|
79 |
if not saved_name.endswith('.safetensors'):
|
80 |
-
state_dict = torch.load(saved_name)
|
|
|
|
|
81 |
else:
|
82 |
state_dict = load_file(saved_name)
|
83 |
|
@@ -153,10 +156,10 @@ def run_lora(prompt, negative, lora_scale, selected_state, sdxl_loras, progress=
|
|
153 |
global last_lora, last_merged, last_fused, pipe
|
154 |
|
155 |
print("✅ Running LoRAAAAA >>>>>>>>>>>>> >>>>>>>>>>>>>>>>>>>")
|
156 |
-
print("prompt: ", prompt)
|
157 |
-
print("negative: ", negative)
|
158 |
-
print("lora_scale: ", lora_scale)
|
159 |
-
print("selected_state: ", selected_state)
|
160 |
print("selected_state index: ", selected_state.index)
|
161 |
|
162 |
|
|
|
22 |
|
23 |
MY_TOKEN = os.environ.get("MY_TOKEN")
|
24 |
|
25 |
+
print(torch.cuda.is_available())
|
26 |
|
27 |
def image_to_base64(image: Image.Image) -> str:
|
28 |
buffered = BytesIO()
|
|
|
78 |
saved_name = hf_hub_download(item["repo"], item["weights"])
|
79 |
|
80 |
if not saved_name.endswith('.safetensors'):
|
81 |
+
# state_dict = torch.load(saved_name)
|
82 |
+
state_dict = torch.load(saved_name, map_location=torch.device('cpu'))
|
83 |
+
|
84 |
else:
|
85 |
state_dict = load_file(saved_name)
|
86 |
|
|
|
156 |
global last_lora, last_merged, last_fused, pipe
|
157 |
|
158 |
print("✅ Running LoRAAAAA >>>>>>>>>>>>> >>>>>>>>>>>>>>>>>>>")
|
159 |
+
# print("prompt: ", prompt)
|
160 |
+
# print("negative: ", negative)
|
161 |
+
# print("lora_scale: ", lora_scale)
|
162 |
+
# print("selected_state: ", selected_state)
|
163 |
print("selected_state index: ", selected_state.index)
|
164 |
|
165 |
|