Skip to content

Cross Zamirski Model Training#26

Open
MattsonCam wants to merge 8 commits into
wgan_gp_cross_zamirskifrom
wgan_gp_cross_zamirski_review
Open

Cross Zamirski Model Training#26
MattsonCam wants to merge 8 commits into
wgan_gp_cross_zamirskifrom
wgan_gp_cross_zamirski_review

Conversation

@MattsonCam

Copy link
Copy Markdown
Member

This pr includes the cross zamirski model and the structure needed for training. It also include per-batch logging and removes irrelevant code. This code may change in the future to allow for training on Alpine due to the cuda memory constraint. As a result batch size has been reduced.

Cameron Mattson added 8 commits June 8, 2026 13:21
wasserstein GAN GP models
Wrap the discriminator-step generator forward in torch.no_grad()
and remove the now-redundant detach on the fake samples passed to
the critic loss.

This preserves the two-step WGAN-GP training behavior while avoiding
construction of an unnecessary generator autograd graph during the
critic update, reducing memory and compute overhead.
@review-notebook-app

Copy link
Copy Markdown

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@MattsonCam MattsonCam requested a review from wli51 June 8, 2026 21:14

@wli51 wli51 left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Maybe the trainer should support differential stepping frequency between discriminator and generator?

Comment thread trainers/WGANGPTrainer.py
Comment on lines +90 to +122
with torch.no_grad():
fake_targets_for_discriminator = self.image_postprocessor(
self.generator(inputs)
)
discriminator_outputs = self.discriminator_loss(
critic=self.discriminator,
real_samples=targets,
fake_samples=fake_targets_for_discriminator,
)
discriminator_loss, discriminator_components = self._detach_components(
discriminator_outputs
)

self.discriminator_optimizer.zero_grad()
discriminator_loss.backward()
self.discriminator_optimizer.step()

generated_predictions = self.image_postprocessor(self.generator(inputs))
fake_classification_outputs = self.discriminator(generated_predictions)
generator_outputs = self.generator_loss(
fake_classification_outputs=fake_classification_outputs,
generated_predictions=generated_predictions,
targets=targets,
epoch=epoch,
loss_mask=batch_data.get("loss_mask"),
)
generator_loss, generator_components = self._detach_components(
generator_outputs
)

self.generator_optimizer.zero_grad()
generator_loss.backward()
self.generator_optimizer.step()

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't remember the Cross-Zamirski trainer implementation that well, is equal number of update frequencies what they decided on. I believe in classical wGAN training the discriminator gets updated more frequently than the generator for stability.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're right, I will update this. I know if degrades the loss contribution by normalize by the epoch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants