Shortcuts

TQDMProgressBar

class lightning.pytorch.callbacks.TQDMProgressBar(refresh_rate=1, process_position=0)[소스]

기반 클래스: lightning.pytorch.callbacks.progress.progress_bar.ProgressBar

This is the default progress bar used by Lightning. It prints to stdout using the tqdm package and shows up to four different bars:

  • sanity check progress: the progress during the sanity check run

  • train progress: shows the training progress. It will pause if validation starts and will resume when it ends, and also accounts for multiple validation runs during training when val_check_interval is used.

  • validation progress: only visible during validation; shows total progress over all validation datasets.

  • test progress: only active when testing; shows total progress over all test datasets.

For infinite datasets, the progress bar never ends.

If you want to customize the default tqdm progress bars used by Lightning, you can override specific methods of the callback class and pass your custom implementation to the Trainer.

예제

>>> class LitProgressBar(TQDMProgressBar):
...     def init_validation_tqdm(self):
...         bar = super().init_validation_tqdm()
...         bar.set_description('running validation ...')
...         return bar
...
>>> bar = LitProgressBar()
>>> from lightning.pytorch import Trainer
>>> trainer = Trainer(callbacks=[bar])
매개변수
  • refresh_rate (int) – Determines at which rate (in number of batches) the progress bars get updated. Set it to 0 to disable the display.

  • process_position (int) – Set this to a value greater than 0 to offset the progress bars by this many lines. This is useful when you have progress bars defined elsewhere and want to show all of them together. This corresponds to process_position in the Trainer.

disable()[소스]

You should provide a way to disable the progress bar.

반환 형식

None

enable()[소스]

You should provide a way to enable the progress bar.

The Trainer will call this in e.g. pre-training routines like the learning rate finder. to temporarily enable and disable the training progress bar.

반환 형식

None

init_predict_tqdm()[소스]

Override this to customize the tqdm bar for predicting.

반환 형식

Tqdm

init_sanity_tqdm()[소스]

Override this to customize the tqdm bar for the validation sanity run.

반환 형식

Tqdm

init_test_tqdm()[소스]

Override this to customize the tqdm bar for testing.

반환 형식

Tqdm

init_train_tqdm()[소스]

Override this to customize the tqdm bar for training.

반환 형식

Tqdm

init_validation_tqdm()[소스]

Override this to customize the tqdm bar for validation.

반환 형식

Tqdm

on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)[소스]

Called when the predict batch ends.

반환 형식

None

on_predict_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)[소스]

Called when the predict batch begins.

반환 형식

None

on_predict_end(trainer, pl_module)[소스]

Called when predict ends.

반환 형식

None

on_predict_start(trainer, pl_module)[소스]

Called when the predict begins.

반환 형식

None

on_sanity_check_end(*_)[소스]

Called when the validation sanity check ends.

반환 형식

None

on_sanity_check_start(*_)[소스]

Called when the validation sanity check starts.

반환 형식

None

on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)[소스]

Called when the test batch ends.

반환 형식

None

on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)[소스]

Called when the test batch begins.

반환 형식

None

on_test_end(trainer, pl_module)[소스]

Called when the test ends.

반환 형식

None

on_test_start(trainer, pl_module)[소스]

Called when the test begins.

반환 형식

None

on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)[소스]

Called when the train batch ends.

참고

The value outputs["loss"] here will be the normalized value w.r.t accumulate_grad_batches of the loss returned from training_step.

반환 형식

None

on_train_end(*_)[소스]

Called when the train ends.

반환 형식

None

on_train_epoch_end(trainer, pl_module)[소스]

Called when the train epoch ends.

To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the pytorch_lightning.LightningModule and access them in this hook:

class MyLightningModule(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.training_step_outputs = []

    def training_step(self):
        loss = ...
        self.training_step_outputs.append(loss)
        return loss


class MyCallback(L.Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        # do something with all training_step outputs, for example:
        epoch_mean = torch.stack(pl_module.training_step_outputs).mean()
        pl_module.log("training_epoch_mean", epoch_mean)
        # free up the memory
        pl_module.training_step_outputs.clear()
반환 형식

None

on_train_epoch_start(trainer, *_)[소스]

Called when the train epoch begins.

반환 형식

None

on_train_start(*_)[소스]

Called when the train begins.

반환 형식

None

on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)[소스]

Called when the validation batch ends.

반환 형식

None

on_validation_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)[소스]

Called when the validation batch begins.

반환 형식

None

on_validation_end(trainer, pl_module)[소스]

Called when the validation loop ends.

반환 형식

None

on_validation_start(trainer, pl_module)[소스]

Called when the validation loop begins.

반환 형식

None

print(*args, sep=' ', **kwargs)[소스]

You should provide a way to print without breaking the progress bar.

반환 형식

None