-
Notifications
You must be signed in to change notification settings - Fork 537
Refactor moe.p: gmm and a2a unsort #4170
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1640,6 +1640,85 @@ def gmm_up(x, w0, w1, w0_bias, w1_bias, gmm_fn, weight_gather): | |
| layer_w1 = adc.checkpoint_name(layer_w1, "moe_mlpwi_1") | ||
| return self.apply_ffn_activation(layer_w0, layer_w1) | ||
|
|
||
| def get_gmm_for_local_experts(x, routing, route_metadata): | ||
| """Return a partial GMM function with preconfigured routing params.""" | ||
| num_ep = self.get_expert_parallelism_size() | ||
| num_experts_per_shard = self.config.num_experts // num_ep | ||
| use_truncated_buffer = self.config.use_ring_of_experts and x.shape[0] < routing.sorted_selected_experts.shape[0] | ||
| if use_truncated_buffer: | ||
| local_group_sizes = routing.local_group_sizes | ||
| return functools.partial( | ||
| gmm, | ||
| group_sizes=local_group_sizes, | ||
| expert_assignments=routing.selected_experts, | ||
| group_offset=0, | ||
| ) | ||
| if self.config.use_ragged_sort and self.config.use_ring_of_experts: | ||
| experts_start = route_metadata.expert_shard_id * num_experts_per_shard | ||
| else: | ||
| experts_start = 0 | ||
| return functools.partial( | ||
| gmm, | ||
| group_sizes=routing.group_sizes, | ||
| expert_assignments=routing.selected_experts, | ||
| group_offset=experts_start, | ||
| ) | ||
|
|
||
| def unsort_output_with_ra2a(intermediate_output, routing, route_metadata, output_shape, is_batch_sharded_by_expert): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function includes both unsort and ra2a.. Maybe we should name it |
||
| """Unsort tokens and return them to original shards using ragged all-to-all.""" | ||
| if is_batch_sharded_by_expert: | ||
| # locally unpermute back to the original order | ||
| if self.config.use_ragged_sort: | ||
| # Mirror the ragged-prefix gather used in `local_permute`. The | ||
| # un-permute can use the same valid-prefix length because the | ||
| # routed token count is identical for forward and backward. | ||
| valid_end = jnp.sum(routing.group_sizes).astype(jnp.int32) | ||
| local_output = a2a_ragged_unsort( | ||
| intermediate_output, | ||
| jnp.argsort(route_metadata.local_sorted_indices), # pylint: disable=undefined-variable | ||
| valid_end, | ||
| ) | ||
| else: | ||
| local_output = _sort_activations( | ||
| intermediate_output, | ||
| jnp.argsort(route_metadata.local_sorted_indices), | ||
| self.config.use_custom_sort_vjp, | ||
| ) | ||
|
|
||
| input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params( | ||
| jnp.transpose(route_metadata.all_shards_group_sizes), | ||
| route_metadata.expert_shard_id, | ||
| self.get_expert_parallelism_size(), | ||
| ) | ||
| return jax.lax.ragged_all_to_all( | ||
| local_output, | ||
| output_shape, | ||
| input_offsets, | ||
| send_sizes, | ||
| output_offsets, | ||
| recv_sizes, | ||
| axis_name=self._expert_parallelism_name, | ||
| ) | ||
|
|
||
| # If batch is replicated across EP shards then each shard should send | ||
| # 0..local_shard_size data to the other shards and receive the | ||
| # local_shard data from all of the other shards using ragged_all_to_all. | ||
| input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params( | ||
| route_metadata.reshaped_group_sizes, | ||
| route_metadata.expert_shard_id, | ||
| self.get_expert_parallelism_size(), | ||
| is_batch_sharded=False, | ||
| ) | ||
| return jax.lax.ragged_all_to_all( | ||
| intermediate_output, | ||
| output_shape, | ||
| input_offsets, | ||
| send_sizes, | ||
| output_offsets, | ||
| recv_sizes, | ||
| axis_name=self._expert_parallelism_name, | ||
| ) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. emmm why ragged_all_to_all show up twice, one in a function and one outside the function? |
||
|
|
||
| @functools.partial( | ||
| jax.shard_map, | ||
| mesh=self.mesh, | ||
|
|
@@ -1663,36 +1742,16 @@ def gmm_up(x, w0, w1, w0_bias, w1_bias, gmm_fn, weight_gather): | |
| ), | ||
| check_vma=self.config.check_vma, | ||
| ) | ||
| def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, sharded_input_ids, rngs): | ||
| def sparse_matmul_route_and_compute( | ||
| x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, sharded_input_ids, rngs | ||
| ): | ||
| batch_size, sequence_length, _ = x.shape | ||
| x, routing, route_metadata = route(x, logits, pre_bias_logits, rngs, input_ids=sharded_input_ids) | ||
|
|
||
| if self.config.mlp_bias: | ||
| w0_bias, w1_bias, wo_bias = self.transform_bias(routing.selected_experts, w0_bias, w1_bias, wo_bias) | ||
|
|
||
| num_ep = self.get_expert_parallelism_size() | ||
| num_experts_per_shard = self.config.num_experts // num_ep | ||
|
|
||
| use_truncated_buffer = self.config.use_ring_of_experts and x.shape[0] < routing.sorted_selected_experts.shape[0] | ||
| if use_truncated_buffer: | ||
| local_group_sizes = routing.local_group_sizes | ||
| gmm_fn = functools.partial( | ||
| gmm, | ||
| group_sizes=local_group_sizes, | ||
| expert_assignments=routing.selected_experts, | ||
| group_offset=0, | ||
| ) | ||
| else: | ||
| if self.config.use_ragged_sort and self.config.use_ring_of_experts: | ||
| experts_start = route_metadata.expert_shard_id * num_experts_per_shard | ||
| else: | ||
| experts_start = 0 | ||
| gmm_fn = functools.partial( | ||
| gmm, | ||
| group_sizes=routing.group_sizes, | ||
| expert_assignments=routing.selected_experts, | ||
| group_offset=experts_start, | ||
| ) | ||
| gmm_fn = get_gmm_for_local_experts(x, routing, route_metadata) | ||
| intermediate_layer = gmm_up(x, w0, w1, w0_bias, w1_bias, gmm_fn, weight_gather) | ||
|
|
||
| wo_gather_axes, wo_tile_size = get_wo_gmm_params() | ||
|
|
@@ -1727,83 +1786,38 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, s | |
| output, (-1, sequence_length, self.moe_expert_input_dim // self.get_tensor_parallelism_size()) | ||
| ) | ||
| output = jax.lax.psum_scatter(output, self._expert_parallelism_name, scatter_dimension=0, tiled=True) | ||
| return output, routing.lb_loss, routing.bias_updates | ||
|
|
||
| if self.get_expert_parallelism_size() > 1: | ||
| original_inputs_first_dim = batch_size * sequence_length * self.config.num_experts_per_tok | ||
| if routing.sorted_selected_experts.shape[0] != original_inputs_first_dim: | ||
| raise ValueError("original_inputs_first_dim does not match the original tensor" " shape!") | ||
| output_shape = jax.lax.empty( | ||
| ( | ||
| original_inputs_first_dim, | ||
| self.moe_expert_input_dim // self.get_tensor_parallelism_size(), | ||
| ), | ||
| dtype=intermediate_output.dtype, | ||
| ) | ||
|
|
||
| else: | ||
| if self.get_expert_parallelism_size() > 1: | ||
| original_inputs_first_dim = batch_size * sequence_length * self.config.num_experts_per_tok | ||
| if routing.sorted_selected_experts.shape[0] != original_inputs_first_dim: | ||
| raise ValueError("original_inputs_first_dim does not match the original tensor" " shape!") | ||
| output_shape = jax.lax.empty( | ||
| ( | ||
| original_inputs_first_dim, | ||
| self.moe_expert_input_dim // self.get_tensor_parallelism_size(), | ||
| ), | ||
| dtype=intermediate_output.dtype, | ||
| ) | ||
|
|
||
| if is_batch_sharded_by_expert: | ||
| # locally unpermute back to the original order | ||
| if self.config.use_ragged_sort: | ||
| # Mirror the ragged-prefix gather used in `local_permute`. The | ||
| # un-permute can use the same valid-prefix length because the | ||
| # routed token count is identical for forward and backward. | ||
| valid_end = jnp.sum(routing.group_sizes).astype(jnp.int32) | ||
| local_output = a2a_ragged_unsort( | ||
| intermediate_output, | ||
| jnp.argsort(route_metadata.local_sorted_indices), # pylint: disable=undefined-variable | ||
| valid_end, | ||
| ) | ||
| else: | ||
| local_output = _sort_activations( | ||
| intermediate_output, | ||
| jnp.argsort(route_metadata.local_sorted_indices), | ||
| self.config.use_custom_sort_vjp, | ||
| ) | ||
|
|
||
| input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params( | ||
| jnp.transpose(route_metadata.all_shards_group_sizes), | ||
| route_metadata.expert_shard_id, | ||
| self.get_expert_parallelism_size(), | ||
| ) | ||
| intermediate_output = jax.lax.ragged_all_to_all( | ||
| local_output, | ||
| output_shape, | ||
| input_offsets, | ||
| send_sizes, | ||
| output_offsets, | ||
| recv_sizes, | ||
| axis_name=self._expert_parallelism_name, | ||
| ) | ||
| else: | ||
| # If batch is replicated across EP shards then each shard should send | ||
| # 0..local_shard_size data to the other shards and receive the | ||
| # local_shard data from all of the other shards using ragged_all_to_all. | ||
| input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params( | ||
| route_metadata.reshaped_group_sizes, | ||
| route_metadata.expert_shard_id, | ||
| self.get_expert_parallelism_size(), | ||
| is_batch_sharded=False, | ||
| ) | ||
| intermediate_output = jax.lax.ragged_all_to_all( | ||
| intermediate_output, | ||
| output_shape, | ||
| input_offsets, | ||
| send_sizes, | ||
| output_offsets, | ||
| recv_sizes, | ||
| axis_name=self._expert_parallelism_name, | ||
| ) | ||
|
|
||
| output = self.unpermute( | ||
| intermediate_output = unsort_output_with_ra2a( | ||
| intermediate_output, | ||
| routing.sorted_selected_experts, | ||
| routing.weights, | ||
| batch_size=batch_size, | ||
| sequence_length=sequence_length, | ||
| use_custom_sort_vjp=self.config.use_custom_sort_vjp, | ||
| group_sizes=routing.group_sizes, | ||
| routing, | ||
| route_metadata, | ||
| output_shape, | ||
| is_batch_sharded_by_expert, | ||
| ) | ||
|
|
||
| output = self.unpermute( | ||
| intermediate_output, | ||
| routing.sorted_selected_experts, | ||
| routing.weights, | ||
| batch_size=batch_size, | ||
| sequence_length=sequence_length, | ||
| use_custom_sort_vjp=self.config.use_custom_sort_vjp, | ||
| group_sizes=routing.group_sizes, | ||
| ) | ||
|
|
||
| return output, routing.lb_loss, routing.bias_updates | ||
|
|
||
| if self.config.moe_fsdp_use_two_stage_all_gather: | ||
|
|
@@ -1851,7 +1865,7 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, s | |
| if wo_bias is not None: | ||
| wo_bias = self._maybe_shard_with_pspec(wo_bias, wo_bias_pspec) | ||
|
|
||
| return wrapper( | ||
| return sparse_matmul_route_and_compute( | ||
| inputs, | ||
| gate_logits, | ||
| pre_bias_logits, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we can merge line 1647 and 1648?