Trainers

class BaseTrainer[source]

Bases: ABC

Base trainer class for all models.

__init__(model, train_loader, val_loader=None, optimizer=None, scheduler=None, device='cpu', monitor_metric='loss', logger=None)[source]
Parameters:
  • model (Any) – Model to train

  • train_loader (DataLoader) – Training data loader

  • val_loader (DataLoader | None) – Validation data loader

  • optimizer (Optimizer | None) – Optimizer to use

  • scheduler (Any | None) – Learning rate scheduler

  • device (str) – Device to use for training

  • monitor_metric (str) – Metric name to monitor for early stopping and model saving

  • logger (BaseLogger | None) – Logger instance for experiment tracking

train_epoch()[source]

Train for one epoch.

Return type:

Dict[str, float]

abstract train_step(batch)[source]

Perform a single training step.

Parameters:

batch (Any)

Return type:

Dict[str, Tensor]

validate()[source]

Validate the model.

Return type:

Dict[str, float]

abstract validate_step(batch)[source]

Perform a single validation step.

Parameters:

batch (Any)

Return type:

Dict[str, Tensor]

train(num_epochs, save_path=None, early_stopping=None)[source]

Train the model for the specified number of epochs.

Parameters:
  • num_epochs (int)

  • save_path (str | None)

  • early_stopping (int | None)

Return type:

Dict[str, list]

class SegmentationTrainer[source]

Bases: BaseTrainer

Trainer class for semantic segmentation models.

__init__(model, train_loader, val_loader=None, optimizer=None, scheduler=None, device='cpu', metrics=None, ignore_index=None, monitor_metric='seg_loss', logger=None)[source]

Initialize trainer.

Parameters:
  • model – Model to train

  • train_loader – Training data loader

  • val_loader – Validation data loader

  • optimizer – Optimizer

  • scheduler – Learning rate scheduler

  • device – Device to use

  • metrics (List[Callable] | None) – List of metric functions to compute during validation

  • ignore_index (int | None) – Index to ignore in metrics computation

  • monitor_metric (str) – Metric to monitor for early stopping

  • logger (BaseLogger | None) – Logger instance for experiment tracking

train_step(batch)[source]

Perform a single training step.

Parameters:

batch (Tuple[Tensor, Tensor])

Return type:

Dict[str, Tensor]

validate_step(batch)[source]

Perform a single validation step.

Parameters:

batch (Tuple[Tensor, Tensor])

Return type:

Dict[str, Tensor]

Segmentation Trainer

class SegmentationTrainer[source]

Bases: BaseTrainer

Trainer class for semantic segmentation models.

__init__(model, train_loader, val_loader=None, optimizer=None, scheduler=None, device='cpu', metrics=None, ignore_index=None, monitor_metric='seg_loss', logger=None)[source]

Initialize trainer.

Parameters:
  • model – Model to train

  • train_loader – Training data loader

  • val_loader – Validation data loader

  • optimizer – Optimizer

  • scheduler – Learning rate scheduler

  • device – Device to use

  • metrics (List[Callable] | None) – List of metric functions to compute during validation

  • ignore_index (int | None) – Index to ignore in metrics computation

  • monitor_metric (str) – Metric to monitor for early stopping

  • logger (BaseLogger | None) – Logger instance for experiment tracking

train_step(batch)[source]

Perform a single training step.

Parameters:

batch (Tuple[Tensor, Tensor])

Return type:

Dict[str, Tensor]

validate_step(batch)[source]

Perform a single validation step.

Parameters:

batch (Tuple[Tensor, Tensor])

Return type:

Dict[str, Tensor]

Base Trainer

class BaseTrainer[source]

Bases: ABC

Base trainer class for all models.

__init__(model, train_loader, val_loader=None, optimizer=None, scheduler=None, device='cpu', monitor_metric='loss', logger=None)[source]
Parameters:
  • model (Any) – Model to train

  • train_loader (DataLoader) – Training data loader

  • val_loader (DataLoader | None) – Validation data loader

  • optimizer (Optimizer | None) – Optimizer to use

  • scheduler (Any | None) – Learning rate scheduler

  • device (str) – Device to use for training

  • monitor_metric (str) – Metric name to monitor for early stopping and model saving

  • logger (BaseLogger | None) – Logger instance for experiment tracking

train_epoch()[source]

Train for one epoch.

Return type:

Dict[str, float]

abstract train_step(batch)[source]

Perform a single training step.

Parameters:

batch (Any)

Return type:

Dict[str, Tensor]

validate()[source]

Validate the model.

Return type:

Dict[str, float]

abstract validate_step(batch)[source]

Perform a single validation step.

Parameters:

batch (Any)

Return type:

Dict[str, Tensor]

train(num_epochs, save_path=None, early_stopping=None)[source]

Train the model for the specified number of epochs.

Parameters:
  • num_epochs (int)

  • save_path (str | None)

  • early_stopping (int | None)

Return type:

Dict[str, list]