Mistake 17: Wasting unlabeled data
In many applications it becomes difficult to label data but it is relatively easy to collect unlabeled data. For example, a huge amount of data from internet network traffic is generated every second. However, labeling it as legitimate or malicious may involve human manual inspection. This is where semi-supervised learning (SSL) methods come into play. SSL is ideal when only a small proportion of the data is labeled but still, there are tons of unlabeled examples. With classic supervised learning methods that only rely on fully labeled data, the remaining unlabeled data are discarded (wasted).
SSL algorithms not only try to learn from labeled data but also from unlabeled examples. While there exist many SSL methods (Van Engelen and Hoos 2020), here, I will focus in one of the simplest ones: self-learning. This algorithm starts with a base classifier and trains it with the available labeled data as usual. Then, the base classifier is used to predict the labels of the unlabeled instances. These are typically called pseudo-labels. Based on some criteria, the best instances along with their pseudo-labels are added to the original training set and the base model is retrained. This process is repeated until some stop condition is met. For example, when the unlabeled data is exhausted or no new instances were added in the last iteration. The criterion for choosing the best instances can be to select the top-\(k\) based on their prediction score. Another criterion is to select all instances whose prediction score was above a given threshold.
SSL classifiers depend on a set of assumptions (Van Engelen and Hoos 2020). One of them is that the unlabeled instances correspond to examples with the same labels and distribution as those contained in the training set. The smoothness assumption states that if two instances are close in the feature space, they should belong to the same class. The low-density assumption says that the decision boundaries of a model should pass through low-density regions and the manifold assumption that instances on the same lower-dimension space should belong to the same class.
To demonstrate the use of self-learning, we will use the DIGITS dataset. First, we load it into memory and split it into train and test sets (50/50). If we print the shape of the train set, we see that it has \(898\) instances.
data = load_digits()
X_train, X_test, y_train, y_test = train_test_split(data.data,
data.target,
test_size = 0.5,
random_state = 1234)
print(X_train.shape)
We will simulate that only \(100\) instances are labeled and the rest (\(798\)) will be unlabeled. To instruct the semi-supervised methods in scikit-learn that an instance is unlabeled, we need to set its label to \(-1\). The following code sets the label of the first \(798\) instances to \(-1\). The last instances will be the labeled ones.
unlabeled = 798 # Number of unlabeled instances.
# Set first n instances as unlabeled.
y_train[:unlabeled] = -1
# Check that the first n instances are unlabeled.
print(y_train)
In this example we will use a RandomForest
as the base model. Then, we instantiate a SelfTrainingClassifier
object that implements the self-learning approach. The first parameter is the base model (a random forest in this case). The threshold
specifies the minimum prediction score for an instance to be added to the train set. The verbose
parameter specifies if additional information is printed to the console when fitting the model. In this case we set it to True
because we want to analyze that information.
rf = RandomForestClassifier(n_estimators = 50, random_state = 123)
ss_model = SelfTrainingClassifier(rf, threshold = 0.95, verbose = True)
ss_model.fit(X_train, y_train)
#>> End of iteration 1, added 1 new labels.
#>> End of iteration 2, added 1 new labels.
#>> End of iteration 3, added 2 new labels.
#>> End of iteration 4, added 3 new labels.
#>> End of iteration 5, added 5 new labels.
By looking at the output, we can see that in the first iteration \(1\) new instance was added to the train set. This instance was the only one that met the threshold criteria specified as an argument. In the fourth iteration, \(3\) instances were added. By default, the maximum iterations are set to \(10\) but this can be specified with the max_iter
parameter. The algorithm stopped at the fifth iteration since no new predictions achieved a score of \(0.95\) or more.
Now that the self-training procedure is complete, a performance report can be generated on the test set.
#>> precision recall f1-score support
#>>
#>> 0 0.93 0.97 0.95 79
#>> 1 0.78 0.91 0.84 93
#>> 2 0.85 0.95 0.90 86
#>> 3 0.90 0.90 0.90 94
#>> 4 0.97 0.92 0.94 107
#>> 5 0.87 0.79 0.83 102
#>> 6 0.95 0.94 0.95 88
#>> 7 0.89 0.93 0.91 88
#>> 8 0.98 0.56 0.71 82
#>> 9 0.70 0.85 0.77 80
#>> accuracy 0.88 899
#>> macro avg 0.88 0.87 0.87 899
#>>weighted avg 0.89 0.88 0.87 899
The overall classification performance using only \(100\) instances in the train set was \(0.88\) which is not bad. If we follow a traditional supervised learning approach (see accompanying code) which is only able to use the labeled instances, the accuracy drops to \(0.85\).
While in this example there was a performance gain when using a semi-supervised approach, this also depends on the chosen parameters. For example, if the decision threshold is decreased, erroneous labels will make its way into the train set and as a consequence, the performance of the model will decrease.
SelfTrainingClassifier
and LabelPropagation
for SSL classification tasks.