StochasticWeightAveraging¶
- class lightning.pytorch.callbacks.StochasticWeightAveraging(swa_lrs, swa_epoch_start=0.8, annealing_epochs=10, annealing_strategy='cos', avg_fn=None, device=device(type='cpu'))[소스]¶
- 기반 클래스: - lightning.pytorch.callbacks.callback.Callback- Implements the Stochastic Weight Averaging (SWA) Callback to average a model. - Stochastic Weight Averaging was proposed in - Averaging Weights Leads to Wider Optima and Better Generalizationby Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson (UAI 2018).- This documentation is highly inspired by PyTorch’s work on SWA. The callback arguments follow the scheme defined in PyTorch’s - swa_utilspackage.- For a SWA explanation, please take a look here. - 경고 - This is an experimental feature. - 경고 - StochasticWeightAveragingis currently not supported for multiple optimizers/schedulers.- 경고 - StochasticWeightAveragingis currently only supported on every epoch.- See also how to enable it directly on the Trainer - 매개변수
- swa_lrs¶ ( - Union[- float,- List[- float]]) –- The SWA learning rate to use: - float. Use this value for all parameter groups of the optimizer.
- List[float]. A list values for each parameter group of the optimizer.
 
- swa_epoch_start¶ ( - Union[- int,- float]) – If provided as int, the procedure will start from the- swa_epoch_start-th epoch. If provided as float between 0 and 1, the procedure will start from- int(swa_epoch_start * max_epochs)epoch
- annealing_epochs¶ ( - int) – number of epochs in the annealing phase (default: 10)
- Specifies the annealing strategy (default: “cos”): - "cos". For cosine annealing.
- "linear"For linear annealing
 
- avg_fn¶ ( - Optional[- Callable[[- Tensor,- Tensor,- Tensor],- Tensor]]) – the averaging function used to update the parameters; the function must take in the current value of the- AveragedModelparameter, the current value of- modelparameter and the number of models already averaged; if None, equally weighted average is used (default:- None)
- device¶ ( - Union[- str,- device,- None]) – if provided, the averaged model will be stored on the- device. When None is provided, it will infer the device from- pl_module. (default:- "cpu")
 
 - static avg_fn(averaged_model_parameter, model_parameter, num_averaged)[소스]¶
- Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L95-L97. - 반환 형식
 
 - load_state_dict(state_dict)[소스]¶
- Called when loading a checkpoint, implement to reload callback state given callback’s - state_dict.
 - on_train_epoch_end(trainer, *args)[소스]¶
- Called when the train epoch ends. - To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the - pytorch_lightning.LightningModuleand access them in this hook:- class MyLightningModule(L.LightningModule): def __init__(self): super().__init__() self.training_step_outputs = [] def training_step(self): loss = ... self.training_step_outputs.append(loss) return loss class MyCallback(L.Callback): def on_train_epoch_end(self, trainer, pl_module): # do something with all training_step outputs, for example: epoch_mean = torch.stack(pl_module.training_step_outputs).mean() pl_module.log("training_epoch_mean", epoch_mean) # free up the memory pl_module.training_step_outputs.clear() - 반환 형식
 
 - reset_batch_norm_and_save_state(pl_module)[소스]¶
- Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L140-L154. - 반환 형식
 
 - reset_momenta()[소스]¶
- Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165. - 반환 형식
 
 - setup(trainer, pl_module, stage)[소스]¶
- Called when fit, validate, test, predict, or tune begins. - 반환 형식
 
 - static update_parameters(average_model, model, n_averaged, avg_fn)[소스]¶
- Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L104-L112. - 반환 형식