Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / torch / pipeline_parallel / _utils.py: 0%

49 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-20 07:18 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""pipeline parallel utils""" 

16from torch import nn 

17 

18import hyper_parallel 

19 

20 

21class _MicroBatch(nn.Module): 

22 """ 

23 Split inputs into micro_batch in pipeline parallel. 

24 

25 Args: 

26 micro_batch_num (int): The number of micro-batch. 

27 args_batch_dim (list, optional): Specify the batch dim of the args. 

28 Default ``None``. 

29 kwargs_batch_dim(dict, optional): Specify the batch dim of the kwargs. 

30 Default ``None``. 

31 Inputs: 

32 - **args** (list) - Input args. 

33 - **kwargs** (dict) - Input kwargs. 

34 

35 Outputs: 

36 - **args_after_split** (list) - Input args after split into micro_batches. 

37 - **kwargs_after_split** (list) - Input kwargs after split into micro_batches. 

38 """ 

39 

40 def __init__(self, micro_batch_num, args_batch_dim=None, kwargs_batch_dim=None): 

41 super().__init__() 

42 self.micro_batch_num = micro_batch_num 

43 self.args_batch_dim = args_batch_dim 

44 self.kwargs_batch_dim = kwargs_batch_dim 

45 

46 def forward(self, args, kwargs): 

47 """forward of _MicroBatch""" 

48 args_after_split = [] 

49 kwargs_after_split = [] 

50 for micro_idx in range(self.micro_batch_num): 

51 micro_args = [] 

52 micro_kwargs = {} 

53 for arg_idx, cur_arg in enumerate(args): 

54 cur_arg_batch_dim = 0 

55 if self.args_batch_dim and self.args_batch_dim[arg_idx] is not None: 

56 cur_arg_batch_dim = self.args_batch_dim[arg_idx].batch_dim 

57 if isinstance(cur_arg, hyper_parallel.DTensor): 

58 micro_arg = self.split_inputs_with_custom_shard(cur_arg, cur_arg_batch_dim, micro_idx) 

59 else: 

60 micro_arg = self.split_inputs(cur_arg, cur_arg_batch_dim, micro_idx) 

61 micro_args.append(micro_arg) 

62 args_after_split.append(micro_args) 

63 

64 for key, cur_kwarg in kwargs.items(): 

65 cur_kwarg_batch_dim = 0 

66 if self.kwargs_batch_dim is not None: 

67 cur_kwarg_batch_dim = self.kwargs_batch_dim[key].batch_dim 

68 if isinstance(cur_kwarg, hyper_parallel.DTensor): 

69 micro_kwarg = self.split_inputs_with_custom_shard(cur_kwarg, cur_kwarg_batch_dim, micro_idx) 

70 else: 

71 micro_kwarg = self.split_inputs(cur_kwarg, cur_kwarg_batch_dim, micro_idx) 

72 micro_kwargs[key] = micro_kwarg 

73 kwargs_after_split.append(micro_kwargs) 

74 return args_after_split, kwargs_after_split 

75 

76 def split_inputs_with_custom_shard(self, input_tensor, cur_arg_batch_dim, micro_idx): 

77 input_layout = input_tensor.layout 

78 func_wrap = hyper_parallel.custom_shard(self.split_inputs, 

79 device_mesh=input_layout.mesh, 

80 out_placements=(input_layout.placements,), 

81 in_placements=(input_layout.placements, None, None) 

82 ) 

83 return func_wrap(input_tensor, cur_arg_batch_dim, micro_idx) 

84 

85 def split_inputs(self, input_tensor, cur_arg_batch_dim, micro_idx): 

86 """ 

87 Split the input along the specified batch_dim and micro_idx 

88 """ 

89 if cur_arg_batch_dim == -1: 

90 return input_tensor 

91 batch_dim_shape = input_tensor.shape[cur_arg_batch_dim] 

92 if batch_dim_shape % self.micro_batch_num != 0: 

93 raise ValueError(f"Batch dimension size {batch_dim_shape} is not divisible by \ 

94 micro_batch_num {self.micro_batch_num}") 

95 micro_batch_size = batch_dim_shape // self.micro_batch_num 

96 

97 # Calculate start and end idx 

98 start = micro_batch_size * micro_idx 

99 end = micro_batch_size * (micro_idx + 1) 

100 

101 # Create slicing tuple 

102 slices = [slice(None)] * input_tensor.ndim 

103 slices[cur_arg_batch_dim] = slice(start, end) 

104 return input_tensor[slices]