Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / mindspore / pipeline_parallel / _utils.py: 18%
84 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
1# http://www.apache.org/licenses/LICENSE-2.0
2#
3# Unless required by applicable law or agreed to in writing, software
4# distributed under the License is distributed on an "AS IS" BASIS,
5# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
6# See the License for the specific language governing permissions and
7# limitations under the License.
8# ============================================================================
9"""pipeline parallel utils"""
10import io
11import pickle
13from mindspore import nn, Tensor, mint, ops
14from mindspore.common import dtype as mstype
15from mindspore.communication import GlobalComm
16from mindspore.mint.distributed.distributed import _object_to_tensor, send, recv
18import hyper_parallel
19from hyper_parallel.core.shard.custom_shard import custom_shard
22class _MicroBatch(nn.Cell):
23 """
24 Split inputs into micro_batch in pipeline parallel.
26 Args:
27 micro_batch_num (int): The number of micro-batch.
28 args_batch_dim (list, optional): Specify the batch dim of the args.
29 Default ``None``.
30 kwargs_batch_dim(dict, optional): Specify the batch dim of the kwargs.
31 Default ``None``.
32 Inputs:
33 - **args** (list) - Input args.
34 - **kwargs** (dict) - Input kwargs.
36 Outputs:
37 - **args_after_split** (list) - Input args after split into micro_batches.
38 - **kwargs_after_split** (list) - Input kwargs after split into micro_batches.
39 """
41 def __init__(self, micro_batch_num, args_batch_dim=None, kwargs_batch_dim=None):
42 super().__init__()
43 self.micro_batch_num = micro_batch_num
44 self.args_batch_dim = args_batch_dim
45 self.kwargs_batch_dim = kwargs_batch_dim
47 def construct(self, args, kwargs):
48 """Construct of _MicroBatch"""
49 args_after_split = []
50 kwargs_after_split = []
51 for micro_idx in range(self.micro_batch_num):
52 micro_args = []
53 micro_kwargs = {}
54 for arg_idx, cur_arg in enumerate(args):
55 cur_arg_batch_dim = 0
56 if self.args_batch_dim and self.args_batch_dim[arg_idx] is not None:
57 cur_arg_batch_dim = self.args_batch_dim[arg_idx].batch_dim
58 if isinstance(cur_arg, hyper_parallel.DTensor):
59 micro_arg = self.split_inputs_with_custom_shard(cur_arg, cur_arg_batch_dim, micro_idx)
60 else:
61 micro_arg = self.split_inputs(cur_arg, cur_arg_batch_dim, micro_idx)
62 micro_args.append(micro_arg)
63 args_after_split.append(micro_args)
65 for key, cur_kwarg in kwargs.items():
66 cur_kwarg_batch_dim = 0
67 if self.kwargs_batch_dim is not None:
68 cur_kwarg_batch_dim = self.kwargs_batch_dim[key].batch_dim
69 if isinstance(cur_kwarg, hyper_parallel.DTensor):
70 micro_kwarg = self.split_inputs_with_custom_shard(cur_kwarg, cur_kwarg_batch_dim, micro_idx)
71 else:
72 micro_kwarg = self.split_inputs(cur_kwarg, cur_kwarg_batch_dim, micro_idx)
73 micro_kwargs[key] = micro_kwarg
74 kwargs_after_split.append(micro_kwargs)
75 return args_after_split, kwargs_after_split
77 def split_inputs_with_custom_shard(self, input_tensor, cur_arg_batch_dim, micro_idx):
78 if not isinstance(input_tensor, hyper_parallel.DTensor):
79 raise TypeError(f"Input type {type(input_tensor)} is not DTensor.")
80 input_layout = input_tensor.layout
81 func_wrap = custom_shard(self.split_inputs,
82 device_mesh=input_layout.mesh,
83 out_placements=(input_layout.placements,),
84 in_placements=(input_layout.placements, None, None)
85 )
86 return func_wrap(input_tensor, cur_arg_batch_dim, micro_idx)
88 def split_inputs(self, input_tensor, cur_arg_batch_dim, micro_idx):
89 """
90 Split the input along the specified batch_dim and micro_idx
91 """
92 if cur_arg_batch_dim == -1:
93 return input_tensor
94 batch_dim_shape = input_tensor.shape[cur_arg_batch_dim]
95 micro_batch_begin = (batch_dim_shape // self.micro_batch_num) * micro_idx
96 micro_batch_end = (batch_dim_shape // self.micro_batch_num) * (micro_idx + 1)
97 strided_slice_begin = [0] * input_tensor.ndim
98 strided_slice_strides = [1] * input_tensor.ndim
99 strided_slice_end = list(input_tensor.shape)
100 strided_slice_begin[cur_arg_batch_dim] = micro_batch_begin
101 strided_slice_end[cur_arg_batch_dim] = micro_batch_end
102 micro_input = ops.strided_slice(input_tensor, strided_slice_begin, strided_slice_end, strided_slice_strides)
103 return micro_input
106def send_object_list(obj, dst=0, group=None):
107 """
108 Send the input Python object to dst rank.
110 Args:
111 obj (Any): The input tensor to be send.
112 dst (int, optional): Specifies the global rank that send the Python object to.
113 Default: ``0``.
114 group (str, optional): Communication group. Default: ``None``.
115 """
116 if group is None:
117 group = GlobalComm.WORLD_COMM_GROUP
118 if not isinstance(group, str):
119 raise TypeError(f"For 'send_object', the argument 'group' must be type of string, \
120 but got 'group' type : {type(group)}.")
121 if not isinstance(dst, int):
122 raise TypeError("For send_object, the dst must be int.")
123 obj_tensor, tensor_size = _object_to_tensor(obj)
124 obj_size = Tensor([tensor_size], dtype=mstype.int32)
125 send(obj_size, dst, group)
126 send(obj_tensor, dst, group)
129def recv_object_list(recv_obj, src=0, group=None):
130 """
131 receive Python object from src rank.
133 Args:
134 recv_obj (list): list to recv python objects.
135 src (int, optional): Specifies the global rank that receive the Python object.
136 Default: ``0`` .
137 group (str, optional): Communication group. Default: ``None``.
138 """
139 if group is None:
140 group = GlobalComm.WORLD_COMM_GROUP
141 if not isinstance(group, str):
142 raise TypeError(f"For 'recv_object', the argument 'group' must be type of string, \
143 but got 'group' type : {type(group)}.")
144 if not isinstance(src, int):
145 raise TypeError("For recv_object, the src must be int.")
146 obj_size = Tensor([0], dtype=mstype.int32)
147 recv(obj_size, src, group)
148 size_val = obj_size.item()
149 obj_tensor = mint.empty([size_val], dtype=mstype.int8)
150 recv(obj_tensor, src, group)
151 buf = obj_tensor.asnumpy().tobytes()[:size_val]
152 recv_obj.clear()
153 recv_obj.append(pickle.Unpickler(io.BytesIO(buf)).load()[0])