Spaces:
Build error
Build error
mod resharding; | |
use std::collections::{HashMap, HashSet}; | |
use std::fs::File; | |
use std::ops::Deref as _; | |
use std::path::{Path, PathBuf}; | |
use std::sync::Arc; | |
use std::time::Duration; | |
use common::cpu::CpuBudget; | |
use common::tar_ext::BuilderExt; | |
use futures::{Future, TryStreamExt as _}; | |
use itertools::Itertools; | |
use segment::common::validate_snapshot_archive::open_snapshot_archive_with_validation; | |
use segment::types::{ShardKey, SnapshotFormat}; | |
use tokio::runtime::Handle; | |
use tokio::sync::{broadcast, OwnedRwLockReadGuard, RwLock}; | |
use tokio_util::codec::{BytesCodec, FramedRead}; | |
use tokio_util::io::SyncIoBridge; | |
use super::replica_set::{AbortShardTransfer, ChangePeerFromState}; | |
use super::resharding::tasks_pool::ReshardTasksPool; | |
use super::resharding::{ReshardStage, ReshardState}; | |
use super::transfer::transfer_tasks_pool::TransferTasksPool; | |
use crate::collection::payload_index_schema::PayloadIndexSchema; | |
use crate::common::snapshot_stream::SnapshotStream; | |
use crate::config::{CollectionConfigInternal, ShardingMethod}; | |
use crate::hash_ring::HashRingRouter; | |
use crate::operations::cluster_ops::ReshardingDirection; | |
use crate::operations::shard_selector_internal::ShardSelectorInternal; | |
use crate::operations::shared_storage_config::SharedStorageConfig; | |
use crate::operations::snapshot_ops::SnapshotDescription; | |
use crate::operations::types::{ | |
CollectionError, CollectionResult, ReshardingInfo, ShardTransferInfo, | |
}; | |
use crate::operations::{OperationToShard, SplitByShard}; | |
use crate::optimizers_builder::OptimizersConfig; | |
use crate::save_on_disk::SaveOnDisk; | |
use crate::shards::channel_service::ChannelService; | |
use crate::shards::local_shard::LocalShard; | |
use crate::shards::replica_set::{ReplicaState, ShardReplicaSet}; | |
use crate::shards::shard::{PeerId, ShardId}; | |
use crate::shards::shard_config::{ShardConfig, ShardType}; | |
use crate::shards::shard_versioning::latest_shard_paths; | |
use crate::shards::transfer::{ShardTransfer, ShardTransferKey}; | |
use crate::shards::CollectionId; | |
const SHARD_TRANSFERS_FILE: &str = "shard_transfers"; | |
const RESHARDING_STATE_FILE: &str = "resharding_state.json"; | |
pub const SHARD_KEY_MAPPING_FILE: &str = "shard_key_mapping.json"; | |
pub type ShardKeyMapping = HashMap<ShardKey, HashSet<ShardId>>; | |
pub struct ShardHolder { | |
shards: HashMap<ShardId, ShardReplicaSet>, | |
pub(crate) shard_transfers: SaveOnDisk<HashSet<ShardTransfer>>, | |
pub(crate) shard_transfer_changes: broadcast::Sender<ShardTransferChange>, | |
pub(crate) resharding_state: SaveOnDisk<Option<ReshardState>>, | |
pub(crate) rings: HashMap<Option<ShardKey>, HashRingRouter>, | |
key_mapping: SaveOnDisk<ShardKeyMapping>, | |
// Duplicates the information from `key_mapping` for faster access | |
// Do not require locking | |
shard_id_to_key_mapping: HashMap<ShardId, ShardKey>, | |
} | |
pub type LockedShardHolder = RwLock<ShardHolder>; | |
impl ShardHolder { | |
pub async fn trigger_optimizers(&self) { | |
for shard in self.shards.values() { | |
shard.trigger_optimizers().await; | |
} | |
} | |
pub fn new(collection_path: &Path) -> CollectionResult<Self> { | |
let shard_transfers = | |
SaveOnDisk::load_or_init_default(collection_path.join(SHARD_TRANSFERS_FILE))?; | |
let resharding_state: SaveOnDisk<Option<ReshardState>> = | |
SaveOnDisk::load_or_init_default(collection_path.join(RESHARDING_STATE_FILE))?; | |
let key_mapping: SaveOnDisk<ShardKeyMapping> = | |
SaveOnDisk::load_or_init_default(collection_path.join(SHARD_KEY_MAPPING_FILE))?; | |
let mut shard_id_to_key_mapping = HashMap::new(); | |
for (shard_key, shard_ids) in key_mapping.read().iter() { | |
for shard_id in shard_ids { | |
shard_id_to_key_mapping.insert(*shard_id, shard_key.clone()); | |
} | |
} | |
let rings = HashMap::from([(None, HashRingRouter::single())]); | |
let (shard_transfer_changes, _) = broadcast::channel(64); | |
Ok(Self { | |
shards: HashMap::new(), | |
shard_transfers, | |
shard_transfer_changes, | |
resharding_state, | |
rings, | |
key_mapping, | |
shard_id_to_key_mapping, | |
}) | |
} | |
pub async fn save_key_mapping_to_tar( | |
&self, | |
tar: &common::tar_ext::BuilderExt, | |
) -> CollectionResult<()> { | |
self.key_mapping | |
.save_to_tar(tar, Path::new(SHARD_KEY_MAPPING_FILE)) | |
.await?; | |
Ok(()) | |
} | |
pub fn get_shard_id_to_key_mapping(&self) -> &HashMap<ShardId, ShardKey> { | |
&self.shard_id_to_key_mapping | |
} | |
pub fn get_shard_key_to_ids_mapping(&self) -> ShardKeyMapping { | |
self.key_mapping.read().clone() | |
} | |
pub async fn drop_and_remove_shard( | |
&mut self, | |
shard_id: ShardId, | |
) -> Result<(), CollectionError> { | |
if let Some(replica_set) = self.shards.remove(&shard_id) { | |
let shard_path = replica_set.shard_path.clone(); | |
drop(replica_set); | |
// Explicitly drop shard config file first | |
// If removing all shard files at once, it may be possible for the shard configuration | |
// file to be left behind if the process is killed in the middle. We must avoid this so | |
// we don't attempt to load this shard anymore on restart. | |
let shard_config_path = ShardConfig::get_config_path(&shard_path); | |
if let Err(err) = tokio::fs::remove_file(shard_config_path).await { | |
log::error!("Failed to remove shard config file before removing the rest of the files: {err}"); | |
} | |
tokio::fs::remove_dir_all(shard_path).await?; | |
} | |
Ok(()) | |
} | |
pub fn remove_shard_from_key_mapping( | |
&mut self, | |
shard_id: ShardId, | |
shard_key: &ShardKey, | |
) -> Result<(), CollectionError> { | |
self.key_mapping.write_optional(|key_mapping| { | |
if !key_mapping.contains_key(shard_key) { | |
return None; | |
} | |
let mut key_mapping = key_mapping.clone(); | |
key_mapping.get_mut(shard_key).unwrap().remove(&shard_id); | |
Some(key_mapping) | |
})?; | |
self.shard_id_to_key_mapping.remove(&shard_id); | |
Ok(()) | |
} | |
pub fn add_shard( | |
&mut self, | |
shard_id: ShardId, | |
shard: ShardReplicaSet, | |
shard_key: Option<ShardKey>, | |
) -> Result<(), CollectionError> { | |
self.shards.insert(shard_id, shard); | |
self.rings | |
.entry(shard_key.clone()) | |
.or_insert_with(HashRingRouter::single) | |
.add(shard_id); | |
if let Some(shard_key) = shard_key { | |
self.key_mapping.write_optional(|key_mapping| { | |
let has_id = key_mapping | |
.get(&shard_key) | |
.map(|shard_ids| shard_ids.contains(&shard_id)) | |
.unwrap_or(false); | |
if has_id { | |
return None; | |
} | |
let mut copy_of_mapping = key_mapping.clone(); | |
let shard_ids = copy_of_mapping.entry(shard_key.clone()).or_default(); | |
shard_ids.insert(shard_id); | |
Some(copy_of_mapping) | |
})?; | |
self.shard_id_to_key_mapping.insert(shard_id, shard_key); | |
} | |
Ok(()) | |
} | |
pub async fn remove_shard_key(&mut self, shard_key: &ShardKey) -> Result<(), CollectionError> { | |
let mut remove_shard_ids = Vec::new(); | |
self.key_mapping.write_optional(|key_mapping| { | |
if key_mapping.contains_key(shard_key) { | |
let mut new_key_mapping = key_mapping.clone(); | |
if let Some(shard_ids) = new_key_mapping.remove(shard_key) { | |
for shard_id in shard_ids { | |
remove_shard_ids.push(shard_id); | |
} | |
} | |
Some(new_key_mapping) | |
} else { | |
None | |
} | |
})?; | |
self.rings.remove(&shard_key.clone().into()); | |
for shard_id in remove_shard_ids { | |
self.drop_and_remove_shard(shard_id).await?; | |
self.shard_id_to_key_mapping.remove(&shard_id); | |
} | |
Ok(()) | |
} | |
fn rebuild_rings(&mut self) { | |
let mut rings = HashMap::from([(None, HashRingRouter::single())]); | |
let ids_to_key = self.get_shard_id_to_key_mapping(); | |
for shard_id in self.shards.keys() { | |
let shard_key = ids_to_key.get(shard_id).cloned(); | |
rings | |
.entry(shard_key) | |
.or_insert_with(HashRingRouter::single) | |
.add(*shard_id); | |
} | |
// Restore resharding hash ring if resharding is active and haven't reached | |
// `WriteHashRingCommitted` stage yet | |
if let Some(state) = self.resharding_state.read().deref() { | |
let ring = rings | |
.get_mut(&state.shard_key) | |
.expect("must have hash ring for current resharding shard key"); | |
ring.start_resharding(state.shard_id, state.direction); | |
if state.stage >= ReshardStage::WriteHashRingCommitted { | |
ring.commit_resharding(); | |
} | |
} | |
self.rings = rings; | |
} | |
pub async fn apply_shards_state( | |
&mut self, | |
shard_ids: HashSet<ShardId>, | |
shard_key_mapping: ShardKeyMapping, | |
extra_shards: HashMap<ShardId, ShardReplicaSet>, | |
) -> Result<(), CollectionError> { | |
self.shards.extend(extra_shards.into_iter()); | |
let all_shard_ids = self.shards.keys().cloned().collect::<HashSet<_>>(); | |
self.key_mapping | |
.write_optional(|_key_mapping| Some(shard_key_mapping))?; | |
for shard_id in all_shard_ids { | |
if !shard_ids.contains(&shard_id) { | |
self.drop_and_remove_shard(shard_id).await?; | |
} | |
} | |
self.rebuild_rings(); | |
Ok(()) | |
} | |
pub fn contains_shard(&self, shard_id: ShardId) -> bool { | |
self.shards.contains_key(&shard_id) | |
} | |
pub fn get_shard(&self, shard_id: ShardId) -> Option<&ShardReplicaSet> { | |
self.shards.get(&shard_id) | |
} | |
pub fn get_shards(&self) -> impl Iterator<Item = (ShardId, &ShardReplicaSet)> { | |
self.shards.iter().map(|(id, shard)| (*id, shard)) | |
} | |
pub fn all_shards(&self) -> impl Iterator<Item = &ShardReplicaSet> { | |
self.shards.values() | |
} | |
pub fn split_by_shard<O: SplitByShard + Clone>( | |
&self, | |
operation: O, | |
shard_keys_selection: &Option<ShardKey>, | |
) -> CollectionResult<Vec<(&ShardReplicaSet, O)>> { | |
let Some(hashring) = self.rings.get(&shard_keys_selection.clone()) else { | |
return if let Some(shard_key) = shard_keys_selection { | |
Err(CollectionError::bad_input(format!( | |
"Shard key {shard_key} not found" | |
))) | |
} else { | |
Err(CollectionError::bad_input( | |
"Shard key not specified".to_string(), | |
)) | |
}; | |
}; | |
if hashring.is_empty() { | |
return Err(CollectionError::bad_input( | |
"No shards found for shard key".to_string(), | |
)); | |
} | |
let operation_to_shard = operation.split_by_shard(hashring); | |
let shard_ops: Vec<_> = match operation_to_shard { | |
OperationToShard::ByShard(by_shard) => by_shard | |
.into_iter() | |
.map(|(shard_id, operation)| (self.shards.get(&shard_id).unwrap(), operation)) | |
.collect(), | |
OperationToShard::ToAll(operation) => { | |
if let Some(shard_key) = shard_keys_selection { | |
let shard_ids = self | |
.key_mapping | |
.read() | |
.get(shard_key) | |
.cloned() | |
.unwrap_or_default(); | |
shard_ids | |
.into_iter() | |
.map(|shard_id| (self.shards.get(&shard_id).unwrap(), operation.clone())) | |
.collect() | |
} else { | |
self.all_shards() | |
.map(|shard| (shard, operation.clone())) | |
.collect() | |
} | |
} | |
}; | |
Ok(shard_ops) | |
} | |
pub fn register_start_shard_transfer(&self, transfer: ShardTransfer) -> CollectionResult<bool> { | |
let changed = self | |
.shard_transfers | |
.write(|transfers| transfers.insert(transfer.clone()))?; | |
let _ = self | |
.shard_transfer_changes | |
.send(ShardTransferChange::Start(transfer)); | |
Ok(changed) | |
} | |
pub fn register_finish_transfer(&self, key: &ShardTransferKey) -> CollectionResult<bool> { | |
let any_removed = self.shard_transfers.write(|transfers| { | |
let before_remove = transfers.len(); | |
transfers.retain(|transfer| !key.check(transfer)); | |
before_remove != transfers.len() | |
})?; | |
let _ = self | |
.shard_transfer_changes | |
.send(ShardTransferChange::Finish(*key)); | |
Ok(any_removed) | |
} | |
pub fn register_abort_transfer(&self, key: &ShardTransferKey) -> CollectionResult<bool> { | |
let any_removed = self.shard_transfers.write(|transfers| { | |
let before_remove = transfers.len(); | |
transfers.retain(|transfer| !key.check(transfer)); | |
before_remove != transfers.len() | |
})?; | |
let _ = self | |
.shard_transfer_changes | |
.send(ShardTransferChange::Abort(*key)); | |
Ok(any_removed) | |
} | |
/// Await for a given shard transfer to complete. | |
/// | |
/// The returned inner result defines whether it successfully finished or whether it was | |
/// aborted/cancelled. | |
pub fn await_shard_transfer_end( | |
&self, | |
transfer: ShardTransferKey, | |
timeout: Duration, | |
) -> impl Future<Output = CollectionResult<Result<(), ()>>> { | |
let mut subscriber = self.shard_transfer_changes.subscribe(); | |
let receiver = async move { | |
loop { | |
match subscriber.recv().await { | |
Err(tokio::sync::broadcast::error::RecvError::Closed) => return Err(CollectionError::service_error( | |
"Failed to await shard transfer end: failed to listen for shard transfer changes, channel closed" | |
)), | |
Err(err @ tokio::sync::broadcast::error::RecvError::Lagged(_)) => return Err(CollectionError::service_error(format!( | |
"Failed to await shard transfer end: failed to listen for shard transfer changes, channel lagged behind: {err}" | |
))), | |
Ok(ShardTransferChange::Finish(key)) if key == transfer => return Ok(Ok(())), | |
Ok(ShardTransferChange::Abort(key)) if key == transfer => return Ok(Err(())), | |
Ok(_) => {}, | |
} | |
} | |
}; | |
async move { | |
match tokio::time::timeout(timeout, receiver).await { | |
Ok(operation) => Ok(operation?), | |
// Timeout | |
Err(err) => Err(CollectionError::service_error(format!( | |
"Awaiting for shard transfer end timed out: {err}" | |
))), | |
} | |
} | |
} | |
/// The count of incoming and outgoing shard transfers on the given peer | |
/// | |
/// This only includes shard transfers that are in consensus for the current collection. A | |
/// shard transfer that has just been proposed may not be included yet. | |
pub fn count_shard_transfer_io(&self, peer_id: PeerId) -> (usize, usize) { | |
let (mut incoming, mut outgoing) = (0, 0); | |
for transfer in self.shard_transfers.read().iter() { | |
incoming += usize::from(transfer.to == peer_id); | |
outgoing += usize::from(transfer.from == peer_id); | |
} | |
(incoming, outgoing) | |
} | |
pub fn get_shard_transfer_info( | |
&self, | |
tasks_pool: &TransferTasksPool, | |
) -> Vec<ShardTransferInfo> { | |
let mut shard_transfers = vec![]; | |
for shard_transfer in self.shard_transfers.read().iter() { | |
let shard_id = shard_transfer.shard_id; | |
let to_shard_id = shard_transfer.to_shard_id; | |
let to = shard_transfer.to; | |
let from = shard_transfer.from; | |
let sync = shard_transfer.sync; | |
let method = shard_transfer.method; | |
let status = tasks_pool.get_task_status(&shard_transfer.key()); | |
shard_transfers.push(ShardTransferInfo { | |
shard_id, | |
to_shard_id, | |
from, | |
to, | |
sync, | |
method, | |
comment: status.map(|p| p.comment), | |
}) | |
} | |
shard_transfers.sort_by_key(|k| k.shard_id); | |
shard_transfers | |
} | |
pub fn get_resharding_operations_info( | |
&self, | |
tasks_pool: &ReshardTasksPool, | |
) -> Option<Vec<ReshardingInfo>> { | |
let mut resharding_operations = vec![]; | |
// We eventually expect to extend this to multiple concurrent operations, which is why | |
// we're using a list here | |
let Some(resharding_state) = &*self.resharding_state.read() else { | |
return None; | |
}; | |
let status = tasks_pool.get_task_status(&resharding_state.key()); | |
resharding_operations.push(ReshardingInfo { | |
shard_id: resharding_state.shard_id, | |
peer_id: resharding_state.peer_id, | |
direction: resharding_state.direction, | |
shard_key: resharding_state.shard_key.clone(), | |
comment: status.map(|p| p.comment), | |
}); | |
resharding_operations.sort_by_key(|k| k.shard_id); | |
Some(resharding_operations) | |
} | |
pub fn get_related_transfers(&self, shard_id: ShardId, peer_id: PeerId) -> Vec<ShardTransfer> { | |
self.get_transfers(|transfer| { | |
transfer.shard_id == shard_id && (transfer.from == peer_id || transfer.to == peer_id) | |
}) | |
} | |
fn get_shard_ids_by_key(&self, shard_key: &ShardKey) -> CollectionResult<HashSet<ShardId>> { | |
match self.key_mapping.read().get(shard_key).cloned() { | |
None => Err(CollectionError::bad_request(format!( | |
"Shard key {shard_key} not found" | |
))), | |
Some(ids) => Ok(ids), | |
} | |
} | |
pub fn select_shards<'a>( | |
&'a self, | |
shard_selector: &'a ShardSelectorInternal, | |
) -> CollectionResult<Vec<(&ShardReplicaSet, Option<&ShardKey>)>> { | |
let mut res = Vec::new(); | |
match shard_selector { | |
ShardSelectorInternal::Empty => { | |
debug_assert!(false, "Do not expect empty shard selector") | |
} | |
ShardSelectorInternal::All => { | |
for (&shard_id, shard) in self.shards.iter() { | |
// Ignore a new resharding shard until it completed point migration | |
// The shard will be marked as active at the end of the migration stage | |
let resharding_migrating_up = | |
self.resharding_state.read().clone().map_or(false, |state| { | |
state.direction == ReshardingDirection::Up | |
&& state.shard_id == shard_id | |
&& state.stage < ReshardStage::ReadHashRingCommitted | |
}); | |
if resharding_migrating_up { | |
continue; | |
} | |
let shard_key = self.shard_id_to_key_mapping.get(&shard_id); | |
res.push((shard, shard_key)); | |
} | |
} | |
ShardSelectorInternal::ShardKey(shard_key) => { | |
for shard_id in self.get_shard_ids_by_key(shard_key)? { | |
if let Some(replica_set) = self.shards.get(&shard_id) { | |
res.push((replica_set, Some(shard_key))); | |
} else { | |
debug_assert!(false, "Shard id {shard_id} not found") | |
} | |
} | |
} | |
ShardSelectorInternal::ShardKeys(shard_keys) => { | |
for shard_key in shard_keys { | |
for shard_id in self.get_shard_ids_by_key(shard_key)? { | |
if let Some(replica_set) = self.shards.get(&shard_id) { | |
res.push((replica_set, Some(shard_key))); | |
} else { | |
debug_assert!(false, "Shard id {shard_id} not found") | |
} | |
} | |
} | |
} | |
ShardSelectorInternal::ShardId(shard_id) => { | |
if let Some(replica_set) = self.shards.get(shard_id) { | |
res.push((replica_set, self.shard_id_to_key_mapping.get(shard_id))); | |
} else { | |
return Err(shard_not_found_error(*shard_id)); | |
} | |
} | |
} | |
Ok(res) | |
} | |
pub fn len(&self) -> usize { | |
self.shards.len() | |
} | |
pub fn is_empty(&self) -> bool { | |
self.shards.is_empty() | |
} | |
pub async fn load_shards( | |
&mut self, | |
collection_path: &Path, | |
collection_id: &CollectionId, | |
collection_config: Arc<RwLock<CollectionConfigInternal>>, | |
effective_optimizers_config: OptimizersConfig, | |
shared_storage_config: Arc<SharedStorageConfig>, | |
payload_index_schema: Arc<SaveOnDisk<PayloadIndexSchema>>, | |
channel_service: ChannelService, | |
on_peer_failure: ChangePeerFromState, | |
abort_shard_transfer: AbortShardTransfer, | |
this_peer_id: PeerId, | |
update_runtime: Handle, | |
search_runtime: Handle, | |
optimizer_cpu_budget: CpuBudget, | |
) { | |
let shard_number = collection_config.read().await.params.shard_number.get(); | |
let (shard_ids_list, shard_id_to_key_mapping) = match collection_config | |
.read() | |
.await | |
.params | |
.sharding_method | |
.unwrap_or_default() | |
{ | |
ShardingMethod::Auto => { | |
let ids_list = (0..shard_number).collect::<Vec<_>>(); | |
let shard_id_to_key_mapping = HashMap::new(); | |
(ids_list, shard_id_to_key_mapping) | |
} | |
ShardingMethod::Custom => { | |
let shard_id_to_key_mapping = self.get_shard_id_to_key_mapping(); | |
let ids_list = shard_id_to_key_mapping | |
.keys() | |
.cloned() | |
.sorted() | |
.collect::<Vec<_>>(); | |
(ids_list, shard_id_to_key_mapping.clone()) | |
} | |
}; | |
// ToDo: remove after version 0.11.0 | |
for shard_id in shard_ids_list { | |
for (path, _shard_version, shard_type) in | |
latest_shard_paths(collection_path, shard_id).await.unwrap() | |
{ | |
let replica_set = ShardReplicaSet::load( | |
shard_id, | |
collection_id.clone(), | |
&path, | |
collection_config.clone(), | |
effective_optimizers_config.clone(), | |
shared_storage_config.clone(), | |
payload_index_schema.clone(), | |
channel_service.clone(), | |
on_peer_failure.clone(), | |
abort_shard_transfer.clone(), | |
this_peer_id, | |
update_runtime.clone(), | |
search_runtime.clone(), | |
optimizer_cpu_budget.clone(), | |
) | |
.await; | |
let mut require_migration = true; | |
match shard_type { | |
ShardType::Local => { | |
// deprecated | |
let local_shard = LocalShard::load( | |
shard_id, | |
collection_id.clone(), | |
&path, | |
collection_config.clone(), | |
effective_optimizers_config.clone(), | |
shared_storage_config.clone(), | |
payload_index_schema.clone(), | |
update_runtime.clone(), | |
search_runtime.clone(), | |
optimizer_cpu_budget.clone(), | |
) | |
.await | |
.unwrap(); | |
replica_set | |
.set_local(local_shard, Some(ReplicaState::Active)) | |
.await | |
.unwrap(); | |
} | |
ShardType::Remote { peer_id } => { | |
// deprecated | |
replica_set | |
.add_remote(peer_id, ReplicaState::Active) | |
.await | |
.unwrap(); | |
} | |
ShardType::Temporary => { | |
// deprecated | |
let temp_shard = LocalShard::load( | |
shard_id, | |
collection_id.clone(), | |
&path, | |
collection_config.clone(), | |
effective_optimizers_config.clone(), | |
shared_storage_config.clone(), | |
payload_index_schema.clone(), | |
update_runtime.clone(), | |
search_runtime.clone(), | |
optimizer_cpu_budget.clone(), | |
) | |
.await | |
.unwrap(); | |
replica_set | |
.set_local(temp_shard, Some(ReplicaState::Partial)) | |
.await | |
.unwrap(); | |
} | |
ShardType::ReplicaSet => { | |
require_migration = false; | |
// nothing to do, replicate set should be loaded already | |
} | |
} | |
// Migrate shard config to replica set | |
// Override existing shard configuration | |
if require_migration { | |
ShardConfig::new_replica_set() | |
.save(&path) | |
.map_err(|e| panic!("Failed to save shard config {path:?}: {e}")) | |
.unwrap(); | |
} | |
// Change local shards stuck in Initializing state to Active | |
let local_peer_id = replica_set.this_peer_id(); | |
let not_distributed = !shared_storage_config.is_distributed; | |
let is_local = | |
replica_set.this_peer_id() == local_peer_id && replica_set.is_local().await; | |
let is_initializing = | |
replica_set.peer_state(local_peer_id) == Some(ReplicaState::Initializing); | |
if not_distributed && is_local && is_initializing { | |
log::warn!("Local shard {collection_id}:{} stuck in Initializing state, changing to Active", replica_set.shard_id); | |
replica_set | |
.set_replica_state(local_peer_id, ReplicaState::Active) | |
.expect("Failed to set local shard state"); | |
} | |
let shard_key = shard_id_to_key_mapping.get(&shard_id).cloned(); | |
self.add_shard(shard_id, replica_set, shard_key).unwrap(); | |
} | |
} | |
// If resharding, rebuild the hash rings because they'll be messed up | |
if self.resharding_state.read().is_some() { | |
self.rebuild_rings(); | |
} | |
} | |
pub fn assert_shard_exists(&self, shard_id: ShardId) -> CollectionResult<()> { | |
match self.get_shard(shard_id) { | |
Some(_) => Ok(()), | |
None => Err(shard_not_found_error(shard_id)), | |
} | |
} | |
async fn assert_shard_is_local(&self, shard_id: ShardId) -> CollectionResult<()> { | |
let is_local_shard = self | |
.is_shard_local(shard_id) | |
.await | |
.ok_or_else(|| shard_not_found_error(shard_id))?; | |
if is_local_shard { | |
Ok(()) | |
} else { | |
Err(CollectionError::bad_input(format!( | |
"Shard {shard_id} is not a local shard" | |
))) | |
} | |
} | |
async fn assert_shard_is_local_or_queue_proxy( | |
&self, | |
shard_id: ShardId, | |
) -> CollectionResult<()> { | |
let is_local_shard = self | |
.is_shard_local_or_queue_proxy(shard_id) | |
.await | |
.ok_or_else(|| shard_not_found_error(shard_id))?; | |
if is_local_shard { | |
Ok(()) | |
} else { | |
Err(CollectionError::bad_input(format!( | |
"Shard {shard_id} is not a local or queue proxy shard" | |
))) | |
} | |
} | |
/// Returns true if shard is explicitly local, false otherwise. | |
pub async fn is_shard_local(&self, shard_id: ShardId) -> Option<bool> { | |
match self.get_shard(shard_id) { | |
Some(shard) => Some(shard.is_local().await), | |
None => None, | |
} | |
} | |
/// Returns true if shard is explicitly local or is queue proxy shard, false otherwise. | |
pub async fn is_shard_local_or_queue_proxy(&self, shard_id: ShardId) -> Option<bool> { | |
match self.get_shard(shard_id) { | |
Some(shard) => Some(shard.is_local().await || shard.is_queue_proxy().await), | |
None => None, | |
} | |
} | |
/// Return a list of local shards, present on this peer | |
pub async fn get_local_shards(&self) -> Vec<ShardId> { | |
let mut res = Vec::with_capacity(1); | |
for (shard_id, replica_set) in self.get_shards() { | |
if replica_set.has_local_shard().await { | |
res.push(shard_id); | |
} | |
} | |
res | |
} | |
/// Count how many shard replicas are on the given peer. | |
pub fn count_peer_shards(&self, peer_id: PeerId) -> usize { | |
self.get_shards() | |
.filter(|(_, replica_set)| replica_set.peer_state(peer_id).is_some()) | |
.count() | |
} | |
pub fn check_transfer_exists(&self, transfer_key: &ShardTransferKey) -> bool { | |
self.shard_transfers | |
.read() | |
.iter() | |
.any(|transfer| transfer_key.check(transfer)) | |
} | |
pub fn get_transfer(&self, transfer_key: &ShardTransferKey) -> Option<ShardTransfer> { | |
self.shard_transfers | |
.read() | |
.iter() | |
.find(|transfer| transfer_key.check(transfer)) | |
.cloned() | |
} | |
pub fn get_transfers<F>(&self, mut predicate: F) -> Vec<ShardTransfer> | |
where | |
F: FnMut(&ShardTransfer) -> bool, | |
{ | |
self.shard_transfers | |
.read() | |
.iter() | |
.filter(|&transfer| predicate(transfer)) | |
.cloned() | |
.collect() | |
} | |
pub fn get_outgoing_transfers(&self, current_peer_id: PeerId) -> Vec<ShardTransfer> { | |
self.get_transfers(|transfer| transfer.from == current_peer_id) | |
} | |
/// # Cancel safety | |
/// | |
/// This method is cancel safe. | |
pub async fn list_shard_snapshots( | |
&self, | |
snapshots_path: &Path, | |
shard_id: ShardId, | |
) -> CollectionResult<Vec<SnapshotDescription>> { | |
self.assert_shard_is_local(shard_id).await?; | |
let snapshots_path = Self::snapshots_path_for_shard_unchecked(snapshots_path, shard_id); | |
let shard = self | |
.get_shard(shard_id) | |
.ok_or_else(|| shard_not_found_error(shard_id))?; | |
let snapshot_manager = shard.get_snapshots_storage_manager()?; | |
snapshot_manager.list_snapshots(&snapshots_path).await | |
} | |
/// # Cancel safety | |
/// | |
/// This method is cancel safe. | |
pub async fn create_shard_snapshot( | |
&self, | |
snapshots_path: &Path, | |
collection_name: &str, | |
shard_id: ShardId, | |
temp_dir: &Path, | |
) -> CollectionResult<SnapshotDescription> { | |
// - `snapshot_temp_dir` and `temp_file` are handled by `tempfile` | |
// and would be deleted, if future is cancelled | |
let shard = self | |
.get_shard(shard_id) | |
.ok_or_else(|| shard_not_found_error(shard_id))?; | |
if !shard.is_local().await && !shard.is_queue_proxy().await { | |
return Err(CollectionError::bad_input(format!( | |
"Shard {shard_id} is not a local or queue proxy shard" | |
))); | |
} | |
let snapshot_file_name = format!( | |
"{collection_name}-shard-{shard_id}-{}.snapshot", | |
chrono::Utc::now().format("%Y-%m-%d-%H-%M-%S"), | |
); | |
let snapshot_temp_dir = tempfile::Builder::new() | |
.prefix(&format!("{snapshot_file_name}-temp-")) | |
.tempdir_in(temp_dir)?; | |
let temp_file = tempfile::Builder::new() | |
.prefix(&format!("{snapshot_file_name}-")) | |
.suffix(".tar") | |
.tempfile_in(temp_dir)?; | |
let tar = BuilderExt::new_seekable_owned(File::create(temp_file.path())?); | |
shard | |
.create_snapshot( | |
snapshot_temp_dir.path(), | |
&tar, | |
SnapshotFormat::Regular, | |
false, | |
) | |
.await?; | |
let snapshot_temp_dir_path = snapshot_temp_dir.path().to_path_buf(); | |
if let Err(err) = snapshot_temp_dir.close() { | |
log::error!( | |
"Failed to remove temporary directory {}: {err}", | |
snapshot_temp_dir_path.display(), | |
); | |
} | |
tar.finish().await?; | |
let snapshot_path = | |
Self::shard_snapshot_path_unchecked(snapshots_path, shard_id, snapshot_file_name)?; | |
let snapshot_manager = shard.get_snapshots_storage_manager()?; | |
let snapshot_description = snapshot_manager | |
.store_file(temp_file.path(), &snapshot_path) | |
.await; | |
if snapshot_description.is_ok() { | |
let _ = temp_file.keep(); | |
} | |
snapshot_description | |
} | |
/// # Cancel safety | |
/// | |
/// This method is cancel safe. | |
pub async fn stream_shard_snapshot( | |
shard: OwnedRwLockReadGuard<ShardHolder, ShardReplicaSet>, | |
collection_name: &str, | |
shard_id: ShardId, | |
temp_dir: &Path, | |
) -> CollectionResult<SnapshotStream> { | |
// - `snapshot_temp_dir` and `temp_file` are handled by `tempfile` | |
// and would be deleted, if future is cancelled | |
if !shard.is_local().await && !shard.is_queue_proxy().await { | |
return Err(CollectionError::bad_input(format!( | |
"Shard {shard_id} is not a local or queue proxy shard" | |
))); | |
} | |
let snapshot_file_name = format!( | |
"{collection_name}-shard-{shard_id}-{}.snapshot", | |
chrono::Utc::now().format("%Y-%m-%d-%H-%M-%S"), | |
); | |
let snapshot_temp_dir = tempfile::Builder::new() | |
.prefix(&format!("{snapshot_file_name}-temp-")) | |
.tempdir_in(temp_dir)?; | |
let (read_half, write_half) = tokio::io::duplex(4096); | |
tokio::spawn(async move { | |
let tar = BuilderExt::new_streaming_owned(SyncIoBridge::new(write_half)); | |
shard | |
.create_snapshot( | |
snapshot_temp_dir.path(), | |
&tar, | |
SnapshotFormat::Streamable, | |
false, | |
) | |
.await?; | |
let snapshot_temp_dir_path = snapshot_temp_dir.path().to_path_buf(); | |
if let Err(err) = snapshot_temp_dir.close() { | |
log::error!( | |
"Failed to remove temporary directory {}: {err}", | |
snapshot_temp_dir_path.display(), | |
); | |
} | |
tar.finish().await?; | |
CollectionResult::Ok(()) | |
}); | |
Ok(SnapshotStream::new_stream( | |
FramedRead::new(read_half, BytesCodec::new()).map_ok(|bytes| bytes.freeze()), | |
Some(snapshot_file_name), | |
)) | |
} | |
/// # Cancel safety | |
/// | |
/// This method is *not* cancel safe. | |
pub async fn restore_shard_snapshot( | |
&self, | |
snapshot_path: &Path, | |
collection_name: &str, | |
shard_id: ShardId, | |
this_peer_id: PeerId, | |
is_distributed: bool, | |
temp_dir: &Path, | |
cancel: cancel::CancellationToken, | |
) -> CollectionResult<()> { | |
if !self.contains_shard(shard_id) { | |
return Err(shard_not_found_error(shard_id)); | |
} | |
if !temp_dir.exists() { | |
std::fs::create_dir_all(temp_dir)?; | |
} | |
let snapshot_file_name = snapshot_path.file_name().unwrap().to_string_lossy(); | |
let snapshot_path = snapshot_path.to_path_buf(); | |
let snapshot_temp_dir = tempfile::Builder::new() | |
.prefix(&format!( | |
"{collection_name}-shard-{shard_id}-{snapshot_file_name}" | |
)) | |
.tempdir_in(temp_dir)?; | |
let task = { | |
let snapshot_temp_dir = snapshot_temp_dir.path().to_path_buf(); | |
cancel::blocking::spawn_cancel_on_token( | |
cancel.child_token(), | |
move |cancel| -> CollectionResult<_> { | |
let mut tar = open_snapshot_archive_with_validation(&snapshot_path)?; | |
if cancel.is_cancelled() { | |
return Err(cancel::Error::Cancelled.into()); | |
} | |
tar.unpack(&snapshot_temp_dir)?; | |
drop(tar); | |
if cancel.is_cancelled() { | |
return Err(cancel::Error::Cancelled.into()); | |
} | |
ShardReplicaSet::restore_snapshot( | |
&snapshot_temp_dir, | |
this_peer_id, | |
is_distributed, | |
)?; | |
Ok(()) | |
}, | |
) | |
}; | |
task.await??; | |
// `ShardHolder::recover_local_shard_from` is *not* cancel safe | |
// (see `ShardReplicaSet::restore_local_replica_from`) | |
let recovered = self | |
.recover_local_shard_from(snapshot_temp_dir.path(), shard_id, cancel) | |
.await?; | |
if !recovered { | |
return Err(CollectionError::bad_request(format!( | |
"Invalid snapshot {snapshot_file_name}" | |
))); | |
} | |
Ok(()) | |
} | |
/// # Cancel safety | |
/// | |
/// This method is *not* cancel safe. | |
pub async fn recover_local_shard_from( | |
&self, | |
snapshot_shard_path: &Path, | |
shard_id: ShardId, | |
cancel: cancel::CancellationToken, | |
) -> CollectionResult<bool> { | |
// TODO: | |
// Check that shard snapshot is compatible with the collection | |
// (see `VectorsConfig::check_compatible_with_segment_config`) | |
let replica_set = self | |
.get_shard(shard_id) | |
.ok_or_else(|| shard_not_found_error(shard_id))?; | |
// `ShardReplicaSet::restore_local_replica_from` is *not* cancel safe | |
replica_set | |
.restore_local_replica_from(snapshot_shard_path, cancel) | |
.await | |
} | |
/// # Cancel safety | |
/// | |
/// This method is cancel safe. | |
pub async fn get_shard_snapshot_path( | |
&self, | |
snapshots_path: &Path, | |
shard_id: ShardId, | |
snapshot_file_name: impl AsRef<Path>, | |
) -> CollectionResult<PathBuf> { | |
self.assert_shard_is_local_or_queue_proxy(shard_id).await?; | |
Self::shard_snapshot_path_unchecked(snapshots_path, shard_id, snapshot_file_name) | |
} | |
fn snapshots_path_for_shard_unchecked(snapshots_path: &Path, shard_id: ShardId) -> PathBuf { | |
snapshots_path.join(format!("shards/{shard_id}")) | |
} | |
fn shard_snapshot_path_unchecked( | |
snapshots_path: &Path, | |
shard_id: ShardId, | |
snapshot_file_name: impl AsRef<Path>, | |
) -> CollectionResult<PathBuf> { | |
let snapshots_path = Self::snapshots_path_for_shard_unchecked(snapshots_path, shard_id); | |
let snapshot_file_name = snapshot_file_name.as_ref(); | |
if snapshot_file_name.file_name() != Some(snapshot_file_name.as_os_str()) { | |
return Err(CollectionError::not_found(format!( | |
"Snapshot {}", | |
snapshot_file_name.display(), | |
))); | |
} | |
let snapshot_path = snapshots_path.join(snapshot_file_name); | |
Ok(snapshot_path) | |
} | |
pub async fn remove_shards_at_peer(&self, peer_id: PeerId) -> CollectionResult<()> { | |
for (_shard_id, replica_set) in self.get_shards() { | |
replica_set.remove_peer(peer_id).await?; | |
} | |
Ok(()) | |
} | |
} | |
pub(crate) enum ShardTransferChange { | |
Start(ShardTransfer), | |
Finish(ShardTransferKey), | |
Abort(ShardTransferKey), | |
} | |
pub(crate) fn shard_not_found_error(shard_id: ShardId) -> CollectionError { | |
CollectionError::NotFound { | |
what: format!("shard {shard_id}"), | |
} | |
} | |