Diff Coverage

Diff: origin/master...HEAD, staged and unstaged changes

Source File Diff Coverage (%) Missing Lines
hyper_parallel/auto_parallel/sapp_nd/memory_estimation/__init__.py 60.0% 46-49
hyper_parallel/auto_parallel/sapp_nd/memory_estimation/_backbone.py 100%  
hyper_parallel/auto_parallel/sapp_nd/memory_estimation/_hook_manager.py 100%  
hyper_parallel/auto_parallel/sapp_nd/memory_estimation/_ppb.py 95.5% 208,212-214
hyper_parallel/auto_parallel/sapp_nd/nd/dimensions.py 84.4% 283-286,291
hyper_parallel/auto_parallel/sapp_nd/nd/parallelize.py 100%  
hyper_parallel/auto_parallel/sapp_nd/pp_modeling/__init__.py 100%  
hyper_parallel/auto_parallel/sapp_nd/pp_modeling/layer_recompute.py 64.7% 91,102,123,132,134,137,139-147,153,165,171,228,234,262-263,267-268,270-271,273,300-301,303-304,306-312,314-315,317,319-324,326,384
hyper_parallel/auto_parallel/sapp_nd/pp_modeling/pp_optimizer.py 88.9% 155-157,198-199,204-207,246,351-352,442
hyper_parallel/auto_parallel/sapp_nd/pp_modeling/pp_strategy_search.py 72.8% 85,123,130-131,142,145-146,150-151,156,257,289,292,295,304,311,407-410,412,414-416,418-419,421-424,432-433,435-436,483,550,559,561-562,564-567,574,576,622,628
hyper_parallel/auto_parallel/sapp_nd/pp_modeling/pp_types.py 100%  
hyper_parallel/auto_parallel/sapp_nd/pp_modeling/pp_utils.py 0.0% 24,26,28,35,46,48-52,55-56,59,84-86,88-101,103,106,113-114,117,127-132,135,142-149,152,160-165,167-172,174-179,182,187-190,195,200,203,210-216,222,230,232-236,241-246,248-253,256,296,298-299,301,304,306-309,311-315,317,319-326,328,332
hyper_parallel/auto_parallel/sapp_nd/pp_modeling/pp_visualizer.py 60.2% 31-33,78,93,98,157-160,166-168,174,214-218,225-228,239,266,269,288,310-313,345,367-370,407,440-441,450-451,453-457,459,464,470-471,473-475,477,705-706,708-709,711-712,714-715,717-718,724,726,731-732,760-761,767-768,770,772
hyper_parallel/auto_parallel/sapp_nd/pp_modeling/ppb_load_balancer.py 86.2% 67,98,189,191-193,198-199,204-207,212-213,218,248,309,313,344,386,419,423,426-427,580-581,653,698,708-709,720
hyper_parallel/auto_parallel/sapp_nd/pp_modeling/sapp_ppb_adapter.py 80.1% 35-45,101,249,262,304,365-371,429-430,461-462,464-465,496,506-508,558,562,613,659,661,673,783,827-828,831,834,935,977-978,1083,1117-1119,1146-1147,1169,1197-1213,1240-1241,1271,1275,1287,1296,1370,1374,1404,1415-1418,1448,1451,1455,1465-1468,1490-1496,1529,1538,1547-1548,1577,1607-1608,1610-1614,1616,1618-1619,1621,1623,1629,1631-1635,1637
hyper_parallel/auto_parallel/sapp_nd/pp_modeling/stage_partition.py 83.6% 50,54,113,146,150,182,185,212,265,331,334,340,374-375,396-400,416
hyper_parallel/auto_parallel/sapp_nd/pp_modeling/yaml_config.py 100%  
hyper_parallel/auto_parallel/sapp_nd/pp_modeling/yaml_config_builder.py 33.5% 86-89,91,93-95,97-98,108-109,111-112,114-116,118,126-127,129-130,156,158,160,167-170,172,193-196,217,219-225,231-234,236-243,249-252,254,281-284,286,288-290,292-293,303-304,306-307,309-311,313,320-321,323-324,348,350,352,359-362,364,385-388,407,409-415,421-424,426,481-484,486-487,489-491,498,500-501,520,596,604,608,616
hyper_parallel/auto_parallel/sapp_ppb/__init__.py 88.0% 45,52-53
hyper_parallel/auto_parallel/sapp_ppb/sapp/sapp_solver.py 88.9% 601
hyper_parallel/auto_parallel/sapp_ppb/simulator/pp_simulator.py 100%  
hyper_parallel/auto_parallel/sapp_ppb/utils/__init__.py 100%  
hyper_parallel/auto_parallel/sapp_nd/memory_estimation/__init__.py
42
43
44
45
46
47
48
49

def __getattr__(name):  # pylint: disable=C0103
    if name not in _EXPORTS:
        raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
    module = _import_module(_EXPORTS[name], __name__)
    value = getattr(module, name)
    globals()[name] = value
    return value
hyper_parallel/auto_parallel/sapp_nd/memory_estimation/_ppb.py
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
            desc["options"] = ["NONE", "FULL"]
            desc["forward_time"] = {"NONE": 1, "FULL": 1}
            desc["backward_time"] = {"NONE": 1, "FULL": 1}
        elif ctx.current_node == ctx.tail_node:
            desc["memory_activation"] = {"NONE": 0, "FULL": 0}
            d_out = self.mb(sum(self._inner_dynamic_mem(ppb=True)))
            desc["memory_parameter"] = self.mb(res_stat) + d_out
            desc["type"] = "TAIL"
            desc["options"] = ["NONE", "FULL"]
            desc["forward_time"] = {"NONE": 1, "FULL": 1}
            desc["backward_time"] = {"NONE": 1, "FULL": 1}
        else:
            desc["memory_activation"] = {"NONE": 0, "COMM": 0, "SLCT": 0, "BOTH": 0, "FULL": 0}
            synthetic_rec_op = False
            if not hasattr(ccfg, 'rec_op'):
hyper_parallel/auto_parallel/sapp_nd/nd/dimensions.py
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295

        if user_global_batch_size is not None:
            calculated_gbs = self.global_batch_size()
            if user_global_batch_size != calculated_gbs:
                dp = self.dims_val.get(DP, 1)
                mbs = self.dims_val.get(MBS, 1)
                mbn = self.dims_val.get(MBN, 1)
                logger.warning(
                    "Batch size constraint violated: user_global_batch_size (%d) != "
                    "DP (%d) * MBS (%d) * MBN (%d) = %d",
                    user_global_batch_size, dp, mbs, mbn, calculated_gbs
                )
                return False

        return True

    def val(self, dim: Dimension) -> int:
hyper_parallel/auto_parallel/sapp_nd/pp_modeling/layer_recompute.py
87
88
89
90
91
92
93
94
95
                provided, other per-layer parameters are ignored.
        """
        if layers is not None:
            if not layers:
                raise ValueError(
                    "For LayerRecomputeModel, layers should not be empty."
                )
            self._layers = layers
            self._num_layer = len(layers)
 98
 99
100
101
102
103
104
105
106
            }
            self._homogeneous = False
        else:
            if num_layer <= 0:
                raise ValueError(
                    f"For LayerRecomputeModel, num_layer should be positive, got {num_layer}."
                )
            self._layers = None
            self._num_layer = num_layer
119
120
121
122
123
124
125
126
127

    @property
    def num_layer(self) -> int:
        """Return the number of body layers."""
        return self._num_layer

    def _get_layer_act(self, layer_id: int, mode: RecomputeMode) -> float:
        """Get activation memory for a layer under a recompute mode."""
        if not self._homogeneous:
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
            layer = self._layers[self._layer_id_to_idx[layer_id]]
            if mode == RecomputeMode.SLCT:
                return layer.activation_memory_slct
            if mode == RecomputeMode.COMM:
                return layer.activation_memory_comm
            if mode == RecomputeMode.BOTH:
                return layer.activation_memory_both
            if mode == RecomputeMode.FULL:
                return layer.activation_memory_full
            return layer.activation_memory

        if mode == RecomputeMode.SLCT:
            return self._activation_memory_slct
        if mode == RecomputeMode.COMM:
            return self._activation_memory_comm
        if mode == RecomputeMode.BOTH:
            return self._activation_memory_both
        if mode == RecomputeMode.FULL:
            return self._activation_memory_full
        return self._activation_memory

    def _get_layer_act_none(self, layer_id: int) -> float:
        """Get activation memory without recompute for a layer."""
        if not self._homogeneous:
149
150
151
152
153
154
155
156
157
    def _get_layer_act_none(self, layer_id: int) -> float:
        """Get activation memory without recompute for a layer."""
        if not self._homogeneous:
            return self._layers[self._layer_id_to_idx[layer_id]].activation_memory
        return self._activation_memory

    def _get_layer_can_recompute(self, layer_id: int) -> bool:
        """Get whether a layer supports recompute."""
        if not self._homogeneous:
161
162
163
164
165
166
167
168
169
    def _get_layer_forward_time(self, layer_id: int) -> float:
        """Get forward time for a layer."""
        if not self._homogeneous:
            return self._layers[self._layer_id_to_idx[layer_id]].forward_time
        return self._forward_time

    def _get_layer_recompute_forward_overhead(self, layer_id: int) -> float:
        """Get recompute forward overhead for a layer."""
        if not self._homogeneous:
167
168
169
170
171
172
173
174
175
    def _get_layer_recompute_forward_overhead(self, layer_id: int) -> float:
        """Get recompute forward overhead for a layer."""
        if not self._homogeneous:
            return self._layers[self._layer_id_to_idx[layer_id]].recompute_forward_overhead
        return self._recompute_forward_overhead

    def validate_config(
        self,
        config: LayerRecomputeConfig,
224
225
226
227
228
229
230
231
232
            ... )
            >>> saving = model.estimate_memory_saving(config)
        """
        if not config.enabled or config.recompute_mode == RecomputeMode.NONE:
            return 0.0

        total_saving = 0.0

        for layer_id in config.recompute_layers:
230
231
232
233
234
235
236
237
238
        total_saving = 0.0

        for layer_id in config.recompute_layers:
            if layer_id not in self._layer_id_to_idx:
                continue

            act_none = self._get_layer_act_none(layer_id)
            act_mode = self._get_layer_act(layer_id, config.recompute_mode)
            saving = act_none - act_mode
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
            >>> model = LayerRecomputeModel(layers)
            >>> config = LayerRecomputeConfig(recompute_mode=RecomputeMode.FULL)
            >>> mem = model.get_layer_activation_memory(0, config)
        """
        if layer_id not in self._layer_id_to_idx:
            raise ValueError(
                f"For get_layer_activation_memory, layer_id {layer_id} is invalid."
            )

        if not config.enabled or config.recompute_mode == RecomputeMode.NONE:
            return self._get_layer_act_none(layer_id)

        if layer_id not in config.recompute_layers:
            return self._get_layer_act_none(layer_id)

        return self._get_layer_act(layer_id, config.recompute_mode)

    def generate_recompute_config_for_memory_target(
        self,
        memory_target: float,
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
            ...     memory_target=1000.0,
            ...     current_activation_memory=2000.0,
            ... )
        """
        if current_activation_memory <= memory_target:
            return LayerRecomputeConfig(enabled=False)

        memory_to_save = current_activation_memory - memory_target
        recompute_layers = set()

        layer_savings = []
        for layer_id in range(self._num_layer):
            act_none = self._get_layer_act_none(layer_id)
            act_mode = self._get_layer_act(layer_id, recompute_mode)
            saving = act_none - act_mode
            if saving > 0:
                layer_savings.append((layer_id, saving))

        if strategy == "greedy":
            layer_savings.sort(key=lambda x: x[1], reverse=True)
        else:
            layer_savings.sort(key=lambda x: x[1])

        accumulated_saving = 0.0
        for layer_id, saving in layer_savings:
            if accumulated_saving >= memory_to_save:
                break
            recompute_layers.add(layer_id)
            accumulated_saving += saving

        return LayerRecomputeConfig(
            recompute_layers=recompute_layers,
            recompute_mode=recompute_mode,
            enabled=True,
        )
380
381
382
383
384
385
386
387
388
                rec_fwd = self._get_layer_recompute_forward_overhead(layer_id)
                fwd = self._get_layer_forward_time(layer_id)
                overhead = rec_fwd if rec_fwd > 0.0 else fwd * 0.5
            else:
                overhead = 0.0

            total_overhead += max(0.0, overhead)

        return total_overhead
hyper_parallel/auto_parallel/sapp_nd/pp_modeling/pp_optimizer.py
151
152
153
154
155
156
157
158
159
160
161

        feasible_results = [r for r in results if r.is_feasible]

        if not feasible_results:
            if results:
                return results[0]
            raise RuntimeError("No feasible PP strategy found.")

        if optimize_for == "throughput":
            feasible_results.sort(key=lambda r: r.estimated_step_time)
        elif optimize_for == "memory":
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211

        output_files = {}

        if output_format in ["png", "all"]:
            try:
                viz_data = self._visualizer.visualize_with_sapp_ppb(
                    result=result,
                    file_name=str(output_path / "pipeline_timeline.png"),
                    show=False,
                )
                if viz_data is not None:
                    output_files["pipeline_timeline"] = str(output_path / "pipeline_timeline.png")
            except (ImportError, ValueError, RuntimeError):
                pass

        if output_format in ["json", "all", "both"]:
            json_data = self._visualizer.export_to_json(result)
            json_file = output_path / "pp_strategy_visualization.json"
242
243
244
245
246
247
248
249
250
            >>> optimizer = PPOptimizer()
            >>> result = optimizer.optimize(dry_run_path="dry_run.yaml")
            >>> viz_data = optimizer.generate_pipeline_visualization(result, show=True)
        """
        return self._visualizer.generate_pipeline_timeline_plot(
            result=result,
            file_name=file_name,
            show=show,
        )
347
348
349
350
351
352
353
354
355
356
                     backgroundColor: '{slc['y_axes'][0]['color']}', yAxisID: 'y'}},
                    {{label: '{slc['y_axes'][1]['label']}', data: {slc['y_axes'][1]['data']},
                     backgroundColor: '{slc['y_axes'][1]['color']}', yAxisID: 'y1'}}
                ]"""
        elif len(slc.get("y_axes", [])) == 1:
            slc_datasets = f"""
                datasets: [
                    {{label: '{slc['y_axes'][0]['label']}', data: {slc['y_axes'][0]['data']},
                     backgroundColor: '{slc['y_axes'][0]['color']}'}}
                ]"""
438
439
440
441
442
443
444
445
446
        num_layers = yaml_config.num_layer

        recommended_strategy = "greedy"
        if num_layers > 100:
            recommended_strategy = "dp"

        recommended_mode = "fast"
        if pp_range[1] - pp_range[0] <= 4 and num_layers <= 50:
            recommended_mode = "precise"
hyper_parallel/auto_parallel/sapp_nd/pp_modeling/pp_strategy_search.py
81
82
83
84
85
86
87
88
89
                "Please provide a valid dry-run YAML path."
            )

        if num_layer <= 0:
            raise ValueError(
                "PSStrategySearcher requires positive num_layer."
            )

        self.num_layer = num_layer
119
120
121
122
123
124
125
126
127
            >>> result = searcher.evaluate_strategy(strategy)
        """
        validation_result = self._validate_strategy(strategy)
        if not validation_result[0]:
            return PPStrategyResult(
                strategy=strategy,
                is_feasible=False,
                infeasibility_reason=validation_result[1],
            )
126
127
128
129
130
131
132
133
134
135
                infeasibility_reason=validation_result[1],
            )

        if not strategy.stage_partition:
            partition = StagePartition(self.num_layers, strategy.pp_degree)
            strategy.stage_partition = partition.uniform_partition()

        stages = self._build_stage_info_from_ppb(
            strategy, ppb_output,
        )
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
        infeasibility_reason = ""
        if ppb_output is not None:
            is_feasible = ppb_output.is_feasible
            if not is_feasible:
                infeasibility_reason = str(ppb_output.infeasibility_details)
        else:
            if memory_limit is not None:
                for stage in stages:
                    if (
                        stage.memory_breakdown
                        and stage.memory_breakdown.total_memory > memory_limit
                    ):
                        is_feasible = False
                        infeasibility_reason = (
                            f"Stage {stage.stage_id} memory "
                            f"({stage.memory_breakdown.total_memory:.1f} MB) "
                            f"exceeds limit ({memory_limit:.1f} MB)."
                        )
                        break

        pipeline_bubble = 0.0
        estimated_step_time = 0.0
        imbalance_score = 0.0
253
254
255
256
257
258
259
260
261
        else:
            for stage_id, layer_ids in enumerate(strategy.stage_partition):
                recompute_layers = set()
                if strategy.layer_recompute_config:
                    recompute_layers = (
                        strategy.layer_recompute_config.recompute_layers & set(layer_ids)
                    )

                stages.append(StageInfo(
285
286
287
288
289
290
291
292
293
294
295
296
297
298
            >>> searcher = PSStrategySearcher(layers, dry_run_path="dry_run.yaml")
            >>> is_valid, msg = searcher._validate_strategy(strategy)
        """
        if strategy.pp_degree <= 0:
            return False, f"pp_degree must be positive, got {strategy.pp_degree}."

        if strategy.micro_batch_num <= 0:
            return False, f"micro_batch_num must be positive, got {strategy.micro_batch_num}."

        if strategy.pp_degree > self.num_layers:
            return (
                False,
                f"pp_degree ({strategy.pp_degree}) exceeds number of layers ({self.num_layers}).",
            )
300
301
302
303
304
305
306
307
308
        if strategy.stage_partition:
            partition = StagePartition(self.num_layers, strategy.pp_degree)
            is_valid, error_msg = partition.validate_partition(strategy.stage_partition)
            if not is_valid:
                return False, error_msg

        if strategy.layer_recompute_config:
            is_valid, error_msg = self._recompute_model.validate_config(
                strategy.layer_recompute_config
307
308
309
310
311
312
313
314
315
            is_valid, error_msg = self._recompute_model.validate_config(
                strategy.layer_recompute_config
            )
            if not is_valid:
                return False, error_msg

        return True, ""

    def _generate_visualization_data(
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
            ...     pp_degrees=[2, 4],
            ...     micro_batch_nums=[4, 8],
            ... )
        """
        if schedule_types is None:
            schedule_types = [PipelineScheduleType.ONE_F_ONE_B]
        if vpps is None:
            vpps = [1]

        results = []

        for pp_degree in pp_degrees:
            if pp_degree > self.num_layers:
                continue

            partition = StagePartition(self.num_layers, pp_degree)
            base_partition = partition.uniform_partition()

            for vpp in vpps:
                for micro_batch_num in micro_batch_nums:
                    for schedule_type in schedule_types:
                        strategy = PPStrategy(
                            pp_degree=pp_degree,
                            micro_batch_num=micro_batch_num,
                            stage_partition=base_partition,
                            schedule_type=schedule_type,
428
429
430
431
432
433
434
435
436
437
438
439
440
                            schedule_type=schedule_type,
                            num_model_chunks=vpp,
                        )

                        result = self.evaluate_strategy(strategy, memory_limit)
                        results.append(result)

        results.sort(key=lambda r: (not r.is_feasible, r.estimated_step_time))
        return results

    def search_with_load_balancing(
        self,
        pp_degrees: List[int],
479
480
481
482
483
484
485
486
487
        results = []

        for pp_degree in pp_degrees:
            if pp_degree > self.num_layers:
                continue

            for vpp in vpps:
                for micro_batch_num in micro_batch_nums:
                    ppb_input = PPBInput(
546
547
548
549
550
551
552
553
554
            ...     pp_degrees=[2, 4],
            ...     micro_batch_nums=[4, 8],
            ... )
        """
        results = self.search_with_load_balancing(
            pp_degrees=pp_degrees,
            micro_batch_nums=micro_batch_nums,
            memory_limit=memory_limit,
            vpps=vpps,
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
            comm_time=comm_time,
            constant_memory=constant_memory,
        )

        feasible_results = [r for r in results if r.is_feasible]

        if not feasible_results:
            return None

        if optimize_for == "throughput":
            feasible_results.sort(key=lambda r: r.estimated_step_time)
        elif optimize_for == "memory":
            feasible_results.sort(
                key=lambda r: max(
                    s.memory_breakdown.total_memory if s.memory_breakdown else 0.0
                    for s in r.stages
                )
570
571
572
573
574
575
576
577
578
579
580
                    for s in r.stages
                )
            )
        else:
            feasible_results.sort(key=lambda r: r.estimated_step_time)

        return feasible_results[0]

    def search_with_adaptive_optimization(
        self,
        pp_degrees: List[int],
618
619
620
621
622
623
624
625
626
            deterministically determined by the ILP solver based on
            the (PP, VPP) configuration, following the paradise approach.
        """
        if vpps is None:
            vpps = [1]

        results = []

        for pp_degree in pp_degrees:
624
625
626
627
628
629
630
631
632
        results = []

        for pp_degree in pp_degrees:
            if pp_degree > self.num_layers:
                continue

            for vpp in vpps:
                for micro_batch_num in micro_batch_nums:
                    ppb_input = PPBInput(
hyper_parallel/auto_parallel/sapp_nd/pp_modeling/pp_utils.py
20
21
22
23
24
25
26
27
28
29
30
31
32
any active code path within ``pp_modeling`` since the dual-path estimator
architecture was replaced by the single ILP-path design.
"""

from __future__ import annotations

from typing import List, Dict, Any, Optional, Tuple

from hyper_parallel.auto_parallel.sapp_nd.pp_modeling.pp_types import (
    LayerInfo,
    LayerRecomputeConfig,
    RecomputeMode,
)
31
32
33
34
35
36
37
38
    RecomputeMode,
)


def recompute_mode_to_layer_type(
    mode: RecomputeMode,
) -> "LayerType":
    """Map a :class:`RecomputeMode` to the corresponding :class:`LayerType`.
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63

    Returns:
        LayerType enum value understood by EvaluatorV2 and perf_estimation.
    """
    from hyper_parallel.auto_parallel.sapp_nd.nd.common.layer_type import LayerType  # pylint: disable=C0415

    if mode in (RecomputeMode.SLCT, RecomputeMode.COMM, RecomputeMode.BOTH):
        return LayerType.SEL_REC_LAYER
    if mode == RecomputeMode.FULL:
        return LayerType.FULL_REC_LAYER
    return LayerType.NOT_REC_LAYER


REC_OP_KEYS = ['attBMM', 'headCast', 'dropout', 'softmax', 'normOp', 'gather', 'ffAct']
REC_OP_OPERATOR_KEYS = ['attBMM', 'headCast', 'dropout', 'softmax', 'normOp', 'ffAct']


def apply_recompute_to_rec_op(ccfg: Any, mode: RecomputeMode) -> Dict[str, int]:
    """Set ``ccfg.rec_op`` flags according to *mode* and return the original values.

    For :class:`EvaluatorV2`, ``rec_op`` flags control which sub-operators
    are recomputed under ``SEL_REC_LAYER``:
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110

    Returns:
        Dictionary of original ``rec_op`` values for later restoration.
    """
    original: Dict[str, int] = {}
    for key in REC_OP_KEYS:
        original[key] = getattr(ccfg.rec_op, key, 1)

    if mode == RecomputeMode.NONE:
        for key in REC_OP_KEYS:
            setattr(ccfg.rec_op, key, 1)
    elif mode == RecomputeMode.SLCT:
        for key in REC_OP_OPERATOR_KEYS:
            setattr(ccfg.rec_op, key, 0)
        setattr(ccfg.rec_op, 'gather', 1)
    elif mode == RecomputeMode.COMM:
        for key in REC_OP_OPERATOR_KEYS:
            setattr(ccfg.rec_op, key, 1)
        setattr(ccfg.rec_op, 'gather', 0)
    elif mode == RecomputeMode.BOTH:
        for key in REC_OP_KEYS:
            setattr(ccfg.rec_op, key, 0)

    return original


def restore_rec_op(ccfg: Any, original: Dict[str, int]) -> None:
    """Restore ``ccfg.rec_op`` flags from *original*.

    Args:
        ccfg: EvaluatorV2's ``_ccfg`` object.
109
110
111
112
113
114
115
116
117
118
119
120
121
    Args:
        ccfg: EvaluatorV2's ``_ccfg`` object.
        original: Original values returned by :func:`apply_recompute_to_rec_op`.
    """
    for key, val in original.items():
        setattr(ccfg.rec_op, key, val)


def _extract_body_ids(
    layer_ids: List[int],
    is_first_stage: bool,
    is_last_stage: bool,
    head_layer_id: Any,
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
    has_head_in_partition: bool,
    has_tail_in_partition: bool,
) -> List[int]:
    """Strip HEAD/TAIL layer IDs from *layer_ids* when they are already in the partition."""
    body_ids = list(layer_ids)
    if is_first_stage and has_head_in_partition and head_layer_id in body_ids:
        body_ids.remove(head_layer_id)
    if is_last_stage and has_tail_in_partition and tail_layer_id in body_ids:
        body_ids.remove(tail_layer_id)
    return body_ids


def _build_non_vpp_chunk(
    n_body: int,
    is_first_stage: bool,
    is_last_stage: bool,
    body_layer_type: "LayerType",
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
    is_last_stage: bool,
    body_layer_type: "LayerType",
) -> List[List["LayerType"]]:
    """Build a single-chunk stage (non-VPP)."""
    from hyper_parallel.auto_parallel.sapp_nd.nd.common.layer_type import LayerType  # pylint: disable=C0415
    chunk: List[LayerType] = []
    if is_first_stage:
        chunk.append(LayerType.EMBEDDING_LAYER)
    chunk.extend([body_layer_type] * n_body)
    if is_last_stage:
        chunk.append(LayerType.OUTPUT_LAYER)
    return [chunk]


def _build_vpp_chunks(
    n_body: int,
    is_first_stage: bool,
    is_last_stage: bool,
    num_model_chunks: int,
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
    num_model_chunks: int,
    body_layer_type: "LayerType",
) -> List[List["LayerType"]]:
    """Build multi-chunk stage (VPP) by splitting layers across virtual chunks."""
    from hyper_parallel.auto_parallel.sapp_nd.nd.common.layer_type import LayerType  # pylint: disable=C0415
    extra = int(is_first_stage) + int(is_last_stage)
    total_layers_in_stage = n_body + extra
    base = total_layers_in_stage // num_model_chunks
    remainder = total_layers_in_stage % num_model_chunks
    chunk_sizes = [base + (1 if c < remainder else 0) for c in range(num_model_chunks)]

    all_layers: List[LayerType] = []
    if is_first_stage:
        all_layers.append(LayerType.EMBEDDING_LAYER)
    all_layers.extend([body_layer_type] * n_body)
    if is_last_stage:
        all_layers.append(LayerType.OUTPUT_LAYER)

    stage_chunks: List[List[LayerType]] = []
    offset = 0
    for size in chunk_sizes:
        stage_chunks.append(all_layers[offset:offset + size])
        offset += size
    return stage_chunks


def _check_head_tail_in_partition(
    layers: List[LayerInfo],
    stage_partition: List[List[int]],
) -> Tuple[Optional[int], Optional[int], bool, bool]:
    """Return (head_layer_id, tail_layer_id, has_head, has_tail)."""
    head_layer_id = layers[0].layer_id if layers else None
    tail_layer_id = layers[-1].layer_id if len(layers) > 1 else None
    pp_degree = len(stage_partition)
    has_head = (
        head_layer_id is not None
        and pp_degree > 0
        and head_layer_id in stage_partition[0]
    )
    has_tail = (
        tail_layer_id is not None
        and pp_degree > 0
        and tail_layer_id in stage_partition[-1]
    )
    return head_layer_id, tail_layer_id, has_head, has_tail


def _validate_body_layer_count(
    stage_partition: List[List[int]],
    has_head: bool,
    has_tail: bool,
    evaluator_body_layers: int,
206
207
208
209
210
211
212
213
214
215
216
217
218
219
    has_tail: bool,
    evaluator_body_layers: int,
) -> None:
    """Raise ValueError if body-layer count does not match evaluator config."""
    total = sum(len(s) for s in stage_partition)
    if has_head:
        total -= 1
    if has_tail:
        total -= 1
    if total != evaluator_body_layers:
        raise ValueError(
            f"Stage partition has {total} body layers but the "
            f"model config defines {evaluator_body_layers}."
        )
218
219
220
221
222
223
224
225
226
            f"model config defines {evaluator_body_layers}."
        )


def _build_vpp_stage_chunks(
    body_layer_type: "LayerType",
    n_body: int,
    is_first_stage: bool,
    is_last_stage: bool,
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
    is_last_stage: bool,
    num_model_chunks: int,
) -> List[List["LayerType"]]:
    """Build VPP (interleaved) stage chunks for a single stage."""
    from hyper_parallel.auto_parallel.sapp_nd.nd.common.layer_type import LayerType  # pylint: disable=C0415

    extra = int(is_first_stage) + int(is_last_stage)
    total_layers_in_stage = n_body + extra
    base = total_layers_in_stage // num_model_chunks
    remainder = total_layers_in_stage % num_model_chunks
    chunk_sizes = [
        base + (1 if c < remainder else 0)
        for c in range(num_model_chunks)
    ]

    all_layers: List[LayerType] = []
    if is_first_stage:
        all_layers.append(LayerType.EMBEDDING_LAYER)
    all_layers.extend([body_layer_type] * n_body)
    if is_last_stage:
        all_layers.append(LayerType.OUTPUT_LAYER)

    stage_chunks: List[List[LayerType]] = []
    offset = 0
    for size in chunk_sizes:
        stage_chunks.append(all_layers[offset:offset + size])
        offset += size
    return stage_chunks


def build_stages_from_partition(
    layers: List[LayerInfo],
    evaluator_body_layers: int,
    stage_partition: List[List[int]],
    recompute_config: LayerRecomputeConfig,
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
    Raises:
        ValueError: If the total number of body layers in *stage_partition*
            does not match *evaluator_body_layers*.
    """
    from hyper_parallel.auto_parallel.sapp_nd.nd.common.layer_type import LayerType  # pylint: disable=C0415

    body_layer_type = recompute_mode_to_layer_type(recompute_config.recompute_mode)
    pp_degree = len(stage_partition)

    head_layer_id, tail_layer_id, has_head, has_tail = _check_head_tail_in_partition(
        layers, stage_partition
    )
    _validate_body_layer_count(stage_partition, has_head, has_tail, evaluator_body_layers)

    stages: List[List[List[LayerType]]] = []
    for stage_id, layer_ids in enumerate(stage_partition):
        is_first_stage = not stage_id
        is_last_stage = stage_id == pp_degree - 1

        body_ids = list(layer_ids)
        if is_first_stage and has_head and head_layer_id in body_ids:
            body_ids.remove(head_layer_id)
        if is_last_stage and has_tail and tail_layer_id in body_ids:
            body_ids.remove(tail_layer_id)

        n_body = len(body_ids)

        if num_model_chunks <= 1:
            chunk: List[LayerType] = []
            if is_first_stage:
                chunk.append(LayerType.EMBEDDING_LAYER)
            chunk.extend([body_layer_type] * n_body)
            if is_last_stage:
                chunk.append(LayerType.OUTPUT_LAYER)
            stages.append([chunk])
        else:
            stages.append(_build_vpp_stage_chunks(
                body_layer_type, n_body, is_first_stage, is_last_stage, num_model_chunks
            ))

    return stages
hyper_parallel/auto_parallel/sapp_nd/pp_modeling/pp_visualizer.py
27
28
29
30
31
32
33
34
35
36
37
    from hyper_parallel.auto_parallel.sapp_nd.pp_modeling.sapp_ppb_adapter import (
        SappPPBAdapter,
        SAPP_PPB_AVAILABLE,
    )
except ImportError:
    SAPP_PPB_AVAILABLE = False
    SappPPBAdapter = None


class PPVisualizer:
    """Visualization utilities for Pipeline Parallelism.
74
75
76
77
78
79
80
81
82
            >>> visualizer = PPVisualizer()
            >>> chart = visualizer.generate_stage_load_chart(result)
        """
        if not result.stages:
            return {
                "chart_type": "bar",
                "title": "Stage Load Distribution",
                "x_axis": {"label": "Pipeline Stage", "data": []},
                "y_axes": [],
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
        for stage in result.stages:
            if stage.memory_breakdown:
                memory_data.append(stage.memory_breakdown.total_memory)
            else:
                memory_data.append(0.0)

            if stage.performance_breakdown:
                compute_data.append(stage.performance_breakdown.total_time)
            else:
                compute_data.append(0.0)

        return {
            "chart_type": "bar",
            "title": "Stage Load Distribution",
153
154
155
156
157
158
159
160
161
162
163
164
    def _build_gpipe_schedule(
        stages: List[int], micro_batch_num: int,
    ) -> List[Dict[str, Any]]:
        """Build GPipe schedule entries."""
        schedule = []
        for mb_id in range(micro_batch_num):
            for stage_id in stages:
                schedule.append({
                    "micro_batch": mb_id,
                    "stage": stage_id,
                    "type": "forward",
                    "time_step": mb_id * len(stages) + stage_id,
162
163
164
165
166
167
168
169
170
171
172
                    "stage": stage_id,
                    "type": "forward",
                    "time_step": mb_id * len(stages) + stage_id,
                })
        for mb_id in range(micro_batch_num):
            for stage_id in reversed(stages):
                schedule.append({
                    "micro_batch": mb_id,
                    "stage": stage_id,
                    "type": "backward",
                    "time_step": (micro_batch_num + mb_id) * len(stages) + stage_id,
170
171
172
173
174
175
176
177
178
                    "stage": stage_id,
                    "type": "backward",
                    "time_step": (micro_batch_num + mb_id) * len(stages) + stage_id,
                })
        return schedule

    @staticmethod
    def _build_1f1b_schedule(
        stages: List[int], micro_batch_num: int,
210
211
212
213
214
215
216
217
218
219
220
221
222
    def _build_interleaved_schedule(
        stages: List[int], micro_batch_num: int, num_model_chunks: int = 2,
    ) -> List[Dict[str, Any]]:
        """Build interleaved (VPP) schedule entries."""
        schedule = []
        for mb_id in range(micro_batch_num):
            for chunk_id in range(num_model_chunks):
                for stage_id in stages:
                    schedule.append({
                        "micro_batch": mb_id,
                        "stage": stage_id,
                        "type": "forward",
                        "chunk": chunk_id,
221
222
223
224
225
226
227
228
229
230
231
232
                        "type": "forward",
                        "chunk": chunk_id,
                        "time_step": (mb_id * num_model_chunks + chunk_id) * len(stages) + stage_id,
                    })
        for mb_id in range(micro_batch_num):
            for chunk_id in range(num_model_chunks):
                for stage_id in reversed(stages):
                    schedule.append({
                        "micro_batch": mb_id,
                        "stage": stage_id,
                        "type": "backward",
                        "chunk": chunk_id,
235
236
237
238
239
240
241
242
243
                            + mb_id * num_model_chunks
                            + chunk_id
                        ) * len(stages) + stage_id,
                    })
        return schedule

    def _build_schedule(
        self,
        stages: List[int],
262
263
264
265
266
267
268
269
270
271
272
273
            >>> visualizer = PPVisualizer()
            >>> schedule = visualizer._build_schedule([0, 1], 4, PipelineScheduleType.ONE_F_ONE_B)
        """
        if schedule_type == PipelineScheduleType.GPipe:
            return self._build_gpipe_schedule(stages, micro_batch_num)
        if schedule_type == PipelineScheduleType.ONE_F_ONE_B:
            return self._build_1f1b_schedule(stages, micro_batch_num)
        return self._build_interleaved_schedule(stages, micro_batch_num)

    def generate_memory_breakdown_chart(
        self,
        result: PPStrategyResult,
284
285
286
287
288
289
290
291
292
            >>> visualizer = PPVisualizer()
            >>> chart = visualizer.generate_memory_breakdown_chart(result)
        """
        if not result.stages:
            return {
                "chart_type": "stacked_bar",
                "title": "Memory Breakdown by Stage",
                "x_axis": {"label": "Pipeline Stage", "data": []},
                "y_axis": {"label": "Memory (MB)"},
306
307
308
309
310
311
312
313
314
315
316
317
                grad_data.append(stage.memory_breakdown.grad_memory)
                optimizer_data.append(stage.memory_breakdown.optimizer_state_memory)
                activation_data.append(stage.memory_breakdown.activation_memory)
            else:
                param_data.append(0.0)
                grad_data.append(0.0)
                optimizer_data.append(0.0)
                activation_data.append(0.0)

        return {
            "chart_type": "stacked_bar",
            "title": "Memory Breakdown by Stage",
341
342
343
344
345
346
347
348
349
            >>> visualizer = PPVisualizer()
            >>> chart = visualizer.generate_performance_breakdown_chart(result)
        """
        if not result.stages:
            return {
                "chart_type": "stacked_bar",
                "title": "Performance Breakdown by Stage",
                "x_axis": {"label": "Pipeline Stage", "data": []},
                "y_axis": {"label": "Time (ms)"},
363
364
365
366
367
368
369
370
371
372
373
374
                backward_data.append(stage.performance_breakdown.backward_time)
                comm_data.append(stage.performance_breakdown.communication_time)
                recompute_data.append(stage.performance_breakdown.recompute_overhead)
            else:
                forward_data.append(0.0)
                backward_data.append(0.0)
                comm_data.append(0.0)
                recompute_data.append(0.0)

        return {
            "chart_type": "stacked_bar",
            "title": "Performance Breakdown by Stage",
403
404
405
406
407
408
409
410
411

        bubble_stages = pp_degree - 1

        if not result.stages:
            ideal_step_time = 0.0
        else:
            ideal_step_time = micro_batch_num * max(
                s.performance_breakdown.total_time if s.performance_breakdown else 0.0
                for s in result.stages
436
437
438
439
440
441
442
443
444
445
        Example:
            >>> visualizer = PPVisualizer()
            >>> heatmap = visualizer.generate_imbalance_heatmap(result)
        """
        if not result.stages:
            return {
                "chart_type": "heatmap",
                "title": "Stage Load Imbalance Heatmap",
                "x_axis": {"label": "Stage", "data": []},
                "y_axis": {"label": "Stage", "data": []},
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
                "data": [],
                "colorscale": "Viridis",
            }

        num_stages = len(result.stages)
        heatmap_data = []

        for i, stage_i in enumerate(result.stages):
            row = []
            for j, stage_j in enumerate(result.stages):
                if i == j:
                    row.append(0.0)
                else:
                    mem_i = (
                        stage_i.memory_breakdown.total_memory
                        if stage_i.memory_breakdown
                        else 0.0
                    )
                    mem_j = (
                        stage_j.memory_breakdown.total_memory
                        if stage_j.memory_breakdown
                        else 0.0
                    )
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
                        if stage_j.memory_breakdown
                        else 0.0
                    )

                    if mem_i + mem_j > 0:
                        imbalance = abs(mem_i - mem_j) / (mem_i + mem_j)
                    else:
                        imbalance = 0.0
                    row.append(imbalance)
            heatmap_data.append(row)

        return {
            "chart_type": "heatmap",
            "title": "Stage Load Imbalance Heatmap",
            "x_axis": {"label": "Stage", "data": list(range(num_stages))},
            "y_axis": {"label": "Stage", "data": list(range(num_stages))},
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
            >>> visualizer = PPVisualizer()
            >>> viz_data = visualizer.visualize_with_sapp_ppb(
            ...     result, dry_run_path="dry_run.yaml", show=True)
        """
        if not self.use_sapp_ppb or not SAPP_PPB_AVAILABLE:
            return None

        if not dry_run_path:
            return None

        try:
            from hyper_parallel.auto_parallel.sapp_nd.pp_modeling.pp_types import PPBInput  # pylint: disable=C0415

            if not result.stages or not result.strategy.stage_partition:
                return None

            if not self._sapp_ppb_adapter:
                ppb_input = PPBInput(
                    num_layer=len(result.stages) * 2,
                    pp_degree=result.strategy.pp_degree,
                    micro_batch_num=result.strategy.micro_batch_num,
                    dry_run_path=dry_run_path,
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
                    pp_degree=result.strategy.pp_degree,
                    micro_batch_num=result.strategy.micro_batch_num,
                    dry_run_path=dry_run_path,
                )
                self._sapp_ppb_adapter = SappPPBAdapter(ppb_input)

            return self._sapp_ppb_adapter.visualize(
                result=result,
                file_name=file_name,
                show=show,
            )
        except (ImportError, ValueError, RuntimeError):
            return None

    def generate_pipeline_timeline_plot(
        self,
        result: PPStrategyResult,
756
757
758
759
760
761
762
763
764
765
            >>> visualizer = PPVisualizer()
            >>> data = visualizer.generate_pipeline_timeline_plot(
            ...     result, dry_run_path="dry_run.yaml")
        """
        if self.use_sapp_ppb:
            sapp_ppb_result = self.visualize_with_sapp_ppb(
                result=result,
                file_name=file_name,
                show=show,
                dry_run_path=dry_run_path,
763
764
765
766
767
768
769
770
771
772
                file_name=file_name,
                show=show,
                dry_run_path=dry_run_path,
            )
            if sapp_ppb_result is not None:
                return sapp_ppb_result

        basic_data = self.generate_pipeline_schedule_diagram(result)

        return basic_data
hyper_parallel/auto_parallel/sapp_nd/pp_modeling/ppb_load_balancer.py
63
64
65
66
67
68
69
70
71
            ValueError: If num_layer is not positive, or if pp_degree /
                micro_batch_num are invalid.
        """
        if ppb_input.num_layer <= 0:
            raise ValueError(
                "For PPBLoadBalancer, ppb_input.num_layer must be positive."
            )

        if ppb_input.pp_degree <= 0:
 94
 95
 96
 97
 98
 99
100
101
102
        if SAPP_PPB_AVAILABLE:
            try:
                self._sapp_ppb_adapter = SappPPBAdapter(ppb_input)
            except ImportError:
                self._sapp_ppb_adapter = None
            except ValueError:  # pylint: disable=W0706
                raise

    def validate_dimensions(
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
        Example:
            >>> balancer = PPBLoadBalancer(ppb_input)
            >>> is_valid, msg = balancer.validate_pp_degree(machine_device_num=8, tp=2)
        """
        pp = self.pp_degree

        devices_per_node = machine_device_num // tp
        if devices_per_node <= 0:
            return (
                False,
                f"Invalid device configuration: machine_device_num ({machine_device_num}) // TP ({tp}) <= 0.",
            )

        if pp > devices_per_node:
            return (
                False,
                f"PP ({pp}) exceeds maximum pipeline bound ({devices_per_node}).",
            )

        if num_layers is not None:
            max_pp_by_layers = num_layers // vpp
            if pp > max_pp_by_layers:
                return (
                    False,
                    f"PP ({pp}) exceeds layer count limit ({max_pp_by_layers} = {num_layers} // {vpp}).",
                )

        if devices_per_node % pp:
            return (
                False,
                f"PP ({pp}) must divide devices_per_node ({devices_per_node}).",
            )

        return True, ""

    def _ilp_balance(self, num_of_interleave: int = 1, comm_time: float = 0.0) -> PPBOutput:
        """Create ILP-optimized load-balanced partition using sapp-ppb.
244
245
246
247
248
249
250
251
            This method uses ILP (Integer Linear Programming) solver exclusively.
            No fallback to greedy/dp algorithms.
        """
        if not SAPP_PPB_AVAILABLE or self._sapp_ppb_adapter is None:
            raise ImportError(
                "sapp-ppb module is required for ILP-based load balancing. "
                "Please ensure sapp-ppb is installed and accessible."
            )
305
306
307
308
309
310
311
312
313
314
315
316
317
            >>> ilp_output = balancer._ilp_balance()
            >>> refined = balancer._ilp_then_offset_refine(ilp_output)
        """
        if self._sapp_ppb_adapter is None or self._sapp_ppb_adapter._pipeline is None:  # pylint: disable=W0212
            return ilp_output

        base_partition = ilp_output.stage_partition
        if not base_partition:
            return ilp_output

        offset_ranges = self._partition.get_offset_range(base_partition)

        best_partition = base_partition
340
341
342
343
344
345
346
347
348
                candidate_score = self._evaluate_partition_score(
                    candidate_partition, pipeline,
                )
                if candidate_score is None:
                    continue

                if candidate_score < best_score:
                    best_partition = candidate_partition
                    best_score = candidate_score
382
383
384
385
386
387
388
389
390
                lay for lay in self._sapp_ppb_adapter.layers_sapp_ppb
                if lay.type_ == _get_pipeline_layer_class().type_enum.BODY
            ]
            if not body_layers:
                return None

            body_layer_num = sum(lay.nb_layer_ for lay in body_layers)
            nass_flat = self._partition_to_nass(stage_partition, body_layer_num)
            nass = [nass_flat]
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
                for stage in range(num_of_stage)
            ]

            if not stage_times or all(t == 0.0 for t in stage_times):
                return None

            avg_time = sum(stage_times) / len(stage_times)
            if avg_time == 0.0:
                return None
            return max(stage_times) / avg_time - 1.0

        except (ValueError, KeyError, AttributeError, IndexError):
            return None

    def _partition_to_nass(
        self,
        stage_partition: List[List[int]],
576
577
578
579
580
581
582
583
584
585
                comm_cost_in_objective=ilp_output.comm_cost_in_objective,
                known_limitations=ilp_output.known_limitations,
                optimization_comparison=optimization_comparison,
            )
        except (ValueError, KeyError, AttributeError, IndexError):
            return ilp_output

    def balance(
        self,
        algorithm: str = "ilp",  # pylint: disable=W0613
649
650
651
652
653
654
655
656
            deterministically determined by the ILP solver based on
            the (PP, VPP) configuration, following the paradise approach.
        """
        if not SAPP_PPB_AVAILABLE or self._sapp_ppb_adapter is None:
            raise ImportError(
                "sapp-ppb module is required for adaptive optimization. "
                "Please ensure sapp-ppb is installed and accessible."
            )
694
695
696
697
698
699
700
701
702
        Example:
            >>> balancer = PPBLoadBalancer(ppb_input)
            >>> d = balancer.to_dict(output)
        """
        result: Dict[str, Any] = {
            "stage_partition": output.stage_partition,
            "stage_compute_cost": output.stage_compute_cost,
            "stage_memory_cost": output.stage_memory_cost,
            "stage_comm_cost": output.stage_comm_cost,
704
705
706
707
708
709
710
711
712
713
            "imbalance_score": output.imbalance_score,
            "is_feasible": output.is_feasible,
            "infeasibility_details": output.infeasibility_details,
        }
        if output.optimization_comparison is not None:
            result["optimization_comparison"] = {
                "pre_opt_time": output.optimization_comparison.pre_opt_time,
                "post_opt_time": output.optimization_comparison.post_opt_time,
                "pre_opt_mem_parameter": output.optimization_comparison.pre_opt_mem_parameter,
                "post_opt_mem_parameter": output.optimization_comparison.post_opt_mem_parameter,
716
717
718
719
720
                "pre_opt_simulator_end_time": output.optimization_comparison.pre_opt_simulator_end_time,
                "post_opt_simulator_end_time": output.optimization_comparison.post_opt_simulator_end_time,
                "improvement_ratio": output.optimization_comparison.improvement_ratio,
            }
        return result
hyper_parallel/auto_parallel/sapp_nd/pp_modeling/sapp_ppb_adapter.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
    from sapp_ppb.sapp.sapp_pipeline import SappPipeline
    from sapp_ppb.utils import recompute as Recompute
    from sapp_ppb.simulator.pp_simulator import PipelineSimulator
    SAPP_PPB_AVAILABLE = True
except ImportError:
    try:
        from hyper_parallel.auto_parallel.sapp_ppb.sapp.sapp_pipeline import SappPipeline  # pylint: disable=C0412
        from hyper_parallel.auto_parallel.sapp_ppb.utils import recompute as Recompute
        from hyper_parallel.auto_parallel.sapp_ppb.simulator.pp_simulator import PipelineSimulator
        SAPP_PPB_AVAILABLE = True
    except ImportError:
        SAPP_PPB_AVAILABLE = False
        SappPipeline = None
        Recompute = None
        PipelineSimulator = None


@dataclass
class YamlConstraints:
 97
 98
 99
100
101
102
103
104
105
    identity mismatches, we always retrieve ``Layer`` from
    ``SappPipeline``'s own module globals.
    """
    if SappPipeline is None:
        return None
    return SappPipeline.__init__.__globals__.get("Layer")


def _build_backward_time_rec(
245
246
247
248
249
250
251
252
            ImportError: If sapp-ppb module is not available.
            ValueError: If ppb_input.dry_run_path is empty or required fields are missing.
        """
        if not SAPP_PPB_AVAILABLE:
            raise ImportError(
                "sapp-ppb module is not available. "
                "Please ensure sapp-ppb is installed and accessible."
            )
258
259
260
261
262
263
264
265
266
                "Please provide a valid dry-run YAML path."
            )

        if ppb_input.num_layer <= 0:
            raise ValueError(
                "ppb_input.num_layer must be positive."
            )

        self.ppb_input = ppb_input
300
301
302
303
304
305
306
307
308

            pipeline_cfg = cfg.get("pipeline_config", {})
            num_layer = pipeline_cfg.get("num_layer", 0)
            if num_layer <= 0:
                raise ValueError(
                    "pipeline_config.num_layer is required when using "
                    "memory_parameter_config"
                )
        else:
361
362
363
364
365
366
367
368
369
370
371
372
373
374
            Integer PP degree used during dry-run, or 0 if not specified.
        """
        if "pipeline_num" in pipeline_cfg:
            return int(pipeline_cfg["pipeline_num"])
        if "pipeline_num_range" in pipeline_cfg:
            val = pipeline_cfg["pipeline_num_range"]
            if isinstance(val, list) and len(val) > 0:
                return int(val[0])
            if isinstance(val, (int, float)):
                return int(val)
        return 0

    def _parse_yaml_constraints(self, num_of_interleave: int) -> YamlConstraints:
        """Parse the dry-run YAML for user-specified offset/recompute values.
425
426
427
428
429
430
431
432
433
434
        pp_for_offset = yaml_pp if yaml_pp > 0 else target_pp
        raw_offset = pipeline_cfg.get("offset", 0)
        try:
            offset_normalized, rounds = process_offset(raw_offset, pp_for_offset)
        except ValueError:
            return YamlConstraints(
                offset_specified=False,
                recompute_specified=False,
                offset_per_stage=[[0] * target_pp for _ in range(num_of_interleave)],
                recompute_per_type_per_stage={},
457
458
459
460
461
462
463
464
465
466
467
468
469
            if val is None:
                continue
            if isinstance(val, bool):
                if val:
                    recompute_specified = True
                    break
            elif isinstance(val, (int, float)) and val != 0:
                recompute_specified = True
                break
            elif isinstance(val, list):
                flat_vals: List[int] = []
                for item in val:
                    if isinstance(item, list):
492
493
494
495
496
497
498
499
500
                            yaml_format[yaml_key] = val
                        else:
                            yaml_format[yaml_key] = [list(val) for _ in range(num_of_interleave)]
                    else:
                        yaml_format[yaml_key] = val

                layer_per_recompute = Recompute.internal_from_yaml(
                    num_of_interleave, target_pp, yaml_format, nass,
                )
502
503
504
505
506
507
508
509
510
511
512
                    recompute_per_type_per_stage[rec] = [
                        list(layer_per_recompute[rec][inter])
                        for inter in range(num_of_interleave)
                    ]
            except (ValueError, KeyError, IndexError):
                recompute_specified = False
                recompute_per_type_per_stage = {}

        return YamlConstraints(
            offset_specified=offset_specified,
            recompute_specified=recompute_specified,
554
555
556
557
558
559
560
561
562
563
564
565
            ImportError: If sapp-ppb module is not available.
            ValueError: If YAML parsing or memory computation fails.
        """
        if not SAPP_PPB_AVAILABLE:
            raise ImportError("sapp-ppb module is not available")

        pipeline_layer = _get_pipeline_layer_class()
        if pipeline_layer is None:
            raise ImportError("Cannot resolve the Layer class used by SappPipeline")

        comp_mem, mem_par, mem_head, mem_tail, mem_act, _ = self._parse_dry_run_memory()
        head_time, body_time, tail_time = self._parse_dry_run_timing()
609
610
611
612
613
614
615
616
617
            PPBOutput with is_feasible=False.
        """
        details: Dict[str, Any] = {"reason": reason}
        if error:
            details["error"] = error
        if solver_status is not None:
            details["solver_status"] = solver_status
        return PPBOutput(
            stage_partition=[],
655
656
657
658
659
660
661
662
663
664
665
        Returns:
            PPBOutput with all computed data.
        """
        if layer_offset is None:
            layer_offset = []
        if layer_recompute is None:
            layer_recompute = self.ppb_input.layer_recompute_config

        limitations = []
        if not comm_cost_in_objective:
            limitations.append(
669
670
671
672
673
674
675
676
677
            limitations.append(
                "Memory parameters from memory_parameter_config (ComputeMemory decomposition bypassed)"
            )
        else:
            limitations.append(
                "Layer grouping driven by dry-run ComputeMemory (HEAD/BODY/TAIL)"
            )

        return PPBOutput(
779
780
781
782
783
784
785
786
787
            lay for lay in self.layers_sapp_ppb
            if lay.type_ == _get_pipeline_layer_class().type_enum.BODY
        ]
        if not body_layers:
            return (
                [0.0] * self.ppb_input.pp_degree,
                [0.0] * self.ppb_input.pp_degree,
                [0.0] * self.ppb_input.pp_degree,
            )
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
                            [exclusive[inter][s] + full_counts[inter][s] + both_counts[inter][s]
                             for s in range(pp)]
                            for inter in range(num_of_interleave)
                        ]
                    elif rec == Recompute.TYPE.COMM:
                        full_counts = constraints.recompute_per_type_per_stage.get(
                            Recompute.TYPE.FULL, zeros_2d,
                        )
                        both_counts = constraints.recompute_per_type_per_stage.get(
                            Recompute.TYPE.BOTH, zeros_2d,
                        )
                        cumulative = [
                            [exclusive[inter][s] + full_counts[inter][s] + both_counts[inter][s]
                             for s in range(pp)]
                            for inter in range(num_of_interleave)
                        ]
931
932
933
934
935
936
937
938
939
            >>> adapter = SappPPBAdapter(ppb_input)
            >>> output = adapter.balance_with_ilp(time_limit=60)
        """
        if not SAPP_PPB_AVAILABLE:
            raise ImportError("sapp-ppb module is not available")

        self._pipeline = SappPipeline(
            model_name="sapp_nd_model",
            num_of_stage=self.ppb_input.pp_degree,
973
974
975
976
977
978
979
980
981
        result = self._pipeline.get_result()

        try:
            stage_partition = self._extract_stage_partition(result)
        except RuntimeError as e:
            return self._make_infeasible_output(
                "Failed to extract valid partition from ILP solution",
                error=str(e),
            )
1079
1080
1081
1082
1083
1084
1085
1086
1087
        """
        stage_partition = [[] for _ in range(self.ppb_input.pp_degree)]

        if self._pipeline is None or self._pipeline.problem_ is None:
            raise RuntimeError("Pipeline not constructed or solved yet")

        solver = self._pipeline.problem_

        body_layer_ids = list(range(1, self.ppb_input.num_layer + 1))
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
                stage_partition,
                allow_empty_stages=self.ppb_input.allow_empty_stages,
            )
            if not is_valid:
                raise RuntimeError(f"Invalid partition: {error_msg}")
        except (RuntimeError, ValueError) as e:
            raise RuntimeError(
                f"Invalid partition extracted from ILP solver: {str(e)}"
            ) from e

        return stage_partition
1142
1143
1144
1145
1146
1147
1148
1149
1150
                        try:
                            var_value = solver.variables_[group_name][rec][inter][stage_id].varValue
                            if var_value is not None:
                                total_count += round(var_value)
                        except (KeyError, AttributeError):
                            continue
                stage_counts[stage_id] += total_count

        return stage_counts
1165
1166
1167
1168
1169
1170
1171
1172
1173
        Returns:
            List of offsets per stage boundary (length = pp_degree - 1).
        """
        if not stage_partition or len(stage_partition) <= 1:
            return []

        uniform = StagePartition(
            self.ppb_input.num_layer + 2, self.ppb_input.pp_degree,
        ).uniform_partition()
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
            Kept for backward compatibility only.  New code should call
            :meth:`_extract_per_stage_recompute` which preserves mixed
            recompute assignments.
        """
        group_rec_type = Recompute.TYPE.NONE
        max_count = 0
        recompute_considered = getattr(solver, 'recompute_considered_', None)
        for inter in range(num_interleave):
            for stage_id in range(pp_degree):
                for rec in Recompute.TYPE:
                    if recompute_considered and not recompute_considered.get(rec, False):
                        continue
                    try:
                        var_value = solver.variables_[group_name][rec][inter][stage_id].varValue
                        count = round(var_value) if var_value is not None else 0
                        if count > max_count:
                            max_count = count
                            group_rec_type = rec
                    except (KeyError, AttributeError):
                        continue
        return group_rec_type if group_rec_type != Recompute.TYPE.NONE else None

    @staticmethod
    def _extract_per_stage_recompute(
        solver: Any,
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
                for inter in range(num_interleave):
                    try:
                        var_value = solver.variables_[group_name][rec][inter][stage_id].varValue
                        total += round(var_value) if var_value is not None else 0
                    except (KeyError, AttributeError):
                        pass
                if total > 0:
                    rec_counts[rec.name] = total
            if rec_counts:
                stage_detail[stage_id] = rec_counts
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
            RecomputeMode,
        )

        if self._pipeline is None or self._pipeline.problem_ is None:
            return self.ppb_input.layer_recompute_config

        solver = self._pipeline.problem_
        if not hasattr(solver, 'variables_') or not solver.variables_:
            return self.ppb_input.layer_recompute_config

        recompute_type_map = {
            Recompute.TYPE.NONE: RecomputeMode.NONE,
            Recompute.TYPE.SLCT: RecomputeMode.SLCT,
1283
1284
1285
1286
1287
1288
1289
1290
1291
        }

        group_name = "layer_group_1"
        if group_name not in solver.variables_:
            return self.ppb_input.layer_recompute_config

        stage_detail = self._extract_per_stage_recompute(
            solver, group_name,
            self._pipeline.num_of_interleave_,
1292
1293
1294
1295
1296
1297
1298
1299
1300
            self.ppb_input.pp_degree,
        )

        if not stage_detail:
            return self.ppb_input.layer_recompute_config

        recompute_layers = set()
        stage_recompute_modes: Dict[int, RecomputeMode] = {}
        global_max_count = 0
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
            >>> adapter = SappPPBAdapter(ppb_input)
            >>> score = adapter._calculate_imbalance_score([100, 110, 105])
        """
        if not costs or all(not c for c in costs):
            return 0.0

        avg_cost = sum(costs) / len(costs)
        if not avg_cost:
            return 0.0

        max_cost = max(costs)
        return max_cost / avg_cost - 1.0
1400
1401
1402
1403
1404
1405
1406
1407
1408
            >>> adapter = SappPPBAdapter(ppb_input)
            >>> total_time, bubbles = adapter.simulate_pipeline(partition, block_time=[10.0, 12.0])
        """
        if not SAPP_PPB_AVAILABLE:
            raise ImportError("sapp-ppb module is not available")

        simulator = PipelineSimulator(
            block_time=block_time,
            micro_num=self.ppb_input.micro_batch_num,
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422

        simulator.run(print_info=False)

        if show or file_name:
            if file_name:
                simulator.save(file_name, comm=True)
            elif show:
                simulator.show(comm=True)

        total_time = max(
            (line[-1].end for line in simulator.lines if line),
            default=0.0
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
        Returns:
            :class:`SimulatorResult` or ``None`` if simulation fails.
        """
        if not SAPP_PPB_AVAILABLE or self._pipeline is None:
            return None

        if self.ppb_input.micro_batch_num < self.ppb_input.pp_degree:
            return None

        try:
            if not fw_time_2d or not fw_time_2d[0]:
                return None

            num_interleave = self._pipeline.num_of_interleave_

            if num_interleave == 1:
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
                layer_recompute_input = list(rec_time_2d[0])
                block_mem_input = list(mem_act_2d[0])
                block_mem_par_input = list(mem_par_2d[0])
            else:
                block_time_input = [list(fw_time_2d[inter]) for inter in range(num_interleave)]
                layer_recompute_input = [list(rec_time_2d[inter]) for inter in range(num_interleave)]
                block_mem_input = [list(mem_act_2d[inter]) for inter in range(num_interleave)]
                block_mem_par_input = [list(mem_par_2d[inter]) for inter in range(num_interleave)]

            use_comm = comm_time > 0.0

            simulator = PipelineSimulator(
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
                end_time=simulator.end_time,
                bubbles=dict(simulator.bubbles),
                peak_memory=list(simulator.peak_memory),
            )
        except (ValueError, KeyError, AttributeError, IndexError):
            return None
        except Exception as _e:
            from sapp_ppb.simulator.causal_error import CausalCommError, CausalError  # pylint: disable=C0415
            if isinstance(_e, (CausalCommError, CausalError)):
                return None
            raise

    def simulate_from_ilp(
        self,
        comm_time: float = 0.0,
1525
1526
1527
1528
1529
1530
1531
1532
1533
            >>> sim.end_time
            1234.5
        """
        if not SAPP_PPB_AVAILABLE or self._pipeline is None:
            return None

        try:
            fw_time = self._pipeline.get_fw_time()
            rec_time = self._pipeline.get_recompute_time()
1534
1535
1536
1537
1538
1539
1540
1541
1542
            mem_act = self._pipeline.get_memory_activation()
            mem_par = self._pipeline.get_memory_parameter()

            if not fw_time or not fw_time[0]:
                return None

            return self._run_simulator(
                fw_time_2d=fw_time,
                rec_time_2d=rec_time,
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
                mem_act_2d=mem_act,
                mem_par_2d=mem_par,
                comm_time=comm_time,
            )
        except (ValueError, KeyError, AttributeError, IndexError):
            return None

    def simulate_pre_opt_baseline(
        self,
        comm_time: float = 0.0,
1573
1574
1575
1576
1577
1578
1579
1580
1581
            or self._pre_opt_rec_time_2d is None
            or self._pre_opt_mem_act_2d is None
            or self._pre_opt_mem_par_2d is None
        ):
            return None

        return self._run_simulator(
            fw_time_2d=self._pre_opt_fw_time_2d,
            rec_time_2d=self._pre_opt_rec_time_2d,
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
        Example:
            >>> adapter = SappPPBAdapter(ppb_input)
            >>> viz_data = adapter.visualize(result)
        """
        if not SAPP_PPB_AVAILABLE:
            raise ImportError("sapp-ppb module is not available")

        block_time = []
        block_mem = []
        for stage in result.stages:
            if stage.performance_breakdown:
                block_time.append(stage.performance_breakdown.forward_time)
            else:
                block_time.append(0.0)

            if stage.memory_breakdown:
                block_mem.append(stage.memory_breakdown.total_memory)
            else:
                block_mem.append(0.0)

        simulator = PipelineSimulator(
            block_time=block_time,
            micro_num=result.strategy.micro_batch_num,
            block_mem=block_mem,
        )
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
            micro_num=result.strategy.micro_batch_num,
            block_mem=block_mem,
        )

        simulator.run(print_info=False)

        if show or file_name:
            if file_name:
                simulator.save(file_name, comm=True)
            elif show:
                simulator.show(comm=True)

        return {
            "total_time": simulator.lines[0][-1].finish if simulator.lines else 0.0,
            "bubbles": dict(simulator.bubbles),
            "peak_memory": list(simulator.peak_memory),
            "stage_partition": result.strategy.stage_partition,
hyper_parallel/auto_parallel/sapp_nd/pp_modeling/stage_partition.py
46
47
48
49
50
51
52
53
54
55
56
57
58
            num_layers: Total number of layers (HEAD + body layers + TAIL).
            pp_degree: Number of pipeline stages.
        """
        if pp_degree <= 0:
            raise ValueError(
                f"For StagePartition, pp_degree should be positive, but got {pp_degree}."
            )
        if num_layers <= 0:
            raise ValueError(
                f"For StagePartition, num_layers should be positive, but got {num_layers}."
            )
        if pp_degree > num_layers:
            raise ValueError(
109
110
111
112
113
114
115
116
            >>> partition = StagePartition(num_layers=4, pp_degree=2)
            >>> is_valid, msg = partition.validate_partition([[0, 1], [2, 3]])
        """
        if len(stage_partition) != self.pp_degree:
            return (
                False,
                f"Expected {self.pp_degree} stages, got {len(stage_partition)}.",
            )
142
143
144
145
146
147
148
149
150
151
152
153
            )

        for stage_id in range(self.pp_degree - 1):
            if not stage_partition[stage_id] or not stage_partition[stage_id + 1]:
                continue
            last_layer_in_stage = max(stage_partition[stage_id])
            first_layer_in_next = min(stage_partition[stage_id + 1])
            if last_layer_in_stage >= first_layer_in_next:
                return (
                    False,
                    f"Stage {stage_id} and {stage_id + 1} are not ordered correctly.",
                )
178
179
180
181
182
183
184
185
186
187
188
            >>> adjusted
            [[0], [1, 2, 3]]
        """
        if not offsets:
            return stage_partition

        if len(offsets) != self.pp_degree - 1:
            raise ValueError(
                f"Expected {self.pp_degree - 1} offsets for {self.pp_degree} stages, "
                f"got {len(offsets)}."
            )
208
209
210
211
212
213
214
215
216
                    partition[stage_from].append(layer_id)
            else:
                for _ in range(-offset):
                    if not partition[stage_from]:
                        raise ValueError(
                            f"Cannot apply offset {offset}: stage {stage_from} is empty."
                        )
                    layer_id = partition[stage_from].pop()
                    partition[stage_to].insert(0, layer_id)
261
262
263
264
265
266
267
268
269
                    "offsets": offsets,
                },
            )

        return PPBOutput(
            stage_partition=adjusted,
            is_feasible=True,
            layer_offset=offsets,
        )
327
328
329
330
331
332
333
334
335
336
337
            >>> stages = [[0, 1], [2, 3, 4]]
            >>> offset_range = partition.get_offset_range(stages)
        """
        if offset_config and offset_config.offset_range:
            return offset_config.offset_range

        if stage_partition is None:
            raise ValueError(
                "For get_offset_range, stage_partition is required when "
                "offset_config.offset_range is not provided."
            )
336
337
338
339
340
341
342
343
                "offset_config.offset_range is not provided."
            )

        if len(stage_partition) != self.pp_degree:
            raise ValueError(
                f"For get_offset_range, expected stage_partition length "
                f"{self.pp_degree}, got {len(stage_partition)}."
            )
370
371
372
373
374
375
376
377
378
379
        Example:
            >>> partition = StagePartition(num_layers=4, pp_degree=2)
            >>> stages = partition.generate_partition_with_offset([-1])
        """
        base_partition = self.uniform_partition()
        return self.apply_offset(base_partition, offsets)

    def get_stage_boundaries(
        self,
        stage_partition: List[List[int]],
392
393
394
395
396
397
398
399
400
401
402
403
            >>> boundaries = partition.get_stage_boundaries(stages)
            >>> boundaries
            [1, 2]
        """
        boundaries = []
        for stage_id in range(self.pp_degree - 1):
            last_layer = max(stage_partition[stage_id])
            boundaries.append(last_layer)
        return boundaries

    def to_dict(self, stage_partition: List[List[int]]) -> Dict[str, Any]:
        """Convert stage partition to dictionary format.
412
413
414
415
416
417
418
419
420
            >>> partition = StagePartition(num_layers=4, pp_degree=2)
            >>> stages = [[0, 1], [2, 3]]
            >>> d = partition.to_dict(stages)
        """
        return {
            "pp_degree": self.pp_degree,
            "num_layers": self.num_layers,
            "stage_partition": stage_partition,
            "layers_per_stage": [len(stage) for stage in stage_partition],
hyper_parallel/auto_parallel/sapp_nd/pp_modeling/yaml_config_builder.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
            ...     model=resnet50,
            ...     input_shape=(1, 3, 224, 224),
            ... )
        """
        if model is None:
            model = self.model
        if model is None:
            raise ValueError("No model provided for profiling.")

        try:
            # Import inside method: optional dependency (torch may not be installed)
            from torch import nn  # pylint: disable=C0415,C9002
        except ImportError as exc:
            raise ImportError("PyTorch is required for PyTorch model profiling.") from exc

        if layer_filter is None:
            def default_filter(name: str, module: nn.Module) -> bool:  # pylint: disable=W0613
                """Filter leaf modules that have no children.

                Args:
                    name: Module name (unused).
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122

                Returns:
                    True if the module is a leaf module.
                """
                return not any(module.children())
            layer_filter = default_filter

        layers = []
        layer_id = 0

        for name, module in model.named_modules():
            if not layer_filter(name, module):
                continue

            layer_info = self._profile_pytorch_layer(
                layer_id=layer_id,
                layer_name=name,
                module=module,
                input_shape=input_shape,
122
123
124
125
126
127
128
129
130
131
132
133
134
                input_shape=input_shape,
                input_tensor=input_tensor,
                use_symbolic=use_symbolic,
            )
            layers.append(layer_info)
            layer_id += 1

        self._profiled_layers = layers
        return layers

    def _profile_pytorch_layer(
        self,
        layer_id: int,
152
153
154
155
156
157
158
159
160
161
162
163
164
            LayerInfo for this layer.
        """
        # Import inside method: optional dependency (torch may not be installed)
        # nn is already imported in the caller (profile_pytorch_model)
        layer_type = module.__class__.__name__

        param_memory = self._estimate_param_memory(module)

        activation_memory = self._estimate_activation_memory(
            module, input_shape, input_tensor, use_symbolic
        )

        # NOTE: These recompute activation coefficients are rough placeholders.
163
164
165
166
167
168
169
170
171
172
173
174
175
176

        # NOTE: These recompute activation coefficients are rough placeholders.
        # Downstream consumers that use config_path get precise values from
        # perf_estimation; the legacy backfill path also overwrites them.
        activation_memory_slct = activation_memory * 0.04
        activation_memory_comm = activation_memory * 0.125
        activation_memory_both = activation_memory * 0.165
        activation_memory_full = activation_memory * 0.5

        return LayerInfo(
            layer_id=layer_id,
            layer_name=layer_name,
            layer_type=layer_type,
            param_memory=param_memory,
189
190
191
192
193
194
195
196
197
198
199
200

        Returns:
            Parameter memory in MB.
        """
        param_size = 0
        for param in module.parameters():
            param_size += param.numel() * param.element_size()
        return param_size / (1024 * 1024)

    def _estimate_activation_memory(
        self,
        module: Any,
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
        Returns:
            Activation memory in MB.
        """
        # Import inside method: optional dependency (torch may not be installed)
        import torch  # pylint: disable=C0415,C9002

        if input_tensor is not None:
            try:
                with torch.no_grad():
                    input_size = input_tensor.numel() * input_tensor.element_size()
                    output = module(input_tensor)
                    if isinstance(output, (tuple, list)):
                        output_size = sum(
                            obj.numel() * obj.element_size()
                            for obj in output
                            if isinstance(obj, torch.Tensor)
                        )
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
                            for obj in output
                            if isinstance(obj, torch.Tensor)
                        )
                    else:
                        output_size = output.numel() * output.element_size()
                    return (input_size + output_size) / (1024 * 1024)
            except Exception:  # pylint: disable=W0718
                pass

        if input_shape is not None:
            try:
                dummy_input = torch.randn(*input_shape)
                with torch.no_grad():
                    input_size = dummy_input.numel() * dummy_input.element_size()
                    output = module(dummy_input)
                    if isinstance(output, (tuple, list)):
                        output_size = sum(
                            obj.numel() * obj.element_size()
                            for obj in output
                            if isinstance(obj, torch.Tensor)
                        )
245
246
247
248
249
250
251
252
253
254
255
256
257
258
                            for obj in output
                            if isinstance(obj, torch.Tensor)
                        )
                    else:
                        output_size = output.numel() * output.element_size()
                    return (input_size + output_size) / (1024 * 1024)
            except Exception:  # pylint: disable=W0718
                pass

        return 1.0

    def profile_from_mindspore_model(
        self,
        model: Optional[Any] = None,
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
            ...     model=gpt_model,
            ...     input_shape=(1, 1024),
            ... )
        """
        if model is None:
            model = self.model
        if model is None:
            raise ValueError("No model provided for profiling.")

        try:
            # Import inside method: optional dependency (mindspore may not be installed)
            from mindspore import nn  # pylint: disable=C0415,C9002
        except ImportError as exc:
            raise ImportError("MindSpore is required for MindSpore model profiling.") from exc

        if layer_filter is None:
            def default_filter(name: str, cell: nn.Cell) -> bool:  # pylint: disable=W0613
                """Filter leaf cells that have no children.

                Args:
                    name: Cell name (unused).
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317

                Returns:
                    True if the cell is a leaf cell.
                """
                return not any(cell.cells())
            layer_filter = default_filter

        layers = []
        layer_id = 0

        for name, cell in model.cells_and_names():
            if not layer_filter(name, cell):
                continue

            layer_info = self._profile_mindspore_layer(
                layer_id=layer_id,
                layer_name=name,
                cell=cell,
                input_shape=input_shape,
316
317
318
319
320
321
322
323
324
325
326
327
328
                cell=cell,
                input_shape=input_shape,
                input_tensor=input_tensor,
            )
            layers.append(layer_info)
            layer_id += 1

        self._profiled_layers = layers
        return layers

    def _profile_mindspore_layer(
        self,
        layer_id: int,
344
345
346
347
348
349
350
351
352
353
354
355
356
            LayerInfo for this layer.
        """
        # Import inside method: optional dependency (mindspore may not be installed)
        # nn is already imported in the caller (profile_mindspore_model)
        layer_type = cell.__class__.__name__

        param_memory = self._estimate_mindspore_param_memory(cell)

        activation_memory = self._estimate_mindspore_activation_memory(
            cell, input_shape, input_tensor
        )

        # NOTE: These recompute activation coefficients are rough placeholders.
355
356
357
358
359
360
361
362
363
364
365
366
367
368

        # NOTE: These recompute activation coefficients are rough placeholders.
        # Downstream consumers that use config_path get precise values from
        # perf_estimation; the legacy backfill path also overwrites them.
        activation_memory_slct = activation_memory * 0.04
        activation_memory_comm = activation_memory * 0.125
        activation_memory_both = activation_memory * 0.165
        activation_memory_full = activation_memory * 0.5

        return LayerInfo(
            layer_id=layer_id,
            layer_name=layer_name,
            layer_type=layer_type,
            param_memory=param_memory,
381
382
383
384
385
386
387
388
389
390
391
392

        Returns:
            Parameter memory in MB.
        """
        param_size = 0
        for param in cell.get_parameters():
            param_size += param.numel() * param.element_size()
        return param_size / (1024 * 1024)

    def _estimate_mindspore_activation_memory(
        self,
        cell: Any,
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
        Returns:
            Activation memory in MB.
        """
        # Import inside method: optional dependency (mindspore may not be installed)
        from mindspore import ops  # pylint: disable=C0415,C9002

        if input_shape is not None:
            try:
                dummy_input = ops.randn(*input_shape)
                input_size = dummy_input.numel() * dummy_input.element_size()
                output = cell(dummy_input)
                if isinstance(output, (tuple, list)):
                    output_size = sum(
                        obj.numel() * obj.element_size()
                        for obj in output
                        if hasattr(obj, 'numel')
                    )
417
418
419
420
421
422
423
424
425
426
427
428
429
430
                        for obj in output
                        if hasattr(obj, 'numel')
                    )
                else:
                    output_size = output.numel() * output.element_size()
                return (input_size + output_size) / (1024 * 1024)
            except Exception:  # pylint: disable=W0718
                pass

        return 1.0

    def profile_from_json_file(
        self,
        json_file: str,
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
        Example:
            >>> builder = YamlConfigBuilder()
            >>> layers = profiler.profile_from_sapp_ppb_yaml("model.yaml")
        """
        try:
            import yaml as _yaml  # pylint: disable=C0415
        except ImportError as exc:
            raise ImportError("PyYAML is required for YAML file loading.") from exc

        with open(yaml_file, 'r', encoding='utf-8') as f:
            data = _yaml.safe_load(f)

        layers = []
        for idx, layer_data in enumerate(data.get('layers', [])):
            layer = LayerInfo(
                layer_id=idx,
                layer_name=layer_data.get('name', f'layer_{idx}'),
                layer_type=layer_data.get('type', 'unknown'),
                param_memory=layer_data.get('parameter_memory', 0.0),
494
495
496
497
498
499
500
501
502
503
504
505
                layer_type=layer_data.get('type', 'unknown'),
                param_memory=layer_data.get('parameter_memory', 0.0),
                activation_memory=layer_data.get('activation_memory', 0.0),
            )
            layers.append(layer)

        self._profiled_layers = layers
        return layers

    def save_to_json(
        self,
        output_file: str,
516
517
518
519
520
521
522
523
524
            >>> layers = profiler.profile_from_pytorch_model(model)
            >>> builder.save_to_json("layers.json", layers)
        """
        if layers is None:
            layers = self._profiled_layers

        data = {
            'layers': [
                {
592
593
594
595
596
597
598
599
600
            ...     memory_limit=80000,
            ... )
        """
        if pipeline_num_range[0] <= 0 or pipeline_num_range[1] <= 0:
            raise ValueError(
                f"pipeline_num_range values must be positive, got {pipeline_num_range}"
            )
        if pipeline_num_range[0] > pipeline_num_range[1]:
            raise ValueError(
600
601
602
603
604
605
606
607
608
609
610
611
612
            raise ValueError(
                f"pipeline_num_range min must not exceed max, got {pipeline_num_range}"
            )
        if micro_batch_num_range[0] <= 0 or micro_batch_num_range[1] <= 0:
            raise ValueError(
                f"micro_batch_num_range values must be positive, got {micro_batch_num_range}"
            )
        if micro_batch_num_range[0] > micro_batch_num_range[1]:
            raise ValueError(
                f"micro_batch_num_range min must not exceed max, got {micro_batch_num_range}"
            )
        if memory_limit <= 0:
            raise ValueError(
612
613
614
615
616
617
618
619
620
            raise ValueError(
                f"memory_limit must be positive, got {memory_limit}"
            )
        if num_of_interleave <= 0:
            raise ValueError(
                f"num_of_interleave must be positive, got {num_of_interleave}"
            )

        num_layer_val = 0
hyper_parallel/auto_parallel/sapp_ppb/__init__.py
41
42
43
44
45
46
47
48
49

    def load_module(self, fullname):
        """Load *fullname* by importing its canonical long-name counterpart."""
        if fullname in _sys.modules:
            return _sys.modules[fullname]
        long_name = _PREFIX + fullname[len(_SHORT_PREFIX):]
        mod = _import_module(long_name)
        _sys.modules[fullname] = mod
        for _attr in getattr(mod, "__all__", ()):
48
49
50
51
52
53
54
55
56
57
        _sys.modules[fullname] = mod
        for _attr in getattr(mod, "__all__", ()):
            try:
                setattr(mod, _attr, getattr(mod, _attr))
            except AttributeError:
                pass
        return mod


_alias_finder = _SappPpbAliasFinder()
hyper_parallel/auto_parallel/sapp_ppb/sapp/sapp_solver.py
597
598
599
600
601
602
603
604
605
        """
        body_layers = layers_sorted.get(Layer.type_enum.BODY, [])
        if body_layers:
            return body_layers[0].recompute_considered_
        return {rec_type: False for rec_type in Recompute.TYPE}

    def max_stage_micro_eq_stage(self, prob: Any,
                                 layers_sorted: Dict[Layer.type_enum, List[Layer]]) -> Any:
        """Apply additional VPP optimisations when ``pp == num_of_micro_batch``."""