While there are already excellent posts on scaling, I wanted to share my own understanding and things i've learned from my past few months and hopefully spark some discussion. I hope this post can shed light for anyone navigating the challenges of scaling up neural networks. And there may be mistakes or inaccuracies, so if you want to correct me or would like to discuss further, please feel free to DM me on X or leave a comment.
Before we dive in, Iâd like to acknowledge that this post is heavily inspired by Simo Ryuâs âWhat to do to scale up?â, and the post theme is based on diffusionflow.github.io by GDM.
"how to scale"
in this post means âhow we should set the initialization standard deviation (init std), learning rate (lr), and batch size (bsz) and other hyperparameters (HPs)
as model size (including both width and depth) and dataset size growâ.
It is so true that as you scale up computing budget, \(C=6ND\) (where \(N\) is model size and \(D\) is dataset size), your model tends to perform better.
Fig. Scaling Law is Universal Behavior. For Every Tasks, It Works. Source from Scaling Laws for Neural Language Models
However, scaling law papers never tell us how to set the lr or bsz for a given computing budget, C
.
This is a non-trivial issue, because âmore compute always leads to better performanceâ is not guaranteed unless youâre using near-optimal HPs for each \(C\).
If you fail to find the right HPs, you might conclude:
âWTF? larger isnât better? okay, letâs quit scaling.â
and you're not gonna make it
.
Thatâs exactly why we need to understand âhow to scale.â
So, the main questions we should know are:
From the perspective of model size, there is a theoretically grounded method for ensuring optimal lr transfer across different scales. Itâs called Maximal Update Parameterization (muP)
.
Youâve probably heard about it on X (formerly twitter).
Compared to muP, Standard Parameterization (SP)
(e.g. He, Lecun, and other default PyTorch initializations) is not designed for scaling.
These methods are well-defined only at initialization, not during training
.
That means they donât account for how weight updates affect model behavior and it finally leads to whatâs known as âleft-shifted lr curvesâ.
But we donât want to this shifting trends.
Our goal is to define a parameterization in which optimal behavior transfers as model size scales. That means we need to properly set per-layer:
And by anaylzing optimization behavior carefully, muP finally give us right parameterization to scale up model size properly.
Fig. Source from Greg Yangâs Video
What really matters isnât something like:
âFor the 8B model, optimal lr is 3e-4; for 405B itâs 8e-5.â
Rather, itâs:
âThe optimal lr at 40M is 0.00195, so we should halve the lr when we double the width (hidden size).â
If we define this scaling rule properly, we can efficiently tune larger modelsâand match scaling lawsâat relatively low cost.
The authors of Tensor Program (TP)-V, aka muTransferâGreg Yang and Edward Huâwere part of the early OpenAIâMicrosoft collaboration. Andrew Carr, formerly at OpenAI, confirmed that muP was likely used in training GPT. (GPT-4 technical report also refers to TP-V)
Fig. Source tweet
Note that, even this elegant muP framework does not consider dataset scaling. Weâll discuss this point later in the post.
Itâs important to note that muP literally stands for Maximal Update
.
Many people often misunderstand muP is only about hyperparameter (HP) transfer
, but thatâs not true.
Of course, HP transfer is a nice property.
it allows us to avoid extensive grid searches over HPs for a given compute budget \(C\) when predicting scaling laws or training very large models.
But muP is fundamentally about ensuring that every layer learns features maximally at each optimization step
, by assigning per-parameter HPs (lr, init std, etc.), even as the networkâs width goes to infinity
.
Fig. Openreview of TP-5. Greg Yang explain âmuP is not only for HP transferâ
So, muP is designed to enable maximal feature learning, but Why SP is not enough?
In SP, we often find that some weights receive disproportionately large gradients
, while others receive gradients that are too small
.
If we reduce the lr to stabilize the weights that receive large gradients, the others can become stuck meaning they donât learn enough features, which leads to inefficient training.
On the other hand, if we increase the lr too much, the model may diverge.
So we are in the dilemma.
Fig. In TP-V paper, it is clearly mentioned.
Thatâs why we need to carefully analyze âper-layer behaviorâ and adopt âper-layer parameterization (i.e., per-layer lr, init std, and multiplier)`, as done in muP.
Of course, there may be other viable approaches.
Normalization techniques such as BatchNorm and LayerNorm can help correct imbalances and improve optimization.
Adaptive optimizers may also help, but normalization and Adam alone are insufficient.
Recently proposed advanced optimizers like Muon (with proper scaling factors) and SCION show that lr can transfer across model widths.
(Not sure whether they guarantee maximal feature learning, though.)
Fig. Source from Training Deep Learning Models with Norm-Constrained LMOs. lr can be transferred with SCION optimizer
IMO, most optimization and training techniques share the same motivation: to stabilize, balance, and improve training dynamicsâseen from a unified perspective.
Anyway, muP is based on three core desiderata
:
By solving for these desiderata, you get a parameterization that not only encourages maximal feature learning but also theoretically guarantees HP transfer across width. (Again, iâd like to say HP transfer was never the primary goal. it follows)
Why does training stability matter?
Because muP is a method for feature learning in the infinite-width regimeâwhere hidden sizes grow larger and larger (e.g., 1024, 4096, âŚ, 40k).
This means muP must ensure maximal learning regardless of width, so we donât want pre-activations to scale with the width \(n\).
Thatâs why muP enables training dynamics to be transferred across different model scales.
Of course, there are many other parameterizations, such as Standard Parameterization (SP), Neural Tangent Kernel (NTK), Mean Field Theory (MFT) and muP.
One might ask, âWhy is muP unique for maximal feature learning?â
I wonât go into full detail hereâcheck the original paperâbut consider this:
In the kernel regime (e.g., NTK), models are effectively frozen in feature space.
So even if your NTK-parameterized BERT is pretrained successfully, its learned representation is weak.
Assume you tried to concatenate a randomly initialized linear layer to its hidden features and fine-tune.
If the model trained with NTK parameterization, it wouldnât work well because the hidden layers never learned meaningful features even though itâs pretrained performance is not bad.
Thatâs the problem and muP doesnt want to allow this.
Fig. The caricature provided in paper might not be intuitive, but what authors want to say is that among other paramterizations, only muP allows stable, non-trivial feature learning. some (NTK) are stuck in kernel regime and another (SP with large lr) diverges.
For more motivation (even though below example isnât strictly about parameterization, but iâd like to raise the research question), transformers with pre-norm often show redundancy in deeper layers. Post-norm based transformer, which is originally proposed in Attention Is All You Need is better than Pre-norm variant at performance, but post-norm does not preserve upstream gradients (identity mapping), so it requires lr warmup stage or other tactics to improve training stability. However, for pre-norm it makes harder for the residual features in deeper layers to contribute to the modelâs main residual streamâa phenomenon known as ârepresentation collapseâ as a side-effect.
So in this case, feature learning across layers can become uneven, and it leads to waste of compute
.
Many researchers studied normalization module or parameterization like residual post norm (sandwich norm), mix-ln, deep-norm, or depth scaled sandwich norm, ⌠to achieve both training stability, and effective feature learning.
Fig. depth scaled sandwich norm from Pangu Ultra
IMO, i believe these examples are related to parameterization too.
In TP-5, the authors show that muP not only transfers optimal lr across width,
but also achieves better performance overall in Language Modeling (LM) task (GPT-3 setup).
Fig. muP not even transfrer optimal lr but also show better performance.
"wider is always better"
That said, in real-world scenarios, maybe itâs not true that muP always outperform SP.
In my own experience, muP has shown stronger benchmark results at small to medium scales (e.g., 200â300B tokens),
but the returns seem to diminish as scale increases.
my guess is that even though other parameterizations such as SP is in the lazy training regime,
the embedding and output layers eventually start learn something in later.
(or maybe⌠itâs just my skill issue)
Now weâre gonna derive unique scaling rule for maximal update and HP transfer.
Before diving into muP, letâs briefly review Standard Parameterization (SP). SP focuses on initialization. But what defines âgoodâ at initialization?
Fig. Source from Sergey Levineâs ML Lecture (CS182)
Assume weight matrix is i.i.d sampled from normal distribution and input feature of this layer is also i.i.d. In forward propagation, the output pre-activations of this layer, \(z=Wx \in \mathbb{R}^{\text{fan-out} \times 1}\)âs elements have \(\text{fan-in} \cdot \color{red}{\sigma_W^2}\) variance if \(x \in \mathbb{R}^{\text{fan-in} \times 1} \sim \mathcal{N}(0,1)\) (it is fair assumption because we typically standardize inputs of each modules), and \(W \in \mathbb{R}^{\text{fan-out} \times \text{fan-in}} \sim \mathcal{N}(0,\sigma_W^2)\).
\[\begin{aligned} & z_i = \sum_j^{\text{fan-in}} W_{ij} x_j & \text{no bias }\\ & \mathbb{E}[z_i^2] = \sum_j^{\text{fan-in}} \mathbb{E}[W_{ij}^2] \mathbb{E}[x_j^2] = \text{fan-in} \cdot \color{red}{\sigma_W^2} \cdot \sigma_x^2 & \\ \end{aligned}\]To keep this around 1, we should counter this value by \(\text{fan-in}\).
And this is why SP is called fan-in (input feature dim) variance
.
It means every element (coordinate) of weight matrix has roughly \(1/\text{fan-in}\) size value.
Also, there is a lot of init method like Xavier Init, He Init and so on. Especially, for Xavier Init, it consider backpropagation at initalization point. Because \(dL/dx\) is outer product of upstream gradient and weight matrix, we can derive backward propagation similar to forward, \(dL/dx \sim \mathcal{N}(0, \text{fan-out}\sigma_W^2 \sigma_x^2)\). If matrix shape is \(n \times n\), Xavier is same as Lecun init.
Fig.
Howerver, as i mentioned, SP is only for initilziation
.
What we want is every (pre-)activations has constant scale (\(\Theta(1)\)) at any time in training step,
regardless of the hidden size of neural network.
In Tensor Program (TP) (muP is from TP-4 and 5),
It does not only care initialization but also training dynamics.
This is why muP is called mu + Paramtereization
.
Parameterization includes all three thing, per parameter 1) initialization (init std), 2) multiplier, 3) learning rate (lr)
but SP only describe initialization standard deviation (init std).
Again, for deriving muP, you should remember muPâs three desideratums. We just want our model to act like this.
Fig.
However, once you try to derive muP, it could be overwhelming because of too many mathematical symbols.
But donât worry, because in this post, weâll derive muP as if we had no brain, keeping things as simple as possible.
All you need is Law of Large Numbers (LLN), Central Limit Theorem (CLT), a bit of intuition about SGD and Adam, and the important fact that Neural Network (NN) training is indeed just bunch of dot products (forward pass) and outer products (backward pass).
(Iâm not going to derive muP very strictly in this post because itâs mathematically heavy.)
Fig.
Typically, dot product of n-dimensional vectors get larger when width, n goes to infinity (we should counter this behavior). The key rule of muP is as follows.
Fig. TP-V
To understand muP, letâs say at \(t\) stepâs weight matrix is \(W_{t} \in \mathbb{R}^{n \times n}\), input \(x \in \mathbb{R}^{n \times 1}\) and itâs output pre-activation, \(z_t \in \mathbb{R}^{n \times 1}\). So, the output pre-activation at initialization step (\(t == 0\)) is as follows
\[z_0 = \underbrace{W_{0} x}_{\Theta(1)}, W_0 \sim \underbrace{\mathcal{N}(0, (\frac{1}{\sqrt{n}})^2)}_{\text{like SP}}\]For the forward pass, we donât want the output to blow up with width.
That means we want to preserve pre-activation output as \(\Theta(1)\).
Here \(\Theta(1)\) means output pre-activationâs scale does not depend on model width (embedding dim), \(n\).
That is, muP is scale invariant
(good behavior for HP transfer).
The symbols \(\Theta(\cdot)\) and \(O(\cdot)\) are known as asymptotic notations
, commonly referred to as Big O notation. While widely used in CS to describe algorithmic complexity, they actually originate from mathematics (and Greg Yang comes from a mathematical background as far as i know).
There are three asymtotic notation, and youâll notice both \(O\) and \(\Theta\) are used for muP derivation.
Fig. Source from here
Fig. Source from A Spectral Condition for Feature Learning
Anyway, we can just use fan-in variance (fan-in (featrue in) dim is n) like SP for forward stability at initilization
.
But after single optimization step (SGD or Adam; letâs assume SGD first),
the weight is updated as:
where \(\eta\) is lr, and \(g_t\) is upstream gradient. Now \(t+1\) stepâs pre-activation can be described as
\[\begin{aligned} & z_{t+1} = W_{t+1} x' & \\ & = (W_{t} + \eta \nabla_{W_{t}}L)x' & \\ & = W_t x' + \eta g_t (x^T x') & \\ \end{aligned}\]where \(x'\) is new input feature of \(t+1\) step and SP does not consider how \(\Delta W\) contributes \(t+1\) stepâs pre-activation output at all. Dot product between column vector of weight matrix and input features are not correlated (they are i.i.d sampled) at the init point, but it start to be correlated after just one optimization step .
In above formula, we can reparameterize \((x^T x')\) term as \(n \cdot (x^T x')/n\) or \(\sqrt{n} (x^T x')/\sqrt{n}\), and these quantities, \((x^T x')/n\) or \((x^T x')/\sqrt{n}\) will be converged to some deterministic scalar following LLN or CLT as \(n\) goes to infinity. If \(x\) and \(x'\) are not correlated, it follows CLT, and if they are correlated, it follows LLN.
Why?
Letâs recap LLN and CLT briefly and think bout dot product of two vectors with 2 different conditiones, correlated or not. First, the Law of Large Numbers (LLN) is a basic concept in probability and statistics stating that the mean of \(n\) random samples converges to the mean as \(n\) becomes large (goes to infinity):
\[\frac{1}{n} \sum_{i=1}^n x_i \;\rightarrow\; \mathbb{E}[X], \quad\text{as } n \rightarrow \infty\]Here, each sample must be drawn independently from the same distributionâthe âindependent and identically distributed (i.i.d.)â assumption where itâs e.g. if you flip a fair coin 1,000 times, the outcome of the 550th flip does not depend on the results of the first 549 flips, and each flip follows the same \(\mathrm{Bernoulli}(1/2)\) distribution.
The Central Limit Theorem (CLT) is another convergence theorem for i.i.d. samples. It states that, for sufficiently large \(n\), the distribution of the sample mean of \(n\) draws approaches a normal distribution, regardless of the original distributionâs shape. Concretely, if you draw \(n\) samples from \(\mathcal{N}(\mu,\sigma^2)\), then the distribution of their mean,
\[\frac{1}{n}\sum_{i=1}^n x_i,\]converges to \(\mathcal{N}\bigl(\mu,\sigma^2/n\bigr)\). A common misconception is to think that CLT says âif you draw many samples, those individual samples become normally distributed,â but in fact CLT refers to the distribution of the sample mean, not the raw samples.
Fig. Source: Wikipedia
If we assume \(\mathbb{E}[x_i]=0\), CLT can be written more simply as
\[\begin{aligned} &\frac{1}{\sqrt{n}} \sum_{i=1}^n x_i \;\rightarrow\; \mathcal{N}(0,\sigma^2), \quad \text{as } n \rightarrow \infty,\\ &\text{where } \mathbb{E}[x_i^2] = \sigma^2. \end{aligned}\]More generally, then centering each sample by its mean,
\[\frac{1}{\sqrt{n}} \sum_{i=1}^n \bigl(x_i - \mathbb{E}[X]\bigr) \;\rightarrow\; \mathcal{N}(0,\sigma(X)^2), \quad \text{as } n \rightarrow \infty.\]And finally, we can write sum of i.i.d sample as follows
\[\begin{aligned} S_n =\sum_{i=1}^n x_i &= \underbrace{n\mu}_{\text{by LLN}} + \underbrace{\sqrt{n}\,\sigma\,\mathcal{N}(0,1)}_{\text{by CLT}} + \text{lower-order terms},\\ &\text{where } \mathbb{E}[x_i]=\mu,\;\mathbb{E}[x_i^2]=\sigma^2 \end{aligned}\]This is quite intuitive because one sample from \(\mathcal{N}(\mu,\sigma^2)\) will lie near \(\mu + \sigma\,\mathcal{N}(0,1)\), and \(\mathrm{Var}(X+Y) = \mathrm{Var}(X) + \mathrm{Var}(Y)\) if X, Y are independent and \(\mathrm{Var}(cX) = c^2\,\mathrm{Var}(X)\). in this form, if mean is non-zero, \(n\mu\) (from LLN) become dominant and if itâs zero, the next largest term, \(\sqrt{n}\sigma\,\mathcal{N}(0,1)\) (from CLT) become dominant.
Fig. TP-V
We discuss about dot product of two vectors and it looks like
\[x^T x = \sum_{i=1}^n x_i x'_i = (x_1x'_1 + \cdots + x_nx'_n)\]Suppose if two vectors are correlated. In other words, if they are aligned, each elementsâ mean value become non-zero, so it follows LLN. And intuitively, in large enough model scenario, typically each coordinate of \(x\) is i.i.d and \(x\) and \(x'\) are correlated even though they are mini-batch sampled, So \((x^T x')/n\) converges to some deterministic scalar by LLN.
\[\begin{aligned} & z_{t+1} = W_{t+1} x' & \\ & = (W_{t} + \eta \nabla_{W_{t}}L)x' & \\ & = W_t x' + (\color{red}{n} \eta) g_t \underbrace{ \frac{(x^T x')}{\color{red}{n}}}_{\text{deterministic scalar}=c} & \\ \end{aligned}\]and here deterministic scalar is \(c=\mathbb{E}Z^{x}Z^{x'}\) where \(Z^{x}, Z^{x'}\) are random variable of each vectors.
Hold on, what just happen?
By introducing \(n/n\) term,
now we got \(n \times \eta\) term, and it means "if model width goes to infinity, the update quantity will be blown up"
,
and we donât want to allow this for stability.
So, we should counter this behavior by \(n\).
However, itâs upstream gradient, \(g_t\) already consists of \(\Theta(1/n)\) scale entries, we donât need to counter this in SGD optimizer setup. (simply put, gradient with respect to the output logit is \(\Theta(1)\) and weight matrix has \(\Theta(1/\sqrt{n})\) coordinate, but in muP, we divide output logit by \(\sqrt{n}\) or use \(\Theta(1/n)\) std. so we can think the upstream gradient \(g_t = W \otimes dL/dy\) can have typical size of \(\Theta(1/n)\). we will discuss this later in this post)
\[\begin{aligned} & z_{t+1} = W_{t+1} x' & \\ & = (W_{t} + \eta \nabla_{W_{t}}L)x' & \\ & = W_t x' + (\color{red}{n} \eta) \underbrace{g_t}_{\Theta(1/n)} \underbrace{ \frac{(x^T x')}{\color{red}{n}}}_{\text{c}} & \\ \end{aligned}\]But if we use adaptive optimizer like Adam,
gradient is rescaled by elementwise,
so, we should scale lr by \(1/n\) to explicitly counter.
Thatâs why we can simply put âmuP = 1/n lr rule for hidden layers (under Adam)â
.
Actually, itâs more than CLT, LLN and dot product. we should consider full optimization trajectory with momentum and other factors. and bsz is greater than one in real world scenario where gradient is averaged by multiple rank 1 gradient and things go wild. if you want to see full derivation, i recommend you to read TP-IVb, TP-5 or A Spectral Condition for Feature Learning.
Anyway, itâs all about CLT, LLN and dot product intuitively. Choose LLN or CLT based on whether the vectors are correlated or not. And this scaling logic holds for any dot product. Thatâs it.
Fig. SP is only well defined at the initialization point (confort zone) but not during training. typically, if there is correlation, dot product become \(\sqrt{n}\) times larger than itâs not. Source from Greg Yangâs Blog
So, compared to SP, muP ensures every transformer moduleâs pre-activation donât blow up in anytime during training times as n goes tends to infinity. Below figure is called coordinate check where coordinate means each element value of vector.
And now, one can accept it ensures maximal feature learning for every layers without instability, and all init std and update quantity is invariant to width n, so optimal training behavior (optimal lr) will be transferred
.
Not that, however, not all parameters are same
,
which means that \(n \times n\) hidden matrix (we call this matrix-like tensor) has two infinite dimensions.
but vector-like has one.
Fig. Vector-like vs Matrix-like. Source from Greg Yangâs Blog
For example, in transformer, embedding and unembedding (lm head or readout) matrix has \(W_{emb} \in \mathbb{R}^{V \times n}\) dimension, where \(V\) is vocab size.
and there is only one infinite dimension
.
So it behaves different and this is why we have to apply separate rule and this is the reason why muP implementation table has 3 category (hidden, embedding, unembedding).
Fig. Openreview of TP-5. Greg Yang explain âmuP is not only for HP transferâ
Again, to achieve both âtraining stabilityâ, âmaximal feature learningâ, and ânon-trivialityâ, we should keep below two things in below three desideratum.
Fig.
Then, we can derive scaling rule for embedding and unembedding matrices, and bias term in similar way.
Actually, What muP want to say is simply as follows
Fig.
Actually, TP authors defines abc-parameterization for muP,
where a
stands for multiplier, b
for init std, and c
for lr per parameters,
and we should define this a, b, c to make model behavior does not change when width, n goes to infinity.
a
: we parameterize each weight parameters as \(W^{l} = n^{-a_l} w^l\) for actual trainable param \(w^l\)b
: we init each \(w^l \sim \mathcal{N}(0, n^{-2b_l})\)c
: the SGD lr is \(\eta n^{-c}\) for some width-independent \(\eta\)(here multipler doesnât exist in conventional initialization method like Lecun init. it is scaling factor applied after linear transform)
(For simplicity, i didnât mentioned abc-parameterization earlier but itâs crucial to further understand muP)
And the main Question: âHow to correctly we set per layers a, b, c
to make every layerâs activation not blown up (training stability
) and to be trained equally? (maximal feature learning
) as Neural Network (NN)âs width goes to infinity?â
Fig.
And the mathematically derived answer to this question is TP 4 and 5, and it ensure maximal feature learning and training stability in inifinite width regime. (like we simply derive above. see papers for more strict and beautiful mathemetical derivation)
Fig. Maximal Update Parameterization Table
And because muP is well defined for any NN building block, it is also valid for Noam architecture where the model consists of Rotary Positional Embedding (RoPE), RMSNorm, and Gated Linear Unit (GLU).
Fig. my coordinate check on LLaMa-3 architecture.
Now, to further understanding and implementing muP with flexiblilty,
weâre gonna discuss abc-parameterization symmetry
.
abc-parameterization symmetry means, if you properly set this a,b,c (multiplier, std, base lr each),
NNâs fwd, bwd will stay same.
So, this is the reason why there are 3 different (but same) tables in TP-5.
Fig.
But 'why do we need alternative forms to implement muP?'
This is because one might want to use tie embedding strategy for saving memory or better performance.
Fig.
Simply put, you can accept this rule like âOh, if the multiplier is scaled by a factor of \(\theta\), then the init std should also be reduced by \(\theta\). And since the modelâs init std becomes smaller, the magnitude of its updates should also be scaled down accordingly!â.
Letâs think bout l-th layerâs weight parameter, \(W\).
\[\begin{aligned} & W^l = A W^l & \\ & W^l \sim \mathcal{N}(0, B) & \\ & \eta_{eff} = \eta C & \\ \end{aligned}\]It is easy to prove that l-th layerâs output (pre-activation) stays same in forward if we scale multiplier by \(\theta\) and init std by \(1/\theta\).
\[\begin{aligned} & A \leftarrow A\theta, B \leftarrow B/\theta, C \leftarrow C/\theta^2 & \\ & z_t^l = A \cdot W_t^l x, W_t^l \sim \mathcal{N}(0,B^2) & \\ & = (A \color{red}{\theta}) \cdot \frac{W_t}{\color{red}{\theta}} x & \\ \end{aligned}\]And for backward, if we update weight parameter using SGD (no momentum), it is also easy to implement.
\[\begin{aligned} & z_0^l = A \cdot W_0^l x, W_0^l \sim \mathcal{N}(0,B^2) & \\ & W_{1}^l = W_0^l - \underbrace{C \cdot \eta \cdot(\nabla_{W_0^l}L)}_{\text{SGD Update}} & \\ & = W_0^l - C \cdot \eta \cdot (A \frac{dL}{dz^l_0} x^T) & \\ & z_{1}^l = A \cdot W_{1}^l x' & \\ & = A \cdot (W_0^l - C \cdot \eta (A \frac{dL}{dz^l_0} x^T)) x' & \\ \end{aligned}\] \[\begin{aligned} & A^{\ast}=A\theta, B^{\ast}=B/\theta, C^{\ast}=C/\theta^2 & \\ & {z'}_0^l = A' \cdot W_0^{l\ast} x, W_0^{l\ast} \sim \mathcal{N}(0, {B'}^2) & \\ & = (A \color{red}{\theta}) \cdot \frac{W_0^l}{\color{red}{\theta}} x & \\ & = A \cdot W_0^l x & \\ & z_{1}^{l\ast} = A' \cdot W_{1}^{l\ast} x' & \\ & = A' \cdot (W_{0}^{l\ast} + C' \cdot \eta \cdot \nabla_{W_{0}^{l\ast}}L ) x' & \\ & = (A \color{red}{\theta}) \cdot ( \frac{W_0^l}{\color{red}{\theta}} - \frac{C}{\color{red}{\theta^2}} \cdot \eta \cdot ((A \theta ) \frac{dL}{dz^{l\ast}_0} x^T)) x' & \\ & = z_1^l & \\ \end{aligned}\]For Adam(W) optimizer, there is a difference that lr is scaled by \(1/\theta\), not \(1/\theta^2\). You can easily derive this too because adaptive optimizer already do âper parameter lr scalingâ.
\[\begin{aligned} & A \leftarrow A\theta, B \leftarrow B/\theta, C \leftarrow C/\theta & \\ \end{aligned}\]Adam update is as follows (adam is scale invariant). (you can check Tensor Programs IVb: Adaptive Optimization in the Infinite-Width Limit to further understanding.)
\[\begin{aligned} & z_0^l = A \cdot W_0^l x, W_0^l \sim \mathcal{N}(0,B^2) & \\ & G_0 = \nabla_{W_0^l} L = A \frac{dL}{dz_0^l} x^T & \\ & m_0 = \beta_1 m_{init} + (1-\beta_1) G_0 & \\ & v_0 = \beta_2 v_{init} + (1-\beta_2) (G_0)^2 & \\ & \hat{m_0} = m_0/(1-b_1^0) & \\ & \hat{v_0} = v_0/(1-b_2^0) & \\ & \Delta {W_0^l} = \underbrace{\frac{\hat{m_0}}{\sqrt{\hat{v_0} + \epsilon}}}_{\text{Adam Update}} & \\ & z_{1}^l = A \cdot W_{1}^l x' & \\ & = A \cdot (W_0^l - C \cdot \eta \Delta {W_0^l}) x' & \\ \end{aligned}\] \[\begin{aligned} & A^{\ast}=A\theta, B^{\ast}=B/\theta, C^{\ast}=C/\theta & \\ & {z}_0^{l\ast} = A' \cdot {W}_0^{l\ast} x, {W}_0^{l\ast} \sim \mathcal{N}(0,{B'}^2) & \\ & = (A \color{red}{\theta}) \cdot \frac{W_0^l}{\color{red}{\theta}} x & \\ & = A \cdot W_0^l x & \\ & G'_0 = \nabla_{W_0^{l'}} L = (A\color{red}{\theta}) \frac{dL}{dz_0^l} x^T & \\ & \Delta W_0^{l\ast} = \frac{\hat{m^{\ast}_0}}{\sqrt{\hat{v^{\ast}_0} + \epsilon}} = \frac{\theta \hat{m_0}}{\sqrt{\theta^2 \hat{v_0} + \epsilon}} & \\ & \approx \nabla_{W_0^l}L & \\ & z_1^{l\ast} = (A \color{red}{\theta}) \cdot (\frac{W_0^l}{\color{red}{\theta}} - \frac{C}{\color{red}{\theta}} \cdot \eta \Delta W_0^{l\ast}) x' & \\ & = z_1^l & \\ \end{aligned}\]# https://arxiv.org/abs/2310.02244
# https://x.com/thecharlieblake/status/1799029085827649930
import torch
from torch import manual_seed, nn, optim, randn
def get_second_fwd_output(mult, init_std, lr, opt):
assert opt in [optim.Adam, optim.SGD]
manual_seed(1234)
l = nn.Linear(1024, 2048, bias=False)
nn.init.normal_(l.weight, std=init_std)
model = lambda x: l(x) * mult
args = {'lr':lr, 'eps':0} if opt==optim.Adam else {'lr':lr}
opt = opt(l.parameters(), **args)
x = randn(512, 1024).requires_grad_()
y1 = model(x).mean()
y1.backward(); opt.step()
y2 = model(x).mean()
print(y1, y2)
return y2
def adjust_sgd(mult, init_std, lr, theta):
return mult*theta, init_std*theta**-1, lr*theta**-2
def adjust_adam(mult, init_std, lr, theta):
return mult*theta, init_std*theta**-1, lr*theta**-1
theta = 2.5
mult = 1; init_std = 0.02; lr = 1
assert torch.allclose(
get_second_fwd_output(mult, init_std, lr, opt=optim.Adam),
get_second_fwd_output(*adjust_adam(mult, init_std, lr, theta), opt=optim.Adam),
)
assert torch.allclose(
get_second_fwd_output(mult, init_std, lr, opt=optim.SGD),
get_second_fwd_output(*adjust_sgd(mult, init_std, lr, theta), opt=optim.SGD),
)
tensor(2.5274e-05, grad_fn=<MeanBackward0>) tensor(-35.6423, grad_fn=<MeanBackward0>)
tensor(2.5274e-05, grad_fn=<MeanBackward0>) tensor(-35.6423, grad_fn=<MeanBackward0>)
tensor(2.5274e-05, grad_fn=<MeanBackward0>) tensor(-0.0009, grad_fn=<MeanBackward0>)
tensor(2.5274e-05, grad_fn=<MeanBackward0>) tensor(-0.0009, grad_fn=<MeanBackward0>)
Now we know how to scale up model size â we know how to transfer optimal HPs from small-scale proxy experiments.
However, even though muP is theoretically well-defined,
it does not guarantee HP transfer across training tokens or batch size
.
Especially, the fact that âmuP does not ensure HP transfer across training horizonâ is not widely spread.
Fig. It may be not a bug. Source from Interview on OLMo team
Letâs consider increasing the training horizon.
For example, suppose we train a small-scale proxy model (40M parameters) with 10B tokens,
and we want to transfer the optimal lr to a 70B model trained on 10T tokens.
Will the optimal lr remain the same?
maybe, not.
Intuitively, the model will remain at its peak lr longer if training horizon is scaled up.
So we should counter this, we should decrease lr if we want to transfer lr.
This leads to a left-shift trend in the optimal lr curve even though you use muP
Fig. left shift trends. Source from Scaling Optimal LR Across Token Horizons
Now consider bsz.
Itâs common to increase the bsz when training larger models to achieve better training efficiency (throughput).
(As long as you donât cross the critical bsz
â weâll discuss that later.)
But increasing bsz reduces training steps.
So we can think like
âHmm, gradients are more accurate but training steps are fewer, so letâs raise the lr to compensate.â
In other words, to counter the shortened training horizon,
we must increase lr,
and this leads to a right-shift trend
to reach the same validation loss as smaller batch training.
And conventional rule for bsz-lr scaling rule is sqrt(n).
(you can see this post from Sadhika Malladi for more)
Fig. Soruce from Power Scheduler: A Batch Size and Token Number Agnostic Learning Rate Scheduler
In the TP-5 paper, the authors show that lr can be transferred well across bsz,
Fig.
but you should not overlook they use âsame training stepsâ for this experiment. That means they used more FLOPs for increased bsz setup and they donât need to care bout bsz-training step tradeoff.
Fig. Source from TP-5
So, if we want to scale both model size and training horizon (tokens, bsz),
we need to understand optimal lr scaling for all three dimensions.
muP doesnât tell us how to adjust lr with respect to bsz or training tokens.
This table is primarily derived from Tensor Program (TP) 4 and 5.
It assumes that model size growth is based only on width (hidden size or embedding size), not depth (number of layers), and that youâre using an adaptive optimizer like Adam.
Also, as discussed earlier, some scaling rules (like LR vs. bsz and training horizon) are based on my interpretation and findings from several papers.
This table is heavily inspired by âWhat to do to scale up?â from Simo Ryu.
hparams | embedding | hidden | residual_out | unembedding (readout) |
---|---|---|---|---|
init_std (b) | \(\sigma_\text{embed}\) | \(\sigma_\text{hidden} \cdot (\color{red}{\tilde{n}})^{-0.5}\) | \(\sigma_\text{res-out} \cdot (\color{red}{\tilde{n}})^{-0.5} \cdot (2 n_\text{layers})^{-0.5}\) | \(\sigma_\text{un-embed}\) |
multiplier (a) | \(\alpha_{\text{embed}} \cdot 1\) | \(\alpha_{\text{hidden}} \cdot 1\) | \(\alpha_{\text{res-out}} \cdot 1\) | \(\alpha_{\text{un-embed}} \cdot (\color{red}{\tilde{n}})^{-1}\) |
adamw lr (c) | \(\eta_{\text{embed}} \cdot (\color{green}{\tilde{b}})^{0.5} \cdot {(\color{blue}{\tilde{d}})^{\alpha_{\text{data}}}}\) | \(\eta_{\text{hidden}} \cdot (\color{red}{\tilde{n}})^{-1} \cdot (\color{green}{\tilde{b}})^{0.5} \cdot {(\color{blue}{\tilde{d}})^{\alpha_{\text{data}}}}\) | \(\eta_{\text{res-out}} \cdot (\color{red}{\tilde{n}})^{-1} \cdot (\color{green}{\tilde{b}})^{0.5} {(\color{blue}{\tilde{d}})^{\alpha_{\text{data}}}}\) | \(\eta_{\text{un-embed}} \cdot (\color{green}{\tilde{b}})^{0.5} {(\color{blue}{\tilde{d}})^{\alpha_{\text{data}}}}\) |
adamw moment | \((1-\color{green}{\tilde{b}}(1-\beta_1),\\1-\color{green}{\tilde{b}}(1-\beta_2))\) | \((1-\color{green}{\tilde{b}}(1-\beta_1),\\1-\color{green}{\tilde{b}}(1-\beta_2))\) | \((1-\color{green}{\tilde{b}}(1-\beta_1),\\1-\color{green}{\tilde{b}}(1-\beta_2))\) | \((1-\color{green}{\tilde{b}}(1-\beta_1),\\1-\color{green}{\tilde{b}}(1-\beta_2))\) |
adamw epsilon | \(\epsilon \cdot (\color{green}{\tilde{b}})^{-0.5}\) | \(\epsilon \cdot (\color{green}{\tilde{b}})^{-0.5}\) | \(\epsilon \cdot (\color{green}{\tilde{b}})^{-0.5}\) | \(\epsilon \cdot (\color{green}{\tilde{b}})^{-0.5}\) |
adamw weight_decay | \(\lambda\) | \(\lambda\) | \(\lambda\) | \(\lambda\) |
Fig. Table 8 from TP-V
Fig. Table 2 from unit-muP. it is based on Table 8 from TP-V but also reflects depth scaling from TP-VI (see residual branchâs multiplier)
width
: Width refers to the hidden size (or head dimension) of a neural network (e.g., in Transformers). For small-scale proxy (base) models, the shape of a specific layerâs weight matrix is given by \(W_l \in \mathbb{R}^{\text{fan-in}_\text{base} \times \text{fan-in}_\text{base}}\). In TP-5, Tables 3, 8, and 9 describe parameterization in terms of fan_in
and fan_out
, corresponding to input and output feature dimensions. In this table, we define \(\tilde{n} = \text{fan-in} \cdot \frac{1}{\text{fan-in}_\text{base}}\). If \(\text{fan-in}_\text{base} = 1\), it recovers to Table 8. For example, if \(\sigma = 1/\sqrt{1024} \approx 0.031\), then init std becomes \(1/\text{fan-in}\).
multiplier
x = hparam_multiplier Ă width_scaling_multiplier Ă embedding_layer(x)
,width_scaling_multiplier
remains constant as width increases, and hparam_multiplier
refers to things like lr.residual branch's output layers
attention logit scaling
attn_multiplier
, but we typically set it to 1.0.
Fig. I also overlooked this note from Attention Is All You Need paper. they designed transformer modules in literally every point of views (stability, parallelizability, âŚ).
across bsz
across dataset size
Fig. contribution of increased dataset size in compute optimal setup
What HPs should we tune?
Zero-variance initialization
:tensorflow adamw
or truly decoupled adamw
defualt is 1e-4 because pytorch default multiply WD value by lr.
weight_decay
as weight_decay / group['lr']
For WD, it is noteworthy that default pytorch adamw is actually coupled with lr.
In my experiments with over 200B tokens and 8B++ model sizes,
I experience that loss spikes and finally diverges.
This is because in large scale NN training, mutransfered modelâs effective lr for hidden matrices become very small compared to small scale proxy. And because WD is coupled iwth lr, discounted per-parameter lr for hidden matrices makes WD extremely small, and param norm growth never recovered even lr keep decreased by scheduler.
So I strongly recommend using truly independent weight decay
, not the PyTorch default.
Although muP is theoretically well-defined and Greg Yang claims itâs a unique solution,
it may offer only marginal performance improvements over other parameterizations or muP variants in practice.
In fact, even Standard Parameterization (SP) with a \(1/n\) lr scale can exhibit HP transferability.
Fig.
According to Small-scale proxies for large-scale Transformer training instabilities,
a muP variant called muP (simple)
adopts only the \(1/n\) LR scaling from muP while still using fan-in variance,
and yet it successfully transfers optimal lr across width.
Fig. Small-scale proxies for large-scale Transformer training instabilities
Moreover, GDM shows that every parameterization can admit HP transfer with the right adjustments,
and even reports that SPâwith a novel per-layer lr schedule outperforms muP in some cases.
(Though, whether SP + per-layer LR can still be considered âSPâ is debatable.)
Fig. Scaling Exponents Across Parameterizations and Optimizers
How did earlier papers approach parameterization?
Surprisingly, papers like Attention Is All You Need and the Pathways Language Model (PaLM) share a few common traits.
First, in the original Transformer paper, it is mentioned that the embedding matrix is multiplied by \(\sqrt{d_{\text{model}}}\).
However, there is no detailed explanation of the initialization method or the reasoning behind this scaling.
The only specific note is that the embedding and language modeling (LM) head weights are tied (i.e., shared).
Fig.
There are some discussions about this, but it is unclear for me. Itâs like âsinusoidal embeddingâs scale is bounded by (-1, 1), so embedding matrix should be scaled by sqrt of hidden sizeâ, but i donât know what init std is.
Fig.
In PaLM, researchers uses different parameterization. (I call this parameterizationânot just initializationâbecause they adopt Adafactor instead of Adam, allowing per-layer lr, which is quite similar in spirit to muP.) They keep tied embeddings, set init std to 1, and scale the pre-softmax logits by \(1/\sqrt{n}\), where \(n\) is equivalent to \(d_{\text{model}}\).
Fig.
As the text states, LayerNorm modules normalize their inputs, which means the variance of the outputs is 1.
So itâs reasonable to initialize the embedding matrix with standard deviation 1âthis is consistent with muP.
However, the \(1/\sqrt{n}\) logit scaling does not match the \(1/n\) scaling used in muP,
but i believe perhaps Adafactor compensates for this mismatch by adapting the update scale.
Finally, in Gemma-2 and Gemma-3, we can check that GDM continues to try to apply proper parameterization.
Fig.
Assume you train Mixture of Experts (MoE) model. What init std, and effective lr should we use for sparsely activated Feed Forward Neural Network (FFN)? Indeed, MoE is large FFN but only partial weight matrix is activated dynamically.
Sometimes we treat MoEâs FFN as very big like num_experts X activated FFN size
and sometimes as just activated FFN size
.
Itâs kinda complicated because as far as i know when we measure âhow we effieicntly train compared to denseâ, we use later one,
but for fitting HP Scaling Law or measuring cbsz, we use former one.
Then what fan-in number should we use?
Also, MoE consumes different the number of tokens compared to attention module. And in MoE training, effective bsz are different for attn and FFN modules because MoE is sparsely activated according to tokens. Assume Global batch size (gbsz) is \(N\), then attn module will consume \(N\) tokens, but FFN consumes \(\frac{N}{E}\times K \times 1.0\) if topk=K, num_exeprts=E, capacity=1.0, Expert Parallelism (EP)=1. So, gradient noise scale from each modules may be different. Then should we scale down lr compared to attn? What per layer lr, init std and bsz should we use for FFNs compared to dense? Does muP can be compatible with MoE? (in my experiments, yes it does but init std of FFN affects performance a lot)
In Switch Transformers, they discount init std, and Efficient Large Scale Language Modeling with Mixtures of Experts reduce lr for FFN for training stability. However, idk it is good solution, especially for reducing lr. in Skywork-MoE, they reduce lr for FFN modules, it seems to work but baseline (not adjust lr for FFN) beats this lr adjusted version in the end.
Why? i guess adjusted lr method can cause imbalanced update quantity between ffnâs and attnâs, so it breaks maximal update.
And there is another problem like MoE moduleâs output (Root Mean Square (RMS) or L2 norm) is relatively small compared to attn module because for backpropagation, typically MoE module output is multiplie by itâs gating score.
\[y = \sum_i^E \underbrace{G_i(x)}_{\text{gating score} \in [0, 1]} \cdot FFN_i(x)\]So, DeepSeek tries to multiply some factor to match output scale with attention module,
Fig. from DeepSeek-V2
and i guess this value depeneds on how many routed expert we use, and what gating function do we use (sigmoid or softmax).
Fig.
Finally, though it seems to allow HP transfer well across various tasks and model architectures from computer vision and language modeling domain,
Fig. From my experiments, muP can be applied to almost every architectures including MoE or something because everthing is built on top of dot product and addition
Fig. muP on various tasks and architectures such as Diffusion Transformer (DiT), Diffusion Language Modeling (LLaDa) and MoE. all credits: cloneofsimo
Fig. muP works succesfully on Mamba architectures. Source
Iâd like to say there is still many room to do in advanced architectures like MoE (even though MoE for deep neural networks is popularized in 2017 by Noam Shazeer) or something newly proposed models. (from Simoâs work and my experiences, it seems that MoE is compatible with muP but in my experience, MoE is bit sensitive to init std and output scale of routed experts in some large scale)
+Updated) Jingyuan Liu, the 1st author of Muon is Scalable for LLM Training, mentioned it is very important to match Root Mean Squared (RMS) scale of each moduleâs output (especially routed experts and shared expert) to prevent expert collapse. (see appendix of moonlight paper)
Fig. Source tweet
Fig. Source tweet
As discussed above, even muP does not guarantee hyperparameter (HP) transfer across training tokens or bsz.
To my best knowledge, there is not theory to ensure optimal lr scaling rule for both training horizon and bsz.
Itâs very complicated to predict because every factors like bsz, num tokens, adaptive optimizerâs HPs, ⌠are all correlated.
So, even though relying on empirical scaling laws may not feel mathematically beautiful,
fitting power laws for HPs like lr or bsz given computing budget or training horizon seems reasonable in practice.
Because Scaling Law is universal behavior
.
To the best of my knowledge, DeepSeek was the first to empirically fit HP scaling laws for large-scale neural networks.
In DeepSeek LLM (DS-V1), they fit lr and bsz with respect to the compute budget \(C\):
How can we find this Scaling Law?
Source | attn type | FLOPs per token formula | calculated (FLOPs per â?â token) | changes | C vs LR | C vs Bsz (tokens per iter) |
---|---|---|---|---|---|---|
Kaplan et al | MHA | \(6N = 72 n_{layer} d_{model}^2\) | Â | baseline | Â | Â |
Hoffmann et al | MHA | \(72 n_{layer} d_{model}^2 + \underbrace{6 n_{vocab} d_{model}}_{logit}\) | Â | +logit | Â | Â |
DeepSeek V1 | MHA | \(72 n_{layer} d_{model}^2 + \underbrace{12 n_{layer} d_{model} l_{seq}}_{SDPA}\) | 4.3T per 2K tokens | +attn but only non-embedding | \(0.3318 C^{-0.1250}\) | \(0.2920 C^{0.3271}\) |
DeepSeek-MoE 2B | MHA | (maybe same with DSV1) | 4.3T per 2K tokens | Â | Â | Â |
DeepSeek-MoE 16B | MHA | (maybe same with DSV1) | 74.4T per 4K tokens | Â | Â | Â |
Megatron-LM | MHA | \(\begin{aligned} & 72 n_{layer} d_{model}^2(1+\frac{l_{seq}}{6d_{model}}+\frac{n_{vocab}}{12 d_{model} l_{seq}}) & \\ & = 72 n_{layer} d_{model}^2 + \underbrace{12 n_{layer} d_{model} l_{seq}}_{SDPA} + \underbrace{6 n_{vocab} d_{model}}_{logit} & \\ \end{aligned}\) | Â | +attn +logit | Â | Â |
Minimax | MHA | \(72 n_{layer} d_{model}^2(1+\frac{l_{seq}}{6d_{model}}+\frac{5}{18{d_{model}}})\) | Â | Â | Â | Â |
 | Linear (Lightening) | \(72 n_{layer} d_{model}^2(1+\frac{1}{2 n_{head}}+\frac{5}{18{d_{model}}})\) |  |  |  |  |
 | Hybrid Linear (Lightening) | \(72 n_{layer} d_{model}^2(1+\frac{l_{seq}}{48d_{model}}+\frac{7}{16 n_{head}}+\frac{5}{18{d_{model}}})\) |  |  |  |  |
Moonlight | MLA | \(6N\) | Â | Â | \(0.0127 C^{-0.057}\) | \(0.0065 C^{0.4138}\) |
Stepfun | GQA | \(6N\) | Â | Â | \(1.79 N^{-0.713} D^{0.307}\) | \(0.58 D^{0.571}\) |
If youâre unfamiliar with how FLOPs/token is computed, check Kaplan et al. (2020).
Briefly saying, for matrix multiply with \(X \in \mathbb{R}^{m \times n}, W \in \mathbb{R}^{n \times k}\), forward pass costs \(2mnk\) FLOPs (fused multiply-add; FMA), and backward roughly doubles this.
Note that, Kaplan et al. considered attention didnât contribute to C a lot in 2020 (if i remember correctly),
But in 2025, with long contexts, attention terms become significant due to the quadratic sequence length term is not negligible anymore.
Fig. Source from Transformer FLOPs by Adam Casson
Even so, for GPT-4++-scale models (trillions of parameters), attention cost might again be negligible but iâm not sure what formula they used.
Fig. Source from Transformer FLOPs by Adam Casson
Among these, recently proposed Stepfun Law is especially interesting.
Finally, it is noteworthy that if you want to apply these scaling laws, you must match the training setup,
same lr scheduler and FLOPs/token formula.
Otherwise, predicted optimal values may diverge significantly.
Fig. Stepfun Law shows how HP landscape shifts based on min lr and lr schedule.
Someone might think like âwtf is Optimal Batch Size
?â.
Indeed, bsz should not be tunable parameter for validation performance if you propetly set optimizer parameters like lr, adam beta 1,2 and eps
.
It should only be a factor for training efficiency (throughput).
Fig. Source from Deep Learning Tuning Playbook
But bsz canât be scaled up forever because there is a certain point called the Critical Batch Size (cbsz)
.
cbsz simply means âthe point at which more parallelization (i.e., increasing bsz) becomes ineffective, because convergence still requires a minimum number of parameter updates (e.g., at least 1,000), since the gradient signal no longer improves.â
(Of course, you can check the original paper An Empirical Model of Large-Batch Training, but for simple intuition, hereâs my thread about cbsz.)
Fig. Less noisy gradient estimates allow SGD-type optimizers to take larger steps, leading to convergence in a smaller number of iterations.
Fig. after certain point, compute cost for reaching same validation performance starts to be increased. before that, if you use 2x GPUs, bsz is doubled, and training time is reduced by 1/2, so itâs compute cost remain same.
So, as long as you donât exceed the cbsz, model performance should not depend on bsz, and thus, an âoptimal bsz shouldnât really exist.
Additionally, when considering the generalization capacity of neural networks, things become a bit more complicated.
Fig. Source tweet
So in most of cases, itâs safe to use smaller enough bsz unless you hurt MFU because there exists cbsz (ofc you should tune HPs for optimizer well). (In the Large Language Model (LLM) era, the concept of overfitting barely existsâ we rarely apply dropout or similar regularization, and the generalization gap is often not a major concern.)
However, in real-world settings, there does exist an âoptimal bszâ
Fig. MiniCPM: Unveiling the Potential of Small Language Models with Scalable Training Strategies
Why?
I guess itâs because in real-world settings, we use stochastic optimizers, where the model is trained using sampled mini-batches, and we usually keep HPs like Adam(W)âs betas, epsilon, and weight decay fixed across different compute budgets, and only tune the lr.
In other words, if all HPs were tuned jointly (including Adam(W)âs betas, epsilon, weight decay, and lr),
then the idea of an âoptimal bszâ might be meaningless (see Simoâs note),
But in practice, as long as we only tune the lr while leaving the rest fixed,
the notion of an optimal bsz still holds some practical value.
Anyway, we can fit scaling laws for HPs.
Below are actual lr scheduler examples derived from each paperâs HP scaling law.
The scheduler includes both the estimated peak lr and training steps (which relate to bsz),
so itâs helpful to compare how different their scaling laws are.
I tested all these methods using the DSV1 paperâs 7B and 67B model configurations with 2T tokens.
Most methods worked well âexcept for the StepFun lawâ.
According to StepFun, the model should be trained for much longer with a 2â3x larger peak lr, which doesnât make sense to me.
It makes sense that Moonlightâs peak LR is larger than DSV1âs,
because Moonlight used a cosine lr scheduler.
I also found that the rough estimate \(C = 6ND = (72 \cdot n_{\text{layer}} \cdot d_{\text{model}}^2) \cdot D\) (Kaplan-style compute formula) and DSV1âs FLOPs per token does not match with modern LLMs, especially wide and shallow models like Gemma-2 or MoE models.
For MoE models, I assume \(N\) to be the total number of parameters, not just activated ones.
Thatâs because MoEs are essentially sparsely-activated versions of large dense models.
You can think of a 16B-parameter MoE with 2.3B activated as a 16B dense model where most weights are zero.
This is still valid even if you trained the full 16B model.
However, in this MoE model, the attention modules contribute very few parameters compared to FFNs.
Thatâs because MoEs typically scale up the FFN part.
For context: the FFN/attention ratio in models like LLaMA is around 3â4,
but Moonlightâs model has a ratio of 44,
and DSV1âs FLOPs-per-token calculation doesnât account for this at all.
(iâm not sure what scaling law they use for MoE models)
So, if you want to use any existing HP scaling law,
you should at least match the original setup they were derived from (e.g., wide/shallow ratio, LR scheduler, etc.).
Otherwise, you should fit your own law.
Originally, cbsz does not depend on model size, but rather on achievable loss.
So it makes sense to double the bsz once the model reaches a certain loss threshold in order to train faster.
This is called bsz warmup (or ramp-up).
In fact, MiniMax-01 was trained using this strategy.
They fit the cbsz scaling law, \(\text{cbsz} = f(L)\), and double the bsz following the fitted power law.
Fig. MiniMax-01: Scaling Foundation Models with Lightning Attention
However, achievable loss is also related to the compute budget \(C\), and \(C\) depends on both model size \(N\) and dataset size \(D\).
So one might intuitively think:
âOh, if the model size increases, loss decreases (i.e., larger models are more sample-efficient),
so we can use a larger bsz for faster training.â
But thatâs not true.
It has been shown that cbsz barely depends on model size, and that training tokens contribute much more.
Fig. cbsz scaling behavior from How Does Critical Batch Size Scale in Pre-training?
(If you want to dive deeper, check out the full paper.)
It is noteworthy that cbsz is also sensitive to the optimizer.
Fig. Source tweet
When we use Adam(W) for all experiments, cbsz is usually determined by data quality and training horizon. But if you use better optimizer, e.g. 2nd order optimizer like MomentUm Orthogonalized by Newton-Schulz (Muon), its cbsz is much larger, allowing us to train large transformers more efficiently. Here, âmore efficientlyâ assumes an infinite-GPU scenario. For example, when using Adam(W), we canât scale up bsz beyond 10M. So even with 50,000 GPUs, we canât finish training faster, because we canât update model parameters using 20M or 30M tokens. But if your cbsz threshold is 20M, you can reach the same validation loss with a 10M Adam(W) baseline twice as fast.
Fig. Practical Efficiency of Muon for Pretraining
In Dion paper, they show that their new optimizer can be scaled even better than Muon for large-batch training.
Fig. Dion: A Communication-Efficient Optimizer for Large Models
Here are some reference bsz and lr used for LLM pre-training.
Note that I excluded papers that donât explore or discuss scaling rules (at least in the paper).
For example, LLaMA 1, 2, and 3 all use the same bsz and lr, even though the number of training tokens varies from 2T to 15Tâ
and to me, that doesnât make much sense.
So I decided to exclude them from this list.
model | model type | activated param | total param | num. training tokens | bsz (tokens) | seqlen | lr | init std | optim | method used for predicting bsz, lr | notes |
---|---|---|---|---|---|---|---|---|---|---|---|
GPT-4 (Leaked, not sure) | MoE | 280B (E16K2S0) | 1.8T | 13T | 60M | 8K | ? | ? | ? | ? | (maybe muP is used) |
DeepSeek-V1 7B | Dense | 7B | 7B | 2T | 9.4M | 4K | peak 4.2x10-4 (Warmup Constant) | 0.006 (SP) | adamw (0.9,0.95,0.1) | Scaling Law (from DSV1) | Multi Head Attention (MHA) |
DeepSeek-V1 67B | Dense | 67B | 67B | 2T | 18.9M | 4K | peak 3.2x10-4 (Warmup Constant) | 0.006 (SP) | adamw (0.9,0.95,0.1) | Scaling Law (from DSV1) | Group Query Attention (GQA) |
DeepSeek-MoE 2B | MoE | 0.24B (E64K6S2) | 1.89B | 2T | 4M | 2K | peak 1.08x10-3 (Warmup Constant) | 0.006 (SP) | adamw (0.9,0.95,0.1) | Scaling Law (from DSV1) | MHA |
DeepSeek-MoE 16B | MoE | 2.8B (E64K6S2) | 16B | 2T | 18.4M | 4K | peak 4.2x10-4 (Warmup Constant) | 0.006 (SP) | adamw (0.9,0.95,0.1) | Scaling Law (from DSV1) | MHA |
DeepSeek-V2 Lite 16B | MoE | 2.4B (E64K6S2) | 15.7B | 5.7T | 18.9M | 4K | peak 4.2x10-4 (Warmup Constant) | 0.006 (SP) | adamw (0.9,0.95,0.1) | Scaling Law (from DSV1) | Multi Head Latent Attention (MLA) |
DeepSeek-V2 | MoE | 21B (E160K6S2) | 236B | 8.1T | 9.4M (init) -> 37.7M (at 225B) | 4K | peak 2.4x10-4 (Warmup Constant) | 0.006 (SP) | adamw (0.9,0.95,0.1) | Scaling Law (from DSV1) | MLA |
DeepSeek-V3 | MoE | 37B (E256K8S1) | 671B | 14.8T | 12.6M (init) -> 62.9M (at 469B) | 4K | peak 2.2x10-4 (Warmup Constant) | 0.006 (SP) | adamw (0.9,0.95,0.1) | Scaling Law (from DSV1) | MLA, Multi Token Prediction (MTP) |
MiniMax | MoE | 45.9B (E32K2S0) | 456B | 11.4T | 16M -> 32M (at 69B) -> 64M (at 790B) -> 128M (at 4.7T) | 8K | peak 2.4x10-4 (Warmup Stable Decay) | ? | adamw | Scaling Law for cbsz (it could be not optimal but efficient) | Hybrid Linear +Softmax Attention |
Moonlight | MoE | 2.24B | 15.29B | 5.7T | 16.7M (33B) -> 33.5M (5.2T) | 8K | peak 4.2x10-4 (Warmup Cosine) | ? | adamw and muon | Scaling Law (from DSV1) | DSV3 style |
MiniCPM 1.2B | Dense | 1.2B | 1.2B | 1.1T | 2M -> 4M | 4K | peak 0.01 (mu-transfered, Warmup Stable Decay) | 0.1 (muP) | ? | muP and Scaling Law for bsz | Â |
MiniCPM 2.4B | Dense | 2.4B | 2.4B | 1.1T | 4M | 4K | peak 0.01 (mu-transfered, Warmup Stable Decay) | 0.1 (muP) | ? | muP and Scaling Law for bsz | Â |
Granite 3.0 | MoE | 800M | 3B | 10T | 4M | 4K | peak 0.02 (mu-transfered, Power Scheduler) | ? | ? | muP | Â |
Hunyuan-Large | MoE | 52B (E16K1S1) | 389B | 7T | ? | ? | ? | ? | ? | ? | Cross Linear Attention (CLA) |
Recently, many researchers have started placing LayerNorms everywhere, including:
Fig. Source from Methods of improving LLM training stability
They seem to work because they keep the range of pre-activations and gradients within a reasonable bound. (However, they may still fail to address extremely large activations as expected. For example, in Gemma-3, residual post-norm and softcapping fail to suppress excessively large activation norms)
One of the earliest and most prominent âput-everywhere normâ techniques, QK-LayerNorm
, comes from Scaling Vision Transformers to 22 Billion Parameters, where it works well.
The key observation of this paper is that the L2 norm of logits before softmax (before Scaled Dot Product Attention (SDPA) and output head) is critical to prevent loss divergence.
So they normalize some tensors to ensure logits donât blow upâand it worked.
This idea was further studied in Small-scale Proxies for Large-scale Transformer Training Instabilities,
where the authors report that QK-LayerNorm enlarges the lr basin.
This means lr sensitivity is reduced, making it easier to find near-optimal performance across a range of model widths.
And, this is compatible with muP. Of course, some may ask:
âWhy should we use with QK-LayerNorm when muP already enables hyperparameter transfer? And, it could reduce throughput due to increased kernel launch overhead.â
Thatâs a valid point, but as we discussed earlier, muP doesnât guarantee HP transfer across training horizon or bsz.
So using QK-LayerNorm can be a good practical choice because it can increase the probability of choosing near-optimal lr.
In addition, it has been shown to outperform the vanilla Transformer even in settings where the HPs are optimal.
And it seems that all these âput-everywhere normâ enlarge lr basin empirically.
Fig. Source from Methods of improving LLM training stability
+Updated) However, even QK-LayerNorm can be problematic, as it may reduce the modelâs ability to handle long contexts by suppressing the sharpness of the attention logits. This means the model may fail to assign sufficiently high probabilities to target tokens, especially as the number of classes increases. (We might need to consider proper parameterization or architectural improvements instead of relying on naive QK-LayerNorm)
Fig.
(This subsection might be slightly outdated as it is based on my personal notes from early 2024.)
bfloat16 (bf16)
instead of float16 (fp16)
float32 (fp32)
and doesnât require dynamic loss scaling (no overhead).fp8
if you have state-of-the-art accelerators and no skill issues.Consider using Maximal Update Parameterization (muP)
instead of Standard Parameterization (SP)
"Don't get too caught up in finding completely optimal HPs"
. IMO, itâs ok to use near-optimal HPs, so youd better use your resource to data quality.Let me tell about âwhy should we care MFUâ little bit.
It is noteworthy that the one of the most important thing when scale up is "if your method (new arch, new learning algorithm) is scalable or not"
.
If your method improves convergence speed by 10%,
but reduces throughput by 20%, is it really an improvement?
Fig. How to Scale Your Model from Jacob et al.
There are some examples like:
A good training algorithm (and infra design) should scale well with both model size and hardware.
It is well known that adam(w) beta1,2,eps=(0.9,0.95,1e-8) works well. But it could be not optimal. I didnt explore this topic a lot, but How Does Critical Batch Size Scale in Pre-training? present some analysis, so i recommend you to read this paper.
Fig. Adam beta 1, 2 ablation from How Does Critical Batch Size Scale in Pre-training?
Fig. Adam beta 1, 2 ablation from How Does Critical Batch Size Scale in Pre-training?
Itâs noteworthy that some open-source framework model configs hardcode std=0.02
, a value inherited from GPT-2 (up to 1.5B scale).
However, itâs not suitable for larger models like 30B, 60B, etc., because 0.02 roughly corresponds to \(\sqrt{1/1536}\), the standard deviation derived from SPâs fan-in variance.
Itâs also questionable to use the same 0.02 for FFN modules, where the inner dimension is typically 3â4Ă larger than the attention embedding dimension.
>>> for d in range(512,8192+512,512):
... print(f"d_model (width): {d}, 1/sqrt(width): {1/math.sqrt(d):.4f}")
...
d_model (width): 512, 1/sqrt(width): 0.0442
d_model (width): 1024, 1/sqrt(width): 0.0312
d_model (width): 1536, 1/sqrt(width): 0.0255
d_model (width): 2048, 1/sqrt(width): 0.0221
d_model (width): 2560, 1/sqrt(width): 0.0198
d_model (width): 3072, 1/sqrt(width): 0.0180
d_model (width): 3584, 1/sqrt(width): 0.0167
d_model (width): 4096, 1/sqrt(width): 0.0156
d_model (width): 4608, 1/sqrt(width): 0.0147
d_model (width): 5120, 1/sqrt(width): 0.0140
d_model (width): 5632, 1/sqrt(width): 0.0133
d_model (width): 6144, 1/sqrt(width): 0.0128
d_model (width): 6656, 1/sqrt(width): 0.0123
d_model (width): 7168, 1/sqrt(width): 0.0118
d_model (width): 7680, 1/sqrt(width): 0.0114
d_model (width): 8192, 1/sqrt(width): 0.0110
But you can still train your model successfully without deep mathematical justification.
According to Using DeepSpeed and Megatron to Train Megatron-Turing NLG 530B,
they use lower init std than SP, like \(1/\sqrt{3 \cdot \text{fan-in}}\).
And DeepSeek consistently uses 0.006 as the init std from v1 (7B dense) to v3 (671B total MoE).
I honestly donât know exactly why this works so well.
My guess is that itâs due to the combination of many normalization modules, residual connections, and adaptive optimizersâbut who really knows?
Inspired by Rethinking Conventional Wisdom in Machine Learning: From Generalization to Scaling. In conventional regime, small bsz and lr at Edge of Stability (EoS) is known as good choice, but it does not seem to fit in large scale training and overparameterized regime. So, iâd like to recommed you to doubt conventional wisdom when scaling up NN.
Fig. lr at EoS is not optimal in Large Scale Regime
Fig. smaller batch size is not optimal in Large Scale Regime
Is pre-training scaling done?
I guess frontier labs have already used up the entire internet.
However, there is room for scaling i guess
Transformer was a revolution level improvement, but it seems not enough today. I believe we can keep pushing the limit by improving architecture and optimizer.
Fig. from my past slide on Scaling Law and MoE
Fig. from my past slide on Scaling Law and MoE
Fig. from my past slide on Scaling Law and MoE
Western frontier labs have been studying MoE for a long time,
and Chinese labs are following the trend.
Recently, State Space Models (SSMs), hybrid attention, and sparse attention models have been activly studied,
and we should know how to scale these models (ofc these models also consists of bunch of matmul, but)
Fig. from Jeff Deanâs Talk
Fig. from Jeff Deanâs Talk
public discussion
on how to properly parameterize them for scaling.
Fig. Muon vs Adam. Muon keep outperforming adam, but we donât know how it can be well integrated with muP or something. Source from Muon is Scalable for LLM Training.
Fig. Dualized Training not only outperforms muP but also transfer lr. Source from Jeremy Bernsteinâs Blog
Iâd like to thank everyone for taking the time to read this post and for sharing any insightful feedback. Special thanks to Simo Ryu for many thoughtful discussions, Jingyuan Liu for sharing the reasoning behind MoE scaling factors, Alexandre for the experimental results on Mamba with muP. And Iâd also like to thank Seonghyeon Kim, Charlie Blake, and Donghoon Ham for their warm encouragement and helpful feedback.
PLACEHOLDER FOR BIBTEX