PyTorch를 Lightning으로 구성하기¶
아래와 같이 PyTorch를 Lightning(라이트닝)으로 구성할 수 있습니다.
1. 연산 코드 가져오기¶
일반적인 nn.Module 구조를 가져옵니다
import lightning.pytorch as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
class LitModel(nn.Module):
def __init__(self):
super().__init__()
self.layer_1 = nn.Linear(28 * 28, 128)
self.layer_2 = nn.Linear(128, 10)
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.layer_1(x)
x = F.relu(x)
x = self.layer_2(x)
return x
2. 학습 로직 구성하기¶
LightningModule의 training_step에 학습 데이터를 묶음(batch)으로 가져와 학습하는 과정을 구성합니다:
class LitModel(pl.LightningModule):
def __init__(self, encoder):
super().__init__()
self.encoder = encoder
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.encoder(x)
loss = F.cross_entropy(y_hat, y)
return loss
참고
기존 프로젝트가 복잡해서 기존의 학습 루프를 직접 구성해야 하면 Own your loop 를 참조하세요.
3. Move Optimizer(s) and LR Scheduler(s)¶
옵티마이저(들)를 configure_optimizers()
훅(hook)으로 이동합니다.
class LitModel(pl.LightningModule):
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.encoder.parameters(), lr=1e-3)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]
4. (선택사항) 검증 로직 구성하기¶
검증(validation) 루프가 필요하면, 검증 데이터를 묶음(batch)으로 가져와 검증하는 과정을 구성합니다:
class LitModel(pl.LightningModule):
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.encoder(x)
val_loss = F.cross_entropy(y_hat, y)
self.log("val_loss", val_loss)
팁
학습(fit) 중 체크포인트 기능이 켜진 경우 trainer.validate()
가 자동으로 최적의 체크포인트를 불러옵니다.
5. (선택사항) 테스트 로직 구성하기¶
테스트(test) 루프가 필요하면, 테스트 데이터를 묶음(batch)으로 가져와 테스트하는 과정을 구성합니다:
class LitModel(pl.LightningModule):
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self.encoder(x)
test_loss = F.cross_entropy(y_hat, y)
self.log("test_loss", test_loss)
6. (선택사항) 예측 로직 구성하기¶
예측(prediction) 루프가 필요하면, 테스트 데이터를 묶음(batch)으로 가져와 예측하는 과정을 구성합니다:
class LitModel(LightningModule):
def predict_step(self, batch, batch_idx):
x, y = batch
pred = self.encoder(x)
return pred
7. .cuda() 또는 .to(device) 호출 제거하기¶
LightningModule.__init__
내에서 초기화된 Module
인스턴스들과 DataLoader
에서 가져온 데이터는
Lightning이 자동으로 해당 장치로 이동해서 실행하므로, 기존에 명시적으로 .cuda()
또는 .to(device)
을 호출하는 부분은 제거해도 됩니다.
그럼에도 장치(device)에 직접 접근해야 할 필요가 있다면, LightningModule
내부에서 (__init__
과 setup
메소드를 제외하고) 아무데서나
self.device
를 사용하면 됩니다.
class LitModel(LightningModule):
def training_step(self, batch, batch_idx):
z = torch.randn(4, 5, device=self.device)
...
Hint: LightningModule.__init__
메소드 내에서 Tensor
를 초기화하면서 자동으로 장치(device)로 이동하려면
register_buffer()
를 호출하여 매개변수로 등록하면 됩니다.
class LitModel(LightningModule):
def __init__(self):
super().__init__()
self.register_buffer("running_mean", torch.zeros(num_features))
8. 기존 데이터 사용하기¶
일반적인 PyTorch DataLoader는 Lightning에서 동작합니다. 더 모듈화되고 확장 가능한 데이터셋들은 LightningDataModule 를 참고하세요. —-
더 알아두기¶
추가로, validate()
메소드를 사용하면 검증(validation) 루프만 실행할 수 있습니다.
model = LitModel()
trainer.validate(model)
참고
model.eval()
와 torch.no_grad()
는 검증 시에 자동으로 호출됩니다.
테스트 루프(test loop)는 fit()
에서 사용되지 않으므로, 필요 시 명시적으로
test()
을 호출해야 합니다.
model = LitModel()
trainer.test(model)
참고
model.eval()
와 torch.no_grad()
는 테스트 시에 자동으로 호출됩니다.
팁
체크포인트 기능이 켜진 경우, trainer.test()
는 자동으로 최적의 체크포인트(best checkpoint)를 불러옵니다.
예측 루프(prediction look)는 predict()
을 호출하기 전에는 사용되지 않습니다.
model = LitModel()
trainer.predict(model)
참고
model.eval()
과 torch.no_grad()
는 예측 시에 자동으로 호출됩니다.
팁
체크포인트 기능이 켜진 경우, trainer.predict()
는 자동으로 최적의 체크포인트를 불러옵니다.