1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Symmetric Memory"""
16
17import os
18from logging import getLogger
19import torch
20import torch.distributed as dist
21from hyper_parallel.platform import get_platform
22
23logger = getLogger(__name__)
24
25_is_shmem_available = False
26
27_manager = None
28_ops = None
29current_dir = os.path.dirname(os.path.abspath(__file__))
30file_path = os.path.join(current_dir, 'libaclshmem_torch.so')
31if os.path.exists(file_path):
32 torch.ops.load_library(file_path)
33 _manager = torch.classes.SymmetricMemory.Manager()
34 _ops = torch.classes.SymmetricMemory.Ops()
35 _is_shmem_available = True
36
37
38class TorchSymmetricMemoryHandler:
39 """SymmetricMemory is used for one-sided communication."""
40 _is_init = False
41 comm_streams = []
42 compute_streams = []
43
44 @classmethod
45 def _init_shmem(cls):
46 """init platform"""
47 if cls._is_init:
48 return
49 logger.info("start init torch symmetric memory")
50 platform = get_platform()
51 rank_id = platform.get_rank()
52 world_size = platform.get_world_size()
53 local_mem_size = os.getenv("SYMMETRIC_MEMORY_HEAP_SIZE")
54 if local_mem_size is not None:
55 local_mem_size = int(local_mem_size)
56 else:
57 local_mem_size = 1024 * 1024 * 1024
58 ipports = "tcp://127.0.0.1:8662"
59 logger.info("start init ach shmem: rank_id:%d, rank size:%d, heap size:%d, ipports:%s",
60 rank_id, world_size, local_mem_size, ipports)
61 _manager.attr_init(rank_id, world_size, local_mem_size, ipports)
62 cls.comm_streams = [torch.npu.Stream() for _ in range(min(world_size, 16))]
63 cls.compute_streams = [torch.npu.Stream() for _ in range(world_size)]
64 cls._is_init = True
65 logger.info("init symmetric memory success!")
66
67 @staticmethod
68 def is_shmem_available():
69 return _is_shmem_available
70
71 @staticmethod
72 def empty(shape, dtype):
73 """create symmetric memory tensor"""
74 if isinstance(shape, int):
75 shape = [shape]
76 elif isinstance(shape, tuple):
77 shape = list(shape)
78 if not TorchSymmetricMemoryHandler._is_init:
79 TorchSymmetricMemoryHandler._init_shmem()
80 return _manager.malloc(shape, dtype)
81
82 @staticmethod
83 def barrier():
84 dist.barrier()
85
86 @staticmethod
87 def rendezvous(tensor, group):
88 raise NotImplementedError("In CANN SHMEM v1.0.0, rendezvous is not needed, "
89 "symmetric memory are allocated at init time by SYMMETRIC_MEMORY_HEAP_SIZE, "
90 "so this function is not implemented. ")
91
92 @staticmethod
93 def set_signal_pad_size(size: int) -> None:
94 raise NotImplementedError("In CANN SHMEM v1.0.0, set_signal_pad_size is not needed, "
95 "symmetric memory are allocated at init time by SYMMETRIC_MEMORY_HEAP_SIZE, "
96 "you can create symmetric signal memory by empty() "
97 "so this function is not implemented. ")
98
99 @staticmethod
100 def get_signal_pad_size() -> int:
101 raise NotImplementedError("In CANN SHMEM v1.0.0, get_signal_pad_size is not needed, "
102 "symmetric memory are allocated at init time by SYMMETRIC_MEMORY_HEAP_SIZE, "
103 "you can create symmetric signal memory by empty() "
104 "so this function is not implemented. ")
105
106 @staticmethod
107 def shmem_put(target, target_offset, src, src_offset, size, target_rank):
108 """shmem_put operator: shmem_put(target, target_offset, src, src_offset, size, target_rank)"""
109 world_size = dist.get_world_size()
110 if target_rank < 0 or target_rank >= world_size:
111 raise ValueError(f"target_rank must be in range [0, {world_size - 1}], but get {target_rank}")
112 _ops.put_mem(target, target_offset, src, src_offset, size, target_rank)
113
114 @staticmethod
115 def shmem_get(target, target_offset, src, src_offset, size, target_rank):
116 """shmem_get operator: shmem_get(target, target_offset, src, src_offset, size, target_rank)"""
117 world_size = dist.get_world_size()
118 if target_rank < 0 or target_rank >= world_size:
119 raise ValueError(f"target_rank must be in range [0, {world_size - 1}], but get {target_rank}")
120 _ops.get_mem(target, target_offset, src, src_offset, size, target_rank)
121
122 @staticmethod
123 def shmem_signal_op(signal, signal_offset, signal_value, signal_op, target_rank):
124 """shmem_signal_op operator: shmem_signal_op(signal, signal_offset, signal_value, signal_op, target_rank)"""
125 world_size = dist.get_world_size()
126 if target_rank < 0 or target_rank >= world_size:
127 raise ValueError(f"target_rank must be in range [0, {world_size - 1}], but get {target_rank}")
128 _ops.signal_op(signal, signal_offset, signal_value, signal_op, target_rank)
129
130 @staticmethod
131 def shmem_wait_for_signal(depend_tensor, signal, signal_offset, compare_value, compare_op):
132 """
133 shmem_wait_for_signal operator:
134 shmem_wait_for_signal(depend_tensor, signal, signal_offset, compare_value, compare_op)
135 """
136 _ops.signal_wait_until(depend_tensor, signal, signal_offset, compare_value, compare_op)
137
138 @staticmethod
139 def shmem_put_with_signal(target, target_offset, src, src_offset,
140 size, signal, signal_offset, signal_value, signal_op, target_rank):
141 """
142 shmem_put_with_signal operator:
143 shmem_put_with_signal(target, target_offset, src, src_offset,
144 size, signal, signal_offset, signal_value, signal_op, target_rank)
145 """
146 world_size = get_platform().get_world_size()
147 if target_rank < 0 or target_rank >= world_size:
148 raise ValueError(f"target_rank must be in range [0, {world_size - 1}], but get {target_rank}")
149 _ops.put_mem_signal(target, target_offset, src, src_offset,
150 size, signal, signal_offset, signal_value, signal_op, target_rank)
151
152 @classmethod
153 def shmem_allgather(cls, output_tensor, input_tensor):
154 """
155 allgather operator: shmem_allgather(output_tensor, input_tensor)
156
157 Performs an allgather collective operation using symmetric memory (SHMEM).
158 Each process/rank contributes its local 'input_tensor', and upon completion,
159 all processes receive the concatenated data from all ranks in 'output_tensor'.
160
161 Parameters:
162 output_tensor: Output tensor that will contain the gathered data from all ranks.
163 Its size should be (world_size * local_input_size).
164 input_tensor: Local input tensor contributed by this rank.
165 """
166 def to_tensor(x, dtype=torch.int64):
167 return torch.tensor([x], dtype=dtype, device='npu')
168 rank_id = dist.get_rank()
169 world_size = dist.get_world_size()
170 size = input_tensor.numel()
171 if size * world_size != output_tensor.numel():
172 raise ValueError(f"All tensor must have same size, but in rank {world_size}, the size "
173 f"of input_tensor is {size}, the size of output_tensor is {output_tensor.numel()}")
174 signal = _manager.malloc(1, torch.int32)
175 torch.zero_(signal)
176 dist.barrier()
177 remain = rank_id
178 now_pe = 0
179 while remain:
180 for i in range(min(16, remain)):
181 target_pe = now_pe + i
182 with torch.npu.stream(cls.comm_streams[i]):
183 _ops.put_mem_signal(output_tensor, to_tensor(size * world_size), input_tensor, to_tensor(0),
184 to_tensor(size), signal, to_tensor(0), to_tensor(1, torch.int32), 1, target_pe)
185 now_pe += 16
186 remain -= min(16, remain)
187 _ops.signal_wait_until(output_tensor, signal, to_tensor(0), to_tensor(rank_id, torch.int32), 0)
188 _manager.free(signal)
189
190 @classmethod
191 def shmem_alltoall(cls, send_tensor_list, receive_tensor, receive_list):
192 """
193 alltoall operator: shmem_alltoall(send_tensor_list, receive_tensor, receive_list)
194
195 Performs an all-to-all collective operation using symmetric memory (SHMEM).
196 Each process sends distinct data blocks to all other processes and receives
197 corresponding blocks from all other processes.
198
199 Parameters:
200 send_tensor_list: List of tensors to send to each rank.
201 send_tensor_list[i] is the tensor sent to rank i.
202 receive_tensor: Output tensor that will contain all received data.
203 Its total size should be sum(receive_list).
204 receive_list: List specifying the size of data to receive from each rank.
205 receive_list[i] is the size (e.g., number of elements)
206 of data to receive from rank i.
207 """
208 def to_tensor(x, dtype=torch.int64):
209 return torch.tensor([x], dtype=dtype, device='npu')
210 world_size = dist.get_world_size()
211 receive_offsets = torch.zeros_like(receive_list)
212 send_offsets = torch.zeros_like(receive_list)
213 for i in range(1, world_size):
214 receive_offsets[i] = receive_offsets[i - 1] + receive_list[i - 1]
215 dist.all_to_all_single(send_offsets, receive_offsets)
216 signal = _manager.malloc(1, torch.int32)
217 torch.zero_(signal)
218 dist.barrier()
219 remain = world_size
220 now_pe = 0
221 while remain:
222 for i in range(min(16, remain)):
223 target_pe = now_pe + i
224 with torch.npu.stream(cls.comm_streams[i]):
225 _ops.put_mem_signal(receive_tensor, send_offsets[target_pe], send_tensor_list[target_pe],
226 to_tensor(0), to_tensor(send_tensor_list[target_pe].numel()),
227 signal, to_tensor(0), to_tensor(1, torch.int32), 1, target_pe)
228 now_pe += 16
229 remain -= min(16, remain)
230 _ops.signal_wait_until(receive_tensor, signal, to_tensor(0), to_tensor(world_size, torch.int32), 0)
231 _manager.free(signal)
232
233 @classmethod
234 def fused_all_gather_matmul(cls, a, b, c, gather_out, signal, block_size=None):
235 """
236 fused_all_gather_matmul(a, b, c, gather_out, signal, block_size=None)
237 Fused operator combining allgather and matmul operations.
238
239 Computational flow:
240 1. gather_out = allgather(a) # Gather local tensor 'a' from all ranks
241 2. c = ReduceScatter(gather_out @ b) # Matrix multiplication followed by reduce-scatter
242
243 Parameters:
244 a: Local input tensor with shape (m_local, k).
245 b: Weight matrix with shape (k, n).
246 c: Output tensor with shape (m, n).
247 gather_out: Output tensor containing gathered 'a' from all ranks,
248 shape (m, k) where m = m_local * world_size.
249 signal: Symmetric memory tensor with shape (world_size) and dtype int32.
250 block_size: Optional block size for tiled computation.
251
252 Note:
253 - This fusion reduces communication overhead by combining gather and matmul operations.
254 """
255 def to_tensor(value, dtype=torch.int64):
256 return torch.tensor(value, dtype=dtype, device='npu')
257
258 world_size = dist.get_world_size()
259 rank_id = dist.get_rank()
260
261 m, k = a.shape
262
263 if block_size is None:
264 block_size = max(1, m // min(world_size, 4))
265 block_size = min(block_size, m)
266 num_blocks = (m + block_size - 1) // block_size
267
268 if m * world_size != gather_out.shape[0]:
269 raise ValueError(f"gather_out shape mismatch: expected [{a.shape[0] * world_size}, {k}], "
270 f"got {gather_out.shape}")
271
272 if gather_out.shape[0] != c.shape[0] or b.shape[1] != c.shape[1]:
273 raise ValueError(f"Matmul output shape mismatch: expected [{gather_out.shape[0]}, {b.shape[1]}], "
274 f"got {c.shape}")
275
276 int32_1 = torch.ones(1, dtype=torch.int32, device='npu')
277 signal_offsets = torch.arange(0, world_size * num_blocks, dtype=torch.int64, device='npu')
278 block_sizes = []
279 for i in range(num_blocks):
280 if i < num_blocks - 1:
281 block_sizes.append(block_size)
282 else:
283 block_sizes.append(m - i * block_size)
284 for block_idx in range(num_blocks):
285
286 start_row = block_idx * block_size
287 start_idx_tensor = to_tensor(start_row * k)
288 block_local_size = block_sizes[block_idx] * k
289 block_local_size_tensor = to_tensor(block_local_size)
290
291 remain = world_size
292 now_pe = 0
293 dst_offset_tensor = to_tensor((rank_id * m + start_row) * k)
294 while remain:
295 for i in range(min(16, remain)):
296 target_pe = now_pe + i
297 with torch.npu.stream(cls.comm_streams[i]):
298 _ops.put_mem_signal(
299 gather_out, dst_offset_tensor,
300 a.view(-1), start_idx_tensor,
301 block_local_size_tensor, signal,
302 signal_offsets[rank_id * num_blocks + block_idx], int32_1,
303 0, target_pe
304 )
305 now_pe += 16
306 remain -= min(16, remain)
307 for rank in range(world_size):
308 with torch.npu.stream(cls.compute_streams[rank]):
309 for block_idx in range(num_blocks):
310 _ops.signal_wait_until(gather_out, signal,
311 signal_offsets[rank * num_blocks + block_idx],
312 int32_1, 0
313 )
314 start_row = rank * m + block_size * block_idx
315 end_row = start_row + block_sizes[block_idx]
316 c[start_row:end_row, :] = torch.matmul(gather_out[start_row:end_row, :], b)
317 for stream in cls.comm_streams:
318 stream.synchronize()
319 for stream in cls.compute_streams:
320 stream.synchronize()
321
322 return gather_out, c
323
324 @classmethod
325 def fused_matmul_reduce_scatter(cls, x1, x2, symm_tensor, signal, reduce_op='sum'):
326 """
327 Fusion operator: Fuses Matmul and ReduceScatter operations.
328 Computation formula: output = ReduceScatter(x1 @ x2)
329
330 Parameters:
331 x1: Left matrix with shape (m, k). 'm' must be an integer multiple of the number of devices (world_size).
332 x2: Right matrix with shape (k, n).
333 symm_tensor: Symmetric memory tensor with shape (m , n).
334 signal: Symmetric memory tensor with shape (world_size) and dtype int32.
335 reduce_op: Operator of scatter, only support 'sum' and 'avg'. Default value is 'sum'.
336
337 output: Output matrix with shape (m / world_size, n).
338 """
339 def to_tensor(x, dtype=torch.int64):
340 return torch.tensor([x], dtype=dtype, device='npu')
341
342 world_size = dist.get_world_size()
343 rank_id = dist.get_rank()
344
345 m, k = x1.shape
346 k2, n = x2.shape
347
348 if k != k2:
349 raise ValueError(f"Dimension k of x1 and x2 does not match: x1.k={k}, x2.k={k2}.")
350
351 if x1.dtype != x2.dtype:
352 raise ValueError(f"Matrix multiplication requires both tensors to have the same data type:"
353 f"x1.dtype={x1.dtype}, x2.dtype={x2.dtype}.")
354
355 if m % world_size != 0:
356 raise ValueError(f"The number of rows m={m} in x1 must be divisible by the world size (number of devices) "
357 f"world_size={world_size}.")
358
359 if reduce_op not in ['sum', 'avg']:
360 raise ValueError(f"The operator of scatter only supports sum and avg, but get {reduce_op}.")
361
362 block_size = m // world_size
363
364 size_tensor = to_tensor(block_size * n)
365 int32_1 = torch.ones(1, dtype=torch.int32, device='npu')
366 offsets = torch.arange(0, world_size, dtype=torch.int64, device='npu')
367 dst_offset_tensor = to_tensor(block_size * n * rank_id)
368 output = torch.matmul(x1[rank_id * block_size:rank_id * block_size + block_size, :], x2)
369 for rank in range(1, world_size):
370 with torch.npu.stream(cls.compute_streams[rank]):
371
372 block_idx = (rank_id + rank) % world_size
373 start_row = block_idx * block_size
374 end_row = start_row + block_size
375 x1_block = x1[start_row:end_row, :]
376
377 block_result = torch.matmul(x1_block, x2)
378
379 _ops.put_mem_signal(
380 symm_tensor, dst_offset_tensor,
381 block_result, offsets[0],
382 size_tensor, signal,
383 offsets[rank_id], int32_1,
384 0, block_idx
385 )
386 for rank in range(1, world_size):
387 block_idx = (rank_id - rank) % world_size
388 with torch.npu.stream(cls.comm_streams[block_idx]):
389 _ops.signal_wait_until(symm_tensor, signal,
390 offsets[block_idx], int32_1, 0
391 )
392 output.add_(symm_tensor[block_idx * block_size:block_idx * block_size + block_size, :])
393
394 for stream in cls.compute_streams:
395 stream.synchronize()
396 for stream in cls.comm_streams:
397 stream.synchronize()
398 if reduce_op == 'sum':
399 return output
400 if reduce_op == 'avg':
401 return output / world_size