Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / dtensor / device_mesh.py: 58%

795 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-20 07:18 +0800

1# Copyright 2025-2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""device mesh""" 

16 

17import copy 

18import os 

19import threading 

20from types import TracebackType 

21from typing import Any, List, Literal, Optional, Sequence, Type, Union 

22import numpy as np 

23 

24from hyper_parallel.core.dtensor._mesh_layout import IntTuple, _MeshLayout, _contiguous_strides, _is_int 

25from hyper_parallel.platform import get_platform 

26from hyper_parallel.platform.platform import EXISTING_COMM_GROUPS, PlatformType 

27 

28platform = get_platform() 

29Tensor = platform.Tensor 

30 

31 

32class _MeshEnv(threading.local): 

33 """Per-thread stack of active :class:`DeviceMesh` (PyTorch ``_mesh_resources`` parity).""" 

34 

35 def __init__(self) -> None: 

36 super().__init__() 

37 self.mesh_stack: List["DeviceMesh"] = [] 

38 

39 def get_current_mesh(self) -> "DeviceMesh": 

40 """Return the innermost active :class:`DeviceMesh` for this thread (PyTorch parity).""" 

41 if len(self.mesh_stack) == 0: 

42 raise RuntimeError("No device mesh is currently active!") 

43 return self.mesh_stack[-1] 

44 

45 

46_mesh_resources = _MeshEnv() 

47 

48BackendConfig = Optional[str] 

49 

50 

51def _get_sub_rank_list(mesh_shape, mesh_dim_names, rank_list, sub_mesh_dim_names, current_rank): 

52 """ 

53 Get the sub rank list for a sub mesh. 

54 

55 Args: 

56 mesh_shape (tuple[int]): The shape of the original mesh. 

57 mesh_dim_names (tuple[str]): The mesh dim names of the original mesh dimensions. 

58 rank_list (tuple[int]): A tuple of ranks that participate in this mesh. 

59 sub_mesh_dim_names (tuple[str]): The mesh dim names of the sub mesh to extract. 

60 current_rank (int): The current process rank. 

61 

62 Returns: 

63 list: The sub rank list for the sub mesh. 

64 """ 

65 mesh_tensor = np.array(rank_list).reshape(mesh_shape) 

66 

67 for dim_index, dim_name in enumerate(mesh_dim_names): 

68 if dim_name in sub_mesh_dim_names: 

69 continue 

70 

71 dim_size = mesh_shape[dim_index] 

72 sliced_tensors = np.split(mesh_tensor, dim_size, axis=dim_index) 

73 

74 for sliced_tensor in sliced_tensors: 

75 rank_exists = np.isin(np.array([current_rank]), sliced_tensor).any() 

76 if rank_exists: 

77 mesh_tensor = sliced_tensor 

78 break 

79 

80 sub_rank_list = mesh_tensor.reshape(-1).tolist() 

81 return sub_rank_list 

82 

83 

84def _normalize_backend_value(value: Any) -> BackendConfig: 

85 if value is None: 

86 return None 

87 if isinstance(value, str): 

88 return value 

89 if isinstance(value, tuple) and len(value) > 0: 

90 backend = value[0] 

91 if backend is None or isinstance(backend, str): 

92 return backend 

93 return None 

94 

95 

96def _normalize_backend_override( 

97 backend_override: dict[Union[int, str], Any], 

98 ndim: int, 

99 mesh_dim_names: Optional[tuple[str, ...]] = None, 

100) -> tuple[BackendConfig, ...]: 

101 """Normalize backend overrides by dim index/name.""" 

102 remaining = dict(backend_override) 

103 normalized: list[BackendConfig] = [] 

104 mesh_dim_names = mesh_dim_names or () 

105 

106 for dim_idx in range(ndim): 

107 dim_name = mesh_dim_names[dim_idx] if dim_idx < len(mesh_dim_names) else None 

108 if dim_name is not None and dim_name in remaining: 

109 if dim_idx in remaining: 

110 raise RuntimeError( 

111 f"Found redundant dim index {dim_idx} and name {dim_name} in backend_override" 

112 ) 

113 normalized.append(_normalize_backend_value(remaining.pop(dim_name))) 

114 elif dim_idx in remaining: 

115 normalized.append(_normalize_backend_value(remaining.pop(dim_idx))) 

116 else: 

117 normalized.append(None) 

118 

119 if remaining: 

120 raise RuntimeError( 

121 f"Found invalid keys in backend_override: got {list(remaining.keys())}, " 

122 f"expected integers in range [0, {ndim}) or one of {mesh_dim_names}" 

123 ) 

124 return tuple(normalized) 

125 

126 

127def _should_defer_group_init(sub_layout: _MeshLayout, backend_override: BackendConfig) -> bool: 

128 """Whether this mesh dimension should skip eager process-group creation.""" 

129 return backend_override == "fake" or sub_layout.numel() == 1 

130 

131 

132class DeviceMesh: 

133 """ 

134 Topological abstraction describing cluster devices. 

135 

136 Args: 

137 device_type (str): Device type. Valid values depend on the active platform: 

138 

139 - **PyTorch** (same as ``torch.distributed.device_mesh.DeviceMesh``): 

140 ``"cpu"``, ``"cuda"``, ``"npu"``. 

141 - **MindSpore** (mapped to the corresponding communication backend): 

142 ``"cpu"`` → mccl, ``"gpu"`` → nccl, ``"npu"`` → hccl. 

143 mesh (Union[Tensor, list, tuple, np.ndarray, None]): A multi-dimensional array, list, or integer 

144 tensor describing the device layout. The IDs in the mesh are global IDs of the 

145 default process group, representing the multi-dimensional networking structure 

146 of devices in distributed training (e.g., [[0,1],[2,3]] represents a 2x2 device mesh). 

147 If a list or non-int32 tensor is provided, it will be automatically converted 

148 to an int32 tensor. If None, a 1D mesh containing all ranks 

149 (i.e., ``[0, 1, ..., world_size-1]``) will be created automatically. 

150 mesh_dim_names (tuple[str]): A tuple[str] of mesh dim names for each dimension of mesh. 

151 _init_backend (boolean): Whether initial process group. 

152 

153 Attributes: 

154 ndim (int): Number of dimensions in the mesh. 

155 mesh_shape (tuple[int]): Shape of the device mesh. 

156 rank_list (tuple[int]): Flattened list of ranks from the mesh. 

157 root_mesh (DeviceMesh): The parent mesh if this is a sub mesh, None otherwise. 

158 sub_mesh (list[DeviceMesh]): List of child meshes created from this mesh. 

159 

160 Context manager: 

161 Use ``with device_mesh:`` to set the **current** mesh for this thread. 

162 """ 

163 

164 device_type: Literal["cpu", "cuda", "gpu", "npu"] 

165 mesh: Union[Tensor, list, tuple, np.ndarray] 

166 mesh_dim_names: Union[tuple[str, ...], list[str], None] 

167 

168 _VALID_DEVICE_TYPES = { 

169 PlatformType.PYTORCH: {"cpu", "cuda", "npu"}, 

170 PlatformType.MINDSPORE: {"cpu", "gpu", "npu"}, 

171 } 

172 

173 def __init__(self, 

174 device_type: Literal["cpu", "cuda", "gpu", "npu"], 

175 mesh: Union[Tensor, list, tuple, np.ndarray, None] = None, 

176 *, 

177 mesh_dim_names: Union[tuple[str, ...], list[str], None] = None, 

178 _init_backend: bool = True, 

179 _layout: Optional[_MeshLayout] = None, 

180 _rank_map: Optional[Tensor] = None, 

181 _root_mesh: Optional['DeviceMesh'] = None, 

182 ): 

183 self._validate_device_type(device_type) 

184 self.device_type = device_type 

185 

186 if _init_backend: 

187 platform.init_process_group() 

188 

189 self._layout, self._rank_map = self._resolve_layout_and_rank_map(mesh, _layout, _rank_map) 

190 self._rank = platform.get_rank() 

191 self._root_mesh = _root_mesh 

192 self._refresh_mesh_view() 

193 self._set_mesh_dim_names(mesh_dim_names) 

194 self._initialize_runtime_state(_init_backend) 

195 if os.getenv("MS_SIMULATION_LEVEL") is None: 

196 self._coordinate_on_dim = self._compute_coordinate_on_dim() 

197 

198 @classmethod 

199 def _validate_device_type(cls, device_type: str) -> None: 

200 """Validate that the requested device type is supported on the active platform.""" 

201 valid_device_types = cls._VALID_DEVICE_TYPES.get(platform.platform_type) 

202 if valid_device_types is not None and device_type not in valid_device_types: 

203 raise ValueError( 

204 f"Invalid device_type '{device_type}' for {platform.platform_type.name} platform. " 

205 f"Valid device types are: {sorted(valid_device_types)}" 

206 ) 

207 

208 @classmethod 

209 def _resolve_layout_and_rank_map( 

210 cls, 

211 mesh: Union[Tensor, list, tuple, np.ndarray, None], 

212 layout: Optional[_MeshLayout], 

213 rank_map: Optional[Tensor], 

214 ) -> tuple[_MeshLayout, Tensor]: 

215 """Build the internal layout and rank map from either public or private constructor inputs.""" 

216 if mesh is not None and (layout is not None or rank_map is not None): 

217 raise TypeError("Cannot provide both explicit mesh and private _layout/_rank_map arguments.") 

218 

219 if mesh is None and (layout is None or rank_map is None): 

220 world_size = platform.get_world_size() 

221 mesh = list(range(world_size)) 

222 

223 if mesh is not None: 

224 mesh_tensor = cls._convert_mesh_to_tensor(mesh) 

225 if mesh_tensor.ndim == 0: 

226 raise ValueError("mesh must be at least 1-dimensional") 

227 return cls._build_layout_from_mesh(mesh_tensor), cls._build_rank_map_from_mesh(mesh_tensor) 

228 

229 rank_map_tensor = cls._convert_rank_map_to_tensor(rank_map) 

230 if layout is None or rank_map_tensor is None: 

231 raise TypeError("The mesh argument is required except for private _layout/_rank_map construction.") 

232 if not layout.check_non_overlap(): 

233 raise ValueError(f"Invalid overlapping layout {layout}.") 

234 return layout, rank_map_tensor 

235 

236 def _refresh_mesh_view(self) -> None: 

237 """Materialize the visible mesh tensor and the derived shape/rank metadata.""" 

238 # Compute everything in numpy first so the intermediate ops don't need 

239 # a real device. Otherwise the call would fail (or SIGSEGV on Ascend) 

240 # when DeviceMesh is constructed inside a ``ms.DeviceCtx("meta")`` 

241 # block — e.g., from ``DeviceMesh.concatenate`` invoked under 

242 # ``fully_shard``, which forces fresh ``Tensor()`` constructions onto 

243 # the meta device and any subsequent op (asnumpy, nonzero, …) crashes. 

244 rank_map_np = platform.tensor_to_numpy(self._rank_map).reshape(-1) 

245 full_mesh_np = self._layout.remap_to_numpy(rank_map_np) 

246 if full_mesh_np.shape[0] == 1: 

247 per_rank_mesh_np = full_mesh_np[0] 

248 else: 

249 coords = np.argwhere(full_mesh_np == self._rank) 

250 if coords.shape[0] == 0: 

251 raise RuntimeError( 

252 "In order to get the mesh tensor of a DeviceMesh it needs to " 

253 "either have all its original dimensions or contain the local rank." 

254 ) 

255 per_rank_mesh_np = full_mesh_np[coords[0, 0]] 

256 # Cache the numpy view so ``_compute_coordinate_on_dim`` doesn't need 

257 # to operate on ``self.mesh`` (which may be on the meta device). 

258 self._per_rank_mesh_np = per_rank_mesh_np 

259 self.mesh = Tensor(per_rank_mesh_np.astype(np.int32)).int() 

260 self._mesh_shape = tuple(per_rank_mesh_np.shape) 

261 self._rank_list = tuple(per_rank_mesh_np.reshape(-1).tolist()) 

262 self._flatten_rank_map = tuple(rank_map_np.tolist()) 

263 self._dev_num = np.prod(np.array(self._mesh_shape)) 

264 self._dev_rank = len(self._mesh_shape) 

265 

266 def _set_mesh_dim_names( 

267 self, 

268 mesh_dim_names: Union[tuple[str, ...], list[str], None], 

269 ) -> None: 

270 """Validate mesh dim names and build lookup tables for named access.""" 

271 self.mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None 

272 if self.mesh_dim_names is None: 

273 return 

274 

275 if len(self._mesh_shape) != len(self.mesh_dim_names): 

276 raise ValueError( 

277 f'mesh dimensions ({len(self._mesh_shape)}) should be equal to ' 

278 f'mesh_dim_names length ({len(self.mesh_dim_names)})' 

279 ) 

280 if len(set(self.mesh_dim_names)) != len(self.mesh_dim_names): 

281 raise ValueError(f'Each element of mesh_dim_names {self.mesh_dim_names} should be different') 

282 inter_key = "interleaved_parallel" 

283 if inter_key in self.mesh_dim_names and self.mesh_dim_names.index(inter_key) != len(self.mesh_dim_names) - 1: 

284 raise ValueError( 

285 "'interleaved_parallel' should be at the last dim of mesh_dim_names, means virtual sharding." 

286 ) 

287 self._dev_name_to_dev_id = { 

288 name: self._dev_rank - i - 1 for i, name in enumerate(self.mesh_dim_names) 

289 } 

290 self._dev_name_to_index = {name: i for i, name in enumerate(self.mesh_dim_names)} 

291 

292 def _initialize_runtime_state(self, init_backend: bool) -> None: 

293 """Initialize caches and optional process-group state for the mesh view.""" 

294 self._cache_rank_list_along_axis = {} 

295 self._global_shape_map = {} 

296 self._sub_mesh_cache = {} 

297 self._flatten_mapping: dict[str, 'DeviceMesh'] = {} 

298 self._ndim = len(self._mesh_shape) 

299 self._dim_group_backends = (None,) * self._ndim 

300 self._dim_group_sources = tuple((self, dim) for dim in range(self._ndim)) 

301 self._sub_mesh: List['DeviceMesh'] = [] 

302 if not init_backend: 

303 return 

304 self._dim_group_names = self._init_process_groups( 

305 self._mesh_shape, 

306 self.mesh_dim_names, 

307 self._rank_list, 

308 ) 

309 

310 @staticmethod 

311 def _build_layout_from_mesh(mesh: Tensor) -> _MeshLayout: 

312 mesh_shape = tuple(mesh.shape) 

313 return _MeshLayout(mesh_shape, _contiguous_strides(mesh_shape)) 

314 

315 @staticmethod 

316 def _build_rank_map_from_mesh(mesh: Tensor) -> Tensor: 

317 return Tensor(platform.tensor_to_numpy(mesh).reshape(-1)).int() 

318 

319 @staticmethod 

320 def _convert_rank_map_to_tensor(rank_map: Tensor) -> Tensor: 

321 """Normalize a rank-map input into the flat int32 Tensor stored on the mesh. 

322 

323 Tensor input is returned as-is to preserve its original device; list / 

324 tuple / numpy input is built into a fresh flat int32 Tensor. 

325 """ 

326 if isinstance(rank_map, Tensor): 

327 # Reuse the existing tensor as-is so we preserve its real device. 

328 # Going through ``Tensor(np_array)`` would re-create on whatever 

329 # device context is active (e.g. ``ms.DeviceCtx("meta")`` while 

330 # ``DeviceMesh.concatenate`` runs under ``fully_shard``), which then 

331 # breaks the immediate ``asnumpy()`` in ``_refresh_mesh_view``. 

332 # All in-tree callers that pass a Tensor pass an existing 

333 # ``DeviceMesh._rank_map`` — already a flat int32 tensor, so no 

334 # reshape/cast is needed. 

335 return rank_map 

336 rank_map_np = np.array(rank_map) 

337 return Tensor(rank_map_np.reshape(-1).astype(np.int32)).int() 

338 

339 @staticmethod 

340 def _get_mesh_tensor_from_full_mesh(full_mesh: Tensor, current_rank: Optional[int] = None) -> Tensor: 

341 """Select the per-rank mesh view from a fully materialized layout remap.""" 

342 if full_mesh.shape[0] == 1: 

343 return full_mesh[0] 

344 

345 if current_rank is None: 

346 current_rank = platform.get_rank() 

347 

348 rank_coords = (full_mesh == current_rank).nonzero() 

349 if rank_coords.shape[0] > 0: 

350 return full_mesh[rank_coords[0, 0]] 

351 raise RuntimeError( 

352 "In order to get the mesh tensor of a DeviceMesh it needs to " 

353 "either have all its original dimensions or contain the local rank." 

354 ) 

355 

356 def _compute_coordinate_on_dim(self): 

357 """Compute the current rank coordinates inside this mesh view.""" 

358 # Use the cached numpy view rather than ``self.mesh`` so this works 

359 # even when the mesh tensor lives on the meta device (DeviceMesh 

360 # constructed under ``ms.DeviceCtx("meta")`` via ``fully_shard``). 

361 per_rank_mesh_np = getattr(self, "_per_rank_mesh_np", None) 

362 if per_rank_mesh_np is not None: 

363 rank_coords = np.argwhere(per_rank_mesh_np == self._rank) 

364 if rank_coords.shape[0] not in (0, 1): 

365 raise AssertionError( 

366 f"rank_coords.shape[0] must be 0 or 1, got {rank_coords.shape[0]}" 

367 ) 

368 if rank_coords.shape[0] == 0: 

369 return None 

370 return tuple(int(x) for x in rank_coords[0]) 

371 return self._compute_coordinates_from_mesh(self.mesh, self._rank) 

372 

373 @staticmethod 

374 def _compute_coordinates_from_mesh( 

375 mesh_tensor: Tensor, 

376 rank: int, 

377 ): 

378 """Locate one rank inside a mesh tensor and return its coordinates.""" 

379 rank_coords = (mesh_tensor == rank).nonzero() 

380 if rank_coords.shape[0] not in (0, 1): 

381 raise AssertionError( 

382 f"rank_coords.shape[0] must be 0 or 1, got {rank_coords.shape[0]}" 

383 ) 

384 

385 if rank_coords.shape[0] == 0: 

386 return None 

387 

388 coords = rank_coords[0].tolist() 

389 return tuple(coords) 

390 

391 def size(self, mesh_dim=None) -> int: 

392 if mesh_dim is not None: 

393 return self.mesh.shape[mesh_dim] 

394 return self.mesh.numel() 

395 

396 def get_coordinate(self): 

397 return self._coordinate_on_dim if self._coordinate_on_dim else None 

398 

399 def __enter__(self) -> "DeviceMesh": 

400 _mesh_resources.mesh_stack.append(self) 

401 return self 

402 

403 def __exit__( 

404 self, 

405 exc_type: Optional[Type[BaseException]], 

406 exc_val: Optional[BaseException], 

407 exc_tb: Optional[TracebackType], 

408 ) -> None: 

409 _mesh_resources.mesh_stack.pop() 

410 

411 @staticmethod 

412 def _convert_mesh_to_tensor(mesh: Union[Tensor, list, tuple, np.ndarray]) -> Tensor: 

413 """Convert a public mesh input into an int32 platform tensor.""" 

414 if isinstance(mesh, Tensor): 

415 mesh = platform.tensor_to_numpy(mesh) 

416 elif isinstance(mesh, (list, tuple)): 

417 mesh = np.array(mesh) 

418 elif not isinstance(mesh, np.ndarray): 

419 raise TypeError( 

420 f"mesh must be Tensor, list, tuple or numpy array, but got {type(mesh)}" 

421 ) 

422 

423 mesh = mesh.astype(np.int32) 

424 return Tensor(mesh).int() 

425 

426 @staticmethod 

427 def _init_one_process_group(mesh_shape: tuple[int, ...], mesh_dim_names: tuple[str, ...], 

428 dim_name: str, rank_list: tuple[int, ...]) -> str: 

429 """Create one process-group family for the named mesh dimension.""" 

430 group_key = None 

431 split_ranks = set() 

432 if not isinstance(dim_name, tuple): 

433 dim_name = (dim_name,) 

434 for rank in rank_list: 

435 split_rank = _get_sub_rank_list(mesh_shape, mesh_dim_names, rank_list, dim_name, rank) 

436 sorted_rank = tuple(sorted(split_rank)) 

437 split_ranks.add(sorted_rank) 

438 if rank == platform.get_rank(): 

439 group_key = str(sorted_rank) 

440 split_ranks = sorted([list(item) for item in split_ranks]) 

441 platform.split_group(split_ranks=split_ranks) 

442 return group_key 

443 

444 @staticmethod 

445 def _build_dim_split_ranks( 

446 sub_layout: _MeshLayout, 

447 rank_map: Tensor, 

448 ) -> tuple[list[list[int]], Optional[str]]: 

449 """Build rank lists and the local cache key for one logical mesh axis.""" 

450 pg_ranks_by_dim = sub_layout.remap_to_numpy(platform.tensor_to_numpy(rank_map)) 

451 current_rank = platform.get_rank() 

452 split_ranks = [] 

453 split_ranks_set = set() 

454 group_key = None 

455 for dim_mesh in np.array(pg_ranks_by_dim): 

456 subgroup_ranks = tuple(int(rank) for rank in np.array(dim_mesh).reshape(-1).tolist()) 

457 subgroup_ranks_sorted = tuple(sorted(subgroup_ranks)) 

458 if subgroup_ranks_sorted not in split_ranks_set: 

459 split_ranks_set.add(subgroup_ranks_sorted) 

460 split_ranks.append(list(subgroup_ranks_sorted)) 

461 if current_rank in subgroup_ranks: 

462 if group_key is not None: 

463 raise RuntimeError( 

464 "Each device mesh dimension should get only one process group per rank." 

465 ) 

466 group_key = str(subgroup_ranks_sorted) 

467 split_ranks = sorted(split_ranks) 

468 return split_ranks, group_key 

469 

470 @staticmethod 

471 def _cache_group_if_needed(group_key: Optional[str], group: Any) -> None: 

472 if group_key is not None and group is not None and group_key not in EXISTING_COMM_GROUPS: 

473 EXISTING_COMM_GROUPS[group_key] = group 

474 

475 @staticmethod 

476 def _init_process_groups_for_layout( 

477 layout: _MeshLayout, 

478 rank_map: Tensor, 

479 mesh_dim_names: Union[tuple[str, ...], None], 

480 backend_override: Optional[tuple[BackendConfig, ...]] = None, 

481 ) -> list: 

482 """Initialize process groups for each top-level axis in the given layout.""" 

483 if mesh_dim_names is None: 

484 mesh_dim_names = tuple(f"dim_{dim}" for dim in range(len(layout))) 

485 if backend_override is None: 

486 backend_override = (None,) * len(layout) 

487 if len(backend_override) != len(layout): 

488 raise ValueError( 

489 f"backend_override length {len(backend_override)} must match layout rank {len(layout)}" 

490 ) 

491 

492 dim_group_names = [] 

493 for dim, sub_layout in enumerate(layout): 

494 split_ranks, group_key = DeviceMesh._build_dim_split_ranks(sub_layout, rank_map) 

495 if _should_defer_group_init(sub_layout, backend_override[dim]): 

496 dim_group_names.append(None) 

497 continue 

498 group = platform.split_group(split_ranks=split_ranks) 

499 DeviceMesh._cache_group_if_needed(group_key, group) 

500 dim_group_names.append(group_key) 

501 return dim_group_names 

502 

503 @staticmethod 

504 def _init_process_groups(mesh_shape: tuple[int, ...], mesh_dim_names: Union[tuple[str, ...], None], 

505 rank_list: tuple[int, ...], 

506 backend_override: Optional[tuple[BackendConfig, ...]] = None) -> list: 

507 layout = _MeshLayout(mesh_shape, _contiguous_strides(mesh_shape)) 

508 rank_map = DeviceMesh._convert_rank_map_to_tensor(rank_list) 

509 return DeviceMesh._init_process_groups_for_layout( 

510 layout, 

511 rank_map, 

512 mesh_dim_names, 

513 backend_override=backend_override, 

514 ) 

515 

516 @property 

517 def rank(self): 

518 return self._rank 

519 

520 @property 

521 def mesh_shape(self): 

522 return self._mesh_shape 

523 

524 @property 

525 def rank_list(self): 

526 return self._rank_list 

527 

528 @property 

529 def ndim(self) -> int: 

530 return self._ndim 

531 

532 @property 

533 def shape(self) -> tuple: 

534 return self._mesh_shape 

535 

536 @property 

537 def root_mesh(self) -> Optional['DeviceMesh']: 

538 return self._root_mesh 

539 

540 @root_mesh.setter 

541 def root_mesh(self, value: Optional['DeviceMesh']): 

542 self._root_mesh = value 

543 

544 @property 

545 def sub_mesh(self) -> List['DeviceMesh']: 

546 return self._sub_mesh 

547 

548 def get_flatten_mapping(self) -> dict: 

549 return self._flatten_mapping 

550 

551 def add_flatten_mapping(self, name: str, mesh: 'DeviceMesh') -> None: 

552 self._flatten_mapping[name] = mesh 

553 

554 def __getitem__(self, sub_mesh_dim_names: Union[str, tuple[str, ...]]) -> 'DeviceMesh': 

555 if not self.mesh_dim_names: 

556 raise RuntimeError("Cannot slice a DeviceMesh without mesh_dim_names!") 

557 

558 sub_mesh_dim_names = DeviceMesh._normalize_sub_mesh_dim_names(sub_mesh_dim_names) 

559 flatten_mapping = self._get_root_mesh().get_flatten_mapping() 

560 

561 flattened_result = self._try_get_from_flatten_mapping(sub_mesh_dim_names, flatten_mapping) 

562 if flattened_result is not None: 

563 return flattened_result 

564 

565 layout = self._get_slice_mesh_layout(sub_mesh_dim_names) 

566 if sub_mesh_dim_names in self._sub_mesh_cache: 

567 return self._sub_mesh_cache[sub_mesh_dim_names] 

568 if layout == self._layout: 

569 return self 

570 return self._create_and_cache_sub_mesh(sub_mesh_dim_names, layout) 

571 

572 @staticmethod 

573 def _normalize_sub_mesh_dim_names(sub_mesh_dim_names: Union[str, tuple[str, ...]]) -> tuple[str, ...]: 

574 """Normalize a slice selector into a non-empty tuple of mesh dim names.""" 

575 if isinstance(sub_mesh_dim_names, str): 

576 sub_mesh_dim_names = (sub_mesh_dim_names,) 

577 

578 if not isinstance(sub_mesh_dim_names, tuple): 

579 raise TypeError( 

580 f"sub_mesh_dim_names must be str or tuple, but got {type(sub_mesh_dim_names)}" 

581 ) 

582 

583 if len(sub_mesh_dim_names) == 0: 

584 raise ValueError("sub_mesh_dim_names cannot be empty") 

585 

586 return sub_mesh_dim_names 

587 

588 @staticmethod 

589 def _try_get_from_flatten_mapping(sub_mesh_dim_names: tuple[str, ...], 

590 flatten_mapping: dict) -> Optional['DeviceMesh']: 

591 if len(sub_mesh_dim_names) == 1 and sub_mesh_dim_names[0] in flatten_mapping: 

592 return flatten_mapping[sub_mesh_dim_names[0]] 

593 return None 

594 

595 def _get_mesh_dim_by_name(self, mesh_dim_name: str) -> int: 

596 """Resolve a named mesh axis to its integer position.""" 

597 mesh_dim_names = self.mesh_dim_names or () 

598 if len(mesh_dim_names) == 0: 

599 raise KeyError("No mesh_dim_names found.") 

600 if mesh_dim_name not in mesh_dim_names: 

601 raise KeyError( 

602 f"Mesh dimension '{mesh_dim_name}' does not exist. " 

603 f"Available mesh dimensions are: {mesh_dim_names}" 

604 ) 

605 return mesh_dim_names.index(mesh_dim_name) 

606 

607 def _get_slice_mesh_layout(self, sub_mesh_dim_names: tuple[str, ...]) -> _MeshLayout: 

608 """Construct the layout corresponding to one named sub-mesh slice request.""" 

609 root_mesh = self._get_root_mesh() 

610 slice_from_root = self == root_mesh 

611 flatten_name_to_layout = ( 

612 {key: mesh._layout for key, mesh in root_mesh.get_flatten_mapping().items()} 

613 if slice_from_root else {} 

614 ) 

615 valid_dim_names = [*(self.mesh_dim_names or ()), *flatten_name_to_layout] 

616 if not all(name in valid_dim_names for name in sub_mesh_dim_names): 

617 raise KeyError( 

618 f"Invalid mesh_dim_names {sub_mesh_dim_names} specified. " 

619 f"Valid mesh_dim_names are {valid_dim_names}." 

620 ) 

621 

622 if all(name in (self.mesh_dim_names or ()) for name in sub_mesh_dim_names): 

623 indices = [self.mesh_dim_names.index(name) for name in sub_mesh_dim_names] 

624 if indices != sorted(indices): 

625 raise ValueError( 

626 f"sub_mesh_dim_names {sub_mesh_dim_names} must follow the order of " 

627 f"original mesh_dim_names {self.mesh_dim_names}" 

628 ) 

629 

630 sliced_sizes: list[IntTuple] = [] 

631 sliced_strides: list[IntTuple] = [] 

632 for name in sub_mesh_dim_names: 

633 if name in (self.mesh_dim_names or ()): 

634 layout = self._layout[self.mesh_dim_names.index(name)] 

635 else: 

636 layout = flatten_name_to_layout[name] 

637 sliced_sizes.append(layout.sizes) 

638 sliced_strides.append(layout.strides) 

639 

640 pre_stride = -1 

641 for stride in reversed(sliced_strides): 

642 if not _is_int(stride): 

643 raise NotImplementedError( 

644 "Currently, this only allows slicing out a contiguous flattened dim." 

645 ) 

646 if stride < pre_stride: 

647 raise ValueError( 

648 f"Invalid mesh_dim_names {sub_mesh_dim_names} specified. " 

649 "Mesh dim indices should be in ascending order." 

650 ) 

651 pre_stride = stride 

652 

653 if len(sliced_sizes) == 1: 

654 layout = _MeshLayout(sliced_sizes[0], sliced_strides[0]) 

655 else: 

656 layout = _MeshLayout(tuple(sliced_sizes), tuple(sliced_strides)) 

657 if not layout.check_non_overlap(): 

658 raise RuntimeError(f"Slicing overlapping dim_names {sub_mesh_dim_names} is not allowed.") 

659 return layout 

660 

661 def _create_and_cache_sub_mesh(self, sub_mesh_dim_names: tuple[str, ...], layout: _MeshLayout) -> 'DeviceMesh': 

662 """Create a sub-mesh view, copy group metadata, and cache the result.""" 

663 root_mesh = self._get_root_mesh() 

664 sub_mesh = DeviceMesh( 

665 device_type=self.device_type, 

666 mesh_dim_names=sub_mesh_dim_names, 

667 _init_backend=False, 

668 _layout=layout, 

669 _rank_map=root_mesh._rank_map, 

670 _root_mesh=root_mesh, 

671 ) 

672 

673 slice_dim_group_name = [] 

674 slice_dim_group_backends: list[BackendConfig] = [] 

675 slice_dim_group_sources: list[tuple['DeviceMesh', int]] = [] 

676 for name in sub_mesh_dim_names: 

677 if name in (self.mesh_dim_names or ()): 

678 dim_index = self.mesh_dim_names.index(name) 

679 if hasattr(self, "_dim_group_names"): 

680 slice_dim_group_name.append(self._dim_group_names[dim_index]) 

681 slice_dim_group_backends.append(self._dim_group_backends[dim_index]) 

682 if hasattr(self, "_dim_group_sources"): 

683 slice_dim_group_sources.append(self._dim_group_sources[dim_index]) # pylint: disable=W0212 

684 else: 

685 slice_dim_group_sources.append((self, dim_index)) 

686 elif name in root_mesh.get_flatten_mapping(): 

687 flatten_mesh = root_mesh.get_flatten_mapping()[name] 

688 if hasattr(flatten_mesh, "_dim_group_names"): 

689 slice_dim_group_name.append(flatten_mesh._dim_group_names[0]) 

690 slice_dim_group_backends.append(flatten_mesh._dim_group_backends[0]) 

691 if hasattr(flatten_mesh, "_dim_group_sources"): 

692 slice_dim_group_sources.append(flatten_mesh._dim_group_sources[0]) # pylint: disable=W0212 

693 else: 

694 slice_dim_group_sources.append((flatten_mesh, 0)) 

695 if slice_dim_group_name: 

696 sub_mesh._dim_group_names = slice_dim_group_name # pylint: disable=W0212 

697 if slice_dim_group_backends: 

698 sub_mesh._dim_group_backends = tuple(slice_dim_group_backends) # pylint: disable=W0212 

699 if slice_dim_group_sources: 

700 sub_mesh._dim_group_sources = tuple(slice_dim_group_sources) # pylint: disable=W0212 

701 

702 self._sub_mesh_cache[sub_mesh_dim_names] = sub_mesh 

703 self.sub_mesh.append(sub_mesh) 

704 return sub_mesh 

705 

706 def get_group(self, mesh_dim: Optional[Union[int, str]] = None): 

707 """Return the communication group for one mesh axis.""" 

708 if not hasattr(self, "_dim_group_names"): 

709 raise RuntimeError("DeviceMesh process groups not initialized!") 

710 

711 if self.ndim > 1 and mesh_dim is None: 

712 raise RuntimeError( 

713 f"Found the DeviceMesh have {self.ndim} dimensions. " 

714 "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1." 

715 ) 

716 

717 root_mesh = self._get_root_mesh() 

718 if isinstance(mesh_dim, str) and mesh_dim in root_mesh.get_flatten_mapping(): 

719 flattened_mesh = root_mesh.get_flatten_mapping()[mesh_dim] 

720 return flattened_mesh.get_comm_group_by_axis(mesh_dim) 

721 

722 return self.get_comm_group_by_axis(mesh_dim) 

723 

724 def get_all_groups(self) -> list: 

725 if not hasattr(self, "_dim_group_names"): 

726 raise RuntimeError("DeviceMesh process groups not initialized!") 

727 

728 return [self.get_group(i) for i in range(self.ndim)] 

729 

730 @staticmethod 

731 def from_group(group: Union[Any, list[Any]], 

732 device_type: str, 

733 mesh: Union[Tensor, list, tuple, np.ndarray] = None, 

734 mesh_dim_names: Union[tuple[str, ...], list[str]] = None 

735 ) -> 'DeviceMesh': 

736 """Build a DeviceMesh from an existing process group or a list of groups.""" 

737 if not isinstance(group, list): 

738 group_ranks = platform.get_process_group_ranks(group) 

739 group_key = str(tuple(sorted(group_ranks))) 

740 if not platform.get_created_group(group_ranks): 

741 EXISTING_COMM_GROUPS[group_key] = group 

742 if ( 

743 isinstance(mesh, Tensor) and mesh.tolist() != group_ranks 

744 ) or ( 

745 mesh is not None 

746 and not isinstance(mesh, Tensor) 

747 and mesh != group_ranks 

748 ): 

749 raise ValueError( 

750 f"Invalid mesh_shape {str(mesh)} for 1D group with ranks {group_ranks}" 

751 ) 

752 device_mesh = DeviceMesh(device_type, group_ranks, mesh_dim_names=mesh_dim_names, _init_backend=False) 

753 device_mesh._dim_group_names = [group_key] # pylint: disable=W0212 

754 return device_mesh 

755 

756 groups = list(group) 

757 if len(groups) == 0: 

758 raise ValueError("Expect at least one group be specified.") 

759 if mesh is None: 

760 raise ValueError("mesh_shape is must specified when group is a list.") 

761 mesh = DeviceMesh._convert_mesh_to_tensor(mesh) 

762 if mesh.ndim != len(groups): 

763 raise ValueError("mesh dimensions must match group dimensions.") 

764 device_mesh = DeviceMesh(device_type, mesh, mesh_dim_names=mesh_dim_names, _init_backend=False) 

765 device_mesh._dim_group_names = [] # pylint: disable=W0212 

766 for dim_group in groups: 

767 group_ranks = platform.get_process_group_ranks(dim_group) 

768 group_key = str(tuple(sorted(group_ranks))) 

769 if not platform.get_created_group(group_ranks): 

770 EXISTING_COMM_GROUPS[group_key] = dim_group 

771 device_mesh._dim_group_names.append(group_key) # pylint: disable=W0212 

772 return device_mesh 

773 

774 def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int: 

775 """Return the local coordinate of the current rank along one mesh dimension.""" 

776 if self.ndim > 1 and mesh_dim is None: 

777 raise RuntimeError( 

778 f"Found the DeviceMesh have {self.ndim} dimensions. " 

779 "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1." 

780 ) 

781 

782 if mesh_dim is None: 

783 mesh_dim = 0 

784 

785 if isinstance(mesh_dim, str): 

786 if mesh_dim not in self.mesh_dim_names: # pylint: disable=E1135 

787 raise ValueError( 

788 f"mesh_dim '{mesh_dim}' not found in mesh_dim_names {self.mesh_dim_names}" 

789 ) 

790 dim_index = self.mesh_dim_names.index(mesh_dim) 

791 else: 

792 if not isinstance(mesh_dim, int) or mesh_dim < 0 or mesh_dim >= self.ndim: 

793 raise ValueError( 

794 f"mesh_dim must be an integer in range [0, {self.ndim}), " 

795 f"but got {mesh_dim}" 

796 ) 

797 dim_index = mesh_dim 

798 

799 if self._rank not in self._rank_list: 

800 raise ValueError( 

801 f"Current rank {self._rank} not found in rank_list {self._rank_list}" 

802 ) 

803 

804 idx = self._rank_list.index(self._rank) 

805 coord = [0] * len(self._mesh_shape) 

806 temp = idx 

807 for i in range(len(self._mesh_shape) - 1, -1, -1): 

808 coord[i] = temp % self._mesh_shape[i] 

809 temp //= self._mesh_shape[i] 

810 

811 return coord[dim_index] 

812 

813 def flatten(self, mesh_dim_name: Optional[str] = None) -> 'DeviceMesh': 

814 return self._create_flatten_mesh(mesh_dim_name) 

815 

816 def _get_root_mesh(self) -> 'DeviceMesh': 

817 """Return the canonical root mesh for this view.""" 

818 if self._root_mesh is None: 

819 return self 

820 return self._root_mesh._get_root_mesh() # pylint: disable=protected-access 

821 

822 @staticmethod 

823 def _validate_concatenate_inputs( 

824 meshes: Sequence['DeviceMesh'], 

825 ) -> tuple['DeviceMesh', tuple[str, ...], tuple[int, ...]]: 

826 """Validate concatenate inputs and return the shared root metadata.""" 

827 if len(meshes) == 0: 

828 raise ValueError("DeviceMesh.concatenate expects at least one mesh.") 

829 if len(meshes) == 1: 

830 return meshes[0]._get_root_mesh(), tuple(meshes[0].mesh_dim_names or ()), meshes[0]._flatten_rank_map 

831 

832 root_mesh = meshes[0]._get_root_mesh() # pylint: disable=protected-access 

833 requested_dim_names: list[str] = [] 

834 flatten_rank_map = meshes[0]._flatten_rank_map # pylint: disable=protected-access 

835 for mesh in meshes: 

836 if mesh._get_root_mesh().to_hash() != root_mesh.to_hash(): # pylint: disable=protected-access 

837 raise ValueError("DeviceMesh.concatenate expects all meshes to share the same root mesh.") 

838 if mesh._flatten_rank_map != flatten_rank_map: # pylint: disable=protected-access 

839 raise ValueError("DeviceMesh.concatenate expects all meshes to share the same root mesh.") 

840 if not mesh.mesh_dim_names: 

841 raise ValueError("DeviceMesh.concatenate requires mesh_dim_names on every input mesh.") 

842 requested_dim_names.extend(mesh.mesh_dim_names) 

843 return root_mesh, tuple(requested_dim_names), flatten_rank_map 

844 

845 @staticmethod 

846 def _validate_concatenate_root_order(root_mesh: 'DeviceMesh', requested_dim_names: tuple[str, ...]) -> None: 

847 """Require original root dims to stay in root order when concatenating by name.""" 

848 root_dim_names = tuple(root_mesh.mesh_dim_names) if root_mesh.mesh_dim_names else () 

849 if not root_dim_names or not all(dim_name in root_dim_names for dim_name in requested_dim_names): 

850 return 

851 

852 requested_indices = [root_dim_names.index(dim_name) for dim_name in requested_dim_names] 

853 if requested_indices != sorted(requested_indices): 

854 raise ValueError( 

855 "DeviceMesh.concatenate expects meshes to follow the root mesh order. " 

856 f"Got root mesh dims {root_dim_names} and requested dims {requested_dim_names}." 

857 ) 

858 

859 @staticmethod 

860 def _collect_concatenate_metadata( 

861 meshes: Sequence['DeviceMesh'], 

862 ) -> tuple[ 

863 list[str], 

864 list[IntTuple], 

865 list[IntTuple], 

866 list[Optional[str]], 

867 list[BackendConfig], 

868 list[tuple['DeviceMesh', int]], 

869 ]: 

870 """Collect layout and process-group metadata from all concatenate inputs.""" 

871 concat_dim_names: list[str] = [] 

872 concat_sizes: list[IntTuple] = [] 

873 concat_strides: list[IntTuple] = [] 

874 concat_dim_group_names: list[Optional[str]] = [] 

875 concat_dim_group_backends: list[BackendConfig] = [] 

876 concat_dim_group_sources: list[tuple['DeviceMesh', int]] = [] 

877 

878 for mesh in meshes: 

879 for dim, sub_layout in enumerate(mesh._layout): # pylint: disable=protected-access 

880 concat_sizes.append(sub_layout.sizes) 

881 concat_strides.append(sub_layout.strides) 

882 if hasattr(mesh, "_dim_group_names"): 

883 concat_dim_group_names.append(mesh._dim_group_names[dim]) # pylint: disable=protected-access 

884 concat_dim_group_backends.append(mesh._dim_group_backends[dim]) # pylint: disable=protected-access 

885 if hasattr(mesh, "_dim_group_sources"): 

886 concat_dim_group_sources.append(mesh._dim_group_sources[dim]) # pylint: disable=protected-access 

887 else: 

888 concat_dim_group_sources.append((mesh, dim)) 

889 concat_dim_names.extend(mesh.mesh_dim_names) 

890 

891 if len(set(concat_dim_names)) != len(concat_dim_names): 

892 raise ValueError( 

893 f"DeviceMesh.concatenate expects disjoint mesh dims, but got {tuple(concat_dim_names)}." 

894 ) 

895 return ( 

896 concat_dim_names, 

897 concat_sizes, 

898 concat_strides, 

899 concat_dim_group_names, 

900 concat_dim_group_backends, 

901 concat_dim_group_sources, 

902 ) 

903 

904 @staticmethod 

905 def _build_concatenate_layout(concat_sizes: list[IntTuple], concat_strides: list[IntTuple]) -> _MeshLayout: 

906 """Build the layout represented by concatenated top-level mesh axes.""" 

907 if len(concat_sizes) == 1: 

908 return _MeshLayout(concat_sizes[0], concat_strides[0]) 

909 return _MeshLayout(tuple(concat_sizes), tuple(concat_strides)) 

910 

911 @staticmethod 

912 def _set_concatenated_group_state( 

913 mesh: 'DeviceMesh', 

914 dim_group_names: list[Optional[str]], 

915 dim_group_backends: list[BackendConfig], 

916 dim_group_sources: list[tuple['DeviceMesh', int]], 

917 ) -> None: 

918 """Attach inherited process-group metadata to a concatenated mesh view.""" 

919 if dim_group_names: 

920 mesh._dim_group_names = dim_group_names # pylint: disable=W0212 

921 if dim_group_backends: 

922 mesh._dim_group_backends = tuple(dim_group_backends) # pylint: disable=W0212 

923 if dim_group_sources: 

924 mesh._dim_group_sources = tuple(dim_group_sources) # pylint: disable=W0212 

925 

926 @staticmethod 

927 def concatenate(meshes: Sequence['DeviceMesh']) -> 'DeviceMesh': 

928 """Concatenate multiple sub-mesh views into one wider layout-backed mesh.""" 

929 if len(meshes) == 1: 

930 return meshes[0] 

931 root_mesh, requested_dim_names, _ = DeviceMesh._validate_concatenate_inputs(meshes) 

932 DeviceMesh._validate_concatenate_root_order(root_mesh, requested_dim_names) 

933 ( 

934 concat_dim_names, 

935 concat_sizes, 

936 concat_strides, 

937 concat_dim_group_names, 

938 concat_dim_group_backends, 

939 concat_dim_group_sources, 

940 ) = DeviceMesh._collect_concatenate_metadata(meshes) 

941 concat_layout = DeviceMesh._build_concatenate_layout(concat_sizes, concat_strides) 

942 if not concat_layout.check_non_overlap(): 

943 raise ValueError(f"Cannot concatenate overlapping meshes: {meshes}") 

944 

945 res_mesh = DeviceMesh( 

946 meshes[0].device_type, 

947 mesh_dim_names=tuple(concat_dim_names), 

948 _init_backend=False, 

949 _layout=concat_layout, 

950 _rank_map=meshes[0]._rank_map, # pylint: disable=protected-access 

951 _root_mesh=meshes[0]._get_root_mesh(), # pylint: disable=protected-access 

952 ) 

953 DeviceMesh._set_concatenated_group_state( 

954 res_mesh, 

955 concat_dim_group_names, 

956 concat_dim_group_backends, 

957 concat_dim_group_sources, 

958 ) 

959 return res_mesh 

960 

961 _concatenate = concatenate 

962 

963 def _create_flatten_mesh( 

964 self, 

965 mesh_dim_name: Optional[str] = None, 

966 backend_override: BackendConfig = None, 

967 ) -> 'DeviceMesh': 

968 """Create or reuse a flattened one-dimensional mesh view.""" 

969 root_mesh = self._get_root_mesh() 

970 

971 if mesh_dim_name is None: 

972 mesh_dim_name = "_".join(self.mesh_dim_names) 

973 

974 if self.ndim == 1 and mesh_dim_name in self.mesh_dim_names: # pylint: disable=E1135 

975 return self 

976 

977 invalid_dim_names = root_mesh.mesh_dim_names 

978 if mesh_dim_name in invalid_dim_names: 

979 raise ValueError( 

980 f"'{mesh_dim_name}' already exists in the root mesh mesh_dim_names " 

981 f"{invalid_dim_names}. Please specify another valid mesh_dim_name." 

982 ) 

983 

984 flattened_mesh_layout = self._layout.coalesce() 

985 if len(flattened_mesh_layout) > 1: 

986 flattened_mesh_layout = flattened_mesh_layout.nest() 

987 

988 flatten_mapping = root_mesh.get_flatten_mapping() 

989 if mesh_dim_name in flatten_mapping: 

990 cached_mesh = flatten_mapping[mesh_dim_name] 

991 if cached_mesh._layout == flattened_mesh_layout: # pylint: disable=protected-access 

992 return cached_mesh 

993 raise ValueError( 

994 f"Flatten mesh with mesh_dim_name '{mesh_dim_name}' has been created " 

995 f"before with different layout. Please specify another valid mesh_dim_name." 

996 ) 

997 

998 res_flattened_mesh = DeviceMesh( 

999 device_type=root_mesh.device_type, 

1000 mesh_dim_names=(mesh_dim_name,), 

1001 _init_backend=False, 

1002 _layout=flattened_mesh_layout, 

1003 _rank_map=root_mesh._rank_map, 

1004 _root_mesh=root_mesh, 

1005 ) 

1006 res_flattened_mesh._dim_group_backends = (backend_override,) # pylint: disable=W0212 

1007 if hasattr(self, "_dim_group_names"): 

1008 res_flattened_mesh._dim_group_names = DeviceMesh._init_process_groups_for_layout( # pylint: disable=W0212 

1009 res_flattened_mesh._layout, 

1010 root_mesh._rank_map, 

1011 res_flattened_mesh.mesh_dim_names, 

1012 backend_override=(backend_override,), 

1013 ) 

1014 

1015 root_mesh.add_flatten_mapping(mesh_dim_name, res_flattened_mesh) 

1016 root_mesh._sub_mesh_cache[(mesh_dim_name,)] = res_flattened_mesh # pylint: disable=W0212 

1017 root_mesh.sub_mesh.append(res_flattened_mesh) 

1018 

1019 return res_flattened_mesh 

1020 

1021 def _create_unflatten_mesh( 

1022 self, 

1023 dim: int, 

1024 mesh_sizes: tuple[int, ...], 

1025 mesh_dim_names: tuple[str, ...], 

1026 backend_override: tuple[BackendConfig, ...], 

1027 ) -> 'DeviceMesh': 

1028 """Split one logical mesh axis into multiple named axes.""" 

1029 inner_layout = _MeshLayout(mesh_sizes, _contiguous_strides(mesh_sizes)) 

1030 original_layout = self._layout[dim] 

1031 if inner_layout.numel() != original_layout.numel(): 

1032 raise ValueError( 

1033 f"The product of mesh_sizes={mesh_sizes} is {inner_layout.numel()}, " 

1034 f"but the original dimension at dim={dim} has size {original_layout.numel()}." 

1035 ) 

1036 

1037 partial_layout = original_layout.composition(inner_layout) 

1038 unflattened_layout = self._layout.splice(dim, dim + 1, partial_layout) 

1039 unflattened_mesh_dim_names = list(self.mesh_dim_names or ()) 

1040 unflattened_mesh_dim_names[dim: dim + 1] = list(mesh_dim_names) 

1041 

1042 root_mesh = self._get_root_mesh() 

1043 res_mesh = DeviceMesh( 

1044 self.device_type, 

1045 mesh_dim_names=tuple(unflattened_mesh_dim_names), 

1046 _init_backend=False, 

1047 _layout=unflattened_layout, 

1048 _rank_map=root_mesh._rank_map, 

1049 _root_mesh=root_mesh, 

1050 ) 

1051 

1052 dim_group_backends = list(self._dim_group_backends) 

1053 dim_group_backends[dim: dim + 1] = list(backend_override) 

1054 res_mesh._dim_group_backends = tuple(dim_group_backends) # pylint: disable=W0212 

1055 

1056 if hasattr(self, "_dim_group_names"): 

1057 dim_group_names = list(self._dim_group_names) 

1058 dim_group_names[dim: dim + 1] = DeviceMesh._init_process_groups_for_layout( 

1059 partial_layout, 

1060 root_mesh._rank_map, 

1061 mesh_dim_names, 

1062 backend_override=backend_override, 

1063 ) 

1064 res_mesh._dim_group_names = dim_group_names # pylint: disable=W0212 

1065 

1066 return res_mesh 

1067 

1068 def _flatten(self, mesh_dim_name: Optional[str] = None, backend_override: Any = None) -> 'DeviceMesh': 

1069 return self._create_flatten_mesh( 

1070 mesh_dim_name, 

1071 backend_override=_normalize_backend_value(backend_override), 

1072 ) 

1073 

1074 def _unflatten( 

1075 self, 

1076 dim: Union[int, str], 

1077 mesh_sizes: tuple[int, ...], 

1078 mesh_dim_names: tuple[str, ...], 

1079 backend_override: Optional[dict[Union[int, str], Any]] = None, 

1080 ) -> 'DeviceMesh': 

1081 """Torch-compatible helper that expands one mesh axis into a nested layout.""" 

1082 if isinstance(dim, int): 

1083 if dim < 0 or dim >= self.ndim: 

1084 raise ValueError(f"dim {dim} specified in `_unflatten` is out of range {self.ndim}") 

1085 else: 

1086 mesh_dim_names_tuple = self.mesh_dim_names or () 

1087 if dim not in mesh_dim_names_tuple: 

1088 raise ValueError(f"dim {dim} specified in `_unflatten` is not in {mesh_dim_names_tuple}") 

1089 dim = mesh_dim_names_tuple.index(dim) 

1090 

1091 if len(mesh_sizes) != len(mesh_dim_names): 

1092 raise RuntimeError("mesh_dim_names must have same length as mesh_sizes in _unflatten!") 

1093 

1094 backend_override_tuple = ( 

1095 _normalize_backend_override(backend_override, len(mesh_sizes), mesh_dim_names) 

1096 if backend_override is not None 

1097 else (None,) * len(mesh_dim_names) 

1098 ) 

1099 return self._create_unflatten_mesh(dim, mesh_sizes, mesh_dim_names, backend_override_tuple) 

1100 

1101 def assert_axis(self, axis, operate_name): 

1102 if not self.mesh_dim_names: 

1103 raise RuntimeError(f"mesh_dim_names not specified, {operate_name} is not supported.") 

1104 if axis not in self.mesh_dim_names: # pylint: disable=E1135 

1105 raise ValueError( 

1106 f"The axis name must be one of mesh dim name {self.mesh_dim_names}, but got {axis}" 

1107 ) 

1108 

1109 def axis_id(self, axis): 

1110 if axis == "None": 

1111 return -1 

1112 self.assert_axis(axis, "axis_id") 

1113 return self._dev_name_to_dev_id[axis] 

1114 

1115 def axis_index(self, axis): 

1116 self.assert_axis(axis, "axis_index") 

1117 return self._dev_name_to_index[axis] 

1118 

1119 def get_device_num_along_axis(self, axis): 

1120 self.assert_axis(axis, "get_device_num_along_axis") 

1121 return self.mesh_shape[self.mesh_dim_names.index(axis)] 

1122 

1123 def get_rank_list_along_axis(self, mesh_dim): 

1124 """Return the ranks that share every other coordinate with the current rank.""" 

1125 if mesh_dim in self._cache_rank_list_along_axis: 

1126 return self._cache_rank_list_along_axis[mesh_dim] 

1127 self.assert_axis(mesh_dim, "get_rank_list_along_axis") 

1128 

1129 mesh_shape = self.mesh_shape 

1130 mesh_dim_names = self.mesh_dim_names 

1131 rank_list = self.rank_list 

1132 rank = self.rank 

1133 

1134 if rank not in rank_list: 

1135 raise ValueError(f"Rank {rank} not found in rank_list") 

1136 

1137 idx = rank_list.index(rank) 

1138 coord = [0] * len(mesh_shape) 

1139 temp = idx 

1140 for i in range(len(mesh_shape) - 1, -1, -1): 

1141 coord[i] = temp % mesh_shape[i] 

1142 temp //= mesh_shape[i] 

1143 

1144 dim_index = mesh_dim_names.index(mesh_dim) 

1145 strides = [1] * len(mesh_shape) 

1146 for i in range(len(mesh_shape) - 2, -1, -1): 

1147 strides[i] = strides[i + 1] * mesh_shape[i + 1] 

1148 

1149 result_ranks = [] 

1150 for v in range(mesh_shape[dim_index]): 

1151 new_coord = coord.copy() 

1152 new_coord[dim_index] = v 

1153 new_idx = 0 

1154 for i in range(len(mesh_shape)): 

1155 new_idx += new_coord[i] * strides[i] 

1156 

1157 result_ranks.append(rank_list[new_idx]) 

1158 

1159 self._cache_rank_list_along_axis[mesh_dim] = result_ranks 

1160 return result_ranks 

1161 

1162 def get_global_shape(self, slice_shape, tensor_map): 

1163 """Infer the global tensor shape from a shard shape and tensor-map metadata.""" 

1164 map_key = hash((slice_shape, tensor_map)) 

1165 if map_key in self._global_shape_map: 

1166 return self._global_shape_map[map_key] 

1167 if tensor_map is None: 

1168 raise ValueError( 

1169 "tensor_map is not set. Please configure the tensor map by calling the layout." 

1170 ) 

1171 if len(slice_shape) != len(tensor_map): 

1172 raise ValueError( 

1173 f"Length of slice_shape ({len(slice_shape)}) must match " 

1174 f"the length of tensor_map ({len(tensor_map)})." 

1175 ) 

1176 

1177 n_dims = len(self._mesh_shape) 

1178 factors = [1] * len(slice_shape) 

1179 

1180 for dev_idx, size in enumerate(self._mesh_shape): 

1181 reverse_idx = n_dims - 1 - dev_idx 

1182 for axis_idx, mapping in enumerate(tensor_map): 

1183 if isinstance(mapping, int): 

1184 if mapping == -1: 

1185 continue 

1186 if mapping == reverse_idx: 

1187 factors[axis_idx] *= size 

1188 break 

1189 elif isinstance(mapping, tuple): 

1190 if reverse_idx in mapping: 

1191 factors[axis_idx] *= size 

1192 break 

1193 

1194 global_shape = [] 

1195 for i, dim in enumerate(slice_shape): 

1196 global_shape.append(dim * factors[i]) 

1197 self._global_shape_map[map_key] = tuple(global_shape) 

1198 return tuple(global_shape) 

1199 

1200 def _materialize_dim_group(self, mesh_dim: int) -> Optional[str]: 

1201 """Create a deferred process group for one mesh dimension on first use.""" 

1202 if not hasattr(self, "_dim_group_names"): 

1203 self._dim_group_names = [None] * self.ndim # pylint: disable=W0201 

1204 

1205 if hasattr(self, "_dim_group_sources"): 

1206 source_mesh, source_dim = self._dim_group_sources[mesh_dim] # pylint: disable=W0212 

1207 if source_mesh is not self or source_dim != mesh_dim: 

1208 source_group_key = source_mesh._materialize_dim_group(source_dim) # pylint: disable=W0212 

1209 self._dim_group_names[mesh_dim] = source_group_key 

1210 return source_group_key 

1211 

1212 group_key = self._dim_group_names[mesh_dim] 

1213 if group_key is not None and group_key in EXISTING_COMM_GROUPS: 

1214 return group_key 

1215 

1216 split_ranks, group_key = DeviceMesh._build_dim_split_ranks(self._layout[mesh_dim], self._rank_map) 

1217 group = platform.split_group(split_ranks=split_ranks) 

1218 DeviceMesh._cache_group_if_needed(group_key, group) 

1219 self._dim_group_names[mesh_dim] = group_key 

1220 return group_key 

1221 

1222 def get_comm_group_by_axis(self, mesh_dim: Union[str, int]): 

1223 """Return the cached or lazily materialized process group for one mesh axis.""" 

1224 if self.ndim == 1 and mesh_dim is None: 

1225 mesh_dim = 0 

1226 

1227 if isinstance(mesh_dim, str): 

1228 if self.mesh_dim_names is None or len(self.mesh_dim_names) == 0: 

1229 raise ValueError(f"DeviceMesh mesh_dim_names is not set, string mesh_dim {mesh_dim}, is not support.") 

1230 if mesh_dim not in self.mesh_dim_names: # pylint: disable=E1135 

1231 raise ValueError( 

1232 f"mesh_dim can pass a string or integer, but string mesh_dim '{mesh_dim}' not found in " 

1233 f"mesh_dim_names {self.mesh_dim_names}" 

1234 ) 

1235 mesh_dim = self.mesh_dim_names.index(mesh_dim) 

1236 else: 

1237 if not isinstance(mesh_dim, int) or mesh_dim < 0 or mesh_dim >= self.ndim: 

1238 raise ValueError( 

1239 f"mesh_dim can pass a string or integer, if not string, mesh_dim should be a integer in range " 

1240 f"[0, {self.ndim}), but got {mesh_dim}" 

1241 ) 

1242 

1243 if not hasattr(self, "_dim_group_names"): 

1244 raise RuntimeError("DeviceMesh process groups not initialized!") 

1245 

1246 group_key = self._dim_group_names[mesh_dim] 

1247 if group_key is None or group_key not in EXISTING_COMM_GROUPS: 

1248 group_key = self._materialize_dim_group(mesh_dim) 

1249 if group_key not in EXISTING_COMM_GROUPS: 

1250 raise ValueError(f"{group_key} not in group cache {EXISTING_COMM_GROUPS.keys()}") 

1251 return EXISTING_COMM_GROUPS[group_key] 

1252 

1253 def get_devices_for_axis(self, mesh_dim: Union[str, int], rank: int): 

1254 """List peer ranks that share all coordinates except the requested axis.""" 

1255 if isinstance(mesh_dim, str): 

1256 if not self.mesh_dim_names: 

1257 raise ValueError("_mesh_dim_names is not set, string mesh_dim is not supported, please pass a integer.") 

1258 mesh_dim_names = self.mesh_dim_names 

1259 if mesh_dim not in mesh_dim_names: # pylint: disable=E1135 

1260 raise ValueError(f"mesh_dim '{mesh_dim}' not found in mesh_dim_names {mesh_dim_names}") 

1261 mesh_dim = mesh_dim_names.index(mesh_dim) 

1262 

1263 mesh_shape = self._mesh_shape 

1264 if mesh_dim < 0 or mesh_dim >= self.ndim: 

1265 raise ValueError(f"mesh_dim {mesh_dim} can not out of range [0, {self.ndim})") 

1266 rank_list = self._rank_list 

1267 if rank not in rank_list: 

1268 raise ValueError(f"Rank {rank} not found in rank_list") 

1269 

1270 idx = rank_list.index(rank) 

1271 coord = [0] * len(mesh_shape) 

1272 temp = idx 

1273 for i in range(len(mesh_shape) - 1, -1, -1): 

1274 coord[i] = temp % mesh_shape[i] 

1275 temp //= mesh_shape[i] 

1276 

1277 strides = [1] * len(mesh_shape) 

1278 for i in range(len(mesh_shape) - 2, -1, -1): 

1279 strides[i] = strides[i + 1] * mesh_shape[i + 1] 

1280 

1281 result_ranks = [] 

1282 for v in range(mesh_shape[mesh_dim]): 

1283 new_coord = coord.copy() 

1284 new_coord[mesh_dim] = v 

1285 new_idx = 0 

1286 for i in range(len(mesh_shape)): 

1287 new_idx += new_coord[i] * strides[i] 

1288 

1289 result_ranks.append(rank_list[new_idx]) 

1290 

1291 return result_ranks 

1292 

1293 def to_hash(self): 

1294 map_key = (self.mesh_shape, self.mesh_dim_names, self.rank_list) 

1295 return map_key 

1296 

1297 def __repr__(self): 

1298 return ( 

1299 f"DeviceMesh(device_type='{self.device_type}', mesh_shape={self._mesh_shape}, " 

1300 f"mesh_dim_names={self.mesh_dim_names}, rank_list={self._rank_list})" 

1301 ) 

1302 

1303 def __str__(self): 

1304 return self.__repr__() 

1305 

1306 def __deepcopy__(self, memo): 

1307 cls = self.__class__ 

1308 result = cls.__new__(cls) 

1309 memo[id(self)] = result 

1310 for k, v in self.__dict__.items(): 

1311 if k in ("_root_mesh", "_dim_group_sources"): 

1312 setattr(result, k, v) 

1313 else: 

1314 setattr(result, k, copy.deepcopy(v, memo)) 

1315 return result 

1316 

1317 

1318_DEVICE_MESH_MAP = {} 

1319 

1320 

1321def _create_device_mesh(device_type: str, 

1322 mesh_shape: tuple[int, ...], 

1323 *, 

1324 mesh_dim_names: Union[tuple[str, ...], list[str], None] = None, 

1325 rank_list: tuple[int, ...], 

1326 init_backend: bool = True, ): 

1327 """Create or reuse a cached DeviceMesh with the requested topology.""" 

1328 mesh = np.array(rank_list).reshape(mesh_shape) 

1329 mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None 

1330 map_key = hash((mesh_shape, mesh_dim_names, rank_list)) 

1331 if map_key not in _DEVICE_MESH_MAP: 

1332 _DEVICE_MESH_MAP[map_key] = DeviceMesh(device_type, mesh, 

1333 mesh_dim_names=mesh_dim_names, 

1334 _init_backend=init_backend) 

1335 return _DEVICE_MESH_MAP.get(map_key, None) 

1336 

1337 

1338def init_device_mesh( 

1339 device_type: str, 

1340 mesh_shape: tuple[int, ...], 

1341 *, 

1342 mesh_dim_names: Union[tuple[str, ...], list[str], None] = None, 

1343 rank_list: Optional[tuple[int, ...]] = None, 

1344 init_backend: bool = True, 

1345) -> DeviceMesh: 

1346 """Initialize a cached DeviceMesh from the provided shape, names, and ranks.""" 

1347 total_devices = int(np.prod(np.array(mesh_shape))) 

1348 if rank_list is not None: 

1349 if len(rank_list) != total_devices: 

1350 raise ValueError( 

1351 f"rank_list length ({len(rank_list)}) must equal mesh size ({total_devices})" 

1352 ) 

1353 else: 

1354 if init_backend: 

1355 platform.init_process_group() 

1356 try: 

1357 current_rank = platform.get_rank() 

1358 except Exception as exc: 

1359 raise RuntimeError( 

1360 "init_device_mesh: failed to get current rank for automatic rank_list generation. " 

1361 "Either pass rank_list explicitly, or ensure the process group is initialized before calling " 

1362 "init_device_mesh (or set init_backend=True to let init_device_mesh initialize it)." 

1363 ) from exc 

1364 base = current_rank - (current_rank % total_devices) 

1365 rank_list = tuple(range(base, base + total_devices)) 

1366 

1367 if not isinstance(mesh_shape, tuple): 

1368 raise TypeError(f'mesh_shape must be a tuple, but got {type(mesh_shape)}') 

1369 

1370 for size in mesh_shape: 

1371 if not isinstance(size, int) or size <= 0: 

1372 raise ValueError( 

1373 f"Each element of mesh_shape must be a positive integer, but got {mesh_shape}" 

1374 ) 

1375 

1376 if mesh_dim_names is not None: 

1377 if not isinstance(mesh_dim_names, (tuple, list)): 

1378 raise TypeError( 

1379 f'mesh_dim_names must be a tuple or list, but got {type(mesh_dim_names)}' 

1380 ) 

1381 mesh_dim_names = tuple(mesh_dim_names) 

1382 if len(mesh_shape) != len(mesh_dim_names): 

1383 raise ValueError( 

1384 f'mesh_shape ({len(mesh_shape)}) and mesh_dim_names ' 

1385 f'({len(mesh_dim_names)}) should have same length' 

1386 ) 

1387 if len(set(mesh_dim_names)) != len(mesh_dim_names): 

1388 raise ValueError(f'Each element of mesh_dim_names {mesh_dim_names} should be different') 

1389 if any(not isinstance(name, str) or name == "" for name in mesh_dim_names): 

1390 raise ValueError(f'Each element of mesh_dim_names {mesh_dim_names} should be a non-empty string') 

1391 

1392 return _create_device_mesh( 

1393 device_type, 

1394 mesh_shape, 

1395 mesh_dim_names=mesh_dim_names, 

1396 rank_list=rank_list, 

1397 init_backend=init_backend, 

1398 )