대부분의 상황에서 분류기를 학습할 때, 객체 클래스에 비해 매우 많은 배경 클래스가 존재하는 클래스 불균형(Class Imbalance) 문제를 만나게 된다. 큰 차이의 클래스 불균형은 학습시 교차 엔트로피 손실(Cross Entropy Loss)에 영향을 줘 분류기가 쉽게 다수 클래스를 선택하게 만든다. Focal Loss는 교차 엔트로피 손실 함수를 다수 클래스의 가중치를 줄이도록 재구성하여 훈련 중 소수 클래스에 더 집중하게 만든다.
$$ FL(p_{t})=-(1-p_{t})^{\gamma}\log(p_{t}) $$
여기서 pt 는 신경망에서 softmax activation을 통과 후 예측된 확률값이고, 조정가능한 γ≥0 는 focusing 파라미터이다. Focal Loss는 기존 교차 엔트로피 손실에 Modulating factor (1-pt)^γ 를 추가하여 교차 엔트로피에 대한 가중치를 조정한다.
신경망이 잘못 분류하여 pt가 작으면, modulating factor는 1에 가까워지고 Focal Loss는 영향을 받지 않는다. 그러나 신경망이 잘 분류하여 pt가 1에 가까워지면 modulating factor는 0에 가까워지고 잘 분류한 클래스에 대한 손실의 가중치는 줄어들게 된다. 한편, Focusing parameter 는 다수 클래스의 가중치를 줄이는 비율을 부드럽게 조정한다. γ=0이면 FL는 교차 엔트로피 손실과 동일하고, γ가 증가하면 modulating factor에 대한 영향도 마찬가지로 증가하게 된다. 결론적으로 moduling factor는 쉬운 샘플에 대한 손실 기여도를 줄이게 된다.
pytorch 이용한 Focal Loss 구현
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
def __init__(self, gamma=0, alpha=None, size_average=True, device='cpu'):
super(FocalLoss, self).__init__()
"""
gamma(int) : focusing parameter.
alpha(list) : alpha-balanced term.
size_average(bool) : whether to apply reduction to the output.
"""
self.gamma = gamma
self.alpha = alpha
self.size_average = size_average
self.device = device
def forward(self, input, target):
# input : N * C (btach_size, num_class)
# target : N (batch_size)
CE = F.cross_entropy(input, target, reduction='none') # -log(pt)
pt = torch.exp(-CE) # pt
loss = (1 - pt) ** self.gamma * CE # -(1-pt)^rlog(pt)
if self.alpha is not None:
alpha = torch.tensor(self.alpha, dtype=torch.float).to(self.device)
# in case that a minority class is not selected when mini-batch sampling
if len(self.alpha) != len(torch.unique(target)):
temp = torch.zeros(len(self.alpha)).to(self.device)
temp[torch.unique(target)] = alpha.index_select(0, torch.unique(target))
alpha_t = temp.gather(0, target)
loss = alpha_t * loss
else:
alpha_t = alpha.gather(0, target)
loss = alpha_t * loss
if self.size_average:
loss = torch.mean(loss)
return loss
'PyTorch' 카테고리의 다른 글
[PyTorch] nn.LSTM input, output shape (0) | 2020.06.30 |
---|---|
[PyTorch] Training 헬퍼 함수 만들기 (0) | 2020.06.17 |
[PyTorch] last fully connected layer에 regularization 추가하기 (0) | 2020.06.17 |
[PyTorch] PyTorch 1.5.0 설치하기 (0) | 2020.06.16 |