Diff Coverage

Diff: origin/master...HEAD, staged and unstaged changes

Source File Diff Coverage (%) Missing Lines
hyper_parallel/__init__.py 100%  
hyper_parallel/core/dtensor/_from_local_utils.py 63.8% 29-31,46-53,88-93
hyper_parallel/core/dtensor/dtensor.py 94.4% 807
hyper_parallel/platform/mindspore/platform.py 33.3% 938-941,946-949
hyper_parallel/platform/torch/platform.py 66.7% 626,631
hyper_parallel/core/dtensor/_from_local_utils.py
25
26
27
28
29
30
31
32
33
34
35
Tensor = platform.Tensor


def _ensure_mesh_process_groups(mesh: DeviceMesh) -> None:
    if hasattr(mesh, "_dim_group_names") and mesh._dim_group_names is not None:
        return
    mesh._dim_group_names = DeviceMesh._init_process_groups(  # pylint: disable=protected-access
        mesh._mesh_shape,
        mesh.mesh_dim_names,
        mesh._rank_list,
    )
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
    *,
    group_src: int = 0,
) -> Tensor:
    """Broadcast *tensor* along one mesh dimension."""
    _ensure_mesh_process_groups(mesh)
    group = mesh.get_group(mesh_dim)
    if hasattr(tensor, "is_contiguous") and not tensor.is_contiguous():
        tensor = tensor.contiguous()
    rank_list = mesh.get_rank_list_along_axis(mesh_dim)
    src = rank_list[group_src]
    platform.broadcast(tensor, src, group=group)
    return tensor


def _tensor_meta(local_tensor: Tensor, *, check_shape_stride: bool) -> dict:
    meta = {
84
85
86
87
88
89
90
91
92
93
94
95
96
97


def _mesh_check_group(device_mesh: DeviceMesh):
    """Return the process group and size covering all ranks in *device_mesh*."""
    _ensure_mesh_process_groups(device_mesh)
    if device_mesh.ndim == 1:
        return device_mesh.get_group(0), device_mesh.size(0)
    flat_mesh = device_mesh.flatten()
    _ensure_mesh_process_groups(flat_mesh)
    return flat_mesh.get_group(0), flat_mesh.size(0)


def run_from_local_checks(
    local_tensor: Tensor,
hyper_parallel/core/dtensor/dtensor.py
803
804
805
806
807
808
809
810
811

        layout = _build_layout(device_mesh, placements, len(local_shape))
        if is_rng_supported_mesh(device_mesh):
            if _OP_DISPATCHER._rng_tracker is None:
                _OP_DISPATCHER._rng_tracker = OffsetBasedRNGTracker(run_state_sync=False)
            with _OP_DISPATCHER._rng_tracker._distribute_region(
                device_mesh,
                layout.placements,
                size,
hyper_parallel/platform/mindspore/platform.py
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953

    @staticmethod
    def rand(size, dtype=None, device=None):  # pylint: disable=unused-argument
        """Create a tensor filled with uniform random values in ``[0, 1)``."""
        tensor = mint.rand(size, dtype=dtype)
        if device in ("GPU", "Ascend"):
            return tensor.to(device)
        return tensor

    @staticmethod
    def randn(size, dtype=None, device=None):  # pylint: disable=unused-argument
        """Create a tensor filled with standard-normal random values."""
        tensor = mint.randn(size, dtype=dtype)
        if device in ("GPU", "Ascend"):
            return tensor.to(device)
        return tensor

    @staticmethod
    def get_rank():
        """
hyper_parallel/platform/torch/platform.py
622
623
624
625
626
627
628
629
630
631
632
633
634
635

    @staticmethod
    def rand(size, dtype=None, device=None):
        """Create a tensor filled with uniform random values in ``[0, 1)``."""
        return torch.rand(size, dtype=dtype, device=device)

    @staticmethod
    def randn(size, dtype=None, device=None):
        """Create a tensor filled with standard-normal random values."""
        return torch.randn(size, dtype=dtype, device=device)

    @staticmethod
    def get_rank():
        """