diff --git a/tensorpack/coupled.py b/tensorpack/coupled.py index fab8342..52648c4 100644 --- a/tensorpack/coupled.py +++ b/tensorpack/coupled.py @@ -253,7 +253,7 @@ def normalize_factors(self, method="max"): for mmode in self.modes: sol = self.x["_" + mmode] if method == "max": # the one with maximum absolute value - norm_vec = np.array([max(sol[:, ii].min(), sol[:, ii].max(), key=abs) for ii in range(sol.shape[1])]) + norm_vec = abs(sol).max(axis=0).to_numpy() elif method == "norm": norm_vec = norm(sol, axis=0) else: