muneebable commited on
Commit
0c33861
·
verified ·
1 Parent(s): 6379200

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +2 -11
README.md CHANGED
@@ -20,7 +20,7 @@ library_name: diffusers
20
  ## Usage
21
  ```python
22
 
23
- # Load the model
24
  def load_model(model_path, device):
25
  # Initialize the same model architecture as during training
26
  model = ClassConditionedUnet().to(device)
@@ -33,7 +33,7 @@ def load_model(model_path, device):
33
 
34
  return model
35
 
36
- # Predict function to generate images
37
  def predict(model, class_label, noise_scheduler, num_samples=8, device='cuda'):
38
  model.eval() # Ensure the model is in evaluation mode
39
 
@@ -60,7 +60,6 @@ def predict(model, class_label, noise_scheduler, num_samples=8, device='cuda'):
60
 
61
  return generated_images
62
 
63
- # Display predicted images
64
  def display_images(images, num_rows=2):
65
  # Create a grid of images
66
  grid = torchvision.utils.make_grid(images, nrow=num_rows)
@@ -75,17 +74,9 @@ def display_images(images, num_rows=2):
75
  # Example of loading a model and generating predictions
76
  model_path = "model_epoch_0.pth" # Path to your saved model
77
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
78
-
79
- # Load the model
80
  model = load_model(model_path, device)
81
-
82
- # Create a noise scheduler
83
  noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')
84
-
85
- # Predict and generate samples for a specific class label
86
  class_label = 1 # Example class label, change to your desired class
87
  generated_images = predict(model, class_label, noise_scheduler, num_samples=2, device=device)
88
-
89
- # Display the generated images
90
  display_images(generated_images)
91
  ```
 
20
  ## Usage
21
  ```python
22
 
23
+ # Predict function to generate images
24
  def load_model(model_path, device):
25
  # Initialize the same model architecture as during training
26
  model = ClassConditionedUnet().to(device)
 
33
 
34
  return model
35
 
36
+
37
  def predict(model, class_label, noise_scheduler, num_samples=8, device='cuda'):
38
  model.eval() # Ensure the model is in evaluation mode
39
 
 
60
 
61
  return generated_images
62
 
 
63
  def display_images(images, num_rows=2):
64
  # Create a grid of images
65
  grid = torchvision.utils.make_grid(images, nrow=num_rows)
 
74
  # Example of loading a model and generating predictions
75
  model_path = "model_epoch_0.pth" # Path to your saved model
76
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
 
77
  model = load_model(model_path, device)
 
 
78
  noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')
 
 
79
  class_label = 1 # Example class label, change to your desired class
80
  generated_images = predict(model, class_label, noise_scheduler, num_samples=2, device=device)
 
 
81
  display_images(generated_images)
82
  ```