Shortcuts

PrecisionPlugin

class lightning.pytorch.plugins.precision.PrecisionPlugin[소스]

기반 클래스: lightning.fabric.plugins.precision.precision.Precision, lightning.pytorch.core.hooks.CheckpointHooks

Base class for all plugins handling the precision-specific parts of the training.

The class attribute precision must be overwritten in child classes. The default value reflects fp32 training.

backward(tensor, model, optimizer, *args, **kwargs)[소스]

Performs the actual backpropagation.

매개변수
  • tensor (Tensor) – the loss value obtained from the closure

  • model (LightningModule) – the model to be optimized

  • optimizer (Optional[Steppable]) – current optimizer being used. None if using manual optimization

  • *args – Positional arguments intended for the actual function that performs the backward, like backward().

  • **kwargs – Keyword arguments for the same purpose as *args.

반환 형식

None

clip_grad_by_norm(optimizer, clip_val)[소스]

Clip gradients by norm.

반환 형식

None

clip_grad_by_value(optimizer, clip_val)[소스]

Clip gradients by value.

반환 형식

None

clip_gradients(optimizer, clip_val=0.0, gradient_clip_algorithm=GradClipAlgorithmType.NORM)[소스]

Clips the gradients.

반환 형식

None

connect(model, optimizers, lr_schedulers)[소스]

Connects this plugin to the accelerator and the training process.

반환 형식

Tuple[Module, List[Optimizer], List[Any]]

optimizer_step(optimizer, model, closure, **kwargs)[소스]

Hook to run the optimizer step.

반환 형식

Any

post_backward(tensor, module)[소스]

Runs after precision plugin executes backward.

매개변수
  • tensor (Tensor) – The tensor that will be used for backpropagation

  • module (LightningModule) – The module that was involved in producing the tensor and whose parameters need the gradients

반환 형식

Tensor

pre_backward(tensor, module)[소스]

Runs before precision plugin executes backward.

매개변수
  • tensor (Tensor) – The tensor that will be used for backpropagation

  • module (LightningModule) – The module that was involved in producing the tensor and whose parameters need the gradients

반환 형식

Tensor

predict_step_context()[소스]

A contextmanager for the predict step.

반환 형식

Generator[None, None, None]

test_step_context()[소스]

A contextmanager for the test step.

반환 형식

Generator[None, None, None]

train_step_context()[소스]

A contextmanager for the training step.

반환 형식

Generator[None, None, None]

val_step_context()[소스]

A contextmanager for the validation step.

반환 형식

Generator[None, None, None]