Medresearch / components /data_utils.py
mgbam's picture
Add untracked files and synchronize with remote
9c7387c
raw
history blame contribute delete
666 Bytes
def partition_data(dataset, num_clients):
"""
Partitions a dataset into `num_clients` subsets.
This is just a placeholder. Implement a more sophisticated partitioning strategy
(e.g., based on medical specialty, patient demographics) for a real application.
"""
data_per_client = len(dataset) // num_clients
remaining_data = len(dataset) % num_clients
partitioned_data = []
start_index = 0
for i in range(num_clients):
end_index = start_index + data_per_client + (1 if i < remaining_data else 0)
partitioned_data.append(dataset[start_index:end_index])
start_index = end_index
return partitioned_data