LYNX: ENABLING EFFICIENT MOE INFERENCE THROUGH DYNAMIC BATCH-AWARE EXPERT SELECTION
Mike’s Daily Deep Learning Paper Review - 05.03.25
I noticed that it's been a while since I reviewed an article on MoE - Mixture Of Experts in language models. Recall that MoE is a method designed to optimize inference in terms of computational load (i.e., fewer calculations). The model is trained to activate only part of the model (specific experts) for each token, where each expert is (usually) a sub-network of the FFN (in fact, it is usually as sub-matrix of the weight matrices in the FFN) within the transformer mechanism. In practice, this makes it possible to reduce the amount of computation per token, which may allow the activation of LLMs of enormous size (only part of the model each time). In addition (according to several studies), this method makes it possible to learn "more complex functions" because each token may be calculated differently (with a different subset of experts).
The experts are selected by a routing network, where it is trained to compute a non-negative score for each expert. Scores are actually "probabilities" of selecting each expert (there is softmax at the end). Usually, k experts with the highest scores are selected in each layer for each token out of N experts, where k < N. The model is trained to balance the utilization of each expert, with the goal that each expert will be utilized equally in the training dataset (aggregative level). Usually, there is a regularization term on the weights of the routing network, for example in the form of negative entropy or the sum of squares.
The paper proposes a method for optimizing memory consumption for inference of transformer models with MoE when they are activated in batches of queries (several inputs). The proposed approach is based on several empirical observations made by the authors:
- The distribution of the frequency of expert activation within the batch is not uniform, meaning there are experts that are activated more and there are those that are activated less.
- The computational density (arithmetic intensity), which is the ratio between the amount of flops and the amount of memory accesses, decreases when the number of experts increases in the decode phase (i.e., prediction). This makes this phase memory-bound, which increases the latencies.
- The tokens are not very sensitive to their experts beyond a few experts (from top-k) with the highest scores. That is, it is possible to "activate only the experts" without significant damage to performance.
- Not all tokens are equal, meaning there are tokens that are more sensitive to the use of some of their experts and there are those that are less. The authors claim that it is possible to infer the level of sensitivity of the token from the routing network scores for it.
- The prefill phase (prompt processing) is more sensitive to the replacement of experts than the decode phase (generation).
- The sensitivity to the replacement of experts varies between the layers of the model, where the middle layers are the most sensitive to it.
The authors propose to take advantage of these observations in the following way (there are several variations, I will describe the main method):
- All experts are used in the prefill phase (which is compute-bound).
- Sensitive and less sensitive tokens (low and high confidence) are identified in the batch. Then the experts of the less sensitive tokens are filtered.
- The experts that are most used for the batch are selected and the rest are filtered.
- Only the remaining experts are activated for all tokens (top-k). A second option (less damaging to performance) - is to activate all experts for sensitive tokens and only those that remain for less sensitive tokens.
This method makes it possible to increase computational density for the decode phase and make it less memory-bound without significant damage to performance.
https://arxiv.org/abs/2411.08982