1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Backward pass: generate RuntimeConfig binary files and tiling binary files.
17
18Usage:
19 python gen_runtime_data.py [--tp 4] [--ep 4] [--seq_size 8192]
20 [--all_expert_num 32] [--top_k 8]
21 [--output_dir multicore_moe_ffn_grad_tp4_ep4_910b]
22
23Outputs (rank-independent):
24 <output_dir>/act_grad_tiling.bin (act_grad, pos 20)
25 <output_dir>/gate_grad_tiling.bin (gate_grad, pos 21)
26 <output_dir>/w2_grad_tiling.bin (w2_grad, pos 22)
27 <output_dir>/w1_grad_tiling.bin (w1_grad, pos 23)
28 <output_dir>/swiglu_grad_tiling.bin (SwiGLU-grad, pos 24)
29 <output_dir>/all_event_counters.bin 1024×int32 zeros (4 KB)
30 <output_dir>/gmm_workspace.bin 256 MiB zeros
31
32Outputs (per rank):
33 <output_dir>/runtime_config_input_rank_<i>.bin
34"""
35import argparse
36import os
37import numpy as np
38
39from hyper_parallel.core.multicore.modules.moe_ffn.common.runtime_structs import (
40 RuntimeConfigC, QUEUE_CAPACITY,
41)
42from hyper_parallel.core.multicore.modules.moe_ffn.common.compute_graph import TaskSplitValue, init_task_split_value
43from hyper_parallel.core.multicore.modules.moe_ffn.common.task_builder_utils import revise_task_queue
44from hyper_parallel.core.multicore.modules.moe_ffn.common.task_builders import (
45 add_terminate, add_dynamic_data, revise_gmm_task_queue_bwd,
46)
47from .tiling_tables import (
48 get_act_grad_tiling_bytes,
49 get_gate_grad_tiling_bytes,
50 get_w2_grad_tiling_bytes,
51 get_w1_grad_tiling_bytes,
52 get_swiglu_grad_tiling_bytes,
53)
54from .backward_graph import build_backward_graph
55
56
57def parse_args():
58 """Parse command-line arguments for backward data generation."""
59 p = argparse.ArgumentParser()
60 p.add_argument('--tp', type=int, default=4)
61 p.add_argument('--ep', type=int, default=4)
62 p.add_argument('--seq_size', type=int, default=8192)
63 p.add_argument('--all_expert_num', type=int, default=32)
64 p.add_argument('--top_k', type=int, default=8)
65 p.add_argument('--hidden_size', type=int, default=7168,
66 help='Model hidden dimension (e.g. 7168 for Qwen3-235B)')
67 p.add_argument('--intermediate_size', type=int, default=2048,
68 help='FFN intermediate dimension after SwiGLU halving (half of up-proj output, e.g. 2048)')
69 p.add_argument('--dtype_size', type=int, default=2,
70 help='Activation bytes per element: bf16=2, fp32=4')
71 p.add_argument('--num_cube_cores', type=int, default=24,
72 help='Number of AI Cube cores on target hardware (910B=24)')
73 p.add_argument('--output_dir', type=str,
74 default='multicore_moe_ffn_grad_tp4_ep4_910b')
75 return p.parse_args()
76
77
78def build_config_for_rank(graph, tsv: TaskSplitValue, rank_id: int,
79 num_cube_cores: int = 24) -> RuntimeConfigC:
80 """Build backward RuntimeConfig for a single rank."""
81 cfg = RuntimeConfigC()
82 cfg.num_workers = 2 * num_cube_cores # NUM_WORKERS_VECTOR = 2 × NUM_WORKERS_CUBE
83 cfg.queue_capacity = QUEUE_CAPACITY
84
85 # Reset counters
86 init_task_split_value(tsv)
87 tsv.rank_id = rank_id
88
89 # Fill tasks in topological order (dispatch, act_grad, w1_grad, swiglu_grad, gate_grad, combine, w2_grad)
90 for op in graph.topological_sort():
91 op.fill_config.fill(cfg, op, tsv)
92
93 # task_num_all = sum of all op tasks + 1 (terminate)
94 task_num_all = sum(op.task_num for op in graph.topological_sort()) + 1
95
96 dispatch_op = graph.get_op("dispatch")
97 swiglu_grad_op = graph.get_op("swiglu_grad")
98 act_grad_op = graph.get_op("act_grad")
99 w2_grad_op = graph.get_op("w2_grad")
100 w1_grad_op = graph.get_op("w1_grad")
101 combine_op = graph.get_op("combine")
102
103 add_terminate(cfg, tsv,
104 w1_grad_op.task_num + w2_grad_op.task_num
105 + combine_op.task_num // tsv.ep * tsv.ep)
106 revise_task_queue(cfg, tsv, dispatch_op.task_num, swiglu_grad_op.task_num)
107 revise_gmm_task_queue_bwd(cfg, tsv, act_grad_op.task_num, num_cube_cores=num_cube_cores)
108 add_dynamic_data(cfg, tsv, dynamic_input_position=19)
109
110 cfg.task_num = task_num_all
111 cfg.atomic_add_values[0] = 1
112 return cfg
113
114
115def write_bin(path: str, data: bytes) -> None:
116 os.makedirs(os.path.dirname(path) or '.', exist_ok=True)
117 with open(path, 'wb') as f:
118 f.write(data)
119 print(f" wrote {len(data):>10,} bytes → {path}")
120
121
122def main():
123 args = parse_args()
124 out = args.output_dir
125
126 tsv = TaskSplitValue(
127 tp=args.tp, ep=args.ep,
128 seq_size=args.seq_size,
129 all_expert_num=args.all_expert_num,
130 top_k=args.top_k,
131 )
132 num_groups = tsv.single_rank_expert_num
133 graph = build_backward_graph(tsv,
134 dispatch_sv=128, act_grad_sv=4096, w1_grad_sv=4096,
135 swiglu_sv=128, gate_grad_sv=4096, w2_grad_sv=4096,
136 combine_sv=128,
137 hidden_size=args.hidden_size,
138 intermediate_size=args.intermediate_size,
139 dtype_size=args.dtype_size,
140 num_cube_cores=args.num_cube_cores)
141 # Compute task_num for each operator via split-axis propagation
142 graph.propagate_splits(tsv)
143
144 dispatch_op = graph.get_op("dispatch")
145 act_grad_op = graph.get_op("act_grad")
146 w1_grad_op = graph.get_op("w1_grad")
147 swiglu_grad_op = graph.get_op("swiglu_grad")
148 gate_grad_op = graph.get_op("gate_grad")
149 w2_grad_op = graph.get_op("w2_grad")
150 combine_op = graph.get_op("combine")
151
152 print(f"[bwd] tp={args.tp} ep={args.ep} seq={args.seq_size} "
153 f"E={args.all_expert_num} topk={args.top_k}")
154 print(f" dispatch={dispatch_op.task_num} act_grad={act_grad_op.task_num} "
155 f"w1_grad={w1_grad_op.task_num} swiglu_grad={swiglu_grad_op.task_num} "
156 f"gate_grad={gate_grad_op.task_num} w2_grad={w2_grad_op.task_num} "
157 f"combine={combine_op.task_num}")
158
159 # ── Tiling files (rank-independent) ──────────────────────────────────────
160 act_grad_bytes = get_act_grad_tiling_bytes(act_grad_op.split_value,
161 hidden_size=args.hidden_size,
162 intermediate_size=args.intermediate_size,
163 num_groups=num_groups,
164 num_cube_cores=args.num_cube_cores)
165 gate_grad_bytes = get_gate_grad_tiling_bytes(gate_grad_op.split_value,
166 hidden_size=args.hidden_size,
167 intermediate_size=args.intermediate_size,
168 num_groups=num_groups,
169 num_cube_cores=args.num_cube_cores)
170 w2_grad_bytes = get_w2_grad_tiling_bytes(w2_grad_op.split_value,
171 hidden_size=args.hidden_size,
172 intermediate_size=args.intermediate_size,
173 num_groups=num_groups,
174 num_cube_cores=args.num_cube_cores)
175 w1_grad_bytes = get_w1_grad_tiling_bytes(w1_grad_op.split_value,
176 hidden_size=args.hidden_size,
177 intermediate_size=args.intermediate_size,
178 num_groups=num_groups,
179 num_cube_cores=args.num_cube_cores)
180 swiglu_grad_bytes = get_swiglu_grad_tiling_bytes(swiglu_grad_op.split_value,
181 intermediate_size=args.intermediate_size)
182
183 write_bin(os.path.join(out, 'act_grad_tiling.bin'), act_grad_bytes)
184 write_bin(os.path.join(out, 'gate_grad_tiling.bin'), gate_grad_bytes)
185 write_bin(os.path.join(out, 'w2_grad_tiling.bin'), w2_grad_bytes)
186 write_bin(os.path.join(out, 'w1_grad_tiling.bin'), w1_grad_bytes)
187 write_bin(os.path.join(out, 'swiglu_grad_tiling.bin'), swiglu_grad_bytes)
188
189 # ── Event counters + workspace (rank-independent) ─────────────────────────
190 # all_event_counters: 1024×int32_t zeros — matches C++ reference gen_data
191 write_bin(os.path.join(out, 'all_event_counters.bin'),
192 np.zeros(1024, dtype=np.int32).tobytes())
193 # gmm_workspace: 256 MiB zeros — kernel-internal scratch buffer
194 write_bin(os.path.join(out, 'gmm_workspace.bin'),
195 bytes(256 * 1024 * 1024))
196
197 # ── RuntimeConfig files (one per rank) ───────────────────────────────────
198 for rank_id in range(args.ep):
199 cfg = build_config_for_rank(graph, tsv, rank_id, num_cube_cores=args.num_cube_cores)
200 data = bytes(cfg)
201 path = os.path.join(out, f'runtime_config_input_rank_{rank_id}.bin')
202 write_bin(path, data)
203
204 print("[bwd] done.")
205
206
207if __name__ == '__main__':
208 main()