Skip to content Skip to footer

Maixtchup: Make Your Own Mixture of Experts with Mergekit


The rise of the MoEs

Image by the author ā€” Generated with DALL-E

Since the release of Mixtral-8x7B by Mistral AI, there has been a renewed interest in the mixture of expert (MoE) models. This architecture exploits expert sub-networks among which only some of them are selected and activated by a router network during inference.

MoEs are so simple and flexible that it is easy to make a custom MoE. On the Hugging Face Hub, we can now find several trending LLMs that are custom MoEs, such as mlabonne/phixtral-4x2_8.

However, most of them are not traditional MoEs made from scratch, they simply use a combination of already fine-tuned LLMs as experts. Their creation was made easy with mergekit (LGPL-3.0 license). For instance, Phixtral LLMs have been made with mergekit by combining several Phi-2 models.

In this article, we will see how Phixtral was created. We will apply the same process to create our own mixture of experts, Maixtchup, using several Mistral 7B models.

To quickly understand the high-level architecture of a model, I like to print it. For instance, for mlabonne/phixtral-4x2_8 (MIT license):

from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"mlabonne/phixtral-4x2_8",
torch_dtype="auto",
load_in_4bit=True,
trust_remote_code=True
)
print(model)

It prints:

PhiForCausalLM(
(transformer): PhiModel(
(embd): Embedding(
(wte): Embedding(51200, 2560)
(drop): Dropout(p=0.0, inplace=False)
)
(h): ModuleList(
(0-31): 32 x ParallelBlock(
(ln): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
(resid_dropout): Dropout(p=0.1, inplace=False)
(mixer): MHA(
(rotary_emb): RotaryEmbedding()
(Wqkv): Linear4bit(in_features=2560, out_features=7680, bias=True)
(out_proj): Linear4bit(in_features=2560, out_features=2560, bias=True)
(inner_attn): SelfAttention(
(drop): Dropout(p=0.0, inplace=False)
)
(inner_cross_attn): CrossAttention(
(drop): Dropout(p=0.0, inplace=False)
)
)
(moe): MoE(
(mlp): ModuleList(ā€¦



Source link