# Python standard library
import json
import os
import sys
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any, Literal
import copy
import shutil
# Set environment variable for MPS fallback
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
# Third-party libraries
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import LambdaLR
from torchvision.ops.ciou_loss import complete_box_iou_loss
import torch.nn.functional as F
from tqdm import tqdm
import csv
from datetime import datetime
# Add parent directory to path
sys.path.append("..")
sys.path.append("../..") # Add this to reach src/
# Ultralytics imports
from ultralytics import YOLO
from ultralytics.utils import YAML
from ultralytics.models.yolo.model import DetectionModel
from ultralytics.cfg import get_cfg
from ultralytics.utils.loss import v8DetectionLoss, BboxLoss
from ultralytics.data.build import build_yolo_dataset, build_dataloader, YOLODataset
from ultralytics.models.yolo.detect.train import DetectionTrainer
from ultralytics.models.yolo.detect.val import DetectionValidator
from ultralytics.utils.metrics import bbox_iou
from ultralytics.utils.tal import make_anchors
from ultralytics.utils.ops import non_max_suppression
from ultralytics.utils.torch_utils import one_cycle, EarlyStopping, ModelEMA
# Custom modules
from utils import load_config, detect_device, create_distill_yaml
SCRIPT_DIR = Path(__file__).parent
[docs]
def load_models(device: str, base_dir: Path, distillation_config: Dict[str, Any]) -> Tuple[YOLO, YOLO]:
"""
Load teacher and student models.
Args:
device: Device to load models on
base_dir: Base directory for model paths
distillation_config: Configuration dictionary for distillation
Returns:
Tuple of (teacher_yolo, student_yolo) models
"""
# Load the teacher model (our pre-trained model)
teacher_yolo = YOLO(
distillation_config["teacher_model"],
).to(device)
# Load the student model (our new model, random initialized weights)
student_yolo = (
YOLO(base_dir / "pipeline/distillation/yolov8n-5class.yaml")
.load(base_dir / "pipeline/distillation/yolov8n.pt")
).to(device)
student_yolo.yaml["nc"] = 5
student_model = student_yolo.model
student_model.nc = 5
# Set model args from distillation config
student_model.args = get_cfg(distillation_config)
# Sanity check
assert student_yolo.nc == 5, "student_yolo.nc should be 5"
assert student_model.nc == 5, "student_model.nc should be 5"
assert isinstance(teacher_yolo, nn.Module)
assert isinstance(student_yolo, nn.Module)
return teacher_yolo, student_yolo
[docs]
def prepare_dataset(img_path: Path, student_model: nn.Module, batch_size: int = 16, mode: str = "train") -> Tuple[YOLODataset, DataLoader]:
"""
Prepare dataset and dataloader for training.
Notes:
number_of_objects_detected: the number of objects detected in all images in the batch
batch_size: number of images in the batch
- each batch in the train_dataloader contains:
- batch_idx:
tensor of shape (number_of_objects_detected),
for each object, the value is 0, ... batch_size - 1,
depending on the index of the image that the object belongs to in the batch
- img: image tensor of shape (batch_size, 3, 640, 640)
- bboxes: bboxes tensor of shape (number_of_objects_detected, 4), 4 is for normalized x1, y1, x2, y2
- cls: cls tensor of shape (number_of_objects_detected, 1), containing all class labels of the objects detected in the batch
- resized_shape: Resized 2D dim of the image. A list of tensor, first tensor is first dim, second tensor is second dim
- ori_shape: Original 2D dim of the image. Alist of tensor, first tensor is first dim, second tensor is second dim
Args:
img_path: Directory containing images
student_model: Student model instance
batch_size: Batch size for training
mode: Dataset mode ("train" or "val")
Returns:
Tuple of (dataset, dataloader)
"""
data = {
"names": {
0: "FireBSI",
1: "LightningBSI",
2: "PersonBSI",
3: "SmokeBSI",
4: "VehicleBSI"
},
"channels": 3,
}
train_dataset = build_yolo_dataset(
cfg=student_model.args,
img_path = img_path,
batch=batch_size,
data=data,
mode=mode,
)
train_dataloader = build_dataloader(
train_dataset,
batch=batch_size,
workers=0,
shuffle=False,
)
return train_dataset, train_dataloader
[docs]
def head_features_decoder(
head_feats: List[torch.Tensor],
nc: int,
detection_criterion: v8DetectionLoss,
reg_max: int = 16,
strides: List[int] = [8, 16, 32],
device: str = "cpu"
) -> torch.Tensor:
"""
Decode the head features into bounding boxes and class scores.
Args:
head_feats: List of tensors, each representing a feature map from a detection head
nc: Number of classes
detection_criterion: Detection loss criterion
reg_max: Maximum number of bounding box parameters
strides: List of strides for the feature maps
device: Device to perform computations on
Returns:
Tensor: pred_concatted: Concatenated bounding boxes and class raw logits scores
Shape is (batch_size, 4 + num_classes, total_predictions)
"""
b = head_feats[0].shape[0] # batch size
dfl_vals = reg_max * 4 # number of dfl encoded channels for bounding boxes
no = nc + dfl_vals # number of out channels
pred_dist, pred_scores = torch.cat(
[feat.view(b, no, -1) for feat in head_feats], dim=2
).split(
(dfl_vals, nc), dim=1
)
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
pred_dist = pred_dist.permute(0, 2, 1).contiguous()
anchor_points, _ = make_anchors(head_feats, strides, 0.5)
anchor_points = anchor_points.to(device)
pred_bboxes = detection_criterion.bbox_decode(anchor_points, pred_dist)
assert torch.any(pred_scores < 0) or torch.any(pred_scores > 1), "pred_scores should be logits, not sigmoid"
pred_concatted = torch.permute(torch.cat((pred_bboxes, pred_scores), dim=2), (0, 2, 1))
return pred_concatted
## This is the original distillation loss function, but it is not used in the current implementation
## Distillation loss with NMS preprocessing
# def compute_distillation_loss(
# student_preds: torch.Tensor,
# teacher_preds: torch.Tensor,
# args: Dict[str, Any],
# nc: int = 80,
# device: str = "cpu",
# eps: float = 1e-7,
# reduction: Literal["batchmean", "sum"] = "batchmean",
# hyperparams: Dict[str, float] = {
# "lambda_dist_ciou": 1.0,
# "lambda_dist_kl": 2.0
# }
# ) -> torch.Tensor:
# """
# Compute the distillation loss between the student and teacher predictions.
# Args:
# student_preds: The student predictions
# teacher_preds: The teacher predictions
# args: Configuration arguments
# nc: Number of classes
# device: Device to perform computations on
# eps: Small epsilon value for numerical stability
# reduction: Reduction method for the loss ("batchmean" or "sum")
# hyperparams: Dictionary of hyperparameters for loss functions
# Returns:
# Total distillation loss
# """
# if not isinstance(student_preds, torch.Tensor) or not isinstance(teacher_preds, torch.Tensor):
# raise ValueError("student_preds and teacher_preds must be tensors")
# batch_size = student_preds.shape[0]
# dtype = teacher_preds.dtype
# kldivloss = nn.KLDivLoss(reduction=reduction, log_target=True)
# # eps = torch.finfo().eps
# eps = 1e-7
# # Split the concatenated bounding boxes and class scores
# s_bbox, s_cls_logits = torch.split(student_preds, (4, nc), dim=1)
# s_cls_sigmoid = torch.sigmoid(s_cls_logits)
# assert torch.all(s_cls_sigmoid >= 0) and torch.all(s_cls_sigmoid <= 1), "s_cls_sigmoid should be sigmoid, not logits"
# teacher_preds_full = teacher_preds.clone()
# assert torch.all(teacher_preds_full[:, 4:, :] >= 0) and torch.all(teacher_preds_full[:, 4:, :] <= 1), "teacher_preds_full should be sigmoid, not logits"
# student_preds_full = torch.cat((s_bbox, s_cls_sigmoid), dim=1)
# common_nms_args = {
# "conf_thres": args.get("conf", 0.25) if args.get("conf") else 0.25,
# "iou_thres": args.get("iou", 0.7) if args.get("iou") else 0.7,
# "classes": args.get("classes", None),
# "agnostic": args.get("agnostic_nms", False) if args.get("agnostic_nms") else False,
# "max_det": args.get("max_det", 300) if args else 300,
# "nc": 0,
# "return_idxs": True,
# "max_time_img": 1
# }
# _, teacher_preds_final_idxs = non_max_suppression(
# prediction=teacher_preds_full,
# **common_nms_args,
# )
# selected_student_raw_predictions_list = []
# selected_teacher_raw_predictions_list = []
# for i in range(batch_size):
# student_preds_for_image_i = student_preds_full[i, ...].transpose(0, 1)
# teacher_preds_for_image_i = teacher_preds_full[i, ...].transpose(0, 1)
# indices_to_select = teacher_preds_final_idxs[i]
# if indices_to_select.numel() > 0:
# selected_student_preds = student_preds_for_image_i[indices_to_select]
# selected_teacher_preds = teacher_preds_for_image_i[indices_to_select]
# selected_student_raw_predictions_list.append(selected_student_preds)
# selected_teacher_raw_predictions_list.append(selected_teacher_preds)
# # get the actual batch size (batches with results)
# actual_batch_size = len(selected_student_raw_predictions_list)
# batch_box_regression_loss = torch.zeros(actual_batch_size, dtype=dtype)
# batch_cls_loss = torch.zeros(actual_batch_size, dtype=dtype)
# for i in range(actual_batch_size):
# tp, sp = selected_teacher_raw_predictions_list[i], selected_student_raw_predictions_list[i]
# s_bboxes, s_cls_sigmoid = torch.split(sp, (4, nc), dim=1)
# t_bboxes, t_cls_sigmoid = torch.split(tp, (4, nc), dim=1)
# assert torch.all(s_bboxes[..., 0] <= s_bboxes[..., 2]), "x1 coordinate should be less than x2 coordinate"
# assert torch.all(t_bboxes[..., 0] <= t_bboxes[..., 2]), "x1 coordinate should be less than x2 coordinate"
# assert torch.all(s_bboxes[..., 1] <= s_bboxes[..., 3]), "y1 coordinate should be less than y2 coordinate"
# assert torch.all(t_bboxes[..., 1] <= t_bboxes[..., 3]), "y1 coordinate should be less than y2 coordinate"
# ciou_loss = complete_box_iou_loss(s_bboxes, t_bboxes, reduction="mean")
# s_cls_logit = torch.logit(s_cls_sigmoid, eps=eps)
# s_cls_log_softmax = F.log_softmax(s_cls_logit / hyperparams["temperature"], dim=1)
# t_cls_logit = torch.logit(t_cls_sigmoid, eps=eps)
# t_cls_log_softmax = F.log_softmax(t_cls_logit / hyperparams["temperature"], dim=1)
# kl_div_loss = (
# kldivloss(s_cls_log_softmax, t_cls_log_softmax) *
# (hyperparams["temperature"]**2)
# )
# batch_box_regression_loss[i] = ciou_loss
# batch_cls_loss[i] = kl_div_loss
# if reduction == "batchmean":
# total_loss = (
# hyperparams["lambda_dist_ciou"] * batch_box_regression_loss.mean() +
# hyperparams["lambda_dist_kl"] * batch_cls_loss.mean()
# )
# else:
# total_loss = (
# hyperparams["lambda_dist_ciou"] * batch_box_regression_loss.sum() +
# hyperparams["lambda_dist_kl"] * batch_cls_loss.sum()
# )
# return total_loss
def compute_distillation_loss(
student_preds: torch.Tensor,
teacher_preds: torch.Tensor,
args: Dict[str, Any],
nc: int = 80,
device: str = "cpu",
eps: float = 1e-7,
reduction: Literal["batchmean", "sum"] = "batchmean",
hyperparams: Dict[str, float] = {
"lambda_dist_ciou": 1.0,
"lambda_dist_kl": 1.0,
"temperature": 2.0
}
) -> torch.Tensor:
if not isinstance(student_preds, torch.Tensor) or not isinstance(teacher_preds, torch.Tensor):
raise ValueError("student_preds and teacher_preds must be tensors")
batch_size, _, num_proposals = student_preds.shape
k = args.get("max_det", 300)
k = min(k, num_proposals)
temp = hyperparams.get("temperature", 2.0)
s_bbox, s_cls_logits = torch.split(student_preds, (4, nc), dim=1)
s_cls_sigmoid = torch.sigmoid(s_cls_logits)
student_preds_full = torch.cat((s_bbox, s_cls_sigmoid), dim=1)
assert torch.all(s_cls_sigmoid >= 0) and torch.all(s_cls_sigmoid <= 1), "s_cls_sigmoid should be sigmoid, not logits"
assert torch.all(teacher_preds[:, 4:, :] >= 0) and torch.all(teacher_preds[:, 4:, :] <= 1), "teacher_preds should be sigmoid, not logits"
t_bbox, t_cls = torch.split(teacher_preds, (4, nc), dim=1)
teacher_conf, _ = t_cls.max(dim=1) # (b, 8400)
_, topk_indices = torch.topk(teacher_conf, k, dim=1) # (b, k)
indices_to_gather = topk_indices.unsqueeze(1).expand(-1, 4 + nc, -1) # (batch_size, 4 + nc, k)
s_preds_topk = torch.gather(student_preds_full, 2, indices_to_gather) # (batch_size, 4 + nc, k)
t_preds_topk = torch.gather(teacher_preds, 2, indices_to_gather) # (batch_size, 4 + nc, k
s_preds = s_preds_topk.permute(0, 2, 1) # (batch_size, k, 4 + nc)
t_preds = t_preds_topk.permute(0, 2, 1) # (batch_size, k, 4 + nc)
s_bbox, s_cls = torch.split(s_preds, (4, nc), dim=2) # (batch_size, k, 4), (batch_size, k, nc)
t_bbox, t_cls = torch.split(t_preds, (4, nc), dim=2) # (batch_size, k, 4), (batch_size, k, nc)
# ensure the bounding boxes are valid
s_bbox = torch.stack([
torch.min(s_bbox[..., 0], s_bbox[..., 2]),
torch.min(s_bbox[..., 1], s_bbox[..., 3]),
torch.max(s_bbox[..., 0], s_bbox[..., 2]),
torch.max(s_bbox[..., 1], s_bbox[..., 3]),
], dim=-1)
t_bbox = torch.stack([
torch.min(t_bbox[..., 0], t_bbox[..., 2]),
torch.min(t_bbox[..., 1], t_bbox[..., 3]),
torch.max(t_bbox[..., 0], t_bbox[..., 2]),
torch.max(t_bbox[..., 1], t_bbox[..., 3]),
], dim=-1)
assert torch.all(s_bbox[..., 0] <= s_bbox[..., 2]), "x1 coordinate should be less than x2 coordinate"
assert torch.all(t_bbox[..., 0] <= t_bbox[..., 2]), "x1 coordinate should be less than x2 coordinate"
assert torch.all(s_bbox[..., 1] <= s_bbox[..., 3]), "y1 coordinate should be less than y2 coordinate"
assert torch.all(t_bbox[..., 1] <= t_bbox[..., 3]), "y1 coordinate should be less than y2 coordinate"
if reduction == "batchmean":
loss_ciou = complete_box_iou_loss(s_bbox, t_bbox, reduction="mean")
else:
loss_ciou = complete_box_iou_loss(s_bbox, t_bbox, reduction=reduction)
s_cls_logit = torch.logit(s_cls, eps=eps)
t_cls_logit = torch.logit(t_cls, eps=eps)
s_log_softmax = F.log_softmax(s_cls_logit / temp, dim=1)
t_log_softmax = F.log_softmax(t_cls_logit / temp, dim=1)
kldiv_loss_fn = nn.KLDivLoss(reduction=reduction, log_target=True)
loss_kldiv = kldiv_loss_fn(s_log_softmax, t_log_softmax) * (temp**2)
if reduction == "batchmean":
ciou_term = loss_ciou.mean()
kl_term = loss_kldiv.mean()
elif reduction == "sum":
ciou_term = loss_ciou.sum()
kl_term = loss_kldiv.sum()
else:
raise ValueError(f"Invalid reduction type: {reduction}")
total_loss = (
hyperparams["lambda_dist_ciou"] * ciou_term +
hyperparams["lambda_dist_kl"] * kl_term
)
return total_loss
[docs]
def calculate_gradient_norm(model: nn.Module) -> float:
"""
Calculate the total gradient norm across all parameters.
Args:
model: The model to calculate gradient norm for
Returns:
Total gradient norm as a float
"""
total_norm = 0.0
for p in model.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
return total_norm ** 0.5
[docs]
def log_training_metrics(
log_file: Path,
epoch: int,
batch_idx: Optional[int],
losses: Dict[str, float],
grad_norm_before: Optional[float] = None,
grad_norm_after: Optional[float] = None,
is_new_file: bool = False,
log_level: Literal["batch", "epoch"] = "epoch"
) -> None:
"""
Log training metrics to a CSV file.
Args:
log_file: Path to the log file
epoch: Current epoch number
batch_idx: Current batch index (None for epoch-level logging)
losses: Dictionary of loss values
grad_norm_before: Gradient norm before clipping
grad_norm_after: Gradient norm after clipping
is_new_file: Whether this is the first write to the file
log_level: Whether to log at batch or epoch level
"""
fieldnames = [
'timestamp', 'epoch', 'batch',
'total_loss', 'bbox_loss', 'cls_loss', 'dfl_loss', 'dist_loss',
'grad_norm_before', 'grad_norm_after'
]
# Prepare the row data
row = {
'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
'epoch': epoch,
'batch': batch_idx if log_level == "batch" else "epoch",
'total_loss': losses.get('total_loss', ''),
'bbox_loss': losses.get('bbox_loss', ''),
'cls_loss': losses.get('cls_loss', ''),
'dfl_loss': losses.get('dfl_loss', ''),
'dist_loss': losses.get('dist_loss', ''),
'grad_norm_before': grad_norm_before if grad_norm_before is not None else '',
'grad_norm_after': grad_norm_after if grad_norm_after is not None else ''
}
# Write to file
mode = 'w' if is_new_file else 'a'
with open(log_file, mode, newline='') as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
if is_new_file:
writer.writeheader()
writer.writerow(row)
[docs]
def train_epoch(
student_model: nn.Module,
teacher_model: nn.Module,
train_dataloader: DataLoader,
detection_trainer: DetectionTrainer,
optimizer: optim.Optimizer,
detection_criterion: v8DetectionLoss,
config_dict: Dict[str, Any],
device: str = "cpu",
nc: int = 5,
hyperparams: Dict[str, float] = {
"lambda_distillation": 2.0,
"lambda_detection": 1.0,
"lambda_dist_ciou": 1.0,
"lambda_dist_kl": 2.0
},
epoch: int = 1,
log_file: Optional[Path] = None,
log_level: Literal["batch", "epoch"] = "batch",
debug: bool = False
) -> Dict[str, float]:
"""
Train for one epoch.
"""
student_model.train()
teacher_model.eval()
batch_loss_dict = {
"total_loss": np.array([]),
"bbox_loss": np.array([]),
"cls_loss": np.array([]),
"dfl_loss": np.array([]),
"distillation_loss": np.array([]),
"grad_norm_before": np.array([]),
"grad_norm_after": np.array([])
}
for batch_idx, batch in enumerate(train_dataloader):
optimizer.zero_grad()
try:
preprocessed_batch = detection_trainer.preprocess_batch(batch)
inputs = preprocessed_batch["img"].to(device)
targets = preprocessed_batch["cls"].to(device)
# Additional validation after preprocessing
if torch.isnan(inputs).any() or torch.isinf(inputs).any():
print(f"NaN/Inf detected in preprocessed images at batch {batch_idx}")
continue
if torch.isnan(targets).any() or torch.isinf(targets).any():
print(f"NaN/Inf detected in preprocessed targets at batch {batch_idx}")
continue
student_head_feats = student_model(inputs)
detection_losses, detection_losses_detached = detection_criterion(preds=student_head_feats, batch=batch)
bbox_loss, cls_loss, dfl_loss = detection_losses_detached.cpu()
with torch.no_grad():
teacher_inputs = batch["img"].to(device)
teacher_preds, _ = teacher_model(teacher_inputs)
teacher_preds = teacher_preds.to(device)
student_preds = head_features_decoder(
head_feats=student_head_feats,
nc=nc,
detection_criterion=detection_criterion,
device=device
).to(device)
distillation_loss = compute_distillation_loss(
student_preds,
teacher_preds,
config_dict,
nc=nc,
device=device,
reduction="batchmean",
hyperparams=hyperparams
).to(device)
# Calculate total loss with proper scaling and type conversion
detection_loss = detection_losses.sum()
bbox_loss = bbox_loss.to(device)
cls_loss = cls_loss.to(device)
dfl_loss = dfl_loss.to(device)
# Debug print individual losses
if debug:
print(f"\nBatch {batch_idx} Loss Components:")
print(f"Detection Loss: {detection_loss.item():.4f}")
print(f"Bbox Loss: {bbox_loss.item():.4f}")
print(f"Cls Loss: {cls_loss.item():.4f}")
print(f"DFL Loss: {dfl_loss.item():.4f}")
print(f"Distillation Loss: {distillation_loss.item():.4f}")
# Calculate weighted components
weighted_detection = hyperparams["lambda_detection"] * detection_loss
weighted_dist = hyperparams["lambda_distillation"] * distillation_loss
# Debug print weighted components
if debug:
print(f"\nWeighted Components:")
print(f"Weighted Detection: {weighted_detection.item():.4f}")
print(f"Weighted Dist: {weighted_dist.item():.4f}")
# Calculate total loss
total_loss = weighted_detection + weighted_dist
if debug:
print(f"Total Loss: {total_loss.item():.4f}")
# Final NaN check before backward pass
if torch.isnan(total_loss).any() or torch.isinf(total_loss).any():
print(f"NaN detected in total loss at batch {batch_idx}")
print(f"Component losses: bbox={bbox_loss}, cls={cls_loss}, dfl={dfl_loss}, dist={distillation_loss}")
continue
total_loss.backward()
# Calculate gradient norm before clipping
grad_norm_before = None
# grad_norm_before = calculate_gradient_norm(student_model)
# Clip gradients to prevent exploding gradients
clip_grad_norm_(student_model.parameters(), max_norm=10.0)
# Calculate gradient norm after clipping
grad_norm_after = None
# grad_norm_after = calculate_gradient_norm(student_model)
optimizer.step()
# Store losses and gradient norms
batch_loss_dict["bbox_loss"] = np.append(batch_loss_dict["bbox_loss"], bbox_loss.cpu().detach().numpy())
batch_loss_dict["cls_loss"] = np.append(batch_loss_dict["cls_loss"], cls_loss.cpu().detach().numpy())
batch_loss_dict["dfl_loss"] = np.append(batch_loss_dict["dfl_loss"], dfl_loss.cpu().detach().numpy())
batch_loss_dict["distillation_loss"] = np.append(
batch_loss_dict["distillation_loss"],
distillation_loss.cpu().detach().numpy()
)
batch_loss_dict["total_loss"] = np.append(batch_loss_dict["total_loss"], total_loss.cpu().detach().numpy())
if grad_norm_before is not None:
batch_loss_dict["grad_norm_before"] = np.append(batch_loss_dict["grad_norm_before"], grad_norm_before)
if grad_norm_after is not None:
batch_loss_dict["grad_norm_after"] = np.append(batch_loss_dict["grad_norm_after"], grad_norm_after)
# Log metrics if log_file is provided and log_level is batch
if log_file is not None and log_level == "batch":
current_losses = {
'total_loss': float(total_loss.cpu().detach().numpy()),
'bbox_loss': float(bbox_loss.cpu().detach().numpy()),
'cls_loss': float(cls_loss.cpu().detach().numpy()),
'dfl_loss': float(dfl_loss.cpu().detach().numpy()),
'dist_loss': float(distillation_loss.cpu().detach().numpy())
}
log_training_metrics(
log_file=log_file,
epoch=epoch,
batch_idx=batch_idx,
losses=current_losses,
grad_norm_before=grad_norm_before,
grad_norm_after=grad_norm_after,
is_new_file=(epoch == 1 and batch_idx == 0),
log_level=log_level
)
except Exception as e:
print(f"Error processing batch {batch_idx}: {str(e)}")
continue
# Log epoch-level metrics if log_level is epoch
if log_file is not None and log_level == "epoch":
epoch_losses = {
'total_loss': float(np.mean(batch_loss_dict["total_loss"])),
'bbox_loss': float(np.mean(batch_loss_dict["bbox_loss"])),
'cls_loss': float(np.mean(batch_loss_dict["cls_loss"])),
'dfl_loss': float(np.mean(batch_loss_dict["dfl_loss"])),
'dist_loss': float(np.mean(batch_loss_dict["distillation_loss"]))
}
log_training_metrics(
log_file=log_file,
epoch=epoch,
batch_idx=None,
losses=epoch_losses,
grad_norm_before=float(np.mean(batch_loss_dict["grad_norm_before"])),
grad_norm_after=float(np.mean(batch_loss_dict["grad_norm_after"])),
is_new_file=(epoch == 1),
log_level=log_level
)
return batch_loss_dict
[docs]
def save_checkpoint(
checkpoint_dir: Path,
epoch: int,
student_model: nn.Module,
optimizer: optim.Optimizer,
learning_rate_scheduler: LambdaLR,
losses: Dict[str, float],
) -> None:
"""
Save model checkpoint.
Args:
checkpoint_dir: Directory to save checkpoint
epoch: Current epoch number
student_model: Student model to save
optimizer: Optimizer state to save
learning_rate_scheduler: Learning rate scheduler state to save
losses: Dictionary of loss values
"""
checkpoint = {
'epoch': epoch,
'model_state_dict': student_model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': learning_rate_scheduler.state_dict(),
'loss': losses['total_loss'],
'bbox_loss': losses['bbox_loss'],
'cls_loss': losses['cls_loss'],
'dfl_loss': losses['dfl_loss']
}
torch.save(checkpoint, checkpoint_dir / f'checkpoint_epoch_{epoch}.pt')
[docs]
def freeze_layers(model: nn.Module, num_layers: int = 10) -> None:
"""
Freeze the first n layers of the model.
For example, if num_layers = 10, the first 10 layers (The Backbone) will be frozen.
https://community.ultralytics.com/t/guidance-on-freezing-layers-for-yolov8x-seg-transfer-learning/189/2
https://github.com/ultralytics/ultralytics/blob/3e669d53067ff1ed97e0dad0a4063b156f66686d/ultralytics/engine/trainer.py#L258
Args:
model: The model to freeze layers in
num_layers: Number of layers to freeze from the start
"""
# Get all parameters
params = list(model.parameters())
# Freeze the first n layers
manual_freeze = [f"model.{i}." for i in range(num_layers)]
perm_freeze = ['.dfl.']
total_freeze = set(manual_freeze + perm_freeze)
for k, v in model.named_parameters():
for layer in total_freeze:
if layer in k:
v.requires_grad = False
# Print which layers are frozen
frozen_count = sum(1 for p in model.parameters() if not p.requires_grad)
total_count = sum(1 for p in model.parameters())
print(f"Frozen {frozen_count}/{total_count} layers in the model")
[docs]
def save_final_model(
model: nn.Module,
output_dir: Path,
model_name: str = "model.pt"
) -> None:
"""
Save the final model after training.
Args:
model: The model to save
output_dir: Directory to save the model
model_name: Name of the saved model file
"""
model_path = output_dir / model_name
output_dir.mkdir(parents=True, exist_ok=True)
# Save the model using YOLO's save method
model.save(model_path)
print(f"Final model saved to {model_path}")
[docs]
def train_loop(
num_epochs: int,
student_model: nn.Module,
student_yolo: YOLO, # Add YOLO model instance
teacher_model: nn.Module,
train_dataloader: DataLoader,
detection_trainer: DetectionTrainer,
detection_validator: DetectionValidator,
optimizer: optim.Optimizer,
stopper: EarlyStopping,
learning_rate_scheduler: LambdaLR,
detection_criterion: v8DetectionLoss,
config_dict: Dict[str, Any],
device: str,
checkpoint_dir: Path,
save_checkpoint_every: int,
hyperparams: Dict[str, float] = {
"lambda_distillation": 2.0,
"lambda_detection": 1.0,
"lambda_dist_ciou": 1.0,
"lambda_dist_kl": 2.0
},
start_epoch: int = 1,
log_file: Optional[Path] = None,
log_level: Literal["batch", "epoch"] = "epoch",
final_model_dir: Path = Path("automl_workspace/model_registry/distilled/latest"),
debug: bool = False
) -> Dict[str, List[float]]:
"""
Execute the complete training process including all epochs.
Args:
num_epochs: Number of epochs to train
student_model: Student model to train
student_yolo: Student YOLO model instance for saving
teacher_model: Teacher model for distillation
train_dataloader: DataLoader for training data
detection_trainer: Detection trainer instance
optimizer: Optimizer for training
learning_rate_scheduler: Learning rate scheduler
detection_criterion: Detection loss criterion
config_dict: Configuration dictionary
device: Device to train on
checkpoint_dir: Directory to save checkpoints
save_checkpoint_every: Save checkpoint every n epochs
hyperparams: Dictionary of hyperparameters for loss functions
start_epoch: Start training from this epoch
log_file: Optional path to log file for metrics
log_level: Whether to log at batch or epoch level
final_model_dir: Directory to save final model
debug: Whether to print debug information
Returns:
Dictionary containing lists of loss values for each epoch
"""
epoch_losses = {
'total_loss': [],
'bbox_loss': [],
'cls_loss': [],
'dfl_loss': [],
'dist_loss': []
}
# Track best model state and fitness
best_fitness = -float('inf')
best_model_state = None
student_model_copy = None
for epoch in tqdm(range(1, num_epochs + 1), desc="Epochs", position=0):
# Train one epoch
batch_loss_dict = train_epoch(
student_model=student_model,
teacher_model=teacher_model,
train_dataloader=train_dataloader,
detection_trainer=detection_trainer,
optimizer=optimizer,
detection_criterion=detection_criterion,
config_dict=config_dict,
device=device,
hyperparams=hyperparams,
epoch=epoch,
log_file=log_file,
log_level=log_level,
debug=debug
)
learning_rate_scheduler.step()
# Calculate average losses
batch_loss_bbox = np.mean(batch_loss_dict["bbox_loss"]).round(4)
batch_loss_cls = np.mean(batch_loss_dict["cls_loss"]).round(4)
batch_loss_dfl = np.mean(batch_loss_dict["dfl_loss"]).round(4)
batch_loss_dist = np.mean(batch_loss_dict["distillation_loss"]).round(4)
batch_loss_total = batch_loss_bbox + batch_loss_cls + batch_loss_dfl + batch_loss_dist
# Store losses
epoch_losses['total_loss'].append(batch_loss_total)
epoch_losses['bbox_loss'].append(batch_loss_bbox)
epoch_losses['cls_loss'].append(batch_loss_cls)
epoch_losses['dfl_loss'].append(batch_loss_dfl)
epoch_losses['dist_loss'].append(batch_loss_dist)
print(
f"Epoch {epoch}: (Overall: {batch_loss_total}, bbox_loss: {batch_loss_bbox}, "
f"cls_loss: {batch_loss_cls}, dfl_loss: {batch_loss_dfl}, dist_loss: {batch_loss_dist})"
)
# Save checkpoint
if save_checkpoint_every > 0 and epoch % save_checkpoint_every == 0:
save_checkpoint(
checkpoint_dir=checkpoint_dir,
epoch=epoch,
student_model=student_model,
optimizer=optimizer,
learning_rate_scheduler=learning_rate_scheduler,
losses={
'total_loss': batch_loss_total,
'bbox_loss': batch_loss_bbox,
'cls_loss': batch_loss_cls,
'dfl_loss': batch_loss_dfl
}
)
# This is a hack to get around the fact that the student model is in eval mode
# after the validation step, and it will not be able to train again in the next epoch
# so we create a deep copy of the student model for validation
with torch.no_grad():
student_model_copy = copy.deepcopy(student_model)
validation_results = detection_validator(model=student_model_copy)
current_fitness = validation_results["fitness"]
# Update best model if current fitness is better
if current_fitness > best_fitness:
best_fitness = current_fitness
best_model_state = copy.deepcopy(student_model.state_dict())
print(f"New best model found with fitness: {best_fitness:.4f}")
stop = stopper(epoch=epoch, fitness=current_fitness)
if stop:
print(f"Early stopping triggered at epoch {epoch}")
print(f"Restoring best model with fitness: {best_fitness:.4f}")
# Restore best model state
student_model.load_state_dict(best_model_state)
# Save the best model
save_final_model(student_yolo, final_model_dir)
return epoch_losses
# If training completes without early stopping, save final model
save_final_model(student_yolo, final_model_dir)
return epoch_losses
[docs]
def build_optimizer_and_scheduler(
model: DetectionModel,
detection_trainer: DetectionTrainer,
model_args: Dict[str, Any]
) -> Tuple[optim.Optimizer, LambdaLR]:
"""
Build the optimizer and learning rate scheduler.
Args:
model: DetectionModel instance
detection_trainer: DetectionTrainer instance
model_args: Model arguments
Returns:
Tuple of optimizer and learning rate scheduler
"""
optimizer = detection_trainer.build_optimizer(
model=model,
name=model_args.optimizer,
lr=model_args.lr0,
momentum=model_args.momentum,
decay=model_args.weight_decay,
)
# https://github.com/ultralytics/ultralytics/blob/487e27639595047cff8775dab5e2ff268d8647c4/ultralytics/engine/trainer.py#L229
if model_args.cos_lr:
lambda_func = one_cycle(1, model_args.lrf, model_args.epochs)
else:
lambda_func = lambda x: max(1 - x / model_args.epochs, 0) * (1.0 - model_args.lrf) + model_args.lrf
learning_rate_scheduler = LambdaLR(
optimizer=optimizer,
lr_lambda=lambda_func,
)
return optimizer, learning_rate_scheduler
[docs]
def load_checkpoint(
checkpoint_path: Path,
student_model: nn.Module,
optimizer: optim.Optimizer,
learning_rate_scheduler: LambdaLR,
device: str = "cpu"
) -> int:
"""
Load a checkpoint and restore model and optimizer state.
Args:
checkpoint_path: Path to the checkpoint file
student_model: Student model to restore state to
optimizer: Optimizer to restore state to
learning_rate_scheduler: Learning rate scheduler to restore state to
Returns:
The epoch number from the checkpoint
"""
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=device)
# Restore model state
student_model.load_state_dict(checkpoint['model_state_dict'])
# Restore optimizer state
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# Restore learning rate scheduler state if it exists
if 'scheduler_state_dict' in checkpoint:
learning_rate_scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
return checkpoint['epoch']
[docs]
def start_distillation(
device: str = "cpu",
base_dir: Path = Path(".."),
img_dir: Path = Path("dataset"),
save_checkpoint_every: int = 25,
frozen_layers: int = 10, # freeze the Backbone layers
hyperparams: Dict[str, float] = {
"lambda_distillation": 2.0,
"lambda_detection": 1.0,
"lambda_dist_ciou": 1.0,
"lambda_dist_kl": 2.0
},
resume_checkpoint: Optional[Path] = None,
output_dir: Path = Path("distillation_out"),
final_model_dir: Path = Path("automl_workspace/model_registry/distilled/latest"),
log_level: Literal["batch", "epoch"] = "batch",
debug: bool = False,
distillation_config: Optional[Dict[str, Any]] = None,
pipeline_config: Optional[Dict[str, Any]] = None
) -> Dict[str, List[float]]:
"""
Start the distillation training process.
Args:
device: Device to train on
base_dir: Base directory for paths (should be SCRIPT_DIR from main.py)
img_dir: Directory containing training images
save_checkpoint_every: Save checkpoint every n epochs
frozen_layers: Number of layers to freeze in the backbone
hyperparams: Dictionary of hyperparameters for loss functions
resume_checkpoint: Optional path to checkpoint to resume training from
output_dir: Directory to save output
final_model_dir: Directory to save final model
log_level: Whether to log at batch or epoch level
debug: Whether to print debug information
distillation_config: Configuration dictionary for distillation
pipeline_config: Configuration dictionary for pipeline
Returns:
Dictionary containing lists of loss values for each epoch
"""
if distillation_config is None:
raise ValueError("distillation_config is required")
# Ensure distillation dataset directories exist
distillation_base_dir = Path(distillation_config["distillation_dataset"])
distillation_dataset_dir = distillation_base_dir / "distillation_dataset"
distillation_dataset_dir.mkdir(parents=True, exist_ok=True)
# Create train and valid directories if they don't exist
(distillation_dataset_dir / "train").mkdir(exist_ok=True)
(distillation_dataset_dir / "valid").mkdir(exist_ok=True)
# Create the distillation_data.yaml file
yaml_path = distillation_base_dir / "distillation_data.yaml"
create_distill_yaml(
output_dir=str(distillation_dataset_dir),
yaml_path=str(yaml_path)
)
# Create output directory structure
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
output_dir = base_dir / output_dir / timestamp
logs_dir = output_dir / "logs"
checkpoint_dir = output_dir / "checkpoints"
# Create directories
output_dir.mkdir(parents=True, exist_ok=True)
logs_dir.mkdir(exist_ok=True)
checkpoint_dir.mkdir(exist_ok=True)
# Create log file
log_file = logs_dir / f'training_log_{timestamp}.csv'
# Load models
teacher_yolo, student_yolo = load_models(device, base_dir, distillation_config)
teacher_model = teacher_yolo.model
student_model = student_yolo.model
# Use distillation config for model args
model_args = get_cfg(distillation_config)
model_args.mode = "train"
BATCH_SIZE = model_args.batch
EPOCHS = model_args.epochs
# Freeze backbone layers if specified
if frozen_layers > 0:
freeze_layers(student_model, frozen_layers)
# Prepare dataset
train_dataset, train_dataloader = prepare_dataset(
img_path=img_dir / "train",
student_model=student_model,
batch_size=BATCH_SIZE,
mode="train"
)
valid_dataset, valid_dataloader = prepare_dataset(
img_path=img_dir / "valid",
student_model=student_model,
batch_size=BATCH_SIZE,
mode="val"
)
# Setup training
detection_trainer = DetectionTrainer(
cfg=model_args,
overrides={"data": Path(distillation_config["distillation_dataset"]) / "distillation_data.yaml"}
)
distillation_config_copy = copy.deepcopy(distillation_config)
# Remove keys that are not needed for validation (they will throw errors since YOLO validate these keys)
del distillation_config_copy["teacher_model"]
del distillation_config_copy["distillation_dataset"]
del distillation_config_copy["distillation_hyperparams"]
distillation_config_copy["mode"] = "val"
distillation_config_copy["data"] = Path(distillation_config["distillation_dataset"]) / "distillation_data.yaml"
model_args_validator = get_cfg(distillation_config_copy)
detection_validator = DetectionValidator(dataloader=valid_dataloader, args=model_args_validator)
stopper = EarlyStopping(patience=model_args_validator.patience)
optimizer, learning_rate_scheduler = build_optimizer_and_scheduler(
model=student_model,
detection_trainer=detection_trainer,
model_args=model_args
)
detection_criterion = v8DetectionLoss(model=student_model)
# Load checkpoint if specified
start_epoch = 1
if resume_checkpoint is not None:
print(f"Resuming from checkpoint: {resume_checkpoint}")
start_epoch = load_checkpoint(
checkpoint_path=resume_checkpoint,
student_model=student_model,
optimizer=optimizer,
learning_rate_scheduler=learning_rate_scheduler,
device=device
) + 1 # Start from next epoch
print(f"Resuming training and distillation from epoch {start_epoch}")
else:
print("Starting training and distillation from scratch")
# Run training loop
return train_loop(
num_epochs=EPOCHS,
student_model=student_model,
student_yolo=student_yolo, # Pass the YOLO model instance
teacher_model=teacher_model,
train_dataloader=train_dataloader,
detection_trainer=detection_trainer,
detection_validator=detection_validator,
optimizer=optimizer,
stopper=stopper,
learning_rate_scheduler=learning_rate_scheduler,
detection_criterion=detection_criterion,
config_dict=model_args,
device=device,
checkpoint_dir=checkpoint_dir,
save_checkpoint_every=save_checkpoint_every,
hyperparams=hyperparams,
start_epoch=start_epoch,
log_file=log_file,
log_level=log_level,
final_model_dir=final_model_dir,
debug=debug
)
if __name__ == "__main__":
# Load configurations
base_dir = Path(__file__).parent.parent.parent
distillation_config = YAML.load(base_dir / "distillation_config.yaml")
hyperparams = {
"lambda_distillation": 1.0,
"lambda_detection": 2.0,
"lambda_dist_ciou": 1.0,
"lambda_dist_kl": 1.0,
"temperature": 2.0
}
start_distillation(
device=detect_device(),
base_dir=base_dir,
img_dir=Path("automl_workspace/data_pipeline/distillation"),
frozen_layers=10,
save_checkpoint_every=25,
hyperparams=hyperparams,
resume_checkpoint=None,
output_dir=Path("automl_workspace/model_registry/distilled"),
final_model_dir=Path("automl_workspace/model_registry/distilled/latest"),
log_level="epoch",
debug=False,
distillation_config=distillation_config
)