Introduction#
Mathematical Background#
The standard BCE loss for a single sample is defined as:
$$\text{BCE}(z, y) = - \left[ y \cdot \log(p) + (1 - y) \cdot \log(1 - p) \right]$$
where \( p = \sigma(z) = \frac{1}{1 + e^{-z}} \) is the sigmoid of the logit \( z \), and \( y \) is the true label (0 or 1).
Underflow#
Direct computation of \( p \) can lead to overflow or underflow for large \( |z| \). A numerically stable alternative is:
$$\text{BCE}(z, y) = \max(z, 0) - y \cdot z + \log(1 + e^{-|z|})$$
This formulation avoids computing \( \sigma(z) \) directly, mitigating numerical issues.
Code#
def bce_loss(logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
"""
Numerically stable binary cross-entropy loss.
:param logits: Raw model outputs (logits), shape (batch_size,) or (batch_size, 1)
:param targets: Ground truth labels (0 or 1), same shape as logits
:return: Mean loss over the batch
"""
# Ensure logits are 1D
logits = logits.squeeze()
# Compute stable BCE: max(logits, 0) - logits * targets + log(1 + exp(-abs(logits)))
loss = torch.maximum(logits, torch.zeros_like(logits)) - logits * targets + torch.log1p(torch.exp(-torch.abs(logits)))
return loss.mean()
The implementation of the BCE loss function is crucial for understanding how to handle numerical stability during training :
Input Handling: The
squeeze()operation ensures logits are a 1D tensor, accommodating varying input shapes.Stable Computation: The formula leverages
torch.maximumandtorch.log1p(log(1 + x) for small x) to prevent overflow/underflow.Batch Averaging: The mean loss is returned, suitable for optimization. This implementation is critical for training models where logits may vary widely in magnitude, ensuring numerical reliability.
Applications: Binary classification tasks, including emotion detection, fraud detection, and medical diagnosis.
Performance Considerations: While this implementation is stable, batch size and input distribution can significantly affect training dynamics.
