Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / torch / symmetric_memory / symmetric_memory.py: 0%

229 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-20 07:18 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Symmetric Memory""" 

16 

17import os 

18from logging import getLogger 

19import torch 

20import torch.distributed as dist 

21from hyper_parallel.platform import get_platform 

22 

23logger = getLogger(__name__) 

24 

25_is_shmem_available = False 

26 

27_manager = None 

28_ops = None 

29current_dir = os.path.dirname(os.path.abspath(__file__)) 

30file_path = os.path.join(current_dir, 'libaclshmem_torch.so') 

31if os.path.exists(file_path): 

32 torch.ops.load_library(file_path) 

33 _manager = torch.classes.SymmetricMemory.Manager() 

34 _ops = torch.classes.SymmetricMemory.Ops() 

35 _is_shmem_available = True 

36 

37 

38class TorchSymmetricMemoryHandler: 

39 """SymmetricMemory is used for one-sided communication.""" 

40 _is_init = False 

41 comm_streams = [] 

42 compute_streams = [] 

43 

44 @classmethod 

45 def _init_shmem(cls): 

46 """init platform""" 

47 if cls._is_init: 

48 return 

49 logger.info("start init torch symmetric memory") 

50 platform = get_platform() 

51 rank_id = platform.get_rank() 

52 world_size = platform.get_world_size() 

53 local_mem_size = os.getenv("SYMMETRIC_MEMORY_HEAP_SIZE") 

54 if local_mem_size is not None: 

55 local_mem_size = int(local_mem_size) 

56 else: 

57 local_mem_size = 1024 * 1024 * 1024 

58 ipports = "tcp://127.0.0.1:8662" 

59 logger.info("start init ach shmem: rank_id:%d, rank size:%d, heap size:%d, ipports:%s", 

60 rank_id, world_size, local_mem_size, ipports) 

61 _manager.attr_init(rank_id, world_size, local_mem_size, ipports) 

62 cls.comm_streams = [torch.npu.Stream() for _ in range(min(world_size, 16))] 

63 cls.compute_streams = [torch.npu.Stream() for _ in range(world_size)] 

64 cls._is_init = True 

65 logger.info("init symmetric memory success!") 

66 

67 @staticmethod 

68 def is_shmem_available(): 

69 return _is_shmem_available 

70 

71 @staticmethod 

72 def empty(shape, dtype): 

73 """create symmetric memory tensor""" 

74 if isinstance(shape, int): 

75 shape = [shape] 

76 elif isinstance(shape, tuple): 

77 shape = list(shape) 

78 if not TorchSymmetricMemoryHandler._is_init: 

79 TorchSymmetricMemoryHandler._init_shmem() 

80 return _manager.malloc(shape, dtype) 

81 

82 @staticmethod 

83 def barrier(): 

84 dist.barrier() 

85 

86 @staticmethod 

87 def rendezvous(tensor, group): 

88 raise NotImplementedError("In CANN SHMEM v1.0.0, rendezvous is not needed, " 

89 "symmetric memory are allocated at init time by SYMMETRIC_MEMORY_HEAP_SIZE, " 

90 "so this function is not implemented. ") 

91 

92 @staticmethod 

93 def set_signal_pad_size(size: int) -> None: 

94 raise NotImplementedError("In CANN SHMEM v1.0.0, set_signal_pad_size is not needed, " 

95 "symmetric memory are allocated at init time by SYMMETRIC_MEMORY_HEAP_SIZE, " 

96 "you can create symmetric signal memory by empty() " 

97 "so this function is not implemented. ") 

98 

99 @staticmethod 

100 def get_signal_pad_size() -> int: 

101 raise NotImplementedError("In CANN SHMEM v1.0.0, get_signal_pad_size is not needed, " 

102 "symmetric memory are allocated at init time by SYMMETRIC_MEMORY_HEAP_SIZE, " 

103 "you can create symmetric signal memory by empty() " 

104 "so this function is not implemented. ") 

105 

106 @staticmethod 

107 def shmem_put(target, target_offset, src, src_offset, size, target_rank): 

108 """shmem_put operator: shmem_put(target, target_offset, src, src_offset, size, target_rank)""" 

109 world_size = dist.get_world_size() 

110 if target_rank < 0 or target_rank >= world_size: 

111 raise ValueError(f"target_rank must be in range [0, {world_size - 1}], but get {target_rank}") 

112 _ops.put_mem(target, target_offset, src, src_offset, size, target_rank) 

113 

114 @staticmethod 

115 def shmem_get(target, target_offset, src, src_offset, size, target_rank): 

116 """shmem_get operator: shmem_get(target, target_offset, src, src_offset, size, target_rank)""" 

117 world_size = dist.get_world_size() 

118 if target_rank < 0 or target_rank >= world_size: 

119 raise ValueError(f"target_rank must be in range [0, {world_size - 1}], but get {target_rank}") 

120 _ops.get_mem(target, target_offset, src, src_offset, size, target_rank) 

121 

122 @staticmethod 

123 def shmem_signal_op(signal, signal_offset, signal_value, signal_op, target_rank): 

124 """shmem_signal_op operator: shmem_signal_op(signal, signal_offset, signal_value, signal_op, target_rank)""" 

125 world_size = dist.get_world_size() 

126 if target_rank < 0 or target_rank >= world_size: 

127 raise ValueError(f"target_rank must be in range [0, {world_size - 1}], but get {target_rank}") 

128 _ops.signal_op(signal, signal_offset, signal_value, signal_op, target_rank) 

129 

130 @staticmethod 

131 def shmem_wait_for_signal(depend_tensor, signal, signal_offset, compare_value, compare_op): 

132 """ 

133 shmem_wait_for_signal operator: 

134 shmem_wait_for_signal(depend_tensor, signal, signal_offset, compare_value, compare_op) 

135 """ 

136 _ops.signal_wait_until(depend_tensor, signal, signal_offset, compare_value, compare_op) 

137 

138 @staticmethod 

139 def shmem_put_with_signal(target, target_offset, src, src_offset, 

140 size, signal, signal_offset, signal_value, signal_op, target_rank): 

141 """ 

142 shmem_put_with_signal operator: 

143 shmem_put_with_signal(target, target_offset, src, src_offset, 

144 size, signal, signal_offset, signal_value, signal_op, target_rank) 

145 """ 

146 world_size = get_platform().get_world_size() 

147 if target_rank < 0 or target_rank >= world_size: 

148 raise ValueError(f"target_rank must be in range [0, {world_size - 1}], but get {target_rank}") 

149 _ops.put_mem_signal(target, target_offset, src, src_offset, 

150 size, signal, signal_offset, signal_value, signal_op, target_rank) 

151 

152 @classmethod 

153 def shmem_allgather(cls, output_tensor, input_tensor): 

154 """ 

155 allgather operator: shmem_allgather(output_tensor, input_tensor) 

156 

157 Performs an allgather collective operation using symmetric memory (SHMEM). 

158 Each process/rank contributes its local 'input_tensor', and upon completion, 

159 all processes receive the concatenated data from all ranks in 'output_tensor'. 

160 

161 Parameters: 

162 output_tensor: Output tensor that will contain the gathered data from all ranks. 

163 Its size should be (world_size * local_input_size). 

164 input_tensor: Local input tensor contributed by this rank. 

165 """ 

166 def to_tensor(x, dtype=torch.int64): 

167 return torch.tensor([x], dtype=dtype, device='npu') 

168 rank_id = dist.get_rank() 

169 world_size = dist.get_world_size() 

170 size = input_tensor.numel() 

171 if size * world_size != output_tensor.numel(): 

172 raise ValueError(f"All tensor must have same size, but in rank {world_size}, the size " 

173 f"of input_tensor is {size}, the size of output_tensor is {output_tensor.numel()}") 

174 signal = _manager.malloc(1, torch.int32) 

175 torch.zero_(signal) 

176 dist.barrier() 

177 remain = rank_id 

178 now_pe = 0 

179 while remain: 

180 for i in range(min(16, remain)): 

181 target_pe = now_pe + i 

182 with torch.npu.stream(cls.comm_streams[i]): 

183 _ops.put_mem_signal(output_tensor, to_tensor(size * world_size), input_tensor, to_tensor(0), 

184 to_tensor(size), signal, to_tensor(0), to_tensor(1, torch.int32), 1, target_pe) 

185 now_pe += 16 

186 remain -= min(16, remain) 

187 _ops.signal_wait_until(output_tensor, signal, to_tensor(0), to_tensor(rank_id, torch.int32), 0) 

188 _manager.free(signal) 

189 

190 @classmethod 

191 def shmem_alltoall(cls, send_tensor_list, receive_tensor, receive_list): 

192 """ 

193 alltoall operator: shmem_alltoall(send_tensor_list, receive_tensor, receive_list) 

194 

195 Performs an all-to-all collective operation using symmetric memory (SHMEM). 

196 Each process sends distinct data blocks to all other processes and receives  

197 corresponding blocks from all other processes. 

198 

199 Parameters: 

200 send_tensor_list: List of tensors to send to each rank. 

201 send_tensor_list[i] is the tensor sent to rank i. 

202 receive_tensor: Output tensor that will contain all received data. 

203 Its total size should be sum(receive_list). 

204 receive_list: List specifying the size of data to receive from each rank. 

205 receive_list[i] is the size (e.g., number of elements)  

206 of data to receive from rank i. 

207 """ 

208 def to_tensor(x, dtype=torch.int64): 

209 return torch.tensor([x], dtype=dtype, device='npu') 

210 world_size = dist.get_world_size() 

211 receive_offsets = torch.zeros_like(receive_list) 

212 send_offsets = torch.zeros_like(receive_list) 

213 for i in range(1, world_size): 

214 receive_offsets[i] = receive_offsets[i - 1] + receive_list[i - 1] 

215 dist.all_to_all_single(send_offsets, receive_offsets) 

216 signal = _manager.malloc(1, torch.int32) 

217 torch.zero_(signal) 

218 dist.barrier() 

219 remain = world_size 

220 now_pe = 0 

221 while remain: 

222 for i in range(min(16, remain)): 

223 target_pe = now_pe + i 

224 with torch.npu.stream(cls.comm_streams[i]): 

225 _ops.put_mem_signal(receive_tensor, send_offsets[target_pe], send_tensor_list[target_pe], 

226 to_tensor(0), to_tensor(send_tensor_list[target_pe].numel()), 

227 signal, to_tensor(0), to_tensor(1, torch.int32), 1, target_pe) 

228 now_pe += 16 

229 remain -= min(16, remain) 

230 _ops.signal_wait_until(receive_tensor, signal, to_tensor(0), to_tensor(world_size, torch.int32), 0) 

231 _manager.free(signal) 

232 

233 @classmethod 

234 def fused_all_gather_matmul(cls, a, b, c, gather_out, signal, block_size=None): 

235 """ 

236 fused_all_gather_matmul(a, b, c, gather_out, signal, block_size=None) 

237 Fused operator combining allgather and matmul operations. 

238 

239 Computational flow: 

240 1. gather_out = allgather(a) # Gather local tensor 'a' from all ranks 

241 2. c = ReduceScatter(gather_out @ b) # Matrix multiplication followed by reduce-scatter 

242 

243 Parameters: 

244 a: Local input tensor with shape (m_local, k). 

245 b: Weight matrix with shape (k, n). 

246 c: Output tensor with shape (m, n). 

247 gather_out: Output tensor containing gathered 'a' from all ranks,  

248 shape (m, k) where m = m_local * world_size. 

249 signal: Symmetric memory tensor with shape (world_size) and dtype int32. 

250 block_size: Optional block size for tiled computation. 

251 

252 Note:  

253 - This fusion reduces communication overhead by combining gather and matmul operations. 

254 """ 

255 def to_tensor(value, dtype=torch.int64): 

256 return torch.tensor(value, dtype=dtype, device='npu') 

257 

258 world_size = dist.get_world_size() 

259 rank_id = dist.get_rank() 

260 

261 m, k = a.shape 

262 

263 if block_size is None: 

264 block_size = max(1, m // min(world_size, 4)) 

265 block_size = min(block_size, m) 

266 num_blocks = (m + block_size - 1) // block_size 

267 

268 if m * world_size != gather_out.shape[0]: 

269 raise ValueError(f"gather_out shape mismatch: expected [{a.shape[0] * world_size}, {k}], " 

270 f"got {gather_out.shape}") 

271 

272 if gather_out.shape[0] != c.shape[0] or b.shape[1] != c.shape[1]: 

273 raise ValueError(f"Matmul output shape mismatch: expected [{gather_out.shape[0]}, {b.shape[1]}], " 

274 f"got {c.shape}") 

275 

276 int32_1 = torch.ones(1, dtype=torch.int32, device='npu') 

277 signal_offsets = torch.arange(0, world_size * num_blocks, dtype=torch.int64, device='npu') 

278 block_sizes = [] 

279 for i in range(num_blocks): 

280 if i < num_blocks - 1: 

281 block_sizes.append(block_size) 

282 else: 

283 block_sizes.append(m - i * block_size) 

284 for block_idx in range(num_blocks): 

285 

286 start_row = block_idx * block_size 

287 start_idx_tensor = to_tensor(start_row * k) 

288 block_local_size = block_sizes[block_idx] * k 

289 block_local_size_tensor = to_tensor(block_local_size) 

290 

291 remain = world_size 

292 now_pe = 0 

293 dst_offset_tensor = to_tensor((rank_id * m + start_row) * k) 

294 while remain: 

295 for i in range(min(16, remain)): 

296 target_pe = now_pe + i 

297 with torch.npu.stream(cls.comm_streams[i]): 

298 _ops.put_mem_signal( 

299 gather_out, dst_offset_tensor, 

300 a.view(-1), start_idx_tensor, 

301 block_local_size_tensor, signal, 

302 signal_offsets[rank_id * num_blocks + block_idx], int32_1, 

303 0, target_pe 

304 ) 

305 now_pe += 16 

306 remain -= min(16, remain) 

307 for rank in range(world_size): 

308 with torch.npu.stream(cls.compute_streams[rank]): 

309 for block_idx in range(num_blocks): 

310 _ops.signal_wait_until(gather_out, signal, 

311 signal_offsets[rank * num_blocks + block_idx], 

312 int32_1, 0 

313 ) 

314 start_row = rank * m + block_size * block_idx 

315 end_row = start_row + block_sizes[block_idx] 

316 c[start_row:end_row, :] = torch.matmul(gather_out[start_row:end_row, :], b) 

317 for stream in cls.comm_streams: 

318 stream.synchronize() 

319 for stream in cls.compute_streams: 

320 stream.synchronize() 

321 

322 return gather_out, c 

323 

324 @classmethod 

325 def fused_matmul_reduce_scatter(cls, x1, x2, symm_tensor, signal, reduce_op='sum'): 

326 """ 

327 Fusion operator: Fuses Matmul and ReduceScatter operations. 

328 Computation formula: output = ReduceScatter(x1 @ x2) 

329 

330 Parameters: 

331 x1: Left matrix with shape (m, k). 'm' must be an integer multiple of the number of devices (world_size). 

332 x2: Right matrix with shape (k, n). 

333 symm_tensor: Symmetric memory tensor with shape (m , n). 

334 signal: Symmetric memory tensor with shape (world_size) and dtype int32. 

335 reduce_op: Operator of scatter, only support 'sum' and 'avg'. Default value is 'sum'. 

336 

337 output: Output matrix with shape (m / world_size, n). 

338 """ 

339 def to_tensor(x, dtype=torch.int64): 

340 return torch.tensor([x], dtype=dtype, device='npu') 

341 

342 world_size = dist.get_world_size() 

343 rank_id = dist.get_rank() 

344 

345 m, k = x1.shape 

346 k2, n = x2.shape 

347 

348 if k != k2: 

349 raise ValueError(f"Dimension k of x1 and x2 does not match: x1.k={k}, x2.k={k2}.") 

350 

351 if x1.dtype != x2.dtype: 

352 raise ValueError(f"Matrix multiplication requires both tensors to have the same data type:" 

353 f"x1.dtype={x1.dtype}, x2.dtype={x2.dtype}.") 

354 

355 if m % world_size != 0: 

356 raise ValueError(f"The number of rows m={m} in x1 must be divisible by the world size (number of devices) " 

357 f"world_size={world_size}.") 

358 

359 if reduce_op not in ['sum', 'avg']: 

360 raise ValueError(f"The operator of scatter only supports sum and avg, but get {reduce_op}.") 

361 

362 block_size = m // world_size 

363 

364 size_tensor = to_tensor(block_size * n) 

365 int32_1 = torch.ones(1, dtype=torch.int32, device='npu') 

366 offsets = torch.arange(0, world_size, dtype=torch.int64, device='npu') 

367 dst_offset_tensor = to_tensor(block_size * n * rank_id) 

368 output = torch.matmul(x1[rank_id * block_size:rank_id * block_size + block_size, :], x2) 

369 for rank in range(1, world_size): 

370 with torch.npu.stream(cls.compute_streams[rank]): 

371 

372 block_idx = (rank_id + rank) % world_size 

373 start_row = block_idx * block_size 

374 end_row = start_row + block_size 

375 x1_block = x1[start_row:end_row, :] 

376 

377 block_result = torch.matmul(x1_block, x2) 

378 

379 _ops.put_mem_signal( 

380 symm_tensor, dst_offset_tensor, 

381 block_result, offsets[0], 

382 size_tensor, signal, 

383 offsets[rank_id], int32_1, 

384 0, block_idx 

385 ) 

386 for rank in range(1, world_size): 

387 block_idx = (rank_id - rank) % world_size 

388 with torch.npu.stream(cls.comm_streams[block_idx]): 

389 _ops.signal_wait_until(symm_tensor, signal, 

390 offsets[block_idx], int32_1, 0 

391 ) 

392 output.add_(symm_tensor[block_idx * block_size:block_idx * block_size + block_size, :]) 

393 

394 for stream in cls.compute_streams: 

395 stream.synchronize() 

396 for stream in cls.comm_streams: 

397 stream.synchronize() 

398 if reduce_op == 'sum': 

399 return output 

400 if reduce_op == 'avg': 

401 return output / world_size