muneebable
commited on
Update README.md
Browse files
README.md
CHANGED
@@ -21,6 +21,7 @@ library_name: diffusers
|
|
21 |
```python
|
22 |
|
23 |
# Predict function to generate images
|
|
|
24 |
def load_model(model_path, device):
|
25 |
# Initialize the same model architecture as during training
|
26 |
model = ClassConditionedUnet().to(device)
|
@@ -60,6 +61,7 @@ def predict(model, class_label, noise_scheduler, num_samples=8, device='cuda'):
|
|
60 |
|
61 |
return generated_images
|
62 |
|
|
|
63 |
def display_images(images, num_rows=2):
|
64 |
# Create a grid of images
|
65 |
grid = torchvision.utils.make_grid(images, nrow=num_rows)
|
@@ -72,6 +74,7 @@ def display_images(images, num_rows=2):
|
|
72 |
plt.show()
|
73 |
|
74 |
# Example of loading a model and generating predictions
|
|
|
75 |
model_path = "model_epoch_0.pth" # Path to your saved model
|
76 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
77 |
model = load_model(model_path, device)
|
|
|
21 |
```python
|
22 |
|
23 |
# Predict function to generate images
|
24 |
+
|
25 |
def load_model(model_path, device):
|
26 |
# Initialize the same model architecture as during training
|
27 |
model = ClassConditionedUnet().to(device)
|
|
|
61 |
|
62 |
return generated_images
|
63 |
|
64 |
+
|
65 |
def display_images(images, num_rows=2):
|
66 |
# Create a grid of images
|
67 |
grid = torchvision.utils.make_grid(images, nrow=num_rows)
|
|
|
74 |
plt.show()
|
75 |
|
76 |
# Example of loading a model and generating predictions
|
77 |
+
|
78 |
model_path = "model_epoch_0.pth" # Path to your saved model
|
79 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
80 |
model = load_model(model_path, device)
|