iLoRA / recommender /A_SASRec_final_bce_llm.py
MingLi
fork and bug fix from https://github.com/AkaliKong/iLoRA
9f13819
import numpy as np
import pandas as pd
import argparse
import torch
from torch import nn
import torch.nn.functional as F
import os
import logging
import time as Time
from collections import Counter
from SASRecModules_ori import *
def extract_axis_1(data, indices):
res = []
for i in range(data.shape[0]):
res.append(data[i, indices[i], :])
res = torch.stack(res, dim=0).unsqueeze(1)
return res
class GRU(nn.Module):
def __init__(self, hidden_size, item_num, state_size, gru_layers=1):
super(GRU, self).__init__()
self.hidden_size = hidden_size
self.item_num = item_num
self.state_size = state_size
self.item_embeddings = nn.Embedding(
num_embeddings=item_num + 1,
embedding_dim=self.hidden_size,
)
nn.init.normal_(self.item_embeddings.weight, 0, 0.01)
self.gru = nn.GRU(
input_size=self.hidden_size,
hidden_size=self.hidden_size,
num_layers=gru_layers,
batch_first=True
)
self.s_fc = nn.Linear(self.hidden_size, self.item_num)
def forward(self, states, len_states):
# Supervised Head
emb = self.item_embeddings(states)
emb_packed = torch.nn.utils.rnn.pack_padded_sequence(emb, len_states, batch_first=True, enforce_sorted=False)
emb_packed, hidden = self.gru(emb_packed)
hidden = hidden.view(-1, hidden.shape[2])
supervised_output = self.s_fc(hidden)
return supervised_output
def forward_eval(self, states, len_states):
# Supervised Head
emb = self.item_embeddings(states)
emb_packed = torch.nn.utils.rnn.pack_padded_sequence(emb, len_states, batch_first=True, enforce_sorted=False)
emb_packed, hidden = self.gru(emb_packed)
hidden = hidden.view(-1, hidden.shape[2])
supervised_output = self.s_fc(hidden)
return supervised_output
class Caser(nn.Module):
def __init__(self, hidden_size, item_num, state_size, num_filters, filter_sizes,
dropout_rate):
super(Caser, self).__init__()
self.hidden_size = hidden_size
self.item_num = int(item_num)
self.state_size = state_size
self.filter_sizes = eval(filter_sizes)
self.num_filters = num_filters
self.dropout_rate = dropout_rate
self.item_embeddings = nn.Embedding(
num_embeddings=item_num + 1,
embedding_dim=self.hidden_size,
)
# init embedding
nn.init.normal_(self.item_embeddings.weight, 0, 0.01)
# Horizontal Convolutional Layers
self.horizontal_cnn = nn.ModuleList(
[nn.Conv2d(1, self.num_filters, (i, self.hidden_size)) for i in self.filter_sizes])
# Initialize weights and biases
for cnn in self.horizontal_cnn:
nn.init.xavier_normal_(cnn.weight)
nn.init.constant_(cnn.bias, 0.1)
# Vertical Convolutional Layer
self.vertical_cnn = nn.Conv2d(1, 1, (self.state_size, 1))
nn.init.xavier_normal_(self.vertical_cnn.weight)
nn.init.constant_(self.vertical_cnn.bias, 0.1)
# Fully Connected Layer
self.num_filters_total = self.num_filters * len(self.filter_sizes)
final_dim = self.hidden_size + self.num_filters_total
self.s_fc = nn.Linear(final_dim, item_num)
# dropout
self.dropout = nn.Dropout(self.dropout_rate)
def forward(self, states, len_states):
input_emb = self.item_embeddings(states)
mask = torch.ne(states, self.item_num).float().unsqueeze(-1)
input_emb *= mask
input_emb = input_emb.unsqueeze(1)
pooled_outputs = []
for cnn in self.horizontal_cnn:
h_out = nn.functional.relu(cnn(input_emb))
h_out = h_out.squeeze()
p_out = nn.functional.max_pool1d(h_out, h_out.shape[2])
pooled_outputs.append(p_out)
h_pool = torch.cat(pooled_outputs, 1)
h_pool_flat = h_pool.view(-1, self.num_filters_total)
v_out = nn.functional.relu(self.vertical_cnn(input_emb))
v_flat = v_out.view(-1, self.hidden_size)
out = torch.cat([h_pool_flat, v_flat], 1)
out = self.dropout(out)
supervised_output = self.s_fc(out)
return supervised_output
def forward_eval(self, states, len_states):
input_emb = self.item_embeddings(states)
mask = torch.ne(states, self.item_num).float().unsqueeze(-1)
input_emb *= mask
input_emb = input_emb.unsqueeze(1)
pooled_outputs = []
for cnn in self.horizontal_cnn:
h_out = nn.functional.relu(cnn(input_emb))
h_out = h_out.squeeze()
p_out = nn.functional.max_pool1d(h_out, h_out.shape[2])
pooled_outputs.append(p_out)
h_pool = torch.cat(pooled_outputs, 1)
h_pool_flat = h_pool.view(-1, self.num_filters_total)
v_out = nn.functional.relu(self.vertical_cnn(input_emb))
v_flat = v_out.view(-1, self.hidden_size)
out = torch.cat([h_pool_flat, v_flat], 1)
out = self.dropout(out)
supervised_output = self.s_fc(out)
return supervised_output
class SASRec(nn.Module):
def __init__(self, hidden_size, item_num, state_size, dropout, device, num_heads=1):
super().__init__()
self.state_size = state_size
self.hidden_size = hidden_size
self.item_num = int(item_num)
self.dropout = nn.Dropout(dropout)
self.device = device
self.item_embeddings = nn.Embedding(
num_embeddings=item_num + 1,
embedding_dim=hidden_size,
)
nn.init.normal_(self.item_embeddings.weight, 0, 1)
self.positional_embeddings = nn.Embedding(
num_embeddings=state_size,
embedding_dim=hidden_size
)
self.emb_dropout = nn.Dropout(dropout)
self.ln_1 = nn.LayerNorm(hidden_size)
self.ln_2 = nn.LayerNorm(hidden_size)
self.ln_3 = nn.LayerNorm(hidden_size)
self.mh_attn = MultiHeadAttention(hidden_size, hidden_size, num_heads, dropout)
self.feed_forward = PositionwiseFeedForward(hidden_size, hidden_size, dropout)
self.s_fc = nn.Linear(hidden_size, item_num)
def forward(self, states, len_states):
inputs_emb = self.item_embeddings(states)
inputs_emb += self.positional_embeddings(torch.arange(self.state_size).to(self.device))
seq = self.emb_dropout(inputs_emb)
mask = torch.ne(states, self.item_num).float().unsqueeze(-1).to(self.device)
seq *= mask
seq_normalized = self.ln_1(seq)
mh_attn_out = self.mh_attn(seq_normalized, seq)
ff_out = self.feed_forward(self.ln_2(mh_attn_out))
ff_out *= mask
ff_out = self.ln_3(ff_out)
state_hidden = extract_axis_1(ff_out, len_states - 1)
supervised_output = self.s_fc(state_hidden).squeeze()
return supervised_output
def forward_eval(self, states, len_states):
inputs_emb = self.item_embeddings(states)
inputs_emb += self.positional_embeddings(torch.arange(self.state_size).to(self.device))
seq = self.emb_dropout(inputs_emb)
mask = torch.ne(states, self.item_num).float().unsqueeze(-1).to(self.device)
seq *= mask
seq_normalized = self.ln_1(seq)
mh_attn_out = self.mh_attn(seq_normalized, seq)
ff_out = self.feed_forward(self.ln_2(mh_attn_out))
ff_out *= mask
ff_out = self.ln_3(ff_out)
state_hidden = extract_axis_1(ff_out, len_states - 1)
supervised_output = self.s_fc(state_hidden).squeeze()
return supervised_output
def cacul_h(self, states, len_states):
device = self.device
states = states.to(device)
inputs_emb = self.item_embeddings(states)
inputs_emb += self.positional_embeddings(torch.arange(self.state_size).to(self.device))
seq = self.emb_dropout(inputs_emb)
mask = torch.ne(states, self.item_num).float().unsqueeze(-1).to(self.device)
seq *= mask
seq_normalized = self.ln_1(seq)
mh_attn_out = self.mh_attn(seq_normalized, seq)
ff_out = self.feed_forward(self.ln_2(mh_attn_out))
ff_out *= mask
ff_out = self.ln_3(ff_out)
state_hidden = extract_axis_1(ff_out, len_states - 1)
# print("state_hidden.size", state_hidden.size())
return state_hidden
def cacu_x(self, x):
x = self.item_embeddings(x)
return x