Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / mindspore / pipeline_parallel / _utils.py: 18%

84 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-20 07:18 +0800

1# http://www.apache.org/licenses/LICENSE-2.0 

2# 

3# Unless required by applicable law or agreed to in writing, software 

4# distributed under the License is distributed on an "AS IS" BASIS, 

5# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

6# See the License for the specific language governing permissions and 

7# limitations under the License. 

8# ============================================================================ 

9"""pipeline parallel utils""" 

10import io 

11import pickle 

12 

13from mindspore import nn, Tensor, mint, ops 

14from mindspore.common import dtype as mstype 

15from mindspore.communication import GlobalComm 

16from mindspore.mint.distributed.distributed import _object_to_tensor, send, recv 

17 

18import hyper_parallel 

19from hyper_parallel.core.shard.custom_shard import custom_shard 

20 

21 

22class _MicroBatch(nn.Cell): 

23 """ 

24 Split inputs into micro_batch in pipeline parallel. 

25 

26 Args: 

27 micro_batch_num (int): The number of micro-batch. 

28 args_batch_dim (list, optional): Specify the batch dim of the args. 

29 Default ``None``. 

30 kwargs_batch_dim(dict, optional): Specify the batch dim of the kwargs. 

31 Default ``None``. 

32 Inputs: 

33 - **args** (list) - Input args. 

34 - **kwargs** (dict) - Input kwargs. 

35 

36 Outputs: 

37 - **args_after_split** (list) - Input args after split into micro_batches. 

38 - **kwargs_after_split** (list) - Input kwargs after split into micro_batches. 

39 """ 

40 

41 def __init__(self, micro_batch_num, args_batch_dim=None, kwargs_batch_dim=None): 

42 super().__init__() 

43 self.micro_batch_num = micro_batch_num 

44 self.args_batch_dim = args_batch_dim 

45 self.kwargs_batch_dim = kwargs_batch_dim 

46 

47 def construct(self, args, kwargs): 

48 """Construct of _MicroBatch""" 

49 args_after_split = [] 

50 kwargs_after_split = [] 

51 for micro_idx in range(self.micro_batch_num): 

52 micro_args = [] 

53 micro_kwargs = {} 

54 for arg_idx, cur_arg in enumerate(args): 

55 cur_arg_batch_dim = 0 

56 if self.args_batch_dim and self.args_batch_dim[arg_idx] is not None: 

57 cur_arg_batch_dim = self.args_batch_dim[arg_idx].batch_dim 

58 if isinstance(cur_arg, hyper_parallel.DTensor): 

59 micro_arg = self.split_inputs_with_custom_shard(cur_arg, cur_arg_batch_dim, micro_idx) 

60 else: 

61 micro_arg = self.split_inputs(cur_arg, cur_arg_batch_dim, micro_idx) 

62 micro_args.append(micro_arg) 

63 args_after_split.append(micro_args) 

64 

65 for key, cur_kwarg in kwargs.items(): 

66 cur_kwarg_batch_dim = 0 

67 if self.kwargs_batch_dim is not None: 

68 cur_kwarg_batch_dim = self.kwargs_batch_dim[key].batch_dim 

69 if isinstance(cur_kwarg, hyper_parallel.DTensor): 

70 micro_kwarg = self.split_inputs_with_custom_shard(cur_kwarg, cur_kwarg_batch_dim, micro_idx) 

71 else: 

72 micro_kwarg = self.split_inputs(cur_kwarg, cur_kwarg_batch_dim, micro_idx) 

73 micro_kwargs[key] = micro_kwarg 

74 kwargs_after_split.append(micro_kwargs) 

75 return args_after_split, kwargs_after_split 

76 

77 def split_inputs_with_custom_shard(self, input_tensor, cur_arg_batch_dim, micro_idx): 

78 if not isinstance(input_tensor, hyper_parallel.DTensor): 

79 raise TypeError(f"Input type {type(input_tensor)} is not DTensor.") 

80 input_layout = input_tensor.layout 

81 func_wrap = custom_shard(self.split_inputs, 

82 device_mesh=input_layout.mesh, 

83 out_placements=(input_layout.placements,), 

84 in_placements=(input_layout.placements, None, None) 

85 ) 

86 return func_wrap(input_tensor, cur_arg_batch_dim, micro_idx) 

87 

88 def split_inputs(self, input_tensor, cur_arg_batch_dim, micro_idx): 

89 """ 

90 Split the input along the specified batch_dim and micro_idx 

91 """ 

92 if cur_arg_batch_dim == -1: 

93 return input_tensor 

94 batch_dim_shape = input_tensor.shape[cur_arg_batch_dim] 

95 micro_batch_begin = (batch_dim_shape // self.micro_batch_num) * micro_idx 

96 micro_batch_end = (batch_dim_shape // self.micro_batch_num) * (micro_idx + 1) 

97 strided_slice_begin = [0] * input_tensor.ndim 

98 strided_slice_strides = [1] * input_tensor.ndim 

99 strided_slice_end = list(input_tensor.shape) 

100 strided_slice_begin[cur_arg_batch_dim] = micro_batch_begin 

101 strided_slice_end[cur_arg_batch_dim] = micro_batch_end 

102 micro_input = ops.strided_slice(input_tensor, strided_slice_begin, strided_slice_end, strided_slice_strides) 

103 return micro_input 

104 

105 

106def send_object_list(obj, dst=0, group=None): 

107 """ 

108 Send the input Python object to dst rank. 

109 

110 Args: 

111 obj (Any): The input tensor to be send. 

112 dst (int, optional): Specifies the global rank that send the Python object to. 

113 Default: ``0``. 

114 group (str, optional): Communication group. Default: ``None``. 

115 """ 

116 if group is None: 

117 group = GlobalComm.WORLD_COMM_GROUP 

118 if not isinstance(group, str): 

119 raise TypeError(f"For 'send_object', the argument 'group' must be type of string, \ 

120 but got 'group' type : {type(group)}.") 

121 if not isinstance(dst, int): 

122 raise TypeError("For send_object, the dst must be int.") 

123 obj_tensor, tensor_size = _object_to_tensor(obj) 

124 obj_size = Tensor([tensor_size], dtype=mstype.int32) 

125 send(obj_size, dst, group) 

126 send(obj_tensor, dst, group) 

127 

128 

129def recv_object_list(recv_obj, src=0, group=None): 

130 """ 

131 receive Python object from src rank. 

132 

133 Args: 

134 recv_obj (list): list to recv python objects. 

135 src (int, optional): Specifies the global rank that receive the Python object. 

136 Default: ``0`` . 

137 group (str, optional): Communication group. Default: ``None``. 

138 """ 

139 if group is None: 

140 group = GlobalComm.WORLD_COMM_GROUP 

141 if not isinstance(group, str): 

142 raise TypeError(f"For 'recv_object', the argument 'group' must be type of string, \ 

143 but got 'group' type : {type(group)}.") 

144 if not isinstance(src, int): 

145 raise TypeError("For recv_object, the src must be int.") 

146 obj_size = Tensor([0], dtype=mstype.int32) 

147 recv(obj_size, src, group) 

148 size_val = obj_size.item() 

149 obj_tensor = mint.empty([size_val], dtype=mstype.int8) 

150 recv(obj_tensor, src, group) 

151 buf = obj_tensor.asnumpy().tobytes()[:size_val] 

152 recv_obj.clear() 

153 recv_obj.append(pickle.Unpickler(io.BytesIO(buf)).load()[0])