Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / torch / pipeline_parallel / _utils.py: 0%
49 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""pipeline parallel utils"""
16from torch import nn
18import hyper_parallel
21class _MicroBatch(nn.Module):
22 """
23 Split inputs into micro_batch in pipeline parallel.
25 Args:
26 micro_batch_num (int): The number of micro-batch.
27 args_batch_dim (list, optional): Specify the batch dim of the args.
28 Default ``None``.
29 kwargs_batch_dim(dict, optional): Specify the batch dim of the kwargs.
30 Default ``None``.
31 Inputs:
32 - **args** (list) - Input args.
33 - **kwargs** (dict) - Input kwargs.
35 Outputs:
36 - **args_after_split** (list) - Input args after split into micro_batches.
37 - **kwargs_after_split** (list) - Input kwargs after split into micro_batches.
38 """
40 def __init__(self, micro_batch_num, args_batch_dim=None, kwargs_batch_dim=None):
41 super().__init__()
42 self.micro_batch_num = micro_batch_num
43 self.args_batch_dim = args_batch_dim
44 self.kwargs_batch_dim = kwargs_batch_dim
46 def forward(self, args, kwargs):
47 """forward of _MicroBatch"""
48 args_after_split = []
49 kwargs_after_split = []
50 for micro_idx in range(self.micro_batch_num):
51 micro_args = []
52 micro_kwargs = {}
53 for arg_idx, cur_arg in enumerate(args):
54 cur_arg_batch_dim = 0
55 if self.args_batch_dim and self.args_batch_dim[arg_idx] is not None:
56 cur_arg_batch_dim = self.args_batch_dim[arg_idx].batch_dim
57 if isinstance(cur_arg, hyper_parallel.DTensor):
58 micro_arg = self.split_inputs_with_custom_shard(cur_arg, cur_arg_batch_dim, micro_idx)
59 else:
60 micro_arg = self.split_inputs(cur_arg, cur_arg_batch_dim, micro_idx)
61 micro_args.append(micro_arg)
62 args_after_split.append(micro_args)
64 for key, cur_kwarg in kwargs.items():
65 cur_kwarg_batch_dim = 0
66 if self.kwargs_batch_dim is not None:
67 cur_kwarg_batch_dim = self.kwargs_batch_dim[key].batch_dim
68 if isinstance(cur_kwarg, hyper_parallel.DTensor):
69 micro_kwarg = self.split_inputs_with_custom_shard(cur_kwarg, cur_kwarg_batch_dim, micro_idx)
70 else:
71 micro_kwarg = self.split_inputs(cur_kwarg, cur_kwarg_batch_dim, micro_idx)
72 micro_kwargs[key] = micro_kwarg
73 kwargs_after_split.append(micro_kwargs)
74 return args_after_split, kwargs_after_split
76 def split_inputs_with_custom_shard(self, input_tensor, cur_arg_batch_dim, micro_idx):
77 input_layout = input_tensor.layout
78 func_wrap = hyper_parallel.custom_shard(self.split_inputs,
79 device_mesh=input_layout.mesh,
80 out_placements=(input_layout.placements,),
81 in_placements=(input_layout.placements, None, None)
82 )
83 return func_wrap(input_tensor, cur_arg_batch_dim, micro_idx)
85 def split_inputs(self, input_tensor, cur_arg_batch_dim, micro_idx):
86 """
87 Split the input along the specified batch_dim and micro_idx
88 """
89 if cur_arg_batch_dim == -1:
90 return input_tensor
91 batch_dim_shape = input_tensor.shape[cur_arg_batch_dim]
92 if batch_dim_shape % self.micro_batch_num != 0:
93 raise ValueError(f"Batch dimension size {batch_dim_shape} is not divisible by \
94 micro_batch_num {self.micro_batch_num}")
95 micro_batch_size = batch_dim_shape // self.micro_batch_num
97 # Calculate start and end idx
98 start = micro_batch_size * micro_idx
99 end = micro_batch_size * (micro_idx + 1)
101 # Create slicing tuple
102 slices = [slice(None)] * input_tensor.ndim
103 slices[cur_arg_batch_dim] = slice(start, end)
104 return input_tensor[slices]