Gouzi Mohaled
Ajout du dossier lib
84d2a97
raw
history blame
3.54 kB
use std::fs::{create_dir_all, File};
use std::path::{Path, PathBuf};
use anyhow::{anyhow, Context, Result};
use flate2::read::GzDecoder;
use indicatif::{ProgressBar, ProgressDrawTarget};
pub enum Dataset {
// https://github.com/qdrant/sparse-vectors-experiments
SpladeWikiMovies,
// https://github.com/qdrant/sparse-vectors-benchmark
NeurIps2023Full,
NeurIps2023_1M,
NeurIps2023Small,
NeurIps2023Queries,
}
impl Dataset {
pub fn download(&self) -> Result<PathBuf> {
download_cached(&self.url())
}
fn url(&self) -> String {
const NEUR_IPS_2023_BASE: &str =
"https://storage.googleapis.com/ann-challenge-sparse-vectors/csr";
match self {
Dataset::SpladeWikiMovies => {
"https://storage.googleapis.com/dataset-sparse-vectors/sparse-vectors.jsonl.gz"
.to_string()
}
Dataset::NeurIps2023Full => format!("{NEUR_IPS_2023_BASE}/base_full.csr.gz"),
Dataset::NeurIps2023_1M => format!("{NEUR_IPS_2023_BASE}/base_1M.csr.gz"),
Dataset::NeurIps2023Small => format!("{NEUR_IPS_2023_BASE}/base_small.csr.gz"),
Dataset::NeurIps2023Queries => format!("{NEUR_IPS_2023_BASE}/queries.dev.csr.gz"),
}
}
}
fn download_cached(url: &str) -> Result<PathBuf> {
// Filename without an ".gz" extension, e.g. "base_full.csr".
let basename = {
let path = Path::new(url);
match path.extension() {
Some(gz) if gz == "gz" => path.file_stem(),
_ => path.file_name(),
}
.ok_or_else(|| anyhow!("Failed to extract basename from {url}"))?
};
// Cache directory, e.g. "target/datasets".
let cache_dir = workspace_dir()
.join(std::env::var_os("CARGO_TARGET_DIR").unwrap_or_else(|| "target".into()))
.join("datasets");
// Cache file path, e.g. "target/datasets/base_full.csr".
let cache_path = cache_dir.join(basename);
if cache_path.exists() {
return Ok(cache_path);
}
eprintln!("Downloading {url} to {cache_path:?}...");
create_dir_all(cache_dir)?;
let resp = reqwest::blocking::get(url)?;
if !resp.status().is_success() {
anyhow::bail!("Failed to download {url}, status: {}", resp.status());
}
let total_size = resp.content_length();
// Download to a temporary file, e.g. "target/datasets/base_full.csr.tmp", to avoid
// incomplete files.
let mut tmp_fname = cache_path.clone().into_os_string();
tmp_fname.push(".tmp");
// Progress bar.
let pb = ProgressBar::with_draw_target(total_size, ProgressDrawTarget::stderr_with_hz(12));
pb.set_style(
indicatif::ProgressStyle::default_bar()
.template("{msg} {wide_bar} {bytes}/{total_bytes} (eta:{eta})")
.expect("failed to set style"),
);
// Download + decompress.
std::io::copy(
&mut GzDecoder::new(pb.wrap_read(resp)),
&mut File::create(&tmp_fname)?,
)?;
std::fs::rename(&tmp_fname, &cache_path)
.with_context(|| format!("Failed to rename {tmp_fname:?} to {cache_path:?}"))?;
Ok(cache_path)
}
fn workspace_dir() -> PathBuf {
let output = std::process::Command::new(env!("CARGO"))
.arg("locate-project")
.arg("--workspace")
.arg("--message-format=plain")
.output()
.unwrap()
.stdout;
let cargo_path = Path::new(std::str::from_utf8(&output).unwrap().trim());
cargo_path.parent().unwrap().to_path_buf()
}