Diff Coverage

Diff: origin/master...HEAD, staged and unstaged changes

Source File Diff Coverage (%) Missing Lines
hyper_parallel/core/context_parallel/async_context_parallel.py 100%  
hyper_parallel/core/context_parallel/context_parallel.py 100%  
hyper_parallel/core/context_parallel/dsa_context_parallel.py 100%  
hyper_parallel/core/dtensor/_mesh_layout.py 100%  
hyper_parallel/core/dtensor/device_mesh.py 0.0% 943,949
hyper_parallel/core/dtensor/layout.py 100%  
hyper_parallel/core/fully_shard/api.py 0.0% 719-724
hyper_parallel/core/fully_shard/hsdp_param.py 100%  
hyper_parallel/core/shard/api.py 66.7% 351
hyper_parallel/platform/torch/fully_shard/param.py 100%  
hyper_parallel/platform/torch/fully_shard/state.py 100%  
hyper_parallel/core/dtensor/device_mesh.py
939
940
941
942
943
944
945
946
947
            try:
                candidate = anchor_mesh[mesh_dim_names]
            except (KeyError, ValueError, RuntimeError, NotImplementedError):
                continue
            candidate_attrs = (
                candidate.device_type == mesh.device_type,
                candidate.mesh_shape == mesh.mesh_shape,
                candidate.rank_list == mesh.rank_list,
                candidate._flatten_rank_map == flatten_rank_map,  # pylint: disable=protected-access
945
946
947
948
949
950
951
952
953
                candidate.mesh_shape == mesh.mesh_shape,
                candidate.rank_list == mesh.rank_list,
                candidate._flatten_rank_map == flatten_rank_map,  # pylint: disable=protected-access
            )
            if all(candidate_attrs):
                return candidate
        return None

    @staticmethod
hyper_parallel/core/fully_shard/api.py
715
716
717
718
719
720
721
722
723
724
725
726
727
        mesh = init_device_mesh(device_type="npu", mesh_shape=(platform.get_world_size(),))
    if mesh is not None:
        device = _get_device_from_mesh(mesh)
    else:
        compat_mesh = None
        for param in params:
            dtensor_mesh = get_dtensor_managed_mesh(param)
            if dtensor_mesh is not None:
                compat_mesh = dtensor_mesh
                break
        if compat_mesh is None:
            raise ValueError("fully_shard could not resolve a DTensor mesh for compatibility mode.")
        device = _get_device_from_mesh(compat_mesh)
hyper_parallel/core/shard/api.py
347
348
349
350
351
352
353
354
355

    for cell_name in return_local_tensor_list:
        register_cell = cell_dict.get(cell_name)
        if register_cell is None:
            raise KeyError(f"Cannot find cell {cell_name!r} in sharded cell.")
        register_cell.register_forward_hook(hook_func)


def _shard_callable(func: Callable, sharding_plan: Dict):