uruguayai commited on
Commit
cfafe9a
·
verified ·
1 Parent(s): ceeeb32

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -1
app.py CHANGED
@@ -27,6 +27,17 @@ os.makedirs(model_cache_dir, exist_ok=True)
27
  print(f"Cache directory: {cache_dir}")
28
  print(f"Model cache directory: {model_cache_dir}")
29
 
 
 
 
 
 
 
 
 
 
 
 
30
  # Function to load or download the model
31
  def get_model(model_id, revision):
32
  model_cache_file = os.path.join(model_cache_dir, f"{model_id.replace('/', '_')}_{revision}.pkl")
@@ -227,6 +238,6 @@ for epoch in range(num_epochs):
227
  # Save the fine-tuned model
228
  output_dir = "/tmp/montevideo_fine_tuned_model"
229
  os.makedirs(output_dir, exist_ok=True)
230
- adjusted_unet.save_pretrained(output_dir, params=state.params["params"])
231
 
232
  print(f"Model saved to {output_dir}")
 
27
  print(f"Cache directory: {cache_dir}")
28
  print(f"Model cache directory: {model_cache_dir}")
29
 
30
+
31
+
32
+ def filter_dict(dict_to_filter, target_callable):
33
+ """Filter a dictionary to only include keys that are valid parameters for the target callable."""
34
+ valid_params = signature(target_callable).parameters.keys()
35
+ return {k: v for k, v in dict_to_filter.items() if k in valid_params}
36
+
37
+
38
+
39
+
40
+
41
  # Function to load or download the model
42
  def get_model(model_id, revision):
43
  model_cache_file = os.path.join(model_cache_dir, f"{model_id.replace('/', '_')}_{revision}.pkl")
 
238
  # Save the fine-tuned model
239
  output_dir = "/tmp/montevideo_fine_tuned_model"
240
  os.makedirs(output_dir, exist_ok=True)
241
+ adjusted_unet.save_pretrained(output_dir, params=state.params)
242
 
243
  print(f"Model saved to {output_dir}")