Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / dtensor / device_mesh.py: 58%
795 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
1# Copyright 2025-2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""device mesh"""
17import copy
18import os
19import threading
20from types import TracebackType
21from typing import Any, List, Literal, Optional, Sequence, Type, Union
22import numpy as np
24from hyper_parallel.core.dtensor._mesh_layout import IntTuple, _MeshLayout, _contiguous_strides, _is_int
25from hyper_parallel.platform import get_platform
26from hyper_parallel.platform.platform import EXISTING_COMM_GROUPS, PlatformType
28platform = get_platform()
29Tensor = platform.Tensor
32class _MeshEnv(threading.local):
33 """Per-thread stack of active :class:`DeviceMesh` (PyTorch ``_mesh_resources`` parity)."""
35 def __init__(self) -> None:
36 super().__init__()
37 self.mesh_stack: List["DeviceMesh"] = []
39 def get_current_mesh(self) -> "DeviceMesh":
40 """Return the innermost active :class:`DeviceMesh` for this thread (PyTorch parity)."""
41 if len(self.mesh_stack) == 0:
42 raise RuntimeError("No device mesh is currently active!")
43 return self.mesh_stack[-1]
46_mesh_resources = _MeshEnv()
48BackendConfig = Optional[str]
51def _get_sub_rank_list(mesh_shape, mesh_dim_names, rank_list, sub_mesh_dim_names, current_rank):
52 """
53 Get the sub rank list for a sub mesh.
55 Args:
56 mesh_shape (tuple[int]): The shape of the original mesh.
57 mesh_dim_names (tuple[str]): The mesh dim names of the original mesh dimensions.
58 rank_list (tuple[int]): A tuple of ranks that participate in this mesh.
59 sub_mesh_dim_names (tuple[str]): The mesh dim names of the sub mesh to extract.
60 current_rank (int): The current process rank.
62 Returns:
63 list: The sub rank list for the sub mesh.
64 """
65 mesh_tensor = np.array(rank_list).reshape(mesh_shape)
67 for dim_index, dim_name in enumerate(mesh_dim_names):
68 if dim_name in sub_mesh_dim_names:
69 continue
71 dim_size = mesh_shape[dim_index]
72 sliced_tensors = np.split(mesh_tensor, dim_size, axis=dim_index)
74 for sliced_tensor in sliced_tensors:
75 rank_exists = np.isin(np.array([current_rank]), sliced_tensor).any()
76 if rank_exists:
77 mesh_tensor = sliced_tensor
78 break
80 sub_rank_list = mesh_tensor.reshape(-1).tolist()
81 return sub_rank_list
84def _normalize_backend_value(value: Any) -> BackendConfig:
85 if value is None:
86 return None
87 if isinstance(value, str):
88 return value
89 if isinstance(value, tuple) and len(value) > 0:
90 backend = value[0]
91 if backend is None or isinstance(backend, str):
92 return backend
93 return None
96def _normalize_backend_override(
97 backend_override: dict[Union[int, str], Any],
98 ndim: int,
99 mesh_dim_names: Optional[tuple[str, ...]] = None,
100) -> tuple[BackendConfig, ...]:
101 """Normalize backend overrides by dim index/name."""
102 remaining = dict(backend_override)
103 normalized: list[BackendConfig] = []
104 mesh_dim_names = mesh_dim_names or ()
106 for dim_idx in range(ndim):
107 dim_name = mesh_dim_names[dim_idx] if dim_idx < len(mesh_dim_names) else None
108 if dim_name is not None and dim_name in remaining:
109 if dim_idx in remaining:
110 raise RuntimeError(
111 f"Found redundant dim index {dim_idx} and name {dim_name} in backend_override"
112 )
113 normalized.append(_normalize_backend_value(remaining.pop(dim_name)))
114 elif dim_idx in remaining:
115 normalized.append(_normalize_backend_value(remaining.pop(dim_idx)))
116 else:
117 normalized.append(None)
119 if remaining:
120 raise RuntimeError(
121 f"Found invalid keys in backend_override: got {list(remaining.keys())}, "
122 f"expected integers in range [0, {ndim}) or one of {mesh_dim_names}"
123 )
124 return tuple(normalized)
127def _should_defer_group_init(sub_layout: _MeshLayout, backend_override: BackendConfig) -> bool:
128 """Whether this mesh dimension should skip eager process-group creation."""
129 return backend_override == "fake" or sub_layout.numel() == 1
132class DeviceMesh:
133 """
134 Topological abstraction describing cluster devices.
136 Args:
137 device_type (str): Device type. Valid values depend on the active platform:
139 - **PyTorch** (same as ``torch.distributed.device_mesh.DeviceMesh``):
140 ``"cpu"``, ``"cuda"``, ``"npu"``.
141 - **MindSpore** (mapped to the corresponding communication backend):
142 ``"cpu"`` → mccl, ``"gpu"`` → nccl, ``"npu"`` → hccl.
143 mesh (Union[Tensor, list, tuple, np.ndarray, None]): A multi-dimensional array, list, or integer
144 tensor describing the device layout. The IDs in the mesh are global IDs of the
145 default process group, representing the multi-dimensional networking structure
146 of devices in distributed training (e.g., [[0,1],[2,3]] represents a 2x2 device mesh).
147 If a list or non-int32 tensor is provided, it will be automatically converted
148 to an int32 tensor. If None, a 1D mesh containing all ranks
149 (i.e., ``[0, 1, ..., world_size-1]``) will be created automatically.
150 mesh_dim_names (tuple[str]): A tuple[str] of mesh dim names for each dimension of mesh.
151 _init_backend (boolean): Whether initial process group.
153 Attributes:
154 ndim (int): Number of dimensions in the mesh.
155 mesh_shape (tuple[int]): Shape of the device mesh.
156 rank_list (tuple[int]): Flattened list of ranks from the mesh.
157 root_mesh (DeviceMesh): The parent mesh if this is a sub mesh, None otherwise.
158 sub_mesh (list[DeviceMesh]): List of child meshes created from this mesh.
160 Context manager:
161 Use ``with device_mesh:`` to set the **current** mesh for this thread.
162 """
164 device_type: Literal["cpu", "cuda", "gpu", "npu"]
165 mesh: Union[Tensor, list, tuple, np.ndarray]
166 mesh_dim_names: Union[tuple[str, ...], list[str], None]
168 _VALID_DEVICE_TYPES = {
169 PlatformType.PYTORCH: {"cpu", "cuda", "npu"},
170 PlatformType.MINDSPORE: {"cpu", "gpu", "npu"},
171 }
173 def __init__(self,
174 device_type: Literal["cpu", "cuda", "gpu", "npu"],
175 mesh: Union[Tensor, list, tuple, np.ndarray, None] = None,
176 *,
177 mesh_dim_names: Union[tuple[str, ...], list[str], None] = None,
178 _init_backend: bool = True,
179 _layout: Optional[_MeshLayout] = None,
180 _rank_map: Optional[Tensor] = None,
181 _root_mesh: Optional['DeviceMesh'] = None,
182 ):
183 self._validate_device_type(device_type)
184 self.device_type = device_type
186 if _init_backend:
187 platform.init_process_group()
189 self._layout, self._rank_map = self._resolve_layout_and_rank_map(mesh, _layout, _rank_map)
190 self._rank = platform.get_rank()
191 self._root_mesh = _root_mesh
192 self._refresh_mesh_view()
193 self._set_mesh_dim_names(mesh_dim_names)
194 self._initialize_runtime_state(_init_backend)
195 if os.getenv("MS_SIMULATION_LEVEL") is None:
196 self._coordinate_on_dim = self._compute_coordinate_on_dim()
198 @classmethod
199 def _validate_device_type(cls, device_type: str) -> None:
200 """Validate that the requested device type is supported on the active platform."""
201 valid_device_types = cls._VALID_DEVICE_TYPES.get(platform.platform_type)
202 if valid_device_types is not None and device_type not in valid_device_types:
203 raise ValueError(
204 f"Invalid device_type '{device_type}' for {platform.platform_type.name} platform. "
205 f"Valid device types are: {sorted(valid_device_types)}"
206 )
208 @classmethod
209 def _resolve_layout_and_rank_map(
210 cls,
211 mesh: Union[Tensor, list, tuple, np.ndarray, None],
212 layout: Optional[_MeshLayout],
213 rank_map: Optional[Tensor],
214 ) -> tuple[_MeshLayout, Tensor]:
215 """Build the internal layout and rank map from either public or private constructor inputs."""
216 if mesh is not None and (layout is not None or rank_map is not None):
217 raise TypeError("Cannot provide both explicit mesh and private _layout/_rank_map arguments.")
219 if mesh is None and (layout is None or rank_map is None):
220 world_size = platform.get_world_size()
221 mesh = list(range(world_size))
223 if mesh is not None:
224 mesh_tensor = cls._convert_mesh_to_tensor(mesh)
225 if mesh_tensor.ndim == 0:
226 raise ValueError("mesh must be at least 1-dimensional")
227 return cls._build_layout_from_mesh(mesh_tensor), cls._build_rank_map_from_mesh(mesh_tensor)
229 rank_map_tensor = cls._convert_rank_map_to_tensor(rank_map)
230 if layout is None or rank_map_tensor is None:
231 raise TypeError("The mesh argument is required except for private _layout/_rank_map construction.")
232 if not layout.check_non_overlap():
233 raise ValueError(f"Invalid overlapping layout {layout}.")
234 return layout, rank_map_tensor
236 def _refresh_mesh_view(self) -> None:
237 """Materialize the visible mesh tensor and the derived shape/rank metadata."""
238 # Compute everything in numpy first so the intermediate ops don't need
239 # a real device. Otherwise the call would fail (or SIGSEGV on Ascend)
240 # when DeviceMesh is constructed inside a ``ms.DeviceCtx("meta")``
241 # block — e.g., from ``DeviceMesh.concatenate`` invoked under
242 # ``fully_shard``, which forces fresh ``Tensor()`` constructions onto
243 # the meta device and any subsequent op (asnumpy, nonzero, …) crashes.
244 rank_map_np = platform.tensor_to_numpy(self._rank_map).reshape(-1)
245 full_mesh_np = self._layout.remap_to_numpy(rank_map_np)
246 if full_mesh_np.shape[0] == 1:
247 per_rank_mesh_np = full_mesh_np[0]
248 else:
249 coords = np.argwhere(full_mesh_np == self._rank)
250 if coords.shape[0] == 0:
251 raise RuntimeError(
252 "In order to get the mesh tensor of a DeviceMesh it needs to "
253 "either have all its original dimensions or contain the local rank."
254 )
255 per_rank_mesh_np = full_mesh_np[coords[0, 0]]
256 # Cache the numpy view so ``_compute_coordinate_on_dim`` doesn't need
257 # to operate on ``self.mesh`` (which may be on the meta device).
258 self._per_rank_mesh_np = per_rank_mesh_np
259 self.mesh = Tensor(per_rank_mesh_np.astype(np.int32)).int()
260 self._mesh_shape = tuple(per_rank_mesh_np.shape)
261 self._rank_list = tuple(per_rank_mesh_np.reshape(-1).tolist())
262 self._flatten_rank_map = tuple(rank_map_np.tolist())
263 self._dev_num = np.prod(np.array(self._mesh_shape))
264 self._dev_rank = len(self._mesh_shape)
266 def _set_mesh_dim_names(
267 self,
268 mesh_dim_names: Union[tuple[str, ...], list[str], None],
269 ) -> None:
270 """Validate mesh dim names and build lookup tables for named access."""
271 self.mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None
272 if self.mesh_dim_names is None:
273 return
275 if len(self._mesh_shape) != len(self.mesh_dim_names):
276 raise ValueError(
277 f'mesh dimensions ({len(self._mesh_shape)}) should be equal to '
278 f'mesh_dim_names length ({len(self.mesh_dim_names)})'
279 )
280 if len(set(self.mesh_dim_names)) != len(self.mesh_dim_names):
281 raise ValueError(f'Each element of mesh_dim_names {self.mesh_dim_names} should be different')
282 inter_key = "interleaved_parallel"
283 if inter_key in self.mesh_dim_names and self.mesh_dim_names.index(inter_key) != len(self.mesh_dim_names) - 1:
284 raise ValueError(
285 "'interleaved_parallel' should be at the last dim of mesh_dim_names, means virtual sharding."
286 )
287 self._dev_name_to_dev_id = {
288 name: self._dev_rank - i - 1 for i, name in enumerate(self.mesh_dim_names)
289 }
290 self._dev_name_to_index = {name: i for i, name in enumerate(self.mesh_dim_names)}
292 def _initialize_runtime_state(self, init_backend: bool) -> None:
293 """Initialize caches and optional process-group state for the mesh view."""
294 self._cache_rank_list_along_axis = {}
295 self._global_shape_map = {}
296 self._sub_mesh_cache = {}
297 self._flatten_mapping: dict[str, 'DeviceMesh'] = {}
298 self._ndim = len(self._mesh_shape)
299 self._dim_group_backends = (None,) * self._ndim
300 self._dim_group_sources = tuple((self, dim) for dim in range(self._ndim))
301 self._sub_mesh: List['DeviceMesh'] = []
302 if not init_backend:
303 return
304 self._dim_group_names = self._init_process_groups(
305 self._mesh_shape,
306 self.mesh_dim_names,
307 self._rank_list,
308 )
310 @staticmethod
311 def _build_layout_from_mesh(mesh: Tensor) -> _MeshLayout:
312 mesh_shape = tuple(mesh.shape)
313 return _MeshLayout(mesh_shape, _contiguous_strides(mesh_shape))
315 @staticmethod
316 def _build_rank_map_from_mesh(mesh: Tensor) -> Tensor:
317 return Tensor(platform.tensor_to_numpy(mesh).reshape(-1)).int()
319 @staticmethod
320 def _convert_rank_map_to_tensor(rank_map: Tensor) -> Tensor:
321 """Normalize a rank-map input into the flat int32 Tensor stored on the mesh.
323 Tensor input is returned as-is to preserve its original device; list /
324 tuple / numpy input is built into a fresh flat int32 Tensor.
325 """
326 if isinstance(rank_map, Tensor):
327 # Reuse the existing tensor as-is so we preserve its real device.
328 # Going through ``Tensor(np_array)`` would re-create on whatever
329 # device context is active (e.g. ``ms.DeviceCtx("meta")`` while
330 # ``DeviceMesh.concatenate`` runs under ``fully_shard``), which then
331 # breaks the immediate ``asnumpy()`` in ``_refresh_mesh_view``.
332 # All in-tree callers that pass a Tensor pass an existing
333 # ``DeviceMesh._rank_map`` — already a flat int32 tensor, so no
334 # reshape/cast is needed.
335 return rank_map
336 rank_map_np = np.array(rank_map)
337 return Tensor(rank_map_np.reshape(-1).astype(np.int32)).int()
339 @staticmethod
340 def _get_mesh_tensor_from_full_mesh(full_mesh: Tensor, current_rank: Optional[int] = None) -> Tensor:
341 """Select the per-rank mesh view from a fully materialized layout remap."""
342 if full_mesh.shape[0] == 1:
343 return full_mesh[0]
345 if current_rank is None:
346 current_rank = platform.get_rank()
348 rank_coords = (full_mesh == current_rank).nonzero()
349 if rank_coords.shape[0] > 0:
350 return full_mesh[rank_coords[0, 0]]
351 raise RuntimeError(
352 "In order to get the mesh tensor of a DeviceMesh it needs to "
353 "either have all its original dimensions or contain the local rank."
354 )
356 def _compute_coordinate_on_dim(self):
357 """Compute the current rank coordinates inside this mesh view."""
358 # Use the cached numpy view rather than ``self.mesh`` so this works
359 # even when the mesh tensor lives on the meta device (DeviceMesh
360 # constructed under ``ms.DeviceCtx("meta")`` via ``fully_shard``).
361 per_rank_mesh_np = getattr(self, "_per_rank_mesh_np", None)
362 if per_rank_mesh_np is not None:
363 rank_coords = np.argwhere(per_rank_mesh_np == self._rank)
364 if rank_coords.shape[0] not in (0, 1):
365 raise AssertionError(
366 f"rank_coords.shape[0] must be 0 or 1, got {rank_coords.shape[0]}"
367 )
368 if rank_coords.shape[0] == 0:
369 return None
370 return tuple(int(x) for x in rank_coords[0])
371 return self._compute_coordinates_from_mesh(self.mesh, self._rank)
373 @staticmethod
374 def _compute_coordinates_from_mesh(
375 mesh_tensor: Tensor,
376 rank: int,
377 ):
378 """Locate one rank inside a mesh tensor and return its coordinates."""
379 rank_coords = (mesh_tensor == rank).nonzero()
380 if rank_coords.shape[0] not in (0, 1):
381 raise AssertionError(
382 f"rank_coords.shape[0] must be 0 or 1, got {rank_coords.shape[0]}"
383 )
385 if rank_coords.shape[0] == 0:
386 return None
388 coords = rank_coords[0].tolist()
389 return tuple(coords)
391 def size(self, mesh_dim=None) -> int:
392 if mesh_dim is not None:
393 return self.mesh.shape[mesh_dim]
394 return self.mesh.numel()
396 def get_coordinate(self):
397 return self._coordinate_on_dim if self._coordinate_on_dim else None
399 def __enter__(self) -> "DeviceMesh":
400 _mesh_resources.mesh_stack.append(self)
401 return self
403 def __exit__(
404 self,
405 exc_type: Optional[Type[BaseException]],
406 exc_val: Optional[BaseException],
407 exc_tb: Optional[TracebackType],
408 ) -> None:
409 _mesh_resources.mesh_stack.pop()
411 @staticmethod
412 def _convert_mesh_to_tensor(mesh: Union[Tensor, list, tuple, np.ndarray]) -> Tensor:
413 """Convert a public mesh input into an int32 platform tensor."""
414 if isinstance(mesh, Tensor):
415 mesh = platform.tensor_to_numpy(mesh)
416 elif isinstance(mesh, (list, tuple)):
417 mesh = np.array(mesh)
418 elif not isinstance(mesh, np.ndarray):
419 raise TypeError(
420 f"mesh must be Tensor, list, tuple or numpy array, but got {type(mesh)}"
421 )
423 mesh = mesh.astype(np.int32)
424 return Tensor(mesh).int()
426 @staticmethod
427 def _init_one_process_group(mesh_shape: tuple[int, ...], mesh_dim_names: tuple[str, ...],
428 dim_name: str, rank_list: tuple[int, ...]) -> str:
429 """Create one process-group family for the named mesh dimension."""
430 group_key = None
431 split_ranks = set()
432 if not isinstance(dim_name, tuple):
433 dim_name = (dim_name,)
434 for rank in rank_list:
435 split_rank = _get_sub_rank_list(mesh_shape, mesh_dim_names, rank_list, dim_name, rank)
436 sorted_rank = tuple(sorted(split_rank))
437 split_ranks.add(sorted_rank)
438 if rank == platform.get_rank():
439 group_key = str(sorted_rank)
440 split_ranks = sorted([list(item) for item in split_ranks])
441 platform.split_group(split_ranks=split_ranks)
442 return group_key
444 @staticmethod
445 def _build_dim_split_ranks(
446 sub_layout: _MeshLayout,
447 rank_map: Tensor,
448 ) -> tuple[list[list[int]], Optional[str]]:
449 """Build rank lists and the local cache key for one logical mesh axis."""
450 pg_ranks_by_dim = sub_layout.remap_to_numpy(platform.tensor_to_numpy(rank_map))
451 current_rank = platform.get_rank()
452 split_ranks = []
453 split_ranks_set = set()
454 group_key = None
455 for dim_mesh in np.array(pg_ranks_by_dim):
456 subgroup_ranks = tuple(int(rank) for rank in np.array(dim_mesh).reshape(-1).tolist())
457 subgroup_ranks_sorted = tuple(sorted(subgroup_ranks))
458 if subgroup_ranks_sorted not in split_ranks_set:
459 split_ranks_set.add(subgroup_ranks_sorted)
460 split_ranks.append(list(subgroup_ranks_sorted))
461 if current_rank in subgroup_ranks:
462 if group_key is not None:
463 raise RuntimeError(
464 "Each device mesh dimension should get only one process group per rank."
465 )
466 group_key = str(subgroup_ranks_sorted)
467 split_ranks = sorted(split_ranks)
468 return split_ranks, group_key
470 @staticmethod
471 def _cache_group_if_needed(group_key: Optional[str], group: Any) -> None:
472 if group_key is not None and group is not None and group_key not in EXISTING_COMM_GROUPS:
473 EXISTING_COMM_GROUPS[group_key] = group
475 @staticmethod
476 def _init_process_groups_for_layout(
477 layout: _MeshLayout,
478 rank_map: Tensor,
479 mesh_dim_names: Union[tuple[str, ...], None],
480 backend_override: Optional[tuple[BackendConfig, ...]] = None,
481 ) -> list:
482 """Initialize process groups for each top-level axis in the given layout."""
483 if mesh_dim_names is None:
484 mesh_dim_names = tuple(f"dim_{dim}" for dim in range(len(layout)))
485 if backend_override is None:
486 backend_override = (None,) * len(layout)
487 if len(backend_override) != len(layout):
488 raise ValueError(
489 f"backend_override length {len(backend_override)} must match layout rank {len(layout)}"
490 )
492 dim_group_names = []
493 for dim, sub_layout in enumerate(layout):
494 split_ranks, group_key = DeviceMesh._build_dim_split_ranks(sub_layout, rank_map)
495 if _should_defer_group_init(sub_layout, backend_override[dim]):
496 dim_group_names.append(None)
497 continue
498 group = platform.split_group(split_ranks=split_ranks)
499 DeviceMesh._cache_group_if_needed(group_key, group)
500 dim_group_names.append(group_key)
501 return dim_group_names
503 @staticmethod
504 def _init_process_groups(mesh_shape: tuple[int, ...], mesh_dim_names: Union[tuple[str, ...], None],
505 rank_list: tuple[int, ...],
506 backend_override: Optional[tuple[BackendConfig, ...]] = None) -> list:
507 layout = _MeshLayout(mesh_shape, _contiguous_strides(mesh_shape))
508 rank_map = DeviceMesh._convert_rank_map_to_tensor(rank_list)
509 return DeviceMesh._init_process_groups_for_layout(
510 layout,
511 rank_map,
512 mesh_dim_names,
513 backend_override=backend_override,
514 )
516 @property
517 def rank(self):
518 return self._rank
520 @property
521 def mesh_shape(self):
522 return self._mesh_shape
524 @property
525 def rank_list(self):
526 return self._rank_list
528 @property
529 def ndim(self) -> int:
530 return self._ndim
532 @property
533 def shape(self) -> tuple:
534 return self._mesh_shape
536 @property
537 def root_mesh(self) -> Optional['DeviceMesh']:
538 return self._root_mesh
540 @root_mesh.setter
541 def root_mesh(self, value: Optional['DeviceMesh']):
542 self._root_mesh = value
544 @property
545 def sub_mesh(self) -> List['DeviceMesh']:
546 return self._sub_mesh
548 def get_flatten_mapping(self) -> dict:
549 return self._flatten_mapping
551 def add_flatten_mapping(self, name: str, mesh: 'DeviceMesh') -> None:
552 self._flatten_mapping[name] = mesh
554 def __getitem__(self, sub_mesh_dim_names: Union[str, tuple[str, ...]]) -> 'DeviceMesh':
555 if not self.mesh_dim_names:
556 raise RuntimeError("Cannot slice a DeviceMesh without mesh_dim_names!")
558 sub_mesh_dim_names = DeviceMesh._normalize_sub_mesh_dim_names(sub_mesh_dim_names)
559 flatten_mapping = self._get_root_mesh().get_flatten_mapping()
561 flattened_result = self._try_get_from_flatten_mapping(sub_mesh_dim_names, flatten_mapping)
562 if flattened_result is not None:
563 return flattened_result
565 layout = self._get_slice_mesh_layout(sub_mesh_dim_names)
566 if sub_mesh_dim_names in self._sub_mesh_cache:
567 return self._sub_mesh_cache[sub_mesh_dim_names]
568 if layout == self._layout:
569 return self
570 return self._create_and_cache_sub_mesh(sub_mesh_dim_names, layout)
572 @staticmethod
573 def _normalize_sub_mesh_dim_names(sub_mesh_dim_names: Union[str, tuple[str, ...]]) -> tuple[str, ...]:
574 """Normalize a slice selector into a non-empty tuple of mesh dim names."""
575 if isinstance(sub_mesh_dim_names, str):
576 sub_mesh_dim_names = (sub_mesh_dim_names,)
578 if not isinstance(sub_mesh_dim_names, tuple):
579 raise TypeError(
580 f"sub_mesh_dim_names must be str or tuple, but got {type(sub_mesh_dim_names)}"
581 )
583 if len(sub_mesh_dim_names) == 0:
584 raise ValueError("sub_mesh_dim_names cannot be empty")
586 return sub_mesh_dim_names
588 @staticmethod
589 def _try_get_from_flatten_mapping(sub_mesh_dim_names: tuple[str, ...],
590 flatten_mapping: dict) -> Optional['DeviceMesh']:
591 if len(sub_mesh_dim_names) == 1 and sub_mesh_dim_names[0] in flatten_mapping:
592 return flatten_mapping[sub_mesh_dim_names[0]]
593 return None
595 def _get_mesh_dim_by_name(self, mesh_dim_name: str) -> int:
596 """Resolve a named mesh axis to its integer position."""
597 mesh_dim_names = self.mesh_dim_names or ()
598 if len(mesh_dim_names) == 0:
599 raise KeyError("No mesh_dim_names found.")
600 if mesh_dim_name not in mesh_dim_names:
601 raise KeyError(
602 f"Mesh dimension '{mesh_dim_name}' does not exist. "
603 f"Available mesh dimensions are: {mesh_dim_names}"
604 )
605 return mesh_dim_names.index(mesh_dim_name)
607 def _get_slice_mesh_layout(self, sub_mesh_dim_names: tuple[str, ...]) -> _MeshLayout:
608 """Construct the layout corresponding to one named sub-mesh slice request."""
609 root_mesh = self._get_root_mesh()
610 slice_from_root = self == root_mesh
611 flatten_name_to_layout = (
612 {key: mesh._layout for key, mesh in root_mesh.get_flatten_mapping().items()}
613 if slice_from_root else {}
614 )
615 valid_dim_names = [*(self.mesh_dim_names or ()), *flatten_name_to_layout]
616 if not all(name in valid_dim_names for name in sub_mesh_dim_names):
617 raise KeyError(
618 f"Invalid mesh_dim_names {sub_mesh_dim_names} specified. "
619 f"Valid mesh_dim_names are {valid_dim_names}."
620 )
622 if all(name in (self.mesh_dim_names or ()) for name in sub_mesh_dim_names):
623 indices = [self.mesh_dim_names.index(name) for name in sub_mesh_dim_names]
624 if indices != sorted(indices):
625 raise ValueError(
626 f"sub_mesh_dim_names {sub_mesh_dim_names} must follow the order of "
627 f"original mesh_dim_names {self.mesh_dim_names}"
628 )
630 sliced_sizes: list[IntTuple] = []
631 sliced_strides: list[IntTuple] = []
632 for name in sub_mesh_dim_names:
633 if name in (self.mesh_dim_names or ()):
634 layout = self._layout[self.mesh_dim_names.index(name)]
635 else:
636 layout = flatten_name_to_layout[name]
637 sliced_sizes.append(layout.sizes)
638 sliced_strides.append(layout.strides)
640 pre_stride = -1
641 for stride in reversed(sliced_strides):
642 if not _is_int(stride):
643 raise NotImplementedError(
644 "Currently, this only allows slicing out a contiguous flattened dim."
645 )
646 if stride < pre_stride:
647 raise ValueError(
648 f"Invalid mesh_dim_names {sub_mesh_dim_names} specified. "
649 "Mesh dim indices should be in ascending order."
650 )
651 pre_stride = stride
653 if len(sliced_sizes) == 1:
654 layout = _MeshLayout(sliced_sizes[0], sliced_strides[0])
655 else:
656 layout = _MeshLayout(tuple(sliced_sizes), tuple(sliced_strides))
657 if not layout.check_non_overlap():
658 raise RuntimeError(f"Slicing overlapping dim_names {sub_mesh_dim_names} is not allowed.")
659 return layout
661 def _create_and_cache_sub_mesh(self, sub_mesh_dim_names: tuple[str, ...], layout: _MeshLayout) -> 'DeviceMesh':
662 """Create a sub-mesh view, copy group metadata, and cache the result."""
663 root_mesh = self._get_root_mesh()
664 sub_mesh = DeviceMesh(
665 device_type=self.device_type,
666 mesh_dim_names=sub_mesh_dim_names,
667 _init_backend=False,
668 _layout=layout,
669 _rank_map=root_mesh._rank_map,
670 _root_mesh=root_mesh,
671 )
673 slice_dim_group_name = []
674 slice_dim_group_backends: list[BackendConfig] = []
675 slice_dim_group_sources: list[tuple['DeviceMesh', int]] = []
676 for name in sub_mesh_dim_names:
677 if name in (self.mesh_dim_names or ()):
678 dim_index = self.mesh_dim_names.index(name)
679 if hasattr(self, "_dim_group_names"):
680 slice_dim_group_name.append(self._dim_group_names[dim_index])
681 slice_dim_group_backends.append(self._dim_group_backends[dim_index])
682 if hasattr(self, "_dim_group_sources"):
683 slice_dim_group_sources.append(self._dim_group_sources[dim_index]) # pylint: disable=W0212
684 else:
685 slice_dim_group_sources.append((self, dim_index))
686 elif name in root_mesh.get_flatten_mapping():
687 flatten_mesh = root_mesh.get_flatten_mapping()[name]
688 if hasattr(flatten_mesh, "_dim_group_names"):
689 slice_dim_group_name.append(flatten_mesh._dim_group_names[0])
690 slice_dim_group_backends.append(flatten_mesh._dim_group_backends[0])
691 if hasattr(flatten_mesh, "_dim_group_sources"):
692 slice_dim_group_sources.append(flatten_mesh._dim_group_sources[0]) # pylint: disable=W0212
693 else:
694 slice_dim_group_sources.append((flatten_mesh, 0))
695 if slice_dim_group_name:
696 sub_mesh._dim_group_names = slice_dim_group_name # pylint: disable=W0212
697 if slice_dim_group_backends:
698 sub_mesh._dim_group_backends = tuple(slice_dim_group_backends) # pylint: disable=W0212
699 if slice_dim_group_sources:
700 sub_mesh._dim_group_sources = tuple(slice_dim_group_sources) # pylint: disable=W0212
702 self._sub_mesh_cache[sub_mesh_dim_names] = sub_mesh
703 self.sub_mesh.append(sub_mesh)
704 return sub_mesh
706 def get_group(self, mesh_dim: Optional[Union[int, str]] = None):
707 """Return the communication group for one mesh axis."""
708 if not hasattr(self, "_dim_group_names"):
709 raise RuntimeError("DeviceMesh process groups not initialized!")
711 if self.ndim > 1 and mesh_dim is None:
712 raise RuntimeError(
713 f"Found the DeviceMesh have {self.ndim} dimensions. "
714 "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1."
715 )
717 root_mesh = self._get_root_mesh()
718 if isinstance(mesh_dim, str) and mesh_dim in root_mesh.get_flatten_mapping():
719 flattened_mesh = root_mesh.get_flatten_mapping()[mesh_dim]
720 return flattened_mesh.get_comm_group_by_axis(mesh_dim)
722 return self.get_comm_group_by_axis(mesh_dim)
724 def get_all_groups(self) -> list:
725 if not hasattr(self, "_dim_group_names"):
726 raise RuntimeError("DeviceMesh process groups not initialized!")
728 return [self.get_group(i) for i in range(self.ndim)]
730 @staticmethod
731 def from_group(group: Union[Any, list[Any]],
732 device_type: str,
733 mesh: Union[Tensor, list, tuple, np.ndarray] = None,
734 mesh_dim_names: Union[tuple[str, ...], list[str]] = None
735 ) -> 'DeviceMesh':
736 """Build a DeviceMesh from an existing process group or a list of groups."""
737 if not isinstance(group, list):
738 group_ranks = platform.get_process_group_ranks(group)
739 group_key = str(tuple(sorted(group_ranks)))
740 if not platform.get_created_group(group_ranks):
741 EXISTING_COMM_GROUPS[group_key] = group
742 if (
743 isinstance(mesh, Tensor) and mesh.tolist() != group_ranks
744 ) or (
745 mesh is not None
746 and not isinstance(mesh, Tensor)
747 and mesh != group_ranks
748 ):
749 raise ValueError(
750 f"Invalid mesh_shape {str(mesh)} for 1D group with ranks {group_ranks}"
751 )
752 device_mesh = DeviceMesh(device_type, group_ranks, mesh_dim_names=mesh_dim_names, _init_backend=False)
753 device_mesh._dim_group_names = [group_key] # pylint: disable=W0212
754 return device_mesh
756 groups = list(group)
757 if len(groups) == 0:
758 raise ValueError("Expect at least one group be specified.")
759 if mesh is None:
760 raise ValueError("mesh_shape is must specified when group is a list.")
761 mesh = DeviceMesh._convert_mesh_to_tensor(mesh)
762 if mesh.ndim != len(groups):
763 raise ValueError("mesh dimensions must match group dimensions.")
764 device_mesh = DeviceMesh(device_type, mesh, mesh_dim_names=mesh_dim_names, _init_backend=False)
765 device_mesh._dim_group_names = [] # pylint: disable=W0212
766 for dim_group in groups:
767 group_ranks = platform.get_process_group_ranks(dim_group)
768 group_key = str(tuple(sorted(group_ranks)))
769 if not platform.get_created_group(group_ranks):
770 EXISTING_COMM_GROUPS[group_key] = dim_group
771 device_mesh._dim_group_names.append(group_key) # pylint: disable=W0212
772 return device_mesh
774 def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int:
775 """Return the local coordinate of the current rank along one mesh dimension."""
776 if self.ndim > 1 and mesh_dim is None:
777 raise RuntimeError(
778 f"Found the DeviceMesh have {self.ndim} dimensions. "
779 "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1."
780 )
782 if mesh_dim is None:
783 mesh_dim = 0
785 if isinstance(mesh_dim, str):
786 if mesh_dim not in self.mesh_dim_names: # pylint: disable=E1135
787 raise ValueError(
788 f"mesh_dim '{mesh_dim}' not found in mesh_dim_names {self.mesh_dim_names}"
789 )
790 dim_index = self.mesh_dim_names.index(mesh_dim)
791 else:
792 if not isinstance(mesh_dim, int) or mesh_dim < 0 or mesh_dim >= self.ndim:
793 raise ValueError(
794 f"mesh_dim must be an integer in range [0, {self.ndim}), "
795 f"but got {mesh_dim}"
796 )
797 dim_index = mesh_dim
799 if self._rank not in self._rank_list:
800 raise ValueError(
801 f"Current rank {self._rank} not found in rank_list {self._rank_list}"
802 )
804 idx = self._rank_list.index(self._rank)
805 coord = [0] * len(self._mesh_shape)
806 temp = idx
807 for i in range(len(self._mesh_shape) - 1, -1, -1):
808 coord[i] = temp % self._mesh_shape[i]
809 temp //= self._mesh_shape[i]
811 return coord[dim_index]
813 def flatten(self, mesh_dim_name: Optional[str] = None) -> 'DeviceMesh':
814 return self._create_flatten_mesh(mesh_dim_name)
816 def _get_root_mesh(self) -> 'DeviceMesh':
817 """Return the canonical root mesh for this view."""
818 if self._root_mesh is None:
819 return self
820 return self._root_mesh._get_root_mesh() # pylint: disable=protected-access
822 @staticmethod
823 def _validate_concatenate_inputs(
824 meshes: Sequence['DeviceMesh'],
825 ) -> tuple['DeviceMesh', tuple[str, ...], tuple[int, ...]]:
826 """Validate concatenate inputs and return the shared root metadata."""
827 if len(meshes) == 0:
828 raise ValueError("DeviceMesh.concatenate expects at least one mesh.")
829 if len(meshes) == 1:
830 return meshes[0]._get_root_mesh(), tuple(meshes[0].mesh_dim_names or ()), meshes[0]._flatten_rank_map
832 root_mesh = meshes[0]._get_root_mesh() # pylint: disable=protected-access
833 requested_dim_names: list[str] = []
834 flatten_rank_map = meshes[0]._flatten_rank_map # pylint: disable=protected-access
835 for mesh in meshes:
836 if mesh._get_root_mesh().to_hash() != root_mesh.to_hash(): # pylint: disable=protected-access
837 raise ValueError("DeviceMesh.concatenate expects all meshes to share the same root mesh.")
838 if mesh._flatten_rank_map != flatten_rank_map: # pylint: disable=protected-access
839 raise ValueError("DeviceMesh.concatenate expects all meshes to share the same root mesh.")
840 if not mesh.mesh_dim_names:
841 raise ValueError("DeviceMesh.concatenate requires mesh_dim_names on every input mesh.")
842 requested_dim_names.extend(mesh.mesh_dim_names)
843 return root_mesh, tuple(requested_dim_names), flatten_rank_map
845 @staticmethod
846 def _validate_concatenate_root_order(root_mesh: 'DeviceMesh', requested_dim_names: tuple[str, ...]) -> None:
847 """Require original root dims to stay in root order when concatenating by name."""
848 root_dim_names = tuple(root_mesh.mesh_dim_names) if root_mesh.mesh_dim_names else ()
849 if not root_dim_names or not all(dim_name in root_dim_names for dim_name in requested_dim_names):
850 return
852 requested_indices = [root_dim_names.index(dim_name) for dim_name in requested_dim_names]
853 if requested_indices != sorted(requested_indices):
854 raise ValueError(
855 "DeviceMesh.concatenate expects meshes to follow the root mesh order. "
856 f"Got root mesh dims {root_dim_names} and requested dims {requested_dim_names}."
857 )
859 @staticmethod
860 def _collect_concatenate_metadata(
861 meshes: Sequence['DeviceMesh'],
862 ) -> tuple[
863 list[str],
864 list[IntTuple],
865 list[IntTuple],
866 list[Optional[str]],
867 list[BackendConfig],
868 list[tuple['DeviceMesh', int]],
869 ]:
870 """Collect layout and process-group metadata from all concatenate inputs."""
871 concat_dim_names: list[str] = []
872 concat_sizes: list[IntTuple] = []
873 concat_strides: list[IntTuple] = []
874 concat_dim_group_names: list[Optional[str]] = []
875 concat_dim_group_backends: list[BackendConfig] = []
876 concat_dim_group_sources: list[tuple['DeviceMesh', int]] = []
878 for mesh in meshes:
879 for dim, sub_layout in enumerate(mesh._layout): # pylint: disable=protected-access
880 concat_sizes.append(sub_layout.sizes)
881 concat_strides.append(sub_layout.strides)
882 if hasattr(mesh, "_dim_group_names"):
883 concat_dim_group_names.append(mesh._dim_group_names[dim]) # pylint: disable=protected-access
884 concat_dim_group_backends.append(mesh._dim_group_backends[dim]) # pylint: disable=protected-access
885 if hasattr(mesh, "_dim_group_sources"):
886 concat_dim_group_sources.append(mesh._dim_group_sources[dim]) # pylint: disable=protected-access
887 else:
888 concat_dim_group_sources.append((mesh, dim))
889 concat_dim_names.extend(mesh.mesh_dim_names)
891 if len(set(concat_dim_names)) != len(concat_dim_names):
892 raise ValueError(
893 f"DeviceMesh.concatenate expects disjoint mesh dims, but got {tuple(concat_dim_names)}."
894 )
895 return (
896 concat_dim_names,
897 concat_sizes,
898 concat_strides,
899 concat_dim_group_names,
900 concat_dim_group_backends,
901 concat_dim_group_sources,
902 )
904 @staticmethod
905 def _build_concatenate_layout(concat_sizes: list[IntTuple], concat_strides: list[IntTuple]) -> _MeshLayout:
906 """Build the layout represented by concatenated top-level mesh axes."""
907 if len(concat_sizes) == 1:
908 return _MeshLayout(concat_sizes[0], concat_strides[0])
909 return _MeshLayout(tuple(concat_sizes), tuple(concat_strides))
911 @staticmethod
912 def _set_concatenated_group_state(
913 mesh: 'DeviceMesh',
914 dim_group_names: list[Optional[str]],
915 dim_group_backends: list[BackendConfig],
916 dim_group_sources: list[tuple['DeviceMesh', int]],
917 ) -> None:
918 """Attach inherited process-group metadata to a concatenated mesh view."""
919 if dim_group_names:
920 mesh._dim_group_names = dim_group_names # pylint: disable=W0212
921 if dim_group_backends:
922 mesh._dim_group_backends = tuple(dim_group_backends) # pylint: disable=W0212
923 if dim_group_sources:
924 mesh._dim_group_sources = tuple(dim_group_sources) # pylint: disable=W0212
926 @staticmethod
927 def concatenate(meshes: Sequence['DeviceMesh']) -> 'DeviceMesh':
928 """Concatenate multiple sub-mesh views into one wider layout-backed mesh."""
929 if len(meshes) == 1:
930 return meshes[0]
931 root_mesh, requested_dim_names, _ = DeviceMesh._validate_concatenate_inputs(meshes)
932 DeviceMesh._validate_concatenate_root_order(root_mesh, requested_dim_names)
933 (
934 concat_dim_names,
935 concat_sizes,
936 concat_strides,
937 concat_dim_group_names,
938 concat_dim_group_backends,
939 concat_dim_group_sources,
940 ) = DeviceMesh._collect_concatenate_metadata(meshes)
941 concat_layout = DeviceMesh._build_concatenate_layout(concat_sizes, concat_strides)
942 if not concat_layout.check_non_overlap():
943 raise ValueError(f"Cannot concatenate overlapping meshes: {meshes}")
945 res_mesh = DeviceMesh(
946 meshes[0].device_type,
947 mesh_dim_names=tuple(concat_dim_names),
948 _init_backend=False,
949 _layout=concat_layout,
950 _rank_map=meshes[0]._rank_map, # pylint: disable=protected-access
951 _root_mesh=meshes[0]._get_root_mesh(), # pylint: disable=protected-access
952 )
953 DeviceMesh._set_concatenated_group_state(
954 res_mesh,
955 concat_dim_group_names,
956 concat_dim_group_backends,
957 concat_dim_group_sources,
958 )
959 return res_mesh
961 _concatenate = concatenate
963 def _create_flatten_mesh(
964 self,
965 mesh_dim_name: Optional[str] = None,
966 backend_override: BackendConfig = None,
967 ) -> 'DeviceMesh':
968 """Create or reuse a flattened one-dimensional mesh view."""
969 root_mesh = self._get_root_mesh()
971 if mesh_dim_name is None:
972 mesh_dim_name = "_".join(self.mesh_dim_names)
974 if self.ndim == 1 and mesh_dim_name in self.mesh_dim_names: # pylint: disable=E1135
975 return self
977 invalid_dim_names = root_mesh.mesh_dim_names
978 if mesh_dim_name in invalid_dim_names:
979 raise ValueError(
980 f"'{mesh_dim_name}' already exists in the root mesh mesh_dim_names "
981 f"{invalid_dim_names}. Please specify another valid mesh_dim_name."
982 )
984 flattened_mesh_layout = self._layout.coalesce()
985 if len(flattened_mesh_layout) > 1:
986 flattened_mesh_layout = flattened_mesh_layout.nest()
988 flatten_mapping = root_mesh.get_flatten_mapping()
989 if mesh_dim_name in flatten_mapping:
990 cached_mesh = flatten_mapping[mesh_dim_name]
991 if cached_mesh._layout == flattened_mesh_layout: # pylint: disable=protected-access
992 return cached_mesh
993 raise ValueError(
994 f"Flatten mesh with mesh_dim_name '{mesh_dim_name}' has been created "
995 f"before with different layout. Please specify another valid mesh_dim_name."
996 )
998 res_flattened_mesh = DeviceMesh(
999 device_type=root_mesh.device_type,
1000 mesh_dim_names=(mesh_dim_name,),
1001 _init_backend=False,
1002 _layout=flattened_mesh_layout,
1003 _rank_map=root_mesh._rank_map,
1004 _root_mesh=root_mesh,
1005 )
1006 res_flattened_mesh._dim_group_backends = (backend_override,) # pylint: disable=W0212
1007 if hasattr(self, "_dim_group_names"):
1008 res_flattened_mesh._dim_group_names = DeviceMesh._init_process_groups_for_layout( # pylint: disable=W0212
1009 res_flattened_mesh._layout,
1010 root_mesh._rank_map,
1011 res_flattened_mesh.mesh_dim_names,
1012 backend_override=(backend_override,),
1013 )
1015 root_mesh.add_flatten_mapping(mesh_dim_name, res_flattened_mesh)
1016 root_mesh._sub_mesh_cache[(mesh_dim_name,)] = res_flattened_mesh # pylint: disable=W0212
1017 root_mesh.sub_mesh.append(res_flattened_mesh)
1019 return res_flattened_mesh
1021 def _create_unflatten_mesh(
1022 self,
1023 dim: int,
1024 mesh_sizes: tuple[int, ...],
1025 mesh_dim_names: tuple[str, ...],
1026 backend_override: tuple[BackendConfig, ...],
1027 ) -> 'DeviceMesh':
1028 """Split one logical mesh axis into multiple named axes."""
1029 inner_layout = _MeshLayout(mesh_sizes, _contiguous_strides(mesh_sizes))
1030 original_layout = self._layout[dim]
1031 if inner_layout.numel() != original_layout.numel():
1032 raise ValueError(
1033 f"The product of mesh_sizes={mesh_sizes} is {inner_layout.numel()}, "
1034 f"but the original dimension at dim={dim} has size {original_layout.numel()}."
1035 )
1037 partial_layout = original_layout.composition(inner_layout)
1038 unflattened_layout = self._layout.splice(dim, dim + 1, partial_layout)
1039 unflattened_mesh_dim_names = list(self.mesh_dim_names or ())
1040 unflattened_mesh_dim_names[dim: dim + 1] = list(mesh_dim_names)
1042 root_mesh = self._get_root_mesh()
1043 res_mesh = DeviceMesh(
1044 self.device_type,
1045 mesh_dim_names=tuple(unflattened_mesh_dim_names),
1046 _init_backend=False,
1047 _layout=unflattened_layout,
1048 _rank_map=root_mesh._rank_map,
1049 _root_mesh=root_mesh,
1050 )
1052 dim_group_backends = list(self._dim_group_backends)
1053 dim_group_backends[dim: dim + 1] = list(backend_override)
1054 res_mesh._dim_group_backends = tuple(dim_group_backends) # pylint: disable=W0212
1056 if hasattr(self, "_dim_group_names"):
1057 dim_group_names = list(self._dim_group_names)
1058 dim_group_names[dim: dim + 1] = DeviceMesh._init_process_groups_for_layout(
1059 partial_layout,
1060 root_mesh._rank_map,
1061 mesh_dim_names,
1062 backend_override=backend_override,
1063 )
1064 res_mesh._dim_group_names = dim_group_names # pylint: disable=W0212
1066 return res_mesh
1068 def _flatten(self, mesh_dim_name: Optional[str] = None, backend_override: Any = None) -> 'DeviceMesh':
1069 return self._create_flatten_mesh(
1070 mesh_dim_name,
1071 backend_override=_normalize_backend_value(backend_override),
1072 )
1074 def _unflatten(
1075 self,
1076 dim: Union[int, str],
1077 mesh_sizes: tuple[int, ...],
1078 mesh_dim_names: tuple[str, ...],
1079 backend_override: Optional[dict[Union[int, str], Any]] = None,
1080 ) -> 'DeviceMesh':
1081 """Torch-compatible helper that expands one mesh axis into a nested layout."""
1082 if isinstance(dim, int):
1083 if dim < 0 or dim >= self.ndim:
1084 raise ValueError(f"dim {dim} specified in `_unflatten` is out of range {self.ndim}")
1085 else:
1086 mesh_dim_names_tuple = self.mesh_dim_names or ()
1087 if dim not in mesh_dim_names_tuple:
1088 raise ValueError(f"dim {dim} specified in `_unflatten` is not in {mesh_dim_names_tuple}")
1089 dim = mesh_dim_names_tuple.index(dim)
1091 if len(mesh_sizes) != len(mesh_dim_names):
1092 raise RuntimeError("mesh_dim_names must have same length as mesh_sizes in _unflatten!")
1094 backend_override_tuple = (
1095 _normalize_backend_override(backend_override, len(mesh_sizes), mesh_dim_names)
1096 if backend_override is not None
1097 else (None,) * len(mesh_dim_names)
1098 )
1099 return self._create_unflatten_mesh(dim, mesh_sizes, mesh_dim_names, backend_override_tuple)
1101 def assert_axis(self, axis, operate_name):
1102 if not self.mesh_dim_names:
1103 raise RuntimeError(f"mesh_dim_names not specified, {operate_name} is not supported.")
1104 if axis not in self.mesh_dim_names: # pylint: disable=E1135
1105 raise ValueError(
1106 f"The axis name must be one of mesh dim name {self.mesh_dim_names}, but got {axis}"
1107 )
1109 def axis_id(self, axis):
1110 if axis == "None":
1111 return -1
1112 self.assert_axis(axis, "axis_id")
1113 return self._dev_name_to_dev_id[axis]
1115 def axis_index(self, axis):
1116 self.assert_axis(axis, "axis_index")
1117 return self._dev_name_to_index[axis]
1119 def get_device_num_along_axis(self, axis):
1120 self.assert_axis(axis, "get_device_num_along_axis")
1121 return self.mesh_shape[self.mesh_dim_names.index(axis)]
1123 def get_rank_list_along_axis(self, mesh_dim):
1124 """Return the ranks that share every other coordinate with the current rank."""
1125 if mesh_dim in self._cache_rank_list_along_axis:
1126 return self._cache_rank_list_along_axis[mesh_dim]
1127 self.assert_axis(mesh_dim, "get_rank_list_along_axis")
1129 mesh_shape = self.mesh_shape
1130 mesh_dim_names = self.mesh_dim_names
1131 rank_list = self.rank_list
1132 rank = self.rank
1134 if rank not in rank_list:
1135 raise ValueError(f"Rank {rank} not found in rank_list")
1137 idx = rank_list.index(rank)
1138 coord = [0] * len(mesh_shape)
1139 temp = idx
1140 for i in range(len(mesh_shape) - 1, -1, -1):
1141 coord[i] = temp % mesh_shape[i]
1142 temp //= mesh_shape[i]
1144 dim_index = mesh_dim_names.index(mesh_dim)
1145 strides = [1] * len(mesh_shape)
1146 for i in range(len(mesh_shape) - 2, -1, -1):
1147 strides[i] = strides[i + 1] * mesh_shape[i + 1]
1149 result_ranks = []
1150 for v in range(mesh_shape[dim_index]):
1151 new_coord = coord.copy()
1152 new_coord[dim_index] = v
1153 new_idx = 0
1154 for i in range(len(mesh_shape)):
1155 new_idx += new_coord[i] * strides[i]
1157 result_ranks.append(rank_list[new_idx])
1159 self._cache_rank_list_along_axis[mesh_dim] = result_ranks
1160 return result_ranks
1162 def get_global_shape(self, slice_shape, tensor_map):
1163 """Infer the global tensor shape from a shard shape and tensor-map metadata."""
1164 map_key = hash((slice_shape, tensor_map))
1165 if map_key in self._global_shape_map:
1166 return self._global_shape_map[map_key]
1167 if tensor_map is None:
1168 raise ValueError(
1169 "tensor_map is not set. Please configure the tensor map by calling the layout."
1170 )
1171 if len(slice_shape) != len(tensor_map):
1172 raise ValueError(
1173 f"Length of slice_shape ({len(slice_shape)}) must match "
1174 f"the length of tensor_map ({len(tensor_map)})."
1175 )
1177 n_dims = len(self._mesh_shape)
1178 factors = [1] * len(slice_shape)
1180 for dev_idx, size in enumerate(self._mesh_shape):
1181 reverse_idx = n_dims - 1 - dev_idx
1182 for axis_idx, mapping in enumerate(tensor_map):
1183 if isinstance(mapping, int):
1184 if mapping == -1:
1185 continue
1186 if mapping == reverse_idx:
1187 factors[axis_idx] *= size
1188 break
1189 elif isinstance(mapping, tuple):
1190 if reverse_idx in mapping:
1191 factors[axis_idx] *= size
1192 break
1194 global_shape = []
1195 for i, dim in enumerate(slice_shape):
1196 global_shape.append(dim * factors[i])
1197 self._global_shape_map[map_key] = tuple(global_shape)
1198 return tuple(global_shape)
1200 def _materialize_dim_group(self, mesh_dim: int) -> Optional[str]:
1201 """Create a deferred process group for one mesh dimension on first use."""
1202 if not hasattr(self, "_dim_group_names"):
1203 self._dim_group_names = [None] * self.ndim # pylint: disable=W0201
1205 if hasattr(self, "_dim_group_sources"):
1206 source_mesh, source_dim = self._dim_group_sources[mesh_dim] # pylint: disable=W0212
1207 if source_mesh is not self or source_dim != mesh_dim:
1208 source_group_key = source_mesh._materialize_dim_group(source_dim) # pylint: disable=W0212
1209 self._dim_group_names[mesh_dim] = source_group_key
1210 return source_group_key
1212 group_key = self._dim_group_names[mesh_dim]
1213 if group_key is not None and group_key in EXISTING_COMM_GROUPS:
1214 return group_key
1216 split_ranks, group_key = DeviceMesh._build_dim_split_ranks(self._layout[mesh_dim], self._rank_map)
1217 group = platform.split_group(split_ranks=split_ranks)
1218 DeviceMesh._cache_group_if_needed(group_key, group)
1219 self._dim_group_names[mesh_dim] = group_key
1220 return group_key
1222 def get_comm_group_by_axis(self, mesh_dim: Union[str, int]):
1223 """Return the cached or lazily materialized process group for one mesh axis."""
1224 if self.ndim == 1 and mesh_dim is None:
1225 mesh_dim = 0
1227 if isinstance(mesh_dim, str):
1228 if self.mesh_dim_names is None or len(self.mesh_dim_names) == 0:
1229 raise ValueError(f"DeviceMesh mesh_dim_names is not set, string mesh_dim {mesh_dim}, is not support.")
1230 if mesh_dim not in self.mesh_dim_names: # pylint: disable=E1135
1231 raise ValueError(
1232 f"mesh_dim can pass a string or integer, but string mesh_dim '{mesh_dim}' not found in "
1233 f"mesh_dim_names {self.mesh_dim_names}"
1234 )
1235 mesh_dim = self.mesh_dim_names.index(mesh_dim)
1236 else:
1237 if not isinstance(mesh_dim, int) or mesh_dim < 0 or mesh_dim >= self.ndim:
1238 raise ValueError(
1239 f"mesh_dim can pass a string or integer, if not string, mesh_dim should be a integer in range "
1240 f"[0, {self.ndim}), but got {mesh_dim}"
1241 )
1243 if not hasattr(self, "_dim_group_names"):
1244 raise RuntimeError("DeviceMesh process groups not initialized!")
1246 group_key = self._dim_group_names[mesh_dim]
1247 if group_key is None or group_key not in EXISTING_COMM_GROUPS:
1248 group_key = self._materialize_dim_group(mesh_dim)
1249 if group_key not in EXISTING_COMM_GROUPS:
1250 raise ValueError(f"{group_key} not in group cache {EXISTING_COMM_GROUPS.keys()}")
1251 return EXISTING_COMM_GROUPS[group_key]
1253 def get_devices_for_axis(self, mesh_dim: Union[str, int], rank: int):
1254 """List peer ranks that share all coordinates except the requested axis."""
1255 if isinstance(mesh_dim, str):
1256 if not self.mesh_dim_names:
1257 raise ValueError("_mesh_dim_names is not set, string mesh_dim is not supported, please pass a integer.")
1258 mesh_dim_names = self.mesh_dim_names
1259 if mesh_dim not in mesh_dim_names: # pylint: disable=E1135
1260 raise ValueError(f"mesh_dim '{mesh_dim}' not found in mesh_dim_names {mesh_dim_names}")
1261 mesh_dim = mesh_dim_names.index(mesh_dim)
1263 mesh_shape = self._mesh_shape
1264 if mesh_dim < 0 or mesh_dim >= self.ndim:
1265 raise ValueError(f"mesh_dim {mesh_dim} can not out of range [0, {self.ndim})")
1266 rank_list = self._rank_list
1267 if rank not in rank_list:
1268 raise ValueError(f"Rank {rank} not found in rank_list")
1270 idx = rank_list.index(rank)
1271 coord = [0] * len(mesh_shape)
1272 temp = idx
1273 for i in range(len(mesh_shape) - 1, -1, -1):
1274 coord[i] = temp % mesh_shape[i]
1275 temp //= mesh_shape[i]
1277 strides = [1] * len(mesh_shape)
1278 for i in range(len(mesh_shape) - 2, -1, -1):
1279 strides[i] = strides[i + 1] * mesh_shape[i + 1]
1281 result_ranks = []
1282 for v in range(mesh_shape[mesh_dim]):
1283 new_coord = coord.copy()
1284 new_coord[mesh_dim] = v
1285 new_idx = 0
1286 for i in range(len(mesh_shape)):
1287 new_idx += new_coord[i] * strides[i]
1289 result_ranks.append(rank_list[new_idx])
1291 return result_ranks
1293 def to_hash(self):
1294 map_key = (self.mesh_shape, self.mesh_dim_names, self.rank_list)
1295 return map_key
1297 def __repr__(self):
1298 return (
1299 f"DeviceMesh(device_type='{self.device_type}', mesh_shape={self._mesh_shape}, "
1300 f"mesh_dim_names={self.mesh_dim_names}, rank_list={self._rank_list})"
1301 )
1303 def __str__(self):
1304 return self.__repr__()
1306 def __deepcopy__(self, memo):
1307 cls = self.__class__
1308 result = cls.__new__(cls)
1309 memo[id(self)] = result
1310 for k, v in self.__dict__.items():
1311 if k in ("_root_mesh", "_dim_group_sources"):
1312 setattr(result, k, v)
1313 else:
1314 setattr(result, k, copy.deepcopy(v, memo))
1315 return result
1318_DEVICE_MESH_MAP = {}
1321def _create_device_mesh(device_type: str,
1322 mesh_shape: tuple[int, ...],
1323 *,
1324 mesh_dim_names: Union[tuple[str, ...], list[str], None] = None,
1325 rank_list: tuple[int, ...],
1326 init_backend: bool = True, ):
1327 """Create or reuse a cached DeviceMesh with the requested topology."""
1328 mesh = np.array(rank_list).reshape(mesh_shape)
1329 mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None
1330 map_key = hash((mesh_shape, mesh_dim_names, rank_list))
1331 if map_key not in _DEVICE_MESH_MAP:
1332 _DEVICE_MESH_MAP[map_key] = DeviceMesh(device_type, mesh,
1333 mesh_dim_names=mesh_dim_names,
1334 _init_backend=init_backend)
1335 return _DEVICE_MESH_MAP.get(map_key, None)
1338def init_device_mesh(
1339 device_type: str,
1340 mesh_shape: tuple[int, ...],
1341 *,
1342 mesh_dim_names: Union[tuple[str, ...], list[str], None] = None,
1343 rank_list: Optional[tuple[int, ...]] = None,
1344 init_backend: bool = True,
1345) -> DeviceMesh:
1346 """Initialize a cached DeviceMesh from the provided shape, names, and ranks."""
1347 total_devices = int(np.prod(np.array(mesh_shape)))
1348 if rank_list is not None:
1349 if len(rank_list) != total_devices:
1350 raise ValueError(
1351 f"rank_list length ({len(rank_list)}) must equal mesh size ({total_devices})"
1352 )
1353 else:
1354 if init_backend:
1355 platform.init_process_group()
1356 try:
1357 current_rank = platform.get_rank()
1358 except Exception as exc:
1359 raise RuntimeError(
1360 "init_device_mesh: failed to get current rank for automatic rank_list generation. "
1361 "Either pass rank_list explicitly, or ensure the process group is initialized before calling "
1362 "init_device_mesh (or set init_backend=True to let init_device_mesh initialize it)."
1363 ) from exc
1364 base = current_rank - (current_rank % total_devices)
1365 rank_list = tuple(range(base, base + total_devices))
1367 if not isinstance(mesh_shape, tuple):
1368 raise TypeError(f'mesh_shape must be a tuple, but got {type(mesh_shape)}')
1370 for size in mesh_shape:
1371 if not isinstance(size, int) or size <= 0:
1372 raise ValueError(
1373 f"Each element of mesh_shape must be a positive integer, but got {mesh_shape}"
1374 )
1376 if mesh_dim_names is not None:
1377 if not isinstance(mesh_dim_names, (tuple, list)):
1378 raise TypeError(
1379 f'mesh_dim_names must be a tuple or list, but got {type(mesh_dim_names)}'
1380 )
1381 mesh_dim_names = tuple(mesh_dim_names)
1382 if len(mesh_shape) != len(mesh_dim_names):
1383 raise ValueError(
1384 f'mesh_shape ({len(mesh_shape)}) and mesh_dim_names '
1385 f'({len(mesh_dim_names)}) should have same length'
1386 )
1387 if len(set(mesh_dim_names)) != len(mesh_dim_names):
1388 raise ValueError(f'Each element of mesh_dim_names {mesh_dim_names} should be different')
1389 if any(not isinstance(name, str) or name == "" for name in mesh_dim_names):
1390 raise ValueError(f'Each element of mesh_dim_names {mesh_dim_names} should be a non-empty string')
1392 return _create_device_mesh(
1393 device_type,
1394 mesh_shape,
1395 mesh_dim_names=mesh_dim_names,
1396 rank_list=rank_list,
1397 init_backend=init_backend,
1398 )