Update app.py
Browse files
app.py
CHANGED
@@ -27,6 +27,17 @@ os.makedirs(model_cache_dir, exist_ok=True)
|
|
27 |
print(f"Cache directory: {cache_dir}")
|
28 |
print(f"Model cache directory: {model_cache_dir}")
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
# Function to load or download the model
|
31 |
def get_model(model_id, revision):
|
32 |
model_cache_file = os.path.join(model_cache_dir, f"{model_id.replace('/', '_')}_{revision}.pkl")
|
@@ -227,6 +238,6 @@ for epoch in range(num_epochs):
|
|
227 |
# Save the fine-tuned model
|
228 |
output_dir = "/tmp/montevideo_fine_tuned_model"
|
229 |
os.makedirs(output_dir, exist_ok=True)
|
230 |
-
adjusted_unet.save_pretrained(output_dir, params=state.params
|
231 |
|
232 |
print(f"Model saved to {output_dir}")
|
|
|
27 |
print(f"Cache directory: {cache_dir}")
|
28 |
print(f"Model cache directory: {model_cache_dir}")
|
29 |
|
30 |
+
|
31 |
+
|
32 |
+
def filter_dict(dict_to_filter, target_callable):
|
33 |
+
"""Filter a dictionary to only include keys that are valid parameters for the target callable."""
|
34 |
+
valid_params = signature(target_callable).parameters.keys()
|
35 |
+
return {k: v for k, v in dict_to_filter.items() if k in valid_params}
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
# Function to load or download the model
|
42 |
def get_model(model_id, revision):
|
43 |
model_cache_file = os.path.join(model_cache_dir, f"{model_id.replace('/', '_')}_{revision}.pkl")
|
|
|
238 |
# Save the fine-tuned model
|
239 |
output_dir = "/tmp/montevideo_fine_tuned_model"
|
240 |
os.makedirs(output_dir, exist_ok=True)
|
241 |
+
adjusted_unet.save_pretrained(output_dir, params=state.params)
|
242 |
|
243 |
print(f"Model saved to {output_dir}")
|