1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Forward pass: generate RuntimeConfig binary files and tiling binary files.
17
18Usage:
19 python gen_runtime_data.py [--tp 4] [--ep 4] [--seq_size 8192]
20 [--all_expert_num 32] [--top_k 8]
21 [--output_dir multicore_moe_ffn_tp4_ep4_910b]
22
23Outputs (rank-independent):
24 <output_dir>/up_proj_tiling.bin
25 <output_dir>/swiglu_tiling.bin
26 <output_dir>/down_proj_tiling.bin
27 <output_dir>/all_event_counters.bin 1024×int32 zeros (4 KB)
28 <output_dir>/gmm_workspace.bin 256 MiB zeros
29
30Outputs (per rank):
31 <output_dir>/runtime_config_input_rank_<i>.bin
32"""
33import argparse
34import os
35import numpy as np
36
37from hyper_parallel.core.multicore.modules.moe_ffn.common.runtime_structs import (
38 RuntimeConfigC, QUEUE_CAPACITY,
39)
40from hyper_parallel.core.multicore.modules.moe_ffn.common.compute_graph import TaskSplitValue, init_task_split_value
41from hyper_parallel.core.multicore.modules.moe_ffn.common.task_builder_utils import revise_task_queue
42from hyper_parallel.core.multicore.modules.moe_ffn.common.task_builders import add_terminate, add_dynamic_data
43from .tiling_tables import (
44 get_up_proj_tiling_bytes,
45 get_down_proj_tiling_bytes,
46 get_swiglu_tiling_bytes,
47)
48from .forward_graph import build_forward_graph
49
50
51def parse_args():
52 """Parse command-line arguments for forward data generation."""
53 p = argparse.ArgumentParser()
54 p.add_argument('--tp', type=int, default=4)
55 p.add_argument('--ep', type=int, default=4)
56 p.add_argument('--seq_size', type=int, default=8192)
57 p.add_argument('--all_expert_num', type=int, default=32)
58 p.add_argument('--top_k', type=int, default=8)
59 p.add_argument('--hidden_size', type=int, default=7168,
60 help='Model hidden dimension (e.g. 7168 for Qwen3-235B)')
61 p.add_argument('--intermediate_size', type=int, default=2048,
62 help='FFN intermediate dimension after SwiGLU halving (half of up-proj output, e.g. 2048)')
63 p.add_argument('--dtype_size', type=int, default=2,
64 help='Activation bytes per element: bf16=2, fp32=4')
65 p.add_argument('--num_cube_cores', type=int, default=24,
66 help='Number of AI Cube cores on target hardware (910B=24)')
67 p.add_argument('--output_dir', type=str,
68 default='multicore_moe_ffn_tp4_ep4_910b')
69 return p.parse_args()
70
71
72def build_config_for_rank(graph, tsv: TaskSplitValue, rank_id: int,
73 num_cube_cores: int = 24) -> RuntimeConfigC:
74 """Build RuntimeConfig for a single rank."""
75 cfg = RuntimeConfigC()
76 cfg.num_workers = 2 * num_cube_cores # NUM_WORKERS_VECTOR = 2 × NUM_WORKERS_CUBE
77 cfg.queue_capacity = QUEUE_CAPACITY
78
79 init_task_split_value(tsv)
80 tsv.rank_id = rank_id # C++ forward never sets rank_id; all ranks use default 0
81
82 # Fill tasks in topological order
83 for op in graph.topological_sort():
84 op.fill_config.fill(cfg, op, tsv)
85
86 # task_num_all = sum of all op tasks (no +1 for terminate)
87 task_num_all = sum(op.task_num for op in graph.topological_sort())
88
89 dispatch_op = graph.get_op("dispatch")
90 swiglu_op = graph.get_op("swiglu")
91 combine_op = graph.get_op("combine")
92 add_terminate(cfg, tsv, combine_op.task_num // tsv.ep * tsv.ep)
93 revise_task_queue(cfg, tsv, dispatch_op.task_num, swiglu_op.task_num)
94 add_dynamic_data(cfg, tsv, dynamic_input_position=6)
95
96 cfg.task_num = task_num_all
97 cfg.atomic_add_values[0] = 1
98 return cfg
99
100
101def write_bin(path: str, data: bytes) -> None:
102 os.makedirs(os.path.dirname(path) or '.', exist_ok=True)
103 with open(path, 'wb') as f:
104 f.write(data)
105 print(f" wrote {len(data):>10,} bytes → {path}")
106
107
108def main():
109 args = parse_args()
110 out = args.output_dir
111
112 tsv = TaskSplitValue(
113 tp=args.tp, ep=args.ep,
114 seq_size=args.seq_size,
115 all_expert_num=args.all_expert_num,
116 top_k=args.top_k,
117 )
118 num_groups = tsv.single_rank_expert_num
119 graph = build_forward_graph(tsv,
120 dispatch_sv=128, up_proj_sv=4096,
121 swiglu_sv=128, down_proj_sv=4096,
122 combine_sv=128,
123 hidden_size=args.hidden_size,
124 intermediate_size=args.intermediate_size,
125 dtype_size=args.dtype_size,
126 num_cube_cores=args.num_cube_cores)
127 # Compute task_num for each operator via split-axis propagation
128 graph.propagate_splits(tsv)
129
130 dispatch_op = graph.get_op("dispatch")
131 up_proj_op = graph.get_op("up_proj")
132 swiglu_op = graph.get_op("swiglu")
133 down_proj_op = graph.get_op("down_proj")
134 combine_op = graph.get_op("combine")
135
136 print(f"[fwd] tp={args.tp} ep={args.ep} seq={args.seq_size} "
137 f"E={args.all_expert_num} topk={args.top_k}")
138 print(f" dispatch={dispatch_op.task_num} up_proj={up_proj_op.task_num} "
139 f"swiglu={swiglu_op.task_num} down_proj={down_proj_op.task_num} "
140 f"combine={combine_op.task_num}")
141
142 # ── Tiling files (rank-independent) ──────────────────────────────────────
143 up_proj_bytes = get_up_proj_tiling_bytes(up_proj_op.split_value,
144 hidden_size=args.hidden_size,
145 intermediate_size=args.intermediate_size,
146 num_groups=num_groups,
147 num_cube_cores=args.num_cube_cores)
148 down_proj_bytes = get_down_proj_tiling_bytes(down_proj_op.split_value,
149 hidden_size=args.hidden_size,
150 intermediate_size=args.intermediate_size,
151 num_groups=num_groups,
152 num_cube_cores=args.num_cube_cores)
153 swiglu_bytes = get_swiglu_tiling_bytes(swiglu_op.split_value,
154 intermediate_size=args.intermediate_size)
155
156 write_bin(os.path.join(out, 'up_proj_tiling.bin'), up_proj_bytes)
157 write_bin(os.path.join(out, 'swiglu_tiling.bin'), swiglu_bytes)
158 write_bin(os.path.join(out, 'down_proj_tiling.bin'), down_proj_bytes)
159
160 # ── Event counters + workspace (rank-independent) ─────────────────────────
161 # all_event_counters: 1024×int32_t zeros — matches C++ reference gen_data
162 write_bin(os.path.join(out, 'all_event_counters.bin'),
163 np.zeros(1024, dtype=np.int32).tobytes())
164 # gmm_workspace: 256 MiB zeros — kernel-internal scratch buffer
165 write_bin(os.path.join(out, 'gmm_workspace.bin'),
166 bytes(256 * 1024 * 1024))
167
168 # ── RuntimeConfig files (one per rank) ───────────────────────────────────
169 for rank_id in range(args.ep):
170 cfg = build_config_for_rank(graph, tsv, rank_id, num_cube_cores=args.num_cube_cores)
171 data = bytes(cfg)
172 path = os.path.join(out, f'runtime_config_input_rank_{rank_id}.bin')
173 write_bin(path, data)
174
175 print("[fwd] done.")
176
177
178if __name__ == '__main__':
179 main()