Diff Coverage

Diff: origin/master...HEAD, staged and unstaged changes

Source File Diff Coverage (%) Missing Lines
hyper_parallel/core/expert_parallel/expert_parallel.py 75.8% 269-270,275-279,281,331,346,360,367-369,624,630-635,640,643-644,646,649,674,710,756,776,804-811,925
hyper_parallel/trainer/config.py 100%  
hyper_parallel/trainer/parallel_dims.py 100%  
hyper_parallel/core/expert_parallel/expert_parallel.py
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
            iep_size=device_mesh.size(),
            outer_rank=0,
            inner_rank=device_mesh.get_local_rank(),
        )
    if ndim != 2:
        raise ValueError(
            "DeredundencyTokenDispatcher expects a 1-D EP mesh or a 2-D "
            f"[oep, iep] EP mesh, but got ndim={ndim}."
        )

    mesh_dim_names = getattr(device_mesh, "mesh_dim_names", None) or ()
    oep_dim = mesh_dim_names.index("oep") if "oep" in mesh_dim_names else 0
    iep_dim = mesh_dim_names.index("iep") if "iep" in mesh_dim_names else 1
    if oep_dim == iep_dim:
        raise ValueError("DeredundencyTokenDispatcher requires distinct oep and iep mesh dimensions.")

    return _DeredundencyMeshInfo(
        oep_group=device_mesh.get_group(oep_dim),
        iep_group=device_mesh.get_group(iep_dim),
        oep_size=device_mesh.size(oep_dim),
        iep_size=device_mesh.size(iep_dim),
327
328
329
330
331
332
333
334
335
    block_counts = counts_by_destination.view(-1)
    token_counts_by_destination_expert = selected_counts.sum(dim=0).contiguous().view(-1)
    total = int(block_counts.sum())
    if total == 0:
        return block_counts.new_zeros(0, dtype=block_counts.dtype).long(), token_counts_by_destination_expert

    block_starts = offsets_by_destination.view(-1).repeat_interleave(block_counts)
    block_offsets = block_counts.cumsum(0) - block_counts
    block_offsets_per_token = block_offsets.repeat_interleave(block_counts)
342
343
344
345
346
347
348
349
350
    """Scale routed expert outputs by optional router coefficients."""
    if router_coeff is None:
        return tokens
    if router_coeff.shape[0] != tokens.shape[0]:
        raise ValueError(
            "router_coeff length must match routed token count, got "
            f"{router_coeff.shape[0]} and {tokens.shape[0]}."
        )
    coeff = router_coeff
356
357
358
359
360
361
362
363
364
def _scatter_add_first_dim(src, indices, output_shape):
    """Scatter-add rows of ``src`` into a zero tensor along dim 0."""
    result = src.new_zeros(*output_shape)
    if len(src.shape) == 1:
        scatter_indices = indices
    else:
        scatter_indices = indices.reshape((-1,) + (1,) * (len(src.shape) - 1)).expand(
            -1, *src.shape[1:],
        )
363
364
365
366
367
368
369
370
371
372
            -1, *src.shape[1:],
        )
    if hasattr(result, "scatter_add"):
        return result.scatter_add(0, scatter_indices, src)
    if hasattr(result, "index_add"):
        return result.index_add(0, indices, src)
    raise RuntimeError(
        "DeredundencyTokenDispatcher.combine requires tensor scatter_add or "
        "index_add support for exdispatch_idx accumulation."
    )
620
621
622
623
624
625
626
627
628
        if mesh_info.oep_size == 1:
            gathered_counts = num_tokens_per_expert.view(1, num_tokens_per_expert.shape[0])
            return gathered_counts, routed_input, router_coeff

        gathered_counts, handle = platform.all_gather_single(
            num_tokens_per_expert,
            output_shape=[mesh_info.oep_size * num_tokens_per_expert.shape[0]],
            group=mesh_info.oep_group,
            async_op=True,
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
            output_shape=[mesh_info.oep_size * num_tokens_per_expert.shape[0]],
            group=mesh_info.oep_group,
            async_op=True,
        )
        if handle is not None:
            handle.wait()
        gathered_counts = gathered_counts.view(mesh_info.oep_size, num_tokens_per_expert.shape[0])
        source_token_totals = gathered_counts.sum(dim=1).tolist()
        if any(total != routed_input.shape[0] for total in source_token_totals):
            raise ValueError(
                "DeredundencyTokenDispatcher requires equal routed token "
                "counts within each OEP group because the shared token view "
                f"uses all-gather, got totals {source_token_totals}."
            )
        gathered_routed = platform.differentiable_all_gather_concat(
            routed_input, mesh_info.oep_group, mesh_info.oep_size, 0,
        )
        if router_coeff is None:
            gathered_router_coeff = None
        else:
            gathered_router_coeff = platform.differentiable_all_gather_concat(
                router_coeff, mesh_info.oep_group, mesh_info.oep_size, 0,
            )
        return gathered_counts, gathered_routed, gathered_router_coeff

    @staticmethod
    def dispatch(module: Module, inputs: tuple, device_mesh: DeviceMesh) -> tuple:
        """Dispatch tokens using OEP all-gather and IEP all-to-all.
670
671
672
673
674
675
676
677
678
        del module
        routed_input, num_tokens_per_expert = inputs[0], inputs[1]
        router_coeff = inputs[2] if len(inputs) > 2 else None
        if router_coeff is not None and router_coeff.shape[0] != routed_input.shape[0]:
            raise ValueError(
                "router_coeff length must match routed_input token count, got "
                f"{router_coeff.shape[0]} and {routed_input.shape[0]}."
            )
        mesh_info = _get_deredundency_mesh_info(device_mesh)
706
707
708
709
710
711
712
713
714
            group=mesh_info.iep_group,
            async_op=True,
        )
        if handle is not None:
            handle.wait()
        iep_output_splits = iep_counts_out.view(mesh_info.iep_size, num_local_experts).sum(dim=1).tolist()

        outer_routed_input = gathered_routed[dispatch_indices]
        outer_router_coeff = (
752
753
754
755
756
757
758
759
        """
        del module
        mesh_info = _get_deredundency_mesh_info(device_mesh)
        if mesh_info.oep_size != ctx.oep_size:
            raise ValueError(
                "DeredundencyTokenDispatcher.combine received a context for "
                f"oep_size={ctx.oep_size}, but the mesh resolves to oep_size={mesh_info.oep_size}."
            )
772
773
774
775
776
777
778
779
780
        )
        if ctx.oep_size == 1:
            return combine_whiteboard

        return platform.differentiable_reduce_scatter(
            combine_whiteboard,
            ctx.oep_size,
            0,
            "sum",
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814


def _get_flattened_ep_mesh(device_mesh: DeviceMesh) -> DeviceMesh:
    """Return a 1-D EP mesh, flattening a 2-D deredundency mesh if needed."""
    if getattr(device_mesh, "ndim", 1) == 1:
        return device_mesh
    mesh_dim_names = getattr(device_mesh, "mesh_dim_names", None) or ()
    if "ep" in mesh_dim_names or "ep" in device_mesh.get_flatten_mapping():
        return device_mesh["ep"]
    if set(mesh_dim_names) == {"oep", "iep"}:
        return device_mesh.flatten("ep")
    raise ValueError(
        "Deredundency ExpertParallel expects a 1-D EP mesh or a 2-D "
        "[oep, iep] mesh when partitioning expert weights."
    )
921
922
923
924
925
926
927
928
929

    def _partition_mesh(self, device_mesh: DeviceMesh) -> DeviceMesh:
        """Return the mesh used to shard expert weights."""
        if self._token_dispatcher_name == "deredundency":
            return _get_flattened_ep_mesh(device_mesh)
        return device_mesh

    def _partition_fn(
        self, name: str, module: Module, device_mesh: DeviceMesh