Source code for src.pipeline.train

from ultralytics import YOLO
from pathlib import Path
from datetime import datetime
import json

from pathlib import Path

[docs] def find_latest_model(model_dir: str, fallback_model: str) -> str: """ Finds the YOLO model with the latest date in the filename. If none found, returns the fallback model. Args: model_dir (str): Directory containing YOLO model `.pt` files. fallback_model (str): Path to fallback model (used if none found). Returns: str: Path to the latest-dated model or the fallback model. """ model_dir = Path(model_dir) models = sorted( model_dir.glob("*_updated_yolo.pt"), key=lambda x: x.stem, reverse=True ) if models: return str(models[0]) else: print(f"[WARN] Updated model not found. Falling back to: {fallback_model}") return fallback_model
[docs] def load_train_config(config_path: str) -> dict: """ Loads training configuration from JSON file. Args: config_path (str): Path to the train_config.json file. Returns: dict: configuration dictionary """ config_path = Path(config_path) # Check if file exists if not config_path.exists(): raise FileNotFoundError(f"Error: Config file not found at {config_path}") # Load JSON into dictionary with open(config_path, "r") as f: config = json.load(f) assert "training_config" in config, "Error: 'training_config' section missing in train_config.json" assert "data_yaml_path" in config, "Error: 'data_yaml_path' section missing in train_config.json" assert "initial_model_path" in config, "Error: 'initial_model_path' section missing in train_config.json" return config
[docs] def train_model(config: dict) -> str: """ Trains a YOLOv8 model using the Ultralytics library and saves the trained model and metadata. Args: config (dict): Loaded config dictionary from train_config.json, containing: - data_yaml_path (str): Path to `data.yaml` - torch_device (str): 'cpu' or 'cuda' - training_config (dict): eg., epochs, lr0, imgsz, batch, workers, etc - model_path (str): (Optional) Path to a pre-trained model to fine-tune. Returns: str: Path to the saved trained model (.pt) """ # Define paths model_dir = Path("automl_workspace/model_registry/model") model_dir.mkdir(parents=True, exist_ok=True) user_model_path = config.get("model_path") initial_model_path = config.get("initial_model_path", "automl_workspace/model_registry/model/nano_trained_model.pt") if user_model_path: model_path = user_model_path print(f"[INFO] Using model specified in config: {model_path}") else: model_path = find_latest_model(model_dir, initial_model_path) print(f"[INFO] Using latest model: {model_path}") # Load YOLO model model = YOLO(model_path) # Extract training parameters train_args = config["training_config"] train_args["data"] = config["data_yaml_path"] train_args["device"] = config.get("torch_device", "cpu") # Generate a timestamped name if user did not specify one timestamp = datetime.now().strftime("%Y-%m-%d_%H_%M_%S") output_model_name = config.get("output_model_name") or f"{timestamp}_updated_yolo.pt" trained_model_path = model_dir / output_model_name # Define metadata and runs output directory model_info_dir = model_dir / "model_info" / output_model_name.replace(".pt", "") model_info_dir.mkdir(parents=True, exist_ok=True) metadata_path = model_info_dir / "metadata.json" run_output_dir = model_info_dir / "runs" run_output_dir.mkdir(parents=True, exist_ok=True) # Save model output train_args["project"] = str(model_info_dir) train_args["name"] = "train" # Run training with all arguments from config model.train(**train_args) # Save trained model model.save(str(trained_model_path)) # Save metadata with training info metadata = { "model_name": output_model_name, "trained_from": model_path, "timestamp": timestamp, "training_args": train_args } with open(metadata_path, "w") as f: json.dump(metadata, f, indent=2) # Evaluate on test set if defined in data.yaml try: test_results = model.val(split='test') test_metrics = { "map_50": test_results.box.map50, "map_75": test_results.box.map75, "map_50_95": test_results.box.map, } metadata["test_metrics"] = test_metrics with open(metadata_path, "w") as f: json.dump(metadata, f, indent=2) except Exception as e: print(f"[WARN] Test evaluation failed or skipped: {e}") print(f"[INFO] Training complete. Model saved to {trained_model_path}") print(f"[INFO] Metadata saved to {metadata_path}") return str(trained_model_path)
# Entry point for standalone use if __name__ == "__main__": config = load_train_config("train_config.json") train_model(config)