import torch import os.path as op import numpy as np import pickle as pkl import torch.utils.data as data import pandas as pd import random class SteamData(data.Dataset): def __init__(self, data_dir=r'data/ref/steam', stage=None, cans_num=10, sep=", ", no_augment=True): self.__dict__.update(locals()) self.aug = (stage=='train') and not no_augment self.padding_item_id=3581 self.check_files() # 返回session_data['seq']长度 def __len__(self): return len(self.session_data['seq']) # 获取索引i的样本(批次i) def __getitem__(self, i): temp = self.session_data.iloc[i] candidates = self.negative_sampling(temp['seq_unpad'],temp['next']) cans_name=[self.item_id2name[can] for can in candidates] sample = { 'seq': temp['seq'], 'seq_name': temp['seq_title'], 'len_seq': temp['len_seq'], 'seq_str': self.sep.join(temp['seq_title']), 'cans': candidates, 'cans_name': cans_name, 'cans_str': self.sep.join(cans_name), 'len_cans': self.cans_num, 'item_id': temp['next'], 'item_name': temp['next_item_name'], 'correct_answer': temp['next_item_name'] } return sample # 进行负采样, 返回序列ID列表 def negative_sampling(self,seq_unpad,next_item): # canset: 所有游戏id中不在seq_unpad中的游戏id canset=[i for i in list(self.item_id2name.keys()) if i not in seq_unpad and i!=next_item] # 随机选择cans_num-1个游戏id, 加上next_item candidates=random.sample(canset, self.cans_num-1)+[next_item] random.shuffle(candidates) return candidates # 检查并加载数据文件 def check_files(self): self.item_id2name=self.get_game_id2name() if self.stage=='train': filename="train_data.df" elif self.stage=='val': filename="Val_data.df" elif self.stage=='test': filename="Test_data.df" data_path=op.join(self.data_dir, filename) # 根据data_path和id2name字典加载数据 self.session_data = self.session_data4frame(data_path, self.item_id2name) # 获取游戏id到游戏名的映射, 返回字典 def get_game_id2name(self): game_id2name = dict() item_path=op.join(self.data_dir, 'id2name.txt') with open(item_path, 'r') as f: for l in f.readlines(): ll = l.strip('\n').split('::') game_id2name[int(ll[0])] = ll[1].strip() return game_id2name # 对数据进行预处理 def session_data4frame(self, datapath, game_id2name): # 根据datapath读取pd数据 train_data = pd.read_pickle(datapath) train_data = train_data[train_data['len_seq'] >= 3] # 从序列中移除填充项 def remove_padding(xx): x = xx[:] for i in range(10): try: x.remove(self.padding_item_id) except: break return x # 去除pad的train_data序列 -> train_data['seq_unpad'] train_data['seq_unpad'] = train_data['seq'].apply(remove_padding) # 序列号 -> 游戏名 def seq_to_title(x): return [game_id2name[x_i] for x_i in x] # 转换train_data ID序列为游戏名序列 -> train_data['seq_title'] train_data['seq_title'] = train_data['seq_unpad'].apply(seq_to_title) # 单个序列 -> 游戏名 def next_item_title(x): return game_id2name[x] # 转换train_data['next'] ID序列为游戏名序列 -> train_data['next_item_name'] train_data['next_item_name'] = train_data['next'].apply(next_item_title) return train_data