Generating predictions from a neural network
For any given observation, there can be a probability of membership in any of a number of classes (for example, an observation may have a 40% chance of being a 5, a 20% chance of being a 6, and so on). To evaluate the performance of the model, some choices have to be made about how to go from the probability of class membership to a discrete classification. In this section, we will explore a few of these options in more detail.
As long as there are no perfect ties, the simplest method is to classify observations based on the highest predicted probability. Another approach, which the RSNNS package calls the winner takes all (WTA) method, chooses the class with the highest probability, provided the following conditions are met:
- There are no ties for highest probabilities
- The highest probability is above a user-defined threshold (the threshold could be zero)
- The remaining classes all have a predicted probability under the maximum minus another user-defined threshold
Otherwise, observations are classified as unknown. If both thresholds are zero (the default), this equates to saying that there must be one unique maximum. The advantage of such an approach is that it provides some quality control. In the digit-classification example we have been exploring, there are 10 possible classes.
Suppose 9 of the digits had a predicted probability of 0.099, and the remaining class had a predicted probability of 0.109. Although one class is technically more likely than the others, the difference is fairly trivial and we may conclude that the model cannot with any certainty classify that observation. A final method, called 402040, classifies if only one value is above a user-defined threshold, and all other values are below another user-defined threshold; if multiple values are above the first threshold, or any value is not below the second threshold, it treats the observation as unknown. Again, the goal here is to provide some quality control.
It may seem like this is unnecessary because uncertainty in predictions should come out in the model performance. However, it can be helpful to know if your model was highly certain in its prediction and right or wrong, or uncertain and right or wrong.
Finally, in some cases, not all classes are equally important. For example, in a medical context where a variety of biomarkers and genes are collected on patients and used to classify whether they are, at risk of cancer, or at risk of heart disease, even a 40% chance of having cancer may be enough to warrant further investigation, even if they have a 60% chance of being healthy. This has to do with the performance measures we saw earlier where, beyond overall accuracy, we can assess aspects such as sensitivity, specificity, and positive and negative predictive values. There are cases where overall accuracy is less important than making sure no one is missed.
The following code shows the raw probabilities for the in-sample data, and the impact these different choices have on the predicted values:
digits.yhat4_b <- predict(digits.m4,newdata=digits.test.X)
head(round(digits.yhat4_b, 2))
[,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10]
18986 0.00 0.00 0.00 0.98 0.00 0.02 0.00 0.00 0.00 0.00
41494 0.00 0.00 0.03 0.00 0.13 0.01 0.95 0.00 0.00 0.00
21738 0.00 0.00 0.02 0.03 0.00 0.46 0.01 0.00 0.74 0.00
37086 0.00 0.01 0.00 0.63 0.02 0.01 0.00 0.00 0.03 0.00
35532 0.00 0.00 0.00 0.00 0.01 0.00 0.00 0.99 0.00 0.00
17889 0.03 0.00 0.00 0.00 0.00 0.34 0.01 0.00 0.00 0.00
table(encodeClassLabels(digits.yhat4_b,method = "WTA", l = 0, h = 0))
1 2 3 4 5 6 7 8 9 10
102 116 104 117 93 66 93 127 89 93
table(encodeClassLabels(digits.yhat4_b,method = "WTA", l = 0, h = .5))
0 1 2 3 4 5 6 7 8 9 10
141 95 113 86 93 67 53 89 116 73 74
table(encodeClassLabels(digits.yhat4_b,method = "WTA", l = .2, h = .5))
0 1 2 3 4 5 6 7 8 9 10
177 91 113 77 91 59 50 88 116 70 68
table(encodeClassLabels(digits.yhat4_b,method = "402040", l = .4, h = .6))
0 1 2 3 4 5 6 7 8 9 10
254 89 110 71 82 46 41 79 109 65 54
We now proceed to examine problems related to overfitting the data and the impact on the evaluation of the model's performance.