Created a method to downsample for unbalanced data (for binary classification)

I modified the code in the undersampling chapter of this page to make it work.

python


#Method to ensure that the number of positive and negative data in train data is the same
#Reference: https://qiita.com/ryouta0506/items/619d9ac0d80f8c0aed92
#Cluster with kmeans and sample at a constant rate for each cluster

# X :pandas DataFrame
# target_column_name :name of the class. "Onset flag" etc.
# minority_label :Decimal label value. For example, "1".

def under_sampling(X, target_column_name, minority_label):
    
    #Hide it because it appears every time
    import warnings
    warnings.simplefilter('ignore', pd.core.common.SettingWithCopyWarning)
    
    #Divide into majority and minority
    X_majority = X.query(f'{target_column_name} != {minority_label}')
    X_minority = X.query(f'{target_column_name} == {minority_label}')

    #Clustering with KMeans
    from sklearn.cluster import KMeans
    km = KMeans(random_state=43)
    km.fit(X_majority)
    X_majority['Cluster'] = km.predict(X_majority)

    #Calculate how many samples to extract for each cluster
    ratio = X_majority['Cluster'].value_counts() / X_majority.shape[0] 
    n_sample_ary = (ratio * X_minority.shape[0]).astype('int64').sort_index()
    
    #Extract samples for each cluster
    dfs = []
    for i, n_sample in enumerate(n_sample_ary):
        dfs.append(X_majority.query(f'Cluster == {i}').sample(n_sample))
    
    #Make sure to combine minority data as well
    dfs.append(X_minority)
    
    #Create data after undersampling
    X_new = pd.concat(dfs, sort=True)
    
    #Deleted because it is unnecessary
    X_new = X_new.drop('Cluster', axis=1)
    
    return X_new

Recommended Posts

Created a method to downsample for unbalanced data (for binary classification)
[Python] Created a method to convert radix in 1 second
A study method for beginners to learn time series analysis
[For beginners] How to register a library created in Python in PyPI
Created a service that allows you to search J League data
How to send a visualization image of data created in Python to Typetalk
Try using PHATE, a dimensionality reduction and visualization method for biological data
Publish a web application for viewing data created with Streamlit on heroku
Want to solve a simple classification problem?
[MNIST] Convert data to PNG for keras
How to use "deque" for Python data
I created a tool to correct GPS data with Mapbox Map Matching API (Mapbox Map Matching API)
Paste a link to the data point of the graph created by jupyterlab & matplotlib
A memorandum of method often used when analyzing data with pandas (for beginners)
[Introduction to Python] How to get the index of data with a for statement