diff --git a/src/virtual_stain_flow/vsf_logging/MlflowLogger.py b/src/virtual_stain_flow/vsf_logging/MlflowLogger.py index a9164da..f91bb10 100644 --- a/src/virtual_stain_flow/vsf_logging/MlflowLogger.py +++ b/src/virtual_stain_flow/vsf_logging/MlflowLogger.py @@ -9,6 +9,7 @@ import mlflow from torch import nn +from torch.optim import Optimizer from ..models.model import BaseModel from ..trainers.trainer_protocol import TrainerProtocol @@ -34,6 +35,7 @@ class MlflowLogger: """ def __init__( self, + *, name: str, experiment_name: str, tracking_uri: Optional[path_type] = None, @@ -228,6 +230,33 @@ def on_train_start(self): except Exception as e: print(f"Fail to log model config as artifact: {e}") + optimizers = self._get_optimizers() + for idx, optimizer in enumerate(optimizers): + if not isinstance(optimizer, Optimizer): + continue + try: + opt_config = { + "class_path": f"{optimizer.__class__.__module__}.{optimizer.__class__.__name__}", + "defaults": dict(optimizer.defaults), + } + except Exception as e: + print(f"Could not get optimizer config for logging: {e}") + opt_config = None + + if opt_config: + mlflow.set_tag( + f"optimizer.{idx}.class_path", + str(opt_config.get("class_path")) + ) + try: + self.log_config( + tag=f"optimizer_{idx}", + config=opt_config, + stage=None + ) + except Exception as e: + print(f"Fail to log optimizer config as artifact: {e}") + self._log_loss_groups_config_and_tags() @@ -540,6 +569,26 @@ def _get_loss_groups(self) -> Dict[str, Any]: return loss_groups + def _get_optimizers(self) -> List[Optimizer]: + """ + Discover optimizer(s) attached to the bound trainer. + """ + + if self.trainer is None: + return [] + + optimizers: List[Optimizer] = [] + + explicit_optimizers = getattr(self.trainer, 'optimizers', None) + if isinstance(explicit_optimizers, list): + optimizers.extend(explicit_optimizers) + + explicit_optimizer = getattr(self.trainer, 'optimizer', None) + if explicit_optimizer is not None: + optimizers.append(explicit_optimizer) + + return optimizers + def _log_loss_groups_config_and_tags(self) -> None: """ Log loss item names and weights as flattened string mlflow tags and