Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / pipeline_parallel / scheduler.py: 14%
848 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""pipeline schedule"""
16from abc import ABC, abstractmethod
17from enum import Enum, auto
18from collections import defaultdict
19import itertools
20import bisect
21import logging
22import re
23import hyper_parallel
24from hyper_parallel.platform import get_platform
25from hyper_parallel.core.fully_shard.api import HSDPModule
26platform = get_platform()
27logger = logging.getLogger(__name__)
30class MetaStepType(Enum):
31 """Specify the enumeration type for MetaStep."""
32 FWD = auto()
33 BWD = auto()
34 BWD_INPUT = auto()
35 BWD_WEIGHT = auto()
36 FWD_RECV = auto()
37 FWD_SEND = auto()
38 BWD_RECV = auto()
39 BWD_SEND = auto()
40 OVERLAP_F_B = auto()
41 OVERLAP_B_F = auto()
42 FSDP_UNSHARD = auto()
43 FSDP_RESHARD = auto()
44 FSDP_REDUCE_GRAD = auto()
47class MetaStep:
48 """
49 Meta step of PipelineSchedule.
50 An execution list composed of MetaStep can be constructed
51 and fed into the PipelineSchedule for execution.
53 Args:
54 micro_index (int | None): The index of micro-batch. ``None`` for
55 composite types (``OVERLAP_F_B`` / ``OVERLAP_B_F``) whose real
56 micro index lives in each ``sub_steps`` entry.
57 type (MetaStepType): Specify the type of current step.
58 stage_index (int | None): Stage index of current step. ``None``
59 for composite types; use ``sub_steps`` to get each direction's
60 stage.
61 sub_steps (tuple[MetaStep, MetaStep] | None): For composite types
62 only: ``(fwd, bwd)`` for ``OVERLAP_F_B``, ``(bwd, fwd)`` for
63 ``OVERLAP_B_F``.
64 """
65 def __init__(self, micro_index, meta_type, stage_index, sub_steps=None):
66 self._type = meta_type
67 self._micro_index = micro_index
68 self._stage_index = stage_index
69 self._sub_steps = sub_steps
71 @property
72 def micro_index(self):
73 return self._micro_index
75 @property
76 def stage_index(self):
77 return self._stage_index
79 @property
80 def type(self):
81 return self._type
83 @property
84 def sub_steps(self):
85 """Sub-steps for composite types: ``(fwd, bwd)`` for OVERLAP_F_B,
86 ``(bwd, fwd)`` for OVERLAP_B_F, or ``None``."""
87 return self._sub_steps
89 def __eq__(self, value):
90 if not isinstance(value, MetaStep):
91 return NotImplemented
92 return (self.type == value.type
93 and self.micro_index == value.micro_index
94 and self.stage_index == value.stage_index
95 and self.sub_steps == value.sub_steps)
97 def __ne__(self, value):
98 if not isinstance(value, MetaStep):
99 return NotImplemented
100 return not self.__eq__(value)
102 def __hash__(self):
103 return hash((self.type, self.micro_index, self.stage_index))
105 def __str__(self):
106 if self.sub_steps:
107 sub = ", ".join(str(s) for s in self.sub_steps)
108 return (f"MetaStep(type={self.type}, micro_index={self.micro_index}, "
109 f"stage_index={self.stage_index}, sub_steps=[{sub}])")
110 return f"MetaStep(type={self.type}, micro_index={self.micro_index}, stage_index={self.stage_index})"
112 def __repr__(self):
113 return self.__str__()
115 @staticmethod
116 def from_str(step_str):
117 pass
119def generate_stage_to_rank_mapping(real_stage_num, stage_num, style='loop'):
120 """Generate stage to rank mapping for loop or V schedules."""
121 if style == 'loop':
122 return {stage_index: stage_index % real_stage_num for stage_index in range(stage_num)}
123 if style == 'v':
124 if stage_num % real_stage_num != 0:
125 raise ValueError(
126 f"stage_num {stage_num} must be evenly divisible by real_stage_num {real_stage_num} for V schedules."
127 )
128 mapping = {}
129 rank_index = 0
130 for stage_index in range(stage_num):
131 mapping[stage_index] = rank_index
132 if (stage_index + 1) % real_stage_num == 0:
133 continue
134 if (stage_index // real_stage_num) % 2 == 0:
135 rank_index += 1
136 else:
137 rank_index -= 1
138 return mapping
139 raise ValueError(f"Unsupported stage rank mapping style: {style}")
142def generate_rank_to_stage_mapping(real_stage_num, stage_num, style='loop'):
143 """Invert the stage to rank mapping."""
144 stage_to_rank = generate_stage_to_rank_mapping(real_stage_num, stage_num, style)
145 rank_to_stages = defaultdict(list)
146 for stage_index, rank in stage_to_rank.items():
147 rank_to_stages[rank].append(stage_index)
148 for stages in rank_to_stages.values():
149 stages.sort()
150 return dict(rank_to_stages)
152def iter_leaf_meta_steps(step):
153 """Yield leaf MetaSteps, recursively expanding OVERLAP_F_B containers."""
154 if step is None:
155 return
156 if step.type == MetaStepType.OVERLAP_F_B:
157 for sub_step in step.sub_steps:
158 yield from iter_leaf_meta_steps(sub_step)
159 return
160 yield step
162class PipelineContext:
163 """Context passed to custom execution functions registered via
164 :meth:`PipelineScheduleRuntime.register_custom_function`.
166 Provides access to the schedule's internal state so that custom
167 handlers (e.g. OVERLAP_F_B callbacks) can perform P2P communication,
168 invoke ``forward_one_chunk`` / ``backward_one_chunk``, record losses,
169 etc.
171 Attributes:
172 schedule: The :class:`PipelineScheduleRuntime` instance.
173 arg_mbs: Per-micro-batch positional args.
174 kwarg_mbs: Per-micro-batch keyword args.
175 losses: Mutable list for loss collection.
176 fwd_recv_ops: ``{(stage_index, micro_index): [handle, ...]}``
177 cached forward recv handles (when ``overlap_p2p=True``).
178 bwd_recv_ops: Same for backward recv handles.
179 send_handles: Mutable list of outstanding send handles.
180 """
182 def __init__(self, schedule, arg_mbs, kwarg_mbs, losses, send_handles):
183 self.schedule = schedule
184 self.arg_mbs = arg_mbs
185 self.kwarg_mbs = kwarg_mbs
186 self.losses = losses
187 self.fwd_recv_ops = schedule.fwd_handle_cache
188 self.bwd_recv_ops = schedule.bwd_handle_cache
189 self.send_handles = send_handles
192class PipelineScheduleRuntime(ABC):
193 """
194 Base class for pipeline schedule.
195 Implements the `split_microbatches` and `run_microbatches` method.
196 Derived classes should implement `run_microbatches` method and `run` method.
198 Supports registering **custom execution functions** for any
199 :class:`MetaStepType` via :meth:`register_custom_function`. When
200 ``run_microbatches`` encounters a step whose type has a registered
201 handler, it creates a :class:`PipelineContext` and delegates execution
202 to the handler instead of using the built-in logic.
204 Args:
205 stages (list[PipelineStage], PipelineStage): PipelineStage used to run_microbatches.
206 micro_batch_num (int): The number of micro-batch.
207 args_batch_dim (list, optional): Specify the batch dim of the args.
208 Default ``None``.
209 kwargs_batch_dim (dict, optional): Specify the batch dim of the kwargs.
210 Default ``None``.
211 """
212 def __init__(self,
213 stages,
214 micro_batch_num,
215 args_batch_dim=None,
216 kwargs_batch_dim=None,
217 output_concat_dim=None,
218 overlap_p2p=False):
219 self.stages = self._check_stages(stages)
220 self.micro_batch_num = micro_batch_num
221 self._args_batch_dim = args_batch_dim
222 self._kwargs_batch_dim = kwargs_batch_dim
223 self._output_concat_dim = output_concat_dim
224 self.split_micro_batch = platform.micro_batch(self.micro_batch_num,
225 self._args_batch_dim, self._kwargs_batch_dim)
226 self.n_local_stages = len(self.stages)
227 self._stage_dict = self.convert_stages_dict()
228 self.real_stage_num = self.stages[0].stage_num // self.n_local_stages
229 self._stage_num = self.stages[0].stage_num
230 self._stage_to_rank_index = None
231 self._overlap_p2p = overlap_p2p
232 self.exec_order = {}
233 self._init_stages()
234 self._build_stage_to_rank_index()
235 self.fwd_handle_cache = {}
236 self.bwd_handle_cache = {}
237 self._custom_fn_map = {}
239 def register_custom_function(self, step_type: MetaStepType, fn) -> None:
240 """Register a custom execution function for the given step type.
242 When :meth:`run_microbatches` encounters a :class:`MetaStep` whose
243 ``type`` matches ``step_type``, it calls ``fn(step, ctx)`` instead
244 of the built-in logic.
246 Args:
247 step_type: The :class:`MetaStepType` to intercept.
248 fn: A callable with signature ``(step: MetaStep, ctx: PipelineContext) -> None``.
250 Example:
251 >>> def my_overlap_callback(step, ctx):
252 ... fwd_step, bwd_step = step.sub_steps
253 ... # custom parallel execution logic
254 >>> schedule.register_custom_function(MetaStepType.OVERLAP_F_B, my_overlap_callback)
255 """
256 self._custom_fn_map[step_type] = fn
258 def _inject_local_fsdp_actions(self):
259 """Annotate the local rank schedule with optional FSDP control actions."""
260 current_rank = self._stage_to_rank_index[self.stages[0].stage_index]
261 managed_stage_indices = {
262 stage.stage_index
263 for stage in self.stages
264 if isinstance(stage.submodule, HSDPModule)
265 }
266 if not managed_stage_indices:
267 return
268 if len(managed_stage_indices) != len(self.stages):
269 raise RuntimeError(
270 "When injecting fsdp_action, expect all stages to be HSDPModule. "
271 "Check whether all separated modules are wrapped with 'fully_shard'."
272 )
273 rank_actions = add_fsdp_unshard_reshard(self.exec_order[current_rank], managed_stage_indices)
274 self.exec_order[current_rank] = add_fsdp_reduce_grad(
275 rank_actions,
276 managed_stage_indices,
277 self.micro_batch_num,
278 )
280 @abstractmethod
281 def _build_stage_to_rank_index(self) -> None:
282 """
283 Build attribute of _stage_to_rank_index.
284 Each subclass constructs it according to its own schedule style.
285 """
287 @abstractmethod
288 def construct_exec_order(self) -> None:
289 """Build exec order, PP cmopute and PP comms(Send/Recv)"""
291 def build_exec_order(self) -> None:
292 """Build the execution order and inject FSDP actions."""
293 self.construct_exec_order()
294 self._inject_local_fsdp_actions()
296 def convert_stages_dict(self):
297 """convert stages to dict."""
298 stage_dict = {}
299 for stage in self.stages:
300 stage_dict[stage.stage_index] = stage
301 return stage_dict
303 def split_microbatches(self, args, kwargs):
304 """split_microbatches."""
305 if args or kwargs:
306 args_split, kwargs_split = self.split_micro_batch(args, kwargs)
307 return args_split, kwargs_split
308 return [[] for _ in range(self.micro_batch_num)], [{} for _ in range(self.micro_batch_num)]
310 def _check_stages(self, stages):
311 """check stages type."""
312 if isinstance(stages, hyper_parallel.PipelineStage):
313 return [stages]
314 if isinstance(stages, (list, tuple)):
315 for stage in stages:
316 if not isinstance(stage, hyper_parallel.PipelineStage):
317 raise TypeError(f"Argument 'stages' must be type of PipelineStage, \
318 list or tuple of PipelineStage, but got list or tuple of {type(stage)}.")
319 return stages
320 raise TypeError(f"Argument 'stages' must be type of PipelineStage, \
321 list or tuple of PipelineStage, but got type of {type(stages)}.")
323 def _init_stages(self):
324 """init stages."""
325 for stage in self.stages:
326 stage.init(self.n_local_stages)
328 def run(self, *args, **kwargs):
329 """schedule run."""
330 split_args, split_kwargs = self.split_microbatches(args, kwargs)
331 losses = []
332 self.run_microbatches(split_args, split_kwargs, losses)
333 return losses
335 def sync_shared_parameters_grad(self):
336 """sync_shared_parameters_grad."""
337 for stage in self.stages:
338 stage.sync_shared_parameters_grad()
340 def update_losses(self, stage, loss, losses):
341 """update_losses."""
342 if stage.is_last_stage:
343 losses.append(loss)
345 def _wait_p2p(self, handles):
346 for handle in handles:
347 if handle is not None:
348 handle.wait()
350 def _assert_in_unshard_if_needed(self, stage, check_step):
351 if not isinstance(stage.submodule, HSDPModule):
352 return
353 submodule_hsdp_scheduler = stage.submodule.hsdp_scheduler
354 scheduler_state = submodule_hsdp_scheduler.hsdp_state
355 if scheduler_state.is_shard:
356 raise RuntimeError(
357 f"Executing MetaStep: {check_step}, expected HSDPModule parameters in unsharded "
358 f"state, but got sharded parameters."
359 )
361 def _exec_step(self, cur_step, arg_mbs, kwarg_mbs, losses, send_handles):
362 """Execute a single built-in step (FWD/BWD/SEND/RECV)."""
363 stage = self._stage_dict[cur_step.stage_index]
364 stage_index = cur_step.stage_index
365 micro_index = cur_step.micro_index
367 if cur_step.type == MetaStepType.FWD_RECV:
368 comm_handle = stage.exec_fwd_recv_ops(micro_index)
369 if not self._overlap_p2p:
370 self._wait_p2p(comm_handle)
371 else:
372 self.fwd_handle_cache[(stage_index, micro_index)] = comm_handle
374 elif cur_step.type == MetaStepType.FWD:
375 self._assert_in_unshard_if_needed(stage, cur_step)
376 key = (stage_index, micro_index)
377 if self._overlap_p2p and key in self.fwd_handle_cache:
378 self._wait_p2p(self.fwd_handle_cache.pop(key))
379 out = stage.forward_one_chunk(micro_index, arg_mbs[micro_index], kwarg_mbs[micro_index])
380 self.update_losses(stage, out, losses)
382 elif cur_step.type == MetaStepType.FWD_SEND:
383 comm_handle = stage.exec_fwd_send_ops(micro_index)
384 if not self._overlap_p2p:
385 self._wait_p2p(comm_handle)
386 else:
387 send_handles.append(comm_handle)
389 elif cur_step.type == MetaStepType.BWD_RECV:
390 comm_handle = stage.exec_bwd_recv_ops(micro_index)
391 if not self._overlap_p2p:
392 self._wait_p2p(comm_handle)
393 else:
394 self.bwd_handle_cache[(stage_index, micro_index)] = comm_handle
396 elif cur_step.type == MetaStepType.BWD:
397 self._assert_in_unshard_if_needed(stage, cur_step)
398 key = (stage_index, micro_index)
399 if self._overlap_p2p and key in self.bwd_handle_cache:
400 self._wait_p2p(self.bwd_handle_cache.pop(key))
401 is_last_microbatch = micro_index == self.micro_batch_num - 1
402 stage.backward_one_chunk(micro_index, is_last_microbatch)
404 elif cur_step.type == MetaStepType.BWD_SEND:
405 comm_handle = stage.exec_bwd_send_ops(micro_index)
406 if not self._overlap_p2p:
407 self._wait_p2p(comm_handle)
408 else:
409 send_handles.append(comm_handle)
411 elif cur_step.type in (
412 MetaStepType.FSDP_UNSHARD,
413 MetaStepType.FSDP_RESHARD,
414 MetaStepType.FSDP_REDUCE_GRAD,
415 ):
416 self._exec_fsdp_step(cur_step, stage)
418 def _exec_fsdp_step(self, cur_step, stage):
419 """Execute an FSDP control step (unshard, reshard, or reduce-grad)."""
420 if cur_step.type == MetaStepType.FSDP_UNSHARD:
421 for _, module in platform.get_cells_and_names(stage.submodule):
422 if isinstance(module, HSDPModule):
423 module.unshard()
424 elif cur_step.type == MetaStepType.FSDP_RESHARD:
425 for _, module in platform.get_cells_and_names(stage.submodule):
426 if isinstance(module, HSDPModule):
427 module.reshard()
428 elif cur_step.type == MetaStepType.FSDP_REDUCE_GRAD:
429 stage.execute_reduce_grad()
431 def run_microbatches(self, arg_mbs, kwarg_mbs, losses):
432 """Execute the schedule step by step.
434 Steps whose :attr:`MetaStep.type` has a registered custom function
435 are delegated to that function with a :class:`PipelineContext`.
436 Composite ``OVERLAP_F_B`` / ``OVERLAP_B_F`` steps without a
437 registered handler fall back to executing their ``sub_steps``
438 sequentially via :meth:`_exec_step` — correct but without
439 comm/compute overlap. All other steps are executed by
440 :meth:`_exec_step`.
441 """
442 real_stage_index = self.stages[0].stage_index % self.real_stage_num
443 send_handles = []
444 ctx = None # lazily created
446 for cur_step in self.exec_order[real_stage_index]:
447 if cur_step is None:
448 continue
450 # Check for registered custom function
451 custom_fn = self._custom_fn_map.get(cur_step.type)
452 if custom_fn is not None:
453 if ctx is None:
454 ctx = PipelineContext(self, arg_mbs, kwarg_mbs, losses, send_handles)
455 custom_fn(cur_step, ctx)
456 continue
458 # Default for composite OVERLAP steps: run sub_steps sequentially.
459 # P2P send/recv around these steps are already laid out in two
460 # virtual slots by ``add_send_recv``, so sequential execution is
461 # semantically equivalent to non-overlapped 1F1B.
462 if (cur_step.type in (MetaStepType.OVERLAP_F_B, MetaStepType.OVERLAP_B_F)
463 and cur_step.sub_steps):
464 for sub in cur_step.sub_steps:
465 self._exec_step(sub, arg_mbs, kwarg_mbs, losses, send_handles)
466 continue
468 self._exec_step(cur_step, arg_mbs, kwarg_mbs, losses, send_handles)
470 self.sync_shared_parameters_grad()
471 while send_handles:
472 self._wait_p2p(send_handles.pop())
475class _OverlapPhantom:
476 """Internal marker used by :func:`add_send_recv` to expand an
477 ``OVERLAP_F_B`` or ``OVERLAP_B_F`` step into two virtual time slots.
479 An overlap step composes two sub-steps (``B + F`` or ``F + B``) that
480 execute concurrently on the GPU but occupy **two** logical time slots
481 in the column-scan sender timeline — the sender can only finish
482 emitting the second sub-step's output after the first sub-step has
483 completed. Treating an overlap step as a single slot places the RECV
484 triggered by the second sub-step too early on the receiver.
486 Each overlap step is expanded into two phantoms:
487 * ``is_first_half=True`` — represents the first sub-step's emission
488 slot; the original overlap step is emitted into the output
489 schedule here (only once).
490 * ``is_first_half=False`` — represents the second sub-step's emission
491 slot; only its send/recv comms are inserted.
492 """
494 __slots__ = ('obf_step', 'sub_step', 'is_first_half')
496 def __init__(self, obf_step, sub_step, is_first_half: bool):
497 self.obf_step = obf_step
498 self.sub_step = sub_step
499 self.is_first_half = is_first_half
502def _expand_overlap_slots(scheduler, real_stage_num):
503 """Expand OVERLAP steps in a per-rank schedule into 2 virtual time slots.
505 Returns a new ``{rank: [MetaStep | _OverlapPhantom | None, ...]}`` dict
506 where each OVERLAP step is replaced by a pair of phantoms. Non-OVERLAP
507 entries pass through unchanged.
508 """
509 expanded = {}
510 for rank in range(real_stage_num):
511 order = scheduler[rank]
512 exp = []
513 for op in order:
514 if (op is not None
515 and op.type in (MetaStepType.OVERLAP_F_B, MetaStepType.OVERLAP_B_F)
516 and op.sub_steps):
517 exp.append(_OverlapPhantom(op, op.sub_steps[0], is_first_half=True))
518 exp.append(_OverlapPhantom(op, op.sub_steps[1], is_first_half=False))
519 else:
520 exp.append(op)
521 expanded[rank] = exp
522 return expanded
525def _process_rank_items(real_stage_num, current_items, insert_step_comms, new_schedule):
526 """Run ``insert_step_comms`` for each rank's current item, even ranks first.
528 Even-before-odd ordering avoids P2P deadlocks between adjacent ranks.
529 """
530 for rank in range(0, real_stage_num, 2):
531 item = current_items.get(rank)
532 if item is not None:
533 sub = item.sub_step if isinstance(item, _OverlapPhantom) else item
534 insert_step_comms(sub, rank, new_schedule)
535 for rank in range(1, real_stage_num, 2):
536 item = current_items.get(rank)
537 if item is not None:
538 sub = item.sub_step if isinstance(item, _OverlapPhantom) else item
539 insert_step_comms(sub, rank, new_schedule)
542def _column_scan_insert_comms(expanded, real_stage_num, insert_step_comms):
543 """Column-scan over an OVERLAP-expanded schedule to insert SEND/RECV.
545 Processes ``expanded`` one time slot at a time. Emits the original
546 overlap step into ``new_schedule`` only once (at the first-half
547 phantom). Delegates comm insertion to ``insert_step_comms`` for each
548 plain step or phantom's underlying sub-step.
550 Even ranks are processed before odd ranks at each time step to avoid
551 P2P deadlocks between adjacent ranks.
553 Args:
554 expanded: Result of :func:`_expand_overlap_slots`.
555 real_stage_num: Number of physical ranks.
556 insert_step_comms: Callable ``(step, rank, new_schedule) -> None``
557 that inserts SEND/RECV for a single FWD/BWD step.
559 Returns:
560 ``{rank: [MetaStep, ...]}`` final schedule.
561 """
562 max_length = max(len(order) for order in expanded.values())
563 new_schedule = {rank: [] for rank in range(real_stage_num)}
565 for time_step in range(max_length):
566 current_items = {}
567 for rank in range(real_stage_num):
568 if time_step < len(expanded[rank]):
569 item = expanded[rank][time_step]
570 current_items[rank] = item
571 if item is None:
572 # Preserve bubble slots to keep per-rank time-step
573 # indexing aligned with the column scan. The runtime
574 # loop skips ``None`` entries, so this is execution-
575 # semantics-neutral.
576 new_schedule[rank].append(None)
577 continue
578 if isinstance(item, _OverlapPhantom):
579 # Emit the overlap step only once, at the first-half slot.
580 if item.is_first_half:
581 new_schedule[rank].append(item.obf_step)
582 else:
583 new_schedule[rank].append(item)
584 else:
585 current_items[rank] = None
587 _process_rank_items(
588 real_stage_num, current_items, insert_step_comms, new_schedule,
589 )
591 return new_schedule
594def add_send_recv(scheduler, stage_num, real_stage_num, style='loop'):
595 """Insert P2P send/recv operations into a per-rank compute schedule.
597 For each FWD or BWD step that requires cross-rank communication, a
598 ``FWD_SEND`` / ``BWD_SEND`` is appended to the sender's schedule and a
599 ``FWD_RECV`` / ``BWD_RECV`` is appended to the receiver's schedule.
601 ``OVERLAP_F_B`` / ``OVERLAP_B_F`` composite steps are expanded into
602 **two** virtual time slots during the column scan so that the RECV
603 triggered by the **second** sub-step lands in the receiver's schedule
604 one slot later — matching the fact that the sender can only finish
605 emitting the second sub-step's output after the first completes.
607 Even ranks are processed before odd ranks at each time step to avoid
608 P2P deadlocks between adjacent ranks.
610 Args:
611 scheduler: ``{rank: [MetaStep | None, ...]}`` — compute schedule
612 with ``None`` for bubble slots.
613 stage_num: Total number of virtual pipeline stages.
614 real_stage_num: Number of physical ranks.
615 style: Topology mapping — ``'loop'`` or ``'v'``.
617 Returns:
618 ``{rank: [MetaStep, ...]}`` — schedule with communication ops inserted.
619 """
621 def stage_to_rank(stage_index: int) -> int:
622 """Map a virtual stage index to its physical rank."""
623 if style == 'loop':
624 return stage_index % real_stage_num
625 if style == 'v':
626 if stage_index < real_stage_num:
627 return stage_index
628 return stage_num - 1 - stage_index
629 raise ValueError(f"Argument 'style' must be 'loop' or 'v', but got {style!r}.")
631 def _fwd_peer(stage_index: int):
632 """Return the rank that receives this stage's forward output, or None."""
633 if stage_index >= stage_num - 1:
634 return None
635 peer = stage_to_rank(stage_index + 1)
636 return peer if peer != stage_to_rank(stage_index) else None
638 def _bwd_peer(stage_index: int):
639 """Return the rank that receives this stage's backward gradient, or None."""
640 if stage_index <= 0:
641 return None
642 peer = stage_to_rank(stage_index - 1)
643 return peer if peer != stage_to_rank(stage_index) else None
645 def _insert_comms_for_step(step, rank, new_schedule):
646 """Insert send/recv for a single FWD, BWD, or composite OVERLAP step."""
647 if step is None:
648 return
650 if step.type == MetaStepType.FWD:
651 peer = _fwd_peer(step.stage_index)
652 if peer is not None:
653 new_schedule[rank].append(
654 MetaStep(step.micro_index, MetaStepType.FWD_SEND, step.stage_index))
655 new_schedule[peer].append(
656 MetaStep(step.micro_index, MetaStepType.FWD_RECV, step.stage_index + 1))
658 elif step.type == MetaStepType.BWD:
659 peer = _bwd_peer(step.stage_index)
660 if peer is not None:
661 new_schedule[rank].append(
662 MetaStep(step.micro_index, MetaStepType.BWD_SEND, step.stage_index))
663 new_schedule[peer].append(
664 MetaStep(step.micro_index, MetaStepType.BWD_RECV, step.stage_index - 1))
666 elif step.type in (MetaStepType.OVERLAP_F_B, MetaStepType.OVERLAP_B_F) and step.sub_steps:
667 for sub in step.sub_steps:
668 _insert_comms_for_step(sub, rank, new_schedule)
670 # --- Main logic: expand OVERLAP steps into 2 virtual slots, then scan ---
671 expanded = _expand_overlap_slots(scheduler, real_stage_num)
672 return _column_scan_insert_comms(expanded, real_stage_num, _insert_comms_for_step)
675_ALIGN_PAD = object()
676"""Sentinel marking a forced 1F1B-boundary bubble produced during alignment."""
679def _step_dep_ready(step, rank, t, done, stage_num, stage_to_rank):
680 """Cross-rank data dependency check used by the alignment simulator.
682 A FWD step at stage ``s`` depends on FWD at stage ``s-1`` (on a
683 different rank); BWD at stage ``s`` depends on BWD at stage ``s+1``.
684 Steps at boundaries or whose producer lives on the same rank are
685 always ready.
686 """
687 si, mi = step.stage_index, step.micro_index
688 if step.type == MetaStepType.FWD:
689 if si == 0 or stage_to_rank(si - 1) == rank:
690 return True
691 key = (MetaStepType.FWD, si - 1, mi)
692 return key in done and done[key] < t
693 if step.type == MetaStepType.BWD:
694 if si == stage_num - 1 or stage_to_rank(si + 1) == rank:
695 return True
696 key = (MetaStepType.BWD, si + 1, mi)
697 return key in done and done[key] < t
698 return True
701def _simulate_aligned_schedule(padded, stage_num, real_stage_num, stage_to_rank):
702 """Simulate execution time-step by time-step, inserting bubbles where
703 a step is not yet ready (cross-rank dep) or where the cooldown
704 rhythm requires it.
706 Args:
707 padded: ``{rank: [step | _ALIGN_PAD | None, ...]}`` after
708 1F1B-boundary padding.
709 stage_num: Total number of virtual pipeline stages.
710 real_stage_num: Number of physical ranks.
711 stage_to_rank: Topology mapping from stage to rank.
713 Returns:
714 ``{rank: [step | None, ...]}`` ready for the column-scan SEND/RECV
715 insertion phase.
716 """
717 remaining_fwd = {
718 rank: sum(
719 1 for s in padded[rank]
720 if s is not _ALIGN_PAD and s is not None and s.type == MetaStepType.FWD
721 )
722 for rank in range(real_stage_num)
723 }
724 cursors = {r: 0 for r in range(real_stage_num)}
725 aligned = {r: [] for r in range(real_stage_num)}
726 done = {}
727 last_was_cooldown_bwd = {r: False for r in range(real_stage_num)}
728 max_t = sum(len(v) for v in padded.values()) + real_stage_num * 20
730 def _emit_bubble(rank):
731 aligned[rank].append(None)
732 last_was_cooldown_bwd[rank] = False
734 def _emit_step(rank, step, t, in_cooldown):
735 aligned[rank].append(step)
736 done[(step.type, step.stage_index, step.micro_index)] = t
737 cursors[rank] += 1
738 if step.type == MetaStepType.FWD:
739 remaining_fwd[rank] -= 1
740 last_was_cooldown_bwd[rank] = in_cooldown and step.type == MetaStepType.BWD
742 def _step_rank_at(t, rank):
743 if cursors[rank] >= len(padded[rank]):
744 return
745 item = padded[rank][cursors[rank]]
746 if item is _ALIGN_PAD:
747 _emit_bubble(rank)
748 cursors[rank] += 1
749 return
750 in_cooldown = remaining_fwd[rank] == 0
751 # Cooldown rhythm: alternate None / BWD in pure-BWD phase.
752 cooldown_skip = (
753 in_cooldown
754 and item.type == MetaStepType.BWD
755 and last_was_cooldown_bwd[rank]
756 )
757 if cooldown_skip:
758 _emit_bubble(rank)
759 return
760 if not _step_dep_ready(item, rank, t, done, stage_num, stage_to_rank):
761 _emit_bubble(rank)
762 return
763 _emit_step(rank, item, t, in_cooldown)
765 for t in range(max_t):
766 if all(cursors[r] >= len(padded[r]) for r in range(real_stage_num)):
767 break
768 for rank in range(real_stage_num):
769 _step_rank_at(t, rank)
770 return aligned
773def auto_align_and_add_send_recv(scheduler, stage_num, real_stage_num, style='loop'):
774 """Auto-insert bubble alignment and P2P send/recv into a pure-compute schedule.
776 Unlike :func:`add_send_recv` which requires the caller to pre-insert
777 ``None`` bubble slots for time-step alignment, this function accepts a
778 **pure compute order** (``FWD`` / ``BWD`` only, no ``None`` needed) and
779 automatically determines bubble placement via execution simulation.
781 Three constraints are enforced:
783 1. **Data dependency** — a ``FWD(stage_k)`` cannot execute until
784 ``FWD(stage_{k-1})`` on its source rank has completed (and
785 analogously for ``BWD``).
786 2. **1F1B transition alignment** — ``real_stage_num - 1 - rank`` padding
787 slots are inserted at the warmup → 1F1B boundary (detected as the
788 first ``FWD`` immediately followed by a ``BWD`` in the compute order)
789 so that all ranks enter the 1F1B steady state in lockstep.
790 3. **Cooldown rhythm** — once a rank exhausts its ``FWD`` ops and enters
791 pure-``BWD`` cooldown, consecutive ``BWD`` steps are separated by a
792 ``None`` slot, maintaining the column-phase-sync property (no rank
793 does ``BWD`` while another does ``FWD`` at the same time step).
795 After alignment, a column-scan pass inserts ``FWD_SEND`` / ``FWD_RECV``
796 and ``BWD_SEND`` / ``BWD_RECV`` with the same prefetch semantics as
797 :func:`add_send_recv`.
799 Args:
800 scheduler: ``{rank: [MetaStep, ...]}`` — pure compute schedule.
801 ``None`` entries are silently stripped before processing.
802 stage_num: Total number of virtual pipeline stages.
803 real_stage_num: Number of physical ranks.
804 style: Topology mapping — ``'loop'`` or ``'v'``.
806 Returns:
807 ``{rank: [MetaStep, ...]}`` — fully aligned schedule with bubbles
808 and communication ops inserted.
809 """
811 # ---- topology helpers (shared with column-scan phase) ----
813 def stage_to_rank(stage_index: int) -> int:
814 if style == 'loop':
815 return stage_index % real_stage_num
816 if style == 'v':
817 if stage_index < real_stage_num:
818 return stage_index
819 return stage_num - 1 - stage_index
820 raise ValueError(f"Argument 'style' must be 'loop' or 'v', but got {style!r}.")
822 def _fwd_peer(stage_index: int):
823 if stage_index >= stage_num - 1:
824 return None
825 peer = stage_to_rank(stage_index + 1)
826 return peer if peer != stage_to_rank(stage_index) else None
828 def _bwd_peer(stage_index: int):
829 if stage_index <= 0:
830 return None
831 peer = stage_to_rank(stage_index - 1)
832 return peer if peer != stage_to_rank(stage_index) else None
834 # ---- Phase 1: strip None, detect 1F1B boundary, insert transition padding ----
836 def _find_1f1b_boundary(order):
837 """Index of the first FWD followed by BWD; ``len(order)`` if absent."""
838 for i in range(len(order) - 1):
839 if (order[i].type == MetaStepType.FWD
840 and order[i + 1].type == MetaStepType.BWD):
841 return i
842 return len(order)
844 padded = {}
845 for rank in range(real_stage_num):
846 order = [s for s in scheduler[rank] if s is not None]
847 boundary = _find_1f1b_boundary(order)
848 pad_count = real_stage_num - 1 - rank
849 padded[rank] = order[:boundary] + [_ALIGN_PAD] * pad_count + order[boundary:]
851 # ---- Phase 2: simulate execution with data deps + cooldown rhythm ----
853 aligned = _simulate_aligned_schedule(padded, stage_num, real_stage_num, stage_to_rank)
855 # ---- Phase 3: column-scan SEND/RECV insertion (same as add_send_recv) ----
857 def _insert_comms_for_step(step, rank, new_schedule):
858 if step is None:
859 return
860 if step.type == MetaStepType.FWD:
861 peer = _fwd_peer(step.stage_index)
862 if peer is not None:
863 new_schedule[rank].append(
864 MetaStep(step.micro_index, MetaStepType.FWD_SEND, step.stage_index))
865 new_schedule[peer].append(
866 MetaStep(step.micro_index, MetaStepType.FWD_RECV, step.stage_index + 1))
867 elif step.type == MetaStepType.BWD:
868 peer = _bwd_peer(step.stage_index)
869 if peer is not None:
870 new_schedule[rank].append(
871 MetaStep(step.micro_index, MetaStepType.BWD_SEND, step.stage_index))
872 new_schedule[peer].append(
873 MetaStep(step.micro_index, MetaStepType.BWD_RECV, step.stage_index - 1))
874 elif step.type in (MetaStepType.OVERLAP_F_B, MetaStepType.OVERLAP_B_F) and step.sub_steps:
875 for sub in step.sub_steps:
876 _insert_comms_for_step(sub, rank, new_schedule)
878 # Expand OVERLAP steps into 2 virtual slots before the column scan so
879 # the RECV triggered by an overlap's second sub-step lands one slot
880 # later on the receiver — matching the fact that the sender can only
881 # finish emitting the second sub-step after the first completes.
882 expanded = _expand_overlap_slots(aligned, real_stage_num)
883 return _column_scan_insert_comms(expanded, real_stage_num, _insert_comms_for_step)
886class ScheduleGPipe(PipelineScheduleRuntime):
887 """
888 The Gpipe schedule.
889 It first executes all forward micro batches and then execute all backward micro batches.
890 """
891 def __init__(self,
892 stages,
893 micro_batch_num,
894 args_batch_dim=None,
895 kwargs_batch_dim=None,
896 output_concat_dim=None):
897 super().__init__(stages,
898 micro_batch_num,
899 args_batch_dim=args_batch_dim,
900 kwargs_batch_dim=kwargs_batch_dim,
901 output_concat_dim=output_concat_dim)
902 self.build_exec_order()
904 def _build_stage_to_rank_index(self) -> None:
905 self._stage_to_rank_index = generate_stage_to_rank_mapping(
906 self.real_stage_num, self._stage_num, style='loop'
907 )
909 def construct_exec_order(self):
910 """construct_exec_order of Gpipe."""
911 for stage_index in range(self.real_stage_num):
912 order_list = []
913 for mb_index in range(self.micro_batch_num):
914 if stage_index != 0:
915 order_list.append(MetaStep(mb_index, MetaStepType.FWD_RECV, stage_index))
916 order_list.append(MetaStep(mb_index, MetaStepType.FWD, stage_index))
917 if stage_index != self.real_stage_num - 1:
918 order_list.append(MetaStep(mb_index, MetaStepType.FWD_SEND, stage_index))
919 for mb_index in range(self.micro_batch_num):
920 if stage_index != self.real_stage_num - 1:
921 order_list.append(MetaStep(mb_index, MetaStepType.BWD_RECV, stage_index))
922 order_list.append(MetaStep(mb_index, MetaStepType.BWD, stage_index))
923 if stage_index != 0:
924 order_list.append(MetaStep(mb_index, MetaStepType.BWD_SEND, stage_index))
925 self.exec_order[stage_index] = order_list
928class Schedule1F1B(PipelineScheduleRuntime):
929 """
930 The 1F1B schedule.
931 It will perform one forward and one backward on the micro batches in steady state.
932 """
933 def __init__(self,
934 stages,
935 micro_batch_num,
936 args_batch_dim=None,
937 kwargs_batch_dim=None,
938 output_concat_dim=None):
939 super().__init__(stages,
940 micro_batch_num,
941 args_batch_dim=args_batch_dim,
942 kwargs_batch_dim=kwargs_batch_dim,
943 output_concat_dim=output_concat_dim)
944 self.build_exec_order()
946 def _build_stage_to_rank_index(self) -> None:
947 self._stage_to_rank_index = generate_stage_to_rank_mapping(
948 self.real_stage_num, self._stage_num, style='loop'
949 )
951 def construct_exec_order(self):
952 """construct_exec_order of 1F1B."""
953 for stage_index in range(self.real_stage_num):
954 order_list = []
955 fwd_index = 0
956 bwd_index = 0
957 # warmup phase
958 warmup_micro_batches = min(self.real_stage_num - stage_index, self.micro_batch_num)
959 for _ in range(warmup_micro_batches):
960 if stage_index != 0:
961 order_list.append(MetaStep(fwd_index, MetaStepType.FWD_RECV, stage_index))
962 if stage_index % 2 == 0:
963 order_list.append(MetaStep(fwd_index, MetaStepType.FWD, stage_index))
964 if fwd_index != warmup_micro_batches - 1:
965 order_list.append(MetaStep(fwd_index, MetaStepType.FWD_SEND, stage_index))
966 else:
967 if fwd_index > 0:
968 order_list.append(MetaStep(fwd_index - 1, MetaStepType.FWD_SEND, stage_index))
969 order_list.append(MetaStep(fwd_index, MetaStepType.FWD, stage_index))
970 fwd_index += 1
972 # if warmup phase cannot filled up, then we need to execute fwd send in advance
973 if self.real_stage_num - stage_index > self.micro_batch_num:
974 order_list.append(MetaStep(fwd_index - 1, MetaStepType.FWD_SEND, stage_index))
975 fwd_index += 1
976 # steady phase
977 steady_micro_batches = self.micro_batch_num - warmup_micro_batches
978 for _ in range(steady_micro_batches):
979 if stage_index != self.real_stage_num - 1:
980 order_list.append(MetaStep(bwd_index, MetaStepType.BWD_RECV, stage_index))
981 order_list.append(MetaStep(fwd_index - 1, MetaStepType.FWD_SEND, stage_index))
982 order_list.append(MetaStep(bwd_index, MetaStepType.BWD, stage_index))
984 if stage_index != 0:
985 order_list.append(MetaStep(bwd_index, MetaStepType.BWD_SEND, stage_index))
986 order_list.append(MetaStep(fwd_index, MetaStepType.FWD_RECV, stage_index))
987 order_list.append(MetaStep(fwd_index, MetaStepType.FWD, stage_index))
988 fwd_index += 1
989 bwd_index += 1
991 # cooldown phase
992 cooldown_micro_batches = warmup_micro_batches
993 for _ in range(cooldown_micro_batches):
994 if stage_index != self.real_stage_num - 1:
995 order_list.append(MetaStep(bwd_index, MetaStepType.BWD_RECV, stage_index))
996 if bwd_index == self.micro_batch_num - warmup_micro_batches and fwd_index <= self.micro_batch_num:
997 order_list.append(MetaStep(fwd_index - 1, MetaStepType.FWD_SEND, stage_index))
998 order_list.append(MetaStep(bwd_index, MetaStepType.BWD, stage_index))
1000 if stage_index != 0:
1001 order_list.append(MetaStep(bwd_index, MetaStepType.BWD_SEND, stage_index))
1002 bwd_index += 1
1003 self.exec_order[stage_index] = order_list
1006class ScheduleInterleaved1F1B(PipelineScheduleRuntime):
1007 """The Interleaved 1F1B schedule.
1009 Supports multiple stages per rank. In steady state, performs one
1010 forward followed by one backward on each micro-batch. Handles the
1011 cases where ``micro_batch_num`` is less than, equal to, or greater
1012 than the stage count, including non-evenly-divisible micro counts.
1014 Two orthogonal overlap modes can be enabled via constructor flags:
1016 * ``overlap_p2p=True``: defer P2P recv ``handle.wait()`` until the
1017 consuming FWD/BWD step (or the OVERLAP_B_F callback when
1018 ``overlap_b_f=True``), letting recv overlap with prior compute.
1019 * ``overlap_b_f=True``: in the 1F1B steady state, pair consecutive
1020 ``(B_i, F_{i+1})`` steps into ``OVERLAP_B_F`` composite steps so
1021 a registered callback can drive comm/compute overlap (typically
1022 via :class:`CommComputeOverlap` for MoE EP A2A). Users register
1023 the callback through :meth:`register_custom_function`.
1025 The two flags are independent and can be combined.
1027 Example:
1028 >>> # Plain interleaved 1F1B
1029 >>> sched = ScheduleInterleaved1F1B(stages, 8)
1030 >>> # With B/F overlap (dual-pipe-style comm/compute overlap)
1031 >>> sched = ScheduleInterleaved1F1B(stages, 8, overlap_b_f=True)
1032 >>> sched.register_custom_function(MetaStepType.OVERLAP_B_F, callback)
1033 """
1034 def __init__(self,
1035 stages,
1036 micro_batch_num,
1037 args_batch_dim=None,
1038 kwargs_batch_dim=None,
1039 output_concat_dim=None,
1040 overlap_p2p=False,
1041 overlap_b_f=False):
1042 super().__init__(stages,
1043 micro_batch_num,
1044 args_batch_dim=args_batch_dim,
1045 kwargs_batch_dim=kwargs_batch_dim,
1046 output_concat_dim=output_concat_dim,
1047 overlap_p2p=overlap_p2p)
1048 # _overlap_b_f selects between plain F/B emission and OVERLAP_B_F
1049 # pairing in the 1F1B steady-state phase. Must be set before
1050 # ``construct_stage_exec_order`` is called below.
1051 self._overlap_b_f = overlap_b_f
1052 self.n_rounds = max(1, self.micro_batch_num // self.real_stage_num)
1053 if self.micro_batch_num < self.real_stage_num:
1054 base = self.micro_batch_num - self.real_stage_num
1055 remainder = 0
1056 else:
1057 n_extra_microbatch = self.micro_batch_num % self.real_stage_num
1058 base = n_extra_microbatch // self.n_rounds
1059 remainder = n_extra_microbatch % self.n_rounds
1060 self.n_microbatch_per_round = \
1061 [self.real_stage_num + base + 1 if i < remainder else
1062 self.real_stage_num + base for i in range(self.n_rounds)]
1063 self.n_microbatch_per_round_accu = \
1064 [x * self.n_local_stages for x in itertools.accumulate(self.n_microbatch_per_round)]
1065 self.n_microbatch_per_round_accu.insert(0, 0)
1066 self.build_exec_order()
1068 def construct_exec_order(self):
1069 for stage_index in range(self.real_stage_num):
1070 self.exec_order[stage_index] = self.construct_stage_exec_order(stage_index)
1071 self.exec_order = add_send_recv(self.exec_order, self._stage_num, self.real_stage_num, style = 'loop')
1073 def _build_stage_to_rank_index(self) -> None:
1074 self._stage_to_rank_index = generate_stage_to_rank_mapping(
1075 self.real_stage_num, self._stage_num, style='loop'
1076 )
1078 def warmup_ops(self, stage_index):
1079 """warmup phase."""
1080 warmup_ops_last_stage = (self.n_local_stages - 1) * self.n_microbatch_per_round[0]
1081 warmup_ops = warmup_ops_last_stage + 2 * (self.real_stage_num - 1 - stage_index)
1082 return min(warmup_ops, self.micro_batch_num * self.n_local_stages)
1084 def forward_stage_index(self, op_index, stage_index):
1085 """obtain forward stage_index based on op_index."""
1086 accu_index = bisect.bisect_right(self.n_microbatch_per_round_accu, op_index) - 1
1087 local_index = (op_index - self.n_microbatch_per_round_accu[accu_index]) // \
1088 self.n_microbatch_per_round[accu_index]
1089 return (local_index * self.real_stage_num) + stage_index
1091 def backward_stage_index(self, op_index, stage_index):
1092 """obtain backward stage_index based on op_index."""
1093 accu_index = bisect.bisect_right(self.n_microbatch_per_round_accu, op_index) - 1
1094 local_index = (op_index - self.n_microbatch_per_round_accu[accu_index]) // \
1095 self.n_microbatch_per_round[accu_index]
1096 local_index = self.n_local_stages - 1 - local_index
1097 return (local_index * self.real_stage_num) + stage_index
1099 def _short_micro(self) -> bool:
1100 """True when ``micro_batch_num < real_stage_num`` (extra-bubble regime)."""
1101 return self.micro_batch_num < self.real_stage_num
1103 def _trailing_bubble(self) -> int:
1104 """Bubble count appended after a BWD with ``micro == micro_batch_num - 1``
1105 in the short-micro regime.
1106 """
1107 return self.real_stage_num - self.micro_batch_num
1109 def _emit_warmup_ops(self, stage_index, warmup_ops, fwd_stage_micro_index):
1110 """Emit pure-FWD warmup ops with optional short-micro bubble padding."""
1111 ops = []
1112 short = self._short_micro()
1113 last_micro = self.micro_batch_num - 1
1114 last_stage = self.real_stage_num - 1
1115 bubble = self._trailing_bubble()
1116 for op_idx in range(warmup_ops):
1117 fwd_stage_idx = self.forward_stage_index(op_idx, stage_index)
1118 fwd_micro_idx = fwd_stage_micro_index[fwd_stage_idx]
1119 ops.append(MetaStep(fwd_micro_idx, MetaStepType.FWD, fwd_stage_idx))
1120 need_pad = (
1121 short
1122 and fwd_micro_idx == last_micro
1123 and (op_idx != warmup_ops - 1 or stage_index == last_stage)
1124 )
1125 if need_pad:
1126 ops.extend([None] * bubble)
1127 fwd_stage_micro_index[fwd_stage_idx] += 1
1128 return ops
1130 def _emit_cooldown_ops(self, stage_index, warmup_ops, fwd_bwd_ops, total_ops,
1131 bwd_stage_micro_index):
1132 """Emit pure-BWD cooldown ops (each preceded by a bubble) with
1133 optional short-micro trailing padding.
1134 """
1135 ops = []
1136 short = self._short_micro()
1137 last_micro = self.micro_batch_num - 1
1138 bubble = self._trailing_bubble()
1139 for op_idx in range(warmup_ops + fwd_bwd_ops, total_ops):
1140 ops.append(None)
1141 bwd_stage_idx = self.backward_stage_index(op_idx - warmup_ops, stage_index)
1142 bwd_micro_idx = bwd_stage_micro_index[bwd_stage_idx]
1143 ops.append(MetaStep(bwd_micro_idx, MetaStepType.BWD, bwd_stage_idx))
1144 if short and bwd_micro_idx == last_micro:
1145 ops.extend([None] * bubble)
1146 bwd_stage_micro_index[bwd_stage_idx] += 1
1147 return ops
1149 def _emit_1f1b_ops(self, stage_index, warmup_ops, fwd_bwd_ops,
1150 fwd_stage_micro_index, bwd_stage_micro_index):
1151 """Emit interleaved (FWD, BWD) pairs for the 1F1B steady-state phase."""
1152 ops = []
1153 short = self._short_micro()
1154 last_micro = self.micro_batch_num - 1
1155 last_stage = self.real_stage_num - 1
1156 bubble = self._trailing_bubble()
1157 for op_idx in range(warmup_ops, warmup_ops + fwd_bwd_ops):
1158 fwd_stage_idx = self.forward_stage_index(op_idx, stage_index)
1159 fwd_micro_idx = fwd_stage_micro_index[fwd_stage_idx]
1160 ops.append(MetaStep(fwd_micro_idx, MetaStepType.FWD, fwd_stage_idx))
1161 fwd_stage_micro_index[fwd_stage_idx] += 1
1162 bwd_stage_idx = self.backward_stage_index(op_idx - warmup_ops, stage_index)
1163 bwd_micro_idx = bwd_stage_micro_index[bwd_stage_idx]
1164 ops.append(MetaStep(bwd_micro_idx, MetaStepType.BWD, bwd_stage_idx))
1165 need_pad = (
1166 short
1167 and bwd_micro_idx == last_micro
1168 and stage_index == last_stage
1169 )
1170 if need_pad:
1171 ops.extend([None] * bubble)
1172 bwd_stage_micro_index[bwd_stage_idx] += 1
1173 return ops
1175 @staticmethod
1176 def _collect_fwd_bwd_steps(emit_fwd, emit_bwd, fwd_bwd_ops, warmup_ops):
1177 """Walk the 1F1B range collecting parallel ``fwd_steps`` / ``bwd_steps``.
1179 ``emit_fwd(op_idx)`` and ``emit_bwd(op_idx)`` build a single
1180 :class:`MetaStep` and advance their respective per-stage micro
1181 counters as a side effect.
1182 """
1183 fwd_steps = []
1184 bwd_steps = []
1185 for op_idx in range(warmup_ops, warmup_ops + fwd_bwd_ops):
1186 fwd_steps.append(emit_fwd(op_idx))
1187 bwd_steps.append(emit_bwd(op_idx))
1188 return fwd_steps, bwd_steps
1190 @staticmethod
1191 def _pair_into_overlap_b_f(fwd_steps, bwd_steps):
1192 """Build ``F₁, [B_i, F_{i+1}], B_n`` ordering with OVERLAP_B_F pairs.
1194 ``sub_steps`` carry the ``(bwd, fwd)`` tuple — callbacks access
1195 them via ``step.sub_steps`` to recover per-direction stage /
1196 micro info.
1197 """
1198 ops = []
1199 if fwd_steps:
1200 ops.append(fwd_steps[0]) # F₁ runs alone
1201 for i in range(len(bwd_steps) - 1):
1202 ops.append(MetaStep(
1203 None, MetaStepType.OVERLAP_B_F, None,
1204 sub_steps=(bwd_steps[i], fwd_steps[i + 1]),
1205 ))
1206 if bwd_steps:
1207 ops.append(bwd_steps[-1]) # B_n runs alone
1208 return ops
1210 def _emit_1f1b_overlap_ops(self, stage_index, warmup_ops, fwd_bwd_ops,
1211 fwd_stage_micro_index, bwd_stage_micro_index):
1212 """Emit ``F₁, [B_i, F_{i+1}], B_n`` for the 1F1B phase under
1213 ``overlap_b_f=True``. Each ``[B_i, F_{i+1}]`` becomes an
1214 ``OVERLAP_B_F`` composite step; a registered callback drives the
1215 actual concurrent execution. Short-micro extra-bubble padding
1216 on the last rank is appended after ``B_n``.
1217 """
1218 def emit_fwd(op_idx):
1219 fwd_si = self.forward_stage_index(op_idx, stage_index)
1220 fwd_mi = fwd_stage_micro_index[fwd_si]
1221 fwd_stage_micro_index[fwd_si] += 1
1222 return MetaStep(fwd_mi, MetaStepType.FWD, fwd_si)
1224 def emit_bwd(op_idx):
1225 bwd_si = self.backward_stage_index(op_idx - warmup_ops, stage_index)
1226 bwd_mi = bwd_stage_micro_index[bwd_si]
1227 bwd_stage_micro_index[bwd_si] += 1
1228 return MetaStep(bwd_mi, MetaStepType.BWD, bwd_si)
1230 fwd_steps, bwd_steps = self._collect_fwd_bwd_steps(
1231 emit_fwd, emit_bwd, fwd_bwd_ops, warmup_ops,
1232 )
1233 ops = self._pair_into_overlap_b_f(fwd_steps, bwd_steps)
1235 last_stage = self.real_stage_num - 1
1236 if self._short_micro() and stage_index == last_stage and bwd_steps:
1237 if bwd_steps[-1].micro_index == self.micro_batch_num - 1:
1238 ops.extend([None] * self._trailing_bubble())
1239 return ops
1241 def construct_stage_exec_order(self, stage_index):
1242 """Construct the execution order for ``stage_index``.
1244 Builds: warmup → bubbles → 1F1B steady state → cooldown. The
1245 1F1B segment switches between :meth:`_emit_1f1b_ops` (plain) and
1246 :meth:`_emit_1f1b_overlap_ops` (OVERLAP_B_F pairing) based on
1247 the ``overlap_b_f`` constructor flag.
1248 """
1249 warmup_ops = self.warmup_ops(stage_index)
1250 fwd_bwd_ops = self.n_local_stages * self.micro_batch_num - warmup_ops
1251 total_ops = 2 * warmup_ops + fwd_bwd_ops
1252 order_list = [None for _ in range(stage_index)]
1253 fwd_stage_micro_index = defaultdict(int)
1254 bwd_stage_micro_index = defaultdict(int)
1255 order_list.extend(self._emit_warmup_ops(stage_index, warmup_ops, fwd_stage_micro_index))
1256 bubbles_before_1f1b = max(
1257 0,
1258 2 * (self.real_stage_num - stage_index - 1) - self.micro_batch_num,
1259 )
1260 order_list.extend([None] * bubbles_before_1f1b)
1261 order_list.extend([None] * (self.real_stage_num - 1 - stage_index))
1262 if self._overlap_b_f:
1263 order_list.extend(self._emit_1f1b_overlap_ops(
1264 stage_index, warmup_ops, fwd_bwd_ops,
1265 fwd_stage_micro_index, bwd_stage_micro_index,
1266 ))
1267 else:
1268 order_list.extend(self._emit_1f1b_ops(
1269 stage_index, warmup_ops, fwd_bwd_ops,
1270 fwd_stage_micro_index, bwd_stage_micro_index,
1271 ))
1272 order_list.extend(self._emit_cooldown_ops(
1273 stage_index, warmup_ops, fwd_bwd_ops, total_ops, bwd_stage_micro_index,
1274 ))
1275 return order_list
1278def detect_cycle_in_graph(ranks_map):
1279 """
1280 Detects a cycle in the directed graph constructed from ranks_map.
1282 Args:
1283 ranks_map: A dictionary where keys are rank names and values are lists of nodes.
1285 Returns:
1286 tuple: (cycle_path, cycle_ranks) where cycle_path is a list of nodes forming the cycle and cycle_ranks
1287 is a list of rank transitions corresponding to the cycle path.
1288 """
1289 graph = defaultdict(list)
1290 rank_edges = {}
1292 for rank, nodes in ranks_map.items():
1293 for i in range(len(nodes) - 1):
1294 u, v = nodes[i], nodes[i + 1]
1295 graph[u].append(v)
1296 rank_edges[(u, v)] = rank
1298 visited = set()
1299 path = []
1300 node_indices = {}
1301 cycle_path = []
1302 cycle_ranks = []
1304 stack = []
1305 for node in list(graph.keys()):
1306 if node not in visited:
1307 stack.append((node, False))
1308 while stack:
1309 current_node, is_processed = stack.pop()
1311 if is_processed:
1312 path.pop()
1313 del node_indices[current_node]
1314 continue
1316 if current_node in node_indices:
1317 cycle_start = node_indices[current_node]
1318 cycle_path = path[cycle_start:] + [current_node]
1319 for i in range(cycle_start, len(path)):
1320 u = path[i]
1321 v = path[i + 1] if i + 1 < len(path) else current_node
1322 cycle_ranks.append(f"{rank_edges[(u, v)]} {u} -> {v}")
1323 return cycle_path, cycle_ranks
1325 if current_node in visited:
1326 continue
1328 visited.add(current_node)
1329 node_indices[current_node] = len(path)
1330 path.append(current_node)
1332 stack.append((current_node, True))
1333 for neighbor in reversed(graph[current_node]):
1334 stack.append((neighbor, False))
1336 return None, None
1339def output_cycle_results(cycle_path, cycle_ranks):
1340 """
1341 Helper function to output cycle detection results.
1343 Args:
1344 cycle_path (list): List of nodes forming a cycle, if any.
1345 cycle_ranks (list): List of ranks involved in the cycle.
1347 Returns:
1348 None: Outputs results to the console.
1349 """
1350 if cycle_path:
1351 logger.error("Cycle detected:")
1352 path_str = " -> ".join(str(node) for node in cycle_path)
1353 logger.error("%s -> %s", path_str, cycle_path[0]) # Close the cycle
1354 logger.error("Involving ranks:")
1355 for rank in cycle_ranks:
1356 logger.error(rank)
1357 else:
1358 logger.warning("Cycle Check succeeded. There is no cycle in the graph.")
1361def parse_and_validate(data: dict, all_rank: bool = True):
1362 """
1363 Parse and validate execution orders in a directed graph structure.
1365 This function checks the integrity and consistency of a given dataset, ensuring all required
1366 keys are present and correctly referenced. It also validates the structure of the input data
1367 and parses string values to extract meaningful components.
1369 Args:
1370 data (dict): A dictionary where keys are string identifiers and values are lists of strings.
1371 Each value represents a dependency or reference to other keys.
1372 all_rank (bool): If True, checks that all elements referenced in the data are present as keys
1373 in the dictionary. If False, only checks intersections.
1375 Returns:
1376 None: Log error messages to the console if validation fails, otherwise completes silently.
1378 Raises:
1379 ValueError: Raised indirectly if `parse_elements` encounters malformed input strings.
1380 TypeError: Raised indirectly if data contains unexpected types.
1381 """
1383 def parse_elements(value: str, max_groups: int = 2) -> set:
1384 """Extract unique elements inside the first one or two parentheses from a string."""
1386 groups = re.findall(r'\((\d+)\)', value)
1387 limited_groups = groups[:max_groups] # Limit to the first `max_groups` matches
1389 return {item.strip() for item in limited_groups}
1391 if not isinstance(data, dict):
1392 logger.error("Input must be a dictionary with string keys and lists of strings as values.")
1393 return
1395 key_to_values = {key: set(values) for key, values in data.items() if
1396 isinstance(values, list) and all(isinstance(v, str) for v in values)}
1398 for key, values in data.items():
1399 if not isinstance(values, list) or not all(isinstance(v, str) for v in values):
1400 logger.error("Values for key '%s' must be a list of strings.", key)
1401 continue
1403 for value in values:
1404 try:
1405 elements = parse_elements(value)
1406 except (ValueError, TypeError, AttributeError) as e:
1407 logger.error("Unable to parse elements from value '%s' in key '%s'. Error: %s", value, key, e)
1408 continue
1410 # Check for missing keys if all_rank is True
1411 if all_rank:
1412 missing_keys = elements - key_to_values.keys()
1413 if missing_keys:
1414 logger.error("The following keys are missing for value '%s': %s", value, missing_keys)
1415 continue
1417 # Check if the value is present in the referenced keys
1418 for element in elements & key_to_values.keys() if not all_rank else elements:
1419 if value not in key_to_values[element]:
1420 logger.error("Key '%s' is missing the value '%s'.", element, value)
1423def generate_operations(order_list: dict[int, list[MetaStep]],
1424 chunk_num: int,
1425 com_type: str = 'loop') -> dict[str, list[str]]:
1426 """
1427 Generate formatted operations dictionary from pipeline execution order.
1429 Args:
1430 order_list (dict): Dictionary where keys are rank IDs and values are MetaStep execution sequences
1431 chunk_num (int): Number of chunks (virtual pipeline stages)
1432 com_type (str): Stage-to-rank mapping type ('loop' for cyclic, 'v' for V-shaped)
1434 Returns:
1435 Dictionary where keys are rank IDs (as strings) and values are lists of formatted operation strings
1436 """
1438 def stage_to_rank(stage_index, style, stage_num, real_stage_num):
1439 """Map stage index to rank"""
1440 if style == 'loop':
1441 return stage_index % real_stage_num
1442 if style == 'v':
1443 if stage_index < real_stage_num:
1444 return stage_index
1445 return stage_num - 1 - stage_index
1446 raise ValueError("Invalid style")
1448 def find_send_target(stage_idx, op_type):
1449 """Find target stage for SEND operation"""
1450 if op_type == MetaStepType.FWD_SEND:
1451 return forward_comm.get(stage_idx)
1452 return backward_comm.get(stage_idx)
1454 def find_recv_source(stage_idx, op_type):
1455 """Find source stage for RECV operation"""
1456 if op_type == MetaStepType.FWD_RECV:
1457 # Reverse lookup in forward_comm
1458 for src, dst in forward_comm.items():
1459 if dst == stage_idx:
1460 return src
1461 else:
1462 # Reverse lookup in backward_comm
1463 for src, dst in backward_comm.items():
1464 if dst == stage_idx:
1465 return src
1466 return None
1468 real_stage = len(order_list)
1469 total_stages = real_stage * chunk_num
1471 # Build communication rules
1472 forward_comm = {}
1473 backward_comm = {}
1475 for i in range(total_stages):
1476 if i + 1 < total_stages:
1477 forward_comm[i] = i + 1
1478 if i - 1 >= 0:
1479 backward_comm[i] = i - 1
1481 formatted_operations = defaultdict(list)
1483 for rank, steps in order_list.items():
1484 operation_counter = defaultdict(int)
1486 for step in steps:
1487 if step.type in [MetaStepType.FWD_SEND, MetaStepType.BWD_SEND]:
1488 target_stage = find_send_target(step.stage_index, step.type)
1489 if target_stage is not None:
1490 target_rank = stage_to_rank(target_stage, com_type, total_stages, real_stage)
1491 comm_pair = (rank, target_rank, step.micro_index)
1492 operation_counter[comm_pair] += 1
1493 count = operation_counter[comm_pair]
1494 formatted_op = f"Send_Receive_({rank})->({target_rank})_micro{step.micro_index}_{count}th"
1495 formatted_operations[str(rank)].append(formatted_op)
1497 elif step.type in [MetaStepType.FWD_RECV, MetaStepType.BWD_RECV]:
1498 source_stage = find_recv_source(step.stage_index, step.type)
1499 if source_stage is not None:
1500 source_rank = stage_to_rank(source_stage, com_type, total_stages, real_stage)
1501 comm_pair = (source_rank, rank, step.micro_index)
1502 operation_counter[comm_pair] += 1
1503 count = operation_counter[comm_pair]
1504 formatted_op = f"Send_Receive_({source_rank})->({rank})_micro{step.micro_index}_{count}th"
1505 formatted_operations[str(rank)].append(formatted_op)
1507 # Convert defaultdict to dict
1508 return dict(formatted_operations)
1511def validate_pipeline_execution(order_list: dict[int, list[MetaStep]],
1512 chunk_num: int,
1513 com_type: str = 'loop') -> dict[str, any]:
1514 """
1515 Comprehensive validation function for pipeline parallel execution order.
1517 This function validates the execution order of pipeline parallelism by:
1518 1. Checking SEND/RECV communication pair matching
1519 2. Detecting duplicate operations
1520 3. Detecting cycles in communication graphs
1521 4. Verifying computation-SEND matching
1523 Args:
1524 order_list: Dictionary where keys are rank IDs and values are MetaStep execution sequences
1525 chunk_num: Number of chunks (virtual pipeline stages)
1526 com_type: Stage-to-rank mapping type ('loop' for cyclic, 'v' for V-shaped)
1528 Returns:
1529 Dictionary containing validation results with the following keys:
1530 - validation: Communication pair validation results
1531 - cycle_detection: Cycle detection results
1532 - computation_send_matching: Computation-SEND matching validation results
1533 - has_errors: Boolean indicating if any errors were found
1534 - error_messages: List of all error messages found
1535 - formatted_operations: Generated formatted operations
1536 """
1538 # Generate operations
1539 formatted_operations = generate_operations(order_list, chunk_num, com_type)
1541 parse_and_validate(formatted_operations, True)
1543 # Detect cycles
1544 cycle_path, cycle_ranks = detect_cycle_in_graph(formatted_operations)
1546 # Output results
1547 output_cycle_results(cycle_path, cycle_ranks)
1549 result = {
1550 'formatted_operations': formatted_operations,
1551 'cycle_path': cycle_path,
1552 'cycle_ranks': cycle_ranks,
1553 'has_cycle': bool(cycle_path)
1554 }
1555 return result
1558_COMPUTE_META_STEP_TYPES = frozenset({
1559 MetaStepType.FWD,
1560 MetaStepType.BWD,
1561 MetaStepType.BWD_INPUT,
1562 MetaStepType.BWD_WEIGHT,
1563})
1566def _next_active_stage_indices(actions, start_index, max_active_stages, managed_stage_indices):
1567 """Find the next distinct managed stages that will execute compute work.
1569 Send/recv and previously injected FSDP control steps are skipped so that the
1570 lookahead window only counts real compute, otherwise communication-only
1571 actions would consume the budget and shrink the effective prefetch depth.
1572 """
1573 stage_indices = []
1574 seen = set()
1575 for action in actions[start_index:]:
1576 for leaf_step in iter_leaf_meta_steps(action):
1577 if leaf_step.type not in _COMPUTE_META_STEP_TYPES:
1578 continue
1579 if leaf_step.stage_index not in managed_stage_indices or leaf_step.stage_index in seen:
1580 continue
1581 seen.add(leaf_step.stage_index)
1582 stage_indices.append(leaf_step.stage_index)
1583 if len(stage_indices) == max_active_stages:
1584 return stage_indices
1585 return stage_indices
1588def add_fsdp_unshard_reshard(actions, managed_stage_indices, max_active_stages=3):
1589 """Insert FSDP unshard/reshard actions for locally managed stages."""
1590 if not managed_stage_indices:
1591 return actions
1593 fsdp_actions = []
1594 active_stages = []
1595 for index, action in enumerate(actions):
1596 next_stage_indices = _next_active_stage_indices(
1597 actions, index, max_active_stages, managed_stage_indices
1598 )
1599 evicted_stages = [stage_index for stage_index in active_stages if stage_index not in next_stage_indices]
1600 fetched_stages = [stage_index for stage_index in next_stage_indices if stage_index not in active_stages]
1601 for stage_index in evicted_stages:
1602 fsdp_actions.append(MetaStep(None, MetaStepType.FSDP_RESHARD, stage_index))
1603 active_stages.remove(stage_index)
1604 for stage_index in fetched_stages:
1605 fsdp_actions.append(MetaStep(None, MetaStepType.FSDP_UNSHARD, stage_index))
1606 active_stages.append(stage_index)
1607 fsdp_actions.append(action)
1609 while active_stages:
1610 fsdp_actions.append(MetaStep(None, MetaStepType.FSDP_RESHARD, active_stages.pop(0)))
1611 return fsdp_actions
1614def add_fsdp_reduce_grad(actions, managed_stage_indices, micro_batch_num):
1615 """Insert FSDP reduce-grad actions after the last backward-like action of each stage."""
1616 if not managed_stage_indices:
1617 return actions
1619 fsdp_actions = []
1620 for action in actions:
1621 fsdp_actions.append(action)
1622 reduced_stage_indices = []
1623 for leaf_step in iter_leaf_meta_steps(action):
1624 if leaf_step.stage_index not in managed_stage_indices:
1625 continue
1626 if leaf_step.type not in (MetaStepType.BWD, MetaStepType.BWD_WEIGHT):
1627 continue
1628 if leaf_step.micro_index != micro_batch_num - 1:
1629 continue
1630 if leaf_step.stage_index not in reduced_stage_indices:
1631 reduced_stage_indices.append(leaf_step.stage_index)
1632 for stage_index in reduced_stage_indices:
1633 fsdp_actions.append(MetaStep(None, MetaStepType.FSDP_REDUCE_GRAD, stage_index))
1634 return fsdp_actions