File size: 750 Bytes
db26c81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
#pragma once

#include "./common.h"

/**
 * tasks are divided into groups, 
 * if tasks in a group are all visited or all not visited,
 * output is is false, otherwise output is true
 *
 * group: task's group, shape is (batch_size, task_num)
 * value: task is visited or not, shape is (batch_size, task_num)
 *
 * output: the result, shape is (batch_size,)
 */
auto task_group_split(const Tensor& group, const Tensor& value) -> Tensor;

void task_group_split_cpu(
        int* group, bool* value, bool* output,
        const int batch_size, const int task_num, const int group_num);

void task_group_split_cuda(
        int* group, bool* value, bool* output,
        const int batch_size, const int task_num, const int group_num, const int device);