PrecisionPlugin¶
- class lightning.pytorch.plugins.precision.PrecisionPlugin[소스]¶
기반 클래스:
lightning.fabric.plugins.precision.precision.Precision
,lightning.pytorch.core.hooks.CheckpointHooks
Base class for all plugins handling the precision-specific parts of the training.
The class attribute precision must be overwritten in child classes. The default value reflects fp32 training.
- backward(tensor, model, optimizer, *args, **kwargs)[소스]¶
Performs the actual backpropagation.
- 매개변수
model¶ (
LightningModule
) – the model to be optimizedoptimizer¶ (
Optional
[Steppable
]) – current optimizer being used.None
if using manual optimization*args¶ – Positional arguments intended for the actual function that performs the backward, like
backward()
.**kwargs¶ – Keyword arguments for the same purpose as
*args
.
- 반환 형식
- clip_gradients(optimizer, clip_val=0.0, gradient_clip_algorithm=GradClipAlgorithmType.NORM)[소스]¶
Clips the gradients.
- 반환 형식
- connect(model, optimizers, lr_schedulers)[소스]¶
Connects this plugin to the accelerator and the training process.
- post_backward(tensor, module)[소스]¶
Runs after precision plugin executes backward.
- 매개변수
tensor¶ (
Tensor
) – The tensor that will be used for backpropagationmodule¶ (
LightningModule
) – The module that was involved in producing the tensor and whose parameters need the gradients
- 반환 형식
- pre_backward(tensor, module)[소스]¶
Runs before precision plugin executes backward.
- 매개변수
tensor¶ (
Tensor
) – The tensor that will be used for backpropagationmodule¶ (
LightningModule
) – The module that was involved in producing the tensor and whose parameters need the gradients
- 반환 형식