Diff: origin/master...HEAD, staged and unstaged changes
80 81 82 83 84 85 86 87 88 89
for k, v in adamw_raw.items() } allowed_keys_adamw = inspect.signature(AdamW.__init__).parameters.keys() - {'self', 'params'} filtered_adamw_config = {k: v for k, v in adamw_config.items() if k in allowed_keys_adamw} excluded_adamw_keys = adamw_config.keys() - allowed_keys_adamw if excluded_adamw_keys: logger.info_rank0("Excluded adamw config: %s", list(excluded_adamw_keys)) # 1.2 muon muon_raw = muon_kwargs or {}
92 93 94 95 96 97 98 99 100 101
for k, v in muon_raw.items() } allowed_keys_muon = inspect.signature(Muon.__init__).parameters.keys() - {'self', 'params'} filtered_muon_config = {k: v for k, v in muon_config.items() if k in allowed_keys_muon} excluded_muon_keys = muon_config.keys() - allowed_keys_muon if excluded_muon_keys: logger.info_rank0("Excluded muon config: %s", list(excluded_muon_keys)) # 2. Optimizer Creation optimizers = {}
323 324 325 326 327 328 329 330 331
# Cache Hit Check if cache_key in self._split_sub_pg_cache: sub_pg_map = self._split_sub_pg_cache[cache_key] for _, sub_pg in sub_pg_map.items(): if sub_pg is not None: try: dist.get_rank(group=sub_pg) return sub_pg