Skip to main content
  1. Posts/

Binary Cross-Entropy Loss

·2 mins· loading · loading · ·
Table of Contents

Introduction
#

Binary Cross-Entropy (BCE) loss is a cornerstone of binary classification tasks in machine learning. However, its standard implementation can encounter numerical instability when dealing with very large or small logits. This post walks through the implementation of a numerically stable BCE loss function in PyTorch, ensuring robustness during model training.

Mathematical Background
#

The standard BCE loss for a single sample is defined as:

$$\text{BCE}(z, y) = - \left[ y \cdot \log(p) + (1 - y) \cdot \log(1 - p) \right]$$

where \( p = \sigma(z) = \frac{1}{1 + e^{-z}} \) is the sigmoid of the logit \( z \), and \( y \) is the true label (0 or 1).

Underflow
#

Direct computation of \( p \) can lead to overflow or underflow for large \( |z| \). A numerically stable alternative is:

$$\text{BCE}(z, y) = \max(z, 0) - y \cdot z + \log(1 + e^{-|z|})$$

This formulation avoids computing \( \sigma(z) \) directly, mitigating numerical issues.

Code
#

def bce_loss(logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
    """
    Numerically stable binary cross-entropy loss.
    :param logits: Raw model outputs (logits), shape (batch_size,) or (batch_size, 1)
    :param targets: Ground truth labels (0 or 1), same shape as logits
    :return: Mean loss over the batch
    """
    # Ensure logits are 1D
    logits = logits.squeeze()
    # Compute stable BCE: max(logits, 0) - logits * targets + log(1 + exp(-abs(logits)))
    loss = torch.maximum(logits, torch.zeros_like(logits)) - logits * targets + torch.log1p(torch.exp(-torch.abs(logits)))
    return loss.mean()

The implementation of the BCE loss function is crucial for understanding how to handle numerical stability during training :

  • Input Handling: The squeeze() operation ensures logits are a 1D tensor, accommodating varying input shapes.

  • Stable Computation: The formula leverages torch.maximum and torch.log1p (log(1 + x) for small x) to prevent overflow/underflow.

  • Batch Averaging: The mean loss is returned, suitable for optimization. This implementation is critical for training models where logits may vary widely in magnitude, ensuring numerical reliability.

  • Applications: Binary classification tasks, including emotion detection, fraud detection, and medical diagnosis.

  • Performance Considerations: While this implementation is stable, batch size and input distribution can significantly affect training dynamics.

Siddanth Emani
Author
Siddanth Emani
Data Scientist with 4+ years experience

Related

Linear Regression
·5 mins· loading · loading
Logistic Regression
·5 mins· loading · loading
What is ResNet?
·2 mins· loading · loading