The pytorch source consists of implementations of wrappers for transformer modules.
SRC: https://github.com/pytorch/pytorch/blob/v2.9.1/torch/nn/modules/transformer.py#L966
I want to add such implementation for ease of use / ux. I will make a new file: flax/flax/nnx/nn/transformer.py which will contain the following modules:
TransformerEncoderLayer
TransformerEncoder
TransformerDecoderLayer
TransformerDecoder
Transformer
I will keep it consistent with nnx.Linear and nnx.MultiHeadAttention modules and update the docs too, if needed, I can implement custom separate attentions such as MHSA(for full) and GQA(with kv-cache) based on review. Can I do a PR? @cgarciae @vfdev-5