Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / torch / fully_shard / state_dict_utils.py: 0%
45 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""State dict utilities for fully_shard (torch-specific)."""
16from typing import Any
18import torch
19import torch.distributed as dist
20from torch import nn
21from torch.distributed.checkpoint.state_dict import StateDictOptions
23from hyper_parallel.core.dtensor.dtensor import DTensor
26def _gather_full_state_dict(
27 state_dict: dict[str, Any], cpu_offload: bool
28) -> dict[str, Any]:
29 """All-gather every DTensor shard into a full tensor.
31 Args:
32 state_dict: Model state dict with DTensor or plain tensor values.
33 cpu_offload: If True, only rank-0 keeps the result on CPU;
34 other ranks return an empty dict to save memory.
35 """
36 is_rank0 = (not dist.is_initialized()) or (dist.get_rank() == 0)
38 gathered: dict[str, Any] = {}
39 for key, val in state_dict.items():
40 if isinstance(val, DTensor):
41 val = val.full_tensor()
42 if cpu_offload:
43 if not is_rank0:
44 del val
45 continue
46 if isinstance(val, torch.Tensor):
47 val = val.cpu()
48 gathered[key] = val
50 if cpu_offload and not is_rank0:
51 return {}
52 return gathered
55def _offload_sharded_state_dict(
56 state_dict: dict[str, Any],
57) -> dict[str, Any]:
58 """Move each shard to CPU without all-gathering.
60 Args:
61 state_dict: Model state dict with DTensor or plain tensor values.
62 """
63 offloaded: dict[str, Any] = {}
64 for key, val in state_dict.items():
65 if isinstance(val, DTensor):
66 val = DTensor.from_local(
67 val.to_local().cpu(), val.device_mesh, val.layout.alias_placements,
68 )
69 elif isinstance(val, torch.Tensor):
70 val = val.cpu()
71 offloaded[key] = val
72 return offloaded
75def get_model_state_dict(
76 model: nn.Module,
77 *,
78 options: StateDictOptions | None = None,
79) -> dict[str, Any]:
80 """Return the model state dict with configurable gathering and offloading.
82 Behaviour matrix:
84 +-----------------+-------------+--------------------------------------+
85 | full_state_dict | cpu_offload | result |
86 +=================+=============+======================================+
87 | False | False | DTensor (sharded, as-is) |
88 +-----------------+-------------+--------------------------------------+
89 | False | True | DTensor local shard offloaded to CPU |
90 +-----------------+-------------+--------------------------------------+
91 | True | False | full Tensor on **every** rank |
92 +-----------------+-------------+--------------------------------------+
93 | True | True | full Tensor on CPU, **rank 0 only** |
94 +-----------------+-------------+--------------------------------------+
96 Args:
97 model: The model whose state dict to retrieve.
98 options: Controls full_state_dict, cpu_offload,
99 ignore_frozen_params, and broadcast_from_rank0 flags.
100 """
101 options = options or StateDictOptions()
103 if options.broadcast_from_rank0 and not options.full_state_dict:
104 raise ValueError(
105 "full_state_dict must be True when broadcast_from_rank0 is True."
106 )
108 state_dict: dict[str, Any] = model.state_dict()
110 if options.ignore_frozen_params:
111 frozen_keys = {
112 name for name, p in model.named_parameters()
113 if not p.requires_grad
114 }
115 for key in frozen_keys:
116 state_dict.pop(key, None)
118 if options.full_state_dict:
119 return _gather_full_state_dict(state_dict, options.cpu_offload)
121 if options.cpu_offload:
122 return _offload_sharded_state_dict(state_dict)
124 return state_dict