Shortcuts

Timer

class lightning.pytorch.callbacks.Timer(duration=None, interval=Interval.step, verbose=True)[소스]

기반 클래스: lightning.pytorch.callbacks.callback.Callback

The Timer callback tracks the time spent in the training, validation, and test loops and interrupts the Trainer if the given time limit for the training loop is reached.

매개변수
  • duration (Union[str, timedelta, Dict[str, int], None]) – A string in the format DD:HH:MM:SS (days, hours, minutes seconds), or a datetime.timedelta, or a dict containing key-value compatible with timedelta.

  • interval (str) – Determines if the interruption happens on epoch level or mid-epoch. Can be either "epoch" or "step".

  • verbose (bool) – Set this to False to suppress logging messages.

예외 발생

MisconfigurationException – If interval is not one of the supported choices.

Example:

from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import Timer

# stop training after 12 hours
timer = Timer(duration="00:12:00:00")

# or provide a datetime.timedelta
from datetime import timedelta
timer = Timer(duration=timedelta(weeks=1))

# or provide a dictionary
timer = Timer(duration=dict(weeks=4, days=2))

# force training to stop after given time limit
trainer = Trainer(callbacks=[timer])

# query training/validation/test time (in seconds)
timer.time_elapsed("train")
timer.start_time("validate")
timer.end_time("test")
end_time(stage=RunningStage.TRAINING)[소스]

Return the end time of a particular stage (in seconds)

반환 형식

Optional[float]

load_state_dict(state_dict)[소스]

Called when loading a checkpoint, implement to reload callback state given callback’s state_dict.

매개변수

state_dict (Dict[str, Any]) – the callback state returned by state_dict.

반환 형식

None

on_fit_start(trainer, *args, **kwargs)[소스]

Called when fit begins.

반환 형식

None

on_test_end(trainer, pl_module)[소스]

Called when the test ends.

반환 형식

None

on_test_start(trainer, pl_module)[소스]

Called when the test begins.

반환 형식

None

on_train_batch_end(trainer, *args, **kwargs)[소스]

Called when the train batch ends.

참고

The value outputs["loss"] here will be the normalized value w.r.t accumulate_grad_batches of the loss returned from training_step.

반환 형식

None

on_train_end(trainer, pl_module)[소스]

Called when the train ends.

반환 형식

None

on_train_epoch_end(trainer, *args, **kwargs)[소스]

Called when the train epoch ends.

To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the pytorch_lightning.LightningModule and access them in this hook:

class MyLightningModule(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.training_step_outputs = []

    def training_step(self):
        loss = ...
        self.training_step_outputs.append(loss)
        return loss


class MyCallback(L.Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        # do something with all training_step outputs, for example:
        epoch_mean = torch.stack(pl_module.training_step_outputs).mean()
        pl_module.log("training_epoch_mean", epoch_mean)
        # free up the memory
        pl_module.training_step_outputs.clear()
반환 형식

None

on_train_start(trainer, pl_module)[소스]

Called when the train begins.

반환 형식

None

on_validation_end(trainer, pl_module)[소스]

Called when the validation loop ends.

반환 형식

None

on_validation_start(trainer, pl_module)[소스]

Called when the validation loop begins.

반환 형식

None

start_time(stage=RunningStage.TRAINING)[소스]

Return the start time of a particular stage (in seconds)

반환 형식

Optional[float]

state_dict()[소스]

Called when saving a checkpoint, implement to generate callback’s state_dict.

반환 형식

Dict[str, Any]

반환

A dictionary containing callback state.

time_elapsed(stage=RunningStage.TRAINING)[소스]

Return the time elapsed for a particular stage (in seconds)

반환 형식

float

time_remaining(stage=RunningStage.TRAINING)[소스]

Return the time remaining for a particular stage (in seconds)

반환 형식

Optional[float]