Skip to content

[PyTorch] Give Different Loss Weights for Different Classification results to Solve the Problem of Data Imbalance

Today, I record some of my own experience in solving data imbalance problem of training model here. It may not be enough for the proficient and expert to refer to, but I hope it can help those who are still unclear on how to do it.

Imbalance data is a very scary thing, but most of the data in reality falls into this category.

Taking binary classification as an example, maybe I have 1000 pieces of data in Label_A, but only 10 pieces of data in Label_B. At this time, we build a classifier model, what kind of training results will there be?

Yes! our model may only guess all the classification results as Label_A! After all, from the loss function to view, the error is very low, and the weight of the model will not be changed in particular.

This is of course not the result we want. I encountered this situation when I actually trained the model. After querying on the Internet, there are roughly two clearer methods to improve it.

I record my test experience below.

By the way, the methods mentioned in this article are all used under the framework of PyTorch, and other machine learning and deep learning frameworks also have the same processing method.


Method One: Data Sampling

It may be a very intuitive way to sample the data: in simple terms, we have 1000 Label_A data and 10 Label_B data, if we train the model directly, the model tends to guess Label_A completely.

So, why don't we just randomly take 10 pieces of data from Label_A? In this way, the information on both sides is balanced.

But the shortcomings of this are also obvious: if we have too little data, them the model will be quite difficult to converge, and may just guess randomly, in the case, it may be worse than guessing Label_A all the way.

In addition, even if the model scores very high on the current test data set, it may lack the distinguishing ability of a certain feature (after all, some part of the training data is excluded).

But if you have very much data, I think it is worth a try.


Method Two: Give Different Labels Different Weights

This is the main theme I want to record today: Giving loss the weight of different labels, to achieve the goal of letting the model know how to guess another kind of label.

This is an example: we assume that the model is a pair of parents who have a pair of siblings, the elder sister performs well in study, but the younger brother has poor grade.

Today, the siblings got a C score in the exam at the same time. The parents were very angry with the sister, but they just sighed at the younger brother.

This is the different weights.

Today, we will keep the weight of Label_A incorrectly guessed by the model, but when the model guesses Label_B incorrectly, Loss is multiplied by a coefficient.

In this way, in order to reduce Loss, it will be automatically corrected when the model goes back to update the weight network. Cheng wants to "guess Label_B right", and this is exactly what we want to achieve.

By the way, I am here to record the weighting method of Binary Cross Entropy in PyTorch:


As you can see, we can directly set the Weight and enter it in BCELoss.


For example, I set the Weight directly during training. Here, I set the weight to 4 when label == 1, but the weight to 1 when label == 0.

Put it directly into BECLoss() and it will work normally.

But be careful, how much weight should be adjusted may have a great relationship with the proportion of labels in the data set, and may even affect the real classification results. I don't have any good answers on this point, I can only try it out and it seems that the weight is determined.


Retrospect

In fact, there are many methods that can be tried when faced with unbalanced data. For example, use GAN to generate data (which may not be appropriate for NLP), for example, use Ensemble's method to improve the classification effect…

In general, unbalanced data can be a very troublesome dilemma in the process of training the model, and sometimes even more features can be obtained to perform a good classification.

Many opinions on the Internet treat it as a task of "defect detection", but I think it may be more suitable for image processing, and this concept may be a bit model for text data with unbalanced labels. If there is any fallacy in my idea, please feel free to point it out.

I hope everyone can train their own model, because I also hope that my model can have good results.


Reference


Read More

Leave a Reply