Spaces:
Build error
Build error
File size: 5,064 Bytes
84d2a97 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
use std::collections::HashMap;
use std::fs::File;
use std::io::{self, BufRead as _, BufReader, Lines};
use std::mem::size_of;
use std::path::Path;
use memmap2::Mmap;
use memory::madvise::{Advice, AdviceSetting};
use memory::mmap_ops::{open_read_mmap, transmute_from_u8, transmute_from_u8_to_slice};
use validator::ValidationErrors;
use crate::common::sparse_vector::SparseVector;
/// Compressed Sparse Row matrix, baked by memory-mapped file.
///
/// The layout of the memory-mapped file is as follows:
///
/// | name | type | size | start |
/// |---------|---------------|------------|---------------------|
/// | nrow | `u64` | 8 | 0 |
/// | ncol | `u64` | 8 | 8 |
/// | nnz | `u64` | 8 | 16 |
/// | indptr | `u64[nrow+1]` | 8*(nrow+1) | 24 |
/// | indices | `u32[nnz]` | 4*nnz | 24+8*(nrow+1) |
/// | data | `u32[nnz]` | 4*nnz | 24+8*(nrow+1)+4*nnz |
pub struct Csr {
mmap: Mmap,
nrow: usize,
nnz: usize,
intptr: Vec<u64>,
}
const CSR_HEADER_SIZE: usize = size_of::<u64>() * 3;
impl Csr {
pub fn open(path: impl AsRef<Path>) -> io::Result<Self> {
Self::from_mmap(open_read_mmap(
path.as_ref(),
AdviceSetting::from(Advice::Normal),
false,
)?)
}
#[inline]
#[allow(clippy::len_without_is_empty)]
pub fn len(&self) -> usize {
self.nrow
}
pub fn iter(&self) -> CsrIter<'_> {
CsrIter { csr: self, row: 0 }
}
fn from_mmap(mmap: Mmap) -> io::Result<Self> {
let (nrow, ncol, nnz) =
transmute_from_u8::<(u64, u64, u64)>(&mmap.as_ref()[..CSR_HEADER_SIZE]);
let (nrow, _ncol, nnz) = (*nrow as usize, *ncol as usize, *nnz as usize);
let indptr = Vec::from(transmute_from_u8_to_slice::<u64>(
&mmap.as_ref()[CSR_HEADER_SIZE..CSR_HEADER_SIZE + size_of::<u64>() * (nrow + 1)],
));
if !indptr.windows(2).all(|w| w[0] <= w[1]) || indptr.last() != Some(&(nnz as u64)) {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Invalid indptr array",
));
}
Ok(Self {
mmap,
nrow,
nnz,
intptr: indptr,
})
}
#[inline]
unsafe fn vec(&self, row: usize) -> Result<SparseVector, ValidationErrors> {
let start = *self.intptr.get_unchecked(row) as usize;
let end = *self.intptr.get_unchecked(row + 1) as usize;
let mut pos = CSR_HEADER_SIZE + size_of::<u64>() * (self.nrow + 1);
let indices = transmute_from_u8_to_slice::<u32>(
self.mmap
.as_ref()
.get_unchecked(pos + size_of::<u32>() * start..pos + size_of::<u32>() * end),
);
pos += size_of::<u32>() * self.nnz;
let data = transmute_from_u8_to_slice::<f32>(
self.mmap
.as_ref()
.get_unchecked(pos + size_of::<f32>() * start..pos + size_of::<f32>() * end),
);
SparseVector::new(indices.to_vec(), data.to_vec())
}
}
/// Iterator over the rows of a CSR matrix.
pub struct CsrIter<'a> {
csr: &'a Csr,
row: usize,
}
impl<'a> Iterator for CsrIter<'a> {
type Item = Result<SparseVector, ValidationErrors>;
fn next(&mut self) -> Option<Self::Item> {
(self.row < self.csr.nrow).then(|| {
let vec = unsafe { self.csr.vec(self.row) };
self.row += 1;
vec
})
}
}
impl<'a> ExactSizeIterator for CsrIter<'a> {
fn len(&self) -> usize {
self.csr.nrow - self.row
}
}
pub fn load_csr_vecs(path: impl AsRef<Path>) -> io::Result<Vec<SparseVector>> {
Csr::open(path)?
.iter()
.collect::<Result<Vec<_>, _>>()
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
}
/// Stream of sparse vectors in JSON format.
pub struct JsonReader(Lines<BufReader<File>>);
impl JsonReader {
pub fn open(path: impl AsRef<Path>) -> io::Result<Self> {
Ok(JsonReader(BufReader::new(File::open(path)?).lines()))
}
}
impl Iterator for JsonReader {
type Item = Result<SparseVector, io::Error>;
fn next(&mut self) -> Option<Self::Item> {
self.0.next().map(|line| {
line.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
.and_then(|line| {
let data: HashMap<String, f32> = serde_json::from_str(&line)?;
SparseVector::new(
data.keys()
.map(|k| k.parse())
.collect::<Result<Vec<_>, _>>()
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?,
data.values().copied().collect(),
)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
})
})
}
}
|