Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / multicore / modules / common / compute_graph.py: 0%

141 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-20 07:18 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16ComputeGraph, OperatorNode, TensorSpec, SplitSpec, and TaskSplitValue for Multicore MoE-FFN scheduling. 

17""" 

18from dataclasses import dataclass, field 

19from enum import Enum 

20from typing import Any, Callable, List, Optional, Tuple 

21from collections import deque 

22 

23 

24class OpType(Enum): 

25 ALLTOALL = "alltoall" 

26 GMM = "gmm" 

27 SWIGLU = "swiglu" 

28 SWIGLU_GRAD = "swiglu_grad" 

29 

30 

31@dataclass 

32class TensorSpec: 

33 """Describes a single tensor in the graph (a graph edge).""" 

34 name: str 

35 shape: list 

36 param_position: int 

37 dtype_size: int = 2 # bf16=2, int64=8, int32=4 

38 tensor_type: int = 0 # 0=tensor, 1=tensorlist 

39 transpose: bool = False 

40 is_dynamic: bool = False 

41 

42 # Set by propagate_splits; not user-supplied 

43 split_dim: int = field(default=-1, init=False) 

44 split_num: int = field(default=1, init=False) 

45 

46 

47@dataclass 

48class SplitSpec: 

49 """ 

50 Declarative split specification for an OperatorNode. 

51 

52 split_inputs 

53 None → source operator; always splits. 

54 list → list of (input_idx, split_dim) pairs; ALL must match. 

55 

56 split_output_dims 

57 Per-output split axis. -1 = leave un-split. 

58 

59 task_num_fn 

60 Callable(tsv) -> int. Computes task_num when split condition holds. 

61 """ 

62 split_inputs: Optional[List[Tuple[int, int]]] 

63 task_num_fn: Callable 

64 split_output_dims: List[int] = field(default_factory=lambda: [0]) 

65 

66 

67@dataclass 

68class OperatorNode: 

69 """A single operator node in the compute graph.""" 

70 name: str 

71 op_type: OpType 

72 inputs: List[TensorSpec] 

73 outputs: List[TensorSpec] 

74 param_positions: List[int] 

75 split_value: int 

76 split_spec: SplitSpec 

77 tiling_position: int 

78 fill_config: Any # FillConfig subclass instance 

79 kernel_spec: Any = None # KernelSpec; None for manual graphs, required for @MultiCore path 

80 

81 predecessors: List['OperatorNode'] = field(default_factory=list) 

82 successors: List['OperatorNode'] = field(default_factory=list) 

83 task_num: int = field(default=0, init=False) 

84 

85 

86@dataclass 

87class TaskSplitValue: 

88 """ 

89 Hardware topology parameters + per-rank runtime counters. 

90 Only contains topology (user inputs + derived) and counters. 

91 Split values and task counts live on OperatorNode, not here. 

92 """ 

93 # ── User inputs ─────────────────────────────────────────────────────────── 

94 tp: int = 4 

95 ep: int = 4 

96 seq_size: int = 8192 

97 all_expert_num: int = 32 

98 top_k: int = 8 

99 

100 # ── Derived properties ──────────────────────────────────────────────────── 

101 @property 

102 def single_rank_expert_num(self) -> int: 

103 return self.all_expert_num // self.ep 

104 

105 @property 

106 def seq_all(self) -> int: 

107 return (self.seq_size * self.ep * self.top_k) // self.tp 

108 

109 @property 

110 def per_expert_seq(self) -> int: 

111 return self.seq_all // self.top_k 

112 

113 @property 

114 def per_rank_seq(self) -> int: 

115 return self.seq_all // self.ep 

116 

117 @property 

118 def per_expert_seq_to_other(self) -> int: 

119 return self.seq_all // (self.ep * self.top_k) 

120 

121 @property 

122 def all_event_num(self) -> int: 

123 e = self.single_rank_expert_num 

124 return 1 + self.all_expert_num + e + e + e 

125 

126 # ── Runtime counters (reset by init_task_split_value per rank) ──────────── 

127 rank_id: int = 0 

128 pre_pre_event_num: int = 0 

129 pre_event_num: int = 0 

130 pre_task_num: int = 0 

131 pre_cube_task_num: int = 0 

132 pre_vector_task_num: int = 0 

133 pre_mix_task_num: int = 0 

134 

135 

136def init_task_split_value(tsv: TaskSplitValue) -> None: 

137 """Reset per-rank runtime counters.""" 

138 tsv.pre_pre_event_num = 0 

139 tsv.pre_event_num = 0 

140 tsv.pre_task_num = 0 

141 tsv.pre_cube_task_num = 0 

142 tsv.pre_vector_task_num = 0 

143 tsv.pre_mix_task_num = 0 

144 

145 

146class ComputeGraph: 

147 """Directed acyclic graph describing operator execution order.""" 

148 

149 def __init__(self): 

150 self._nodes: dict = {} 

151 self._insertion_order: list = [] 

152 

153 def add_op(self, op: OperatorNode) -> 'ComputeGraph': 

154 self._nodes[op.name] = op 

155 self._insertion_order.append(op.name) 

156 return self 

157 

158 def add_edge(self, src, dst) -> 'ComputeGraph': 

159 """src / dst can be an OperatorNode object or a name string.""" 

160 s = src if isinstance(src, OperatorNode) else self._nodes[src] 

161 d = dst if isinstance(dst, OperatorNode) else self._nodes[dst] 

162 s.successors.append(d) 

163 d.predecessors.append(s) 

164 return self 

165 

166 def get_op(self, name: str) -> OperatorNode: 

167 """Look up an operator by name.""" 

168 return self._nodes[name] 

169 

170 def topological_sort(self) -> List[OperatorNode]: 

171 """Kahn's algorithm; respects add_op insertion order for tie-breaking.""" 

172 in_deg = {n: len(op.predecessors) for n, op in self._nodes.items()} 

173 queue = deque(n for n in self._insertion_order if in_deg[n] == 0) 

174 order = [] 

175 while queue: 

176 n = queue.popleft() 

177 order.append(self._nodes[n]) 

178 for succ in self._nodes[n].successors: 

179 in_deg[succ.name] -= 1 

180 if in_deg[succ.name] == 0: 

181 queue.append(succ.name) 

182 if len(order) != len(self._nodes): 

183 raise ValueError("ComputeGraph has a cycle") 

184 return order 

185 

186 def propagate_splits(self, tsv: TaskSplitValue) -> None: 

187 """ 

188 Compute task_num for each operator from SplitSpec and propagate 

189 split_dim through shared TensorSpec objects. 

190 

191 Must be called once after graph construction, before any fill loop. 

192 """ 

193 # Reset all tensor split info 

194 for op in self._nodes.values(): 

195 for t in op.inputs + op.outputs: 

196 t.split_dim = -1 

197 t.split_num = 1 

198 

199 for op in self.topological_sort(): 

200 ss = op.split_spec 

201 

202 if ss.split_inputs is None: 

203 task_num = ss.task_num_fn(tsv) 

204 else: 

205 if all(op.inputs[idx].split_dim == dim for idx, dim in ss.split_inputs): 

206 task_num = ss.task_num_fn(tsv) 

207 else: 

208 task_num = 1 

209 

210 op.task_num = task_num 

211 

212 for i, out in enumerate(op.outputs): 

213 if i < len(ss.split_output_dims): 

214 d = ss.split_output_dims[i] 

215 out.split_dim = d if (task_num > 1 and d >= 0) else -1 

216 else: 

217 out.split_dim = -1 

218 out.split_num = task_num 

219 

220 def build_runtime_config(self, tsv: 'TaskSplitValue', rank_id: int = 0, 

221 num_cube_cores: int = 24): 

222 """ 

223 Generic RuntimeConfig builder for the framework (@MultiCore) path. 

224 

225 Runs init_task_split_value, the topological fill loop, and sets 

226 cfg.task_num / cfg.atomic_add_values[0]. Graph-specific post-processing 

227 (add_terminate, add_dynamic_data, revise_task_queue) is intentionally 

228 omitted — call build_config_for_rank() in gen_runtime_data.py for MoE FFN. 

229 """ 

230 from hyper_parallel.core.multicore.modules.common.runtime_structs import ( # pylint: disable=import-outside-toplevel 

231 RuntimeConfigC, QUEUE_CAPACITY) 

232 

233 cfg = RuntimeConfigC() 

234 cfg.num_workers = 2 * num_cube_cores 

235 cfg.queue_capacity = QUEUE_CAPACITY 

236 

237 init_task_split_value(tsv) 

238 tsv.rank_id = rank_id 

239 

240 for op in self.topological_sort(): 

241 op.fill_config.fill(cfg, op, tsv) 

242 

243 cfg.task_num = sum(op.task_num for op in self.topological_sort()) 

244 cfg.atomic_add_values[0] = 1 

245 return cfg