自前画像を使ったセマンティックセグメンテーションの実装_学習編(FusionNet)

FusionNetというネットワークを使って、自前の画像でセマンティックセグメンテーションの学習を行うコードを作成してみたので紹介します。

素人の私が分からないなりに作成したので、無茶苦茶な作りをしていますが、深層学習初学者でセマンティックセグメンテーションの学習をしてみたいという人は活用してください!
お手元に高性能なPCがないという方も、google colabで解析できますので、ぜひ!

このコードはgoogle driveにも公開しています。
google colabで実行する際は、ディレクトリのマウント等を忘れないでください。
やり方はこちらです。

作成したコード

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd 
from PIL import Image, ImageOps
import os
import random
import glob

from keras.preprocessing.image import load_img, save_img, img_to_array, array_to_img

import torch
import torch.nn as nn
import torch.utils as utils
import torch.nn.init as init
import torch.utils.data as data
from torch.utils.data import DataLoader, random_split
import torchvision.utils as v_utils
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.autograd import Variable
from torchsummary import summary

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Assume that we are on a CUDA machine, then this should print a CUDA device:
print(device)

batch_size = 1
img_size = 256

input_ch = 1
target_class = 4

lr = 0.0002
epoch = 500

org_path = sorted(glob.glob("test/original/*.png"))
lbl_path = sorted(glob.glob("test/annotation/*.png"))

val_org_path = sorted(glob.glob("test/original/*.png"))
val_lbl_path = sorted(glob.glob("test/annotation/*.png"))

def get_palette():
    palette = [[255, 160, 93],
               [93, 255, 231],
               [255, 0, 229],
               [0, 185, 53]]
    return np.array(palette)/255
    
if input_ch ==3 :
  color="rgb" 
else:
  color="grayscale"
  
#modelの構築
#inputは256*256
def conv_block(in_dim,out_dim,act_fn):
    model = nn.Sequential(
        nn.Conv2d(in_dim,out_dim, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_dim),
        act_fn,
    )
    return model


def conv_trans_block(in_dim,out_dim,act_fn):
    model = nn.Sequential(
        nn.ConvTranspose2d(in_dim,out_dim, kernel_size=3, stride=2, padding=1,output_padding=1),
        nn.BatchNorm2d(out_dim),
        act_fn,
    )
    return model


def maxpool():
    pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
    return pool


def conv_block_3(in_dim,out_dim,act_fn):
    model = nn.Sequential(
        conv_block(in_dim,out_dim,act_fn),
        conv_block(out_dim,out_dim,act_fn),
        nn.Conv2d(out_dim,out_dim, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_dim),
    )
    return model

class Conv_residual_conv(nn.Module):
    def __init__(self,in_dim,out_dim,act_fn):
        super(Conv_residual_conv,self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        act_fn = act_fn

        self.conv_1 = conv_block(self.in_dim,self.out_dim,act_fn)
        self.conv_2 = conv_block_3(self.out_dim,self.out_dim,act_fn)
        self.conv_3 = conv_block(self.out_dim,self.out_dim,act_fn)
        

    def forward(self,input):
        conv_1 = self.conv_1(input)
        conv_2 = self.conv_2(conv_1)
        res = conv_1 + conv_2
        conv_3 = self.conv_3(res)
        return conv_3

class FusionGenerator(nn.Module):

    def __init__(self,input_nc, output_nc, ngf): 
        super(FusionGenerator,self).__init__()
        self.in_dim = input_nc
        self.out_dim = ngf
        self.final_out_dim = output_nc
        act_fn = nn.LeakyReLU(0.2, inplace=True)
        act_fn_2 = nn.ReLU()
        
        # encoder
        self.down_1 = Conv_residual_conv(self.in_dim, self.out_dim, act_fn)
        self.pool_1 = maxpool()
        self.down_2 = Conv_residual_conv(self.out_dim, self.out_dim * 2, act_fn)
        self.pool_2 = maxpool()
        self.down_3 = Conv_residual_conv(self.out_dim * 2, self.out_dim * 4, act_fn)
        self.pool_3 = maxpool()
        self.down_4 = Conv_residual_conv(self.out_dim * 4, self.out_dim * 8, act_fn)
        self.pool_4 = maxpool()

        # bridge

        self.bridge = Conv_residual_conv(self.out_dim * 8, self.out_dim * 16, act_fn)

        # decoder
        self.deconv_1 = conv_trans_block(self.out_dim * 16, self.out_dim * 8, act_fn_2)
        self.up_1 = Conv_residual_conv(self.out_dim * 8, self.out_dim * 8, act_fn_2)
        self.deconv_2 = conv_trans_block(self.out_dim * 8, self.out_dim * 4, act_fn_2)
        self.up_2 = Conv_residual_conv(self.out_dim * 4, self.out_dim * 4, act_fn_2)
        self.deconv_3 = conv_trans_block(self.out_dim * 4, self.out_dim * 2, act_fn_2)
        self.up_3 = Conv_residual_conv(self.out_dim * 2, self.out_dim * 2, act_fn_2)
        self.deconv_4 = conv_trans_block(self.out_dim * 2, self.out_dim, act_fn_2)
        self.up_4 = Conv_residual_conv(self.out_dim, self.out_dim, act_fn_2)

        # output
        self.out = nn.Conv2d(self.out_dim,self.final_out_dim, kernel_size=3, stride=1, padding=1)
        self.out_2 = nn.Tanh()

        # initialization

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                m.weight.data.normal_(0.0, 0.02)
                m.bias.data.fill_(0)
            
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.normal_(1.0, 0.02)
                m.bias.data.fill_(0)

    def forward(self,input):

        down_1 = self.down_1(input)
        pool_1 = self.pool_1(down_1)
        down_2 = self.down_2(pool_1)
        pool_2 = self.pool_2(down_2)
        down_3 = self.down_3(pool_2)
        pool_3 = self.pool_3(down_3)
        down_4 = self.down_4(pool_3)
        pool_4 = self.pool_4(down_4)

        bridge = self.bridge(pool_4)

        deconv_1 = self.deconv_1(bridge)
        skip_1 = (deconv_1 + down_4)/2
        up_1 = self.up_1(skip_1)
        deconv_2 = self.deconv_2(up_1)
        skip_2 = (deconv_2 + down_3)/2
        up_2 = self.up_2(skip_2)
        deconv_3 = self.deconv_3(up_2)
        skip_3 = (deconv_3 + down_2)/2
        up_3 = self.up_3(skip_3)
        deconv_4 = self.deconv_4(up_3)
        skip_4 = (deconv_4 + down_1)/2
        up_4 = self.up_4(skip_4)

        out = self.out(up_4)
        out = self.out_2(out)
        #out = torch.clamp(out, min=-1, max=1)

        return out

fusion = nn.DataParallel(FusionGenerator(input_ch,target_class,16)).cuda()

summary(fusion,(input_ch,img_size,img_size))
#summary(fusion,(3,256,256))

#loss関数, 最適化
loss_func = nn.SmoothL1Loss()
optimizer = torch.optim.Adam(fusion.parameters(),lr=lr)

#modelのロード
try:
    checkpoint = torch.load('model/fusion_9.pkl')
    fusion.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    print("\n--------model restored--------\n")
except:
    print("\n--------model not restored--------\n")
    pass
    
df = pd.DataFrame(index=[], columns=['loss','val_loss'])

def data_loader(org_path,lbl_path,color):
  org_data = []
  lbl_data = []
  for m in range(len(org_path)):
    org_img = load_img(org_path[m],color_mode=color)
    lbl_img = load_img(lbl_path[m],color_mode="rgb")
    
    if np.random.rand() >0.5:
      n = random.randint(1,3)
      org_img = org_img.rotate(90*n)
      lbl_img = lbl_img.rotate(90*n)
    if np.random.rand() >0.5:
      org_img = ImageOps.flip(org_img)
      lbl_img = ImageOps.flip(lbl_img)
    if np.random.rand() >0.5:
      org_img = ImageOps.mirror(org_img)
      lbl_img = ImageOps.mirror(lbl_img)
    
    org_img = img_to_array(org_img)/255
    lbl_img = img_to_array(lbl_img)/255
    org_data.append(org_img)
    lbl_data.append(lbl_img)
  org_data = np.array(org_data)
  lbl_data = np.array(lbl_data)
  return org_data,lbl_data

def make_onehot(data):
  palette = np.array(get_palette(),dtype=np.float32)
  onehot = np.zeros((data.shape[0], img_size, img_size,len(palette) ), dtype=np.uint8)
  for i in range(len(palette)):
    cat_color = palette[i]
    # 画像が現在カテゴリ色と一致する画素に1を立てた(256, 256)のndarrayを作る
    temp = np.where((data[:, :, :, 0] == cat_color[0]) &
                    (data[:, :, :, 1] == cat_color[1]) &
                    (data[:, :, :, 2] == cat_color[2]), 1, 0)
    
    onehot[:, :, :, i] = temp
  return onehot

random.seed(0)

for i in range(epoch):
  c = list(zip(org_path, lbl_path))
  random.shuffle(c)
  org_path, lbl_path = zip(*c)
  for m in range(len(org_path)//batch_size):
    #print(org_path[m*batch_size:m*batch_size+batch_size])
    x_train,y_train = data_loader(org_path[m*batch_size:m*batch_size+batch_size],lbl_path[m*batch_size:m*batch_size+batch_size],color=color)
    y_train = make_onehot(y_train)
    x_train = x_train.transpose(0,3,1,2)
    y_train = y_train.transpose(0,3,1,2)
    x_train = torch.from_numpy(x_train.astype(np.float32)).clone().cuda()
    y_train = torch.from_numpy(y_train.astype(np.float32)).clone().cuda()
    
    optimizer.zero_grad()

    x = Variable(x_train).cuda()
    y_ = Variable(y_train).cuda()
    y = fusion.forward(x)
            
    loss = loss_func(y,y_)
    loss.backward()
    optimizer.step()

  for m in range(len(org_path)//batch_size):
    x_test,y_test = data_loader(val_org_path[m*batch_size:m*batch_size+batch_size],val_lbl_path[m*batch_size:m*batch_size+batch_size],color=color)
    y_test = make_onehot(y_test)
    x_test = x_test.transpose(0,3,1,2)
    y_test = y_test.transpose(0,3,1,2)
    x_test = torch.from_numpy(x_test.astype(np.float32)).clone().cuda()
    y_test = torch.from_numpy(y_test.astype(np.float32)).clone().cuda()
    
    x = Variable(x_test).cuda()
    y_ = Variable(y_test).cuda()
    y = fusion.forward(x)
    val_loss = loss_func(y,y_)

  print("epoch"+str(i)+":",loss,val_loss)
    
  df.loc[i] = [loss.item(),val_loss.item()]
  if len(df) > 2:
    if df["val_loss"][len(df)-1] < min(df["val_loss"][0:len(df)-2]):
      %rm model/*.pkl 

      torch.save({
            'epoch': i,
            'model_state_dict': fusion.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            }, "model/fusion_"+str(i)+".pkl")

  df.to_csv('model/loss.csv')    

コードの使用方法の解説

データセットの作成方法が分からず、kerasとpytorchが混ぜこぜになっている情けない仕様ですが、説明していきます。
皆様が使用される場合は、コードを保存してるディレクトリにmodelフォルダを作成してください。

さらに、以下の部分は最低限変更する必要があります。
batch_size = 1
img_size = 256

input_ch = 1
target_class = 4

lr = 0.0002
epoch = 500

#trainigデータのpath
org_path = sorted(glob.glob("test/original/*.png"))
lbl_path = sorted(glob.glob("test/annotation/*.png"))
#validationデータのpath
val_org_path = sorted(glob.glob("test/original/*.png"))
val_lbl_path = sorted(glob.glob("test/annotation/*.png"))

def get_palette():
    palette = [[255, 160, 93],
               [93, 255, 231],
               [255, 0, 229],
               [0, 185, 53]]
    return np.array(palette)/255
input_chはグレースケール画像もしくはカラー画像ということで、1か3にしてください。

このサンプルでは面倒だったので、trainingとvalidationを同じにしていますが、モデルを作る際は必ず違うものにしてください。

また、特に重要となるのはtarget_classとpaletteのところです!
オリジナルの画像に対して、分類先が4種になる場合はtarget_classを4にしてください。
さらにアノテーション画像の色(RGB)に合わせてpaletteの中身の値を変えてください。

上記のコードは、下図のような画像を教師画像とした場合の例となります。


さらに、以下のコードでデータの拡張を行っているので、各自で不要なものは削除してください。
def data_loader(org_path,lbl_path,color):
  org_data = []
  lbl_data = []
  for m in range(len(org_path)):
    org_img = load_img(org_path[m],color_mode=color)
    lbl_img = load_img(lbl_path[m],color_mode="rgb")
    
    if np.random.rand() >0.5:
      n = random.randint(1,3)
      org_img = org_img.rotate(90*n)
      lbl_img = lbl_img.rotate(90*n)
    if np.random.rand() >0.5:
      org_img = ImageOps.flip(org_img)
      lbl_img = ImageOps.flip(lbl_img)
    if np.random.rand() >0.5:
      org_img = ImageOps.mirror(org_img)
      lbl_img = ImageOps.mirror(lbl_img)
    
    org_img = img_to_array(org_img)/255
    lbl_img = img_to_array(lbl_img)/255
    org_data.append(org_img)
    lbl_data.append(lbl_img)
  org_data = np.array(org_data)
  lbl_data = np.array(lbl_data)
  return org_data,lbl_data
学習を実行するとmodelフォルダ内にval_lossが最も低くなったモデルと、エポックごとのlossを示したcsvファイルが保存されます。
モデルは推論時に活用してください。

学習結果について

教師画像を推論すると下図のようになりました。
入力画像にある3つの図形が異なるクラスに分類されていることが分かります。
(教師が1枚の画像だけなので、できて当たり前ですが・・・)


私自身、このコードでどれくらい精度よくセグメンテーションが可能か分からないので、試行された方はご一報くださると嬉しいです。
推論編はこちらです

コメント