Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / integration / llamafactory / trainer.py: 0%

325 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-20 07:18 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""HyperParallel trainer backend for LlamaFactory.""" 

16import json 

17import logging 

18import os 

19import types 

20from contextlib import nullcontext 

21from dataclasses import dataclass 

22from pathlib import Path 

23from typing import Any, Optional, Union 

24 

25import numpy as np 

26import torch 

27from torch import nn 

28from transformers import Seq2SeqTrainer 

29 

30from hyper_parallel import SkipDTensorDispatch 

31from hyper_parallel.core.fully_shard.api import HSDPModule, hsdp_sync_stream 

32from hyper_parallel.core.utils import clip_grad_norm_ as hp_clip_grad_norm_ 

33from hyper_parallel.integration.llamafactory.utils import fsdp2_prepare_model 

34from hyper_parallel.platform import get_platform 

35 

36logger = logging.getLogger(__name__) 

37 

38_VALID_DTYPES = {"float32", "float16", "bfloat16", "fp32", "fp16", "bf16"} 

39_HSDP_MODEL_NAME = "hsdp_model" 

40_HSDP_OPTIMIZER_NAME = "optimizer" 

41 

42 

43@dataclass 

44class HyperParallelArguments: 

45 """Minimal HyperParallel configuration needed by the trainer backend.""" 

46 

47 tp_size: int = 1 

48 device_type: str = "auto" 

49 param_dtype: Optional[str] = None 

50 reduce_dtype: Optional[str] = None 

51 reshard_after_forward: Optional[bool] = None 

52 

53 def validate(self) -> None: 

54 """Validate supported argument values.""" 

55 if self.tp_size != 1: 

56 raise ValueError( 

57 "Current trainer backend only supports replacing FSDP/fully_shard. " 

58 f"Expected tp_size=1, got {self.tp_size}." 

59 ) 

60 if self.param_dtype is not None and self.param_dtype not in _VALID_DTYPES: 

61 raise ValueError( 

62 f"param_dtype must be one of {sorted(_VALID_DTYPES)}, got {self.param_dtype!r}." 

63 ) 

64 if self.reduce_dtype is not None and self.reduce_dtype not in _VALID_DTYPES: 

65 raise ValueError( 

66 f"reduce_dtype must be one of {sorted(_VALID_DTYPES)}, got {self.reduce_dtype!r}." 

67 ) 

68 if self.device_type not in {"auto", "npu", "cuda", "cpu"}: 

69 raise ValueError( 

70 f"device_type must be one of ['auto', 'cpu', 'cuda', 'npu'], got {self.device_type!r}." 

71 ) 

72 if self.reshard_after_forward is not None and not isinstance(self.reshard_after_forward, bool): 

73 raise ValueError( 

74 "reshard_after_forward must be a bool when provided, " 

75 f"got {type(self.reshard_after_forward).__name__}." 

76 ) 

77 

78 @classmethod 

79 def from_dict(cls, config: dict) -> "HyperParallelArguments": 

80 """Build arguments from a plain dict.""" 

81 known_fields = set(cls.__dataclass_fields__) # pylint: disable=no-member 

82 hp_args = cls(**{key: value for key, value in config.items() if key in known_fields}) 

83 hp_args.validate() 

84 return hp_args 

85 

86 @classmethod 

87 def from_finetuning_args(cls, finetuning_args) -> "HyperParallelArguments": 

88 """Extract HyperParallel arguments from LlamaFactory finetuning args.""" 

89 raw = getattr(finetuning_args, "hyper_parallel_args", None) 

90 if raw is None: 

91 hp_args = cls() 

92 hp_args.validate() 

93 return hp_args 

94 if isinstance(raw, str): 

95 with open(raw, "r", encoding="utf-8") as file: 

96 raw = json.load(file) 

97 if not isinstance(raw, dict): 

98 raise ValueError( 

99 "finetuning_args.hyper_parallel_args must be a dict or JSON file path, " 

100 f"got {type(raw).__name__}." 

101 ) 

102 return cls.from_dict(raw) 

103 

104 

105def _localize_optimizer_state(optim_sd: dict) -> dict: 

106 """Convert DTensors in optimizer state dict to local CPU tensors for serialization. 

107 

108 Args: 

109 optim_sd: Optimizer state dict from ``optimizer.state_dict()``. 

110 

111 Returns: 

112 A new state dict with the same structure but all DTensor / Tensor values 

113 replaced by their local (shard) equivalents on CPU. 

114 """ 

115 from hyper_parallel.core.dtensor.dtensor import DTensor as _DTensor # pylint: disable=C0415 

116 

117 new_state = {} 

118 for param_idx, state in optim_sd.get("state", {}).items(): 

119 local_state = {} 

120 for key, val in state.items(): 

121 if isinstance(val, _DTensor): 

122 local_state[key] = val.to_local().detach().cpu() 

123 elif isinstance(val, torch.Tensor): 

124 local_state[key] = val.detach().cpu() 

125 else: 

126 local_state[key] = val 

127 new_state[param_idx] = local_state 

128 return {"state": new_state, "param_groups": optim_sd.get("param_groups", [])} 

129 

130 

131def _load_local_optimizer_state(optimizer, saved_sd: dict) -> None: 

132 """Copy saved local optimizer state into the optimizer's current (possibly DTensor-backed) state. 

133 

134 Args: 

135 optimizer: The optimizer whose state to restore. 

136 saved_sd: State dict saved by ``_localize_optimizer_state`` (local CPU tensors). 

137 """ 

138 from hyper_parallel.core.dtensor.dtensor import DTensor as _DTensor # pylint: disable=C0415 

139 

140 # Build param index → param object mapping 

141 param_by_idx: dict[int, torch.nn.Parameter] = {} 

142 idx = 0 

143 for group in optimizer.param_groups: 

144 for p in group["params"]: 

145 param_by_idx[idx] = p 

146 idx += 1 

147 

148 for param_idx, saved_state in saved_sd.get("state", {}).items(): 

149 param_idx = int(param_idx) if isinstance(param_idx, str) else param_idx 

150 param = param_by_idx.get(param_idx) 

151 if param is None or param not in optimizer.state: 

152 continue 

153 current_state = optimizer.state[param] 

154 for key, saved_val in saved_state.items(): 

155 current_val = current_state.get(key) 

156 if current_val is None: 

157 # New state entry (e.g. step counter added later) 

158 if isinstance(saved_val, torch.Tensor): 

159 device = param.to_local().device if isinstance(param, _DTensor) else param.device 

160 current_state[key] = saved_val.to(device) 

161 else: 

162 current_state[key] = saved_val 

163 elif isinstance(current_val, _DTensor): 

164 local = current_val.to_local() 

165 local.copy_(saved_val.to(local.device)) 

166 elif isinstance(current_val, torch.Tensor): 

167 current_val.copy_(saved_val.to(current_val.device)) 

168 else: 

169 current_state[key] = saved_val 

170 

171 # Restore hyper-parameters (lr, betas, etc.) 

172 for saved_group, current_group in zip(saved_sd.get("param_groups", []), optimizer.param_groups): 

173 for key, val in saved_group.items(): 

174 if key != "params": 

175 current_group[key] = val 

176 

177 

178def _wrap_optimizer_step_with_skip_dtensor_dispatch(optimizer) -> None: 

179 """Wrap optimizer.step so DTensor dispatch is skipped during parameter updates.""" 

180 if getattr(optimizer, "_hp_step_wrapped", False): 

181 return 

182 

183 original_step = optimizer.step 

184 

185 def _hp_step(bound_optimizer, *args, **kwargs): 

186 del bound_optimizer 

187 with SkipDTensorDispatch(): 

188 return original_step(*args, **kwargs) 

189 

190 optimizer.step = types.MethodType(_hp_step, optimizer) 

191 setattr(optimizer, "_hp_step_wrapped", True) 

192 

193 

194def _export_to_hf_format(model: nn.Module, tokenizer, save_dir: str): 

195 """Gather full state dict via HyperParallel and save in HuggingFace-compatible format. 

196 

197 Uses HyperParallel's own ``get_model_state_dict(full_state_dict=True, cpu_offload=True)`` 

198 which calls ``DTensor.full_tensor()`` (all-gather) for each sharded parameter. 

199 Rank 0 gets the full gathered weights on CPU; other ranks get an empty dict. 

200 """ 

201 from hyper_parallel.core.fully_shard.api import ( # pylint: disable=C0415 

202 get_model_state_dict as hp_get_model_state_dict, 

203 ) 

204 from torch.distributed.checkpoint.state_dict import StateDictOptions # pylint: disable=C0415 

205 

206 export_dir = Path(save_dir) 

207 options = StateDictOptions(full_state_dict=True, cpu_offload=True) 

208 state_dict = hp_get_model_state_dict(model, options=options) 

209 state_dict = _normalize_hf_export_state_dict(state_dict) 

210 

211 if get_platform().get_rank() == 0: 

212 export_dir.mkdir(parents=True, exist_ok=True) 

213 

214 if hasattr(model, "save_pretrained"): 

215 model.save_pretrained(str(export_dir), state_dict=state_dict) 

216 else: 

217 torch.save(state_dict, export_dir / "pytorch_model.bin") 

218 

219 if tokenizer is not None: 

220 tokenizer.save_pretrained(str(export_dir)) 

221 

222 if get_platform().get_world_size() > 1: 

223 torch.distributed.barrier() 

224 

225 

226def _normalize_hf_export_state_dict(state_dict: dict[str, Any]) -> dict[str, Any]: 

227 """Normalize gathered tensors to match the baseline HF/LlamaFactory export. 

228 

229 HyperParallel mixed precision can leave the live parameters in reduced 

230 precision, which halves the on-disk checkpoint size compared with the 

231 baseline FSDP2 export. The baseline path saves full-precision weights, so 

232 cast floating tensors back to fp32 before forwarding them to HF save logic. 

233 

234 Shared/tied tensors are cast once and then reused to preserve aliasing. 

235 """ 

236 normalized: dict[str, Any] = {} 

237 cast_cache: dict[tuple[Any, ...], torch.Tensor] = {} 

238 

239 for key, value in state_dict.items(): 

240 if not isinstance(value, torch.Tensor) or not torch.is_floating_point(value): 

241 normalized[key] = value 

242 continue 

243 

244 if value.dtype == torch.float32: 

245 normalized[key] = value 

246 continue 

247 

248 storage = value.untyped_storage() 

249 cache_key = ( 

250 storage.data_ptr(), 

251 value.storage_offset(), 

252 tuple(value.size()), 

253 tuple(value.stride()), 

254 value.device.type, 

255 str(value.dtype), 

256 ) 

257 casted = cast_cache.get(cache_key) 

258 if casted is None: 

259 casted = value.to(dtype=torch.float32) 

260 cast_cache[cache_key] = casted 

261 normalized[key] = casted 

262 

263 return normalized 

264 

265 

266class HyperParallelTrainer(Seq2SeqTrainer): 

267 """Trainer backend that swaps FSDP2 prepare for HyperParallel fully_shard.""" 

268 

269 def __init__( 

270 self, 

271 hp_args: HyperParallelArguments, 

272 finetuning_args=None, 

273 processor=None, 

274 ref_model: Optional[nn.Module] = None, 

275 **kwargs, 

276 ): 

277 kwargs["processing_class"] = kwargs.pop("tokenizer", kwargs.get("processing_class", None)) 

278 gen_kwargs = kwargs.pop("gen_kwargs", None) 

279 self._hp_args = hp_args 

280 self.finetuning_args = finetuning_args 

281 super().__init__(**kwargs) 

282 if not getattr(self.accelerator, "is_fsdp2", False): 

283 raise ValueError("HyperParallel trainer requires Accelerate FSDP2 mode to be enabled.") 

284 if gen_kwargs is not None: 

285 self._gen_kwargs = gen_kwargs 

286 self.ref_model = ref_model 

287 

288 if processor is not None: 

289 self.model_accepts_loss_kwargs = False 

290 

291 if self.ref_model is not None: 

292 self.ref_model = fsdp2_prepare_model(self.accelerator, self.ref_model, self._hp_args) 

293 self._orig_accelerator_clip_grad_norm = self.accelerator.clip_grad_norm_ 

294 self._orig_fsdp2_prepare_model = None 

295 self._accelerator_patches_active = False 

296 

297 def _activate_accelerator_patches(self) -> None: 

298 """Activate temporary Accelerate patches for HyperParallel training.""" 

299 if self._accelerator_patches_active: 

300 return 

301 

302 import accelerate.accelerator as acc_module # pylint: disable=C0415 

303 

304 hp_args = self._hp_args 

305 

306 self._orig_fsdp2_prepare_model = acc_module.fsdp2_prepare_model 

307 

308 def _hp_fsdp2_prepare_model(accelerator, model): 

309 return fsdp2_prepare_model(accelerator, model, hp_args) 

310 

311 acc_module.fsdp2_prepare_model = _hp_fsdp2_prepare_model 

312 

313 def _hp_clip_grad_norm(accelerator, parameters, max_norm, norm_type=2): 

314 if getattr(accelerator, "is_fsdp2", False): 

315 accelerator.unscale_gradients() 

316 parameter_list = list(parameters) 

317 parameter_ids = {id(param) for param in parameter_list} 

318 for model in accelerator._models: # pylint: disable=protected-access 

319 if not isinstance(model, HSDPModule): 

320 continue 

321 model_param_ids = {id(param) for param in model.parameters()} 

322 if parameter_ids and parameter_ids.issubset(model_param_ids): 

323 return hp_clip_grad_norm_(parameter_list, max_norm, norm_type=norm_type) 

324 return self._orig_accelerator_clip_grad_norm(parameters, max_norm, norm_type=norm_type) 

325 

326 self.accelerator.clip_grad_norm_ = types.MethodType(_hp_clip_grad_norm, self.accelerator) 

327 self._accelerator_patches_active = True 

328 

329 def _restore_accelerator_patches(self) -> None: 

330 """Restore Accelerate patches to avoid cross-trainer contamination.""" 

331 if not self._accelerator_patches_active: 

332 return 

333 

334 import accelerate.accelerator as acc_module # pylint: disable=C0415 

335 

336 if self._orig_fsdp2_prepare_model is not None: 

337 acc_module.fsdp2_prepare_model = self._orig_fsdp2_prepare_model 

338 self.accelerator.clip_grad_norm_ = self._orig_accelerator_clip_grad_norm 

339 self._accelerator_patches_active = False 

340 

341 def _wrap_model(self, model: nn.Module, training: bool = True, dataloader=None) -> nn.Module: 

342 """Let Accelerate own FSDP2/HSDP wrapping so optimizer remapping stays correct.""" 

343 del dataloader 

344 if isinstance(model, HSDPModule): 

345 return model 

346 if training and getattr(self.accelerator, "is_fsdp2", False): 

347 # Trainer usually wraps here, but FSDP2 must be prepared by Accelerate. 

348 return model 

349 return super()._wrap_model(model, training=training) 

350 

351 def _get_train_sampler(self, *args, **kwargs): 

352 """Respect disable_shuffling when provided by the caller.""" 

353 if getattr(self.finetuning_args, "disable_shuffling", False): 

354 return torch.utils.data.SequentialSampler(self.train_dataset) 

355 return super()._get_train_sampler(*args, **kwargs) 

356 

357 def compute_loss(self, model, inputs, *args, **kwargs): 

358 """Support ASFT-style loss when a reference model is configured.""" 

359 if getattr(self.finetuning_args, "use_asft_loss", False) and self.ref_model is not None: 

360 with torch.no_grad(): 

361 ref_outputs = self.ref_model( 

362 input_ids=inputs["input_ids"], 

363 attention_mask=inputs.get("attention_mask", None), 

364 ) 

365 ref_logits = ref_outputs.logits 

366 outputs = model(**inputs) 

367 return self.compute_loss_func(outputs, inputs["labels"], ref_logits) 

368 return super().compute_loss(model, inputs, *args, **kwargs) 

369 

370 def prediction_step( 

371 self, 

372 model: nn.Module, 

373 inputs: dict[str, Union[torch.Tensor, Any]], 

374 prediction_loss_only: bool, 

375 ignore_keys: Optional[list[str]] = None, 

376 **gen_kwargs, 

377 ) -> tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: 

378 """Remove the prompt span from generated tokens during generation-based eval.""" 

379 if self.args.predict_with_generate: 

380 labels = inputs.pop("labels", None) 

381 else: 

382 labels = inputs.get("labels") 

383 

384 loss, generated_tokens, _ = super().prediction_step( 

385 model, 

386 inputs, 

387 prediction_loss_only=prediction_loss_only, 

388 ignore_keys=ignore_keys, 

389 **gen_kwargs, 

390 ) 

391 if generated_tokens is not None and self.args.predict_with_generate: 

392 generated_tokens[:, : inputs["input_ids"].size(-1)] = self.processing_class.pad_token_id 

393 generated_tokens = generated_tokens.contiguous() 

394 

395 return loss, generated_tokens, labels 

396 

397 def save_predictions(self, dataset, predict_results, skip_special_tokens: bool = True) -> None: 

398 """Save generation results to `generated_predictions.jsonl`.""" 

399 if not self.is_world_process_zero(): 

400 return 

401 

402 output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl") 

403 logger.info("Saving prediction results to %s", output_prediction_file) 

404 

405 labels = np.where( 

406 predict_results.label_ids != getattr(self.data_collator, "label_pad_token_id", -100), 

407 predict_results.label_ids, 

408 self.processing_class.pad_token_id, 

409 ) 

410 preds = np.where( 

411 predict_results.predictions != getattr(self.data_collator, "label_pad_token_id", -100), 

412 predict_results.predictions, 

413 self.processing_class.pad_token_id, 

414 ) 

415 

416 for index, pred in enumerate(preds): 

417 pad_len = np.nonzero(pred != self.processing_class.pad_token_id)[0] 

418 if len(pad_len): 

419 preds[index] = np.concatenate((pred[pad_len[0] :], pred[: pad_len[0]]), axis=-1) 

420 

421 input_ids_column = dataset["input_ids"] 

422 try: 

423 input_ids_list = input_ids_column.to_pylist() 

424 except AttributeError: 

425 input_ids_list = list(input_ids_column) 

426 

427 decoded_inputs = self.processing_class.batch_decode(input_ids_list, skip_special_tokens=False) 

428 decoded_preds = self.processing_class.batch_decode(preds, skip_special_tokens=skip_special_tokens) 

429 decoded_labels = self.processing_class.batch_decode(labels, skip_special_tokens=skip_special_tokens) 

430 

431 with open(output_prediction_file, "w", encoding="utf-8") as file: 

432 for text, pred, label in zip(decoded_inputs, decoded_preds, decoded_labels): 

433 file.write(json.dumps({"prompt": text, "predict": pred, "label": label}, ensure_ascii=False) + "\n") 

434 

435 def _move_model_to_device(self, model: nn.Module, device: Optional[torch.device] = None): 

436 """Skip redundant device moves for HSDP-wrapped models.""" 

437 if isinstance(model, HSDPModule): 

438 return model 

439 if device is None: 

440 return model 

441 return model.to(device) 

442 

443 def train(self, *args, **kwargs): 

444 """Activate HP-specific Accelerate patches only during training.""" 

445 self._activate_accelerator_patches() 

446 try: 

447 return super().train(*args, **kwargs) 

448 finally: 

449 self._restore_accelerator_patches() 

450 

451 def training_step( 

452 self, 

453 model: nn.Module, 

454 inputs: dict[str, Any], 

455 num_items_in_batch: Optional[int] = None, 

456 ) -> torch.Tensor: 

457 """Keep Accelerate training flow and only add HSDP sync hooks.""" 

458 model.train() 

459 inputs = self._prepare_inputs(inputs) 

460 

461 sync_gradients = getattr(self.accelerator, "sync_gradients", True) 

462 if isinstance(model, HSDPModule): 

463 model.set_is_last_backward(sync_gradients) 

464 model.set_requires_gradient_sync(sync_gradients) 

465 

466 compute_loss_context_manager = getattr(self, "compute_loss_context_manager", nullcontext) 

467 with compute_loss_context_manager(): 

468 loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) 

469 

470 if self.args.n_gpu > 1: 

471 loss = loss.mean() 

472 

473 if not getattr(self, "model_accepts_loss_kwargs", False) and getattr(self, "compute_loss_func", None) is None: 

474 loss = loss / self.args.gradient_accumulation_steps 

475 

476 self.accelerator.backward(loss) 

477 

478 if isinstance(model, HSDPModule) and sync_gradients: 

479 hsdp_sync_stream() 

480 

481 return loss.detach() 

482 

483 def create_optimizer(self): 

484 """Create optimizer and wrap step with SkipDTensorDispatch.""" 

485 optimizer = super().create_optimizer() 

486 _wrap_optimizer_step_with_skip_dtensor_dispatch(optimizer) 

487 return optimizer 

488 

489 # ---- Checkpoint save/load via HyperParallel native APIs ---- 

490 

491 def _save_optimizer_and_scheduler(self, output_dir: str) -> None: 

492 """Save model/optimizer shards per-rank and scheduler via torch.save. 

493 

494 - Model: saved via HyperParallel's ``hp_save(use_collectives=False)`` so each 

495 rank writes its own shard independently (no collective communication). 

496 - Optimizer: DTensor state values are converted to local CPU tensors and 

497 saved per-rank via ``torch.save``. 

498 - Scheduler: standard ``torch.save`` (same as Trainer default). 

499 """ 

500 from hyper_parallel.core.distributed_checkpoint.api import save as hp_save # pylint: disable=C0415 

501 

502 os.makedirs(output_dir, exist_ok=True) 

503 rank = get_platform().get_rank() 

504 

505 # Model shards (for checkpoint resuming, separate from save_model HF export) 

506 model_dir = os.path.join(output_dir, f"{_HSDP_MODEL_NAME}_0") 

507 os.makedirs(model_dir, exist_ok=True) 

508 logger.info("Saving HSDP model shards to %s (rank %d)", model_dir, rank) 

509 model_sd = self.model.state_dict() 

510 hp_save(model_sd, checkpoint_id=model_dir, use_collectives=False) 

511 

512 # Optimizer shards (per-rank, local tensors) 

513 if self.optimizer is not None: 

514 optim_file = os.path.join(output_dir, f"{_HSDP_OPTIMIZER_NAME}_rank{rank}.pt") 

515 logger.info("Saving optimizer shard to %s", optim_file) 

516 local_optim_sd = _localize_optimizer_state(self.optimizer.state_dict()) 

517 torch.save(local_optim_sd, optim_file) 

518 

519 # Scheduler (standard torch.save, same as Trainer default) 

520 if self.args.should_save and self.lr_scheduler is not None: 

521 torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) 

522 

523 def _load_from_checkpoint(self, resume_from_checkpoint: str, model: Optional[nn.Module] = None) -> None: 

524 """Load model from HSDP sharded checkpoint saved by ``hp_save``.""" 

525 from hyper_parallel.core.distributed_checkpoint.api import load as hp_load # pylint: disable=C0415 

526 

527 target = model if model is not None else self.model 

528 model_dir = os.path.join(resume_from_checkpoint, f"{_HSDP_MODEL_NAME}_0") 

529 

530 if not os.path.isdir(model_dir): 

531 # Fallback to standard Trainer load (HF weights / FSDP checkpoint) 

532 return super()._load_from_checkpoint(resume_from_checkpoint, model=model) 

533 

534 logger.info("Loading HSDP model shards from %s", model_dir) 

535 state_dict = target.state_dict() 

536 hp_load(state_dict, checkpoint_id=model_dir, use_collectives=False) 

537 # hp_load modifies DTensor local storage in-place via the planner; 

538 # call load_state_dict to ensure consistency with HSDP internal bookkeeping. 

539 target.load_state_dict(state_dict) 

540 

541 # Remember checkpoint dir for optimizer loading in _load_optimizer_and_scheduler 

542 self._pending_hsdp_checkpoint = resume_from_checkpoint 

543 return None 

544 

545 def _load_optimizer_and_scheduler(self, checkpoint: Optional[str] = None) -> None: 

546 """Load optimizer/scheduler from per-rank checkpoint files.""" 

547 ckpt_dir = getattr(self, "_pending_hsdp_checkpoint", None) or checkpoint 

548 if ckpt_dir is None: 

549 return 

550 

551 rank = get_platform().get_rank() 

552 optim_file = os.path.join(ckpt_dir, f"{_HSDP_OPTIMIZER_NAME}_rank{rank}.pt") 

553 

554 if os.path.isfile(optim_file) and self.optimizer is not None: 

555 logger.info("Loading optimizer shard from %s", optim_file) 

556 saved_sd = torch.load(optim_file, map_location="cpu", weights_only=True) 

557 _load_local_optimizer_state(self.optimizer, saved_sd) 

558 

559 # Scheduler 

560 scheduler_file = os.path.join(ckpt_dir, "scheduler.pt") 

561 if os.path.isfile(scheduler_file) and self.lr_scheduler is not None: 

562 self.lr_scheduler.load_state_dict(torch.load(scheduler_file, map_location="cpu", weights_only=True)) 

563 

564 def save_model( # pylint: disable=invalid-name 

565 self, output_dir: Optional[str] = None, _internal_call: bool = False 

566 ): 

567 """Save model weights. 

568 

569 Match the baseline LlamaFactory behavior by exporting HF-format weights 

570 both for intermediate checkpoints and the final output directory. 

571 HSDP-native shards for resume are still handled separately by 

572 ``_save_optimizer_and_scheduler``. 

573 """ 

574 save_dir = output_dir or self.args.output_dir 

575 os.makedirs(save_dir, exist_ok=True) 

576 _export_to_hf_format(self.model, getattr(self, "processing_class", None), save_dir) 

577 

578 

579__all__ = ["HyperParallelArguments", "HyperParallelTrainer"]