Post

Focal loss에 관하여

Focal loss에 관하여

🟣 ‘Focal Loss’ for Semantic Segmentation

⚪ Focal Loss란?

딥러닝에서 가장 기본적인 loss 함수 중 하나는 Cross Entropy Loss다. 하지만 다음과 같은 상황에서는 성능이 떨어질 수 있다:

  • Class imbalance가 심한 경우 (예: 배경 vs. 소수 클래스)
  • 모델이 쉬운 예시(easy examples)에 너무 집중하고, 어려운 예(hard examples)를 무시할 때

이럴 때 등장하는 게 바로 Focal Loss이다.

Focal Loss는 어려운 예제(hard examples)에 더 집중하도록 설계된 Cross Entropy의 확장판.


⚪ 수식 정리

기본적인 Cross Entropy Loss는 다음과 같다:

\[\text{CE}(p_t) = -\log(p_t)\]

여기서:
$p_t$는 정답 클래스에 대한 예측 확률 (0~1)

Focal Loss는 여기에 조절 계수 modulating factor를 곱해서, 쉬운 예제의 loss를 줄이고, 어려운 예제의 loss는 유지한다:

\(\text{FL}(p_t) = -(1 - p_t)^{\gamma} \cdot \log(p_t)\) Focal

여기서:

  • $\gamma$는 focusing parameter (보통 2 사용)
  • $(1 - p_t)^\gamma$는 예측이 쉬운 경우(= $p_t$가 클수록) loss를 줄이는 역할

⚪ 해석 예시

상황$ p_t $Modulating Factor $ (1 - p_t)^\gamma $결과
예측이 매우 정확함0.95$ (1 - 0.95)^2 = 0.0025 $거의 무시됨
예측이 매우 틀림0.1$(1 - 0.1)^2 = 0.81 $강하게 페널티 부여

즉, 쉬운 샘플은 무시하고, 어려운 샘플에 집중한다!


⚪ PyTorch 코드 예시

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.ce = nn.CrossEntropyLoss(reduction='none')  # pixel별 loss 계산

    def forward(self, logits, targets):
        ce_loss = self.ce(logits, targets)  # shape: (B, H, W)
        pt = torch.exp(-ce_loss)            # pt = softmax score of true class
        focal_loss = ((1 - pt) ** self.gamma) * ce_loss
        return focal_loss.mean()

⚪ 언제 사용하면 좋을까?

  • Semantic segmentation에서 foreground 클래스가 매우 적을 때
  • Object detection (ex: RetinaNet)에서도 Focal Loss가 공식적으로 사용됨
  • 클래스 불균형이 심할수록 효과가 크다

🟣 정리

항목내용
이름Focal Loss
공식$FL(p_t) = -(1 - p_t)^\gamma \log(p_t)$
하이퍼파라미터$\gamma \in [0, 5]$ (보통 2)
장점어려운 샘플에 집중, 클래스 불균형 완화
단점일반적인 문제에서는 오히려 과적합을 유발할 수도 있음

Reference

This post is licensed under CC BY 4.0 by the author.

© 2025 Soohyun Jeon ⭐

🌱 Mostly to remember, sometimes to understand.