Skip to content

Why doesn't the AutoEncoder tutorial use Lightning? #266

@taimoorsohail

Description

@taimoorsohail

Hi everyone,

I am building an ocean-dataset-based Autoencoder off the Autoencoder tutorial - very helpful thank you!

But I am curious why the tutorial as it is uses PET for setting up the pipeline, but then defines the Autoencoder and trains using torch.nn and related modules. Based on my read of current PET implementation, we could instead define a Lightning wrapper like (ignore the masked implementation, that is specific to my oceans routine):

import lightning as L
import torch
import torch.nn as nn
import torch.optim as optim

class LitAutoEncoder(L.LightningModule):
    def __init__(self, model, mask, lr=1e-4):
        super().__init__()
        self.model = model
        self.lr = lr
        self.criterion = nn.L1Loss()

        # Persist mask with device movement/checkpoints
        self.register_buffer(
            "mask",
            torch.as_tensor(mask, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
        )

    def forward(self, x):
        # x: [B, C, H, W]
        mask = self.mask.expand(x.shape[0], 1, x.shape[2], x.shape[3])
        return self.model(x, mask)

    def training_step(self, batch, batch_idx):
        x = x[:, 0]
        if isinstance(x, (tuple, list)):
            x = x[0]

        x_hat = self.forward(x)

        # optional: ocean-only loss
        mask = self.mask.expand_as(x[:, :1])
        loss = torch.abs(x_hat - x)
        loss = (loss * mask).sum() / mask.sum().clamp_min(1.0)

        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x = batch

        if isinstance(x, (tuple, list)):
            x = x[0]

        x_hat = self.forward(x)

        mask = self.mask.expand_as(x[:, :1])
        loss = torch.abs(x_hat - x)
        loss = (loss * mask).sum() / mask.sum().clamp_min(1.0)

        self.log("valid_loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.lr)

then define a base model and data module as:

base_model = AutoEncoder(
    input_channel_count=2,
    output_channel_count=2,
)

lightning_model = LitAutoEncoder(
    model=base_model,
    mask=mask.values,
    lr=1e-4,
)

datamodule = pyearthtools.training.data.lightning.PipelineLightningDataModule(
    pipeline_i,
    **splits,
    batch_size=8,
    num_workers=0,
)

and then trains and fits:

trainer = pyearthtools.training.lightning.Train(
    lightning_model,
    datamodule,
    path="/g/data/v46/txs156/OM2-emulator/data/",
    trainer_kwargs={
        "max_epochs": 10,
        "num_sanity_val_steps": 1,
    },
)

trainer.fit()

Am I missing something here? Perhaps the aim is to keep the tutorial focussed only on pipeline preparation and not the other implemented PET routines?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions