minBERT / zemo2.py
GlowCheese's picture
First model version
9756d99
raw
history blame
1.61 kB
import torch
import torch.nn as nn
# Xây dựng mô hình RNN
class RNNModel(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(RNNModel, self).__init__()
self.rnn = nn.RNN(input_size, hidden_size, batch_first=True) # Định nghĩa RNN
self.fc = nn.Linear(hidden_size, output_size) # Lớp fully connected để dự đoán output
def forward(self, x):
out, _ = self.rnn(x) # Lấy output từ RNN
out = out[:, -1, :] # Lấy output của bước cuối cùng (nếu dữ liệu có nhiều bước thời gian)
out = self.fc(out) # Dự đoán output
return out
# Khởi tạo mô hình
input_size = 10 # Kích thước đầu vào
hidden_size = 20 # Số lượng hidden units
output_size = 1 # Đầu ra (ví dụ: hồi quy)
model = RNNModel(input_size, hidden_size, output_size)
# Khởi tạo dữ liệu giả
X = torch.randn(32, 5, 10) # 32 samples, 5 bước thời gian, mỗi bước có 10 đặc trưng
y = torch.randn(32, 1) # 32 samples, 1 giá trị đầu ra cho mỗi sample
# Hàm mất mát và bộ tối ưu
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# Huấn luyện mô hình
for epoch in range(100):
model.train()
optimizer.zero_grad()
output = model(X) # Truyền dữ liệu qua mô hình
loss = criterion(output, y) # Tính mất mát
loss.backward() # Tính gradient
optimizer.step() # Cập nhật trọng số
if (epoch + 1) % 10 == 0:
print(f'Epoch [{epoch+1}/100], Loss: {loss.item():.4f}')