Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / symmetric_memory / __init__.py: 0%

43 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-20 07:18 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Symmetric memory module for hyper-parallel.""" 

16import os 

17import pathlib 

18from hyper_parallel.platform import get_platform 

19 

20#优先加载hyper-parallel/lib/shmem目录下的shmem库,确保使用的是编译好的shmem库 

21project_root = pathlib.Path(__file__).parent.parent.parent 

22shmem_lib_dir = project_root / "lib" / "shmem" 

23 

24if not shmem_lib_dir.exists(): 

25 raise FileNotFoundError(f"shmem库目录不存在: {shmem_lib_dir}") 

26 

27ld_path = os.environ.get("LD_LIBRARY_PATH", "") 

28new_ld_path = f"{shmem_lib_dir}:{ld_path}" if ld_path else str(shmem_lib_dir) 

29os.environ["LD_LIBRARY_PATH"] = new_ld_path 

30 

31platform = get_platform() 

32_symm_handler = platform.get_symmetric_memory_handler() 

33 

34 

35def is_shmem_available() -> bool: 

36 return _symm_handler.is_shmem_available() 

37 

38def empty(shape, dtype): 

39 r""" 

40 Similar to :func:`empty()`. The returned tensor will malloc a 

41 symmetric memory among participating processes. This output tensor can be directly used 

42 for one-sided communication, normal computation, or can be used by rendezvous() 

43 

44 Args: 

45 shape (int...): a sequence of integers defining the shape of the output tensor. 

46 Can be a variable number of arguments or a collection like a list or tuple. 

47 

48 Keyword args: 

49 dtype (:class:`mindspore.dtype` or :class:`torch.dtype`): the desired data type of returned tensor. 

50 

51 example:: 

52 >>> # doctest: +SKIP 

53 >>> # Create a symmetric memory tensor of shape (2, 3) with float32 data type 

54 >>> symm_tensor = symmetric_memory.empty((2, 3), dtype=mindspore.float32) 

55 """ 

56 return _symm_handler.empty(shape, dtype) 

57 

58def barrier(): 

59 r""" 

60 A synchronization barrier for all processes in the default process group. 

61 This function blocks until all processes have reached this method, ensuring 

62 that all processes are synchronized at this point in the code. 

63 

64 Example:: 

65 >>> # doctest: +SKIP 

66 >>> # Synchronize all processes before starting a new phase of computation 

67 >>> symmetric_memory.barrier() 

68 """ 

69 return _symm_handler.barrier() 

70 

71def rendezvous(tensor, group): 

72 r""" 

73 This interface is not compatible yet, cause shmem rendezvous is still in development.  

74 

75 rendezvous(tensor, group) -> _SymmetricMemory 

76 

77 Establish a symmetric memory tensor among participating processes. This is 

78 a collective operation. It will malloc a signal symmetric memory coupled with tensor 

79 and get all corresponding ptrs for buffer and signal. 

80 

81 Args: 

82 tensor: the local tensor used to establish the symmetric memory tensor. 

83 It must be allocated via :func:`symmetric_memory.empty()`. The shape, 

84 dtype, and device type must be identical across all participating processes. 

85 group: The group identifying the 

86 participating processes. This can be either a group name or a process group object. 

87 """ 

88 return _symm_handler.rendezvous(tensor, group) 

89 

90def set_signal_pad_size(size: int) -> None: 

91 r""" 

92 This interface is not compatible yet, you can alloc the signal tensor in your program instead. 

93 Set the signal pad size for future symmetric memory allocations. 

94 

95 Signal pads are P2P-accessible memory regions used for synchronization in 

96 symmetric memory. This function allows users to configure 

97 the signal pad size to be proportional to their workload requirements. 

98 

99 .. warning:: 

100 This must be called before any symmetric memory allocations are made. 

101 The size cannot be changed after allocations have been performed. 

102 

103 Args: 

104 size (int): the signal pad size in bytes. The size should be 

105 proportional to the number of blocks launched and the world size. 

106 """ 

107 return _symm_handler.set_signal_pad_size(size) 

108 

109 

110def get_signal_pad_size() -> int: 

111 r""" 

112 This interface is not compatible yet, you can alloc the signal tensor in your program instead. 

113 Get the current signal pad size for symmetric memory allocations. 

114 

115 Returns the user-configured size if set via :func:`set_signal_pad_size`, 

116 otherwise returns the default size. 

117 

118 Returns: 

119 int: the signal pad size in bytes. 

120 """ 

121 return _symm_handler.get_signal_pad_size() 

122 

123def shmem_put(target, target_offset, src, src_offset, size, target_rank): 

124 r""" 

125 Perform a one-sided send operation to write data from the local source tensor to a target tensor. 

126 """ 

127 _symm_handler.shmem_put(target, target_offset, src, src_offset, size, target_rank) 

128 

129def shmem_get(target, target_offset, src, src_offset, size, target_rank): 

130 r""" 

131 Perform a one-sided receive operation to read data from a target tensor into the local source tensor. 

132 """ 

133 _symm_handler.shmem_get(target, target_offset, src, src_offset, size, target_rank) 

134 

135def shmem_signal_op(signal, signal_offset, signal_value, signal_op, target_rank): 

136 r""" 

137 Perform an atomic operation on a signal in the symmetric memory. 

138 

139 This function allows for atomic updates to a signal value at a specified offset within a symmetric memory tensor. 

140 The operation is performed on the target rank's memory, enabling efficient synchronization between processes. 

141 

142 Args: 

143 signal (tensor, int32): The symmetric memory tensor that contains the signal to be updated. (Data type: int32) 

144 signal_offset (tensor): The byte offset within the signal tensor  

145 where the signal value is located. 

146 signal_value (tensor, int32): The value to update the signal with. (Data type: int32) 

147 signal_op (int64, optional): The operation to perform on the signal value, 0:set, 1:add. Defaults to 0. 

148 target_rank (int64, optional): The rank of the target process that owns the signal tensor. Defaults to 0. 

149 """ 

150 _symm_handler.shmem_signal_op(signal, signal_offset, signal_value, signal_op, target_rank) 

151 

152def shmem_wait_for_signal(depend_tensor, signal, signal_offset, compare_value, compare_op): 

153 r""" 

154 Wait for a signal to satisfy a specified condition before proceeding. 

155 This function blocks the calling process until the value at the specified signal offset 

156 meets the condition defined by compare_value and compare_op. 

157 Args: 

158 depend_tensor (tensor): A tensor that the wait operation depends on.  

159 It is used to ensure proper ordering of operations. 

160 signal (tensor, int32): The symmetric memory tensor that contains the signal to wait on. (Data type: int32) 

161 signal_offset (tensor): The byte offset within the signal tensor where the signal value is located. 

162 compare_value (tensor, int32): The value to compare against the  

163 signal value at the specified offset. (Data type: int32) 

164 compare_op (int64, optional): The comparison operator to use.  

165 0: equal, 1: greater than, 2: less than. Defaults to 0. 

166 """ 

167 _symm_handler.shmem_wait_for_signal(depend_tensor, signal, signal_offset, compare_value, compare_op) 

168 

169def shmem_put_with_signal(target, target_offset, src, src_offset, 

170 size, signal, signal_offset, signal_value, signal_op, target_rank): 

171 r""" 

172 Perform a one-sided send operation to write data from the local source tensor to a target tensor, 

173 then update the signal value at signal_offset with signal_op. 

174 This function combines the data transfer of shmem_put with an atomic update to a signal,  

175 allowing for efficient synchronization after the put operation. 

176 Args: 

177 target (tensor): The target symmetric memory tensor to write to. 

178 target_offset (tensor): The byte offset within the target tensor where the data should be written. 

179 src (tensor): The local source tensor containing the data to be sent. 

180 Its dtype must match the dtype of the target tensor. 

181 src_offset (tensor): The byte offset within the source tensor where the data to be sent is located. 

182 size (tensor): The size of the data to be sent in bytes. 

183 signal (tensor, int32): The symmetric memory tensor that contains  

184 the signal to be updated after the put operation. (Data type: int32) 

185 signal_offset (tensor): The byte offset within the signal tensor where the signal value is located. 

186 signal_value (tensor, int32): The value to update the signal with after the put operation. (Data type: int32) 

187 signal_op (int64, optional): The operation to perform on the signal value, 0:set, 1:add. Defaults to 0. 

188 target_rank (int64, optional): The rank of the target process that  

189 owns the target tensor and signal. Defaults to 0. 

190 """ 

191 _symm_handler.shmem_put_with_signal(target, target_offset, src, src_offset, 

192 size, signal, signal_offset, signal_value, signal_op, target_rank) 

193 

194def shmem_allgather(output_tensor, input_tensor): 

195 """ 

196 This interface only supports torch for now, the mindspore version is still in development. 

197 This function gathers the input tensor from all ranks and concatenates them into the output tensor. 

198 The resulting output tensor will contain the gathered data from all ranks, 

199 and the order of the gathered data will correspond to the order of the ranks. 

200 All ranks must provide an input tensor of the same shape and dtype, 

201 and the output tensor must be appropriately sized to hold the gathered data from all ranks. 

202 Args: 

203 output_tensor (tensor): The symmetric memory tensor that will hold the gathered data from all ranks.  

204 Its shape should be (world_size * local_shape) where local_shape is the shape of input_tensor. 

205 input_tensor (tensor): The local tensor to be gathered from each rank. 

206 """ 

207 _symm_handler.shmem_allgather(output_tensor, input_tensor) 

208 

209def shmem_alltoall(send_tensor_list, receive_tensor, receive_list): 

210 """ 

211 This interface only supports torch for now, the mindspore version is still in development. 

212 This function performs an all-to-all communication pattern where each rank sends a tensor  

213 to every other rank and receives a tensor from every other rank. 

214 the send_tensor_list is a list of tensors to be sent to each rank, 

215 and the receive_tensor is the tensor that will hold the received data from all ranks. 

216 The receive_list is a list of tensors that will hold the received data from each rank. 

217 The order of the received data in the receive_tensor and receive_list will correspond to the order of the ranks. 

218 Args: 

219 send_tensor_list (list of tensors): A list of tensors to be sent to each rank,  

220 where send_tensor_list[i] is the tensor to be sent to rank i. 

221 receive_tensor (tensor): The symmetric memory tensor that will hold the received data from all ranks. 

222 receive_list (list of int): A list of int that specifies the size of the data to be received from each rank, 

223 where receive_list[i] is the size of the data to be received from rank i. 

224 """ 

225 _symm_handler.shmem_alltoall(send_tensor_list, receive_tensor, receive_list) 

226 

227def fused_all_gather_matmul(a, b, c, gather_out, signal, block_size): 

228 """ 

229 This interface only supports torch for now, the mindspore version is still in development. 

230 fused_all_gather_matmul(a, b, c, gather_out, signal, block_size=None) 

231 Fused operator combining allgather and matmul operations. 

232 

233 Computational flow: 

234 1. gather_out = allgather(a) # Gather local tensor 'a' from all ranks 

235 2. c = ReduceScatter(gather_out @ b) # Matrix multiplication followed by reduce-scatter 

236 

237 Parameters: 

238 a: Local input tensor with shape (M_local, K). 

239 b: Weight matrix with shape (K, N). 

240 c: Output tensor with shape (M, N). 

241 gather_out: Output tensor containing gathered 'a' from all ranks,  

242 shape (M, K) where M = M_local * world_size. 

243 signal: Symmetric memory tensor with shape (world_size) and dtype int32. 

244 block_size: Optional block size for tiled computation. 

245 """ 

246 return _symm_handler.fused_all_gather_matmul(a, b, c, gather_out, signal, block_size) 

247 

248def fused_matmul_reduce_scatter(x1, x2, symm_tensor, signal, reduce_op): 

249 """ 

250 This interface only supports torch for now, the mindspore version is still in development. 

251 Fusion operator: Fuses Matmul and ReduceScatter operations. 

252 Computation formula: output = ReduceScatter(x1 @ x2) 

253 

254 Parameters: 

255 x1: Left matrix with shape (m, k). 'm' must be an integer multiple of the number of devices (rank size). 

256 x2: Right matrix with shape (k, n). 

257 symm_tensor: Symmetric memory tensor with shape (m , n). 

258 signal: Symmetric memory tensor with shape (world_size) and dtype int32. 

259 reduce_op: Operator of scatter, only support 'sum' and 'avg'. Default value is 'sum'. 

260 

261 output: Output matrix with shape (m / rank_size, n). 

262 """ 

263 return _symm_handler.fused_matmul_reduce_scatter(x1, x2, symm_tensor, signal, reduce_op) 

264 

265 

266__all__ = [ 

267 "is_shmem_available", 

268 "empty", 

269 "rendezvous", 

270 "set_signal_pad_size", 

271 "get_signal_pad_size", 

272 "barrier", 

273 "shmem_put", 

274 "shmem_get", 

275 "shmem_wait_for_signal", 

276 "shmem_put_with_signal", 

277 "shmem_signal_op", 

278 "shmem_allgather", 

279 "shmem_alltoall", 

280 "fused_all_gather_matmul", 

281 "fused_matmul_reduce_scatter", 

282 # "overlap_launch_all_to_all_v", 

283]