Focal LossΒΆ
TensorFlow implementation of focal loss: a loss function generalizing binary and multiclass cross-entropy loss that penalizes hard-to-classify examples.
The focal_loss
package provides functions and classes that can be used as
off-the-shelf replacements for tf.keras.losses
functions and classes,
respectively.
# Typical tf.keras API usage
import tensorflow as tf
from focal_loss import BinaryFocalLoss
model = tf.keras.Model(...)
model.compile(
optimizer=...,
loss=BinaryFocalLoss(gamma=2), # Used here like a tf.keras loss
metrics=...,
)
history = model.fit(...)
The focal_loss
package includes the functions
and wrapper classes
BinaryFocalLoss
(use liketf.keras.losses.BinaryCrossentropy
)SparseCategoricalFocalLoss
(use liketf.keras.losses.SparseCategoricalCrossentropy
)