Update README.md
Browse files
README.md
CHANGED
@@ -20,7 +20,7 @@ library_name: diffusers
|
|
20 |
## Usage
|
21 |
```python
|
22 |
|
23 |
-
#
|
24 |
def load_model(model_path, device):
|
25 |
# Initialize the same model architecture as during training
|
26 |
model = ClassConditionedUnet().to(device)
|
@@ -33,7 +33,7 @@ def load_model(model_path, device):
|
|
33 |
|
34 |
return model
|
35 |
|
36 |
-
|
37 |
def predict(model, class_label, noise_scheduler, num_samples=8, device='cuda'):
|
38 |
model.eval() # Ensure the model is in evaluation mode
|
39 |
|
@@ -60,7 +60,6 @@ def predict(model, class_label, noise_scheduler, num_samples=8, device='cuda'):
|
|
60 |
|
61 |
return generated_images
|
62 |
|
63 |
-
# Display predicted images
|
64 |
def display_images(images, num_rows=2):
|
65 |
# Create a grid of images
|
66 |
grid = torchvision.utils.make_grid(images, nrow=num_rows)
|
@@ -75,17 +74,9 @@ def display_images(images, num_rows=2):
|
|
75 |
# Example of loading a model and generating predictions
|
76 |
model_path = "model_epoch_0.pth" # Path to your saved model
|
77 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
78 |
-
|
79 |
-
# Load the model
|
80 |
model = load_model(model_path, device)
|
81 |
-
|
82 |
-
# Create a noise scheduler
|
83 |
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')
|
84 |
-
|
85 |
-
# Predict and generate samples for a specific class label
|
86 |
class_label = 1 # Example class label, change to your desired class
|
87 |
generated_images = predict(model, class_label, noise_scheduler, num_samples=2, device=device)
|
88 |
-
|
89 |
-
# Display the generated images
|
90 |
display_images(generated_images)
|
91 |
```
|
|
|
20 |
## Usage
|
21 |
```python
|
22 |
|
23 |
+
# Predict function to generate images
|
24 |
def load_model(model_path, device):
|
25 |
# Initialize the same model architecture as during training
|
26 |
model = ClassConditionedUnet().to(device)
|
|
|
33 |
|
34 |
return model
|
35 |
|
36 |
+
|
37 |
def predict(model, class_label, noise_scheduler, num_samples=8, device='cuda'):
|
38 |
model.eval() # Ensure the model is in evaluation mode
|
39 |
|
|
|
60 |
|
61 |
return generated_images
|
62 |
|
|
|
63 |
def display_images(images, num_rows=2):
|
64 |
# Create a grid of images
|
65 |
grid = torchvision.utils.make_grid(images, nrow=num_rows)
|
|
|
74 |
# Example of loading a model and generating predictions
|
75 |
model_path = "model_epoch_0.pth" # Path to your saved model
|
76 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
|
77 |
model = load_model(model_path, device)
|
|
|
|
|
78 |
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')
|
|
|
|
|
79 |
class_label = 1 # Example class label, change to your desired class
|
80 |
generated_images = predict(model, class_label, noise_scheduler, num_samples=2, device=device)
|
|
|
|
|
81 |
display_images(generated_images)
|
82 |
```
|