Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / mindspore / symmetric_memory / symmetric_memory.py: 0%

101 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-20 07:18 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Symmetric Memory""" 

16import os 

17import mindspore as ms 

18from mindspore.runtime import Stream, StreamCtx, Event 

19from hyper_parallel.platform import get_platform 

20from . import aclshmem_ms 

21 

22_is_shmem_available = False 

23current_dir = os.path.dirname(os.path.abspath(__file__)) 

24_lib_path = os.path.join(current_dir, 'aclshmem_ms/aclshmem_ms.so') 

25if os.path.exists(_lib_path): 

26 _is_shmem_available = True 

27 

28 

29class MSSymmetricMemoryHandler: 

30 """SymmetricMemory is used for one-sided communication.""" 

31 

32 _is_init = False 

33 _mem_pool = None 

34 streamlist = [] 

35 

36 @classmethod 

37 def _init_shmem(cls): 

38 """init platform""" 

39 if cls._is_init: 

40 return 

41 allocator = ms.runtime.PluggableAllocator(_lib_path, "Alloc", "Free") 

42 cls._mem_pool = ms.runtime.MemPool(allocator) 

43 cls.streamlist = [Stream() for _ in range(get_platform().get_world_size())] 

44 cls._is_init = True 

45 

46 @staticmethod 

47 def is_shmem_available(): 

48 return _is_shmem_available 

49 

50 @staticmethod 

51 def empty(size, dtype): 

52 if not MSSymmetricMemoryHandler._is_init: 

53 MSSymmetricMemoryHandler._init_shmem() 

54 with ms.runtime.use_mem_pool(MSSymmetricMemoryHandler._mem_pool): 

55 return ms.mint.empty(size, dtype=dtype) 

56 

57 @staticmethod 

58 def rendezvous(tensor, group): 

59 raise NotImplementedError("In CANN SHMEM v1.0.0, rendezvous is not needed, " 

60 "symmetric memory are allocated at init time by SYMMETRIC_MEMORY_HEAP_SIZE, " 

61 "so this function is not implemented. ") 

62 

63 @staticmethod 

64 def set_signal_pad_size(size: int) -> None: 

65 raise NotImplementedError("In CANN SHMEM v1.0.0, set_signal_pad_size is not needed, " 

66 "symmetric memory are allocated at init time by SYMMETRIC_MEMORY_HEAP_SIZE, " 

67 "you can create symmetric signal memory by empty() " 

68 "so this function is not implemented. ") 

69 

70 @staticmethod 

71 def get_signal_pad_size() -> int: 

72 raise NotImplementedError("In CANN SHMEM v1.0.0, get_signal_pad_size is not needed, " 

73 "symmetric memory are allocated at init time by SYMMETRIC_MEMORY_HEAP_SIZE, " 

74 "you can create symmetric signal memory by empty() " 

75 "so this function is not implemented. ") 

76 

77 @staticmethod 

78 def shmem_put(target, target_offset, src, src_offset, size, target_rank): 

79 """shmem_put operator: shmem_put(target, target_offset, src, src_offset, size, target_rank)""" 

80 world_size = get_platform().get_world_size() 

81 if target_rank < 0 or target_rank >= world_size: 

82 raise ValueError(f"target_rank must be in range [0, {world_size - 1}], but get {target_rank}") 

83 aclshmem_ms.put_mem(target, target_offset, src, src_offset, size, target_rank) 

84 

85 @staticmethod 

86 def shmem_get(target, target_offset, src, src_offset, size, target_rank): 

87 """shmem_get operator: shmem_get(target, target_offset, src, src_offset, size, target_rank)""" 

88 world_size = get_platform().get_world_size() 

89 if target_rank < 0 or target_rank >= world_size: 

90 raise ValueError(f"target_rank must be in range [0, {world_size - 1}], but get {target_rank}") 

91 aclshmem_ms.get_mem(target, target_offset, src, src_offset, size, target_rank) 

92 

93 @staticmethod 

94 def shmem_signal_op(signal, signal_offset, signal_value, signal_op, target_rank): 

95 """shmem_signal_op operator: shmem_signal_op(signal, signal_offset, signal_value, signal_op, target_rank)""" 

96 world_size = get_platform().get_world_size() 

97 if target_rank < 0 or target_rank >= world_size: 

98 raise ValueError(f"target_rank must be in range [0, {world_size - 1}], but get {target_rank}") 

99 aclshmem_ms.signal_op(signal, signal_offset, signal_value, signal_op, target_rank) 

100 

101 @staticmethod 

102 def shmem_wait_for_signal(depend_tensor, signal, signal_offset, compare_value, compare_op): 

103 aclshmem_ms.signal_wait_until(depend_tensor, signal, signal_offset, compare_value, compare_op) 

104 

105 @staticmethod 

106 def shmem_put_with_signal(target, target_offset, src, src_offset, 

107 size, signal, signal_offset, signal_value, signal_op, target_rank): 

108 """shmem_put_with_signal operator""" 

109 world_size = get_platform().get_world_size() 

110 if target_rank < 0 or target_rank >= world_size: 

111 raise ValueError(f"target_rank must be in range [0, {world_size - 1}], but get {target_rank}") 

112 aclshmem_ms.put_mem_signal(target, target_offset, src, src_offset, 

113 size, signal, signal_offset, signal_value, signal_op, target_rank) 

114 

115 @staticmethod 

116 def barrier(): 

117 ms.mint.distributed.barrier() 

118 

119 @classmethod 

120 def shmem_allgather(cls, output_tensor, input_tensor): 

121 """ 

122 allgather operator: shmem_allgather(output_tensor, input_tensor) 

123 

124 Performs an allgather collective operation using symmetric memory (SHMEM). 

125 Each process/rank contributes its local 'input_tensor', and upon completion, 

126 all processes receive the concatenated data from all ranks in 'output_tensor'. 

127 

128 Parameters: 

129 output_tensor: Output tensor that will contain the gathered data from all ranks. 

130 Its size should be (world_size * local_input_size). 

131 input_tensor: Local input tensor contributed by this rank. 

132 """ 

133 rank_id= get_platform().get_rank() 

134 world_size = get_platform().get_world_size() 

135 size= input_tensor.numel() 

136 if size * world_size != output_tensor.numel(): 

137 raise ValueError(f"All tensor must have same size, but in rank {world_size}, the size " 

138 f"of input_tensor is {size}, the size of output_tensor is {output_tensor.numel()}") 

139 signal = cls.empty(1, dtype=ms.int32) 

140 signal.fill_(0) 

141 end_event_list=[] 

142 event_signal = Event() 

143 event_signal.record() 

144 for i in range(world_size): 

145 with StreamCtx(cls.streamlist[i]): 

146 event_signal.wait() 

147 cls.shmem_put_with_signal(output_tensor, ms.Tensor(rank_id * size, dtype=ms.int64), input_tensor, 

148 ms.Tensor(0, dtype=ms.int64),ms.Tensor(size,dtype=ms.int64), 

149 signal, ms.Tensor(0, dtype=ms.int64), ms.Tensor(1, dtype=ms.int32), 1, i) 

150 end_event = Event() 

151 end_event.record() 

152 end_event_list.append(end_event) 

153 for event in end_event_list: 

154 event.wait() 

155 cls.shmem_wait_for_signal(output_tensor, signal, ms.Tensor(0, dtype=ms.int64), 

156 ms.Tensor(world_size, dtype=ms.int32), 0) 

157 

158 @classmethod 

159 def shmem_alltoall(cls, send_tensor_list, receive_tensor, receive_list): 

160 raise NotImplementedError("Mindspore SymmetricMemory will implement shmem_alltoall later") 

161 

162 @classmethod 

163 def fused_all_gather_matmul(cls, a, b, c, gather_out, signal, block_size=None): 

164 raise NotImplementedError("Mindspore SymmetricMemory will implement fused_all_gather_matmul later") 

165 

166 @classmethod 

167 def fused_matmul_reduce_scatter(cls, x1, x2, symm_tensor, signal, reduce_op): 

168 raise NotImplementedError("Mindspore SymmetricMemory will implement fused_matmul_reduce_scatter later")