Shortcuts

StochasticWeightAveraging

class lightning.pytorch.callbacks.StochasticWeightAveraging(swa_lrs, swa_epoch_start=0.8, annealing_epochs=10, annealing_strategy='cos', avg_fn=None, device=device(type='cpu'))[소스]

기반 클래스: lightning.pytorch.callbacks.callback.Callback

Implements the Stochastic Weight Averaging (SWA) Callback to average a model.

Stochastic Weight Averaging was proposed in Averaging Weights Leads to Wider Optima and Better Generalization by Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson (UAI 2018).

This documentation is highly inspired by PyTorch’s work on SWA. The callback arguments follow the scheme defined in PyTorch’s swa_utils package.

For a SWA explanation, please take a look here.

경고

This is an experimental feature.

경고

StochasticWeightAveraging is currently not supported for multiple optimizers/schedulers.

경고

StochasticWeightAveraging is currently only supported on every epoch.

See also how to enable it directly on the Trainer

매개변수
  • swa_lrs (Union[float, List[float]]) –

    The SWA learning rate to use:

    • float. Use this value for all parameter groups of the optimizer.

    • List[float]. A list values for each parameter group of the optimizer.

  • swa_epoch_start (Union[int, float]) – If provided as int, the procedure will start from the swa_epoch_start-th epoch. If provided as float between 0 and 1, the procedure will start from int(swa_epoch_start * max_epochs) epoch

  • annealing_epochs (int) – number of epochs in the annealing phase (default: 10)

  • annealing_strategy (str) –

    Specifies the annealing strategy (default: “cos”):

    • "cos". For cosine annealing.

    • "linear" For linear annealing

  • avg_fn (Optional[Callable[[Tensor, Tensor, Tensor], Tensor]]) – the averaging function used to update the parameters; the function must take in the current value of the AveragedModel parameter, the current value of model parameter and the number of models already averaged; if None, equally weighted average is used (default: None)

  • device (Union[str, device, None]) – if provided, the averaged model will be stored on the device. When None is provided, it will infer the device from pl_module. (default: "cpu")

static avg_fn(averaged_model_parameter, model_parameter, num_averaged)[소스]

Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L95-L97.

반환 형식

Tensor

load_state_dict(state_dict)[소스]

Called when loading a checkpoint, implement to reload callback state given callback’s state_dict.

매개변수

state_dict (Dict[str, Any]) – the callback state returned by state_dict.

반환 형식

None

on_fit_start(trainer, pl_module)[소스]

Called when fit begins.

반환 형식

None

on_train_end(trainer, pl_module)[소스]

Called when the train ends.

반환 형식

None

on_train_epoch_end(trainer, *args)[소스]

Called when the train epoch ends.

To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the pytorch_lightning.LightningModule and access them in this hook:

class MyLightningModule(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.training_step_outputs = []

    def training_step(self):
        loss = ...
        self.training_step_outputs.append(loss)
        return loss


class MyCallback(L.Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        # do something with all training_step outputs, for example:
        epoch_mean = torch.stack(pl_module.training_step_outputs).mean()
        pl_module.log("training_epoch_mean", epoch_mean)
        # free up the memory
        pl_module.training_step_outputs.clear()
반환 형식

None

on_train_epoch_start(trainer, pl_module)[소스]

Called when the train epoch begins.

반환 형식

None

reset_batch_norm_and_save_state(pl_module)[소스]

Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L140-L154.

반환 형식

None

reset_momenta()[소스]

Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165.

반환 형식

None

setup(trainer, pl_module, stage)[소스]

Called when fit, validate, test, predict, or tune begins.

반환 형식

None

state_dict()[소스]

Called when saving a checkpoint, implement to generate callback’s state_dict.

반환 형식

Dict[str, Any]

반환

A dictionary containing callback state.

static update_parameters(average_model, model, n_averaged, avg_fn)[소스]

Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L104-L112.

반환 형식

None