iLoRA / data /movielens_data.py
MingLi
fork and bug fix from https://github.com/AkaliKong/iLoRA
9f13819
import torch
import os.path as op
import numpy as np
import pickle as pkl
import torch.utils.data as data
import pandas as pd
import random
class MovielensData(data.Dataset):
def __init__(self, data_dir=r'data/ref/movielens',
stage=None,
cans_num=10,
sep=", ",
no_augment=True):
self.__dict__.update(locals())
self.aug = (stage=='train') and not no_augment
self.padding_item_id=1682
self.padding_rating=0
self.check_files()
def __len__(self):
return len(self.session_data['seq'])
def __getitem__(self, i):
temp = self.session_data.iloc[i]
candidates = self.negative_sampling(temp['seq_unpad'],temp['next'])
cans_name=[self.item_id2name[can] for can in candidates]
sample = {
'seq': temp['seq'],
'seq_name': temp['seq_title'],
'len_seq': temp['len_seq'],
'seq_str': self.sep.join(temp['seq_title']),
'cans': candidates,
'cans_name': cans_name,
'cans_str': self.sep.join(cans_name),
'len_cans': self.cans_num,
'item_id': temp['next'],
'item_name': temp['next_item_name'],
'correct_answer': temp['next_item_name']
}
return sample
def negative_sampling(self,seq_unpad,next_item):
canset=[i for i in list(self.item_id2name.keys()) if i not in seq_unpad and i!=next_item]
candidates=random.sample(canset, self.cans_num-1)+[next_item]
random.shuffle(candidates)
return candidates
def check_files(self):
self.item_id2name=self.get_movie_id2name()
if self.stage=='train':
filename="train_data.df"
elif self.stage=='val':
filename="Val_data.df"
elif self.stage=='test':
filename="Test_data.df"
data_path=op.join(self.data_dir, filename)
self.session_data = self.session_data4frame(data_path, self.item_id2name)
def get_mv_title(self,s):
sub_list=[", The", ", A", ", An"]
for sub_s in sub_list:
if sub_s in s:
return sub_s[2:]+" "+s.replace(sub_s,"")
return s
def get_movie_id2name(self):
movie_id2name = dict()
item_path=op.join(self.data_dir, 'u.item')
with open(item_path, 'r', encoding = "ISO-8859-1") as f:
for l in f.readlines():
ll = l.strip('\n').split('|')
movie_id2name[int(ll[0]) - 1] = self.get_mv_title(ll[1][:-7])
return movie_id2name
def session_data4frame(self, datapath, movie_id2name):
train_data = pd.read_pickle(datapath)
train_data = train_data[train_data['len_seq'] >= 3]
def remove_padding(xx):
x = xx[:]
for i in range(10):
try:
x.remove((self.padding_item_id,self.padding_rating))
except:
break
return x
train_data['seq_unpad'] = train_data['seq'].apply(remove_padding)
def seq_to_title(x):
return [movie_id2name[x_i[0]] for x_i in x]
train_data['seq_title'] = train_data['seq_unpad'].apply(seq_to_title)
def next_item_title(x):
return movie_id2name[x[0]]
train_data['next_item_name'] = train_data['next'].apply(next_item_title)
def get_id_from_tumple(x):
return x[0]
def get_id_from_list(x):
return [i[0] for i in x]
train_data['next'] = train_data['next'].apply(get_id_from_tumple)
train_data['seq'] = train_data['seq'].apply(get_id_from_list)
train_data['seq_unpad']=train_data['seq_unpad'].apply(get_id_from_list)
return train_data