Diff Coverage

Diff: origin/master...HEAD, staged and unstaged changes

Source File Diff Coverage (%) Missing Lines
hyper_parallel/trainer/base.py 41.7% 691,770,786,809,814,1077,1416
hyper_parallel/trainer/callbacks/base.py 72.0% 240,761,767,776,823-824,826-829,833,839,841,848,878-883,887,891-893,900-912,917,923,931,936-942,992,995-998,1004-1005,1008-1009,1016-1020,1024,1027,1040-1041,1051,1065,1148-1154,1160-1161,1168-1169,1176,1178-1179,1227,1241-1243,1258-1259,1294-1297,1303-1304,1308,1317
hyper_parallel/trainer/config.py 100%  
hyper_parallel/trainer/base.py
687
688
689
690
691
692
693
694
695
        self.tensorboard_callback = TensorBoardCallback(self)
        self.progress_callback = ProgressCallback(self)
        self.moe_monitor_callback = MoEMonitorCallback(self)
        # Health + operability (no-ops unless enabled in cfg.train.debug / .memory_monitor).
        self.training_state_monitor_callback = TrainingStateMonitorCallback(self)
        self.gradient_health_callback = GradientHealthCallback(self)
        self.memory_monitor_callback = MemoryMonitorCallback(self)
        self.gc_callback = GCCallback(self)
        # ``user_callbacks`` lets external code append extra Callback instances
766
767
768
769
770
771
772
773
774
        """Dispatch on_train_begin to all callbacks."""
        # Memory monitor first so it captures the truly-initial peak.
        self.memory_monitor_callback.on_train_begin(self.state)
        self.moe_monitor_callback.on_train_begin(self.state)
        self.training_state_monitor_callback.on_train_begin(self.state)
        self.profiler_callback.on_train_begin(self.state)
        self.wandb_callback.on_train_begin(self.state)
        self.tensorboard_callback.on_train_begin(self.state)
        self.progress_callback.on_train_begin(self.state)
782
783
784
785
786
787
788
789
790
        """Dispatch on_train_end to all callbacks."""
        self.checkpoint_callback.on_train_end(self.state)
        self.hf_export_callback.on_train_end(self.state)
        self.progress_callback.on_train_end(self.state)
        self.training_state_monitor_callback.on_train_end(self.state)
        self.tensorboard_callback.on_train_end(self.state)
        self.wandb_callback.on_train_end(self.state)
        self.profiler_callback.on_train_end(self.state)
        for cb in self.user_callbacks:
805
806
807
808
809
810
811
812
813
814
815
816
817
818
        """Dispatch on_substep_end (after each micro-batch forward/backward)."""
        self.moe_monitor_callback.on_substep_end(self.state, **kwargs)
        self.training_state_monitor_callback.on_substep_end(self.state, **kwargs)
        for cb in self.user_callbacks:
            cb.on_substep_end(self.state, **kwargs)

    def on_pre_optimizer_step(self, grad_norm=None):
        """Dispatch on_pre_optimizer_step (after grad clip, before optimizer.step)."""
        # Monitor first so an abnormal step is captured before health checks abort.
        self.training_state_monitor_callback.on_pre_optimizer_step(
            self.state, grad_norm=grad_norm,
        )
        self.gradient_health_callback.on_pre_optimizer_step(
            self.state, grad_norm=grad_norm,
1073
1074
1075
1076
1077
1078
1079
1080
1081
        for epoch in range(num_epochs):
            if self.state.global_step >= self.state.max_steps:
                break
            self.state.epoch = epoch
            if hasattr(self, 'sampler'):
                self.sampler.set_epoch(epoch)
            self.on_epoch_begin()

            # Build micro-batch iterator from the stateful dataloader.
1412
1413
1414
1415
1416
1417
1418
1419
1420
        for hf_name, hf_tensor in hf_sd.items():
            real_name = logical_to_real.get(hf_name)
            if real_name is None:
                continue
            tgt = tuple(real_to_param[real_name].shape)
            src = tuple(hf_tensor.shape)
            if src == tgt:
                valid_sd[real_name] = hf_tensor
            else:
hyper_parallel/trainer/callbacks/base.py
236
237
238
239
240
241
242
243
244
        message = " | ".join(f"{k}={v}" for k, v in metrics.items())
        monitor_cb = getattr(self.trainer, 'training_state_monitor_callback', None)
        monitor_active = monitor_cb is not None and getattr(monitor_cb, 'active', False)
        if monitor_active:
            logger.info(message)
        else:
            logger.info_rank0(message)

        record = {
757
758
759
760
761
762
763
764
765
        super().__init__(trainer)
        train_cfg = getattr(trainer.args, "train", None)
        self.cfg = getattr(train_cfg, "monitor", None)
        if self.cfg is None:
            self.cfg = getattr(trainer.args, "monitor", None)

        self.enabled = bool(getattr(self.cfg, "monitor_on", False)) if self.cfg else False
        self.dump_path = getattr(self.cfg, "dump_path", "./dump") if self.cfg else "./dump"
        self.step_interval = int(getattr(self.cfg, "step_interval", 1) if self.cfg else 1)
763
764
765
766
767
768
769
770
771
        self.enabled = bool(getattr(self.cfg, "monitor_on", False)) if self.cfg else False
        self.dump_path = getattr(self.cfg, "dump_path", "./dump") if self.cfg else "./dump"
        self.step_interval = int(getattr(self.cfg, "step_interval", 1) if self.cfg else 1)
        if self.step_interval < 1:
            raise ValueError("train.monitor.step_interval must be >= 1.")

        self.local_loss_format = self._parse_formats("local_loss_format")
        self.device_local_loss_format = self._parse_formats("device_local_loss_format")
        self.local_norm_format = self._parse_formats("local_norm_format")
772
773
774
775
776
777
778
779
780
        self.device_local_norm_format = self._parse_formats("device_local_norm_format")

        raw_patterns = getattr(self.cfg, "target", None) if self.cfg else None
        if isinstance(raw_patterns, str):
            raw_patterns = [raw_patterns]
        self._target_patterns = (
            [re.compile(pattern) for pattern in raw_patterns]
            if raw_patterns else None
        )
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
        """Parse configured monitor output formats for a metric field."""
        value = getattr(self.cfg, field_name, None) if self.cfg else None
        if value is None:
            return ()
        if isinstance(value, str):
            formats = (value,)
        else:
            formats = tuple(value)
        unknown = sorted(set(formats) - self._SUPPORTED_FORMATS)
        if unknown:
            raise ValueError(
                f"train.monitor.{field_name} only supports "
                f"{sorted(self._SUPPORTED_FORMATS)}, got {unknown}."
            )
        return formats

    @staticmethod
    def _to_scalar(value) -> float:
        """Convert tensor-like values to a Python float."""
835
836
837
838
839
840
841
842
843
844
845
    @staticmethod
    def _to_scalar(value) -> float:
        """Convert tensor-like values to a Python float."""
        if value is None:
            return 0.0
        if hasattr(value, "to_local"):
            value = value.to_local()
        if hasattr(value, "detach"):
            value = value.detach()
        if hasattr(value, "float"):
            value = value.float()
844
845
846
847
848
849
850
851
852
        if hasattr(value, "float"):
            value = value.float()
        if hasattr(value, "item"):
            return float(value.item())
        return float(value)

    @staticmethod
    def _sanitize_tag(value: str) -> str:
        value = value.replace(".", "/").replace(" ", "_")
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
    def _write(self, formats: tuple[str, ...], tag: str, value: float,
               step: int, *, global_metric: bool = False) -> None:
        if not formats:
            return
        if "tensorboard" in formats:
            writer = self._global_writer if global_metric else self._rank_writer
            if writer is not None:
                writer.add_scalar(tag, value, step)
        if "log" in formats:
            self._pending_log_metrics[tag] = value

    @staticmethod
    def _format_monitor_value(value: float) -> str:
        return f"{float(value):.8f}"

    def _flush_step_log(self, step: int) -> None:
        """Print one compact rank-local monitor line for console output."""
        if not self._pending_log_metrics:
            return
        field_map = (
            ("loss", "loss/local_loss"),
            ("accum_loss", "loss/device_accum_local_loss"),
            ("grad_norm", "grad/device_local_norm"),
            ("grad_nan_count", "grad/device_nan_count"),
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
            ("grad_norm", "grad/device_local_norm"),
            ("grad_nan_count", "grad/device_nan_count"),
            ("grad_inf_count", "grad/device_inf_count"),
        )
        raw_tags = {tag for _, tag in field_map}
        parts = []
        for display_name, tag in field_map:
            if tag in self._pending_log_metrics:
                value = self._pending_log_metrics[tag]
                parts.append(f"{display_name}={self._format_monitor_value(value)}")
        for tag in sorted(self._pending_log_metrics):
            if tag in raw_tags:
                continue
            value = self._pending_log_metrics[tag]
            parts.append(f"{tag}={self._format_monitor_value(value)}")
        timestamp = time.strftime("%H:%M:%S")
        print(
            f"[{timestamp}][rank{self._rank}][INFO] local_state: "
            f"step={step} | " + " | ".join(parts),
            flush=True,
        )
        self._pending_log_metrics = {}

    def _abnormal_record_path(self, *, global_record: bool = False) -> str:
        configured = getattr(self.cfg, "abnormal_record_path", None) if self.cfg else None
        base_path = configured or os.path.join(self.dump_path, "abnormal_training_state.json")
919
920
921
922
923
924
925
926
    def _abnormal_record_path(self, *, global_record: bool = False) -> str:
        configured = getattr(self.cfg, "abnormal_record_path", None) if self.cfg else None
        base_path = configured or os.path.join(self.dump_path, "abnormal_training_state.json")
        if global_record:
            return base_path
        root, ext = os.path.splitext(base_path)
        suffix = ext or ".json"
        return f"{root}_rank_{self._rank}{suffix}"
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945

    def _record_abnormal(self, record: dict, *, global_record: bool = False) -> None:
        """Append an abnormal training-state record to the JSON log."""
        if global_record and self._rank != 0:
            return
        path = self._abnormal_record_path(global_record=global_record)
        os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
        records = []
        if os.path.isfile(path):
            try:
                with open(path, encoding="utf-8") as file:
                    loaded = json.load(file)
                if isinstance(loaded, list):
                    records = loaded
            except (OSError, json.JSONDecodeError):
                records = []
        records.append(record)
        with open(path, "w", encoding="utf-8") as file:
            json.dump(records, file, indent=2)
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
        return param_ids

    def on_train_begin(self, state: "TrainerState", **kwargs) -> None:
        if not (self.active or self.health_active):
            return
        self._rank = platform.get_rank()
        if self._uses_tensorboard():
            try:
                from torch.utils.tensorboard import SummaryWriter  # pylint: disable=C0415
            except ImportError as exc:
                raise RuntimeError(
                    "train.monitor requested tensorboard output, but "
                    "torch.utils.tensorboard.SummaryWriter is unavailable. "
                    "Install tensorboard or remove 'tensorboard' from monitor formats."
                ) from exc
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
                    "torch.utils.tensorboard.SummaryWriter is unavailable. "
                    "Install tensorboard or remove 'tensorboard' from monitor formats."
                ) from exc

            tb_root = os.path.join(self.dump_path, "tensorboard")
            self._rank_writer = SummaryWriter(
                os.path.join(tb_root, f"rank_{self._rank}")
            )
            if self._rank == 0:
                self._global_writer = SummaryWriter(os.path.join(tb_root, "global"))
        logger.info_rank0(
            "TrainingStateMonitor enabled: dump_path=%s step_interval=%d health=%s",
            self.dump_path, self.step_interval, self.health_active,
        )
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
            self.dump_path, self.step_interval, self.health_active,
        )

    def on_train_end(self, state: "TrainerState", **kwargs) -> None:
        for writer in (self._rank_writer, self._global_writer):
            if writer is not None:
                writer.close()
        self._rank_writer = None
        self._global_writer = None

    def on_substep_end(self, state: "TrainerState", **kwargs) -> None:
        if not (self.active or self.health_active) or state.global_step % self.step_interval != 0:
            return
        raw_loss = kwargs.get("raw_loss")
        if raw_loss is None:
            return
        if int(kwargs.get("micro_step", 0)) == 0:
            self._step_loss_sum = 0.0
            self._step_loss_count = 0
            self.last_local_loss = None
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
        self.last_local_loss = self._step_loss_sum / self._step_loss_count
        self.last_monitor_step = state.global_step

        self._check_nonfinite(state, "local_loss", loss_value)
        tag = "loss/local_loss"
        self._write(
            self.local_loss_format,
            tag,
            loss_value,
            state.global_step,
1047
1048
1049
1050
1051
1052
1053
1054
1055
        )

    def on_pre_optimizer_step(self, state: "TrainerState", **kwargs) -> None:
        if not (self.active or self.health_active) or state.global_step % self.step_interval != 0:
            return
        device_sum_sq = 0.0
        device_nan_count = 0
        device_inf_count = 0
        embedding_sum_sq = 0.0
1061
1062
1063
1064
1065
1066
1067
1068
1069
            if not self._should_record_param(name):
                continue
            grad = getattr(param, "grad", None)
            if grad is None:
                continue
            is_dtensor = hasattr(grad, "to_local")
            local_grad = grad.to_local() if is_dtensor else grad
            local_float = local_grad.detach().float()
            sum_sq = local_float.pow(2).sum()
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
                )

    def on_step_end(self, state: "TrainerState", *, loss: Optional[float] = None,
                    grad_norm: Optional[float] = None, **kwargs) -> None:
        if not (self.active or self.health_active):
            return
        if state.global_step % self.step_interval == 0:
            if self._step_loss_count > 0:
                device_loss = self._step_loss_sum / self._step_loss_count
                self._check_nonfinite(state, "device_accum_local_loss", device_loss)
                self._write(
                    self.device_local_loss_format,
                    "loss/device_accum_local_loss",
                    device_loss,
                    state.global_step,
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
                    "loss/device_accum_local_loss",
                    device_loss,
                    state.global_step,
                )
            if self._rank == 0 and loss is not None:
                self._write(
                    ("tensorboard",) if self._uses_tensorboard() else (),
                    "loss/global_loss",
                    float(loss),
                    state.global_step,
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
                    float(loss),
                    state.global_step,
                    global_metric=True,
                )
            if self._rank == 0 and grad_norm is not None:
                self._write(
                    ("tensorboard",) if self._uses_tensorboard() else (),
                    "grad/global_grad_norm",
                    float(grad_norm),
                    state.global_step,
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
                    float(grad_norm),
                    state.global_step,
                    global_metric=True,
                )
            self._flush_step_log(state.global_step)

        self._step_loss_sum = 0.0
        self._step_loss_count = 0


class GradientHealthCallback(Callback):
    """Detect NaN / Inf grad_norm and raise / warn.
1223
1224
1225
1226
1227
1228
1229
1230
1231
            int(getattr(monitor_cfg, "global_norm_spike_count_threshold", 1))
            if monitor_cfg is not None else 1
        )
        if self.global_norm_spike_count_threshold < 1:
            raise ValueError("train.monitor.global_norm_spike_count_threshold must be >= 1.")
        self._global_norm_spike_count = 0
        self.monitor_cfg = monitor_cfg

    def _abnormal_record_path(self, *, global_record: bool = False) -> str:
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
        )
        base_path = configured or os.path.join(dump_path, "abnormal_training_state.json")
        if global_record:
            return base_path
        root, ext = os.path.splitext(base_path)
        suffix = ext or ".json"
        return f"{root}_rank_{platform.get_rank()}{suffix}"

    def _record_abnormal(self, record: dict, *, global_record: bool = False) -> None:
        """Append an abnormal gradient-health record to the JSON log."""
        if global_record and platform.get_rank() != 0:
1254
1255
1256
1257
1258
1259
1260
1261
1262
                with open(path, encoding="utf-8") as file:
                    loaded = json.load(file)
                if isinstance(loaded, list):
                    records = loaded
            except (OSError, json.JSONDecodeError):
                records = []
        records.append(record)
        with open(path, "w", encoding="utf-8") as file:
            json.dump(records, file, indent=2)
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
            self._check_global_norm_spike(state, grad_norm)

    def on_step_end(self, state: "TrainerState", *, loss: Optional[float] = None,
                    grad_norm: Optional[float] = None, **kwargs) -> None:
        if not self.nan_inf_check_enabled or loss is None:
            return
        if math.isnan(loss) or math.isinf(loss):
            record = {
                "step": state.global_step,
                "rank": platform.get_rank(),
                "metric": "loss",
                "value": loss,
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
                "rank": platform.get_rank(),
                "metric": "loss",
                "value": loss,
            }
            self._record_abnormal(record)
            logger.error(
                "GradientHealthCallback: loss=%s at step %d (NaN/Inf).",
                loss, state.global_step,
            )
            raise RuntimeError(
                f"Non-finite loss={loss} at step {state.global_step}. "
                "Disable train.monitor.check_for_nan_in_loss_and_grad or "
                "train.debug.check_nan_inf to skip this guard."
            )
1313
1314
1315
1316
1317
1318
1319
1320
1321

    def _check_global_norm_spike(self, state: "TrainerState", grad_norm: float) -> None:
        """Check whether global grad norm exceeds the configured spike threshold."""
        if not self.global_norm_spike_enabled:
            return
        threshold = self.global_norm_spike_threshold
        if grad_norm >= threshold:
            self._global_norm_spike_count += 1
            record = {