werry-chanの日記.料理とエンジニアリング

料理!コーディング!研究!日常!飯!うんち!睡眠!人間の全て!

validation(test)がtrainデータよりaccuracy高くなる問題, dropout見直してみませんか?

機械学習ある程度やっている人が稀に遭遇する現象
あれ, testデータの方がtrainデータよりaccuracy高くね?
なんか変じゃね?すごい不安, このままリリースして問題ないの?

この現象, もしかするとdropoutが原因かもしれません。

この記事の内容を要約すると,
「modelのtrain時とeval時でbatch normalizationやdropoutの挙動が異なることから, testの方がtrainよりaccuracyが高くなる現象が発生している場合があるため, 評価方法を見直してみましょう」です。


では実際に, この現象を再現してみましょう。
まずはデータセットを作ります。

# make easy dataset
import numpy as np
import torch


def make_easy_dataset(X_length):
    theta = np.linspace(0,50*np.pi, 5000)
    sin   = np.sin(theta)
    X, y  = [], []
    for i in range(len(theta) - X_length - 1):
        X.append([sin[i : i + X_length]])
        label = [0,1]
        if sin[i + X_length + 1] > sin[i + X_length]:
            label = [1,0]
        y.append(label)
    return np.array(X, dtype=np.float32), np.array(y, dtype=np.float32)


class MyEasyDataset(torch.utils.data.Dataset):

    def __init__(self, data, label, transform=None):
        self.transform = transform
        self.data_num  = data.shape[0]
        self.data      = torch.tensor(data)
        self.label     = torch.tensor(label)
    
    def __len__(self):
        return self.data_num

    def __getitem__(self, idx):
        out_data  = self.data [idx]
        out_label = self.label[idx]
        if self.transform:
            out_data = self.transform(out_data)
        return out_data, out_label

inputはsin波, prediction対象は sin波の値が次に上昇するか否かです。

次はmodelを定義します。
簡単な1次元Convolutionを任意層重ねたもの, 簡単なFullConnect層を任意層重ねたもの, 複数のモデルを束ねるものです。

# DNN_module.py
import torch
from   torch import nn


class SimpleCNN1d(nn.Module):
    def __init__(self, in_channel:int, out_channel:int, num_block:int, kernel_size:int, stride:int=1, dropout:float=0.0):
        super(SimpleCNN1d, self).__init__()
        self.in_channel  = in_channel
        self.out_channel = out_channel
        self.num_block   = num_block
        self.kernel_size = kernel_size
        self.dropout     = dropout
        layers = []
        for i in range(num_block):
            if i != 0:
                layers += [nn.Conv1d(out_channel, out_channel, kernel_size, stride), nn.ReLU()]
            else:
                layers += [nn.Conv1d( in_channel, out_channel, kernel_size, stride), nn.ReLU()]
            if dropout > 0:
                layers += [nn.Dropout(dropout)]
        self.net = nn.Sequential(*layers)
    def forward(self, x):
        return self.net(x)


class FC_classifier(nn.Module):
    def __init__(self, in_channel:int, in_length:int, out_channel:int, num_block:int, dropout:float=0.0):
        super(FC_classifier, self).__init__()
        self.in_dim      = in_channel * in_length
        self.out_channel = out_channel
        self.num_block   = num_block
        self.flatten     = nn.Flatten()
        layers = []
        for i in range(num_block):
            in_dim_  = self.in_dim // (i + 1)
            if in_dim_ < 1:
                in_dim_ = 1
            out_dim_ = in_dim_ // 2
            if out_dim_ < out_channel or i == num_block-1:
                out_dim_ = out_channel
            layers += [nn.Linear(self.in_dim, out_dim_)]
            if dropout > 0:
                layers += [nn.Dropout(dropout)]

        layers += [nn.Hardsigmoid()]
        self.net = nn.Sequential(*layers)
    def forward(self, x):
        x = self.flatten(x)
        out = self.net(x)
        return out


class CombinedNet(nn.Module):
    def __init__(self, models):
        super(CombinedNet, self).__init__()
        layers = []
        for model in models:
            layers += [model]
        self.net = nn.Sequential(*layers)
    def forward(self, x):
        return self.net(x)

このネットワークを組み合わせて, 実際に学習してみましょう。

dropoutを有効にして, 実際にvalidation(test) accuracyの方がtrain accuracyより大きくなる現象を再現しましょう。

import torch
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt

import make_easy_dataset as med
import DNN_module


def accuracy(logits, correct_logits):
    indices_max     = logits.max(dim=1).indices
    indices_correct = correct_logits.max(dim=1).indices
    return (indices_max == indices_correct).sum() / logits.shape[0]

def train(dataloader, model, loss_func, optimizer, eval_mode=False):
    if eval_mode:
        model.eval()
    else:
        model.train()
    ave_loss, num_batch, ave_acc = 0, 0, 0
    for data in dataloader:
        X, y = data
        out  = model(X.to(device))
        loss = loss_func(out, y.to(device))
        acc  = accuracy(out, y.to(device))
        
        if not eval_mode:
            loss.backward()
            optimizer.step()

        ave_loss  += loss
        ave_acc   += acc
        num_batch += 1
    ave_loss /= num_batch
    ave_acc  /= num_batch
    if eval_mode:
        model.train()
    return float(ave_loss), float(ave_acc)


if __name__ == "__main__":
    device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
    print(device)

    data_len    = 150
    X, y        = med.make_easy_dataset(data_len)
    X_train, X_valid, y_train, y_valid = train_test_split(X, y, stratify=y, train_size=0.5)
    dataset_train = med.MyEasyDataset(X_train, y_train)
    dataset_valid = med.MyEasyDataset(X_valid, y_valid)
    dataloader_train  = torch.utils.data.DataLoader(dataset_train, batch_size=256)
    dataloader_valid  = torch.utils.data.DataLoader(dataset_valid, batch_size=256)


    num_block   = 2
    kernel_size = 10
    base_model        = DNN_module.SimpleCNN1d(  in_channel = 1,
                                                out_channel = 1,
                                                num_block   = num_block,
                                                kernel_size = kernel_size,
                                                dropout     = 0.5,
                                            )

    base_out_data_len = data_len - (kernel_size - 1)*num_block
    classifier        = DNN_module.FC_classifier( in_channel  = 1,
                                                in_length   = base_out_data_len,
                                                out_channel = 2,
                                                num_block   = 1,
                                                dropout     = 0.5,
                                                )

    combined_model    = DNN_module.CombinedNet([base_model, classifier]).to(device)

    optimizer = torch.optim.Adam(combined_model.parameters(), lr=1e-4)

    loss_func = torch.nn.BCELoss().to(device)


    histry_train_loss, histry_train_eval_loss, histry_valid_loss = [], [], []
    histry_train_acc , histry_train_eval_acc , histry_valid_acc  = [], [], []

    epochs = 100
    for t in range(epochs):
        loss_train     , acc_train      = train(dataloader_train, combined_model, loss_func, optimizer)
        loss_train_eval, acc_train_eval = train(dataloader_train, combined_model, loss_func, optimizer, eval_mode=True)
        loss_valid     , acc_valid      = train(dataloader_valid, combined_model, loss_func, optimizer, eval_mode=True)
        histry_train_acc      .append(acc_train)
        histry_train_eval_acc .append(acc_train_eval)
        histry_valid_acc      .append(acc_valid)
        histry_train_loss     .append(loss_train)
        histry_train_eval_loss.append(loss_train_eval)
        histry_valid_loss     .append(loss_valid)
    plt.plot(histry_train_acc     , color='blue'   , label='train acc')
    plt.plot(histry_valid_acc     , color='orange' , label='valid acc')
    plt.plot(histry_train_eval_acc, color='skyblue', label='train eval acc')
    plt.legend(); plt.ylim(0,1)
    plt.savefig('acc_histry.png'); plt.clf(); plt.close()

    plt.plot(histry_train_loss     , color='blue'   , label='train loss')
    plt.plot(histry_valid_loss     , color='orange' , label='valid loss')
    plt.plot(histry_train_eval_loss, color='skyblue', label='train eval loss')
    plt.legend()
    plt.savefig('loss_histry.png'); plt.clf(); plt.close()

学習結果を以下に示しました。
通常のtrain(model.train()でmodel(train_data)) : 青色
valid(model.eval()でmodel(valid_data)):オレンジ色
eval modeでtrain(model.eval()でmodel(train_data)):水色

epoch毎のaccuracy, train:青, valid: オレンジ, eval modeのtrain: 水色

見た目から簡単に分かることとして, 青線だけは水色・オレンジ色とは明らかに異なっています。

青線のみはmodel.train()のtrain modeに入力された結果, 水色・オレンジ色はmodel.eval()のeval modeに入力された結果です。

model.eval()モードという同じ基準下で, 水色(train data)とオレンジ色(valid data)を比較すると, 乖離ない学習曲線が見えます。

同様の現象がLossにも確認できます。

epoch毎のloss, train:青, valid: オレンジ, eval modeのtrain: 水色

さて, この現象はどうして発生したのか?


理由は, model.train()においてはdropoutが有効で, model.eval()ではdropoutが無効化されているということが原因として挙げられます。

dropoutは, networkの重みを確率的に無視するという機能を持ちます。
dropout値が大きすぎる, あるいは重要なlayerの重みを無視してしまった場合, model.eval()と出力が大きく異なるという現象が発生します。

このような現象に遭遇した場合には, 学習時間が倍程度になりますが, model.train()で学習した後に, model.eval()でdropout無効化してtrainデータを再評価することが必要です。

通常は, 学習時間が2倍になることから, このような評価方法は実施しません。
また, dropoutが有効であっても, 他のnetworkや残った重みがカバーしてくれるため, 一般的なdropout値であれば問題になることは少ないです。


どうしてもtestデータ validationデータがtrain accuracyよりも高いことが不安で, この問題を修正しないと製品リリース許可出なそう, みたいな時には, 是非ともお試しください。