대부분의 상황에서 분류기를 학습할 때, 객체 클래스에 비해 매우 많은 배경 클래스가 존재하는 클래스 불균형(Class Imbalance) 문제를 만나게 된다. 큰 차이의 클래스 불균형은 학습시 교차 엔트로피 손실(Cross Entropy Loss)에 영향을 줘 분류기가 쉽게 다수 클래스를 선택하게 만든다. Focal Loss는 교차 엔트로피 손실 함수를 다수 클래스의 가중치를 줄이도록 재구성하여 훈련 중 소수 클래스에 더 집중하게 만든다.

 

$$ FL(p_{t})=-(1-p_{t})^{\gamma}\log(p_{t}) $$

 

 여기서 pt 는 신경망에서 softmax activation을 통과 후 예측된 확률값이고, 조정가능한 γ≥0 focusing 파라미터이다. Focal Loss는 기존 교차 엔트로피 손실에 Modulating factor (1-pt)^γ 를 추가하여 교차 엔트로피에 대한 가중치를 조정한다.

 

 신경망이 잘못 분류하여 pt가 작으면, modulating factor 1에 가까워지고 Focal Loss는 영향을 받지 않는다. 그러나 신경망이 잘 분류하여 pt1에 가까워지면 modulating factor0에 가까워지고 잘 분류한 클래스에 대한 손실의 가중치는 줄어들게 된다. 한편, Focusing parameter 는 다수 클래스의 가중치를 줄이는 비율을 부드럽게 조정한다. γ=0이면 FL는 교차 엔트로피 손실과 동일하고, γ가 증가하면 modulating factor에 대한 영향도 마찬가지로 증가하게 된다. 결론적으로 moduling factor는 쉬운 샘플에 대한 손실 기여도를 줄이게 된다.

 

pytorch 이용한 Focal Loss 구현

import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalLoss(nn.Module):
    def __init__(self, gamma=0, alpha=None, size_average=True, device='cpu'):
        super(FocalLoss, self).__init__()
        """
        gamma(int) : focusing parameter.
        alpha(list) : alpha-balanced term.
        size_average(bool) : whether to apply reduction to the output.
        """
        self.gamma = gamma
        self.alpha = alpha
        self.size_average = size_average
        self.device = device

    def forward(self, input, target):
        # input : N * C (btach_size, num_class)
        # target : N (batch_size)

        CE = F.cross_entropy(input, target, reduction='none')  # -log(pt)
        pt = torch.exp(-CE)  # pt
        loss = (1 - pt) ** self.gamma * CE  # -(1-pt)^rlog(pt)

        if self.alpha is not None:
            alpha = torch.tensor(self.alpha, dtype=torch.float).to(self.device)
            # in case that a minority class is not selected when mini-batch sampling
            if len(self.alpha) != len(torch.unique(target)):
                temp = torch.zeros(len(self.alpha)).to(self.device)
                temp[torch.unique(target)] = alpha.index_select(0, torch.unique(target))
                alpha_t = temp.gather(0, target)
                loss = alpha_t * loss
            else:
                alpha_t = alpha.gather(0, target)
                loss = alpha_t * loss

        if self.size_average:
            loss = torch.mean(loss)

        return loss

gamma = 0이면 Focal Loss는 CrossEntropyLoss와 동일

 

gamma = 5일때 Focal Loss
alpha-balanced term을 적용한 Focal Loss

 

 

+ Recent posts