Diff Coverage

Diff: origin/master...HEAD, staged and unstaged changes

Source File Diff Coverage (%) Missing Lines
hyper_parallel/core/expert_parallel/expert_parallel.py 79.3% 269-270,275-279,281,331,346,360,367-369,402,711,717-722,727,730-731,733,736,761,797,859,874,946-953,1124,1335
hyper_parallel/trainer/config.py 100%  
hyper_parallel/trainer/parallel_dims.py 100%  
hyper_parallel/core/expert_parallel/expert_parallel.py
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
            iep_size=device_mesh.size(),
            outer_rank=0,
            inner_rank=device_mesh.get_local_rank(),
        )
    if ndim != 2:
        raise ValueError(
            "DeredundencyTokenDispatcher expects a 1-D EP mesh or a 2-D "
            f"[oep, iep] EP mesh, but got ndim={ndim}."
        )

    mesh_dim_names = getattr(device_mesh, "mesh_dim_names", None) or ()
    oep_dim = mesh_dim_names.index("oep") if "oep" in mesh_dim_names else 0
    iep_dim = mesh_dim_names.index("iep") if "iep" in mesh_dim_names else 1
    if oep_dim == iep_dim:
        raise ValueError("DeredundencyTokenDispatcher requires distinct oep and iep mesh dimensions.")

    return _DeredundencyMeshInfo(
        oep_group=device_mesh.get_group(oep_dim),
        iep_group=device_mesh.get_group(iep_dim),
        oep_size=device_mesh.size(oep_dim),
        iep_size=device_mesh.size(iep_dim),
327
328
329
330
331
332
333
334
335
    block_counts = counts_by_destination.view(-1)
    token_counts_by_destination_expert = selected_counts.sum(dim=0).contiguous().view(-1)
    total = int(block_counts.sum())
    if total == 0:
        return block_counts.new_zeros(0, dtype=block_counts.dtype).long(), token_counts_by_destination_expert

    block_starts = offsets_by_destination.view(-1).repeat_interleave(block_counts)
    block_offsets = block_counts.cumsum(0) - block_counts
    block_offsets_per_token = block_offsets.repeat_interleave(block_counts)
342
343
344
345
346
347
348
349
350
    """Scale routed expert outputs by optional router coefficients."""
    if router_coeff is None:
        return tokens
    if router_coeff.shape[0] != tokens.shape[0]:
        raise ValueError(
            "router_coeff length must match routed token count, got "
            f"{router_coeff.shape[0]} and {tokens.shape[0]}."
        )
    coeff = router_coeff
356
357
358
359
360
361
362
363
364
def _scatter_add_first_dim(src, indices, output_shape):
    """Scatter-add rows of ``src`` into a zero tensor along dim 0."""
    result = src.new_zeros(*output_shape)
    if len(src.shape) == 1:
        scatter_indices = indices
    else:
        scatter_indices = indices.reshape((-1,) + (1,) * (len(src.shape) - 1)).expand(
            -1, *src.shape[1:],
        )
363
364
365
366
367
368
369
370
371
372
            -1, *src.shape[1:],
        )
    if hasattr(result, "scatter_add"):
        return result.scatter_add(0, scatter_indices, src)
    if hasattr(result, "index_add"):
        return result.index_add(0, indices, src)
    raise RuntimeError(
        "DeredundencyTokenDispatcher.combine requires tensor scatter_add or "
        "index_add support for exdispatch_idx accumulation."
    )
398
399
400
401
402
403
404
405
406
            )
            if self._ctx.oep_size == 1:
                self._combined = combine_whiteboard
            else:
                self._combined = platform.differentiable_reduce_scatter(
                    combine_whiteboard,
                    self._ctx.oep_size,
                    0,
                    "sum",
707
708
709
710
711
712
713
714
715
        if mesh_info.oep_size == 1:
            gathered_counts = num_tokens_per_expert.view(1, num_tokens_per_expert.shape[0])
            return gathered_counts, routed_input, router_coeff

        gathered_counts, handle = platform.all_gather_single(
            num_tokens_per_expert,
            output_shape=[mesh_info.oep_size * num_tokens_per_expert.shape[0]],
            group=mesh_info.oep_group,
            async_op=True,
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
            output_shape=[mesh_info.oep_size * num_tokens_per_expert.shape[0]],
            group=mesh_info.oep_group,
            async_op=True,
        )
        if handle is not None:
            handle.wait()
        gathered_counts = gathered_counts.view(mesh_info.oep_size, num_tokens_per_expert.shape[0])
        source_token_totals = gathered_counts.sum(dim=1).tolist()
        if any(total != routed_input.shape[0] for total in source_token_totals):
            raise ValueError(
                "DeredundencyTokenDispatcher requires equal routed token "
                "counts within each OEP group because the shared token view "
                f"uses all-gather, got totals {source_token_totals}."
            )
        gathered_routed = platform.differentiable_all_gather_concat(
            routed_input, mesh_info.oep_group, mesh_info.oep_size, 0,
        )
        if router_coeff is None:
            gathered_router_coeff = None
        else:
            gathered_router_coeff = platform.differentiable_all_gather_concat(
                router_coeff, mesh_info.oep_group, mesh_info.oep_size, 0,
            )
        return gathered_counts, gathered_routed, gathered_router_coeff

    @staticmethod
    def dispatch(module: Module, inputs: tuple, device_mesh: DeviceMesh) -> tuple:
        """Dispatch tokens using OEP all-gather and IEP all-to-all.
757
758
759
760
761
762
763
764
765
        del module
        routed_input, num_tokens_per_expert = inputs[0], inputs[1]
        router_coeff = inputs[2] if len(inputs) > 2 else None
        if router_coeff is not None and router_coeff.shape[0] != routed_input.shape[0]:
            raise ValueError(
                "router_coeff length must match routed_input token count, got "
                f"{router_coeff.shape[0]} and {routed_input.shape[0]}."
            )
        mesh_info = _get_deredundency_mesh_info(device_mesh)
793
794
795
796
797
798
799
800
801
            group=mesh_info.iep_group,
            async_op=True,
        )
        if handle is not None:
            handle.wait()
        iep_output_splits = iep_counts_out.view(mesh_info.iep_size, num_local_experts).sum(dim=1).tolist()

        outer_routed_input = gathered_routed[dispatch_indices]
        outer_router_coeff = (
855
856
857
858
859
860
861
862
863
        )
        if ctx.oep_size == 1:
            return combine_whiteboard

        return platform.differentiable_reduce_scatter(
            combine_whiteboard,
            ctx.oep_size,
            0,
            "sum",
870
871
872
873
874
875
876
877
        ctx: DeredundencyDispatchContext,
    ) -> None:
        """Validate that dispatch context and combine mesh are compatible."""
        if mesh_info.oep_size != ctx.oep_size:
            raise ValueError(
                "DeredundencyTokenDispatcher.combine received a context for "
                f"oep_size={ctx.oep_size}, but the mesh resolves to oep_size={mesh_info.oep_size}."
            )
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956


def _get_flattened_ep_mesh(device_mesh: DeviceMesh) -> DeviceMesh:
    """Return a 1-D EP mesh, flattening a 2-D deredundency mesh if needed."""
    if getattr(device_mesh, "ndim", 1) == 1:
        return device_mesh
    mesh_dim_names = getattr(device_mesh, "mesh_dim_names", None) or ()
    if "ep" in mesh_dim_names or "ep" in device_mesh.get_flatten_mapping():
        return device_mesh["ep"]
    if set(mesh_dim_names) == {"oep", "iep"}:
        return device_mesh.flatten("ep")
    raise ValueError(
        "Deredundency ExpertParallel expects a 1-D EP mesh or a 2-D "
        "[oep, iep] mesh when partitioning expert weights."
    )
1120
1121
1122
1123
1124
1125
1126
1127
1128

    def _partition_mesh(self, device_mesh: DeviceMesh) -> DeviceMesh:
        """Return the mesh used to shard expert weights."""
        if self._token_dispatcher_name == "deredundency":
            return _get_flattened_ep_mesh(device_mesh)
        return device_mesh

    def _partition_fn(
        self, name: str, module: Module, device_mesh: DeviceMesh
1331
1332
1333
1334
1335
1336
1337
1338
1339

        dispatch_mesh = self._dispatch_mesh(device_mesh)

        if self.async_combine:
            handle = self._token_dispatcher.combine_start(
                routed_output, dispatch_mesh, ctx
            )
            # pylint: disable=W0212
            module._ep_combine_handle = handle