Nawinkumar15 commited on
Commit
4721e1c
·
verified ·
1 Parent(s): fd64259

Upload train_model.py

Browse files
Files changed (1) hide show
  1. train_model.py +74 -0
train_model.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from ultralytics import YOLO
3
+
4
+ # --- Configuration ---
5
+ # Path to your data.yaml file.
6
+ DATA_YAML_PATH = "C:\\Users\\Veera\\Downloads\\My First Project.v1i.yolov8\\data.yaml"
7
+
8
+ # Choose a pre-trained YOLOv8 model to start with.
9
+ PRETRAINED_MODEL = "yolov8n.pt"
10
+
11
+ # Training parameters
12
+ EPOCHS = 50 # Number of training epochs. Adjust based on your dataset size and desired accuracy.
13
+ IMG_SIZE = 640 # Image size for training (as per Roboflow preprocessing).
14
+ BATCH_SIZE = 8 # Reduced batch size since you don't have a GPU. You might need to lower it further if you encounter memory issues.
15
+ PROJECT_NAME = "Model_trained" # Name of the project directory where results will be saved
16
+ RUN_NAME = "best.pt" # Name of the specific run within the project directory
17
+
18
+ # --- Main Training Logic ---
19
+ def train_yolov8_model():
20
+ """
21
+ Trains a YOLOv8 model using the specified dataset and parameters.
22
+ The trained model (best.pt) will be saved in runs/detect/{RUN_NAME}/weights/.
23
+ """
24
+ print(f"Starting YOLOv8 model training with {PRETRAINED_MODEL}...")
25
+
26
+ # 1. Load a pre-trained YOLOv8 model
27
+ try:
28
+ model = YOLO(PRETRAINED_MODEL)
29
+ print(f"Successfully loaded pre-trained model: {PRETRAINED_MODEL}")
30
+ except Exception as e:
31
+ print(f"Error loading pre-trained model: {e}")
32
+ print("Please ensure you have an active internet connection if downloading for the first time.")
33
+ return
34
+
35
+ # 2. Check if data.yaml exists
36
+ if not os.path.exists(DATA_YAML_PATH):
37
+ print(f"Error: data.yaml not found at '{DATA_YAML_PATH}'.")
38
+ print("Please ensure the 'data.yaml' file is in the correct location.")
39
+ return
40
+
41
+ # 3. Train the model
42
+ print(f"Training model on dataset defined in: {DATA_YAML_PATH}")
43
+ print(f"Training for {EPOCHS} epochs with image size {IMG_SIZE} and batch size {BATCH_SIZE} on CPU...")
44
+ print("Training on CPU will be significantly slower.")
45
+
46
+ try:
47
+ results = model.train(
48
+ data=DATA_YAML_PATH,
49
+ epochs=EPOCHS,
50
+ imgsz=IMG_SIZE,
51
+ batch=BATCH_SIZE,
52
+ project=PROJECT_NAME,
53
+ name=RUN_NAME
54
+ )
55
+ print("\nTraining completed successfully!")
56
+
57
+ # The best.pt file is typically saved in runs/detect/{RUN_NAME}/weights/best.pt
58
+ output_weights_dir = os.path.join("runs", "detect", RUN_NAME, "weights")
59
+ best_pt_path = os.path.join(output_weights_dir, "best.pt")
60
+
61
+ if os.path.exists(best_pt_path):
62
+ print(f"Your trained model (best.pt) is saved at: {os.path.abspath(best_pt_path)}")
63
+ print("You can now use this .pt file for local inference or upload it to Hugging Face.")
64
+ else:
65
+ print("Warning: 'best.pt' file not found in the expected location after training.")
66
+ print(f"Please check the output directory: {os.path.abspath(output_weights_dir)}")
67
+
68
+ except Exception as e:
69
+ print(f"An error occurred during training: {e}")
70
+ print("Common issues: insufficient CPU memory (try reducing batch_size), incorrect data.yaml paths.")
71
+
72
+ # Run the training function
73
+ if __name__ == "__main__":
74
+ train_yolov8_model()