Hi everyone,
I am building an ocean-dataset-based Autoencoder off the Autoencoder tutorial - very helpful thank you!
But I am curious why the tutorial as it is uses PET for setting up the pipeline, but then defines the Autoencoder and trains using torch.nn and related modules. Based on my read of current PET implementation, we could instead define a Lightning wrapper like (ignore the masked implementation, that is specific to my oceans routine):
import lightning as L
import torch
import torch.nn as nn
import torch.optim as optim
class LitAutoEncoder(L.LightningModule):
def __init__(self, model, mask, lr=1e-4):
super().__init__()
self.model = model
self.lr = lr
self.criterion = nn.L1Loss()
# Persist mask with device movement/checkpoints
self.register_buffer(
"mask",
torch.as_tensor(mask, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
)
def forward(self, x):
# x: [B, C, H, W]
mask = self.mask.expand(x.shape[0], 1, x.shape[2], x.shape[3])
return self.model(x, mask)
def training_step(self, batch, batch_idx):
x = x[:, 0]
if isinstance(x, (tuple, list)):
x = x[0]
x_hat = self.forward(x)
# optional: ocean-only loss
mask = self.mask.expand_as(x[:, :1])
loss = torch.abs(x_hat - x)
loss = (loss * mask).sum() / mask.sum().clamp_min(1.0)
self.log("train_loss", loss, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
x = batch
if isinstance(x, (tuple, list)):
x = x[0]
x_hat = self.forward(x)
mask = self.mask.expand_as(x[:, :1])
loss = torch.abs(x_hat - x)
loss = (loss * mask).sum() / mask.sum().clamp_min(1.0)
self.log("valid_loss", loss, prog_bar=True)
return loss
def configure_optimizers(self):
return optim.Adam(self.parameters(), lr=self.lr)
then define a base model and data module as:
base_model = AutoEncoder(
input_channel_count=2,
output_channel_count=2,
)
lightning_model = LitAutoEncoder(
model=base_model,
mask=mask.values,
lr=1e-4,
)
datamodule = pyearthtools.training.data.lightning.PipelineLightningDataModule(
pipeline_i,
**splits,
batch_size=8,
num_workers=0,
)
and then trains and fits:
trainer = pyearthtools.training.lightning.Train(
lightning_model,
datamodule,
path="/g/data/v46/txs156/OM2-emulator/data/",
trainer_kwargs={
"max_epochs": 10,
"num_sanity_val_steps": 1,
},
)
trainer.fit()
Am I missing something here? Perhaps the aim is to keep the tutorial focussed only on pipeline preparation and not the other implemented PET routines?
Hi everyone,
I am building an ocean-dataset-based Autoencoder off the Autoencoder tutorial - very helpful thank you!
But I am curious why the tutorial as it is uses PET for setting up the pipeline, but then defines the Autoencoder and trains using
torch.nnand related modules. Based on my read of current PET implementation, we could instead define a Lightning wrapper like (ignore the masked implementation, that is specific to my oceans routine):then define a base model and data module as:
and then trains and fits:
Am I missing something here? Perhaps the aim is to keep the tutorial focussed only on pipeline preparation and not the other implemented PET routines?