|
import numpy as np |
|
import pandas as pd |
|
import argparse |
|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
import os |
|
import logging |
|
import time as Time |
|
from collections import Counter |
|
from SASRecModules_ori import * |
|
|
|
def extract_axis_1(data, indices): |
|
res = [] |
|
for i in range(data.shape[0]): |
|
res.append(data[i, indices[i], :]) |
|
res = torch.stack(res, dim=0).unsqueeze(1) |
|
return res |
|
|
|
class GRU(nn.Module): |
|
def __init__(self, hidden_size, item_num, state_size, gru_layers=1): |
|
super(GRU, self).__init__() |
|
self.hidden_size = hidden_size |
|
self.item_num = item_num |
|
self.state_size = state_size |
|
self.item_embeddings = nn.Embedding( |
|
num_embeddings=item_num + 1, |
|
embedding_dim=self.hidden_size, |
|
) |
|
nn.init.normal_(self.item_embeddings.weight, 0, 0.01) |
|
self.gru = nn.GRU( |
|
input_size=self.hidden_size, |
|
hidden_size=self.hidden_size, |
|
num_layers=gru_layers, |
|
batch_first=True |
|
) |
|
self.s_fc = nn.Linear(self.hidden_size, self.item_num) |
|
|
|
def forward(self, states, len_states): |
|
|
|
emb = self.item_embeddings(states) |
|
emb_packed = torch.nn.utils.rnn.pack_padded_sequence(emb, len_states, batch_first=True, enforce_sorted=False) |
|
emb_packed, hidden = self.gru(emb_packed) |
|
hidden = hidden.view(-1, hidden.shape[2]) |
|
supervised_output = self.s_fc(hidden) |
|
return supervised_output |
|
|
|
def forward_eval(self, states, len_states): |
|
|
|
emb = self.item_embeddings(states) |
|
emb_packed = torch.nn.utils.rnn.pack_padded_sequence(emb, len_states, batch_first=True, enforce_sorted=False) |
|
emb_packed, hidden = self.gru(emb_packed) |
|
hidden = hidden.view(-1, hidden.shape[2]) |
|
supervised_output = self.s_fc(hidden) |
|
|
|
return supervised_output |
|
|
|
|
|
class Caser(nn.Module): |
|
def __init__(self, hidden_size, item_num, state_size, num_filters, filter_sizes, |
|
dropout_rate): |
|
super(Caser, self).__init__() |
|
self.hidden_size = hidden_size |
|
self.item_num = int(item_num) |
|
self.state_size = state_size |
|
self.filter_sizes = eval(filter_sizes) |
|
self.num_filters = num_filters |
|
self.dropout_rate = dropout_rate |
|
self.item_embeddings = nn.Embedding( |
|
num_embeddings=item_num + 1, |
|
embedding_dim=self.hidden_size, |
|
) |
|
|
|
|
|
nn.init.normal_(self.item_embeddings.weight, 0, 0.01) |
|
|
|
|
|
self.horizontal_cnn = nn.ModuleList( |
|
[nn.Conv2d(1, self.num_filters, (i, self.hidden_size)) for i in self.filter_sizes]) |
|
|
|
for cnn in self.horizontal_cnn: |
|
nn.init.xavier_normal_(cnn.weight) |
|
nn.init.constant_(cnn.bias, 0.1) |
|
|
|
|
|
self.vertical_cnn = nn.Conv2d(1, 1, (self.state_size, 1)) |
|
nn.init.xavier_normal_(self.vertical_cnn.weight) |
|
nn.init.constant_(self.vertical_cnn.bias, 0.1) |
|
|
|
|
|
self.num_filters_total = self.num_filters * len(self.filter_sizes) |
|
final_dim = self.hidden_size + self.num_filters_total |
|
self.s_fc = nn.Linear(final_dim, item_num) |
|
|
|
|
|
self.dropout = nn.Dropout(self.dropout_rate) |
|
|
|
def forward(self, states, len_states): |
|
input_emb = self.item_embeddings(states) |
|
mask = torch.ne(states, self.item_num).float().unsqueeze(-1) |
|
input_emb *= mask |
|
input_emb = input_emb.unsqueeze(1) |
|
pooled_outputs = [] |
|
for cnn in self.horizontal_cnn: |
|
h_out = nn.functional.relu(cnn(input_emb)) |
|
h_out = h_out.squeeze() |
|
p_out = nn.functional.max_pool1d(h_out, h_out.shape[2]) |
|
pooled_outputs.append(p_out) |
|
|
|
h_pool = torch.cat(pooled_outputs, 1) |
|
h_pool_flat = h_pool.view(-1, self.num_filters_total) |
|
|
|
v_out = nn.functional.relu(self.vertical_cnn(input_emb)) |
|
v_flat = v_out.view(-1, self.hidden_size) |
|
|
|
out = torch.cat([h_pool_flat, v_flat], 1) |
|
out = self.dropout(out) |
|
supervised_output = self.s_fc(out) |
|
|
|
return supervised_output |
|
|
|
def forward_eval(self, states, len_states): |
|
input_emb = self.item_embeddings(states) |
|
mask = torch.ne(states, self.item_num).float().unsqueeze(-1) |
|
input_emb *= mask |
|
input_emb = input_emb.unsqueeze(1) |
|
pooled_outputs = [] |
|
for cnn in self.horizontal_cnn: |
|
h_out = nn.functional.relu(cnn(input_emb)) |
|
h_out = h_out.squeeze() |
|
p_out = nn.functional.max_pool1d(h_out, h_out.shape[2]) |
|
pooled_outputs.append(p_out) |
|
|
|
h_pool = torch.cat(pooled_outputs, 1) |
|
h_pool_flat = h_pool.view(-1, self.num_filters_total) |
|
|
|
v_out = nn.functional.relu(self.vertical_cnn(input_emb)) |
|
v_flat = v_out.view(-1, self.hidden_size) |
|
|
|
out = torch.cat([h_pool_flat, v_flat], 1) |
|
out = self.dropout(out) |
|
supervised_output = self.s_fc(out) |
|
|
|
return supervised_output |
|
|
|
|
|
class SASRec(nn.Module): |
|
def __init__(self, hidden_size, item_num, state_size, dropout, device, num_heads=1): |
|
super().__init__() |
|
self.state_size = state_size |
|
self.hidden_size = hidden_size |
|
self.item_num = int(item_num) |
|
self.dropout = nn.Dropout(dropout) |
|
self.device = device |
|
self.item_embeddings = nn.Embedding( |
|
num_embeddings=item_num + 1, |
|
embedding_dim=hidden_size, |
|
) |
|
nn.init.normal_(self.item_embeddings.weight, 0, 1) |
|
self.positional_embeddings = nn.Embedding( |
|
num_embeddings=state_size, |
|
embedding_dim=hidden_size |
|
) |
|
self.emb_dropout = nn.Dropout(dropout) |
|
self.ln_1 = nn.LayerNorm(hidden_size) |
|
self.ln_2 = nn.LayerNorm(hidden_size) |
|
self.ln_3 = nn.LayerNorm(hidden_size) |
|
self.mh_attn = MultiHeadAttention(hidden_size, hidden_size, num_heads, dropout) |
|
self.feed_forward = PositionwiseFeedForward(hidden_size, hidden_size, dropout) |
|
self.s_fc = nn.Linear(hidden_size, item_num) |
|
|
|
def forward(self, states, len_states): |
|
inputs_emb = self.item_embeddings(states) |
|
inputs_emb += self.positional_embeddings(torch.arange(self.state_size).to(self.device)) |
|
seq = self.emb_dropout(inputs_emb) |
|
mask = torch.ne(states, self.item_num).float().unsqueeze(-1).to(self.device) |
|
seq *= mask |
|
seq_normalized = self.ln_1(seq) |
|
mh_attn_out = self.mh_attn(seq_normalized, seq) |
|
ff_out = self.feed_forward(self.ln_2(mh_attn_out)) |
|
ff_out *= mask |
|
ff_out = self.ln_3(ff_out) |
|
state_hidden = extract_axis_1(ff_out, len_states - 1) |
|
supervised_output = self.s_fc(state_hidden).squeeze() |
|
return supervised_output |
|
|
|
def forward_eval(self, states, len_states): |
|
inputs_emb = self.item_embeddings(states) |
|
inputs_emb += self.positional_embeddings(torch.arange(self.state_size).to(self.device)) |
|
seq = self.emb_dropout(inputs_emb) |
|
mask = torch.ne(states, self.item_num).float().unsqueeze(-1).to(self.device) |
|
seq *= mask |
|
seq_normalized = self.ln_1(seq) |
|
mh_attn_out = self.mh_attn(seq_normalized, seq) |
|
ff_out = self.feed_forward(self.ln_2(mh_attn_out)) |
|
ff_out *= mask |
|
ff_out = self.ln_3(ff_out) |
|
state_hidden = extract_axis_1(ff_out, len_states - 1) |
|
supervised_output = self.s_fc(state_hidden).squeeze() |
|
return supervised_output |
|
|
|
def cacul_h(self, states, len_states): |
|
device = self.device |
|
states = states.to(device) |
|
inputs_emb = self.item_embeddings(states) |
|
inputs_emb += self.positional_embeddings(torch.arange(self.state_size).to(self.device)) |
|
seq = self.emb_dropout(inputs_emb) |
|
mask = torch.ne(states, self.item_num).float().unsqueeze(-1).to(self.device) |
|
seq *= mask |
|
seq_normalized = self.ln_1(seq) |
|
mh_attn_out = self.mh_attn(seq_normalized, seq) |
|
ff_out = self.feed_forward(self.ln_2(mh_attn_out)) |
|
ff_out *= mask |
|
ff_out = self.ln_3(ff_out) |
|
state_hidden = extract_axis_1(ff_out, len_states - 1) |
|
|
|
return state_hidden |
|
|
|
def cacu_x(self, x): |
|
x = self.item_embeddings(x) |
|
|
|
return x |
|
|