File size: 5,064 Bytes
84d2a97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
use std::collections::HashMap;
use std::fs::File;
use std::io::{self, BufRead as _, BufReader, Lines};
use std::mem::size_of;
use std::path::Path;

use memmap2::Mmap;
use memory::madvise::{Advice, AdviceSetting};
use memory::mmap_ops::{open_read_mmap, transmute_from_u8, transmute_from_u8_to_slice};
use validator::ValidationErrors;

use crate::common::sparse_vector::SparseVector;

/// Compressed Sparse Row matrix, baked by memory-mapped file.
///
/// The layout of the memory-mapped file is as follows:
///
/// | name    | type          | size       | start               |
/// |---------|---------------|------------|---------------------|
/// | nrow    | `u64`         | 8          | 0                   |
/// | ncol    | `u64`         | 8          | 8                   |
/// | nnz     | `u64`         | 8          | 16                  |
/// | indptr  | `u64[nrow+1]` | 8*(nrow+1) | 24                  |
/// | indices | `u32[nnz]`    | 4*nnz      | 24+8*(nrow+1)       |
/// | data    | `u32[nnz]`    | 4*nnz      | 24+8*(nrow+1)+4*nnz |
pub struct Csr {
    mmap: Mmap,

    nrow: usize,
    nnz: usize,
    intptr: Vec<u64>,
}

const CSR_HEADER_SIZE: usize = size_of::<u64>() * 3;

impl Csr {
    pub fn open(path: impl AsRef<Path>) -> io::Result<Self> {
        Self::from_mmap(open_read_mmap(
            path.as_ref(),
            AdviceSetting::from(Advice::Normal),
            false,
        )?)
    }

    #[inline]
    #[allow(clippy::len_without_is_empty)]
    pub fn len(&self) -> usize {
        self.nrow
    }

    pub fn iter(&self) -> CsrIter<'_> {
        CsrIter { csr: self, row: 0 }
    }

    fn from_mmap(mmap: Mmap) -> io::Result<Self> {
        let (nrow, ncol, nnz) =
            transmute_from_u8::<(u64, u64, u64)>(&mmap.as_ref()[..CSR_HEADER_SIZE]);
        let (nrow, _ncol, nnz) = (*nrow as usize, *ncol as usize, *nnz as usize);

        let indptr = Vec::from(transmute_from_u8_to_slice::<u64>(
            &mmap.as_ref()[CSR_HEADER_SIZE..CSR_HEADER_SIZE + size_of::<u64>() * (nrow + 1)],
        ));
        if !indptr.windows(2).all(|w| w[0] <= w[1]) || indptr.last() != Some(&(nnz as u64)) {
            return Err(io::Error::new(
                io::ErrorKind::InvalidData,
                "Invalid indptr array",
            ));
        }

        Ok(Self {
            mmap,
            nrow,
            nnz,
            intptr: indptr,
        })
    }

    #[inline]
    unsafe fn vec(&self, row: usize) -> Result<SparseVector, ValidationErrors> {
        let start = *self.intptr.get_unchecked(row) as usize;
        let end = *self.intptr.get_unchecked(row + 1) as usize;

        let mut pos = CSR_HEADER_SIZE + size_of::<u64>() * (self.nrow + 1);

        let indices = transmute_from_u8_to_slice::<u32>(
            self.mmap
                .as_ref()
                .get_unchecked(pos + size_of::<u32>() * start..pos + size_of::<u32>() * end),
        );
        pos += size_of::<u32>() * self.nnz;

        let data = transmute_from_u8_to_slice::<f32>(
            self.mmap
                .as_ref()
                .get_unchecked(pos + size_of::<f32>() * start..pos + size_of::<f32>() * end),
        );

        SparseVector::new(indices.to_vec(), data.to_vec())
    }
}

/// Iterator over the rows of a CSR matrix.
pub struct CsrIter<'a> {
    csr: &'a Csr,
    row: usize,
}

impl<'a> Iterator for CsrIter<'a> {
    type Item = Result<SparseVector, ValidationErrors>;

    fn next(&mut self) -> Option<Self::Item> {
        (self.row < self.csr.nrow).then(|| {
            let vec = unsafe { self.csr.vec(self.row) };
            self.row += 1;
            vec
        })
    }
}

impl<'a> ExactSizeIterator for CsrIter<'a> {
    fn len(&self) -> usize {
        self.csr.nrow - self.row
    }
}

pub fn load_csr_vecs(path: impl AsRef<Path>) -> io::Result<Vec<SparseVector>> {
    Csr::open(path)?
        .iter()
        .collect::<Result<Vec<_>, _>>()
        .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
}

/// Stream of sparse vectors in JSON format.
pub struct JsonReader(Lines<BufReader<File>>);

impl JsonReader {
    pub fn open(path: impl AsRef<Path>) -> io::Result<Self> {
        Ok(JsonReader(BufReader::new(File::open(path)?).lines()))
    }
}

impl Iterator for JsonReader {
    type Item = Result<SparseVector, io::Error>;

    fn next(&mut self) -> Option<Self::Item> {
        self.0.next().map(|line| {
            line.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
                .and_then(|line| {
                    let data: HashMap<String, f32> = serde_json::from_str(&line)?;
                    SparseVector::new(
                        data.keys()
                            .map(|k| k.parse())
                            .collect::<Result<Vec<_>, _>>()
                            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?,
                        data.values().copied().collect(),
                    )
                    .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
                })
        })
    }
}