From e93fe087325c1435aaaea59968ab24e033d7ab3b Mon Sep 17 00:00:00 2001 From: marco Date: Mon, 25 May 2026 10:12:27 +0000 Subject: [PATCH 1/3] vectorize promo logit computation --- maia3/models.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/maia3/models.py b/maia3/models.py index 3b23ed5..2ea775b 100644 --- a/maia3/models.py +++ b/maia3/models.py @@ -333,6 +333,8 @@ def __init__(self, cfg): self.promo_bias_proj = nn.Linear(cfg.head_hid_dim, 4, bias=False) # 4 promotion types: q, r, b, n + self.rank7_indices = [chess.square(file, 6) for file in range(8)] # squares 48-55 + self.rank8_indices = [chess.square(file, 7) for file in range(8)] # squares 56-63 def interpolate_elo(self, elos): @@ -375,22 +377,12 @@ def forward(self, tokens, self_elos, oppo_elos): scores_base = torch.einsum("bid,bjd->bij", sq_from, sq_to) / math.sqrt(self.cfg.head_hid_dim) scores_flat = scores_base.reshape(x.size(0), 64 * 64) # (B, 4096) - rank7_indices = [chess.square(file, 6) for file in range(8)] # squares 48-55 - rank8_indices = [chess.square(file, 7) for file in range(8)] # squares 56-63 - - rank8_features = sq_to[:, rank8_indices, :] # (B, 8, head_hid_dim) + rank8_features = sq_to[:, self.rank8_indices, :] # (B, 8, head_hid_dim) promo_biases = self.promo_bias_proj(rank8_features) * math.sqrt(self.cfg.head_hid_dim) # (B, 8, 4) for q,r,b,n - promotion_logits = [] - for from_file in range(8): # source file (a-h) - from_sq = rank7_indices[from_file] - for to_file in range(8): # target file (a-h) - to_sq = rank8_indices[to_file] - base_score = scores_base[:, from_sq, to_sq] # (B,) - for piece_idx in range(4): # q=0, r=1, b=2, n=3 - bias = promo_biases[:, to_file, piece_idx] # (B,) - promotion_logits.append((base_score + bias).unsqueeze(1)) - promotion_logits = torch.cat(promotion_logits, dim=1) # (B, 256) + base = scores_base[:, self.rank7_indices][:, :, self.rank8_indices] # (B, 8, 8) + promotion_logits = (base.unsqueeze(-1) + promo_biases.unsqueeze(1)).reshape(x.size(0), 256) + logits_move = torch.cat([scores_flat, promotion_logits], dim=1) # (B, 4352) x = self.last_ln(x.mean(dim=1)) From ed4a0ad0bfe1421ad4c8bccbcbf936fdc99584bb Mon Sep 17 00:00:00 2001 From: Marco Date: Mon, 25 May 2026 19:17:30 +0900 Subject: [PATCH 2/3] add shape information --- maia3/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/maia3/models.py b/maia3/models.py index 2ea775b..23e62c8 100644 --- a/maia3/models.py +++ b/maia3/models.py @@ -381,7 +381,7 @@ def forward(self, tokens, self_elos, oppo_elos): promo_biases = self.promo_bias_proj(rank8_features) * math.sqrt(self.cfg.head_hid_dim) # (B, 8, 4) for q,r,b,n base = scores_base[:, self.rank7_indices][:, :, self.rank8_indices] # (B, 8, 8) - promotion_logits = (base.unsqueeze(-1) + promo_biases.unsqueeze(1)).reshape(x.size(0), 256) + promotion_logits = (base.unsqueeze(-1) + promo_biases.unsqueeze(1)).reshape(x.size(0), 256) # (B, 256) logits_move = torch.cat([scores_flat, promotion_logits], dim=1) # (B, 4352) @@ -389,4 +389,4 @@ def forward(self, tokens, self_elos, oppo_elos): logits_value = self.fc_value(F.relu(self.fc_value_hid(x))) # (B, 3) logits_ponder = self.fc_ponder(F.relu(self.fc_ponder_hid(x))) # (B, 1) - return logits_move, logits_value, logits_ponder.squeeze(1) # (B, 4352), (B, 3), (B,) \ No newline at end of file + return logits_move, logits_value, logits_ponder.squeeze(1) # (B, 4352), (B, 3), (B,) From f24d5531f8acfeddbb0ae879268bda0864b8d0af Mon Sep 17 00:00:00 2001 From: Marco Date: Mon, 25 May 2026 19:18:34 +0900 Subject: [PATCH 3/3] remove trailing newline