Shortcuts

GradientAccumulationScheduler

class lightning.pytorch.callbacks.GradientAccumulationScheduler(scheduling)[소스]

기반 클래스: lightning.pytorch.callbacks.callback.Callback

Change gradient accumulation factor according to scheduling.

매개변수

scheduling (Dict[int, int]) – scheduling in format {epoch: accumulation_factor}

참고

The argument scheduling is a dictionary. Each key represent an epoch and its associated accumulation factor value. Warning: Epoch are zero-indexed c.f it means if you want to change the accumulation factor after 4 epochs, set Trainer(accumulate_grad_batches={4: factor}) or GradientAccumulationScheduler(scheduling={4: factor}). For more info check the example below.

예외 발생
  • TypeError – If scheduling is an empty dict, or not all keys and values of scheduling are integers.

  • IndexError – If minimal_epoch is less than 0.

Example:

>>> from lightning.pytorch import Trainer
>>> from lightning.pytorch.callbacks import GradientAccumulationScheduler

# from epoch 5, it starts accumulating every 2 batches. Here we have 4 instead of 5
# because epoch (key) should be zero-indexed.
>>> accumulator = GradientAccumulationScheduler(scheduling={4: 2})
>>> trainer = Trainer(callbacks=[accumulator])
on_train_epoch_start(trainer, *_)[소스]

Called when the train epoch begins.

반환 형식

None

on_train_start(trainer, pl_module)[소스]

Performns a configuration validation before training starts and raises errors for incompatible settings.

반환 형식

None