Breaking the Gradient: Supervised Learning with Non-Differentiable Loss Functions
Zian (Andy) Wang
Supervised learning is a fundamental machine learning task involving training a model to make predictions based on labeled data. The process typically starts with inputting data into the model and using it to generate predictions. These predictions are then compared against the ground-truth data through a function that quantifies the prediction error, which we refer to as the loss/objective function. To improve the model's performance, we need to update its parameters. We then take the gradient of the loss function with respect to the network's parameters, which allows us to determine the direction of parameter update that will "nudge" the model towards the global minima. This prediction, evaluation, and update cycle is repeated until the model converges. The above process only generalizes to Deep Learning approaches since not all classical machine learning models follow the same training scheme.
Aside from the model, the next crucial component in the training process is the loss function. In supervised learning, the loss function assumes the role of an "optimizee." At the same time, the optimizer, an algorithm outlining how parameter updates are executed, optimizes the loss function through its computed gradients. The gradient of the loss function can be seen as a tool that "sculpts" out the loss landscape; without it, the model would be no better off than a blind man traveling tens of thousands of mountain ranges because the model cannot even determine which way is up or down!
The chosen loss function must therefore be differentiable and accurately reflect the model's performance. However, in the ever-so-intricate world of today, not every problem can be described and optimized by a differentiable function, and not every carefully crafted objective function is differentiable.
The Problem With Most Loss Functions
To illustrate, in the scenario of biomedical image segmentation, the usual dice loss and Jaccard Index (IoU) might not suffice. The model is often required to identify small objects in a complex medical image, which can span hundreds or even thousands of pixels in resolution. Using only the amount of overlap between the predicted mask and the ground truth (using the dice loss) to measure error can be problematic. This is because it doesn't give the model any feedback about how close its predictions are to the location of the ground truth when there is no overlap between the pixels. Other cases in which the loss function may be non-differentiable are when the model's predictions are used as an intermediate stage in a pipeline that generates labels. For example, the model may predict specific characteristics of a 3D structure, but the actual labels can only be generated once the model predictions are fed into third-party rendering software. In such cases, the loss function cannot be calculated directly as the network cannot "back propagate" through the third-party software.
The truth is, there are many more real-world situations where there are no concrete roadblocks preventing engineers from crafting a differentiable loss function, but rather that the function needs to articulate more information to the model for it to improve and converge.