import numpy as np import pandas as pd import argparse import torch from torch import nn import torch.nn.functional as F import os import logging import time as Time from collections import Counter from SASRecModules_ori import * def extract_axis_1(data, indices): res = [] for i in range(data.shape[0]): res.append(data[i, indices[i], :]) res = torch.stack(res, dim=0).unsqueeze(1) return res class GRU(nn.Module): def __init__(self, hidden_size, item_num, state_size, gru_layers=1): super(GRU, self).__init__() self.hidden_size = hidden_size self.item_num = item_num self.state_size = state_size self.item_embeddings = nn.Embedding( num_embeddings=item_num + 1, embedding_dim=self.hidden_size, ) nn.init.normal_(self.item_embeddings.weight, 0, 0.01) self.gru = nn.GRU( input_size=self.hidden_size, hidden_size=self.hidden_size, num_layers=gru_layers, batch_first=True ) self.s_fc = nn.Linear(self.hidden_size, self.item_num) def forward(self, states, len_states): # Supervised Head emb = self.item_embeddings(states) emb_packed = torch.nn.utils.rnn.pack_padded_sequence(emb, len_states, batch_first=True, enforce_sorted=False) emb_packed, hidden = self.gru(emb_packed) hidden = hidden.view(-1, hidden.shape[2]) supervised_output = self.s_fc(hidden) return supervised_output def forward_eval(self, states, len_states): # Supervised Head emb = self.item_embeddings(states) emb_packed = torch.nn.utils.rnn.pack_padded_sequence(emb, len_states, batch_first=True, enforce_sorted=False) emb_packed, hidden = self.gru(emb_packed) hidden = hidden.view(-1, hidden.shape[2]) supervised_output = self.s_fc(hidden) return supervised_output class Caser(nn.Module): def __init__(self, hidden_size, item_num, state_size, num_filters, filter_sizes, dropout_rate): super(Caser, self).__init__() self.hidden_size = hidden_size self.item_num = int(item_num) self.state_size = state_size self.filter_sizes = eval(filter_sizes) self.num_filters = num_filters self.dropout_rate = dropout_rate self.item_embeddings = nn.Embedding( num_embeddings=item_num + 1, embedding_dim=self.hidden_size, ) # init embedding nn.init.normal_(self.item_embeddings.weight, 0, 0.01) # Horizontal Convolutional Layers self.horizontal_cnn = nn.ModuleList( [nn.Conv2d(1, self.num_filters, (i, self.hidden_size)) for i in self.filter_sizes]) # Initialize weights and biases for cnn in self.horizontal_cnn: nn.init.xavier_normal_(cnn.weight) nn.init.constant_(cnn.bias, 0.1) # Vertical Convolutional Layer self.vertical_cnn = nn.Conv2d(1, 1, (self.state_size, 1)) nn.init.xavier_normal_(self.vertical_cnn.weight) nn.init.constant_(self.vertical_cnn.bias, 0.1) # Fully Connected Layer self.num_filters_total = self.num_filters * len(self.filter_sizes) final_dim = self.hidden_size + self.num_filters_total self.s_fc = nn.Linear(final_dim, item_num) # dropout self.dropout = nn.Dropout(self.dropout_rate) def forward(self, states, len_states): input_emb = self.item_embeddings(states) mask = torch.ne(states, self.item_num).float().unsqueeze(-1) input_emb *= mask input_emb = input_emb.unsqueeze(1) pooled_outputs = [] for cnn in self.horizontal_cnn: h_out = nn.functional.relu(cnn(input_emb)) h_out = h_out.squeeze() p_out = nn.functional.max_pool1d(h_out, h_out.shape[2]) pooled_outputs.append(p_out) h_pool = torch.cat(pooled_outputs, 1) h_pool_flat = h_pool.view(-1, self.num_filters_total) v_out = nn.functional.relu(self.vertical_cnn(input_emb)) v_flat = v_out.view(-1, self.hidden_size) out = torch.cat([h_pool_flat, v_flat], 1) out = self.dropout(out) supervised_output = self.s_fc(out) return supervised_output def forward_eval(self, states, len_states): input_emb = self.item_embeddings(states) mask = torch.ne(states, self.item_num).float().unsqueeze(-1) input_emb *= mask input_emb = input_emb.unsqueeze(1) pooled_outputs = [] for cnn in self.horizontal_cnn: h_out = nn.functional.relu(cnn(input_emb)) h_out = h_out.squeeze() p_out = nn.functional.max_pool1d(h_out, h_out.shape[2]) pooled_outputs.append(p_out) h_pool = torch.cat(pooled_outputs, 1) h_pool_flat = h_pool.view(-1, self.num_filters_total) v_out = nn.functional.relu(self.vertical_cnn(input_emb)) v_flat = v_out.view(-1, self.hidden_size) out = torch.cat([h_pool_flat, v_flat], 1) out = self.dropout(out) supervised_output = self.s_fc(out) return supervised_output class SASRec(nn.Module): def __init__(self, hidden_size, item_num, state_size, dropout, device, num_heads=1): super().__init__() self.state_size = state_size self.hidden_size = hidden_size self.item_num = int(item_num) self.dropout = nn.Dropout(dropout) self.device = device self.item_embeddings = nn.Embedding( num_embeddings=item_num + 1, embedding_dim=hidden_size, ) nn.init.normal_(self.item_embeddings.weight, 0, 1) self.positional_embeddings = nn.Embedding( num_embeddings=state_size, embedding_dim=hidden_size ) self.emb_dropout = nn.Dropout(dropout) self.ln_1 = nn.LayerNorm(hidden_size) self.ln_2 = nn.LayerNorm(hidden_size) self.ln_3 = nn.LayerNorm(hidden_size) self.mh_attn = MultiHeadAttention(hidden_size, hidden_size, num_heads, dropout) self.feed_forward = PositionwiseFeedForward(hidden_size, hidden_size, dropout) self.s_fc = nn.Linear(hidden_size, item_num) def forward(self, states, len_states): inputs_emb = self.item_embeddings(states) inputs_emb += self.positional_embeddings(torch.arange(self.state_size).to(self.device)) seq = self.emb_dropout(inputs_emb) mask = torch.ne(states, self.item_num).float().unsqueeze(-1).to(self.device) seq *= mask seq_normalized = self.ln_1(seq) mh_attn_out = self.mh_attn(seq_normalized, seq) ff_out = self.feed_forward(self.ln_2(mh_attn_out)) ff_out *= mask ff_out = self.ln_3(ff_out) state_hidden = extract_axis_1(ff_out, len_states - 1) supervised_output = self.s_fc(state_hidden).squeeze() return supervised_output def forward_eval(self, states, len_states): inputs_emb = self.item_embeddings(states) inputs_emb += self.positional_embeddings(torch.arange(self.state_size).to(self.device)) seq = self.emb_dropout(inputs_emb) mask = torch.ne(states, self.item_num).float().unsqueeze(-1).to(self.device) seq *= mask seq_normalized = self.ln_1(seq) mh_attn_out = self.mh_attn(seq_normalized, seq) ff_out = self.feed_forward(self.ln_2(mh_attn_out)) ff_out *= mask ff_out = self.ln_3(ff_out) state_hidden = extract_axis_1(ff_out, len_states - 1) supervised_output = self.s_fc(state_hidden).squeeze() return supervised_output def cacul_h(self, states, len_states): device = self.device states = states.to(device) inputs_emb = self.item_embeddings(states) inputs_emb += self.positional_embeddings(torch.arange(self.state_size).to(self.device)) seq = self.emb_dropout(inputs_emb) mask = torch.ne(states, self.item_num).float().unsqueeze(-1).to(self.device) seq *= mask seq_normalized = self.ln_1(seq) mh_attn_out = self.mh_attn(seq_normalized, seq) ff_out = self.feed_forward(self.ln_2(mh_attn_out)) ff_out *= mask ff_out = self.ln_3(ff_out) state_hidden = extract_axis_1(ff_out, len_states - 1) # print("state_hidden.size", state_hidden.size()) return state_hidden def cacu_x(self, x): x = self.item_embeddings(x) return x