1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Symmetric Memory"""
16import os
17import mindspore as ms
18from mindspore.runtime import Stream, StreamCtx, Event
19from hyper_parallel.platform import get_platform
20from . import aclshmem_ms
21
22_is_shmem_available = False
23current_dir = os.path.dirname(os.path.abspath(__file__))
24_lib_path = os.path.join(current_dir, 'aclshmem_ms/aclshmem_ms.so')
25if os.path.exists(_lib_path):
26 _is_shmem_available = True
27
28
29class MSSymmetricMemoryHandler:
30 """SymmetricMemory is used for one-sided communication."""
31
32 _is_init = False
33 _mem_pool = None
34 streamlist = []
35
36 @classmethod
37 def _init_shmem(cls):
38 """init platform"""
39 if cls._is_init:
40 return
41 allocator = ms.runtime.PluggableAllocator(_lib_path, "Alloc", "Free")
42 cls._mem_pool = ms.runtime.MemPool(allocator)
43 cls.streamlist = [Stream() for _ in range(get_platform().get_world_size())]
44 cls._is_init = True
45
46 @staticmethod
47 def is_shmem_available():
48 return _is_shmem_available
49
50 @staticmethod
51 def empty(size, dtype):
52 if not MSSymmetricMemoryHandler._is_init:
53 MSSymmetricMemoryHandler._init_shmem()
54 with ms.runtime.use_mem_pool(MSSymmetricMemoryHandler._mem_pool):
55 return ms.mint.empty(size, dtype=dtype)
56
57 @staticmethod
58 def rendezvous(tensor, group):
59 raise NotImplementedError("In CANN SHMEM v1.0.0, rendezvous is not needed, "
60 "symmetric memory are allocated at init time by SYMMETRIC_MEMORY_HEAP_SIZE, "
61 "so this function is not implemented. ")
62
63 @staticmethod
64 def set_signal_pad_size(size: int) -> None:
65 raise NotImplementedError("In CANN SHMEM v1.0.0, set_signal_pad_size is not needed, "
66 "symmetric memory are allocated at init time by SYMMETRIC_MEMORY_HEAP_SIZE, "
67 "you can create symmetric signal memory by empty() "
68 "so this function is not implemented. ")
69
70 @staticmethod
71 def get_signal_pad_size() -> int:
72 raise NotImplementedError("In CANN SHMEM v1.0.0, get_signal_pad_size is not needed, "
73 "symmetric memory are allocated at init time by SYMMETRIC_MEMORY_HEAP_SIZE, "
74 "you can create symmetric signal memory by empty() "
75 "so this function is not implemented. ")
76
77 @staticmethod
78 def shmem_put(target, target_offset, src, src_offset, size, target_rank):
79 """shmem_put operator: shmem_put(target, target_offset, src, src_offset, size, target_rank)"""
80 world_size = get_platform().get_world_size()
81 if target_rank < 0 or target_rank >= world_size:
82 raise ValueError(f"target_rank must be in range [0, {world_size - 1}], but get {target_rank}")
83 aclshmem_ms.put_mem(target, target_offset, src, src_offset, size, target_rank)
84
85 @staticmethod
86 def shmem_get(target, target_offset, src, src_offset, size, target_rank):
87 """shmem_get operator: shmem_get(target, target_offset, src, src_offset, size, target_rank)"""
88 world_size = get_platform().get_world_size()
89 if target_rank < 0 or target_rank >= world_size:
90 raise ValueError(f"target_rank must be in range [0, {world_size - 1}], but get {target_rank}")
91 aclshmem_ms.get_mem(target, target_offset, src, src_offset, size, target_rank)
92
93 @staticmethod
94 def shmem_signal_op(signal, signal_offset, signal_value, signal_op, target_rank):
95 """shmem_signal_op operator: shmem_signal_op(signal, signal_offset, signal_value, signal_op, target_rank)"""
96 world_size = get_platform().get_world_size()
97 if target_rank < 0 or target_rank >= world_size:
98 raise ValueError(f"target_rank must be in range [0, {world_size - 1}], but get {target_rank}")
99 aclshmem_ms.signal_op(signal, signal_offset, signal_value, signal_op, target_rank)
100
101 @staticmethod
102 def shmem_wait_for_signal(depend_tensor, signal, signal_offset, compare_value, compare_op):
103 aclshmem_ms.signal_wait_until(depend_tensor, signal, signal_offset, compare_value, compare_op)
104
105 @staticmethod
106 def shmem_put_with_signal(target, target_offset, src, src_offset,
107 size, signal, signal_offset, signal_value, signal_op, target_rank):
108 """shmem_put_with_signal operator"""
109 world_size = get_platform().get_world_size()
110 if target_rank < 0 or target_rank >= world_size:
111 raise ValueError(f"target_rank must be in range [0, {world_size - 1}], but get {target_rank}")
112 aclshmem_ms.put_mem_signal(target, target_offset, src, src_offset,
113 size, signal, signal_offset, signal_value, signal_op, target_rank)
114
115 @staticmethod
116 def barrier():
117 ms.mint.distributed.barrier()
118
119 @classmethod
120 def shmem_allgather(cls, output_tensor, input_tensor):
121 """
122 allgather operator: shmem_allgather(output_tensor, input_tensor)
123
124 Performs an allgather collective operation using symmetric memory (SHMEM).
125 Each process/rank contributes its local 'input_tensor', and upon completion,
126 all processes receive the concatenated data from all ranks in 'output_tensor'.
127
128 Parameters:
129 output_tensor: Output tensor that will contain the gathered data from all ranks.
130 Its size should be (world_size * local_input_size).
131 input_tensor: Local input tensor contributed by this rank.
132 """
133 rank_id= get_platform().get_rank()
134 world_size = get_platform().get_world_size()
135 size= input_tensor.numel()
136 if size * world_size != output_tensor.numel():
137 raise ValueError(f"All tensor must have same size, but in rank {world_size}, the size "
138 f"of input_tensor is {size}, the size of output_tensor is {output_tensor.numel()}")
139 signal = cls.empty(1, dtype=ms.int32)
140 signal.fill_(0)
141 end_event_list=[]
142 event_signal = Event()
143 event_signal.record()
144 for i in range(world_size):
145 with StreamCtx(cls.streamlist[i]):
146 event_signal.wait()
147 cls.shmem_put_with_signal(output_tensor, ms.Tensor(rank_id * size, dtype=ms.int64), input_tensor,
148 ms.Tensor(0, dtype=ms.int64),ms.Tensor(size,dtype=ms.int64),
149 signal, ms.Tensor(0, dtype=ms.int64), ms.Tensor(1, dtype=ms.int32), 1, i)
150 end_event = Event()
151 end_event.record()
152 end_event_list.append(end_event)
153 for event in end_event_list:
154 event.wait()
155 cls.shmem_wait_for_signal(output_tensor, signal, ms.Tensor(0, dtype=ms.int64),
156 ms.Tensor(world_size, dtype=ms.int32), 0)
157
158 @classmethod
159 def shmem_alltoall(cls, send_tensor_list, receive_tensor, receive_list):
160 raise NotImplementedError("Mindspore SymmetricMemory will implement shmem_alltoall later")
161
162 @classmethod
163 def fused_all_gather_matmul(cls, a, b, c, gather_out, signal, block_size=None):
164 raise NotImplementedError("Mindspore SymmetricMemory will implement fused_all_gather_matmul later")
165
166 @classmethod
167 def fused_matmul_reduce_scatter(cls, x1, x2, symm_tensor, signal, reduce_op):
168 raise NotImplementedError("Mindspore SymmetricMemory will implement fused_matmul_reduce_scatter later")