Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / dtensor / random.py: 36%
205 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""RNG state management for distributed tensor operations.
17Provides utilities for tracking and synchronizing random number generator states
18across multiple devices in distributed training scenarios.
19"""
21__all__ = [
22 "is_rng_supported_mesh",
23 "manual_seed",
24 "OffsetBasedRNGTracker",
25]
27import contextlib
28import warnings
29from logging import getLogger
30import typing
31from typing import Optional
32import functools
33import operator
35from hyper_parallel.core.dtensor.placement_types import Shard
36from hyper_parallel.core.dtensor.device_mesh import DeviceMesh
37from hyper_parallel.platform import get_platform
39platform = get_platform()
40DTensorBase = platform.DTensorBase
41Tensor = platform.tensor
43logger = getLogger(__name__)
46def is_rng_supported_mesh(device_mesh: Optional[DeviceMesh] = None) -> bool:
47 """Check if the device mesh supports DTensor random operations.
49 Currently, DTensor random operations are only supported on CUDA and CUDA-like
50 devices. Users should call this function before using DTensor random APIs to
51 verify compatibility.
53 Args:
54 device_mesh: Optional :class:`DeviceMesh` to check (same semantics as PyTorch
55 ``torch.distributed.tensor``). If omitted, checks the active platform device
56 handle only.
58 Returns:
59 bool: ``True`` if the device mesh supports DTensor random operations,
60 ``False`` otherwise.
61 """
62 if device_mesh is not None and device_mesh.device_type == "cpu":
63 warnings.warn(
64 f"DTensor random operators may not have complete support on {device_mesh.device_type} device mesh",
65 stacklevel=2,
66 )
67 return False
68 device_handle = platform.get_device_handle()
69 if device_handle and hasattr(device_handle, "set_rng_state"):
70 return True
71 if device_mesh is not None:
72 warnings.warn(
73 f"DTensor random operators may not have complete support on {device_mesh.device_type} device mesh",
74 stacklevel=2,
75 )
76 return False
79class _PhiloxState:
80 """
81 Convenience accessor for interpreting the packed bits of (seed: uint64, offset: uint64) in the philox state,
82 which for some reason is actually exposed as a size-16 uint8 tensor.
84 The state is always moved to .cpu since it is necessary for it to be on CPU before applying it back to a generator.
85 """
87 def __init__(self, state: Tensor):
88 self._state = state.to("cpu")
90 @property
91 def state(self):
92 return self._state
94 @property
95 def offset(self) -> int:
96 return int(self._state[8:].view(dtype=platform.tensor_dtype.int64).item())
98 @offset.setter
99 def offset(self, offset: int) -> None:
100 offset_tensor = Tensor([offset], dtype=platform.tensor_dtype.uint64).view(
101 platform.tensor_dtype.uint8
102 ) # device?
103 self._state[8:] = offset_tensor
105 @property
106 def seed(self) -> int:
107 return int(self._state[:8].view(dtype=platform.tensor_dtype.uint64).item())
109 @seed.setter
110 def seed(self, seed: int) -> None:
111 seed_tensor = Tensor([seed], dtype=platform.tensor_dtype.uint64).view(
112 platform.tensor_dtype.uint8
113 )# device
114 self._state[:8] = seed_tensor
117class _RNGStateTracker:
118 """
119 Tracks and manages RNG states for DTensor random operations.
121 Maintains a mapping from operation tags to RNG state tensors (ByteTensor),
122 providing standardized interfaces for state access and modification.
124 The core method `_distribute_region` establishes the proper RNG context
125 when DTensor executes random operators across distributed devices.
126 """
128 def __init__(self, device):
129 self._device = device
130 self._device_handle = platform.get_device_handle()
131 if not self._device_handle:
132 raise RuntimeError(
133 f"{self.__class__.__name__} instantiation requires the presence of "
134 )
135 self._use_distribute_region = True
137 @property
138 def distribute_region_enabled(self) -> bool:
139 return self._use_distribute_region
141 @distribute_region_enabled.setter
142 def distribute_region_enabled(self, value) -> None:
143 self._use_distribute_region = value
145 def _distribute_region(
146 self, device_mesh, placements, global_shape, generator = None
147 ):
148 pass
150 def _manual_seed(self, parallel_seed: int) -> None:
151 pass
154class OffsetBasedRNGTracker(_RNGStateTracker):
155 """
156 This subclass of ``_RNGStateTracker`` defines the default policy of how RNG states
157 should be shared and synchronized among all ranks to respect the semantics of DTensor
158 random operators.
159 """
161 def __init__(
162 self,
163 run_state_sync: bool = True,
164 ):
165 super().__init__(_resolve_device())
166 rng_state = self._get_device_state()
167 if run_state_sync:
168 # synchronize RNG state using rank 0's current one
169 platform.broadcast(rng_state, 0)
170 my_rng_state = self._get_device_state()
171 if not all(my_rng_state == rng_state):
172 logger.warning(
173 "DTensor is synchronizing RNG states of every rank with the state from rank 0. "
174 "This behavior is deprecated. "
175 "Please call ``manual_seed(seed, device_mesh)`` from "
176 "``hyper_parallel.core.dtensor.random`` on every rank that participates in SPMD DTensor "
177 "operations with the same seed. If using Pipeline Parallelism, each pipelining state would use "
178 "a different seed, but all ranks belonging to one pipeline stage would use the same seed."
179 )
180 self._set_device_state(rng_state)
182 def _manual_seed(self, parallel_seed: int) -> None:
183 """Set default RNG seed (``platform.manual_seed``); same idea as PyTorch DTensor."""
184 platform.manual_seed(parallel_seed)
186 def _get_device_state(self):
187 rng_state = self._device_handle.get_rng_state().to(self._device)
188 return rng_state
190 def _set_device_state(self, state: Tensor):
191 # It seems that the underlying generator wants a cpu tensor but the dtensor code expects `_get_device_state`
192 # to convert to a 'device' tensor, probably because we may use it with our backend comms for sync/debug
193 # for now, we just convert back to cpu here to make sure it always works.
194 self._device_handle.set_rng_state(state.to("cpu"))
196 @contextlib.contextmanager
197 def _distribute_region(
198 self, device_mesh, placements, global_shape, generator = None
199 ):
201 # regular (non-LocalTensor) mode
202 if generator is not None:
203 # This is a little hacky, but for any user-passed generator, we store its state under a unique key,
204 # not because we need to keep a copy of it but because its the easiest way to make it work with the
205 # existing set/get APIs. We also ensure we remove it from rng_states after each _distribute_region.
206 state = _PhiloxState(generator.get_state())
207 else:
208 state = _PhiloxState(self._get_device_state())
210 if self.distribute_region_enabled:
211 old_offset = state.offset
212 self._set_pre_op_offset(state, device_mesh, placements, global_shape)
213 with fork_rng(
214 devices=[self._device], device_type=platform.device_type()
215 ):
216 self._device_handle.set_rng_state(state.state)
217 try:
218 yield # execute the region code
219 finally:
220 # update offset to synchronize among ranks
221 self._set_post_op_offset(state, global_shape, old_offset)
223 else:
224 yield
226 if generator is not None:
227 # ensure we (a) propagate the state advancement back to the user's RNG so its visible and impacts any future
228 # usage of that RNG (dtensor or non-dtensor), (b) drop it from our own cache so that if the user updates
229 # the seed value in their rng and uses it with DTensor again, we always use the latest value
230 generator.set_state(state.state)
231 else:
232 self._set_device_state(state.state)
234 def compute_offset_incr(self, device_mesh, placements, global_shape) -> int:
235 """Compute the per-shard RNG offset increment for the current rank.
237 Based on the shard linear index and local shard size, computes how much to
238 advance the offset so that each shard gets a unique portion of the random stream.
240 Args:
241 device_mesh (DeviceMesh): The device mesh describing the device topology.
242 placements (Sequence[Placement]): The placement strategy for each mesh dimension.
243 global_shape: input global shape
245 Returns:
246 int: The offset increment, 4-byte aligned.
247 """
248 mesh_coordinate = device_mesh.get_coordinate()
249 shard_idx_by_dim, total_num_shards_by_dim = _calc_shard_info(
250 mesh_coordinate, device_mesh, placements
251 )
252 shard_linear_idx = self._calc_shard_linear_idx(
253 shard_idx_by_dim, total_num_shards_by_dim
254 )
255 local_size_on_rank_0 = _calc_first_shard_size(device_mesh, placements, global_shape)
256 local_size = functools.reduce(operator.mul, local_size_on_rank_0, 1)
257 return (shard_linear_idx * local_size + 3) // 4 * 4
259 def _set_pre_op_offset(self, state: _PhiloxState, device_mesh, placements, global_shape) -> None:
260 """Set the starting random number generator (RNG) offset for the local shard
261 on the current process before operation execution.The offset value begins from
262 the current accumulated position and increments by the local shard size until
263 covering the total elements of the global distributed tensor. Multiple processes
264 holding replicas of the same shard will share identical starting offset values.
266 Args:
267 state (`Tensor`): The generator state to modify
268 device_mesh (DeviceMesh): The device mesh describing the device topology.
269 placements (Sequence[Placement]): The placement strategy for each mesh dimension.
270 Each element should be a Placement object (Shard, Replicate, Partial, etc.).
271 global_shape: input global shape
273 Returns:
274 None
276 .. warning::
277 The current implementation does not consider memory layout contiguity.
279 Example:
280 take a DTensor of shape [8, 16] as an example. Assume that the DTensor
281 is placed on a device mesh with placements ([Shard(1), Replicate(), Shard(0)]),
282 and the mesh is:
283 [[[0, 1], [2, 3]], [[4, 5], [6, 7]]]
284 ``mesh.get_coordinate()`` provides the coordinate of the current rank
285 in the mesh. For example, the coordinate of rank 5 is (1, 0, 1).
287 Another concept to introduce besides rank coordinate is shard coordinate.
288 Each rank holds a local shard of the DTensor. In the example, the DTensor
289 is partitioned into 4 [4, 8] shards. The first shard has 2 replicas and
290 rank 0 (coord (0, 0, 0)) and rank 2 (coord (0, 1, 0)) have 1 replica each.
291 That being said, the local shard on rank 0 and rank 2 correspond to the same
292 shard of the DTensor. To denote each DTensor shard, we use a shard coordinate
293 (in the example, it will be a tuple (i, j) where shard (i, j) has the slice
294 DTensor[4 * i : 4 * (i + 1), 8 * j : 8 * (j + 1)], 0 <= i < 2, 0 <= j < 2).
296 Once we have rank coordinate and shard coordinate, we can calculate on each rank
297 what shard of the DTensor the rank holds, with the help of dim_map. The dim_map
298 of the above DTensor is [2, 0] so the shard coordinate of a rank with rank coord
299 (x, y, z) is simply (z, x) by taking(rank_coord[dim_map[0]],rank_coord[dim_map[1]]).
300 Following this calculation,
301 rank 0 and rank 2 holds the shard of coord (0, 0);
302 rank 1 and rank 3 holds the shard of coord (0, 1);
303 rank 4 and rank 6 holds the shard of coord (1, 0);
304 rank 5 and rank 7 holds the shard of coord (1, 1);
306 The last value to calculate before obtaining the starting offset is the shard linear index.
307 The starting offset for each rank will be its shard_linear_index * local_tensor_numel.
308 """
309 current_offset = state.offset
310 offset_incr = self.compute_offset_incr(device_mesh, placements, global_shape)
311 state.offset = current_offset + offset_incr
313 def _set_post_op_offset(
314 self, state: _PhiloxState, global_shape, old_offset: int
315 ) -> None:
316 """Sets the RNG to a synchronized state after running the local random op.
317 Restores the random number generator to a globally consistent state following
318 local shard execution. Each process must advance its offset by the total element
319 count of the distributed tensor, measured from the offset value recorded before
320 the operation began.
322 Args:
323 state (`Tensor`): The generator state to modify.
324 global_shape: The global shape of the distributed tensor.
325 old_offset (int): The RNG offset before the operation.
327 Returns:
328 None
329 """
330 numel = functools.reduce(operator.mul, global_shape, 1)
331 numel = (numel + 3) // 4 * 4
332 state.offset = old_offset + numel
334 def _calc_shard_linear_idx(
335 self, shard_coord: list[int], shard_size: list[int]
336 ) -> int:
337 return _calc_shard_linear_idx(shard_coord, shard_size)
340def _calc_first_shard_size(device_mesh, placements, global_shape) -> list[int]:
341 """Calculate the size of the first shard on rank 0.
343 Args:
344 device_mesh: The device mesh describing the device topology.
345 placements: Sequence of Placement objects (Shard, Replicate, etc.).
346 global_shape: input global shape
348 Returns:
349 list[int]: Shape of rank 0's local shard.
350 """
351 local_size_on_rank_0 = list(global_shape)
352 for idx, placement in enumerate(placements):
353 if isinstance(placement, Shard):
354 mesh_dim_size = device_mesh.size(idx)
355 shard_dim = placement.dim
356 local_size_on_rank_0[shard_dim], _ = local_shard_size_and_offset(
357 global_shape[shard_dim],
358 mesh_dim_size,
359 0,
360 )
361 return local_size_on_rank_0
364def _calc_shard_info(
365 mesh_coordinate, device_mesh, placements
366):
367 """Calculate shard information for a specific rank."""
368 mesh_size = device_mesh.mesh_shape
369 # note: dim_map does not allow double sharding which is the FSDP(fully_shard)+TP
370 # case. Replace the custom logic with dim_map once we support it.
371 dim_map = [-1] * device_mesh.ndim
372 for i, placement in enumerate(placements):
373 if isinstance(placement, Shard):
374 shard_dim = placement.dim
375 if dim_map[shard_dim] == -1:
376 dim_map[shard_dim] = [i]
377 else:
378 mesh_dim_list = dim_map[shard_dim]
379 if not isinstance(mesh_dim_list, list):
380 raise TypeError(f"Expected mesh_dim_list to be a list, got {type(mesh_dim_list)}")
381 mesh_dim_list.append(i)
383 # Compute shard coordinate:
384 # The coordinate on each tensor dim is a tuple (idx, range)
385 # If a DTensor is partitioned on its dim i into n shards, and the current rank
386 # holds the j-th, then its shard coordinate will be (idx=j, range=n) on dim i
387 if mesh_coordinate is None:
388 raise ValueError("mesh_coordinate must not be None")
389 shard_idx_by_dim = []
390 total_num_shards_by_dim = [] # total number of shards on each tensor dim
391 for mesh_dim in dim_map:
392 shard_idx = 0
393 total_num_shards = 1
394 # the tensor dim is sharded on more than 1 mesh dim
395 if isinstance(mesh_dim, list):
396 rank_coord = [mesh_coordinate[d] for d in mesh_dim]
397 num_shards = [mesh_size[d] for d in mesh_dim]
398 # compute the shard idx and total number of shards
399 for idx, size in zip(rank_coord, num_shards):
400 shard_idx = shard_idx * size + idx
401 total_num_shards *= size
403 shard_idx_by_dim.append(shard_idx)
404 total_num_shards_by_dim.append(total_num_shards)
405 return shard_idx_by_dim, total_num_shards_by_dim
408def _calc_shard_linear_idx(shard_coord: list[int], shard_size: list[int]) -> int:
409 # compute shard linear index
410 shard_linear_idx = 0
411 shard_coord_stride = 1
412 for idx, size in zip(reversed(shard_coord), reversed(shard_size)):
413 shard_linear_idx += idx * shard_coord_stride
414 shard_coord_stride *= size
416 return shard_linear_idx
419def _resolve_device():
420 device_handle = platform.get_device_handle()
421 device_idx = platform.get_rank() % platform.device_count(device_handle)
423 def get_device(device_idx):
424 return platform.device(device_idx)
426 return get_device(device_idx)
429def manual_seed(seed: int, device_mesh: DeviceMesh) -> None:
430 """Set the seed for generating random numbers on the calling rank (PyTorch DTensor parity).
432 Ensures the global RNG used by DTensor random ops is initialized consistently. Lazily
433 creates the :class:`OffsetBasedRNGTracker` used by shard dispatch with
434 ``run_state_sync=False`` so ranks are not synchronized from rank 0's prior RNG state.
436 Args:
437 seed: Desired RNG seed (must be agreed across ranks in the mesh for SPMD).
438 device_mesh: Mesh that must include the current process rank.
440 Raises:
441 RuntimeError: If the current rank is not part of ``device_mesh`` (undefined DTensor
442 RNG behavior in that case).
444 Warning:
445 Does not validate that ``seed`` matches across ranks; callers must ensure SPMD
446 consistency. Pipeline parallel: use one seed per pipeline stage group as in PyTorch.
447 """
448 if not is_rng_supported_mesh(device_mesh):
449 warnings.warn(
450 "DTensor manual_seed() may not have complete support "
451 f"on {device_mesh.device_type} device mesh",
452 stacklevel=2,
453 )
454 return
456 # Local import avoids import cycle: _op_dispatch imports this module at load time.
457 from hyper_parallel.core.shard._op_dispatch import _OP_DISPATCHER # pylint: disable=C0415
459 if _OP_DISPATCHER._rng_tracker is None:
460 _OP_DISPATCHER._rng_tracker = OffsetBasedRNGTracker(run_state_sync=False)
462 if device_mesh.get_coordinate() is None:
463 raise RuntimeError(
464 "manual_seed requires the current rank to be a part of the device mesh "
465 "otherwise DTensor RNG state on the rank will not be initialized and "
466 "the behavior of DTensor random ops is undefined."
467 )
469 platform.manual_seed(seed)
472def local_shard_size_and_offset(
473 curr_local_size: int,
474 num_chunks: int,
475 rank,
476):
477 """
478 Given the size of the current local tensor (which may already be sharded on some dimensions),
479 computes the new local shard size and offset given the desired number of chunks
480 (num_chunks is generally equal to the size of the current sharding dim).
482 Note: new local shard offset is relative to the current sharded tensor, not the global tensor.
483 See `_utils.compute_local_shape_and_global_offset` for computing global offset.
485 Returns (new local shard size, offset)
487 """
488 # Compute the chunk size inline
489 if curr_local_size % num_chunks == 0:
490 full_chunk_size = curr_local_size // num_chunks
491 shard_starting_idx = full_chunk_size * rank
492 return full_chunk_size, shard_starting_idx
494 # uneven sharding case
495 full_chunk_size = (curr_local_size + num_chunks - 1) // num_chunks
496 shard_starting_idx = full_chunk_size * rank
498 if curr_local_size < shard_starting_idx:
499 return 0, typing.cast(int, curr_local_size)
500 local_shard_size = (
501 min(curr_local_size, shard_starting_idx + full_chunk_size)
502 - shard_starting_idx
503 )
504 return local_shard_size, shard_starting_idx
507_fork_rng_warned_already = False
510@contextlib.contextmanager
511def fork_rng(
512 devices=None,
513 enabled=True,
514 device_type="npu",
515):
516 """
517 Forks the RNG, so that when you return, the RNG is reset
518 to the state that it was previously in.
520 Args:
521 devices (iterable of Device IDs): devices for which to fork
522 the RNG. CPU RNG state is always forked. By default, :meth:`fork_rng` operates
523 on all devices, but will emit a warning if your machine has a lot
524 of devices, since this function will run very slowly in that case.
525 If you explicitly specify devices, this warning will be suppressed
526 enabled (bool): if ``False``, the RNG is not forked. This is a convenience
527 argument for easily disabling the context manager without having
528 to delete it and unindent your Python code under it.
529 device_type (str): device type str, default is `npu`. As for supported device,
530 see details in :ref:`accelerator<accelerators>`
531 """
533 device_mod = platform.get_device_handle()
534 if device_mod is None:
535 raise RuntimeError(
536 f"{platform} has no module of `{device_type}`, you should register "
537 )
538 global _fork_rng_warned_already
540 if not enabled:
541 yield
542 return
544 if devices is None:
545 num_devices = platform.device_count(device_mod)
546 if num_devices > 1 and not _fork_rng_warned_already:
547 _fork_rng_warned_already = True
548 devices = list(range(num_devices))
549 else:
550 # Protect against user passing us a generator; we need to traverse this
551 # multiple times but a generator will be exhausted upon first traversal
552 devices = list(devices)
554 cpu_rng_state = platform.get_rng_state()
555 device_rng_states = [platform.get_rng_state(device, device_mod) for device in devices]
557 try:
558 yield
559 finally:
560 platform.set_rng_state(cpu_rng_state)
561 for device, device_rng_state in zip(devices, device_rng_states):
562 platform.set_rng_state(device_rng_state, device, device_mod)