DeepSpeedPrecisionPlugin¶
- class lightning.pytorch.plugins.precision.DeepSpeedPrecisionPlugin(precision)[소스]¶
기반 클래스:
lightning.pytorch.plugins.precision.precision_plugin.PrecisionPlugin
Precision plugin for DeepSpeed integration.
경고
This is an experimental feature.
- 매개변수
precision¶ (
Literal
[‘32-true’, ‘16-mixed’, ‘bf16-mixed’]) – Full precision (32), half precision (16) or bfloat16 precision (bf16).- 예외 발생
ValueError – If unsupported
precision
is provided.
- backward(tensor, model, optimizer, *args, **kwargs)[소스]¶
Performs back-propagation using DeepSpeed’s engine.
- 매개변수
model¶ (
LightningModule
) – the model to be optimized*args¶ – additional positional arguments for the
deepspeed.DeepSpeedEngine.backward()
call**kwargs¶ – additional keyword arguments for the
deepspeed.DeepSpeedEngine.backward()
call
- 반환 형식