Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / dtensor / random.py: 36%

205 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-20 07:18 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""RNG state management for distributed tensor operations. 

16 

17Provides utilities for tracking and synchronizing random number generator states 

18across multiple devices in distributed training scenarios. 

19""" 

20 

21__all__ = [ 

22 "is_rng_supported_mesh", 

23 "manual_seed", 

24 "OffsetBasedRNGTracker", 

25] 

26 

27import contextlib 

28import warnings 

29from logging import getLogger 

30import typing 

31from typing import Optional 

32import functools 

33import operator 

34 

35from hyper_parallel.core.dtensor.placement_types import Shard 

36from hyper_parallel.core.dtensor.device_mesh import DeviceMesh 

37from hyper_parallel.platform import get_platform 

38 

39platform = get_platform() 

40DTensorBase = platform.DTensorBase 

41Tensor = platform.tensor 

42 

43logger = getLogger(__name__) 

44 

45 

46def is_rng_supported_mesh(device_mesh: Optional[DeviceMesh] = None) -> bool: 

47 """Check if the device mesh supports DTensor random operations. 

48 

49 Currently, DTensor random operations are only supported on CUDA and CUDA-like 

50 devices. Users should call this function before using DTensor random APIs to 

51 verify compatibility. 

52 

53 Args: 

54 device_mesh: Optional :class:`DeviceMesh` to check (same semantics as PyTorch 

55 ``torch.distributed.tensor``). If omitted, checks the active platform device 

56 handle only. 

57 

58 Returns: 

59 bool: ``True`` if the device mesh supports DTensor random operations, 

60 ``False`` otherwise. 

61 """ 

62 if device_mesh is not None and device_mesh.device_type == "cpu": 

63 warnings.warn( 

64 f"DTensor random operators may not have complete support on {device_mesh.device_type} device mesh", 

65 stacklevel=2, 

66 ) 

67 return False 

68 device_handle = platform.get_device_handle() 

69 if device_handle and hasattr(device_handle, "set_rng_state"): 

70 return True 

71 if device_mesh is not None: 

72 warnings.warn( 

73 f"DTensor random operators may not have complete support on {device_mesh.device_type} device mesh", 

74 stacklevel=2, 

75 ) 

76 return False 

77 

78 

79class _PhiloxState: 

80 """ 

81 Convenience accessor for interpreting the packed bits of (seed: uint64, offset: uint64) in the philox state, 

82 which for some reason is actually exposed as a size-16 uint8 tensor. 

83 

84 The state is always moved to .cpu since it is necessary for it to be on CPU before applying it back to a generator. 

85 """ 

86 

87 def __init__(self, state: Tensor): 

88 self._state = state.to("cpu") 

89 

90 @property 

91 def state(self): 

92 return self._state 

93 

94 @property 

95 def offset(self) -> int: 

96 return int(self._state[8:].view(dtype=platform.tensor_dtype.int64).item()) 

97 

98 @offset.setter 

99 def offset(self, offset: int) -> None: 

100 offset_tensor = Tensor([offset], dtype=platform.tensor_dtype.uint64).view( 

101 platform.tensor_dtype.uint8 

102 ) # device? 

103 self._state[8:] = offset_tensor 

104 

105 @property 

106 def seed(self) -> int: 

107 return int(self._state[:8].view(dtype=platform.tensor_dtype.uint64).item()) 

108 

109 @seed.setter 

110 def seed(self, seed: int) -> None: 

111 seed_tensor = Tensor([seed], dtype=platform.tensor_dtype.uint64).view( 

112 platform.tensor_dtype.uint8 

113 )# device 

114 self._state[:8] = seed_tensor 

115 

116 

117class _RNGStateTracker: 

118 """ 

119 Tracks and manages RNG states for DTensor random operations. 

120 

121 Maintains a mapping from operation tags to RNG state tensors (ByteTensor), 

122 providing standardized interfaces for state access and modification. 

123 

124 The core method `_distribute_region` establishes the proper RNG context 

125 when DTensor executes random operators across distributed devices. 

126 """ 

127 

128 def __init__(self, device): 

129 self._device = device 

130 self._device_handle = platform.get_device_handle() 

131 if not self._device_handle: 

132 raise RuntimeError( 

133 f"{self.__class__.__name__} instantiation requires the presence of " 

134 ) 

135 self._use_distribute_region = True 

136 

137 @property 

138 def distribute_region_enabled(self) -> bool: 

139 return self._use_distribute_region 

140 

141 @distribute_region_enabled.setter 

142 def distribute_region_enabled(self, value) -> None: 

143 self._use_distribute_region = value 

144 

145 def _distribute_region( 

146 self, device_mesh, placements, global_shape, generator = None 

147 ): 

148 pass 

149 

150 def _manual_seed(self, parallel_seed: int) -> None: 

151 pass 

152 

153 

154class OffsetBasedRNGTracker(_RNGStateTracker): 

155 """ 

156 This subclass of ``_RNGStateTracker`` defines the default policy of how RNG states 

157 should be shared and synchronized among all ranks to respect the semantics of DTensor 

158 random operators. 

159 """ 

160 

161 def __init__( 

162 self, 

163 run_state_sync: bool = True, 

164 ): 

165 super().__init__(_resolve_device()) 

166 rng_state = self._get_device_state() 

167 if run_state_sync: 

168 # synchronize RNG state using rank 0's current one 

169 platform.broadcast(rng_state, 0) 

170 my_rng_state = self._get_device_state() 

171 if not all(my_rng_state == rng_state): 

172 logger.warning( 

173 "DTensor is synchronizing RNG states of every rank with the state from rank 0. " 

174 "This behavior is deprecated. " 

175 "Please call ``manual_seed(seed, device_mesh)`` from " 

176 "``hyper_parallel.core.dtensor.random`` on every rank that participates in SPMD DTensor " 

177 "operations with the same seed. If using Pipeline Parallelism, each pipelining state would use " 

178 "a different seed, but all ranks belonging to one pipeline stage would use the same seed." 

179 ) 

180 self._set_device_state(rng_state) 

181 

182 def _manual_seed(self, parallel_seed: int) -> None: 

183 """Set default RNG seed (``platform.manual_seed``); same idea as PyTorch DTensor.""" 

184 platform.manual_seed(parallel_seed) 

185 

186 def _get_device_state(self): 

187 rng_state = self._device_handle.get_rng_state().to(self._device) 

188 return rng_state 

189 

190 def _set_device_state(self, state: Tensor): 

191 # It seems that the underlying generator wants a cpu tensor but the dtensor code expects `_get_device_state` 

192 # to convert to a 'device' tensor, probably because we may use it with our backend comms for sync/debug 

193 # for now, we just convert back to cpu here to make sure it always works. 

194 self._device_handle.set_rng_state(state.to("cpu")) 

195 

196 @contextlib.contextmanager 

197 def _distribute_region( 

198 self, device_mesh, placements, global_shape, generator = None 

199 ): 

200 

201 # regular (non-LocalTensor) mode 

202 if generator is not None: 

203 # This is a little hacky, but for any user-passed generator, we store its state under a unique key, 

204 # not because we need to keep a copy of it but because its the easiest way to make it work with the 

205 # existing set/get APIs. We also ensure we remove it from rng_states after each _distribute_region. 

206 state = _PhiloxState(generator.get_state()) 

207 else: 

208 state = _PhiloxState(self._get_device_state()) 

209 

210 if self.distribute_region_enabled: 

211 old_offset = state.offset 

212 self._set_pre_op_offset(state, device_mesh, placements, global_shape) 

213 with fork_rng( 

214 devices=[self._device], device_type=platform.device_type() 

215 ): 

216 self._device_handle.set_rng_state(state.state) 

217 try: 

218 yield # execute the region code 

219 finally: 

220 # update offset to synchronize among ranks 

221 self._set_post_op_offset(state, global_shape, old_offset) 

222 

223 else: 

224 yield 

225 

226 if generator is not None: 

227 # ensure we (a) propagate the state advancement back to the user's RNG so its visible and impacts any future 

228 # usage of that RNG (dtensor or non-dtensor), (b) drop it from our own cache so that if the user updates 

229 # the seed value in their rng and uses it with DTensor again, we always use the latest value 

230 generator.set_state(state.state) 

231 else: 

232 self._set_device_state(state.state) 

233 

234 def compute_offset_incr(self, device_mesh, placements, global_shape) -> int: 

235 """Compute the per-shard RNG offset increment for the current rank. 

236 

237 Based on the shard linear index and local shard size, computes how much to 

238 advance the offset so that each shard gets a unique portion of the random stream. 

239 

240 Args: 

241 device_mesh (DeviceMesh): The device mesh describing the device topology. 

242 placements (Sequence[Placement]): The placement strategy for each mesh dimension. 

243 global_shape: input global shape 

244 

245 Returns: 

246 int: The offset increment, 4-byte aligned. 

247 """ 

248 mesh_coordinate = device_mesh.get_coordinate() 

249 shard_idx_by_dim, total_num_shards_by_dim = _calc_shard_info( 

250 mesh_coordinate, device_mesh, placements 

251 ) 

252 shard_linear_idx = self._calc_shard_linear_idx( 

253 shard_idx_by_dim, total_num_shards_by_dim 

254 ) 

255 local_size_on_rank_0 = _calc_first_shard_size(device_mesh, placements, global_shape) 

256 local_size = functools.reduce(operator.mul, local_size_on_rank_0, 1) 

257 return (shard_linear_idx * local_size + 3) // 4 * 4 

258 

259 def _set_pre_op_offset(self, state: _PhiloxState, device_mesh, placements, global_shape) -> None: 

260 """Set the starting random number generator (RNG) offset for the local shard 

261 on the current process before operation execution.The offset value begins from 

262 the current accumulated position and increments by the local shard size until 

263 covering the total elements of the global distributed tensor. Multiple processes 

264 holding replicas of the same shard will share identical starting offset values. 

265 

266 Args: 

267 state (`Tensor`): The generator state to modify 

268 device_mesh (DeviceMesh): The device mesh describing the device topology. 

269 placements (Sequence[Placement]): The placement strategy for each mesh dimension. 

270 Each element should be a Placement object (Shard, Replicate, Partial, etc.). 

271 global_shape: input global shape 

272 

273 Returns: 

274 None 

275 

276 .. warning:: 

277 The current implementation does not consider memory layout contiguity. 

278 

279 Example: 

280 take a DTensor of shape [8, 16] as an example. Assume that the DTensor 

281 is placed on a device mesh with placements ([Shard(1), Replicate(), Shard(0)]), 

282 and the mesh is: 

283 [[[0, 1], [2, 3]], [[4, 5], [6, 7]]] 

284 ``mesh.get_coordinate()`` provides the coordinate of the current rank 

285 in the mesh. For example, the coordinate of rank 5 is (1, 0, 1). 

286 

287 Another concept to introduce besides rank coordinate is shard coordinate. 

288 Each rank holds a local shard of the DTensor. In the example, the DTensor 

289 is partitioned into 4 [4, 8] shards. The first shard has 2 replicas and 

290 rank 0 (coord (0, 0, 0)) and rank 2 (coord (0, 1, 0)) have 1 replica each. 

291 That being said, the local shard on rank 0 and rank 2 correspond to the same 

292 shard of the DTensor. To denote each DTensor shard, we use a shard coordinate 

293 (in the example, it will be a tuple (i, j) where shard (i, j) has the slice 

294 DTensor[4 * i : 4 * (i + 1), 8 * j : 8 * (j + 1)], 0 <= i < 2, 0 <= j < 2). 

295 

296 Once we have rank coordinate and shard coordinate, we can calculate on each rank 

297 what shard of the DTensor the rank holds, with the help of dim_map. The dim_map 

298 of the above DTensor is [2, 0] so the shard coordinate of a rank with rank coord 

299 (x, y, z) is simply (z, x) by taking(rank_coord[dim_map[0]],rank_coord[dim_map[1]]). 

300 Following this calculation, 

301 rank 0 and rank 2 holds the shard of coord (0, 0); 

302 rank 1 and rank 3 holds the shard of coord (0, 1); 

303 rank 4 and rank 6 holds the shard of coord (1, 0); 

304 rank 5 and rank 7 holds the shard of coord (1, 1); 

305 

306 The last value to calculate before obtaining the starting offset is the shard linear index. 

307 The starting offset for each rank will be its shard_linear_index * local_tensor_numel. 

308 """ 

309 current_offset = state.offset 

310 offset_incr = self.compute_offset_incr(device_mesh, placements, global_shape) 

311 state.offset = current_offset + offset_incr 

312 

313 def _set_post_op_offset( 

314 self, state: _PhiloxState, global_shape, old_offset: int 

315 ) -> None: 

316 """Sets the RNG to a synchronized state after running the local random op. 

317 Restores the random number generator to a globally consistent state following 

318 local shard execution. Each process must advance its offset by the total element 

319 count of the distributed tensor, measured from the offset value recorded before 

320 the operation began. 

321 

322 Args: 

323 state (`Tensor`): The generator state to modify. 

324 global_shape: The global shape of the distributed tensor. 

325 old_offset (int): The RNG offset before the operation. 

326 

327 Returns: 

328 None 

329 """ 

330 numel = functools.reduce(operator.mul, global_shape, 1) 

331 numel = (numel + 3) // 4 * 4 

332 state.offset = old_offset + numel 

333 

334 def _calc_shard_linear_idx( 

335 self, shard_coord: list[int], shard_size: list[int] 

336 ) -> int: 

337 return _calc_shard_linear_idx(shard_coord, shard_size) 

338 

339 

340def _calc_first_shard_size(device_mesh, placements, global_shape) -> list[int]: 

341 """Calculate the size of the first shard on rank 0. 

342 

343 Args: 

344 device_mesh: The device mesh describing the device topology. 

345 placements: Sequence of Placement objects (Shard, Replicate, etc.). 

346 global_shape: input global shape 

347 

348 Returns: 

349 list[int]: Shape of rank 0's local shard. 

350 """ 

351 local_size_on_rank_0 = list(global_shape) 

352 for idx, placement in enumerate(placements): 

353 if isinstance(placement, Shard): 

354 mesh_dim_size = device_mesh.size(idx) 

355 shard_dim = placement.dim 

356 local_size_on_rank_0[shard_dim], _ = local_shard_size_and_offset( 

357 global_shape[shard_dim], 

358 mesh_dim_size, 

359 0, 

360 ) 

361 return local_size_on_rank_0 

362 

363 

364def _calc_shard_info( 

365 mesh_coordinate, device_mesh, placements 

366): 

367 """Calculate shard information for a specific rank.""" 

368 mesh_size = device_mesh.mesh_shape 

369 # note: dim_map does not allow double sharding which is the FSDP(fully_shard)+TP 

370 # case. Replace the custom logic with dim_map once we support it. 

371 dim_map = [-1] * device_mesh.ndim 

372 for i, placement in enumerate(placements): 

373 if isinstance(placement, Shard): 

374 shard_dim = placement.dim 

375 if dim_map[shard_dim] == -1: 

376 dim_map[shard_dim] = [i] 

377 else: 

378 mesh_dim_list = dim_map[shard_dim] 

379 if not isinstance(mesh_dim_list, list): 

380 raise TypeError(f"Expected mesh_dim_list to be a list, got {type(mesh_dim_list)}") 

381 mesh_dim_list.append(i) 

382 

383 # Compute shard coordinate: 

384 # The coordinate on each tensor dim is a tuple (idx, range) 

385 # If a DTensor is partitioned on its dim i into n shards, and the current rank 

386 # holds the j-th, then its shard coordinate will be (idx=j, range=n) on dim i 

387 if mesh_coordinate is None: 

388 raise ValueError("mesh_coordinate must not be None") 

389 shard_idx_by_dim = [] 

390 total_num_shards_by_dim = [] # total number of shards on each tensor dim 

391 for mesh_dim in dim_map: 

392 shard_idx = 0 

393 total_num_shards = 1 

394 # the tensor dim is sharded on more than 1 mesh dim 

395 if isinstance(mesh_dim, list): 

396 rank_coord = [mesh_coordinate[d] for d in mesh_dim] 

397 num_shards = [mesh_size[d] for d in mesh_dim] 

398 # compute the shard idx and total number of shards 

399 for idx, size in zip(rank_coord, num_shards): 

400 shard_idx = shard_idx * size + idx 

401 total_num_shards *= size 

402 

403 shard_idx_by_dim.append(shard_idx) 

404 total_num_shards_by_dim.append(total_num_shards) 

405 return shard_idx_by_dim, total_num_shards_by_dim 

406 

407 

408def _calc_shard_linear_idx(shard_coord: list[int], shard_size: list[int]) -> int: 

409 # compute shard linear index 

410 shard_linear_idx = 0 

411 shard_coord_stride = 1 

412 for idx, size in zip(reversed(shard_coord), reversed(shard_size)): 

413 shard_linear_idx += idx * shard_coord_stride 

414 shard_coord_stride *= size 

415 

416 return shard_linear_idx 

417 

418 

419def _resolve_device(): 

420 device_handle = platform.get_device_handle() 

421 device_idx = platform.get_rank() % platform.device_count(device_handle) 

422 

423 def get_device(device_idx): 

424 return platform.device(device_idx) 

425 

426 return get_device(device_idx) 

427 

428 

429def manual_seed(seed: int, device_mesh: DeviceMesh) -> None: 

430 """Set the seed for generating random numbers on the calling rank (PyTorch DTensor parity). 

431 

432 Ensures the global RNG used by DTensor random ops is initialized consistently. Lazily 

433 creates the :class:`OffsetBasedRNGTracker` used by shard dispatch with 

434 ``run_state_sync=False`` so ranks are not synchronized from rank 0's prior RNG state. 

435 

436 Args: 

437 seed: Desired RNG seed (must be agreed across ranks in the mesh for SPMD). 

438 device_mesh: Mesh that must include the current process rank. 

439 

440 Raises: 

441 RuntimeError: If the current rank is not part of ``device_mesh`` (undefined DTensor 

442 RNG behavior in that case). 

443 

444 Warning: 

445 Does not validate that ``seed`` matches across ranks; callers must ensure SPMD 

446 consistency. Pipeline parallel: use one seed per pipeline stage group as in PyTorch. 

447 """ 

448 if not is_rng_supported_mesh(device_mesh): 

449 warnings.warn( 

450 "DTensor manual_seed() may not have complete support " 

451 f"on {device_mesh.device_type} device mesh", 

452 stacklevel=2, 

453 ) 

454 return 

455 

456 # Local import avoids import cycle: _op_dispatch imports this module at load time. 

457 from hyper_parallel.core.shard._op_dispatch import _OP_DISPATCHER # pylint: disable=C0415 

458 

459 if _OP_DISPATCHER._rng_tracker is None: 

460 _OP_DISPATCHER._rng_tracker = OffsetBasedRNGTracker(run_state_sync=False) 

461 

462 if device_mesh.get_coordinate() is None: 

463 raise RuntimeError( 

464 "manual_seed requires the current rank to be a part of the device mesh " 

465 "otherwise DTensor RNG state on the rank will not be initialized and " 

466 "the behavior of DTensor random ops is undefined." 

467 ) 

468 

469 platform.manual_seed(seed) 

470 

471 

472def local_shard_size_and_offset( 

473 curr_local_size: int, 

474 num_chunks: int, 

475 rank, 

476): 

477 """ 

478 Given the size of the current local tensor (which may already be sharded on some dimensions), 

479 computes the new local shard size and offset given the desired number of chunks 

480 (num_chunks is generally equal to the size of the current sharding dim). 

481 

482 Note: new local shard offset is relative to the current sharded tensor, not the global tensor. 

483 See `_utils.compute_local_shape_and_global_offset` for computing global offset. 

484 

485 Returns (new local shard size, offset) 

486 

487 """ 

488 # Compute the chunk size inline 

489 if curr_local_size % num_chunks == 0: 

490 full_chunk_size = curr_local_size // num_chunks 

491 shard_starting_idx = full_chunk_size * rank 

492 return full_chunk_size, shard_starting_idx 

493 

494 # uneven sharding case 

495 full_chunk_size = (curr_local_size + num_chunks - 1) // num_chunks 

496 shard_starting_idx = full_chunk_size * rank 

497 

498 if curr_local_size < shard_starting_idx: 

499 return 0, typing.cast(int, curr_local_size) 

500 local_shard_size = ( 

501 min(curr_local_size, shard_starting_idx + full_chunk_size) 

502 - shard_starting_idx 

503 ) 

504 return local_shard_size, shard_starting_idx 

505 

506 

507_fork_rng_warned_already = False 

508 

509 

510@contextlib.contextmanager 

511def fork_rng( 

512 devices=None, 

513 enabled=True, 

514 device_type="npu", 

515): 

516 """ 

517 Forks the RNG, so that when you return, the RNG is reset 

518 to the state that it was previously in. 

519 

520 Args: 

521 devices (iterable of Device IDs): devices for which to fork 

522 the RNG. CPU RNG state is always forked. By default, :meth:`fork_rng` operates 

523 on all devices, but will emit a warning if your machine has a lot 

524 of devices, since this function will run very slowly in that case. 

525 If you explicitly specify devices, this warning will be suppressed 

526 enabled (bool): if ``False``, the RNG is not forked. This is a convenience 

527 argument for easily disabling the context manager without having 

528 to delete it and unindent your Python code under it. 

529 device_type (str): device type str, default is `npu`. As for supported device, 

530 see details in :ref:`accelerator<accelerators>` 

531 """ 

532 

533 device_mod = platform.get_device_handle() 

534 if device_mod is None: 

535 raise RuntimeError( 

536 f"{platform} has no module of `{device_type}`, you should register " 

537 ) 

538 global _fork_rng_warned_already 

539 

540 if not enabled: 

541 yield 

542 return 

543 

544 if devices is None: 

545 num_devices = platform.device_count(device_mod) 

546 if num_devices > 1 and not _fork_rng_warned_already: 

547 _fork_rng_warned_already = True 

548 devices = list(range(num_devices)) 

549 else: 

550 # Protect against user passing us a generator; we need to traverse this 

551 # multiple times but a generator will be exhausted upon first traversal 

552 devices = list(devices) 

553 

554 cpu_rng_state = platform.get_rng_state() 

555 device_rng_states = [platform.get_rng_state(device, device_mod) for device in devices] 

556 

557 try: 

558 yield 

559 finally: 

560 platform.set_rng_state(cpu_rng_state) 

561 for device, device_rng_state in zip(devices, device_rng_states): 

562 platform.set_rng_state(device_rng_state, device, device_mod)