Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 38 additions & 3 deletions docs/reference/Configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,8 @@ which are common to all optimizers (and most have to do with learning rate sched
| Parameter | Description | Default |
|-----------------|-------------------------------------------------------------------------------|----------|
| `weight_decay` | The weight decay. | `0.0` |
| `learning_rate` | The learning rate. | `1e-4` |
| `learning_rate` | Global learning rate or mapping of tag to rate. | `1e-4` |
| `param_tags` | Patterns assigning tags to parameters | `None` |
| `lr_schedule` | The type of learning rate schedule for decay. See below. | `cosine` |
| `min_lr_ratio` | The minimum learning rate ratio. | `0.1` |
| `warmup` | Warmup fraction or number of steps | `0.01` |
Expand All @@ -344,11 +345,45 @@ which are common to all optimizers (and most have to do with learning rate sched
| `cycles` | The number of cycles for the learning rate, or steps where cycles end | `None` |
| `cycle_length` | How long the cycles should be (as an int, fraction), or list of cycle lengths | `None` |

By default, Levanter uses a cosine learning rate decay with warmup. The learning rate is decayed to
`min_lr_ratio * learning_rate` over the course of the training run. This is a fairly standard default for LLM training.
#### Parameter Tags

Parameters can be grouped using ``param_tags``. Each entry is a ``TagPattern``
consisting of a ``pattern`` (or list of patterns) and a ``tag``. Patterns
match either the full dotted parameter path or the module class name, and are
evaluated in order. The first matching pattern assigns its tag; parameters that
match no pattern remain untagged.

Tags can then be referenced in ``weight_decay_modules`` and when providing a
dictionary to ``learning_rate``. Tagged parameters take precedence over pattern
matches in those fields. A small example:

```yaml
param_tags:
- pattern: lm_head.weight
tag: output
- pattern: bias
tag: bias
- pattern: Embedding.weight
tag: input
- pattern: Linear.weight
tag: hidden

learning_rate:
default: 6e-4
bias: 3e-4

weight_decay_modules:
- hidden
- input
- output
```


#### Learning Rate Schedules

By default, Levanter uses a cosine learning rate decay with warmup. The learning rate is decayed to
`min_lr_ratio * learning_rate` over the course of the training run. This is a fairly standard default for LLM training.

The `lr_schedule` parameter specifies the learning rate schedule. The following schedules are supported:

* `constant`: Constant learning rate.
Expand Down
2 changes: 1 addition & 1 deletion src/levanter/optim/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .adam_mini import MiniConfig, ScaleByMiniState
from .adopt import AdoptConfig, ScaleByAdoptState
from .cautious import CautiousConfig
from .config import AdamConfig, LionConfig, OptimizerConfig
from .config import AdamConfig, LionConfig, OptimizerConfig, TagPattern
from .kron import KronConfig
from .mars import MarsConfig, ScaleByMarsState
from .muon import MuonConfig, ScaleByMuonState
Expand Down
169 changes: 147 additions & 22 deletions src/levanter/optim/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,22 @@ class LrScheduleContext:
min_lr: float


@dataclass(frozen=True)
class TagPattern:
"""Pattern used to tag parameters.

Parameters
----------
pattern : str | list[str]
Name or names to match against parameter paths or module class names.
tag : str
Tag to apply to matched parameters.
"""

pattern: list[str] | str
tag: str


class LrSchedule(draccus.ChoiceRegistry, abc.ABC):
@abc.abstractmethod
def build(self, ctx: LrScheduleContext) -> Callable:
Expand Down Expand Up @@ -105,7 +121,6 @@ def schedule(step):

@dataclass
class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
learning_rate: float = 6e-4
weight_decay: float = 0.1

min_lr_ratio: float = 0.1
Expand All @@ -129,6 +144,19 @@ class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
weight_decay_modules: Optional[list[str] | str] = None
"""A regex or a list of strings to identify where to mask weight.
For nano-GPT, this field can be set as `r".*attn.*weight|.*mlp.*weight|.*token_embeddings|.*position_embeddings"`"""
param_tags: Optional[list[TagPattern]] = None
"""Assign tags to parameters for grouping.

The first matching pattern wins. Tags can then be used by
:pyattr:`weight_decay_modules` and :pyattr:`learning_rate`.
"""
learning_rate: float | dict[str, float] = 6e-4
"""Peak learning rate.

If a mapping is provided, the keys correspond to parameter tags and the
values are the peak learning rates for those groups. The special key
``"default"`` specifies the learning rate for any untagged parameters.
"""
default_weight_decay_mask: Optional[bool] = None
"""Whether to apply a default reasonable weight decay to modules not explicitly masked. None means it will if
no weight_decay_modules are set. False means it will not. True means it will regardless of weight_decay_modules."""
Expand All @@ -141,6 +169,66 @@ def default_choice_name(cls) -> Optional[str]:
def build(self, num_train_steps: int):
raise NotImplementedError

def _match_param_tag(self, root_key: Optional[str], class_key: Optional[str]) -> Optional[str]:
"""Return the tag matching a parameter path or class name."""

if not self.param_tags:
return None
for tp in self.param_tags:
pats = tp.pattern if isinstance(tp.pattern, list) else [tp.pattern]
for p in pats:
if (root_key and p in root_key) or (class_key and p in class_key):
return tp.tag
return None

def build_learning_rate_tree(self):
"""Create a tree of learning rates based on parameter tags."""

if not isinstance(self.learning_rate, dict):
return None

lr_map = self.learning_rate

def is_leaf(x):
return eqx.is_array(x) or isinstance(x, eqx.Module) or haliax.is_named_array(x) or isinstance(x, str)

def _apply_on(x, from_root_key_path, from_class_keypath):
if isinstance(x, eqx.Module):
is_leaf_here = lambda y: x is not y and is_leaf(y) # noqa: E731
class_name = x.__class__.__name__
from_root_key_paths = leaf_key_paths(x, is_leaf=is_leaf_here, prefix=from_root_key_path)
from_class_key_paths = leaf_key_paths(x, is_leaf=is_leaf_here, prefix=class_name)
return jax.tree_util.tree_map(
_apply_on,
x,
from_root_key_paths,
from_class_key_paths,
is_leaf=is_leaf_here,
)
elif not haliax.util.is_jax_or_hax_array_like(x):
return x

tag = self._match_param_tag(from_root_key_path, from_class_keypath)
if tag is not None and tag in lr_map:
return lr_map[tag]
else:
return lr_map.get("default", 1.0)

def lr_fn(model):
return jax.tree_util.tree_map(
_apply_on,
model,
leaf_key_paths(model, is_leaf=is_leaf),
leaf_key_paths(
model,
is_leaf=is_leaf,
prefix=model.__class__.__name__ if isinstance(model, eqx.Module) else "",
),
is_leaf=is_leaf,
)

return lr_fn

def build_weight_decay_mask(self):
def reasonable_default(module, path):
# TODO: gross
Expand All @@ -164,7 +252,7 @@ def reasonable_default(module, path):
)

def is_leaf(x):
return eqx.is_array(x) or isinstance(x, eqx.Module) or haliax.is_named_array(x)
return eqx.is_array(x) or isinstance(x, eqx.Module) or haliax.is_named_array(x) or isinstance(x, str)

# mask based on regex or module path
def _apply_on(decayed_paths, x, from_root_key_path, from_class_keypath):
Expand All @@ -186,24 +274,28 @@ def _apply_on(decayed_paths, x, from_root_key_path, from_class_keypath):
elif not haliax.util.is_jax_or_hax_array_like(x):
return x

should_decay = None
for key_path in [from_root_key_path, from_class_keypath]:
if key_path is None:
continue

if isinstance(self.weight_decay_modules, str):
compiled_regex = re.compile(self.weight_decay_modules)
should_decay = should_decay or compiled_regex.match(key_path) is not None
elif isinstance(self.weight_decay_modules, list):
should_decay = should_decay or any(
key_path.__contains__(target) for target in self.weight_decay_modules
)

if should_use_default and not should_decay:
should_decay = reasonable_default(x, key_path)

if should_decay:
break
tag = self._match_param_tag(from_root_key_path, from_class_keypath)
if tag is not None:
should_decay = isinstance(self.weight_decay_modules, list) and tag in self.weight_decay_modules
else:
should_decay = None
for key_path in [from_root_key_path, from_class_keypath]:
if key_path is None:
continue

if isinstance(self.weight_decay_modules, str):
compiled_regex = re.compile(self.weight_decay_modules)
should_decay = should_decay or compiled_regex.match(key_path) is not None
elif isinstance(self.weight_decay_modules, list):
should_decay = should_decay or any(
key_path.__contains__(target) for target in self.weight_decay_modules
)

if should_use_default and not should_decay:
should_decay = reasonable_default(x, key_path)

if should_decay:
break

if should_decay is None:
if should_use_default:
Expand Down Expand Up @@ -242,9 +334,13 @@ def lr_scheduler(self, num_train_steps, override_lr=None):
total_main_steps = num_train_steps - cooldown_steps
cooldown_points = self._get_cycle_minima(total_main_steps)

learning_rate = self.learning_rate
base_lr = self.learning_rate
if override_lr is not None:
learning_rate = override_lr
base_lr = override_lr
if isinstance(base_lr, dict):
learning_rate = 1.0
else:
learning_rate = base_lr

min_lr = learning_rate * self.min_lr_ratio

Expand Down Expand Up @@ -389,6 +485,27 @@ def _convert_frac_or_steps(frac_or_steps: float | int, num_train_steps: int):
return int(frac_or_steps)


def scale_by_tagged_learning_rates(lr_map, tag_tree_fn) -> optax.GradientTransformation:
"""Scale updates by per-parameter learning rates.

Parameters
----------
lr_map : dict[str, float]
Mapping from tag to learning rate multiplier. ``"default"`` specifies the fallback.
tag_tree_fn : Callable
Function that produces a tree of learning rates for a given parameter tree.
"""

def init_fn(params):
return tag_tree_fn(params)

def update_fn(updates, state, params=None):
scaled = jax.tree_util.tree_map(lambda g, lr: g * lr, updates, state)
return scaled, state

return optax.GradientTransformation(init_fn, update_fn)


@dataclass
class HessianOptConfig(OptimizerConfig, abc.ABC):
update_interval: int = 10
Expand Down Expand Up @@ -446,6 +563,10 @@ def _optimizer(learning_rate):
components.append(scan_aware_clip_by_block_rms(self.update_rms_clipping))
components.append(log_norm_passthrough("optim/post_clip_update_norm"))

lr_tree_fn = self.build_learning_rate_tree()
if lr_tree_fn is not None:
components.append(scale_by_tagged_learning_rates(self.learning_rate, lr_tree_fn))

# - learning rate for descent
components.append(optax.scale(-learning_rate))

Expand Down Expand Up @@ -488,6 +609,10 @@ def _optimizer(learning_rate):
if self.weight_decay > 0:
components.append(optax.add_decayed_weights(self.weight_decay, self.build_weight_decay_mask()))

lr_tree_fn = self.build_learning_rate_tree()
if lr_tree_fn is not None:
components.append(scale_by_tagged_learning_rates(self.learning_rate, lr_tree_fn))

# - learning rate for descent
components.append(optax.scale(-learning_rate))

Expand Down
Loading
Loading