Source code for detectools.train.trainer

import shutil
from pathlib import Path
from typing import Dict, List, Literal, Tuple

import torch
from torch import Tensor
from torch.optim import Optimizer
from torch.utils.tensorboard import SummaryWriter
from torchmetrics import Metric
from tqdm import tqdm

from detectools.data.dataset import DetectionLoader
from detectools.models.base import BaseModel
from detectools.formats import BatchedFormats
from detectools.train.loss_aggregator import Aggregator


[docs] class Trainer: """Trainer wrap the whole training process. Args: model (``BaseModel``): Detectools model. otpimizer (``Optimizer``): Optimizer (from ``torch.optim``). log_dir (``str``, **optional**): Path to store tensorboard logs. Defaults to "" (no log). metrics (``List[Metric]``, **optional**): List of detectools metrics that will be computed at valid. Defaults to []. device (``Literal['cpu', 'cuda']``, **optional**): Device to use for trainning. Defaults to "cpu". nms_thr (``float``, **optional**): IoU threshold to consider 2 boxes as overlapping in NMS algorithm using for valid loop. Defaults to 0.45. confidence_thr (``float``, **optional**): Minimum confidence score for each predicted object to be kept, used for valid loop. Defaults to 0.5. Example: Trainning loop. ----------------------------- .. highlight:: python .. code-block:: python >>> from detectools.dataset import DetectionDataset, DetectionLoader >>> from torch.utils.data import random_split >>> import torch >>> trainer : Trainer # Trainer already built >>> data_path = \"path/to/data\" >>> dataset = DetectionDataset(data_path) >>> train_set, valid_set, test_set = random_split(dataset, 0.8,0.2,0.0) >>> train_loader, valid_loader = DetectionLoader(train_set), DetectionLoader(valid_set) >>> for e in range(10): >>> trainer.train_epoch(train_loader) >>> trainer.valid_epoch(valid_loader) >>> torch.save(trainer.model, "/model/path/model.pth") Attributes: ----------- Attributes: model (``BaseModel``): Detectools model. otpimizer (``Optimizer``): Optimizer (from ``torch.optim``). log_dir (``str``, **optional**): Path to store tensorboard logs. metrics (``List[Metric]``, **optional**): List of detectools metrics that will be computed at valid. device (``Literal['cpu', 'cuda']``, **optional**): Device to use for trainning. nms_thr (``float``, **optional**): IoU threshold to consider 2 boxes as overlapping in NMS algorithm using for valid loop. confidence_thr (``float``, **optional**): Minimum confidence score for each predicted object to be kept, used for valid loop. Methods ---------- """ model: BaseModel otpimizer: Optimizer log_dir: str metrics: List[Metric] device: Literal["cpu", "cuda"] nms_threshold: float confidence_threshold: float def __init__( self, model: BaseModel, otpimizer: Optimizer, log_dir: str = "", metrics: List[Metric] = [], device: Literal["cpu", "cuda"] = "cpu", nms_threshold: float = 0.45, confidence_threshold: float = 0.5, ): self.model = model self.model.to_device(device) self.model.confidence_thr = confidence_threshold self.model.nms_threshold = nms_threshold self.optimizer = otpimizer metrics = [metric.to(device) for metric in metrics] self.metrics = metrics self.log_dir = log_dir self.device = device self.nms_threshold = nms_threshold self.confidence_threshold = confidence_threshold # create log dir and board for tensorboard if log_dir: # if log dir exist remove it if Path(log_dir).exists(): shutil.rmtree(log_dir) Path(log_dir).mkdir(parents=True) self.board = SummaryWriter(log_dir) else: self.board = False
[docs] def train_step(self, images: Tensor, targets: BatchedFormats) -> Dict[str, Tensor]: """Run forward pass, loss computation and backward pass. Args: images (``Tensor``): Batch images targets (``BatchedFormats``): Batch targets. Returns: ``Dict[str, Tensor]``: - Dict of losses containing (total loss at key 'loss'). """ loss_dict = self.model.run_forward(images, targets, predict=False) loss = loss_dict["loss"] loss.backward() self.optimizer.step() self.optimizer.zero_grad() return loss_dict
[docs] def compute_metrics( self, predictions: BatchedFormats, targets: BatchedFormats ) -> Dict[str, Tensor]: """Compute metrics. Args: predictions (BatchedFormats): Predictions. targets (BatchedFormats): Targets. Args: predictions (``BatchedFormats``): Predictions. targets (``BatchedFormats``): Targets. Returns: ``Dict[str, Tensor]``: - Metric values. """ # split predictions to compute metric individual images # TODO make batchification possible splitted_predictions = predictions.split() splitted_targets = targets.split() # for each pair pred/target for i, target in enumerate(splitted_targets): pred = splitted_predictions[i] # update each metrics for metric in self.metrics: metric.update(pred, target) # after all updates recompute to get averaged values of metrics metric_dict = {} for metric in self.metrics: results = metric.compute() metric_dict.update({metric.name: results}) return metric_dict
[docs] def valid_step( self, images: Tensor, targets: BatchedFormats ) -> Tuple[Dict[str, Tensor], Dict[str, Dict[str, Tensor]]]: """Run train step, return loss dict. Args: images (``Tensor``): Batch images. targets (``BatchedFormats``): Targets. Returns: ``Tuple[Dict[str, Tensor], Dict[str, Dict[str, Tensor]]]``: - Losses and metrics values. """ if self.model.training: self.model.eval() loss_dict, predictions = self.model.run_forward(images, targets, predict=True) predictions.apply("set_device", self.device) metrics = self.compute_metrics(predictions, targets) return loss_dict, metrics
[docs] def log_string(self, epoch_dict: Dict[str, Tensor]) -> str: """Transform epoch dict in string. Args: epoch_dict (``Dict[str, Tensor]``): Dict of epoch values to display. Returns: ``str``: - String to print with epoch values. """ flattened_dict = epoch_dict.copy() for key, value in flattened_dict.items(): if isinstance(value, dict): flattened_dict[key] = value[list(value.keys())[0]] log = "" for key, value in flattened_dict.items(): log += f"{key} : {str(round(value.item(), 3))} " return log
[docs] def epoch( self, loader: DetectionLoader, ep_number: int, mode: Literal["Train", "Valid"] = "Train", tag: str = "", ) -> Dict[str, Tensor]: """Run trainning epoch. Args: loader (``DetectionLoader``): DetectionLoader. ep_number (``int``): Epoch number. mode (``Literal['Train', 'Valid']``, **optional**): Mode of epoch, if "Valid" metrics are computed and gradient shuted down. Defaults to "Train". tag (``str``, **optional**): Tag to link to epoch. Defaults to "". Returns: ``Dict[str, Tensor]``: - Epochs values (Losses & metrics). """ # create aggregator for loss averged accros samples loss_aggregator = Aggregator() iterator = tqdm(loader, total=len(loader), desc=f"Epoch {ep_number}/{tag}") # iterate over batches for images, targets, _ in iterator: batch_size = images.shape[0] # send to device images = images.to(self.device) targets: BatchedFormats targets.apply("set_device", self.device) # gather loss & metrics (if valid) if mode == "Train": loss_dict = self.train_step(images, targets) loss_aggregator(loss_dict, batch_size) epoch_dict = loss_aggregator.compute() elif mode == "Valid": loss_dict, metric_dict = self.valid_step(images, targets) loss_aggregator(loss_dict, batch_size) epoch_dict = loss_aggregator.compute() epoch_dict.update(metric_dict) # extract str from log to display in terminal log_str = self.log_string(epoch_dict) iterator.set_postfix_str(f"{log_str}") # reset all metrics for metric in self.metrics: metric.reset() # write tensorboard if self.log_dir: for key, value in epoch_dict.items(): if isinstance(value, dict): self.board.add_scalars(key, value, ep_number) else: self.board.add_scalars(key, {tag: value}, ep_number) return epoch_dict
[docs] def train_epoch( self, loader: DetectionLoader, ep_number: int, tag: str = "Train", *args, **kwargs, ) -> Dict[str, Tensor]: """Run train epoch. Args: loader (``DetectionLoader``): DetectionLoader. ep_number (``int``): Epoch number. tag (``str``, **optional**): Tag to link to epoch. Defaults to "Train". Returns: ``Dict[str, Tensor]``: - Epochs values (Losses). """ torch.set_grad_enabled(True) self.model.train() epoch_dict = self.epoch( loader, ep_number, mode="Train", tag=tag, *args, **kwargs ) return epoch_dict
[docs] def valid_epoch( self, loader: DetectionLoader, ep_number: int, tag: str = "Valid", *args, **kwargs, ) -> Dict[str, Tensor]: """Run train epoch. Args: loader (``DetectionLoader``): DetectionLoader. ep_number (``int``): Epoch number. tag (``str``, **optional**): Tag to link to epoch. Defaults to "Valid". Returns: ``Dict[str, Tensor]``: - Epochs values (Losses & metrics). """ torch.set_grad_enabled(False) self.model.eval() epoch_dict = self.epoch( loader, ep_number, mode="Valid", tag=tag, *args, **kwargs ) torch.set_grad_enabled(True) return epoch_dict