Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 112 additions & 98 deletions src/maxtext/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1640,6 +1640,85 @@ def gmm_up(x, w0, w1, w0_bias, w1_bias, gmm_fn, weight_gather):
layer_w1 = adc.checkpoint_name(layer_w1, "moe_mlpwi_1")
return self.apply_ffn_activation(layer_w0, layer_w1)

def get_gmm_for_local_experts(x, routing, route_metadata):
"""Return a partial GMM function with preconfigured routing params."""
num_ep = self.get_expert_parallelism_size()
num_experts_per_shard = self.config.num_experts // num_ep
use_truncated_buffer = self.config.use_ring_of_experts and x.shape[0] < routing.sorted_selected_experts.shape[0]
if use_truncated_buffer:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can merge line 1647 and 1648?

local_group_sizes = routing.local_group_sizes
return functools.partial(
gmm,
group_sizes=local_group_sizes,
expert_assignments=routing.selected_experts,
group_offset=0,
)
if self.config.use_ragged_sort and self.config.use_ring_of_experts:
experts_start = route_metadata.expert_shard_id * num_experts_per_shard
else:
experts_start = 0
return functools.partial(
gmm,
group_sizes=routing.group_sizes,
expert_assignments=routing.selected_experts,
group_offset=experts_start,
)

def unsort_output_with_ra2a(intermediate_output, routing, route_metadata, output_shape, is_batch_sharded_by_expert):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function includes both unsort and ra2a.. Maybe we should name it unsort_output_and_ra2a?

"""Unsort tokens and return them to original shards using ragged all-to-all."""
if is_batch_sharded_by_expert:
# locally unpermute back to the original order
if self.config.use_ragged_sort:
# Mirror the ragged-prefix gather used in `local_permute`. The
# un-permute can use the same valid-prefix length because the
# routed token count is identical for forward and backward.
valid_end = jnp.sum(routing.group_sizes).astype(jnp.int32)
local_output = a2a_ragged_unsort(
intermediate_output,
jnp.argsort(route_metadata.local_sorted_indices), # pylint: disable=undefined-variable
valid_end,
)
else:
local_output = _sort_activations(
intermediate_output,
jnp.argsort(route_metadata.local_sorted_indices),
self.config.use_custom_sort_vjp,
)

input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params(
jnp.transpose(route_metadata.all_shards_group_sizes),
route_metadata.expert_shard_id,
self.get_expert_parallelism_size(),
)
return jax.lax.ragged_all_to_all(
local_output,
output_shape,
input_offsets,
send_sizes,
output_offsets,
recv_sizes,
axis_name=self._expert_parallelism_name,
)

# If batch is replicated across EP shards then each shard should send
# 0..local_shard_size data to the other shards and receive the
# local_shard data from all of the other shards using ragged_all_to_all.
input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params(
route_metadata.reshaped_group_sizes,
route_metadata.expert_shard_id,
self.get_expert_parallelism_size(),
is_batch_sharded=False,
)
return jax.lax.ragged_all_to_all(
intermediate_output,
output_shape,
input_offsets,
send_sizes,
output_offsets,
recv_sizes,
axis_name=self._expert_parallelism_name,
)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

emmm why ragged_all_to_all show up twice, one in a function and one outside the function?


@functools.partial(
jax.shard_map,
mesh=self.mesh,
Expand All @@ -1663,36 +1742,16 @@ def gmm_up(x, w0, w1, w0_bias, w1_bias, gmm_fn, weight_gather):
),
check_vma=self.config.check_vma,
)
def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, sharded_input_ids, rngs):
def sparse_matmul_route_and_compute(
x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, sharded_input_ids, rngs
):
batch_size, sequence_length, _ = x.shape
x, routing, route_metadata = route(x, logits, pre_bias_logits, rngs, input_ids=sharded_input_ids)

if self.config.mlp_bias:
w0_bias, w1_bias, wo_bias = self.transform_bias(routing.selected_experts, w0_bias, w1_bias, wo_bias)

num_ep = self.get_expert_parallelism_size()
num_experts_per_shard = self.config.num_experts // num_ep

use_truncated_buffer = self.config.use_ring_of_experts and x.shape[0] < routing.sorted_selected_experts.shape[0]
if use_truncated_buffer:
local_group_sizes = routing.local_group_sizes
gmm_fn = functools.partial(
gmm,
group_sizes=local_group_sizes,
expert_assignments=routing.selected_experts,
group_offset=0,
)
else:
if self.config.use_ragged_sort and self.config.use_ring_of_experts:
experts_start = route_metadata.expert_shard_id * num_experts_per_shard
else:
experts_start = 0
gmm_fn = functools.partial(
gmm,
group_sizes=routing.group_sizes,
expert_assignments=routing.selected_experts,
group_offset=experts_start,
)
gmm_fn = get_gmm_for_local_experts(x, routing, route_metadata)
intermediate_layer = gmm_up(x, w0, w1, w0_bias, w1_bias, gmm_fn, weight_gather)

wo_gather_axes, wo_tile_size = get_wo_gmm_params()
Expand Down Expand Up @@ -1727,83 +1786,38 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, s
output, (-1, sequence_length, self.moe_expert_input_dim // self.get_tensor_parallelism_size())
)
output = jax.lax.psum_scatter(output, self._expert_parallelism_name, scatter_dimension=0, tiled=True)
return output, routing.lb_loss, routing.bias_updates

if self.get_expert_parallelism_size() > 1:
original_inputs_first_dim = batch_size * sequence_length * self.config.num_experts_per_tok
if routing.sorted_selected_experts.shape[0] != original_inputs_first_dim:
raise ValueError("original_inputs_first_dim does not match the original tensor" " shape!")
output_shape = jax.lax.empty(
(
original_inputs_first_dim,
self.moe_expert_input_dim // self.get_tensor_parallelism_size(),
),
dtype=intermediate_output.dtype,
)

else:
if self.get_expert_parallelism_size() > 1:
original_inputs_first_dim = batch_size * sequence_length * self.config.num_experts_per_tok
if routing.sorted_selected_experts.shape[0] != original_inputs_first_dim:
raise ValueError("original_inputs_first_dim does not match the original tensor" " shape!")
output_shape = jax.lax.empty(
(
original_inputs_first_dim,
self.moe_expert_input_dim // self.get_tensor_parallelism_size(),
),
dtype=intermediate_output.dtype,
)

if is_batch_sharded_by_expert:
# locally unpermute back to the original order
if self.config.use_ragged_sort:
# Mirror the ragged-prefix gather used in `local_permute`. The
# un-permute can use the same valid-prefix length because the
# routed token count is identical for forward and backward.
valid_end = jnp.sum(routing.group_sizes).astype(jnp.int32)
local_output = a2a_ragged_unsort(
intermediate_output,
jnp.argsort(route_metadata.local_sorted_indices), # pylint: disable=undefined-variable
valid_end,
)
else:
local_output = _sort_activations(
intermediate_output,
jnp.argsort(route_metadata.local_sorted_indices),
self.config.use_custom_sort_vjp,
)

input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params(
jnp.transpose(route_metadata.all_shards_group_sizes),
route_metadata.expert_shard_id,
self.get_expert_parallelism_size(),
)
intermediate_output = jax.lax.ragged_all_to_all(
local_output,
output_shape,
input_offsets,
send_sizes,
output_offsets,
recv_sizes,
axis_name=self._expert_parallelism_name,
)
else:
# If batch is replicated across EP shards then each shard should send
# 0..local_shard_size data to the other shards and receive the
# local_shard data from all of the other shards using ragged_all_to_all.
input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params(
route_metadata.reshaped_group_sizes,
route_metadata.expert_shard_id,
self.get_expert_parallelism_size(),
is_batch_sharded=False,
)
intermediate_output = jax.lax.ragged_all_to_all(
intermediate_output,
output_shape,
input_offsets,
send_sizes,
output_offsets,
recv_sizes,
axis_name=self._expert_parallelism_name,
)

output = self.unpermute(
intermediate_output = unsort_output_with_ra2a(
intermediate_output,
routing.sorted_selected_experts,
routing.weights,
batch_size=batch_size,
sequence_length=sequence_length,
use_custom_sort_vjp=self.config.use_custom_sort_vjp,
group_sizes=routing.group_sizes,
routing,
route_metadata,
output_shape,
is_batch_sharded_by_expert,
)

output = self.unpermute(
intermediate_output,
routing.sorted_selected_experts,
routing.weights,
batch_size=batch_size,
sequence_length=sequence_length,
use_custom_sort_vjp=self.config.use_custom_sort_vjp,
group_sizes=routing.group_sizes,
)

return output, routing.lb_loss, routing.bias_updates

if self.config.moe_fsdp_use_two_stage_all_gather:
Expand Down Expand Up @@ -1851,7 +1865,7 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, s
if wo_bias is not None:
wo_bias = self._maybe_shard_with_pspec(wo_bias, wo_bias_pspec)

return wrapper(
return sparse_matmul_route_and_compute(
inputs,
gate_logits,
pre_bias_logits,
Expand Down
Loading