Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / pipeline_parallel / scheduler.py: 14%

848 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-20 07:18 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""pipeline schedule""" 

16from abc import ABC, abstractmethod 

17from enum import Enum, auto 

18from collections import defaultdict 

19import itertools 

20import bisect 

21import logging 

22import re 

23import hyper_parallel 

24from hyper_parallel.platform import get_platform 

25from hyper_parallel.core.fully_shard.api import HSDPModule 

26platform = get_platform() 

27logger = logging.getLogger(__name__) 

28 

29 

30class MetaStepType(Enum): 

31 """Specify the enumeration type for MetaStep.""" 

32 FWD = auto() 

33 BWD = auto() 

34 BWD_INPUT = auto() 

35 BWD_WEIGHT = auto() 

36 FWD_RECV = auto() 

37 FWD_SEND = auto() 

38 BWD_RECV = auto() 

39 BWD_SEND = auto() 

40 OVERLAP_F_B = auto() 

41 OVERLAP_B_F = auto() 

42 FSDP_UNSHARD = auto() 

43 FSDP_RESHARD = auto() 

44 FSDP_REDUCE_GRAD = auto() 

45 

46 

47class MetaStep: 

48 """ 

49 Meta step of PipelineSchedule. 

50 An execution list composed of MetaStep can be constructed 

51 and fed into the PipelineSchedule for execution. 

52 

53 Args: 

54 micro_index (int | None): The index of micro-batch. ``None`` for 

55 composite types (``OVERLAP_F_B`` / ``OVERLAP_B_F``) whose real 

56 micro index lives in each ``sub_steps`` entry. 

57 type (MetaStepType): Specify the type of current step. 

58 stage_index (int | None): Stage index of current step. ``None`` 

59 for composite types; use ``sub_steps`` to get each direction's 

60 stage. 

61 sub_steps (tuple[MetaStep, MetaStep] | None): For composite types 

62 only: ``(fwd, bwd)`` for ``OVERLAP_F_B``, ``(bwd, fwd)`` for 

63 ``OVERLAP_B_F``. 

64 """ 

65 def __init__(self, micro_index, meta_type, stage_index, sub_steps=None): 

66 self._type = meta_type 

67 self._micro_index = micro_index 

68 self._stage_index = stage_index 

69 self._sub_steps = sub_steps 

70 

71 @property 

72 def micro_index(self): 

73 return self._micro_index 

74 

75 @property 

76 def stage_index(self): 

77 return self._stage_index 

78 

79 @property 

80 def type(self): 

81 return self._type 

82 

83 @property 

84 def sub_steps(self): 

85 """Sub-steps for composite types: ``(fwd, bwd)`` for OVERLAP_F_B, 

86 ``(bwd, fwd)`` for OVERLAP_B_F, or ``None``.""" 

87 return self._sub_steps 

88 

89 def __eq__(self, value): 

90 if not isinstance(value, MetaStep): 

91 return NotImplemented 

92 return (self.type == value.type 

93 and self.micro_index == value.micro_index 

94 and self.stage_index == value.stage_index 

95 and self.sub_steps == value.sub_steps) 

96 

97 def __ne__(self, value): 

98 if not isinstance(value, MetaStep): 

99 return NotImplemented 

100 return not self.__eq__(value) 

101 

102 def __hash__(self): 

103 return hash((self.type, self.micro_index, self.stage_index)) 

104 

105 def __str__(self): 

106 if self.sub_steps: 

107 sub = ", ".join(str(s) for s in self.sub_steps) 

108 return (f"MetaStep(type={self.type}, micro_index={self.micro_index}, " 

109 f"stage_index={self.stage_index}, sub_steps=[{sub}])") 

110 return f"MetaStep(type={self.type}, micro_index={self.micro_index}, stage_index={self.stage_index})" 

111 

112 def __repr__(self): 

113 return self.__str__() 

114 

115 @staticmethod 

116 def from_str(step_str): 

117 pass 

118 

119def generate_stage_to_rank_mapping(real_stage_num, stage_num, style='loop'): 

120 """Generate stage to rank mapping for loop or V schedules.""" 

121 if style == 'loop': 

122 return {stage_index: stage_index % real_stage_num for stage_index in range(stage_num)} 

123 if style == 'v': 

124 if stage_num % real_stage_num != 0: 

125 raise ValueError( 

126 f"stage_num {stage_num} must be evenly divisible by real_stage_num {real_stage_num} for V schedules." 

127 ) 

128 mapping = {} 

129 rank_index = 0 

130 for stage_index in range(stage_num): 

131 mapping[stage_index] = rank_index 

132 if (stage_index + 1) % real_stage_num == 0: 

133 continue 

134 if (stage_index // real_stage_num) % 2 == 0: 

135 rank_index += 1 

136 else: 

137 rank_index -= 1 

138 return mapping 

139 raise ValueError(f"Unsupported stage rank mapping style: {style}") 

140 

141 

142def generate_rank_to_stage_mapping(real_stage_num, stage_num, style='loop'): 

143 """Invert the stage to rank mapping.""" 

144 stage_to_rank = generate_stage_to_rank_mapping(real_stage_num, stage_num, style) 

145 rank_to_stages = defaultdict(list) 

146 for stage_index, rank in stage_to_rank.items(): 

147 rank_to_stages[rank].append(stage_index) 

148 for stages in rank_to_stages.values(): 

149 stages.sort() 

150 return dict(rank_to_stages) 

151 

152def iter_leaf_meta_steps(step): 

153 """Yield leaf MetaSteps, recursively expanding OVERLAP_F_B containers.""" 

154 if step is None: 

155 return 

156 if step.type == MetaStepType.OVERLAP_F_B: 

157 for sub_step in step.sub_steps: 

158 yield from iter_leaf_meta_steps(sub_step) 

159 return 

160 yield step 

161 

162class PipelineContext: 

163 """Context passed to custom execution functions registered via 

164 :meth:`PipelineScheduleRuntime.register_custom_function`. 

165 

166 Provides access to the schedule's internal state so that custom 

167 handlers (e.g. OVERLAP_F_B callbacks) can perform P2P communication, 

168 invoke ``forward_one_chunk`` / ``backward_one_chunk``, record losses, 

169 etc. 

170 

171 Attributes: 

172 schedule: The :class:`PipelineScheduleRuntime` instance. 

173 arg_mbs: Per-micro-batch positional args. 

174 kwarg_mbs: Per-micro-batch keyword args. 

175 losses: Mutable list for loss collection. 

176 fwd_recv_ops: ``{(stage_index, micro_index): [handle, ...]}`` 

177 cached forward recv handles (when ``overlap_p2p=True``). 

178 bwd_recv_ops: Same for backward recv handles. 

179 send_handles: Mutable list of outstanding send handles. 

180 """ 

181 

182 def __init__(self, schedule, arg_mbs, kwarg_mbs, losses, send_handles): 

183 self.schedule = schedule 

184 self.arg_mbs = arg_mbs 

185 self.kwarg_mbs = kwarg_mbs 

186 self.losses = losses 

187 self.fwd_recv_ops = schedule.fwd_handle_cache 

188 self.bwd_recv_ops = schedule.bwd_handle_cache 

189 self.send_handles = send_handles 

190 

191 

192class PipelineScheduleRuntime(ABC): 

193 """ 

194 Base class for pipeline schedule. 

195 Implements the `split_microbatches` and `run_microbatches` method. 

196 Derived classes should implement `run_microbatches` method and `run` method. 

197 

198 Supports registering **custom execution functions** for any 

199 :class:`MetaStepType` via :meth:`register_custom_function`. When 

200 ``run_microbatches`` encounters a step whose type has a registered 

201 handler, it creates a :class:`PipelineContext` and delegates execution 

202 to the handler instead of using the built-in logic. 

203 

204 Args: 

205 stages (list[PipelineStage], PipelineStage): PipelineStage used to run_microbatches. 

206 micro_batch_num (int): The number of micro-batch. 

207 args_batch_dim (list, optional): Specify the batch dim of the args. 

208 Default ``None``. 

209 kwargs_batch_dim (dict, optional): Specify the batch dim of the kwargs. 

210 Default ``None``. 

211 """ 

212 def __init__(self, 

213 stages, 

214 micro_batch_num, 

215 args_batch_dim=None, 

216 kwargs_batch_dim=None, 

217 output_concat_dim=None, 

218 overlap_p2p=False): 

219 self.stages = self._check_stages(stages) 

220 self.micro_batch_num = micro_batch_num 

221 self._args_batch_dim = args_batch_dim 

222 self._kwargs_batch_dim = kwargs_batch_dim 

223 self._output_concat_dim = output_concat_dim 

224 self.split_micro_batch = platform.micro_batch(self.micro_batch_num, 

225 self._args_batch_dim, self._kwargs_batch_dim) 

226 self.n_local_stages = len(self.stages) 

227 self._stage_dict = self.convert_stages_dict() 

228 self.real_stage_num = self.stages[0].stage_num // self.n_local_stages 

229 self._stage_num = self.stages[0].stage_num 

230 self._stage_to_rank_index = None 

231 self._overlap_p2p = overlap_p2p 

232 self.exec_order = {} 

233 self._init_stages() 

234 self._build_stage_to_rank_index() 

235 self.fwd_handle_cache = {} 

236 self.bwd_handle_cache = {} 

237 self._custom_fn_map = {} 

238 

239 def register_custom_function(self, step_type: MetaStepType, fn) -> None: 

240 """Register a custom execution function for the given step type. 

241 

242 When :meth:`run_microbatches` encounters a :class:`MetaStep` whose 

243 ``type`` matches ``step_type``, it calls ``fn(step, ctx)`` instead 

244 of the built-in logic. 

245 

246 Args: 

247 step_type: The :class:`MetaStepType` to intercept. 

248 fn: A callable with signature ``(step: MetaStep, ctx: PipelineContext) -> None``. 

249 

250 Example: 

251 >>> def my_overlap_callback(step, ctx): 

252 ... fwd_step, bwd_step = step.sub_steps 

253 ... # custom parallel execution logic 

254 >>> schedule.register_custom_function(MetaStepType.OVERLAP_F_B, my_overlap_callback) 

255 """ 

256 self._custom_fn_map[step_type] = fn 

257 

258 def _inject_local_fsdp_actions(self): 

259 """Annotate the local rank schedule with optional FSDP control actions.""" 

260 current_rank = self._stage_to_rank_index[self.stages[0].stage_index] 

261 managed_stage_indices = { 

262 stage.stage_index 

263 for stage in self.stages 

264 if isinstance(stage.submodule, HSDPModule) 

265 } 

266 if not managed_stage_indices: 

267 return 

268 if len(managed_stage_indices) != len(self.stages): 

269 raise RuntimeError( 

270 "When injecting fsdp_action, expect all stages to be HSDPModule. " 

271 "Check whether all separated modules are wrapped with 'fully_shard'." 

272 ) 

273 rank_actions = add_fsdp_unshard_reshard(self.exec_order[current_rank], managed_stage_indices) 

274 self.exec_order[current_rank] = add_fsdp_reduce_grad( 

275 rank_actions, 

276 managed_stage_indices, 

277 self.micro_batch_num, 

278 ) 

279 

280 @abstractmethod 

281 def _build_stage_to_rank_index(self) -> None: 

282 """ 

283 Build attribute of _stage_to_rank_index. 

284 Each subclass constructs it according to its own schedule style. 

285 """ 

286 

287 @abstractmethod 

288 def construct_exec_order(self) -> None: 

289 """Build exec order, PP cmopute and PP comms(Send/Recv)""" 

290 

291 def build_exec_order(self) -> None: 

292 """Build the execution order and inject FSDP actions.""" 

293 self.construct_exec_order() 

294 self._inject_local_fsdp_actions() 

295 

296 def convert_stages_dict(self): 

297 """convert stages to dict.""" 

298 stage_dict = {} 

299 for stage in self.stages: 

300 stage_dict[stage.stage_index] = stage 

301 return stage_dict 

302 

303 def split_microbatches(self, args, kwargs): 

304 """split_microbatches.""" 

305 if args or kwargs: 

306 args_split, kwargs_split = self.split_micro_batch(args, kwargs) 

307 return args_split, kwargs_split 

308 return [[] for _ in range(self.micro_batch_num)], [{} for _ in range(self.micro_batch_num)] 

309 

310 def _check_stages(self, stages): 

311 """check stages type.""" 

312 if isinstance(stages, hyper_parallel.PipelineStage): 

313 return [stages] 

314 if isinstance(stages, (list, tuple)): 

315 for stage in stages: 

316 if not isinstance(stage, hyper_parallel.PipelineStage): 

317 raise TypeError(f"Argument 'stages' must be type of PipelineStage, \ 

318 list or tuple of PipelineStage, but got list or tuple of {type(stage)}.") 

319 return stages 

320 raise TypeError(f"Argument 'stages' must be type of PipelineStage, \ 

321 list or tuple of PipelineStage, but got type of {type(stages)}.") 

322 

323 def _init_stages(self): 

324 """init stages.""" 

325 for stage in self.stages: 

326 stage.init(self.n_local_stages) 

327 

328 def run(self, *args, **kwargs): 

329 """schedule run.""" 

330 split_args, split_kwargs = self.split_microbatches(args, kwargs) 

331 losses = [] 

332 self.run_microbatches(split_args, split_kwargs, losses) 

333 return losses 

334 

335 def sync_shared_parameters_grad(self): 

336 """sync_shared_parameters_grad.""" 

337 for stage in self.stages: 

338 stage.sync_shared_parameters_grad() 

339 

340 def update_losses(self, stage, loss, losses): 

341 """update_losses.""" 

342 if stage.is_last_stage: 

343 losses.append(loss) 

344 

345 def _wait_p2p(self, handles): 

346 for handle in handles: 

347 if handle is not None: 

348 handle.wait() 

349 

350 def _assert_in_unshard_if_needed(self, stage, check_step): 

351 if not isinstance(stage.submodule, HSDPModule): 

352 return 

353 submodule_hsdp_scheduler = stage.submodule.hsdp_scheduler 

354 scheduler_state = submodule_hsdp_scheduler.hsdp_state 

355 if scheduler_state.is_shard: 

356 raise RuntimeError( 

357 f"Executing MetaStep: {check_step}, expected HSDPModule parameters in unsharded " 

358 f"state, but got sharded parameters." 

359 ) 

360 

361 def _exec_step(self, cur_step, arg_mbs, kwarg_mbs, losses, send_handles): 

362 """Execute a single built-in step (FWD/BWD/SEND/RECV).""" 

363 stage = self._stage_dict[cur_step.stage_index] 

364 stage_index = cur_step.stage_index 

365 micro_index = cur_step.micro_index 

366 

367 if cur_step.type == MetaStepType.FWD_RECV: 

368 comm_handle = stage.exec_fwd_recv_ops(micro_index) 

369 if not self._overlap_p2p: 

370 self._wait_p2p(comm_handle) 

371 else: 

372 self.fwd_handle_cache[(stage_index, micro_index)] = comm_handle 

373 

374 elif cur_step.type == MetaStepType.FWD: 

375 self._assert_in_unshard_if_needed(stage, cur_step) 

376 key = (stage_index, micro_index) 

377 if self._overlap_p2p and key in self.fwd_handle_cache: 

378 self._wait_p2p(self.fwd_handle_cache.pop(key)) 

379 out = stage.forward_one_chunk(micro_index, arg_mbs[micro_index], kwarg_mbs[micro_index]) 

380 self.update_losses(stage, out, losses) 

381 

382 elif cur_step.type == MetaStepType.FWD_SEND: 

383 comm_handle = stage.exec_fwd_send_ops(micro_index) 

384 if not self._overlap_p2p: 

385 self._wait_p2p(comm_handle) 

386 else: 

387 send_handles.append(comm_handle) 

388 

389 elif cur_step.type == MetaStepType.BWD_RECV: 

390 comm_handle = stage.exec_bwd_recv_ops(micro_index) 

391 if not self._overlap_p2p: 

392 self._wait_p2p(comm_handle) 

393 else: 

394 self.bwd_handle_cache[(stage_index, micro_index)] = comm_handle 

395 

396 elif cur_step.type == MetaStepType.BWD: 

397 self._assert_in_unshard_if_needed(stage, cur_step) 

398 key = (stage_index, micro_index) 

399 if self._overlap_p2p and key in self.bwd_handle_cache: 

400 self._wait_p2p(self.bwd_handle_cache.pop(key)) 

401 is_last_microbatch = micro_index == self.micro_batch_num - 1 

402 stage.backward_one_chunk(micro_index, is_last_microbatch) 

403 

404 elif cur_step.type == MetaStepType.BWD_SEND: 

405 comm_handle = stage.exec_bwd_send_ops(micro_index) 

406 if not self._overlap_p2p: 

407 self._wait_p2p(comm_handle) 

408 else: 

409 send_handles.append(comm_handle) 

410 

411 elif cur_step.type in ( 

412 MetaStepType.FSDP_UNSHARD, 

413 MetaStepType.FSDP_RESHARD, 

414 MetaStepType.FSDP_REDUCE_GRAD, 

415 ): 

416 self._exec_fsdp_step(cur_step, stage) 

417 

418 def _exec_fsdp_step(self, cur_step, stage): 

419 """Execute an FSDP control step (unshard, reshard, or reduce-grad).""" 

420 if cur_step.type == MetaStepType.FSDP_UNSHARD: 

421 for _, module in platform.get_cells_and_names(stage.submodule): 

422 if isinstance(module, HSDPModule): 

423 module.unshard() 

424 elif cur_step.type == MetaStepType.FSDP_RESHARD: 

425 for _, module in platform.get_cells_and_names(stage.submodule): 

426 if isinstance(module, HSDPModule): 

427 module.reshard() 

428 elif cur_step.type == MetaStepType.FSDP_REDUCE_GRAD: 

429 stage.execute_reduce_grad() 

430 

431 def run_microbatches(self, arg_mbs, kwarg_mbs, losses): 

432 """Execute the schedule step by step. 

433 

434 Steps whose :attr:`MetaStep.type` has a registered custom function 

435 are delegated to that function with a :class:`PipelineContext`. 

436 Composite ``OVERLAP_F_B`` / ``OVERLAP_B_F`` steps without a 

437 registered handler fall back to executing their ``sub_steps`` 

438 sequentially via :meth:`_exec_step` — correct but without 

439 comm/compute overlap. All other steps are executed by 

440 :meth:`_exec_step`. 

441 """ 

442 real_stage_index = self.stages[0].stage_index % self.real_stage_num 

443 send_handles = [] 

444 ctx = None # lazily created 

445 

446 for cur_step in self.exec_order[real_stage_index]: 

447 if cur_step is None: 

448 continue 

449 

450 # Check for registered custom function 

451 custom_fn = self._custom_fn_map.get(cur_step.type) 

452 if custom_fn is not None: 

453 if ctx is None: 

454 ctx = PipelineContext(self, arg_mbs, kwarg_mbs, losses, send_handles) 

455 custom_fn(cur_step, ctx) 

456 continue 

457 

458 # Default for composite OVERLAP steps: run sub_steps sequentially. 

459 # P2P send/recv around these steps are already laid out in two 

460 # virtual slots by ``add_send_recv``, so sequential execution is 

461 # semantically equivalent to non-overlapped 1F1B. 

462 if (cur_step.type in (MetaStepType.OVERLAP_F_B, MetaStepType.OVERLAP_B_F) 

463 and cur_step.sub_steps): 

464 for sub in cur_step.sub_steps: 

465 self._exec_step(sub, arg_mbs, kwarg_mbs, losses, send_handles) 

466 continue 

467 

468 self._exec_step(cur_step, arg_mbs, kwarg_mbs, losses, send_handles) 

469 

470 self.sync_shared_parameters_grad() 

471 while send_handles: 

472 self._wait_p2p(send_handles.pop()) 

473 

474 

475class _OverlapPhantom: 

476 """Internal marker used by :func:`add_send_recv` to expand an 

477 ``OVERLAP_F_B`` or ``OVERLAP_B_F`` step into two virtual time slots. 

478 

479 An overlap step composes two sub-steps (``B + F`` or ``F + B``) that 

480 execute concurrently on the GPU but occupy **two** logical time slots 

481 in the column-scan sender timeline — the sender can only finish 

482 emitting the second sub-step's output after the first sub-step has 

483 completed. Treating an overlap step as a single slot places the RECV 

484 triggered by the second sub-step too early on the receiver. 

485 

486 Each overlap step is expanded into two phantoms: 

487 * ``is_first_half=True`` — represents the first sub-step's emission 

488 slot; the original overlap step is emitted into the output 

489 schedule here (only once). 

490 * ``is_first_half=False`` — represents the second sub-step's emission 

491 slot; only its send/recv comms are inserted. 

492 """ 

493 

494 __slots__ = ('obf_step', 'sub_step', 'is_first_half') 

495 

496 def __init__(self, obf_step, sub_step, is_first_half: bool): 

497 self.obf_step = obf_step 

498 self.sub_step = sub_step 

499 self.is_first_half = is_first_half 

500 

501 

502def _expand_overlap_slots(scheduler, real_stage_num): 

503 """Expand OVERLAP steps in a per-rank schedule into 2 virtual time slots. 

504 

505 Returns a new ``{rank: [MetaStep | _OverlapPhantom | None, ...]}`` dict 

506 where each OVERLAP step is replaced by a pair of phantoms. Non-OVERLAP 

507 entries pass through unchanged. 

508 """ 

509 expanded = {} 

510 for rank in range(real_stage_num): 

511 order = scheduler[rank] 

512 exp = [] 

513 for op in order: 

514 if (op is not None 

515 and op.type in (MetaStepType.OVERLAP_F_B, MetaStepType.OVERLAP_B_F) 

516 and op.sub_steps): 

517 exp.append(_OverlapPhantom(op, op.sub_steps[0], is_first_half=True)) 

518 exp.append(_OverlapPhantom(op, op.sub_steps[1], is_first_half=False)) 

519 else: 

520 exp.append(op) 

521 expanded[rank] = exp 

522 return expanded 

523 

524 

525def _process_rank_items(real_stage_num, current_items, insert_step_comms, new_schedule): 

526 """Run ``insert_step_comms`` for each rank's current item, even ranks first. 

527 

528 Even-before-odd ordering avoids P2P deadlocks between adjacent ranks. 

529 """ 

530 for rank in range(0, real_stage_num, 2): 

531 item = current_items.get(rank) 

532 if item is not None: 

533 sub = item.sub_step if isinstance(item, _OverlapPhantom) else item 

534 insert_step_comms(sub, rank, new_schedule) 

535 for rank in range(1, real_stage_num, 2): 

536 item = current_items.get(rank) 

537 if item is not None: 

538 sub = item.sub_step if isinstance(item, _OverlapPhantom) else item 

539 insert_step_comms(sub, rank, new_schedule) 

540 

541 

542def _column_scan_insert_comms(expanded, real_stage_num, insert_step_comms): 

543 """Column-scan over an OVERLAP-expanded schedule to insert SEND/RECV. 

544 

545 Processes ``expanded`` one time slot at a time. Emits the original 

546 overlap step into ``new_schedule`` only once (at the first-half 

547 phantom). Delegates comm insertion to ``insert_step_comms`` for each 

548 plain step or phantom's underlying sub-step. 

549 

550 Even ranks are processed before odd ranks at each time step to avoid 

551 P2P deadlocks between adjacent ranks. 

552 

553 Args: 

554 expanded: Result of :func:`_expand_overlap_slots`. 

555 real_stage_num: Number of physical ranks. 

556 insert_step_comms: Callable ``(step, rank, new_schedule) -> None`` 

557 that inserts SEND/RECV for a single FWD/BWD step. 

558 

559 Returns: 

560 ``{rank: [MetaStep, ...]}`` final schedule. 

561 """ 

562 max_length = max(len(order) for order in expanded.values()) 

563 new_schedule = {rank: [] for rank in range(real_stage_num)} 

564 

565 for time_step in range(max_length): 

566 current_items = {} 

567 for rank in range(real_stage_num): 

568 if time_step < len(expanded[rank]): 

569 item = expanded[rank][time_step] 

570 current_items[rank] = item 

571 if item is None: 

572 # Preserve bubble slots to keep per-rank time-step 

573 # indexing aligned with the column scan. The runtime 

574 # loop skips ``None`` entries, so this is execution- 

575 # semantics-neutral. 

576 new_schedule[rank].append(None) 

577 continue 

578 if isinstance(item, _OverlapPhantom): 

579 # Emit the overlap step only once, at the first-half slot. 

580 if item.is_first_half: 

581 new_schedule[rank].append(item.obf_step) 

582 else: 

583 new_schedule[rank].append(item) 

584 else: 

585 current_items[rank] = None 

586 

587 _process_rank_items( 

588 real_stage_num, current_items, insert_step_comms, new_schedule, 

589 ) 

590 

591 return new_schedule 

592 

593 

594def add_send_recv(scheduler, stage_num, real_stage_num, style='loop'): 

595 """Insert P2P send/recv operations into a per-rank compute schedule. 

596 

597 For each FWD or BWD step that requires cross-rank communication, a 

598 ``FWD_SEND`` / ``BWD_SEND`` is appended to the sender's schedule and a 

599 ``FWD_RECV`` / ``BWD_RECV`` is appended to the receiver's schedule. 

600 

601 ``OVERLAP_F_B`` / ``OVERLAP_B_F`` composite steps are expanded into 

602 **two** virtual time slots during the column scan so that the RECV 

603 triggered by the **second** sub-step lands in the receiver's schedule 

604 one slot later — matching the fact that the sender can only finish 

605 emitting the second sub-step's output after the first completes. 

606 

607 Even ranks are processed before odd ranks at each time step to avoid 

608 P2P deadlocks between adjacent ranks. 

609 

610 Args: 

611 scheduler: ``{rank: [MetaStep | None, ...]}`` — compute schedule 

612 with ``None`` for bubble slots. 

613 stage_num: Total number of virtual pipeline stages. 

614 real_stage_num: Number of physical ranks. 

615 style: Topology mapping — ``'loop'`` or ``'v'``. 

616 

617 Returns: 

618 ``{rank: [MetaStep, ...]}`` — schedule with communication ops inserted. 

619 """ 

620 

621 def stage_to_rank(stage_index: int) -> int: 

622 """Map a virtual stage index to its physical rank.""" 

623 if style == 'loop': 

624 return stage_index % real_stage_num 

625 if style == 'v': 

626 if stage_index < real_stage_num: 

627 return stage_index 

628 return stage_num - 1 - stage_index 

629 raise ValueError(f"Argument 'style' must be 'loop' or 'v', but got {style!r}.") 

630 

631 def _fwd_peer(stage_index: int): 

632 """Return the rank that receives this stage's forward output, or None.""" 

633 if stage_index >= stage_num - 1: 

634 return None 

635 peer = stage_to_rank(stage_index + 1) 

636 return peer if peer != stage_to_rank(stage_index) else None 

637 

638 def _bwd_peer(stage_index: int): 

639 """Return the rank that receives this stage's backward gradient, or None.""" 

640 if stage_index <= 0: 

641 return None 

642 peer = stage_to_rank(stage_index - 1) 

643 return peer if peer != stage_to_rank(stage_index) else None 

644 

645 def _insert_comms_for_step(step, rank, new_schedule): 

646 """Insert send/recv for a single FWD, BWD, or composite OVERLAP step.""" 

647 if step is None: 

648 return 

649 

650 if step.type == MetaStepType.FWD: 

651 peer = _fwd_peer(step.stage_index) 

652 if peer is not None: 

653 new_schedule[rank].append( 

654 MetaStep(step.micro_index, MetaStepType.FWD_SEND, step.stage_index)) 

655 new_schedule[peer].append( 

656 MetaStep(step.micro_index, MetaStepType.FWD_RECV, step.stage_index + 1)) 

657 

658 elif step.type == MetaStepType.BWD: 

659 peer = _bwd_peer(step.stage_index) 

660 if peer is not None: 

661 new_schedule[rank].append( 

662 MetaStep(step.micro_index, MetaStepType.BWD_SEND, step.stage_index)) 

663 new_schedule[peer].append( 

664 MetaStep(step.micro_index, MetaStepType.BWD_RECV, step.stage_index - 1)) 

665 

666 elif step.type in (MetaStepType.OVERLAP_F_B, MetaStepType.OVERLAP_B_F) and step.sub_steps: 

667 for sub in step.sub_steps: 

668 _insert_comms_for_step(sub, rank, new_schedule) 

669 

670 # --- Main logic: expand OVERLAP steps into 2 virtual slots, then scan --- 

671 expanded = _expand_overlap_slots(scheduler, real_stage_num) 

672 return _column_scan_insert_comms(expanded, real_stage_num, _insert_comms_for_step) 

673 

674 

675_ALIGN_PAD = object() 

676"""Sentinel marking a forced 1F1B-boundary bubble produced during alignment.""" 

677 

678 

679def _step_dep_ready(step, rank, t, done, stage_num, stage_to_rank): 

680 """Cross-rank data dependency check used by the alignment simulator. 

681 

682 A FWD step at stage ``s`` depends on FWD at stage ``s-1`` (on a 

683 different rank); BWD at stage ``s`` depends on BWD at stage ``s+1``. 

684 Steps at boundaries or whose producer lives on the same rank are 

685 always ready. 

686 """ 

687 si, mi = step.stage_index, step.micro_index 

688 if step.type == MetaStepType.FWD: 

689 if si == 0 or stage_to_rank(si - 1) == rank: 

690 return True 

691 key = (MetaStepType.FWD, si - 1, mi) 

692 return key in done and done[key] < t 

693 if step.type == MetaStepType.BWD: 

694 if si == stage_num - 1 or stage_to_rank(si + 1) == rank: 

695 return True 

696 key = (MetaStepType.BWD, si + 1, mi) 

697 return key in done and done[key] < t 

698 return True 

699 

700 

701def _simulate_aligned_schedule(padded, stage_num, real_stage_num, stage_to_rank): 

702 """Simulate execution time-step by time-step, inserting bubbles where 

703 a step is not yet ready (cross-rank dep) or where the cooldown 

704 rhythm requires it. 

705 

706 Args: 

707 padded: ``{rank: [step | _ALIGN_PAD | None, ...]}`` after 

708 1F1B-boundary padding. 

709 stage_num: Total number of virtual pipeline stages. 

710 real_stage_num: Number of physical ranks. 

711 stage_to_rank: Topology mapping from stage to rank. 

712 

713 Returns: 

714 ``{rank: [step | None, ...]}`` ready for the column-scan SEND/RECV 

715 insertion phase. 

716 """ 

717 remaining_fwd = { 

718 rank: sum( 

719 1 for s in padded[rank] 

720 if s is not _ALIGN_PAD and s is not None and s.type == MetaStepType.FWD 

721 ) 

722 for rank in range(real_stage_num) 

723 } 

724 cursors = {r: 0 for r in range(real_stage_num)} 

725 aligned = {r: [] for r in range(real_stage_num)} 

726 done = {} 

727 last_was_cooldown_bwd = {r: False for r in range(real_stage_num)} 

728 max_t = sum(len(v) for v in padded.values()) + real_stage_num * 20 

729 

730 def _emit_bubble(rank): 

731 aligned[rank].append(None) 

732 last_was_cooldown_bwd[rank] = False 

733 

734 def _emit_step(rank, step, t, in_cooldown): 

735 aligned[rank].append(step) 

736 done[(step.type, step.stage_index, step.micro_index)] = t 

737 cursors[rank] += 1 

738 if step.type == MetaStepType.FWD: 

739 remaining_fwd[rank] -= 1 

740 last_was_cooldown_bwd[rank] = in_cooldown and step.type == MetaStepType.BWD 

741 

742 def _step_rank_at(t, rank): 

743 if cursors[rank] >= len(padded[rank]): 

744 return 

745 item = padded[rank][cursors[rank]] 

746 if item is _ALIGN_PAD: 

747 _emit_bubble(rank) 

748 cursors[rank] += 1 

749 return 

750 in_cooldown = remaining_fwd[rank] == 0 

751 # Cooldown rhythm: alternate None / BWD in pure-BWD phase. 

752 cooldown_skip = ( 

753 in_cooldown 

754 and item.type == MetaStepType.BWD 

755 and last_was_cooldown_bwd[rank] 

756 ) 

757 if cooldown_skip: 

758 _emit_bubble(rank) 

759 return 

760 if not _step_dep_ready(item, rank, t, done, stage_num, stage_to_rank): 

761 _emit_bubble(rank) 

762 return 

763 _emit_step(rank, item, t, in_cooldown) 

764 

765 for t in range(max_t): 

766 if all(cursors[r] >= len(padded[r]) for r in range(real_stage_num)): 

767 break 

768 for rank in range(real_stage_num): 

769 _step_rank_at(t, rank) 

770 return aligned 

771 

772 

773def auto_align_and_add_send_recv(scheduler, stage_num, real_stage_num, style='loop'): 

774 """Auto-insert bubble alignment and P2P send/recv into a pure-compute schedule. 

775 

776 Unlike :func:`add_send_recv` which requires the caller to pre-insert 

777 ``None`` bubble slots for time-step alignment, this function accepts a 

778 **pure compute order** (``FWD`` / ``BWD`` only, no ``None`` needed) and 

779 automatically determines bubble placement via execution simulation. 

780 

781 Three constraints are enforced: 

782 

783 1. **Data dependency** — a ``FWD(stage_k)`` cannot execute until 

784 ``FWD(stage_{k-1})`` on its source rank has completed (and 

785 analogously for ``BWD``). 

786 2. **1F1B transition alignment** — ``real_stage_num - 1 - rank`` padding 

787 slots are inserted at the warmup → 1F1B boundary (detected as the 

788 first ``FWD`` immediately followed by a ``BWD`` in the compute order) 

789 so that all ranks enter the 1F1B steady state in lockstep. 

790 3. **Cooldown rhythm** — once a rank exhausts its ``FWD`` ops and enters 

791 pure-``BWD`` cooldown, consecutive ``BWD`` steps are separated by a 

792 ``None`` slot, maintaining the column-phase-sync property (no rank 

793 does ``BWD`` while another does ``FWD`` at the same time step). 

794 

795 After alignment, a column-scan pass inserts ``FWD_SEND`` / ``FWD_RECV`` 

796 and ``BWD_SEND`` / ``BWD_RECV`` with the same prefetch semantics as 

797 :func:`add_send_recv`. 

798 

799 Args: 

800 scheduler: ``{rank: [MetaStep, ...]}`` — pure compute schedule. 

801 ``None`` entries are silently stripped before processing. 

802 stage_num: Total number of virtual pipeline stages. 

803 real_stage_num: Number of physical ranks. 

804 style: Topology mapping — ``'loop'`` or ``'v'``. 

805 

806 Returns: 

807 ``{rank: [MetaStep, ...]}`` — fully aligned schedule with bubbles 

808 and communication ops inserted. 

809 """ 

810 

811 # ---- topology helpers (shared with column-scan phase) ---- 

812 

813 def stage_to_rank(stage_index: int) -> int: 

814 if style == 'loop': 

815 return stage_index % real_stage_num 

816 if style == 'v': 

817 if stage_index < real_stage_num: 

818 return stage_index 

819 return stage_num - 1 - stage_index 

820 raise ValueError(f"Argument 'style' must be 'loop' or 'v', but got {style!r}.") 

821 

822 def _fwd_peer(stage_index: int): 

823 if stage_index >= stage_num - 1: 

824 return None 

825 peer = stage_to_rank(stage_index + 1) 

826 return peer if peer != stage_to_rank(stage_index) else None 

827 

828 def _bwd_peer(stage_index: int): 

829 if stage_index <= 0: 

830 return None 

831 peer = stage_to_rank(stage_index - 1) 

832 return peer if peer != stage_to_rank(stage_index) else None 

833 

834 # ---- Phase 1: strip None, detect 1F1B boundary, insert transition padding ---- 

835 

836 def _find_1f1b_boundary(order): 

837 """Index of the first FWD followed by BWD; ``len(order)`` if absent.""" 

838 for i in range(len(order) - 1): 

839 if (order[i].type == MetaStepType.FWD 

840 and order[i + 1].type == MetaStepType.BWD): 

841 return i 

842 return len(order) 

843 

844 padded = {} 

845 for rank in range(real_stage_num): 

846 order = [s for s in scheduler[rank] if s is not None] 

847 boundary = _find_1f1b_boundary(order) 

848 pad_count = real_stage_num - 1 - rank 

849 padded[rank] = order[:boundary] + [_ALIGN_PAD] * pad_count + order[boundary:] 

850 

851 # ---- Phase 2: simulate execution with data deps + cooldown rhythm ---- 

852 

853 aligned = _simulate_aligned_schedule(padded, stage_num, real_stage_num, stage_to_rank) 

854 

855 # ---- Phase 3: column-scan SEND/RECV insertion (same as add_send_recv) ---- 

856 

857 def _insert_comms_for_step(step, rank, new_schedule): 

858 if step is None: 

859 return 

860 if step.type == MetaStepType.FWD: 

861 peer = _fwd_peer(step.stage_index) 

862 if peer is not None: 

863 new_schedule[rank].append( 

864 MetaStep(step.micro_index, MetaStepType.FWD_SEND, step.stage_index)) 

865 new_schedule[peer].append( 

866 MetaStep(step.micro_index, MetaStepType.FWD_RECV, step.stage_index + 1)) 

867 elif step.type == MetaStepType.BWD: 

868 peer = _bwd_peer(step.stage_index) 

869 if peer is not None: 

870 new_schedule[rank].append( 

871 MetaStep(step.micro_index, MetaStepType.BWD_SEND, step.stage_index)) 

872 new_schedule[peer].append( 

873 MetaStep(step.micro_index, MetaStepType.BWD_RECV, step.stage_index - 1)) 

874 elif step.type in (MetaStepType.OVERLAP_F_B, MetaStepType.OVERLAP_B_F) and step.sub_steps: 

875 for sub in step.sub_steps: 

876 _insert_comms_for_step(sub, rank, new_schedule) 

877 

878 # Expand OVERLAP steps into 2 virtual slots before the column scan so 

879 # the RECV triggered by an overlap's second sub-step lands one slot 

880 # later on the receiver — matching the fact that the sender can only 

881 # finish emitting the second sub-step after the first completes. 

882 expanded = _expand_overlap_slots(aligned, real_stage_num) 

883 return _column_scan_insert_comms(expanded, real_stage_num, _insert_comms_for_step) 

884 

885 

886class ScheduleGPipe(PipelineScheduleRuntime): 

887 """ 

888 The Gpipe schedule. 

889 It first executes all forward micro batches and then execute all backward micro batches. 

890 """ 

891 def __init__(self, 

892 stages, 

893 micro_batch_num, 

894 args_batch_dim=None, 

895 kwargs_batch_dim=None, 

896 output_concat_dim=None): 

897 super().__init__(stages, 

898 micro_batch_num, 

899 args_batch_dim=args_batch_dim, 

900 kwargs_batch_dim=kwargs_batch_dim, 

901 output_concat_dim=output_concat_dim) 

902 self.build_exec_order() 

903 

904 def _build_stage_to_rank_index(self) -> None: 

905 self._stage_to_rank_index = generate_stage_to_rank_mapping( 

906 self.real_stage_num, self._stage_num, style='loop' 

907 ) 

908 

909 def construct_exec_order(self): 

910 """construct_exec_order of Gpipe.""" 

911 for stage_index in range(self.real_stage_num): 

912 order_list = [] 

913 for mb_index in range(self.micro_batch_num): 

914 if stage_index != 0: 

915 order_list.append(MetaStep(mb_index, MetaStepType.FWD_RECV, stage_index)) 

916 order_list.append(MetaStep(mb_index, MetaStepType.FWD, stage_index)) 

917 if stage_index != self.real_stage_num - 1: 

918 order_list.append(MetaStep(mb_index, MetaStepType.FWD_SEND, stage_index)) 

919 for mb_index in range(self.micro_batch_num): 

920 if stage_index != self.real_stage_num - 1: 

921 order_list.append(MetaStep(mb_index, MetaStepType.BWD_RECV, stage_index)) 

922 order_list.append(MetaStep(mb_index, MetaStepType.BWD, stage_index)) 

923 if stage_index != 0: 

924 order_list.append(MetaStep(mb_index, MetaStepType.BWD_SEND, stage_index)) 

925 self.exec_order[stage_index] = order_list 

926 

927 

928class Schedule1F1B(PipelineScheduleRuntime): 

929 """ 

930 The 1F1B schedule. 

931 It will perform one forward and one backward on the micro batches in steady state. 

932 """ 

933 def __init__(self, 

934 stages, 

935 micro_batch_num, 

936 args_batch_dim=None, 

937 kwargs_batch_dim=None, 

938 output_concat_dim=None): 

939 super().__init__(stages, 

940 micro_batch_num, 

941 args_batch_dim=args_batch_dim, 

942 kwargs_batch_dim=kwargs_batch_dim, 

943 output_concat_dim=output_concat_dim) 

944 self.build_exec_order() 

945 

946 def _build_stage_to_rank_index(self) -> None: 

947 self._stage_to_rank_index = generate_stage_to_rank_mapping( 

948 self.real_stage_num, self._stage_num, style='loop' 

949 ) 

950 

951 def construct_exec_order(self): 

952 """construct_exec_order of 1F1B.""" 

953 for stage_index in range(self.real_stage_num): 

954 order_list = [] 

955 fwd_index = 0 

956 bwd_index = 0 

957 # warmup phase 

958 warmup_micro_batches = min(self.real_stage_num - stage_index, self.micro_batch_num) 

959 for _ in range(warmup_micro_batches): 

960 if stage_index != 0: 

961 order_list.append(MetaStep(fwd_index, MetaStepType.FWD_RECV, stage_index)) 

962 if stage_index % 2 == 0: 

963 order_list.append(MetaStep(fwd_index, MetaStepType.FWD, stage_index)) 

964 if fwd_index != warmup_micro_batches - 1: 

965 order_list.append(MetaStep(fwd_index, MetaStepType.FWD_SEND, stage_index)) 

966 else: 

967 if fwd_index > 0: 

968 order_list.append(MetaStep(fwd_index - 1, MetaStepType.FWD_SEND, stage_index)) 

969 order_list.append(MetaStep(fwd_index, MetaStepType.FWD, stage_index)) 

970 fwd_index += 1 

971 

972 # if warmup phase cannot filled up, then we need to execute fwd send in advance 

973 if self.real_stage_num - stage_index > self.micro_batch_num: 

974 order_list.append(MetaStep(fwd_index - 1, MetaStepType.FWD_SEND, stage_index)) 

975 fwd_index += 1 

976 # steady phase 

977 steady_micro_batches = self.micro_batch_num - warmup_micro_batches 

978 for _ in range(steady_micro_batches): 

979 if stage_index != self.real_stage_num - 1: 

980 order_list.append(MetaStep(bwd_index, MetaStepType.BWD_RECV, stage_index)) 

981 order_list.append(MetaStep(fwd_index - 1, MetaStepType.FWD_SEND, stage_index)) 

982 order_list.append(MetaStep(bwd_index, MetaStepType.BWD, stage_index)) 

983 

984 if stage_index != 0: 

985 order_list.append(MetaStep(bwd_index, MetaStepType.BWD_SEND, stage_index)) 

986 order_list.append(MetaStep(fwd_index, MetaStepType.FWD_RECV, stage_index)) 

987 order_list.append(MetaStep(fwd_index, MetaStepType.FWD, stage_index)) 

988 fwd_index += 1 

989 bwd_index += 1 

990 

991 # cooldown phase 

992 cooldown_micro_batches = warmup_micro_batches 

993 for _ in range(cooldown_micro_batches): 

994 if stage_index != self.real_stage_num - 1: 

995 order_list.append(MetaStep(bwd_index, MetaStepType.BWD_RECV, stage_index)) 

996 if bwd_index == self.micro_batch_num - warmup_micro_batches and fwd_index <= self.micro_batch_num: 

997 order_list.append(MetaStep(fwd_index - 1, MetaStepType.FWD_SEND, stage_index)) 

998 order_list.append(MetaStep(bwd_index, MetaStepType.BWD, stage_index)) 

999 

1000 if stage_index != 0: 

1001 order_list.append(MetaStep(bwd_index, MetaStepType.BWD_SEND, stage_index)) 

1002 bwd_index += 1 

1003 self.exec_order[stage_index] = order_list 

1004 

1005 

1006class ScheduleInterleaved1F1B(PipelineScheduleRuntime): 

1007 """The Interleaved 1F1B schedule. 

1008 

1009 Supports multiple stages per rank. In steady state, performs one 

1010 forward followed by one backward on each micro-batch. Handles the 

1011 cases where ``micro_batch_num`` is less than, equal to, or greater 

1012 than the stage count, including non-evenly-divisible micro counts. 

1013 

1014 Two orthogonal overlap modes can be enabled via constructor flags: 

1015 

1016 * ``overlap_p2p=True``: defer P2P recv ``handle.wait()`` until the 

1017 consuming FWD/BWD step (or the OVERLAP_B_F callback when 

1018 ``overlap_b_f=True``), letting recv overlap with prior compute. 

1019 * ``overlap_b_f=True``: in the 1F1B steady state, pair consecutive 

1020 ``(B_i, F_{i+1})`` steps into ``OVERLAP_B_F`` composite steps so 

1021 a registered callback can drive comm/compute overlap (typically 

1022 via :class:`CommComputeOverlap` for MoE EP A2A). Users register 

1023 the callback through :meth:`register_custom_function`. 

1024 

1025 The two flags are independent and can be combined. 

1026 

1027 Example: 

1028 >>> # Plain interleaved 1F1B 

1029 >>> sched = ScheduleInterleaved1F1B(stages, 8) 

1030 >>> # With B/F overlap (dual-pipe-style comm/compute overlap) 

1031 >>> sched = ScheduleInterleaved1F1B(stages, 8, overlap_b_f=True) 

1032 >>> sched.register_custom_function(MetaStepType.OVERLAP_B_F, callback) 

1033 """ 

1034 def __init__(self, 

1035 stages, 

1036 micro_batch_num, 

1037 args_batch_dim=None, 

1038 kwargs_batch_dim=None, 

1039 output_concat_dim=None, 

1040 overlap_p2p=False, 

1041 overlap_b_f=False): 

1042 super().__init__(stages, 

1043 micro_batch_num, 

1044 args_batch_dim=args_batch_dim, 

1045 kwargs_batch_dim=kwargs_batch_dim, 

1046 output_concat_dim=output_concat_dim, 

1047 overlap_p2p=overlap_p2p) 

1048 # _overlap_b_f selects between plain F/B emission and OVERLAP_B_F 

1049 # pairing in the 1F1B steady-state phase. Must be set before 

1050 # ``construct_stage_exec_order`` is called below. 

1051 self._overlap_b_f = overlap_b_f 

1052 self.n_rounds = max(1, self.micro_batch_num // self.real_stage_num) 

1053 if self.micro_batch_num < self.real_stage_num: 

1054 base = self.micro_batch_num - self.real_stage_num 

1055 remainder = 0 

1056 else: 

1057 n_extra_microbatch = self.micro_batch_num % self.real_stage_num 

1058 base = n_extra_microbatch // self.n_rounds 

1059 remainder = n_extra_microbatch % self.n_rounds 

1060 self.n_microbatch_per_round = \ 

1061 [self.real_stage_num + base + 1 if i < remainder else 

1062 self.real_stage_num + base for i in range(self.n_rounds)] 

1063 self.n_microbatch_per_round_accu = \ 

1064 [x * self.n_local_stages for x in itertools.accumulate(self.n_microbatch_per_round)] 

1065 self.n_microbatch_per_round_accu.insert(0, 0) 

1066 self.build_exec_order() 

1067 

1068 def construct_exec_order(self): 

1069 for stage_index in range(self.real_stage_num): 

1070 self.exec_order[stage_index] = self.construct_stage_exec_order(stage_index) 

1071 self.exec_order = add_send_recv(self.exec_order, self._stage_num, self.real_stage_num, style = 'loop') 

1072 

1073 def _build_stage_to_rank_index(self) -> None: 

1074 self._stage_to_rank_index = generate_stage_to_rank_mapping( 

1075 self.real_stage_num, self._stage_num, style='loop' 

1076 ) 

1077 

1078 def warmup_ops(self, stage_index): 

1079 """warmup phase.""" 

1080 warmup_ops_last_stage = (self.n_local_stages - 1) * self.n_microbatch_per_round[0] 

1081 warmup_ops = warmup_ops_last_stage + 2 * (self.real_stage_num - 1 - stage_index) 

1082 return min(warmup_ops, self.micro_batch_num * self.n_local_stages) 

1083 

1084 def forward_stage_index(self, op_index, stage_index): 

1085 """obtain forward stage_index based on op_index.""" 

1086 accu_index = bisect.bisect_right(self.n_microbatch_per_round_accu, op_index) - 1 

1087 local_index = (op_index - self.n_microbatch_per_round_accu[accu_index]) // \ 

1088 self.n_microbatch_per_round[accu_index] 

1089 return (local_index * self.real_stage_num) + stage_index 

1090 

1091 def backward_stage_index(self, op_index, stage_index): 

1092 """obtain backward stage_index based on op_index.""" 

1093 accu_index = bisect.bisect_right(self.n_microbatch_per_round_accu, op_index) - 1 

1094 local_index = (op_index - self.n_microbatch_per_round_accu[accu_index]) // \ 

1095 self.n_microbatch_per_round[accu_index] 

1096 local_index = self.n_local_stages - 1 - local_index 

1097 return (local_index * self.real_stage_num) + stage_index 

1098 

1099 def _short_micro(self) -> bool: 

1100 """True when ``micro_batch_num < real_stage_num`` (extra-bubble regime).""" 

1101 return self.micro_batch_num < self.real_stage_num 

1102 

1103 def _trailing_bubble(self) -> int: 

1104 """Bubble count appended after a BWD with ``micro == micro_batch_num - 1`` 

1105 in the short-micro regime. 

1106 """ 

1107 return self.real_stage_num - self.micro_batch_num 

1108 

1109 def _emit_warmup_ops(self, stage_index, warmup_ops, fwd_stage_micro_index): 

1110 """Emit pure-FWD warmup ops with optional short-micro bubble padding.""" 

1111 ops = [] 

1112 short = self._short_micro() 

1113 last_micro = self.micro_batch_num - 1 

1114 last_stage = self.real_stage_num - 1 

1115 bubble = self._trailing_bubble() 

1116 for op_idx in range(warmup_ops): 

1117 fwd_stage_idx = self.forward_stage_index(op_idx, stage_index) 

1118 fwd_micro_idx = fwd_stage_micro_index[fwd_stage_idx] 

1119 ops.append(MetaStep(fwd_micro_idx, MetaStepType.FWD, fwd_stage_idx)) 

1120 need_pad = ( 

1121 short 

1122 and fwd_micro_idx == last_micro 

1123 and (op_idx != warmup_ops - 1 or stage_index == last_stage) 

1124 ) 

1125 if need_pad: 

1126 ops.extend([None] * bubble) 

1127 fwd_stage_micro_index[fwd_stage_idx] += 1 

1128 return ops 

1129 

1130 def _emit_cooldown_ops(self, stage_index, warmup_ops, fwd_bwd_ops, total_ops, 

1131 bwd_stage_micro_index): 

1132 """Emit pure-BWD cooldown ops (each preceded by a bubble) with 

1133 optional short-micro trailing padding. 

1134 """ 

1135 ops = [] 

1136 short = self._short_micro() 

1137 last_micro = self.micro_batch_num - 1 

1138 bubble = self._trailing_bubble() 

1139 for op_idx in range(warmup_ops + fwd_bwd_ops, total_ops): 

1140 ops.append(None) 

1141 bwd_stage_idx = self.backward_stage_index(op_idx - warmup_ops, stage_index) 

1142 bwd_micro_idx = bwd_stage_micro_index[bwd_stage_idx] 

1143 ops.append(MetaStep(bwd_micro_idx, MetaStepType.BWD, bwd_stage_idx)) 

1144 if short and bwd_micro_idx == last_micro: 

1145 ops.extend([None] * bubble) 

1146 bwd_stage_micro_index[bwd_stage_idx] += 1 

1147 return ops 

1148 

1149 def _emit_1f1b_ops(self, stage_index, warmup_ops, fwd_bwd_ops, 

1150 fwd_stage_micro_index, bwd_stage_micro_index): 

1151 """Emit interleaved (FWD, BWD) pairs for the 1F1B steady-state phase.""" 

1152 ops = [] 

1153 short = self._short_micro() 

1154 last_micro = self.micro_batch_num - 1 

1155 last_stage = self.real_stage_num - 1 

1156 bubble = self._trailing_bubble() 

1157 for op_idx in range(warmup_ops, warmup_ops + fwd_bwd_ops): 

1158 fwd_stage_idx = self.forward_stage_index(op_idx, stage_index) 

1159 fwd_micro_idx = fwd_stage_micro_index[fwd_stage_idx] 

1160 ops.append(MetaStep(fwd_micro_idx, MetaStepType.FWD, fwd_stage_idx)) 

1161 fwd_stage_micro_index[fwd_stage_idx] += 1 

1162 bwd_stage_idx = self.backward_stage_index(op_idx - warmup_ops, stage_index) 

1163 bwd_micro_idx = bwd_stage_micro_index[bwd_stage_idx] 

1164 ops.append(MetaStep(bwd_micro_idx, MetaStepType.BWD, bwd_stage_idx)) 

1165 need_pad = ( 

1166 short 

1167 and bwd_micro_idx == last_micro 

1168 and stage_index == last_stage 

1169 ) 

1170 if need_pad: 

1171 ops.extend([None] * bubble) 

1172 bwd_stage_micro_index[bwd_stage_idx] += 1 

1173 return ops 

1174 

1175 @staticmethod 

1176 def _collect_fwd_bwd_steps(emit_fwd, emit_bwd, fwd_bwd_ops, warmup_ops): 

1177 """Walk the 1F1B range collecting parallel ``fwd_steps`` / ``bwd_steps``. 

1178 

1179 ``emit_fwd(op_idx)`` and ``emit_bwd(op_idx)`` build a single 

1180 :class:`MetaStep` and advance their respective per-stage micro 

1181 counters as a side effect. 

1182 """ 

1183 fwd_steps = [] 

1184 bwd_steps = [] 

1185 for op_idx in range(warmup_ops, warmup_ops + fwd_bwd_ops): 

1186 fwd_steps.append(emit_fwd(op_idx)) 

1187 bwd_steps.append(emit_bwd(op_idx)) 

1188 return fwd_steps, bwd_steps 

1189 

1190 @staticmethod 

1191 def _pair_into_overlap_b_f(fwd_steps, bwd_steps): 

1192 """Build ``F₁, [B_i, F_{i+1}], B_n`` ordering with OVERLAP_B_F pairs. 

1193 

1194 ``sub_steps`` carry the ``(bwd, fwd)`` tuple — callbacks access 

1195 them via ``step.sub_steps`` to recover per-direction stage / 

1196 micro info. 

1197 """ 

1198 ops = [] 

1199 if fwd_steps: 

1200 ops.append(fwd_steps[0]) # F₁ runs alone 

1201 for i in range(len(bwd_steps) - 1): 

1202 ops.append(MetaStep( 

1203 None, MetaStepType.OVERLAP_B_F, None, 

1204 sub_steps=(bwd_steps[i], fwd_steps[i + 1]), 

1205 )) 

1206 if bwd_steps: 

1207 ops.append(bwd_steps[-1]) # B_n runs alone 

1208 return ops 

1209 

1210 def _emit_1f1b_overlap_ops(self, stage_index, warmup_ops, fwd_bwd_ops, 

1211 fwd_stage_micro_index, bwd_stage_micro_index): 

1212 """Emit ``F₁, [B_i, F_{i+1}], B_n`` for the 1F1B phase under 

1213 ``overlap_b_f=True``. Each ``[B_i, F_{i+1}]`` becomes an 

1214 ``OVERLAP_B_F`` composite step; a registered callback drives the 

1215 actual concurrent execution. Short-micro extra-bubble padding 

1216 on the last rank is appended after ``B_n``. 

1217 """ 

1218 def emit_fwd(op_idx): 

1219 fwd_si = self.forward_stage_index(op_idx, stage_index) 

1220 fwd_mi = fwd_stage_micro_index[fwd_si] 

1221 fwd_stage_micro_index[fwd_si] += 1 

1222 return MetaStep(fwd_mi, MetaStepType.FWD, fwd_si) 

1223 

1224 def emit_bwd(op_idx): 

1225 bwd_si = self.backward_stage_index(op_idx - warmup_ops, stage_index) 

1226 bwd_mi = bwd_stage_micro_index[bwd_si] 

1227 bwd_stage_micro_index[bwd_si] += 1 

1228 return MetaStep(bwd_mi, MetaStepType.BWD, bwd_si) 

1229 

1230 fwd_steps, bwd_steps = self._collect_fwd_bwd_steps( 

1231 emit_fwd, emit_bwd, fwd_bwd_ops, warmup_ops, 

1232 ) 

1233 ops = self._pair_into_overlap_b_f(fwd_steps, bwd_steps) 

1234 

1235 last_stage = self.real_stage_num - 1 

1236 if self._short_micro() and stage_index == last_stage and bwd_steps: 

1237 if bwd_steps[-1].micro_index == self.micro_batch_num - 1: 

1238 ops.extend([None] * self._trailing_bubble()) 

1239 return ops 

1240 

1241 def construct_stage_exec_order(self, stage_index): 

1242 """Construct the execution order for ``stage_index``. 

1243 

1244 Builds: warmup → bubbles → 1F1B steady state → cooldown. The 

1245 1F1B segment switches between :meth:`_emit_1f1b_ops` (plain) and 

1246 :meth:`_emit_1f1b_overlap_ops` (OVERLAP_B_F pairing) based on 

1247 the ``overlap_b_f`` constructor flag. 

1248 """ 

1249 warmup_ops = self.warmup_ops(stage_index) 

1250 fwd_bwd_ops = self.n_local_stages * self.micro_batch_num - warmup_ops 

1251 total_ops = 2 * warmup_ops + fwd_bwd_ops 

1252 order_list = [None for _ in range(stage_index)] 

1253 fwd_stage_micro_index = defaultdict(int) 

1254 bwd_stage_micro_index = defaultdict(int) 

1255 order_list.extend(self._emit_warmup_ops(stage_index, warmup_ops, fwd_stage_micro_index)) 

1256 bubbles_before_1f1b = max( 

1257 0, 

1258 2 * (self.real_stage_num - stage_index - 1) - self.micro_batch_num, 

1259 ) 

1260 order_list.extend([None] * bubbles_before_1f1b) 

1261 order_list.extend([None] * (self.real_stage_num - 1 - stage_index)) 

1262 if self._overlap_b_f: 

1263 order_list.extend(self._emit_1f1b_overlap_ops( 

1264 stage_index, warmup_ops, fwd_bwd_ops, 

1265 fwd_stage_micro_index, bwd_stage_micro_index, 

1266 )) 

1267 else: 

1268 order_list.extend(self._emit_1f1b_ops( 

1269 stage_index, warmup_ops, fwd_bwd_ops, 

1270 fwd_stage_micro_index, bwd_stage_micro_index, 

1271 )) 

1272 order_list.extend(self._emit_cooldown_ops( 

1273 stage_index, warmup_ops, fwd_bwd_ops, total_ops, bwd_stage_micro_index, 

1274 )) 

1275 return order_list 

1276 

1277 

1278def detect_cycle_in_graph(ranks_map): 

1279 """ 

1280 Detects a cycle in the directed graph constructed from ranks_map. 

1281 

1282 Args: 

1283 ranks_map: A dictionary where keys are rank names and values are lists of nodes. 

1284 

1285 Returns: 

1286 tuple: (cycle_path, cycle_ranks) where cycle_path is a list of nodes forming the cycle and cycle_ranks 

1287 is a list of rank transitions corresponding to the cycle path. 

1288 """ 

1289 graph = defaultdict(list) 

1290 rank_edges = {} 

1291 

1292 for rank, nodes in ranks_map.items(): 

1293 for i in range(len(nodes) - 1): 

1294 u, v = nodes[i], nodes[i + 1] 

1295 graph[u].append(v) 

1296 rank_edges[(u, v)] = rank 

1297 

1298 visited = set() 

1299 path = [] 

1300 node_indices = {} 

1301 cycle_path = [] 

1302 cycle_ranks = [] 

1303 

1304 stack = [] 

1305 for node in list(graph.keys()): 

1306 if node not in visited: 

1307 stack.append((node, False)) 

1308 while stack: 

1309 current_node, is_processed = stack.pop() 

1310 

1311 if is_processed: 

1312 path.pop() 

1313 del node_indices[current_node] 

1314 continue 

1315 

1316 if current_node in node_indices: 

1317 cycle_start = node_indices[current_node] 

1318 cycle_path = path[cycle_start:] + [current_node] 

1319 for i in range(cycle_start, len(path)): 

1320 u = path[i] 

1321 v = path[i + 1] if i + 1 < len(path) else current_node 

1322 cycle_ranks.append(f"{rank_edges[(u, v)]} {u} -> {v}") 

1323 return cycle_path, cycle_ranks 

1324 

1325 if current_node in visited: 

1326 continue 

1327 

1328 visited.add(current_node) 

1329 node_indices[current_node] = len(path) 

1330 path.append(current_node) 

1331 

1332 stack.append((current_node, True)) 

1333 for neighbor in reversed(graph[current_node]): 

1334 stack.append((neighbor, False)) 

1335 

1336 return None, None 

1337 

1338 

1339def output_cycle_results(cycle_path, cycle_ranks): 

1340 """ 

1341 Helper function to output cycle detection results. 

1342 

1343 Args: 

1344 cycle_path (list): List of nodes forming a cycle, if any. 

1345 cycle_ranks (list): List of ranks involved in the cycle. 

1346 

1347 Returns: 

1348 None: Outputs results to the console. 

1349 """ 

1350 if cycle_path: 

1351 logger.error("Cycle detected:") 

1352 path_str = " -> ".join(str(node) for node in cycle_path) 

1353 logger.error("%s -> %s", path_str, cycle_path[0]) # Close the cycle 

1354 logger.error("Involving ranks:") 

1355 for rank in cycle_ranks: 

1356 logger.error(rank) 

1357 else: 

1358 logger.warning("Cycle Check succeeded. There is no cycle in the graph.") 

1359 

1360 

1361def parse_and_validate(data: dict, all_rank: bool = True): 

1362 """ 

1363 Parse and validate execution orders in a directed graph structure. 

1364 

1365 This function checks the integrity and consistency of a given dataset, ensuring all required 

1366 keys are present and correctly referenced. It also validates the structure of the input data 

1367 and parses string values to extract meaningful components. 

1368 

1369 Args: 

1370 data (dict): A dictionary where keys are string identifiers and values are lists of strings. 

1371 Each value represents a dependency or reference to other keys. 

1372 all_rank (bool): If True, checks that all elements referenced in the data are present as keys 

1373 in the dictionary. If False, only checks intersections. 

1374 

1375 Returns: 

1376 None: Log error messages to the console if validation fails, otherwise completes silently. 

1377 

1378 Raises: 

1379 ValueError: Raised indirectly if `parse_elements` encounters malformed input strings. 

1380 TypeError: Raised indirectly if data contains unexpected types. 

1381 """ 

1382 

1383 def parse_elements(value: str, max_groups: int = 2) -> set: 

1384 """Extract unique elements inside the first one or two parentheses from a string.""" 

1385 

1386 groups = re.findall(r'\((\d+)\)', value) 

1387 limited_groups = groups[:max_groups] # Limit to the first `max_groups` matches 

1388 

1389 return {item.strip() for item in limited_groups} 

1390 

1391 if not isinstance(data, dict): 

1392 logger.error("Input must be a dictionary with string keys and lists of strings as values.") 

1393 return 

1394 

1395 key_to_values = {key: set(values) for key, values in data.items() if 

1396 isinstance(values, list) and all(isinstance(v, str) for v in values)} 

1397 

1398 for key, values in data.items(): 

1399 if not isinstance(values, list) or not all(isinstance(v, str) for v in values): 

1400 logger.error("Values for key '%s' must be a list of strings.", key) 

1401 continue 

1402 

1403 for value in values: 

1404 try: 

1405 elements = parse_elements(value) 

1406 except (ValueError, TypeError, AttributeError) as e: 

1407 logger.error("Unable to parse elements from value '%s' in key '%s'. Error: %s", value, key, e) 

1408 continue 

1409 

1410 # Check for missing keys if all_rank is True 

1411 if all_rank: 

1412 missing_keys = elements - key_to_values.keys() 

1413 if missing_keys: 

1414 logger.error("The following keys are missing for value '%s': %s", value, missing_keys) 

1415 continue 

1416 

1417 # Check if the value is present in the referenced keys 

1418 for element in elements & key_to_values.keys() if not all_rank else elements: 

1419 if value not in key_to_values[element]: 

1420 logger.error("Key '%s' is missing the value '%s'.", element, value) 

1421 

1422 

1423def generate_operations(order_list: dict[int, list[MetaStep]], 

1424 chunk_num: int, 

1425 com_type: str = 'loop') -> dict[str, list[str]]: 

1426 """ 

1427 Generate formatted operations dictionary from pipeline execution order. 

1428 

1429 Args: 

1430 order_list (dict): Dictionary where keys are rank IDs and values are MetaStep execution sequences 

1431 chunk_num (int): Number of chunks (virtual pipeline stages) 

1432 com_type (str): Stage-to-rank mapping type ('loop' for cyclic, 'v' for V-shaped) 

1433 

1434 Returns: 

1435 Dictionary where keys are rank IDs (as strings) and values are lists of formatted operation strings 

1436 """ 

1437 

1438 def stage_to_rank(stage_index, style, stage_num, real_stage_num): 

1439 """Map stage index to rank""" 

1440 if style == 'loop': 

1441 return stage_index % real_stage_num 

1442 if style == 'v': 

1443 if stage_index < real_stage_num: 

1444 return stage_index 

1445 return stage_num - 1 - stage_index 

1446 raise ValueError("Invalid style") 

1447 

1448 def find_send_target(stage_idx, op_type): 

1449 """Find target stage for SEND operation""" 

1450 if op_type == MetaStepType.FWD_SEND: 

1451 return forward_comm.get(stage_idx) 

1452 return backward_comm.get(stage_idx) 

1453 

1454 def find_recv_source(stage_idx, op_type): 

1455 """Find source stage for RECV operation""" 

1456 if op_type == MetaStepType.FWD_RECV: 

1457 # Reverse lookup in forward_comm 

1458 for src, dst in forward_comm.items(): 

1459 if dst == stage_idx: 

1460 return src 

1461 else: 

1462 # Reverse lookup in backward_comm 

1463 for src, dst in backward_comm.items(): 

1464 if dst == stage_idx: 

1465 return src 

1466 return None 

1467 

1468 real_stage = len(order_list) 

1469 total_stages = real_stage * chunk_num 

1470 

1471 # Build communication rules 

1472 forward_comm = {} 

1473 backward_comm = {} 

1474 

1475 for i in range(total_stages): 

1476 if i + 1 < total_stages: 

1477 forward_comm[i] = i + 1 

1478 if i - 1 >= 0: 

1479 backward_comm[i] = i - 1 

1480 

1481 formatted_operations = defaultdict(list) 

1482 

1483 for rank, steps in order_list.items(): 

1484 operation_counter = defaultdict(int) 

1485 

1486 for step in steps: 

1487 if step.type in [MetaStepType.FWD_SEND, MetaStepType.BWD_SEND]: 

1488 target_stage = find_send_target(step.stage_index, step.type) 

1489 if target_stage is not None: 

1490 target_rank = stage_to_rank(target_stage, com_type, total_stages, real_stage) 

1491 comm_pair = (rank, target_rank, step.micro_index) 

1492 operation_counter[comm_pair] += 1 

1493 count = operation_counter[comm_pair] 

1494 formatted_op = f"Send_Receive_({rank})->({target_rank})_micro{step.micro_index}_{count}th" 

1495 formatted_operations[str(rank)].append(formatted_op) 

1496 

1497 elif step.type in [MetaStepType.FWD_RECV, MetaStepType.BWD_RECV]: 

1498 source_stage = find_recv_source(step.stage_index, step.type) 

1499 if source_stage is not None: 

1500 source_rank = stage_to_rank(source_stage, com_type, total_stages, real_stage) 

1501 comm_pair = (source_rank, rank, step.micro_index) 

1502 operation_counter[comm_pair] += 1 

1503 count = operation_counter[comm_pair] 

1504 formatted_op = f"Send_Receive_({source_rank})->({rank})_micro{step.micro_index}_{count}th" 

1505 formatted_operations[str(rank)].append(formatted_op) 

1506 

1507 # Convert defaultdict to dict 

1508 return dict(formatted_operations) 

1509 

1510 

1511def validate_pipeline_execution(order_list: dict[int, list[MetaStep]], 

1512 chunk_num: int, 

1513 com_type: str = 'loop') -> dict[str, any]: 

1514 """ 

1515 Comprehensive validation function for pipeline parallel execution order. 

1516 

1517 This function validates the execution order of pipeline parallelism by: 

1518 1. Checking SEND/RECV communication pair matching 

1519 2. Detecting duplicate operations 

1520 3. Detecting cycles in communication graphs 

1521 4. Verifying computation-SEND matching 

1522 

1523 Args: 

1524 order_list: Dictionary where keys are rank IDs and values are MetaStep execution sequences 

1525 chunk_num: Number of chunks (virtual pipeline stages) 

1526 com_type: Stage-to-rank mapping type ('loop' for cyclic, 'v' for V-shaped) 

1527 

1528 Returns: 

1529 Dictionary containing validation results with the following keys: 

1530 - validation: Communication pair validation results 

1531 - cycle_detection: Cycle detection results 

1532 - computation_send_matching: Computation-SEND matching validation results 

1533 - has_errors: Boolean indicating if any errors were found 

1534 - error_messages: List of all error messages found 

1535 - formatted_operations: Generated formatted operations 

1536 """ 

1537 

1538 # Generate operations 

1539 formatted_operations = generate_operations(order_list, chunk_num, com_type) 

1540 

1541 parse_and_validate(formatted_operations, True) 

1542 

1543 # Detect cycles 

1544 cycle_path, cycle_ranks = detect_cycle_in_graph(formatted_operations) 

1545 

1546 # Output results 

1547 output_cycle_results(cycle_path, cycle_ranks) 

1548 

1549 result = { 

1550 'formatted_operations': formatted_operations, 

1551 'cycle_path': cycle_path, 

1552 'cycle_ranks': cycle_ranks, 

1553 'has_cycle': bool(cycle_path) 

1554 } 

1555 return result 

1556 

1557 

1558_COMPUTE_META_STEP_TYPES = frozenset({ 

1559 MetaStepType.FWD, 

1560 MetaStepType.BWD, 

1561 MetaStepType.BWD_INPUT, 

1562 MetaStepType.BWD_WEIGHT, 

1563}) 

1564 

1565 

1566def _next_active_stage_indices(actions, start_index, max_active_stages, managed_stage_indices): 

1567 """Find the next distinct managed stages that will execute compute work. 

1568 

1569 Send/recv and previously injected FSDP control steps are skipped so that the 

1570 lookahead window only counts real compute, otherwise communication-only 

1571 actions would consume the budget and shrink the effective prefetch depth. 

1572 """ 

1573 stage_indices = [] 

1574 seen = set() 

1575 for action in actions[start_index:]: 

1576 for leaf_step in iter_leaf_meta_steps(action): 

1577 if leaf_step.type not in _COMPUTE_META_STEP_TYPES: 

1578 continue 

1579 if leaf_step.stage_index not in managed_stage_indices or leaf_step.stage_index in seen: 

1580 continue 

1581 seen.add(leaf_step.stage_index) 

1582 stage_indices.append(leaf_step.stage_index) 

1583 if len(stage_indices) == max_active_stages: 

1584 return stage_indices 

1585 return stage_indices 

1586 

1587 

1588def add_fsdp_unshard_reshard(actions, managed_stage_indices, max_active_stages=3): 

1589 """Insert FSDP unshard/reshard actions for locally managed stages.""" 

1590 if not managed_stage_indices: 

1591 return actions 

1592 

1593 fsdp_actions = [] 

1594 active_stages = [] 

1595 for index, action in enumerate(actions): 

1596 next_stage_indices = _next_active_stage_indices( 

1597 actions, index, max_active_stages, managed_stage_indices 

1598 ) 

1599 evicted_stages = [stage_index for stage_index in active_stages if stage_index not in next_stage_indices] 

1600 fetched_stages = [stage_index for stage_index in next_stage_indices if stage_index not in active_stages] 

1601 for stage_index in evicted_stages: 

1602 fsdp_actions.append(MetaStep(None, MetaStepType.FSDP_RESHARD, stage_index)) 

1603 active_stages.remove(stage_index) 

1604 for stage_index in fetched_stages: 

1605 fsdp_actions.append(MetaStep(None, MetaStepType.FSDP_UNSHARD, stage_index)) 

1606 active_stages.append(stage_index) 

1607 fsdp_actions.append(action) 

1608 

1609 while active_stages: 

1610 fsdp_actions.append(MetaStep(None, MetaStepType.FSDP_RESHARD, active_stages.pop(0))) 

1611 return fsdp_actions 

1612 

1613 

1614def add_fsdp_reduce_grad(actions, managed_stage_indices, micro_batch_num): 

1615 """Insert FSDP reduce-grad actions after the last backward-like action of each stage.""" 

1616 if not managed_stage_indices: 

1617 return actions 

1618 

1619 fsdp_actions = [] 

1620 for action in actions: 

1621 fsdp_actions.append(action) 

1622 reduced_stage_indices = [] 

1623 for leaf_step in iter_leaf_meta_steps(action): 

1624 if leaf_step.stage_index not in managed_stage_indices: 

1625 continue 

1626 if leaf_step.type not in (MetaStepType.BWD, MetaStepType.BWD_WEIGHT): 

1627 continue 

1628 if leaf_step.micro_index != micro_batch_num - 1: 

1629 continue 

1630 if leaf_step.stage_index not in reduced_stage_indices: 

1631 reduced_stage_indices.append(leaf_step.stage_index) 

1632 for stage_index in reduced_stage_indices: 

1633 fsdp_actions.append(MetaStep(None, MetaStepType.FSDP_REDUCE_GRAD, stage_index)) 

1634 return fsdp_actions