Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / integration / llamafactory / utils.py: 0%
313 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Accelerate-style FSDP2 utilities backed by HyperParallel fully_shard."""
16import copy
17import functools
18import re
19import warnings
20from collections.abc import Iterable
21from typing import cast
23import torch
24import torch.distributed as dist
25from torch import nn
26from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy
28from hyper_parallel import init_device_mesh
29from hyper_parallel.core.dtensor.dtensor import DTensor, distribute_tensor
30from hyper_parallel.core.fully_shard.api import HSDPModule, fully_shard
31from hyper_parallel.core.fully_shard.utils import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy
32from hyper_parallel.platform import get_platform
34_DTYPE_MAP = {
35 "float32": torch.float32,
36 "fp32": torch.float32,
37 "float16": torch.float16,
38 "fp16": torch.float16,
39 "bfloat16": torch.bfloat16,
40 "bf16": torch.bfloat16,
41}
44def _resolve_device_type(hp_args) -> str:
45 """Resolve the runtime device type for HyperParallel wrapping."""
46 if hp_args.device_type != "auto":
47 return hp_args.device_type
48 if hasattr(torch, "npu") and torch.npu.is_available(): # pylint: disable=no-member
49 return "npu"
50 if torch.cuda.is_available():
51 return "cuda"
52 return "cpu"
55def _build_device_mesh(accelerator, hp_args):
56 """Build an FSDP mesh compatible with Accelerate's FSDP2 expectations."""
57 mesh = getattr(accelerator, "torch_device_mesh", None)
58 if mesh is not None:
59 fsdp_dim_names = getattr(getattr(accelerator, "parallelism_config", None), "fsdp_dim_names", None)
60 if fsdp_dim_names:
61 return mesh[tuple(fsdp_dim_names)]
62 return mesh
64 device_type = _resolve_device_type(hp_args)
65 world_size = get_platform().get_world_size()
66 return init_device_mesh(device_type, (world_size,), mesh_dim_names=("dp",))
69def _build_mp_policy(hp_args) -> MixedPrecisionPolicy:
70 """Build HyperParallel mixed precision policy."""
71 return MixedPrecisionPolicy(
72 param_dtype=_DTYPE_MAP[hp_args.param_dtype] if hp_args.param_dtype is not None else None,
73 reduce_dtype=_DTYPE_MAP[hp_args.reduce_dtype] if hp_args.reduce_dtype is not None else None,
74 output_dtype=_DTYPE_MAP[hp_args.param_dtype] if hp_args.param_dtype is not None else None,
75 cast_forward_inputs=True,
76 )
79def _resolve_offload_policy(fsdp2_plugin) -> OffloadPolicy:
80 """Translate Accelerate cpu_offload config to HyperParallel offload policy."""
81 cpu_offload = getattr(fsdp2_plugin, "cpu_offload", None)
82 if isinstance(cpu_offload, OffloadPolicy):
83 return cpu_offload
84 if cpu_offload is True:
85 return CPUOffloadPolicy()
86 if type(cpu_offload).__name__ == "CPUOffloadPolicy":
87 return CPUOffloadPolicy()
88 return OffloadPolicy()
91def _is_cpu_offload_enabled(cpu_offload) -> bool:
92 """Return whether CPU offload is truly enabled."""
93 if cpu_offload is True:
94 return True
95 if isinstance(cpu_offload, CPUOffloadPolicy):
96 return True
97 return type(cpu_offload).__name__ == "CPUOffloadPolicy"
100def _resolve_mp_policy(fsdp2_plugin, hp_args) -> MixedPrecisionPolicy:
101 """Resolve mixed precision with Accelerate defaults and optional HyperParallel overrides."""
102 policy = getattr(fsdp2_plugin, "mixed_precision_policy", None)
103 resolved_policy = MixedPrecisionPolicy()
104 if policy is not None:
105 resolved_policy = MixedPrecisionPolicy(
106 param_dtype=getattr(policy, "param_dtype", None),
107 reduce_dtype=getattr(policy, "reduce_dtype", None),
108 output_dtype=getattr(policy, "output_dtype", None),
109 cast_forward_inputs=getattr(policy, "cast_forward_inputs", True),
110 )
112 hp_policy = _build_mp_policy(hp_args)
113 if hp_args.param_dtype is not None:
114 resolved_policy.param_dtype = hp_policy.param_dtype
115 resolved_policy.output_dtype = hp_policy.output_dtype
116 if hp_args.reduce_dtype is not None:
117 resolved_policy.reduce_dtype = hp_policy.reduce_dtype
118 return resolved_policy
121def _is_compiled_module(model: nn.Module) -> bool:
122 """Best-effort check for compiled modules."""
123 return hasattr(model, "_orig_mod")
126def _get_module_children_bottom_up(model: nn.Module, return_fqns: bool = False):
127 """Return model children bottom-up, matching Accelerate helper semantics."""
128 modules = []
130 def _visit(module: nn.Module, prefix: str = ""):
131 for child_name, child in module.named_children():
132 child_prefix = f"{prefix}.{child_name}" if prefix else child_name
133 _visit(child, child_prefix)
134 modules.append((prefix, module) if return_fqns else module)
136 _visit(model)
137 return modules
140def _get_non_persistent_buffers(model: nn.Module, recurse: bool = True, fqns: bool = True):
141 """Collect non-persistent buffers."""
142 buffers = set()
143 for module_name, module in model.named_modules():
144 if not recurse and module is not model:
145 continue
146 for buffer_name in getattr(module, "_non_persistent_buffers_set", set()):
147 if fqns and module_name:
148 buffers.add(f"{module_name}.{buffer_name}")
149 else:
150 buffers.add(buffer_name)
151 return buffers
154def _get_module_class_from_name(module: nn.Module, class_name: str):
155 """Find a module class by name from the model tree."""
156 for child in module.modules():
157 if child.__class__.__name__ == class_name:
158 return child.__class__
159 return None
162def _move_model_to_meta(model: nn.Module) -> nn.Module:
163 """Move the model to meta before fully_shard to match Accelerate FSDP2 loading order."""
164 model = model.to(torch.device("meta"))
165 if hasattr(model, "tie_weights"):
166 model.tie_weights()
167 return model
171def _get_parameters_from_modules(modules: Iterable[nn.Module] | str, model: nn.Module, device) -> set[nn.Parameter]:
172 """Convert ignored modules to ignored parameters, matching Accelerate behaviour."""
173 if modules is None:
174 return set()
176 parameters = []
177 if isinstance(modules, str):
178 pattern = re.compile(modules)
179 matched_modules = []
180 for name, module in model.named_modules():
181 if pattern.fullmatch(name):
182 module.to(device)
183 matched_modules.append(module)
184 modules = matched_modules
186 for module in modules:
187 parameters.extend(list(module.parameters()))
188 return set(parameters)
191def _prepare_auto_wrap_policy(fsdp2_plugin, model: nn.Module):
192 """Prepare auto-wrap policy, copied from Accelerate FSDP2 logic."""
193 fn = fsdp2_plugin.auto_wrap_policy
194 if isinstance(fn, functools.partial):
195 fn = fn.func
197 if fn is transformer_auto_wrap_policy:
198 no_split_modules = getattr(model, "_no_split_modules", None) or []
199 transformer_cls_names_to_wrap = list(no_split_modules)
200 if fsdp2_plugin.transformer_cls_names_to_wrap is not None:
201 transformer_cls_names_to_wrap = fsdp2_plugin.transformer_cls_names_to_wrap
202 transformer_cls_to_wrap = set()
204 for layer_class in transformer_cls_names_to_wrap:
205 transformer_cls = _get_module_class_from_name(model, layer_class)
206 if transformer_cls is None:
207 raise ValueError(f"Could not find the transformer layer class {layer_class} in the model.")
208 transformer_cls_to_wrap.add(transformer_cls)
210 def policy(module: nn.Module) -> bool:
211 if fsdp2_plugin.transformer_cls_names_to_wrap is None:
212 return False
213 return isinstance(module, tuple(transformer_cls_to_wrap))
215 elif fn is size_based_auto_wrap_policy:
217 def policy(module: nn.Module) -> bool:
218 return sum(param.numel() for param in module.parameters()) > fsdp2_plugin.min_num_params
220 else:
221 return None
223 return policy
226def fsdp2_load_full_state_dict(accelerator, model: nn.Module, full_sd: dict, cpu_offload: bool = False):
227 """Load full state dict into a HyperParallel-sharded model following Accelerate semantics."""
228 meta_sharded_sd = model.state_dict()
229 local_sd = {}
231 def _infer_parameter_dtype(target_model: nn.Module, param_name: str, empty_param: torch.Tensor):
232 try:
233 old_param = target_model.get_parameter(param_name)
234 except Exception: # pylint: disable=broad-except
235 old_param = None
236 if old_param is None:
237 try:
238 old_param = target_model.get_buffer(param_name)
239 except Exception: # pylint: disable=broad-except
240 old_param = None
241 if old_param is None:
242 base_name, local_name = param_name.rsplit(".", 1)
243 old_param = getattr(target_model.get_submodule(base_name), local_name)
245 is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
246 casting_dtype = None
247 is_param_float8 = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn
248 if empty_param.dtype.is_floating_point and not is_param_float8:
249 casting_dtype = old_param.dtype
250 if isinstance(old_param, DTensor):
251 local_param = old_param.to_local()
252 return local_param is not None and local_param.is_contiguous(), casting_dtype
253 return old_param is not None and old_param.is_contiguous(), casting_dtype
255 def _cast_and_contiguous(tensor: torch.Tensor, to_contiguous: bool, dtype):
256 if isinstance(tensor, DTensor):
257 local_tensor = tensor.to_local()
258 if dtype is not None:
259 local_tensor = local_tensor.to(dtype=dtype)
260 if to_contiguous:
261 local_tensor = local_tensor.contiguous()
262 return DTensor.from_local(local_tensor, tensor.device_mesh, tensor.placements)
263 if dtype is not None:
264 tensor = tensor.to(dtype=dtype)
265 if to_contiguous:
266 tensor = tensor.contiguous()
267 return tensor
269 if accelerator.is_main_process:
270 iterable = full_sd.items()
271 else:
272 iterable = meta_sharded_sd.items()
274 for item in iterable:
275 if accelerator.is_main_process:
276 param_name, full_param = item
277 sharded_param = meta_sharded_sd[param_name]
278 else:
279 param_name, sharded_param = item
280 full_param = torch.empty(sharded_param.size(), device=accelerator.device, dtype=sharded_param.dtype)
282 if isinstance(full_param, DTensor):
283 full_param = full_param.to_local()
285 full_param = full_param.detach().to(accelerator.device)
286 dist.broadcast(full_param, src=0, group=dist.group.WORLD)
288 if isinstance(sharded_param, DTensor):
289 local_param = distribute_tensor(full_param, sharded_param.device_mesh, sharded_param.placements).to_local()
290 else:
291 local_param = full_param
293 to_contiguous, casting_dtype = _infer_parameter_dtype(model, param_name, local_param)
294 local_param = _cast_and_contiguous(local_param, to_contiguous, casting_dtype)
295 if isinstance(local_param, DTensor):
296 local_param = local_param.to_local()
297 local_param = local_param.detach().clone()
298 if not local_param.is_contiguous():
299 local_param = local_param.contiguous()
300 if cpu_offload:
301 local_param = local_param.to("cpu")
303 local_sd[param_name] = local_param
305 cast(nn.Module, model).load_state_dict(local_sd, assign=True)
306 return model
309def fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model: nn.Module):
310 """Prepare auto-wrap policy, matching Accelerate helper naming and behavior."""
311 return _prepare_auto_wrap_policy(fsdp2_plugin, model)
314def get_parameters_from_modules(modules: Iterable[nn.Module] | str, model: nn.Module, device) -> set[nn.Parameter]:
315 """Convert ignored modules to ignored parameters."""
316 return _get_parameters_from_modules(modules, model, device)
319def _is_fsdp2_wrapped_model(model: nn.Module) -> bool:
320 """Return whether the model is already wrapped by HyperParallel FSDP2."""
321 return isinstance(model, HSDPModule) or (
322 _is_compiled_module(model) and isinstance(model._orig_mod, HSDPModule) # pylint: disable=protected-access
323 )
326def _resolve_shard_size(mesh) -> int:
327 """Return the FSDP shard-dim size for a 1D FSDP or 2D HSDP mesh.
329 HP ``fully_shard`` builds ``FSDPMeshInfo(shard_mesh_dim=0)`` for a 1D mesh
330 and ``HSDPMeshInfo(shard_mesh_dim=1, replicate_mesh_dim=0)`` for a 2D mesh
331 (see ``platform/*/fully_shard/scheduler.py``). In both cases the shard
332 dim is the last mesh dim, so ``mesh.mesh_shape[-1]`` gives the actual
333 per-param shard count regardless of HSDP layout.
334 """
335 if mesh is None:
336 return get_platform().get_world_size()
337 shape = getattr(mesh, "mesh_shape", None)
338 if shape:
339 return int(shape[-1])
340 return mesh.size() if hasattr(mesh, "size") else get_platform().get_world_size()
343def _collect_replicate_params(model: nn.Module, shard_size: int) -> set:
344 """Collect params whose dim-0 isn't divisible by ``shard_size``.
346 HP ``fully_shard`` raises ``Uneven sharding on dim 0`` for such params
347 (e.g. ``shared_expert_gate.weight`` of shape ``(1, hidden)`` on
348 ``shard_size > 1``). Routing them through ``replicate_params`` makes
349 them DDP-replicated along the shard dim instead.
350 """
351 replicate = set()
352 if shard_size <= 1:
353 return replicate
354 for _, param in model.named_parameters():
355 if param.dim() == 0:
356 continue
357 if param.size(0) % shard_size != 0:
358 replicate.add(param)
359 return replicate
362def _build_fsdp2_kwargs(accelerator, model: nn.Module, hp_args, fsdp2_plugin) -> dict:
363 """Build fully_shard kwargs from accelerator and plugin settings."""
364 mesh = _build_device_mesh(accelerator, hp_args)
365 reshard_after_forward = fsdp2_plugin.reshard_after_forward
366 if hp_args.reshard_after_forward is not None:
367 reshard_after_forward = hp_args.reshard_after_forward
368 kwargs = {
369 "reshard_after_forward": reshard_after_forward,
370 "offload_policy": _resolve_offload_policy(fsdp2_plugin),
371 "mp_policy": _resolve_mp_policy(fsdp2_plugin, hp_args),
372 "mesh": mesh if mesh is not None else None,
373 "ignored_params": get_parameters_from_modules(fsdp2_plugin.ignored_modules, model, accelerator.device),
374 "comm_fusion": True,
375 }
376 replicate_params = _collect_replicate_params(model, _resolve_shard_size(mesh))
377 if replicate_params:
378 kwargs["replicate_params"] = replicate_params
379 return kwargs
382def _model_has_4bit_params(model: nn.Module) -> bool:
383 """Return whether the model contains bitsandbytes 4-bit parameters."""
384 return any(param.__class__.__name__ == "Params4bit" for _, param in model.named_parameters())
387def _prepare_cpu_ram_efficient_loading(model: nn.Module, enabled: bool) -> dict[str, torch.Tensor]:
388 """Capture non-persistent buffers before cpu_ram_efficient_loading rematerializes the model."""
389 if not enabled:
390 return {}
392 non_persistent_buffer_fqns = _get_non_persistent_buffers(model, recurse=True, fqns=True)
393 original_non_persistent_buffers = copy.deepcopy(
394 {name: buffer for name, buffer in model.named_buffers() if name in non_persistent_buffer_fqns}
395 )
396 return original_non_persistent_buffers
399def _apply_auto_wrap_policy(model: nn.Module, fsdp2_plugin, fsdp2_kwargs: dict) -> None:
400 """Apply fully_shard to matching child modules before wrapping the root module."""
401 auto_wrap_policy_func = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model)
402 if auto_wrap_policy_func is None:
403 return
405 for module in _get_module_children_bottom_up(model)[:-1]:
406 if auto_wrap_policy_func(module) and not isinstance(module, HSDPModule):
407 fully_shard(module, **fsdp2_kwargs)
410def _setup_prefetch(model: nn.Module) -> None:
411 """Set up forward and backward prefetch for HSDP-wrapped child modules.
413 Each wrapped layer prefetches the next layer's allgather during forward,
414 and the previous layer's allgather during backward, to overlap communication
415 with computation.
417 Backward prefetch uses reversed module order because backward execution
418 proceeds from the last layer to the first.
419 """
420 wrapped_modules = [m for m in model.modules() if isinstance(m, HSDPModule) and m is not model]
421 num_to_forward_prefetch = 1
422 num_to_backward_prefetch = 1
424 # Forward prefetch: each layer prefetches the next layer(s)
425 for i, layer in enumerate(wrapped_modules):
426 j_end = min(len(wrapped_modules), i + 1 + num_to_forward_prefetch)
427 forward_targets = wrapped_modules[i + 1:j_end]
428 if forward_targets:
429 layer.set_modules_to_forward_prefetch(forward_targets)
431 # Backward prefetch: reverse order since backward runs last-to-first
432 wrapped_modules.reverse()
433 for i, layer in enumerate(wrapped_modules):
434 j_end = min(len(wrapped_modules), i + 1 + num_to_backward_prefetch)
435 backward_targets = wrapped_modules[i + 1:j_end]
436 if backward_targets:
437 layer.set_modules_to_backward_prefetch(backward_targets)
440def _restore_non_persistent_buffers(model: nn.Module, buffers: dict[str, torch.Tensor], device) -> None:
441 """Restore non-persistent buffers after cpu_ram_efficient_loading finishes."""
442 if not buffers:
443 return
445 for fqn, buffer_tensor in buffers.items():
446 buffer_tensor = buffer_tensor.to(device)
447 if "." in fqn:
448 parent_fqn, local_buffer_name = fqn.rsplit(".", 1)
449 parent_module = model.get_submodule(parent_fqn)
450 else:
451 local_buffer_name = fqn
452 parent_module = model
453 parent_module.register_buffer(local_buffer_name, buffer_tensor, persistent=False)
455 if hasattr(model, "tie_weights"):
456 model.tie_weights()
459def _maybe_upcast_trainable_params(accelerator, model: nn.Module) -> None:
460 """Upcast model parameters to fp32 when mixed precision requires Accelerate-compatible behavior.
462 ``model.to(torch.float32)`` creates new fp32 parameters in the module tree.
463 Refresh HSDP's cached sharded parameter references and mixed-precision dtypes
464 so comm_fusion uses the new fp32 parameter dtype as well.
465 """
466 model_dtype = getattr(model, "dtype", None)
467 should_upcast = accelerator.mixed_precision != "no" and (model_dtype is None or model_dtype != torch.float32)
468 if not should_upcast:
469 return
471 model.to(torch.float32)
473 for module in model.modules():
474 if isinstance(module, HSDPModule):
475 state = module.hsdp_scheduler.hsdp_state # pylint: disable=protected-access
476 for hsdp_param in state.hsdp_params:
477 if hsdp_param.is_sharded:
478 hsdp_param.reset_sharded_param()
479 param_group = getattr(state, "param_group", None)
480 if param_group is not None:
481 param_group._init_mp_dtypes() # pylint: disable=protected-access
483 if accelerator.is_main_process:
484 warnings.warn(
485 "FSDP upcast of low precision parameters to fp32 (since mixed_precision != 'no') "
486 "may affect the precision of model checkpoints."
487 )
491def fsdp2_prepare_model(accelerator, model: nn.Module, hp_args) -> nn.Module:
492 """
493 Prepare model following Accelerate FSDP2 flow, using HyperParallel fully_shard.
495 This function is designed to be called with the runtime `accelerator`
496 instance already created by `transformers.Trainer` / `accelerate`.
498 Required accelerator attributes:
499 state.fsdp_plugin: FSDP plugin configuration used to derive wrapping and
500 state-dict behaviour.
501 torch_device_mesh: Optional device mesh prepared by Accelerate.
502 parallelism_config.fsdp_dim_names: Optional FSDP mesh dimension names
503 used when `torch_device_mesh` is available.
504 device: Current process device, used for ignored module parameter
505 materialization and buffer restoration.
506 is_main_process: Whether the current rank is the main process during
507 full state-dict distribution.
508 mixed_precision: Mixed precision mode string, used for the final
509 parameter upcast behavior.
510 """
511 if _is_fsdp2_wrapped_model(model):
512 return model
514 fsdp2_plugin = accelerator.state.fsdp_plugin
515 fsdp2_plugin.set_auto_wrap_policy(model)
517 model_has_params4bit = _model_has_4bit_params(model)
518 original_sd = model.state_dict()
519 should_restore_non_persistent_buffers = fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit
520 original_non_persistent_buffers = _prepare_cpu_ram_efficient_loading(model, should_restore_non_persistent_buffers)
521 if should_restore_non_persistent_buffers:
522 model = _move_model_to_meta(model)
524 fsdp2_kwargs = _build_fsdp2_kwargs(accelerator, model, hp_args, fsdp2_plugin)
526 _apply_auto_wrap_policy(model, fsdp2_plugin, fsdp2_kwargs)
527 if not isinstance(model, HSDPModule):
528 fully_shard(model, **fsdp2_kwargs)
530 _setup_prefetch(model)
532 if fsdp2_plugin.cpu_ram_efficient_loading:
533 fsdp2_load_full_state_dict(
534 accelerator,
535 model,
536 original_sd,
537 cpu_offload=_is_cpu_offload_enabled(fsdp2_plugin.cpu_offload),
538 )
540 _restore_non_persistent_buffers(model, original_non_persistent_buffers, accelerator.device)
541 _maybe_upcast_trainable_params(accelerator, model)
542 return model