【スコア0.98】鋳造製品の欠陥検出 CNNでシンプルに解く【SIGNATE)

eye catch cnn tyuzo 機械学習
この記事は約9分で読めます。
※記事内には広告を含む場合がございます

概要

鋳造製品の画像データを使って、欠陥を検出するモデルを構築していただきます。

正常品と不良品の画像、正誤のラベルが与えられ、画像を元に0 か 1を判定する問題。

【練習問題】鋳造製品の欠陥検出 | SIGNATE - Data Science Competition

ノートブックURL

ノートブックは以下にて閲覧できます。

解説

データの前処理

import pandas as pd
import cv2
import numpy as np

df = pd.read_csv('train.csv')

id_list = df['id'].to_list()
target_list = df['target'].to_list()

images = []
target_label = []

pandasでデータを読み込みます。
この後処理するため、一旦listに変換します。

for (path,t) in zip(id_list,target_list):
    img = cv2.imread('./train_data/'+path)
    if img is not None:
        dst = cv2.resize(img, dsize=(200, 200))
        im_rgb = cv2.cvtColor(dst, cv2.COLOR_BGR2RGB)
        
        images.append(im_rgb.astype(np.float32))
        target_label.append(t)
        
        rot90_r = cv2.rotate(im_rgb, cv2.ROTATE_90_CLOCKWISE)
        images.append(rot90_r.astype(np.float32))
        target_label.append(t)
        
        rot90_l = cv2.rotate(im_rgb, cv2.ROTATE_90_COUNTERCLOCKWISE)
        images.append(rot90_l.astype(np.float32))
        target_label.append(t)
        
        rot180 = cv2.rotate(im_rgb, cv2.ROTATE_180)
        images.append(rot180.astype(np.float32))
        target_label.append(t)
        
        img_flip_lr = cv2.flip(im_rgb, 1)
        images.append(img_flip_lr.astype(np.float32))
        target_label.append(t)
        
    else:
        print(path+' : Failed to load.')

np_image = np.array(images)

# 画像とラベルがちゃんと入っているか?
if len(np_image) == len(target_label):
    print('Successfully loaded.')
else:
    print('Different array lengths.')
for (path,t) in zip(id_list,target_list):

2つのリストをfor文で回します。

img = cv2.imread('./train_data/'+path)

データを水増しする

OpenCVで読み込みます。

dst = cv2.resize(img, dsize=(200, 200))

200×200に画像をリサイズします。
元データは300×300ですが、そのままだとメモリに乗らなかったため、リサイズしています。

rot90_r = cv2.rotate(im_rgb, cv2.ROTATE_90_CLOCKWISE)
rot90_l = cv2.rotate(im_rgb, cv2.ROTATE_90_COUNTERCLOCKWISE)
rot180 = cv2.rotate(im_rgb, cv2.ROTATE_180)
img_flip_lr = cv2.flip(im_rgb, 1)

データを水増しします。

  • 右に90度回転
  • 左に90度回転
  • 180度回転
  • 左右反転

を今回は実施しています。

通常のCNNでは90度区切りではなく、15度区切りで回転させることもありますが、
正解データでも、正方形のデータのみで構成されている事が予想できるため、
処理は90度区切りの回転のみです。

水増し用画像処理

下ごしらえ

import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
import keras
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten,BatchNormalization
from keras.layers import Conv2D, MaxPooling2D

num_classes = 2
im_rows = 200
im_cols = 200
in_shape = (im_rows, im_cols, 3)

# 評価用に分ける
X_train, X_test,y_train, y_test = train_test_split(np_image, target_label, train_size = 0.85)

print(len(X_train))
print(len(y_train))
print(len(X_test))
print(len(y_test))

# 正規化
X_train = X_train.astype('float32') / 255
X_test = X_test.astype('float32') / 255
# One-Hotに変換
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

CNNに投入するために正規化、One-Hot化などの処理を行います。

学習

model = Sequential()
model.add(Conv2D(32, (3, 3), padding='same',
                 input_shape=in_shape))
model.add(Activation('relu'))
model.add(Conv2D(32, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))

model.add(Conv2D(64, (3, 3), padding='same'))
model.add(Activation('relu'))
model.add(Conv2D(64, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))

model.add(Flatten())
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes))
model.add(Activation('softmax'))

# モデルをコンパイル
model.compile(
    loss='categorical_crossentropy',
    optimizer='adam',
    metrics=['accuracy'])

# 学習を実行
hist = model.fit(X_train, y_train,
    batch_size=64, epochs=75,
    verbose=1,
    validation_data=(X_test, y_test))

# モデルを評価
score = model.evaluate(X_test, y_test, verbose=1)

典型的なCNNのモデルです。
過去にCIFAR-10を学習した時のものを流用しています。
epochsは最終的に75で設定しました。
以下のグラフを見る限りは、もっと少なくても大丈夫かもしれません。

推論と提出用データ作製

import cv2
import numpy as np
import glob
import re

# glob.globの結果をソートする。
# https://teshi-learn.com/2021-04/python-glob-glob-sorted/
def atoi(text):
    return int(text) if text.isdigit() else text

def natural_keys(text):
    return [ atoi(c) for c in re.split(r'(\d+)', text) ]

files = sorted(glob.glob("test_data/*jpeg"), key=natural_keys)

ans_csv = ""
for file in files:
    im = cv2.imread(file)
    im = cv2.resize(im, (200, 200))
    im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
    im = im.reshape(in_shape).astype('float32') / 255

    r = model.predict(np.array([im]), batch_size=32,verbose=0)
    res = r[0]

    filename = file.replace('test_data/','')
    # 回答csvを作る
    ans_csv+='{},{}\n'.format(filename,str(res.argmax()))
    
# print(ans_csv)
                              
with open('./myans.csv', mode='w') as f:
    f.write(ans_csv)
    print('myans.csv output completed')

実際にSIGNATEから提供されている、テスト用データで推論します。
glob.globでフォルダ内のjpegを集めて推論、最後に正解用のCSVに加工します。
提出用のサンプルも提供されるので、それを見ながらチマチマ実装します。

試行錯誤の記録

ans1 (1)は単純なCNNのみで、データの水増しはしていません。

ans4では、右90 180 左90度にデータを処理し、追加してデータの水増しを行いました。

ans5 (当記事)ではさらに左右反転させたデータを追加し、スコア0.98を記録しました。


~サイト支援のお願い~

以下のリンクはアフェリエイトリンクです。
この本は、私自身が実際に購入し、学習に活用した本です。

コメント

タイトルとURLをコピーしました