Skip to content

Adding transformer encoder and decoder layers to flax source as in pytorch #5176

@coder0143

Description

@coder0143

The pytorch source consists of implementations of wrappers for transformer modules.

SRC: https://github.com/pytorch/pytorch/blob/v2.9.1/torch/nn/modules/transformer.py#L966

I want to add such implementation for ease of use / ux. I will make a new file: flax/flax/nnx/nn/transformer.py which will contain the following modules:

  • TransformerEncoderLayer
  • TransformerEncoder
  • TransformerDecoderLayer
  • TransformerDecoder
  • Transformer

I will keep it consistent with nnx.Linear and nnx.MultiHeadAttention modules and update the docs too, if needed, I can implement custom separate attentions such as MHSA(for full) and GQA(with kv-cache) based on review. Can I do a PR? @cgarciae @vfdev-5

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions