Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / torch / fully_shard / state_dict_utils.py: 0%

45 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-20 07:18 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""State dict utilities for fully_shard (torch-specific).""" 

16from typing import Any 

17 

18import torch 

19import torch.distributed as dist 

20from torch import nn 

21from torch.distributed.checkpoint.state_dict import StateDictOptions 

22 

23from hyper_parallel.core.dtensor.dtensor import DTensor 

24 

25 

26def _gather_full_state_dict( 

27 state_dict: dict[str, Any], cpu_offload: bool 

28) -> dict[str, Any]: 

29 """All-gather every DTensor shard into a full tensor. 

30 

31 Args: 

32 state_dict: Model state dict with DTensor or plain tensor values. 

33 cpu_offload: If True, only rank-0 keeps the result on CPU; 

34 other ranks return an empty dict to save memory. 

35 """ 

36 is_rank0 = (not dist.is_initialized()) or (dist.get_rank() == 0) 

37 

38 gathered: dict[str, Any] = {} 

39 for key, val in state_dict.items(): 

40 if isinstance(val, DTensor): 

41 val = val.full_tensor() 

42 if cpu_offload: 

43 if not is_rank0: 

44 del val 

45 continue 

46 if isinstance(val, torch.Tensor): 

47 val = val.cpu() 

48 gathered[key] = val 

49 

50 if cpu_offload and not is_rank0: 

51 return {} 

52 return gathered 

53 

54 

55def _offload_sharded_state_dict( 

56 state_dict: dict[str, Any], 

57) -> dict[str, Any]: 

58 """Move each shard to CPU without all-gathering. 

59 

60 Args: 

61 state_dict: Model state dict with DTensor or plain tensor values. 

62 """ 

63 offloaded: dict[str, Any] = {} 

64 for key, val in state_dict.items(): 

65 if isinstance(val, DTensor): 

66 val = DTensor.from_local( 

67 val.to_local().cpu(), val.device_mesh, val.layout.alias_placements, 

68 ) 

69 elif isinstance(val, torch.Tensor): 

70 val = val.cpu() 

71 offloaded[key] = val 

72 return offloaded 

73 

74 

75def get_model_state_dict( 

76 model: nn.Module, 

77 *, 

78 options: StateDictOptions | None = None, 

79) -> dict[str, Any]: 

80 """Return the model state dict with configurable gathering and offloading. 

81 

82 Behaviour matrix: 

83 

84 +-----------------+-------------+--------------------------------------+ 

85 | full_state_dict | cpu_offload | result | 

86 +=================+=============+======================================+ 

87 | False | False | DTensor (sharded, as-is) | 

88 +-----------------+-------------+--------------------------------------+ 

89 | False | True | DTensor local shard offloaded to CPU | 

90 +-----------------+-------------+--------------------------------------+ 

91 | True | False | full Tensor on **every** rank | 

92 +-----------------+-------------+--------------------------------------+ 

93 | True | True | full Tensor on CPU, **rank 0 only** | 

94 +-----------------+-------------+--------------------------------------+ 

95 

96 Args: 

97 model: The model whose state dict to retrieve. 

98 options: Controls full_state_dict, cpu_offload, 

99 ignore_frozen_params, and broadcast_from_rank0 flags. 

100 """ 

101 options = options or StateDictOptions() 

102 

103 if options.broadcast_from_rank0 and not options.full_state_dict: 

104 raise ValueError( 

105 "full_state_dict must be True when broadcast_from_rank0 is True." 

106 ) 

107 

108 state_dict: dict[str, Any] = model.state_dict() 

109 

110 if options.ignore_frozen_params: 

111 frozen_keys = { 

112 name for name, p in model.named_parameters() 

113 if not p.requires_grad 

114 } 

115 for key in frozen_keys: 

116 state_dict.pop(key, None) 

117 

118 if options.full_state_dict: 

119 return _gather_full_state_dict(state_dict, options.cpu_offload) 

120 

121 if options.cpu_offload: 

122 return _offload_sharded_state_dict(state_dict) 

123 

124 return state_dict