Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / multicore / modules / common / compute_graph.py: 0%
141 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-20 07:18 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16ComputeGraph, OperatorNode, TensorSpec, SplitSpec, and TaskSplitValue for Multicore MoE-FFN scheduling.
17"""
18from dataclasses import dataclass, field
19from enum import Enum
20from typing import Any, Callable, List, Optional, Tuple
21from collections import deque
24class OpType(Enum):
25 ALLTOALL = "alltoall"
26 GMM = "gmm"
27 SWIGLU = "swiglu"
28 SWIGLU_GRAD = "swiglu_grad"
31@dataclass
32class TensorSpec:
33 """Describes a single tensor in the graph (a graph edge)."""
34 name: str
35 shape: list
36 param_position: int
37 dtype_size: int = 2 # bf16=2, int64=8, int32=4
38 tensor_type: int = 0 # 0=tensor, 1=tensorlist
39 transpose: bool = False
40 is_dynamic: bool = False
42 # Set by propagate_splits; not user-supplied
43 split_dim: int = field(default=-1, init=False)
44 split_num: int = field(default=1, init=False)
47@dataclass
48class SplitSpec:
49 """
50 Declarative split specification for an OperatorNode.
52 split_inputs
53 None → source operator; always splits.
54 list → list of (input_idx, split_dim) pairs; ALL must match.
56 split_output_dims
57 Per-output split axis. -1 = leave un-split.
59 task_num_fn
60 Callable(tsv) -> int. Computes task_num when split condition holds.
61 """
62 split_inputs: Optional[List[Tuple[int, int]]]
63 task_num_fn: Callable
64 split_output_dims: List[int] = field(default_factory=lambda: [0])
67@dataclass
68class OperatorNode:
69 """A single operator node in the compute graph."""
70 name: str
71 op_type: OpType
72 inputs: List[TensorSpec]
73 outputs: List[TensorSpec]
74 param_positions: List[int]
75 split_value: int
76 split_spec: SplitSpec
77 tiling_position: int
78 fill_config: Any # FillConfig subclass instance
79 kernel_spec: Any = None # KernelSpec; None for manual graphs, required for @MultiCore path
81 predecessors: List['OperatorNode'] = field(default_factory=list)
82 successors: List['OperatorNode'] = field(default_factory=list)
83 task_num: int = field(default=0, init=False)
86@dataclass
87class TaskSplitValue:
88 """
89 Hardware topology parameters + per-rank runtime counters.
90 Only contains topology (user inputs + derived) and counters.
91 Split values and task counts live on OperatorNode, not here.
92 """
93 # ── User inputs ───────────────────────────────────────────────────────────
94 tp: int = 4
95 ep: int = 4
96 seq_size: int = 8192
97 all_expert_num: int = 32
98 top_k: int = 8
100 # ── Derived properties ────────────────────────────────────────────────────
101 @property
102 def single_rank_expert_num(self) -> int:
103 return self.all_expert_num // self.ep
105 @property
106 def seq_all(self) -> int:
107 return (self.seq_size * self.ep * self.top_k) // self.tp
109 @property
110 def per_expert_seq(self) -> int:
111 return self.seq_all // self.top_k
113 @property
114 def per_rank_seq(self) -> int:
115 return self.seq_all // self.ep
117 @property
118 def per_expert_seq_to_other(self) -> int:
119 return self.seq_all // (self.ep * self.top_k)
121 @property
122 def all_event_num(self) -> int:
123 e = self.single_rank_expert_num
124 return 1 + self.all_expert_num + e + e + e
126 # ── Runtime counters (reset by init_task_split_value per rank) ────────────
127 rank_id: int = 0
128 pre_pre_event_num: int = 0
129 pre_event_num: int = 0
130 pre_task_num: int = 0
131 pre_cube_task_num: int = 0
132 pre_vector_task_num: int = 0
133 pre_mix_task_num: int = 0
136def init_task_split_value(tsv: TaskSplitValue) -> None:
137 """Reset per-rank runtime counters."""
138 tsv.pre_pre_event_num = 0
139 tsv.pre_event_num = 0
140 tsv.pre_task_num = 0
141 tsv.pre_cube_task_num = 0
142 tsv.pre_vector_task_num = 0
143 tsv.pre_mix_task_num = 0
146class ComputeGraph:
147 """Directed acyclic graph describing operator execution order."""
149 def __init__(self):
150 self._nodes: dict = {}
151 self._insertion_order: list = []
153 def add_op(self, op: OperatorNode) -> 'ComputeGraph':
154 self._nodes[op.name] = op
155 self._insertion_order.append(op.name)
156 return self
158 def add_edge(self, src, dst) -> 'ComputeGraph':
159 """src / dst can be an OperatorNode object or a name string."""
160 s = src if isinstance(src, OperatorNode) else self._nodes[src]
161 d = dst if isinstance(dst, OperatorNode) else self._nodes[dst]
162 s.successors.append(d)
163 d.predecessors.append(s)
164 return self
166 def get_op(self, name: str) -> OperatorNode:
167 """Look up an operator by name."""
168 return self._nodes[name]
170 def topological_sort(self) -> List[OperatorNode]:
171 """Kahn's algorithm; respects add_op insertion order for tie-breaking."""
172 in_deg = {n: len(op.predecessors) for n, op in self._nodes.items()}
173 queue = deque(n for n in self._insertion_order if in_deg[n] == 0)
174 order = []
175 while queue:
176 n = queue.popleft()
177 order.append(self._nodes[n])
178 for succ in self._nodes[n].successors:
179 in_deg[succ.name] -= 1
180 if in_deg[succ.name] == 0:
181 queue.append(succ.name)
182 if len(order) != len(self._nodes):
183 raise ValueError("ComputeGraph has a cycle")
184 return order
186 def propagate_splits(self, tsv: TaskSplitValue) -> None:
187 """
188 Compute task_num for each operator from SplitSpec and propagate
189 split_dim through shared TensorSpec objects.
191 Must be called once after graph construction, before any fill loop.
192 """
193 # Reset all tensor split info
194 for op in self._nodes.values():
195 for t in op.inputs + op.outputs:
196 t.split_dim = -1
197 t.split_num = 1
199 for op in self.topological_sort():
200 ss = op.split_spec
202 if ss.split_inputs is None:
203 task_num = ss.task_num_fn(tsv)
204 else:
205 if all(op.inputs[idx].split_dim == dim for idx, dim in ss.split_inputs):
206 task_num = ss.task_num_fn(tsv)
207 else:
208 task_num = 1
210 op.task_num = task_num
212 for i, out in enumerate(op.outputs):
213 if i < len(ss.split_output_dims):
214 d = ss.split_output_dims[i]
215 out.split_dim = d if (task_num > 1 and d >= 0) else -1
216 else:
217 out.split_dim = -1
218 out.split_num = task_num
220 def build_runtime_config(self, tsv: 'TaskSplitValue', rank_id: int = 0,
221 num_cube_cores: int = 24):
222 """
223 Generic RuntimeConfig builder for the framework (@MultiCore) path.
225 Runs init_task_split_value, the topological fill loop, and sets
226 cfg.task_num / cfg.atomic_add_values[0]. Graph-specific post-processing
227 (add_terminate, add_dynamic_data, revise_task_queue) is intentionally
228 omitted — call build_config_for_rank() in gen_runtime_data.py for MoE FFN.
229 """
230 from hyper_parallel.core.multicore.modules.common.runtime_structs import ( # pylint: disable=import-outside-toplevel
231 RuntimeConfigC, QUEUE_CAPACITY)
233 cfg = RuntimeConfigC()
234 cfg.num_workers = 2 * num_cube_cores
235 cfg.queue_capacity = QUEUE_CAPACITY
237 init_task_split_value(tsv)
238 tsv.rank_id = rank_id
240 for op in self.topological_sort():
241 op.fill_config.fill(cfg, op, tsv)
243 cfg.task_num = sum(op.task_num for op in self.topological_sort())
244 cfg.atomic_add_values[0] = 1
245 return cfg