yozozaya commited on
Commit
eff8fcf
·
1 Parent(s): 1bcfe87

update torch load

Browse files
Files changed (1) hide show
  1. app.py +8 -5
app.py CHANGED
@@ -22,6 +22,7 @@ from PIL import Image
22
 
23
  MY_TOKEN = os.environ.get("MY_TOKEN")
24
 
 
25
 
26
  def image_to_base64(image: Image.Image) -> str:
27
  buffered = BytesIO()
@@ -77,7 +78,9 @@ for item in sdxl_loras_raw:
77
  saved_name = hf_hub_download(item["repo"], item["weights"])
78
 
79
  if not saved_name.endswith('.safetensors'):
80
- state_dict = torch.load(saved_name)
 
 
81
  else:
82
  state_dict = load_file(saved_name)
83
 
@@ -153,10 +156,10 @@ def run_lora(prompt, negative, lora_scale, selected_state, sdxl_loras, progress=
153
  global last_lora, last_merged, last_fused, pipe
154
 
155
  print("✅ Running LoRAAAAA >>>>>>>>>>>>> >>>>>>>>>>>>>>>>>>>")
156
- print("prompt: ", prompt)
157
- print("negative: ", negative)
158
- print("lora_scale: ", lora_scale)
159
- print("selected_state: ", selected_state)
160
  print("selected_state index: ", selected_state.index)
161
 
162
 
 
22
 
23
  MY_TOKEN = os.environ.get("MY_TOKEN")
24
 
25
+ print(torch.cuda.is_available())
26
 
27
  def image_to_base64(image: Image.Image) -> str:
28
  buffered = BytesIO()
 
78
  saved_name = hf_hub_download(item["repo"], item["weights"])
79
 
80
  if not saved_name.endswith('.safetensors'):
81
+ # state_dict = torch.load(saved_name)
82
+ state_dict = torch.load(saved_name, map_location=torch.device('cpu'))
83
+
84
  else:
85
  state_dict = load_file(saved_name)
86
 
 
156
  global last_lora, last_merged, last_fused, pipe
157
 
158
  print("✅ Running LoRAAAAA >>>>>>>>>>>>> >>>>>>>>>>>>>>>>>>>")
159
+ # print("prompt: ", prompt)
160
+ # print("negative: ", negative)
161
+ # print("lora_scale: ", lora_scale)
162
+ # print("selected_state: ", selected_state)
163
  print("selected_state index: ", selected_state.index)
164
 
165