0

我是机器学习和深度学习的新手。我想澄清我对train_test_split培训前的疑问

我有一个 size 的数据集(302, 100, 5),其中,

(207,100,5)属于class 0

(95,100,5) 属于class 1.

我想使用 LSTM 进行分类(因为,序列数据)

我如何分割我的数据集进行训练,因为这些类没有相等的分布集?

选项 1:考虑整个数据[(302,100, 5) - both classes (0 & 1)],对其进行洗牌,train_test_split,继续训练。

选项 2:平均拆分两个类数据集 [(95,100,5) - class 0 & (95,100,5) - class 1],打乱它,train_test_split,继续训练。

在训练之前进行拆分的更好方法是什么,以便在减少损失、准确性、预测方面获得更好的结果?

如果有其他选项而不是以上 2 个选项,请推荐,

根据评论部分,我包括了我的一部分数据:

X_train:形状(241 * 100 * 5)

每100*5中的每一行对应1个时间步最后100行对应100个时间步,单位是毫秒(ms)

array([[[0.98620635, 0.        , 0.12752912, 0.60897341, 0.46903766],
        [0.97345112, 0.        , 0.12752912, 0.49205995, 0.38709902],
        [0.9566397 , 0.        , 0.12752912, 0.45728718, 0.42154812],
        ...,
        [0.28669754, 0.8852459 , 0.12752912, 0.8786213 , 0.80125523],
        [0.31559784, 0.8852459 , 0.20968731, 0.89087803, 0.79476987],
        [0.34368841, 0.8852459 , 0.12752912, 0.89087803, 0.71066946]],

       [[0.97957188, 0.14909194, 0.04159147, 0.50548561, 0.34209531],
        [0.9687237 , 0.13964397, 0.04159147, 0.55926067, 0.64613533],
        [0.96596236, 0.13553813, 0.04159147, 0.55903796, 0.85299319],
        ...,
        [0.49309139, 0.72396527, 0.04159147, 0.81998825, 0.12362443],
        [0.52072591, 0.70872926, 0.04159147, 0.82361951, 0.89639432],
        [0.54441507, 0.71835207, 0.04159147, 0.84964602, 1.        ]],

       [[0.48151381, 0.875     , 0.16666667, 0.90637286, 0.62737926],
        [0.53325374, 0.8625    , 0.33333333, 0.87881677, 0.5321154 ],
        [0.57506452, 0.81859091, 0.16666667, 0.84915758, 0.3552661 ],
        ...,
        [0.34456041, 0.92993213, 0.33333333, 0.92953899, 0.78782408],
        [0.39496018, 0.90523485, 0.33333333, 0.9117954 , 0.54579383],
        [0.44187985, 0.8625    , 0.33333333, 0.84163194, 0.25789356]],

       ...,

       [[0.16368355, 0.        , 0.15313225, 0.40101906, 0.36784741],
        [0.15679684, 0.        , 0.15313225, 0.4435126 , 0.67351994],
        [0.15544309, 0.06132052, 0.15313225, 0.40101906, 0.36611345],
        ...,
        [0.43936628, 0.68292683, 0.15313225, 0.82305329, 0.36784741],
        [0.49751546, 0.68292683, 0.07764888, 0.84141109, 0.42828833],
        [0.53288488, 0.68292683, 0.15313225, 0.85959823, 0.36784741]],

       [[0.9418247 , 0.30821318, 0.03072816, 0.744977  , 0.93769733],
        [0.9537216 , 0.28989357, 0.03072816, 0.74576381, 0.98468743],
        [0.96455286, 0.21736423, 0.03072816, 0.74182977, 1.        ],
        ...,
        [0.36273884, 0.60113245, 0.06145633, 0.85409181, 0.32277415],
        [0.38774614, 0.57789971, 0.05844559, 0.82937631, 0.        ],
        [0.41546859, 0.57789971, 0.03072816, 0.79315883, 0.31256578]],

       [[0.97868688, 0.06451613, 0.00411829, 0.64705259, 0.69827586],
        [0.97999663, 0.06451613, 0.02256676, 0.66812232, 0.75195925],
        [0.97143037, 0.02476377, 0.02256676, 0.66317859, 0.78487461],
        ...,
        [0.50336862, 0.73867709, 0.02256676, 0.84921606, 0.1226489 ],
        [0.54003486, 0.72043011, 0.02256676, 0.82679269, 0.20297806],
        [0.57594039, 0.70967742, 0.02256676, 0.83350205, 0.        ]]])

Y_train : 形状 (241,)

[1. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 0. 0. 1. 1. 0. 0. 0.
 1. 1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 1. 1.
 0. 0. 1. 0. 0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 1. 1. 0. 0. 1. 0. 1. 0. 1. 0.
 0. 1. 0. 0. 1. 1. 0. 0. 0. 1. 1. 0. 0. 1. 0. 0. 1. 0. 0. 1. 0. 0. 0. 1.
 1. 0. 0. 1. 0. 1. 0. 0. 0. 1. 1. 0. 0. 0. 0. 1. 1. 1. 0. 0. 0. 0. 0. 1.
 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 1. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 1. 1. 0. 1. 1. 0. 0. 0. 0. 0. 1. 0.
 0. 0. 0. 0. 0. 0. 0. 1. 0. 1. 0. 1. 0. 0. 0. 1. 1. 0. 0. 1. 1. 1. 0. 1.
 0. 1. 0. 1. 0. 0. 0. 1. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.
 1. 0. 0. 1. 1. 1. 0. 1. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 1.
 1.]

作为参考,正如您在上面看到的,X-train 数据很大,我无法包含我的整个 X_train 数据的完整集。所以我在这里只提供我的数据的一个片段,以便更好地理解我的数据对于 1 个片段的样子,(i.e X_train[0] : shape- (100*5)). 其余的240将或多或少如下所示

array([[9.86206354e-01, 0.00000000e+00, 1.27529123e-01, 2.29139335e-02,
        6.08973407e-01, 4.69037657e-01],
       [9.73451120e-01, 0.00000000e+00, 1.27529123e-01, 2.60807671e-02,
        4.92059955e-01, 3.87099024e-01],
       [9.56639704e-01, 0.00000000e+00, 1.27529123e-01, 2.64184174e-02,
        4.57287179e-01, 4.21548117e-01],
       [9.34897700e-01, 0.00000000e+00, 1.27529123e-01, 2.64184174e-02,
        4.84177685e-01, 4.69037657e-01],
       [9.18030989e-01, 0.00000000e+00, 1.27529123e-01, 2.64184174e-02,
        4.86406180e-01, 4.08577406e-01],
       [9.02168015e-01, 0.00000000e+00, 1.27529123e-01, 2.64020795e-02,
        4.84920517e-01, 4.04184100e-01],
       [8.82551572e-01, 0.00000000e+00, 1.27529123e-01, 2.56783096e-02,
        4.51195959e-01, 3.78661088e-01],
       [8.69975342e-01, 0.00000000e+00, 1.27529123e-01, 2.40477851e-02,
        4.70286733e-01, 4.23640167e-01],
       [8.41027241e-01, 0.00000000e+00, 1.27529123e-01, 1.75387576e-02,
        5.04754123e-01, 4.34728033e-01],
       [8.28189535e-01, 5.28763040e-01, 1.27529123e-01, 6.89133486e-03,
        4.98662903e-01, 4.58368201e-01],
       [8.21784739e-01, 8.21162444e-01, 1.27529123e-01, 1.06196483e-02,
        5.87431288e-01, 5.72594142e-01],
       [8.26651597e-01, 9.96721311e-01, 1.27529123e-01, 1.75044480e-02,
        6.89050661e-01, 5.40376569e-01],
       [8.42115326e-01, 1.00000000e+00, 1.27529123e-01, 1.71205069e-02,
        8.35388501e-01, 4.69037657e-01],
       [8.64071009e-01, 9.26875310e-01, 1.27529123e-01, 1.34068975e-02,
        1.00000000e+00, 4.65062762e-01],
       [8.79579724e-01, 7.60158967e-01, 1.27529123e-01, 4.65303975e-03,
        9.61744169e-01, 3.65481172e-01],
       [9.03630040e-01, 7.61549925e-01, 1.27529123e-01, 4.21518348e-03,
        9.22076957e-01, 3.78033473e-01],
       [9.18435858e-01, 6.72429210e-01, 1.27529123e-01, 2.70229205e-03,
        9.39979201e-01, 5.03138075e-01],
       [9.29983046e-01, 6.85345256e-01, 1.27529123e-01, 9.05120794e-04,
        8.53736443e-01, 5.52510460e-01],
       [9.48081232e-01, 5.78539493e-01, 1.27529123e-01, 6.96485550e-03,
        8.84415391e-01, 3.04602510e-01],
       [9.48112160e-01, 5.55091903e-01, 1.27529123e-01, 1.10493356e-02,
        8.19046204e-01, 4.78661088e-01],
       [9.61281634e-01, 5.08693492e-01, 1.27529123e-01, 9.36162843e-03,
        8.23651761e-01, 3.21548117e-01],
       [9.72179346e-01, 4.91803279e-01, 1.27529123e-01, 9.82725917e-03,
        7.57391175e-01, 4.96025105e-01],
       [9.84752763e-01, 4.91803279e-01, 1.27529123e-01, 7.04491131e-03,
        7.59322538e-01, 3.95397490e-01],
       [9.90300024e-01, 4.91803279e-01, 1.27529123e-01, 8.19346712e-03,
        7.64819492e-01, 4.69037657e-01],
       [9.88306609e-01, 3.77049180e-01, 1.27529123e-01, 8.62642201e-03,
        7.93492795e-01, 4.16945607e-01],
       [9.91084457e-01, 3.93442623e-01, 1.27529123e-01, 9.16557339e-03,
        7.10741346e-01, 4.72175732e-01],
       [1.00000000e+00, 3.78936910e-01, 1.27529123e-01, 1.16538387e-02,
        6.93359085e-01, 4.76987448e-01],
       [9.98925974e-01, 3.93442623e-01, 1.27529123e-01, 1.21309060e-02,
        7.16609716e-01, 3.46025105e-01],
       [9.92838888e-01, 3.32141083e-01, 1.27529123e-01, 1.19315833e-02,
        7.31540633e-01, 4.16527197e-01],
       [9.90637415e-01, 3.36910084e-01, 1.27529123e-01, 9.95632874e-03,
        7.12524142e-01, 4.15481172e-01],
       [9.90761125e-01, 3.38301043e-01, 1.27529123e-01, 6.59235091e-03,
        6.86970732e-01, 4.37656904e-01],
       [9.90274720e-01, 3.27868852e-01, 2.10913550e-01, 5.68396253e-03,
        7.09181399e-01, 4.99372385e-01],
       [9.83015202e-01, 3.27868852e-01, 1.27529123e-01, 2.14974358e-02,
        7.31392067e-01, 6.41631799e-01],
       [9.77392028e-01, 2.85245902e-01, 1.47762109e-01, 2.52861995e-02,
        7.09478532e-01, 6.07112971e-01],
       [9.75300207e-01, 2.78688525e-01, 1.27529123e-01, 2.91468501e-02,
        6.70257020e-01, 6.28242678e-01],
       [9.74917831e-01, 2.71733731e-01, 1.27529123e-01, 3.58780734e-02,
        6.70257020e-01, 5.72594142e-01],
       [9.64950755e-01, 2.62295082e-01, 1.27529123e-01, 3.92992339e-02,
        6.36383895e-01, 6.67991632e-01],
       [9.63159774e-01, 2.62295082e-01, 1.27529123e-01, 4.82932591e-02,
        6.93581934e-01, 5.46443515e-01],
       [9.54983679e-01, 2.90511674e-01, 1.27529123e-01, 4.90627752e-02,
        6.59708810e-01, 7.40376569e-01],
       [9.57595643e-01, 3.11475410e-01, 1.27529123e-01, 4.72492660e-02,
        6.49977715e-01, 5.61297071e-01],
       [9.51511369e-01, 2.95081967e-01, 1.27529123e-01, 1.82576261e-02,
        6.64314366e-01, 5.22384937e-01],
       [9.48528275e-01, 2.95081967e-01, 1.27529123e-01, 3.89659403e-03,
        6.29846977e-01, 3.20711297e-01],
       [9.47085931e-01, 2.95081967e-01, 1.27529123e-01, 6.86682798e-03,
        6.48417769e-01, 4.38284519e-01],
       [9.38153518e-01, 2.95081967e-01, 1.27529123e-01, 5.73951146e-03,
        7.04130144e-01, 5.32635983e-01],
       [9.38114156e-01, 2.95081967e-01, 1.27529123e-01, 2.05955826e-02,
        6.85782202e-01, 5.47280335e-01],
       [9.35597786e-01, 2.95081967e-01, 1.27529123e-01, 2.91141743e-02,
        6.69142772e-01, 7.13807531e-01],
       [9.29311077e-01, 2.72826627e-01, 1.27529123e-01, 2.91141743e-02,
        6.81622344e-01, 5.72594142e-01],
       [9.25495753e-01, 2.23646299e-01, 1.27529123e-01, 2.65507546e-02,
        6.35566781e-01, 6.41004184e-01],
       [9.18525829e-01, 2.08643815e-03, 1.27529123e-01, 2.37618715e-02,
        6.09641955e-01, 5.02928870e-01],
       [8.91801693e-01, 0.00000000e+00, 1.27529123e-01, 9.27013608e-03,
        5.26073392e-01, 4.21338912e-01],
       [8.77693149e-01, 0.00000000e+00, 1.27529123e-01, 8.13628440e-03,
        4.22522656e-01, 3.44560669e-01],
       [8.61894841e-01, 0.00000000e+00, 1.27529123e-01, 1.49639014e-02,
        4.52755906e-01, 3.65481172e-01],
       [8.44254943e-01, 0.00000000e+00, 1.27529123e-01, 2.29515107e-02,
        4.59069975e-01, 3.76150628e-01],
       [8.21183060e-01, 0.00000000e+00, 1.27529123e-01, 3.97583295e-02,
        4.60852771e-01, 2.60460251e-01],
       [8.04116726e-01, 0.00000000e+00, 1.27529123e-01, 5.89292454e-02,
        4.26905363e-01, 1.97907950e-01],
       [7.81311943e-01, 0.00000000e+00, 1.27529123e-01, 8.53656345e-02,
        4.37379290e-01, 1.00836820e-01],
       [7.60863270e-01, 0.00000000e+00, 1.27529123e-01, 1.03087377e-01,
        4.37379290e-01, 6.98744770e-02],
       [7.41227145e-01, 0.00000000e+00, 1.27529123e-01, 1.14206966e-01,
        4.27128213e-01, 1.58368201e-01],
       [7.26694052e-01, 0.00000000e+00, 1.27529123e-01, 1.17776801e-01,
        4.37379290e-01, 0.00000000e+00],
       [7.08716764e-01, 0.00000000e+00, 1.27529123e-01, 1.17288297e-01,
        4.48596048e-01, 2.18619247e-01],
       [6.90483621e-01, 0.00000000e+00, 1.27529123e-01, 1.08491961e-01,
        4.58549993e-01, 1.26987448e-01],
       [6.67451099e-01, 0.00000000e+00, 1.27529123e-01, 8.38217010e-02,
        4.99628584e-01, 3.55020921e-01],
       [6.51610618e-01, 0.00000000e+00, 1.27529123e-01, 4.32889541e-02,
        5.10919626e-01, 4.83054393e-01],
       [6.31195684e-01, 0.00000000e+00, 1.27529123e-01, 1.29200275e-02,
        5.21170703e-01, 4.97907950e-01],
       [6.14317726e-01, 0.00000000e+00, 2.26241570e-01, 9.32895259e-04,
        4.98960036e-01, 4.69037657e-01],
       [5.98165158e-01, 0.00000000e+00, 5.90435316e-01, 0.00000000e+00,
        4.61892735e-01, 5.03556485e-01],
       [5.68221755e-01, 0.00000000e+00, 6.33353771e-01, 1.61745413e-03,
        4.25122567e-01, 4.69037657e-01],
       [5.35292447e-01, 0.00000000e+00, 1.00000000e+00, 8.99402522e-03,
        3.58490566e-01, 5.10041841e-01],
       [5.10766973e-01, 0.00000000e+00, 3.93010423e-01, 3.39894098e-02,
        3.27068786e-01, 6.15690377e-01],
       [4.78939807e-01, 0.00000000e+00, 5.32188841e-01, 5.98114931e-02,
        3.27068786e-01, 6.22175732e-01],
       [4.47053597e-01, 0.00000000e+00, 4.31023912e-01, 8.44245703e-02,
        3.24023176e-01, 6.76150628e-01],
       [4.13654754e-01, 0.00000000e+00, 5.32188841e-01, 1.07209434e-01,
        2.90298618e-01, 7.08577406e-01],
       [3.80151882e-01, 0.00000000e+00, 7.97057020e-01, 1.21122807e-01,
        1.19150201e-01, 4.95397490e-01],
       [3.28235926e-01, 0.00000000e+00, 3.56223176e-01, 1.23820198e-01,
        0.00000000e+00, 6.65271967e-01],
       [2.83452966e-01, 0.00000000e+00, 2.28694053e-01, 1.22658572e-01,
        2.65933739e-02, 5.55648536e-01],
       [2.38616587e-01, 0.00000000e+00, 2.28694053e-01, 1.22990232e-01,
        9.41910563e-02, 4.92887029e-01],
       [1.82964031e-01, 0.00000000e+00, 5.19926426e-01, 1.30564491e-01,
        8.97340663e-02, 4.94142259e-01],
       [1.43835174e-01, 0.00000000e+00, 5.25444513e-01, 1.64135650e-01,
        1.14618927e-01, 7.40585774e-01],
       [1.04402664e-01, 0.00000000e+00, 1.55119559e-01, 2.41378071e-01,
        1.98261774e-01, 6.50418410e-01],
       [7.96438281e-02, 0.00000000e+00, 7.11220110e-02, 3.27145618e-01,
        2.89110088e-01, 7.45188285e-01],
       [6.36065353e-02, 0.00000000e+00, 0.00000000e+00, 4.11129065e-01,
        4.05140395e-01, 6.88912134e-01],
       [4.11672585e-02, 0.00000000e+00, 2.52605763e-01, 5.62182942e-01,
        4.54315852e-01, 1.00000000e+00],
       [2.87063044e-02, 0.00000000e+00, 1.27529123e-01, 6.81786323e-01,
        4.59515674e-01, 9.32217573e-01],
       [1.70269716e-02, 1.58966716e-03, 1.27529123e-01, 7.33474602e-01,
        4.37453573e-01, 6.07322176e-01],
       [3.30361486e-03, 6.37853949e-01, 1.27529123e-01, 8.06276376e-01,
        4.69692468e-01, 7.54602510e-01],
       [0.00000000e+00, 7.89369101e-01, 1.27529123e-01, 8.85843682e-01,
        5.10919626e-01, 8.70502092e-01],
       [5.13114648e-03, 8.19672131e-01, 1.27529123e-01, 9.60932765e-01,
        5.99316595e-01, 8.79288703e-01],
       [2.16829598e-02, 8.36065574e-01, 1.27529123e-01, 9.99121020e-01,
        7.28866439e-01, 8.56903766e-01],
       [4.27951674e-02, 8.36065574e-01, 1.27529123e-01, 1.00000000e+00,
        8.67181697e-01, 7.88912134e-01],
       [7.02334461e-02, 8.36065574e-01, 1.27529123e-01, 9.93500775e-01,
        8.46308127e-01, 9.78451883e-01],
       [9.73680733e-02, 8.36065574e-01, 1.27529123e-01, 9.87896869e-01,
        8.66364582e-01, 8.59414226e-01],
       [1.23611427e-01, 8.36065574e-01, 1.27529123e-01, 9.69613102e-01,
        8.35685634e-01, 9.17991632e-01],
       [1.52157471e-01, 8.68852459e-01, 1.27529123e-01, 9.22226597e-01,
        7.96686971e-01, 9.65062762e-01],
       [1.77979087e-01, 8.68852459e-01, 1.27529123e-01, 8.61132577e-01,
        8.29594414e-01, 8.14225941e-01],
       [2.03010647e-01, 8.84252360e-01, 1.27529123e-01, 8.13277174e-01,
        8.29594414e-01, 9.11506276e-01],
       [2.32490138e-01, 8.85245902e-01, 1.27529123e-01, 7.59549923e-01,
        8.41851137e-01, 9.52301255e-01],
       [2.58952796e-01, 8.85245902e-01, 1.27529123e-01, 6.97804020e-01,
        8.55667806e-01, 8.68200837e-01],
       [2.86697538e-01, 8.85245902e-01, 1.27529123e-01, 6.25149288e-01,
        8.78621304e-01, 8.01255230e-01],
       [3.15597842e-01, 8.85245902e-01, 2.09687308e-01, 5.51940700e-01,
        8.90878027e-01, 7.94769874e-01],
       [3.43688409e-01, 8.85245902e-01, 1.27529123e-01, 4.75801089e-01,
        8.90878027e-01, 7.10669456e-01]])

4

3 回答 3

1

我正在做一个项目,我正在尝试使用信用数据集(包含 1% 的少数类和 99% 的多数类的不平衡数据集)使用不同的抽样方法进行欺诈检测,并发现 SMOTE 在不平衡数据集上给出了更好的结果。

SMOTE(Synthetic Minority Oversampling Technique)是一种强大的采样方法,它超越了简单的欠采样或过采样。该算法通过创建相邻实例的凸组合来创建少数类的新实例

我使用了 SMOTE 采样方法和 K-Fold 交叉验证。交叉验证技术确保模型从数据中获得正确的模式,并且不会产生太多噪音。

在数据集不平衡的情况下,采样算法的准确度得分为 99%,这似乎令人印象深刻,但在数据集不平衡的情况下,少数类可以完全忽略。因此,除了 Accuracy 之外,我还使用了 Matthew Coefficient Correlation Score、F1 Score 测量算法来对不平衡数据集进行性能测量。

代码 :

from imblearn.over_sampling import SMOTE
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, 
random_state=0)

sm = SMOTE(random_state=2)
X_train_res, y_train_res = sm.fit_sample(X_train, y_train.ravel())

参考 :

https://www.kaggle.com/qianchao/smote-with-imbalance-data

于 2019-07-23T21:29:14.973 回答
1

您可以在训练测试拆分中使用分层选项,该选项将每个类拆分为提到的测试大小。

x_train,y_train,x_test,y_test = train_test_split(X,y,test_size=0.2,stratify=y)
于 2019-07-22T10:08:45.600 回答
1

TLDR:两者都试试!


在我的数据集不平衡之前,我也遇到过类似的情况。我使用train_test_splitKFold来通过。

然而,一旦我偶然发现处理不平衡数据集的问题,并遇到过平衡和欠平衡的技术。为此,我建议使用该库:imblearn

你会在那里找到各种技术来处理你的一个类超过另一个类的情况。我个人经常使用SMOTE,并且在这种情况下取​​得了相对较好的成功。


其他参考:

https://www.analyticsvidhya.com/blog/2017/03/imbalanced-classification-problem/

https://towardsdatascience.com/handling-imbalanced-datasets-in-machine-learning-7a0e84220f28

于 2019-07-22T09:30:29.487 回答