quanglnt commited on
Commit
8c36119
·
1 Parent(s): 07a2ce5

Add application files

Browse files
Files changed (9) hide show
  1. app.py +25 -0
  2. device.py +15 -0
  3. predict.py +105 -0
  4. simple_cnn.py +28 -0
  5. simple_nn.py +17 -0
  6. test.py +130 -0
  7. train.py +109 -0
  8. view_image.py +63 -0
  9. view_model_information.py +26 -0
app.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from predict import predict_image
3
+ # Custom CSS to make the label text bigger
4
+ def predict(image_dict):
5
+ # Extract the "composite" key from the dictionary
6
+ composite_image = image_dict["composite"]
7
+ # composite_image.save("sketchpad_output.png") # Save as PNG
8
+ predicted = predict_image(composite_image)
9
+ # print(predicted)
10
+ return predicted #, composite_image # Directly return the PIL image
11
+ css = """
12
+ .big-label {
13
+ font-size: 24px; /* Adjust this value to make the label bigger */
14
+ font-weight: bold; /* Optional: to make it bold */
15
+ }
16
+ """
17
+ demo = gr.Interface(
18
+ fn=predict,
19
+ inputs=gr.Sketchpad(type="pil", brush=gr.Brush(default_size=20)), # Ensure it returns a PIL image
20
+ outputs=[gr.Label(num_top_classes=3, label="Predicted number is:", elem_classes=["big-label"])],
21
+ css=css
22
+ )
23
+
24
+ if __name__ == "__main__":
25
+ demo.launch()
device.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def get_device():
4
+ if torch.cuda.is_available:
5
+ # print('cuda is available')
6
+ return 'cuda'
7
+ elif torch.backends.mps.is_available:
8
+ # print('mps is available')
9
+ return 'mps'
10
+ else:
11
+ # print('using cpu')
12
+ return 'cpu'
13
+
14
+ device = get_device()
15
+ # print(device)
predict.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import numpy as np
4
+ import torch
5
+ from torchvision import transforms
6
+ from device import get_device
7
+ from simple_nn import SimpleNN
8
+ from simple_cnn import SimpleCNN
9
+ from view_image import view_image, view_tensor_image
10
+
11
+ transform = transforms.Compose([
12
+ transforms.ToTensor(),
13
+ transforms.Normalize((0.5,), (0.5,))
14
+ ])
15
+ def predict_image(image):
16
+
17
+ model = SimpleCNN()
18
+ with open('mnist_simple_cnn.pht', 'rb') as f:
19
+ state_dict = torch.load(f, weights_only=True)
20
+ model.load_state_dict(state_dict)
21
+ model.eval()
22
+
23
+ image = image.convert('RGBA')
24
+ grayscale_image = Image.new("L", image.size, 255) # Create a white background
25
+ grayscale_image.paste(image.convert("L"), mask=image.split()[3]) # Use alpha channel as mask
26
+ grayscale_image = grayscale_image.resize((28, 28)) # Resize to 28x28 pixels
27
+
28
+ grayscale_image.save("processed_image.png")
29
+
30
+ image_np = np.array(grayscale_image)
31
+ image_np = 255 - image_np # Invert colors (MNIST has white digits on black)
32
+
33
+ # Normalize to range [0, 1]
34
+ image_np = image_np / 255.0
35
+
36
+ image_tensor = transform(image_np) # Add batch and channel dimensions
37
+ image_tensor = image_tensor.unsqueeze(0)
38
+ image_tensor = image_tensor.to(torch.float32)
39
+
40
+ # image_tensor = transform(grayscale_image).unsqueeze(0) # Add batch and channel dimensions
41
+
42
+ with torch.no_grad():
43
+ output = model(image_tensor)
44
+ #_, predicted = torch.max(output.data, 1)
45
+ probabilities = torch.softmax(output, dim=1)
46
+ # Convert probabilities to a list of (class, probability)
47
+ class_probabilities = {
48
+ str(class_index): prob.item() for class_index, prob in enumerate(probabilities[0])
49
+ }
50
+ print(class_probabilities)
51
+ # class_probabilities = {}
52
+ return class_probabilities
53
+
54
+ def predict(model_path, image_path):
55
+ model = SimpleCNN()
56
+ with open(model_path, 'rb') as f:
57
+ state_dict = torch.load(f, weights_only=True)
58
+ model.load_state_dict(state_dict)
59
+ model.eval()
60
+
61
+ # Load and preprocess the image
62
+ image = Image.open(image_path).convert("L") # Convert to grayscale
63
+ # view_image(image=image)
64
+ # Resize to 28x28
65
+ image = image.resize((28, 28))
66
+
67
+ # Convert to NumPy array and invert colors if needed
68
+ image_np = np.array(image)
69
+ image_np = 255 - image_np # Invert colors (MNIST has white digits on black)
70
+
71
+ # Normalize to range [0, 1]
72
+ image_np = image_np / 255.0
73
+
74
+ # Convert to tensor
75
+
76
+ image_tensor = transform(image_np) # Add batch and channel dimensions
77
+ image_tensor = image_tensor.unsqueeze(0)
78
+ image_tensor = image_tensor.to(torch.float32)
79
+ # Ensure the tensor is in the correct dtype
80
+
81
+ # view_tensor_image(image_tensor=image_tensor)
82
+
83
+ with torch.no_grad():
84
+ output = model(image_tensor)
85
+ #_, predicted = torch.max(output.data, 1)
86
+ probabilities = torch.softmax(output, dim=1)
87
+ # Convert probabilities to a list of (class, probability)
88
+ class_probabilities = {
89
+ str(class_index): prob.item() for class_index, prob in enumerate(probabilities[0])
90
+ }
91
+ # return predicted.item()
92
+ return class_probabilities
93
+ if __name__ == "__main__":
94
+ device = get_device()
95
+ model_path = "trained_model/mnist_simple_cnn.pht"
96
+
97
+ # Loop through all files in the test folder
98
+ test_folder = "test/"
99
+ for filename in os.listdir(test_folder):
100
+ if filename.endswith(".png"): # Only process .png files (you can add more extensions if needed)
101
+ image_path = os.path.join(test_folder, filename)
102
+ predicted = predict(model_path = "mnist_model.pht",image_path=image_path)
103
+ print(F"[INFO] The predicted results of the image {image_path} are: {predicted}")
104
+ print()
105
+
simple_cnn.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class SimpleCNN(nn.Module):
6
+ def __init__(self):
7
+ super(SimpleCNN, self).__init__()
8
+ # Convolutional layers
9
+ self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1) # 28x28 -> 28x28
10
+ self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1) # 28x28 -> 28x28
11
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2) # 28x28 -> 14x14
12
+
13
+ self.bn1 = nn.BatchNorm2d(32)
14
+ self.bn2 = nn.BatchNorm2d(64)
15
+
16
+ # Fully connected layers
17
+ self.fc1 = nn.Linear(64 * 14 * 14, 128)
18
+ self.dropout = nn.Dropout(0.5)
19
+ self.fc2 = nn.Linear(128, 10)
20
+
21
+ def forward(self, x):
22
+ x = F.relu(self.bn1(self.conv1(x))) # Apply first convolution and ReLU
23
+ x = self.pool(F.relu(self.bn2(self.conv2(x)))) # Apply second convolution, ReLU, and pooling
24
+ x = torch.flatten(x, 1) # Flatten the feature maps
25
+ x = F.relu(self.fc1(x)) # Fully connected layer with ReLU
26
+ x = self.dropout(x)
27
+ x = self.fc2(x) # Output layer
28
+ return x
simple_nn.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class SimpleNN(nn.Module):
5
+ def __init__(self):
6
+ super(SimpleNN, self).__init__()
7
+ self.fc1 = nn.Linear(28*28, 128)
8
+ self.fc2 = nn.Linear(128, 64)
9
+ self.fc3 = nn.Linear(64, 10)
10
+
11
+
12
+ def forward(self, x):
13
+ x = x.view(-1, 28*28)
14
+ x = torch.relu(self.fc1(x))
15
+ x = torch.relu(self.fc2(x))
16
+ x = self.fc3(x)
17
+ return x
test.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+ from torchvision import datasets, transforms
7
+ from torch.utils.data import DataLoader
8
+ from torch.utils.tensorboard import SummaryWriter
9
+ from simple_nn import SimpleNN
10
+ from simple_cnn import SimpleCNN
11
+ from device import get_device
12
+ from view_image import view_image,view_batch_images, save_batch_images
13
+ from tqdm import tqdm # Import tqdm for the progress bar
14
+
15
+ root_path = os.path.expanduser('data')
16
+
17
+ # Define transforms for training and testing
18
+ transforms = {
19
+ 'train': transforms.Compose([
20
+ transforms.RandomRotation(10),
21
+ transforms.RandomHorizontalFlip(),
22
+ transforms.ToTensor(),
23
+ transforms.Normalize((0.5,), (0.5,)) # Normalize for MNIST
24
+ ]),
25
+ 'valid_test': transforms.Compose([
26
+ transforms.ToTensor(),
27
+ transforms.Normalize((0.5,), (0.5,))
28
+ ])
29
+ }
30
+
31
+ # Define dataset and dataloader
32
+ train_dataset = datasets.MNIST(root=root_path, download=True, train=True, transform=transforms['train'])
33
+ test_dataset = datasets.MNIST(root=root_path, download=True, train=False, transform=transforms['valid_test'])
34
+
35
+ train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
36
+ test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)
37
+
38
+ model = SimpleCNN()
39
+ device = get_device()
40
+
41
+ model.to(device=device)
42
+ criterion = nn.CrossEntropyLoss()
43
+ optimizer = optim.Adam(model.parameters(), lr=0.001)
44
+ epochs = 50
45
+ for epoch in range(epochs):
46
+ model.train() # Set the model to training mode
47
+ epoch_loss = 0 # Initialize epoch loss
48
+ correct = 0 # Track number of correct predictions
49
+ total = 0 # Track total predictions
50
+
51
+ with tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{epochs}", unit="batch") as pbar:
52
+ for batch_idx, (data, target) in enumerate(train_loader):
53
+ data, target = data.to(device), target.to(device) # Move data to device
54
+
55
+ # Forward pass
56
+ output = model(data)
57
+ loss = criterion(output, target)
58
+
59
+ # Backward pass and optimization
60
+ optimizer.zero_grad()
61
+ loss.backward()
62
+ optimizer.step()
63
+
64
+ epoch_loss += loss.item() # Add current batch loss to epoch loss
65
+
66
+ # Calculate accuracy for this batch
67
+ _, predicted = torch.max(output, 1)
68
+ total += target.size(0)
69
+ correct += (predicted == target).sum().item()
70
+
71
+ # Update the progress bar
72
+ pbar.set_postfix(loss=loss.item(), accuracy=100. * correct / total)
73
+ pbar.update(1)
74
+
75
+ # Calculate the average loss and accuracy for this epoch
76
+ avg_loss = epoch_loss / len(train_loader)
77
+ accuracy = 100. * correct / total
78
+
79
+
80
+ print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%")
81
+
82
+ # Now validation (on validation set)
83
+ correct_val = 0
84
+ total_val = 0
85
+ model.eval() # Set the model to evaluation mode
86
+ with torch.no_grad():
87
+ for data, target in test_loader: # Assuming `val_loader` is your validation data loader
88
+ data, target = data.to(device), target.to(device)
89
+ output = model(data)
90
+ _, predicted = torch.max(output.data, 1)
91
+ total_val += target.size(0)
92
+ correct_val += (predicted == target).sum().item()
93
+
94
+ accuracy_val = 100 * correct_val / total_val
95
+ print(f'Validation Accuracy: {accuracy_val:.2f}%')
96
+
97
+
98
+ # Save model after training
99
+ torch.save(model.state_dict(), "mnist_simple_cnn.pht")
100
+ print(f'Model saved')
101
+
102
+
103
+ model.eval()
104
+
105
+ test_folder = "test/"
106
+
107
+ # Loop through all files in the test folder
108
+ for filename in os.listdir(test_folder):
109
+ if filename.endswith(".png"): # Only process .png files (you can add more extensions if needed)
110
+ image_path = os.path.join(test_folder, filename)
111
+
112
+ # Load and preprocess the image
113
+ image = Image.open(image_path).convert("L") # Convert to grayscale
114
+ image = image.resize((28, 28)) # Resize to 28x28 pixels
115
+ image_tensor = transforms['valid_test'](image).unsqueeze(0) # Add batch and channel dimensions
116
+ image_tensor = image_tensor.to(device)
117
+
118
+ # Make prediction
119
+ with torch.no_grad():
120
+ output = model(image_tensor)
121
+ probabilities = torch.softmax(output, dim=1)
122
+
123
+ # Convert probabilities to a list of (class, probability)
124
+ class_probabilities = {
125
+ str(class_index): prob.item() for class_index, prob in enumerate(probabilities[0])
126
+ }
127
+
128
+ # Print or store the predictions for the current image
129
+ print(f"Predictions for {filename}: {class_probabilities}")
130
+ print() # Line break for separation
train.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torchvision import datasets, transforms
6
+ from torch.utils.data import DataLoader
7
+ from torch.utils.tensorboard import SummaryWriter
8
+ from simple_nn import SimpleNN
9
+ from simple_cnn import SimpleCNN
10
+ from device import get_device
11
+ from view_image import view_batch_images, save_batch_images
12
+ from tqdm import tqdm # Import tqdm for the progress bar
13
+
14
+
15
+ def train(model, device, train_loader, test_loader, criterion, optimizer, epochs = 5):
16
+
17
+ # Initialize TensorBoard writer
18
+ # writer = SummaryWriter('runs/mnist_experiment')
19
+
20
+ # Train the model
21
+ for epoch in range(epochs):
22
+ model.train()
23
+ epoch_loss = 0
24
+ with tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{epochs}", unit="batch") as pbar:
25
+ for batch_idx, (data, target) in enumerate(train_loader):
26
+ # Forward pass
27
+ data, target = data.to(device), target.to(device)
28
+ output = model(data)
29
+ loss = criterion(output, target)
30
+
31
+ # Backward pass and optimization
32
+ optimizer.zero_grad()
33
+ loss.backward()
34
+ optimizer.step()
35
+
36
+ epoch_loss += loss.item()
37
+
38
+ pbar.set_postfix(loss=loss.item())
39
+ pbar.update(1)
40
+
41
+ # Log the average loss for this epoch to TensorBoard
42
+ # avg_loss = epoch_loss / len(train_loader)
43
+ # writer.add_scalar('Loss/train', avg_loss, epoch)
44
+
45
+ print(f'Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}')
46
+ validation(model,device,test_loader)
47
+ # After training, visualize with TensorBoard
48
+ # writer.close()
49
+
50
+ # Test the model
51
+ def validation(model, device, data_loader):
52
+ correct = 0
53
+ total = 0
54
+ model.eval() # Set the model to evaluation mode
55
+ with torch.no_grad():
56
+ for data, target in data_loader:
57
+ data, target = data.to(device), target.to(device)
58
+ output = model(data)
59
+ _, predicted = torch.max(output.data, 1)
60
+ total += target.size(0)
61
+ correct += (predicted == target).sum().item()
62
+
63
+ accuracy = 100 * correct / total
64
+ print(f'Accuracy on test set: {accuracy:.2f}%')
65
+
66
+ # Save the trained model
67
+ def save_model(model, model_save_path):
68
+ torch.save(model.state_dict(), model_save_path)
69
+ print(f'[INFO] Model saved to {model_save_path}')
70
+
71
+
72
+ if __name__ == "__main__":
73
+ root_path = os.path.expanduser('data')
74
+ # Load the dataset
75
+ transforms = {
76
+ 'train': transforms.Compose([
77
+ transforms.RandomRotation(10),
78
+ transforms.RandomHorizontalFlip(),
79
+ transforms.ToTensor(),
80
+ transforms.Normalize((0.5,), (0.5,)) # Normalize for MNIST
81
+ ]),
82
+ 'valid_test' : transforms.Compose([
83
+ transforms.ToTensor(),
84
+ transforms.Normalize((0.5,), (0.5,))
85
+ ])
86
+ }
87
+ train_dataset = datasets.MNIST(root=root_path, download=False, train=True, transform=transforms['train'])
88
+ test_dataset = datasets.MNIST(root=root_path, download=True, train=False, transform=transforms['valid_test'])
89
+ train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
90
+ test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)
91
+ # view_batch_images(train_loader=train_loader)
92
+
93
+ # Get a batch of images
94
+ # data_iter = iter(train_loader)
95
+ # images, labels = next(data_iter)
96
+ # Save the images
97
+ # save_batch_images(images, save_dir="output_images", prefix="mnist_image", file_format="png")
98
+
99
+ # Initialize the model, loss function, and optimizer
100
+ device = get_device()
101
+ # model = SimpleNN()
102
+ model = SimpleCNN()
103
+ model.to(device=device)
104
+ criterion = nn.CrossEntropyLoss()
105
+ optimizer = optim.Adam(model.parameters(), lr=0.001)
106
+
107
+ epochs = 20
108
+ train(model=model,device=device,train_loader=train_loader,test_loader=test_loader,criterion=criterion,optimizer=optimizer,epochs = epochs)
109
+ save_model(model=model,model_save_path="mnist_model.pht")
view_image.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import matplotlib.pyplot as plt
4
+ import torchvision
5
+ from torchvision.transforms import ToPILImage
6
+
7
+ def view_image(image):
8
+ plt.imshow(image, cmap="gray")
9
+ plt.title("Grayscale Image")
10
+ plt.axis("off") # Hide axes for better visualization
11
+ plt.show()
12
+
13
+ def view_tensor_image(image_tensor, title="Image"):
14
+ image_np = image_tensor.squeeze().numpy()
15
+ plt.imshow(image_np)
16
+ plt.title(title)
17
+ plt.axis('off')
18
+ plt.show()
19
+
20
+ def view_batch_images(train_loader, num_images=8):
21
+ """
22
+ Display a batch of images from the train_loader.
23
+
24
+ Parameters:
25
+ train_loader (DataLoader): The DataLoader containing the images.
26
+ num_images (int): Number of images to display from the batch.
27
+ """
28
+ data_iter = iter(train_loader)
29
+ images, labels = next(data_iter) # Get a batch of images and labels
30
+
31
+ # Make a grid of images
32
+ img_grid = torchvision.utils.make_grid(images[:num_images], nrow=num_images, normalize=True)
33
+ img_np = img_grid.numpy().transpose((1, 2, 0)) # Rearrange dimensions for plotting
34
+
35
+ plt.figure(figsize=(12, 6))
36
+ plt.imshow(img_np, cmap="gray")
37
+ plt.title("Batch of Images")
38
+ plt.axis("off")
39
+ plt.show()
40
+
41
+
42
+ def save_batch_images(images, save_dir, prefix="image", file_format="png", unnormalize=None):
43
+ """
44
+ Save each image in a batch to a specified directory.
45
+
46
+ Parameters:
47
+ images (torch.Tensor): Batch of images with shape (B, C, H, W).
48
+ save_dir (str): Directory to save the images.
49
+ prefix (str): Prefix for the saved image filenames.
50
+ file_format (str): File format for the saved images (e.g., "png", "jpg").
51
+ unnormalize (callable, optional): Function to unnormalize the images before saving.
52
+ """
53
+ os.makedirs(save_dir, exist_ok=True) # Create the directory if it doesn't exist
54
+ to_pil = ToPILImage() # Converts tensors to PIL images
55
+
56
+ for idx, image in enumerate(images):
57
+ if unnormalize:
58
+ image = unnormalize(image) # Apply unnormalization if provided
59
+
60
+ pil_image = to_pil(image) # Convert to PIL Image
61
+ filename = os.path.join(save_dir, f"{prefix}_{idx}.{file_format}")
62
+ pil_image.save(filename)
63
+ print(f"Saved: {filename}")
view_model_information.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import numpy as np
4
+ import torch
5
+ from torchvision import transforms
6
+ from device import get_device
7
+ from simple_nn import SimpleNN
8
+ from simple_cnn import SimpleCNN
9
+ from view_image import view_image, view_tensor_image
10
+
11
+ model_path = "mnist_model.pht"
12
+ # Load model
13
+ model = SimpleCNN()
14
+ with open(model_path, 'rb') as f:
15
+ state_dict = torch.load(f, weights_only=True)
16
+ model.load_state_dict(state_dict)
17
+
18
+ # View model information
19
+ print(model) # Display the model architecture
20
+
21
+ # For more detailed information about the model's parameters:
22
+ print(f"Model summary: {model}")
23
+
24
+ # You can also view the parameters' details (e.g., number of parameters, layers, etc.)
25
+ for name, param in model.named_parameters():
26
+ print(f"Parameter: {name}, Shape: {param.shape}")