Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / symmetric_memory / __init__.py: 0%
43 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Symmetric memory module for hyper-parallel."""
16import os
17import pathlib
18from hyper_parallel.platform import get_platform
20#优先加载hyper-parallel/lib/shmem目录下的shmem库,确保使用的是编译好的shmem库
21project_root = pathlib.Path(__file__).parent.parent.parent
22shmem_lib_dir = project_root / "lib" / "shmem"
24if not shmem_lib_dir.exists():
25 raise FileNotFoundError(f"shmem库目录不存在: {shmem_lib_dir}")
27ld_path = os.environ.get("LD_LIBRARY_PATH", "")
28new_ld_path = f"{shmem_lib_dir}:{ld_path}" if ld_path else str(shmem_lib_dir)
29os.environ["LD_LIBRARY_PATH"] = new_ld_path
31platform = get_platform()
32_symm_handler = platform.get_symmetric_memory_handler()
35def is_shmem_available() -> bool:
36 return _symm_handler.is_shmem_available()
38def empty(shape, dtype):
39 r"""
40 Similar to :func:`empty()`. The returned tensor will malloc a
41 symmetric memory among participating processes. This output tensor can be directly used
42 for one-sided communication, normal computation, or can be used by rendezvous()
44 Args:
45 shape (int...): a sequence of integers defining the shape of the output tensor.
46 Can be a variable number of arguments or a collection like a list or tuple.
48 Keyword args:
49 dtype (:class:`mindspore.dtype` or :class:`torch.dtype`): the desired data type of returned tensor.
51 example::
52 >>> # doctest: +SKIP
53 >>> # Create a symmetric memory tensor of shape (2, 3) with float32 data type
54 >>> symm_tensor = symmetric_memory.empty((2, 3), dtype=mindspore.float32)
55 """
56 return _symm_handler.empty(shape, dtype)
58def barrier():
59 r"""
60 A synchronization barrier for all processes in the default process group.
61 This function blocks until all processes have reached this method, ensuring
62 that all processes are synchronized at this point in the code.
64 Example::
65 >>> # doctest: +SKIP
66 >>> # Synchronize all processes before starting a new phase of computation
67 >>> symmetric_memory.barrier()
68 """
69 return _symm_handler.barrier()
71def rendezvous(tensor, group):
72 r"""
73 This interface is not compatible yet, cause shmem rendezvous is still in development.
75 rendezvous(tensor, group) -> _SymmetricMemory
77 Establish a symmetric memory tensor among participating processes. This is
78 a collective operation. It will malloc a signal symmetric memory coupled with tensor
79 and get all corresponding ptrs for buffer and signal.
81 Args:
82 tensor: the local tensor used to establish the symmetric memory tensor.
83 It must be allocated via :func:`symmetric_memory.empty()`. The shape,
84 dtype, and device type must be identical across all participating processes.
85 group: The group identifying the
86 participating processes. This can be either a group name or a process group object.
87 """
88 return _symm_handler.rendezvous(tensor, group)
90def set_signal_pad_size(size: int) -> None:
91 r"""
92 This interface is not compatible yet, you can alloc the signal tensor in your program instead.
93 Set the signal pad size for future symmetric memory allocations.
95 Signal pads are P2P-accessible memory regions used for synchronization in
96 symmetric memory. This function allows users to configure
97 the signal pad size to be proportional to their workload requirements.
99 .. warning::
100 This must be called before any symmetric memory allocations are made.
101 The size cannot be changed after allocations have been performed.
103 Args:
104 size (int): the signal pad size in bytes. The size should be
105 proportional to the number of blocks launched and the world size.
106 """
107 return _symm_handler.set_signal_pad_size(size)
110def get_signal_pad_size() -> int:
111 r"""
112 This interface is not compatible yet, you can alloc the signal tensor in your program instead.
113 Get the current signal pad size for symmetric memory allocations.
115 Returns the user-configured size if set via :func:`set_signal_pad_size`,
116 otherwise returns the default size.
118 Returns:
119 int: the signal pad size in bytes.
120 """
121 return _symm_handler.get_signal_pad_size()
123def shmem_put(target, target_offset, src, src_offset, size, target_rank):
124 r"""
125 Perform a one-sided send operation to write data from the local source tensor to a target tensor.
126 """
127 _symm_handler.shmem_put(target, target_offset, src, src_offset, size, target_rank)
129def shmem_get(target, target_offset, src, src_offset, size, target_rank):
130 r"""
131 Perform a one-sided receive operation to read data from a target tensor into the local source tensor.
132 """
133 _symm_handler.shmem_get(target, target_offset, src, src_offset, size, target_rank)
135def shmem_signal_op(signal, signal_offset, signal_value, signal_op, target_rank):
136 r"""
137 Perform an atomic operation on a signal in the symmetric memory.
139 This function allows for atomic updates to a signal value at a specified offset within a symmetric memory tensor.
140 The operation is performed on the target rank's memory, enabling efficient synchronization between processes.
142 Args:
143 signal (tensor, int32): The symmetric memory tensor that contains the signal to be updated. (Data type: int32)
144 signal_offset (tensor): The byte offset within the signal tensor
145 where the signal value is located.
146 signal_value (tensor, int32): The value to update the signal with. (Data type: int32)
147 signal_op (int64, optional): The operation to perform on the signal value, 0:set, 1:add. Defaults to 0.
148 target_rank (int64, optional): The rank of the target process that owns the signal tensor. Defaults to 0.
149 """
150 _symm_handler.shmem_signal_op(signal, signal_offset, signal_value, signal_op, target_rank)
152def shmem_wait_for_signal(depend_tensor, signal, signal_offset, compare_value, compare_op):
153 r"""
154 Wait for a signal to satisfy a specified condition before proceeding.
155 This function blocks the calling process until the value at the specified signal offset
156 meets the condition defined by compare_value and compare_op.
157 Args:
158 depend_tensor (tensor): A tensor that the wait operation depends on.
159 It is used to ensure proper ordering of operations.
160 signal (tensor, int32): The symmetric memory tensor that contains the signal to wait on. (Data type: int32)
161 signal_offset (tensor): The byte offset within the signal tensor where the signal value is located.
162 compare_value (tensor, int32): The value to compare against the
163 signal value at the specified offset. (Data type: int32)
164 compare_op (int64, optional): The comparison operator to use.
165 0: equal, 1: greater than, 2: less than. Defaults to 0.
166 """
167 _symm_handler.shmem_wait_for_signal(depend_tensor, signal, signal_offset, compare_value, compare_op)
169def shmem_put_with_signal(target, target_offset, src, src_offset,
170 size, signal, signal_offset, signal_value, signal_op, target_rank):
171 r"""
172 Perform a one-sided send operation to write data from the local source tensor to a target tensor,
173 then update the signal value at signal_offset with signal_op.
174 This function combines the data transfer of shmem_put with an atomic update to a signal,
175 allowing for efficient synchronization after the put operation.
176 Args:
177 target (tensor): The target symmetric memory tensor to write to.
178 target_offset (tensor): The byte offset within the target tensor where the data should be written.
179 src (tensor): The local source tensor containing the data to be sent.
180 Its dtype must match the dtype of the target tensor.
181 src_offset (tensor): The byte offset within the source tensor where the data to be sent is located.
182 size (tensor): The size of the data to be sent in bytes.
183 signal (tensor, int32): The symmetric memory tensor that contains
184 the signal to be updated after the put operation. (Data type: int32)
185 signal_offset (tensor): The byte offset within the signal tensor where the signal value is located.
186 signal_value (tensor, int32): The value to update the signal with after the put operation. (Data type: int32)
187 signal_op (int64, optional): The operation to perform on the signal value, 0:set, 1:add. Defaults to 0.
188 target_rank (int64, optional): The rank of the target process that
189 owns the target tensor and signal. Defaults to 0.
190 """
191 _symm_handler.shmem_put_with_signal(target, target_offset, src, src_offset,
192 size, signal, signal_offset, signal_value, signal_op, target_rank)
194def shmem_allgather(output_tensor, input_tensor):
195 """
196 This interface only supports torch for now, the mindspore version is still in development.
197 This function gathers the input tensor from all ranks and concatenates them into the output tensor.
198 The resulting output tensor will contain the gathered data from all ranks,
199 and the order of the gathered data will correspond to the order of the ranks.
200 All ranks must provide an input tensor of the same shape and dtype,
201 and the output tensor must be appropriately sized to hold the gathered data from all ranks.
202 Args:
203 output_tensor (tensor): The symmetric memory tensor that will hold the gathered data from all ranks.
204 Its shape should be (world_size * local_shape) where local_shape is the shape of input_tensor.
205 input_tensor (tensor): The local tensor to be gathered from each rank.
206 """
207 _symm_handler.shmem_allgather(output_tensor, input_tensor)
209def shmem_alltoall(send_tensor_list, receive_tensor, receive_list):
210 """
211 This interface only supports torch for now, the mindspore version is still in development.
212 This function performs an all-to-all communication pattern where each rank sends a tensor
213 to every other rank and receives a tensor from every other rank.
214 the send_tensor_list is a list of tensors to be sent to each rank,
215 and the receive_tensor is the tensor that will hold the received data from all ranks.
216 The receive_list is a list of tensors that will hold the received data from each rank.
217 The order of the received data in the receive_tensor and receive_list will correspond to the order of the ranks.
218 Args:
219 send_tensor_list (list of tensors): A list of tensors to be sent to each rank,
220 where send_tensor_list[i] is the tensor to be sent to rank i.
221 receive_tensor (tensor): The symmetric memory tensor that will hold the received data from all ranks.
222 receive_list (list of int): A list of int that specifies the size of the data to be received from each rank,
223 where receive_list[i] is the size of the data to be received from rank i.
224 """
225 _symm_handler.shmem_alltoall(send_tensor_list, receive_tensor, receive_list)
227def fused_all_gather_matmul(a, b, c, gather_out, signal, block_size):
228 """
229 This interface only supports torch for now, the mindspore version is still in development.
230 fused_all_gather_matmul(a, b, c, gather_out, signal, block_size=None)
231 Fused operator combining allgather and matmul operations.
233 Computational flow:
234 1. gather_out = allgather(a) # Gather local tensor 'a' from all ranks
235 2. c = ReduceScatter(gather_out @ b) # Matrix multiplication followed by reduce-scatter
237 Parameters:
238 a: Local input tensor with shape (M_local, K).
239 b: Weight matrix with shape (K, N).
240 c: Output tensor with shape (M, N).
241 gather_out: Output tensor containing gathered 'a' from all ranks,
242 shape (M, K) where M = M_local * world_size.
243 signal: Symmetric memory tensor with shape (world_size) and dtype int32.
244 block_size: Optional block size for tiled computation.
245 """
246 return _symm_handler.fused_all_gather_matmul(a, b, c, gather_out, signal, block_size)
248def fused_matmul_reduce_scatter(x1, x2, symm_tensor, signal, reduce_op):
249 """
250 This interface only supports torch for now, the mindspore version is still in development.
251 Fusion operator: Fuses Matmul and ReduceScatter operations.
252 Computation formula: output = ReduceScatter(x1 @ x2)
254 Parameters:
255 x1: Left matrix with shape (m, k). 'm' must be an integer multiple of the number of devices (rank size).
256 x2: Right matrix with shape (k, n).
257 symm_tensor: Symmetric memory tensor with shape (m , n).
258 signal: Symmetric memory tensor with shape (world_size) and dtype int32.
259 reduce_op: Operator of scatter, only support 'sum' and 'avg'. Default value is 'sum'.
261 output: Output matrix with shape (m / rank_size, n).
262 """
263 return _symm_handler.fused_matmul_reduce_scatter(x1, x2, symm_tensor, signal, reduce_op)
266__all__ = [
267 "is_shmem_available",
268 "empty",
269 "rendezvous",
270 "set_signal_pad_size",
271 "get_signal_pad_size",
272 "barrier",
273 "shmem_put",
274 "shmem_get",
275 "shmem_wait_for_signal",
276 "shmem_put_with_signal",
277 "shmem_signal_op",
278 "shmem_allgather",
279 "shmem_alltoall",
280 "fused_all_gather_matmul",
281 "fused_matmul_reduce_scatter",
282 # "overlap_launch_all_to_all_v",
283]