diff --git a/contrastive-pretraining/mr_rate/mr_rate/mr_rate.py b/contrastive-pretraining/mr_rate/mr_rate/mr_rate.py index fe7b434..c0f347b 100644 --- a/contrastive-pretraining/mr_rate/mr_rate/mr_rate.py +++ b/contrastive-pretraining/mr_rate/mr_rate/mr_rate.py @@ -275,9 +275,11 @@ def _encode_visual_tokens(self, image, real_volume_mask, vis_proj_layer, merged = vis_proj_layer(enc) elif self.fusion_mode == "mid_cnn": - flat_img = rearrange(image, 'b r c d h w -> (b r) c d h w') - cnn_features = self.run_checkpoint(self.visual_transformer.forward_cnn, flat_img) - cnn_features = rearrange(cnn_features, '(b r) t h w d -> b r t h w d', r=r) + cnn_list = [] + for i in range(r): + feat_i = self.run_checkpoint(self.visual_transformer.forward_cnn, image[:, i]) + cnn_list.append(feat_i) + cnn_features = torch.stack(cnn_list, dim=1) m = real_volume_mask.view(b, r, 1, 1, 1, 1).to(cnn_features.dtype) merged = (cnn_features * m).sum(1) / m.sum(1).clamp(min=1.0) enc = self.run_checkpoint(self.visual_transformer.forward_transformer, merged) @@ -347,7 +349,7 @@ def load_state_dict(self, *args, **kwargs): def load(self, path): path = Path(path) assert path.exists() - pt = torch.load(str(path)) + pt = torch.load(str(path), map_location="cpu") clean_state = {} for k, v in pt.items(): if k.startswith("module."):