Shortcuts

PyTorch를 Lightning으로 구성하기

아래와 같이 PyTorch를 Lightning(라이트닝)으로 구성할 수 있습니다.


1. 연산 코드 가져오기

일반적인 nn.Module 구조를 가져옵니다

import lightning.pytorch as pl
import torch
import torch.nn as nn
import torch.nn.functional as F


class LitModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer_1 = nn.Linear(28 * 28, 128)
        self.layer_2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.layer_1(x)
        x = F.relu(x)
        x = self.layer_2(x)
        return x

2. 학습 로직 구성하기

LightningModule의 training_step에 학습 데이터를 묶음(batch)으로 가져와 학습하는 과정을 구성합니다:

class LitModel(pl.LightningModule):
    def __init__(self, encoder):
        super().__init__()
        self.encoder = encoder

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.encoder(x)
        loss = F.cross_entropy(y_hat, y)
        return loss

참고

기존 프로젝트가 복잡해서 기존의 학습 루프를 직접 구성해야 하면 Own your loop 를 참조하세요.


3. Move Optimizer(s) and LR Scheduler(s)

옵티마이저(들)를 configure_optimizers() 훅(hook)으로 이동합니다.

class LitModel(pl.LightningModule):
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.encoder.parameters(), lr=1e-3)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
        return [optimizer], [lr_scheduler]

4. (선택사항) 검증 로직 구성하기

검증(validation) 루프가 필요하면, 검증 데이터를 묶음(batch)으로 가져와 검증하는 과정을 구성합니다:

class LitModel(pl.LightningModule):
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.encoder(x)
        val_loss = F.cross_entropy(y_hat, y)
        self.log("val_loss", val_loss)

학습(fit) 중 체크포인트 기능이 켜진 경우 trainer.validate() 가 자동으로 최적의 체크포인트를 불러옵니다.


5. (선택사항) 테스트 로직 구성하기

테스트(test) 루프가 필요하면, 테스트 데이터를 묶음(batch)으로 가져와 테스트하는 과정을 구성합니다:

class LitModel(pl.LightningModule):
    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.encoder(x)
        test_loss = F.cross_entropy(y_hat, y)
        self.log("test_loss", test_loss)

6. (선택사항) 예측 로직 구성하기

예측(prediction) 루프가 필요하면, 테스트 데이터를 묶음(batch)으로 가져와 예측하는 과정을 구성합니다:

class LitModel(LightningModule):
    def predict_step(self, batch, batch_idx):
        x, y = batch
        pred = self.encoder(x)
        return pred

7. .cuda() 또는 .to(device) 호출 제거하기

LightningModule.__init__ 내에서 초기화된 Module 인스턴스들과 DataLoader 에서 가져온 데이터는 Lightning이 자동으로 해당 장치로 이동해서 실행하므로, 기존에 명시적으로 .cuda() 또는 .to(device) 을 호출하는 부분은 제거해도 됩니다.

그럼에도 장치(device)에 직접 접근해야 할 필요가 있다면, LightningModule 내부에서 (__init__setup 메소드를 제외하고) 아무데서나 self.device 를 사용하면 됩니다.

class LitModel(LightningModule):
    def training_step(self, batch, batch_idx):
        z = torch.randn(4, 5, device=self.device)
        ...

Hint: LightningModule.__init__ 메소드 내에서 Tensor 를 초기화하면서 자동으로 장치(device)로 이동하려면 register_buffer() 를 호출하여 매개변수로 등록하면 됩니다.

class LitModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.register_buffer("running_mean", torch.zeros(num_features))

8. 기존 데이터 사용하기

일반적인 PyTorch DataLoader는 Lightning에서 동작합니다. 더 모듈화되고 확장 가능한 데이터셋들은 LightningDataModule 를 참고하세요. —-

더 알아두기

추가로, validate() 메소드를 사용하면 검증(validation) 루프만 실행할 수 있습니다.

model = LitModel()
trainer.validate(model)

참고

model.eval()torch.no_grad() 는 검증 시에 자동으로 호출됩니다.

테스트 루프(test loop)는 fit() 에서 사용되지 않으므로, 필요 시 명시적으로 test() 을 호출해야 합니다.

model = LitModel()
trainer.test(model)

참고

model.eval()torch.no_grad() 는 테스트 시에 자동으로 호출됩니다.

체크포인트 기능이 켜진 경우, trainer.test() 는 자동으로 최적의 체크포인트(best checkpoint)를 불러옵니다.

예측 루프(prediction look)는 predict() 을 호출하기 전에는 사용되지 않습니다.

model = LitModel()
trainer.predict(model)

참고

model.eval()torch.no_grad() 는 예측 시에 자동으로 호출됩니다.

체크포인트 기능이 켜진 경우, trainer.predict() 는 자동으로 최적의 체크포인트를 불러옵니다.