version: '3.8' | |
services: | |
toxic-classifier: | |
build: . | |
runtime: nvidia # Enable NVIDIA runtime for GPU support | |
environment: | |
- NVIDIA_VISIBLE_DEVICES=all | |
- WANDB_API_KEY=${WANDB_API_KEY} # Set this in .env file | |
volumes: | |
- ./dataset:/app/dataset # Mount dataset directory | |
- ./weights:/app/weights # Mount weights directory | |
command: python model/train.py # Default command, can be overridden |