自前画像を使ったセマンティックセグメンテーションの実装_学習編(FusionNet)
FusionNetというネットワークを使って、自前の画像でセマンティックセグメンテーションの学習を行うコードを作成してみたので紹介します。
素人の私が分からないなりに作成したので、無茶苦茶な作りをしていますが、深層学習初学者でセマンティックセグメンテーションの学習をしてみたいという人は活用してください!
お手元に高性能なPCがないという方も、google colabで解析できますので、ぜひ!
このコードはgoogle driveにも公開しています。
google colabで実行する際は、ディレクトリのマウント等を忘れないでください。
やり方はこちらです。
作成したコード
import numpy as np import matplotlib.pyplot as plt import pandas as pd from PIL import Image, ImageOps import os import random import glob from keras.preprocessing.image import load_img, save_img, img_to_array, array_to_img import torch import torch.nn as nn import torch.utils as utils import torch.nn.init as init import torch.utils.data as data from torch.utils.data import DataLoader, random_split import torchvision.utils as v_utils import torchvision.datasets as dset import torchvision.transforms as transforms from torch.autograd import Variable from torchsummary import summary device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Assume that we are on a CUDA machine, then this should print a CUDA device: print(device) batch_size = 1 img_size = 256 input_ch = 1 target_class = 4 lr = 0.0002 epoch = 500 org_path = sorted(glob.glob("test/original/*.png")) lbl_path = sorted(glob.glob("test/annotation/*.png")) val_org_path = sorted(glob.glob("test/original/*.png")) val_lbl_path = sorted(glob.glob("test/annotation/*.png")) def get_palette(): palette = [[255, 160, 93], [93, 255, 231], [255, 0, 229], [0, 185, 53]] return np.array(palette)/255 if input_ch ==3 : color="rgb" else: color="grayscale" #modelの構築 #inputは256*256 def conv_block(in_dim,out_dim,act_fn): model = nn.Sequential( nn.Conv2d(in_dim,out_dim, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(out_dim), act_fn, ) return model def conv_trans_block(in_dim,out_dim,act_fn): model = nn.Sequential( nn.ConvTranspose2d(in_dim,out_dim, kernel_size=3, stride=2, padding=1,output_padding=1), nn.BatchNorm2d(out_dim), act_fn, ) return model def maxpool(): pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) return pool def conv_block_3(in_dim,out_dim,act_fn): model = nn.Sequential( conv_block(in_dim,out_dim,act_fn), conv_block(out_dim,out_dim,act_fn), nn.Conv2d(out_dim,out_dim, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(out_dim), ) return model class Conv_residual_conv(nn.Module): def __init__(self,in_dim,out_dim,act_fn): super(Conv_residual_conv,self).__init__() self.in_dim = in_dim self.out_dim = out_dim act_fn = act_fn self.conv_1 = conv_block(self.in_dim,self.out_dim,act_fn) self.conv_2 = conv_block_3(self.out_dim,self.out_dim,act_fn) self.conv_3 = conv_block(self.out_dim,self.out_dim,act_fn) def forward(self,input): conv_1 = self.conv_1(input) conv_2 = self.conv_2(conv_1) res = conv_1 + conv_2 conv_3 = self.conv_3(res) return conv_3 class FusionGenerator(nn.Module): def __init__(self,input_nc, output_nc, ngf): super(FusionGenerator,self).__init__() self.in_dim = input_nc self.out_dim = ngf self.final_out_dim = output_nc act_fn = nn.LeakyReLU(0.2, inplace=True) act_fn_2 = nn.ReLU() # encoder self.down_1 = Conv_residual_conv(self.in_dim, self.out_dim, act_fn) self.pool_1 = maxpool() self.down_2 = Conv_residual_conv(self.out_dim, self.out_dim * 2, act_fn) self.pool_2 = maxpool() self.down_3 = Conv_residual_conv(self.out_dim * 2, self.out_dim * 4, act_fn) self.pool_3 = maxpool() self.down_4 = Conv_residual_conv(self.out_dim * 4, self.out_dim * 8, act_fn) self.pool_4 = maxpool() # bridge self.bridge = Conv_residual_conv(self.out_dim * 8, self.out_dim * 16, act_fn) # decoder self.deconv_1 = conv_trans_block(self.out_dim * 16, self.out_dim * 8, act_fn_2) self.up_1 = Conv_residual_conv(self.out_dim * 8, self.out_dim * 8, act_fn_2) self.deconv_2 = conv_trans_block(self.out_dim * 8, self.out_dim * 4, act_fn_2) self.up_2 = Conv_residual_conv(self.out_dim * 4, self.out_dim * 4, act_fn_2) self.deconv_3 = conv_trans_block(self.out_dim * 4, self.out_dim * 2, act_fn_2) self.up_3 = Conv_residual_conv(self.out_dim * 2, self.out_dim * 2, act_fn_2) self.deconv_4 = conv_trans_block(self.out_dim * 2, self.out_dim, act_fn_2) self.up_4 = Conv_residual_conv(self.out_dim, self.out_dim, act_fn_2) # output self.out = nn.Conv2d(self.out_dim,self.final_out_dim, kernel_size=3, stride=1, padding=1) self.out_2 = nn.Tanh() # initialization for m in self.modules(): if isinstance(m, nn.Conv2d): m.weight.data.normal_(0.0, 0.02) m.bias.data.fill_(0) elif isinstance(m, nn.BatchNorm2d): m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0) def forward(self,input): down_1 = self.down_1(input) pool_1 = self.pool_1(down_1) down_2 = self.down_2(pool_1) pool_2 = self.pool_2(down_2) down_3 = self.down_3(pool_2) pool_3 = self.pool_3(down_3) down_4 = self.down_4(pool_3) pool_4 = self.pool_4(down_4) bridge = self.bridge(pool_4) deconv_1 = self.deconv_1(bridge) skip_1 = (deconv_1 + down_4)/2 up_1 = self.up_1(skip_1) deconv_2 = self.deconv_2(up_1) skip_2 = (deconv_2 + down_3)/2 up_2 = self.up_2(skip_2) deconv_3 = self.deconv_3(up_2) skip_3 = (deconv_3 + down_2)/2 up_3 = self.up_3(skip_3) deconv_4 = self.deconv_4(up_3) skip_4 = (deconv_4 + down_1)/2 up_4 = self.up_4(skip_4) out = self.out(up_4) out = self.out_2(out) #out = torch.clamp(out, min=-1, max=1) return out fusion = nn.DataParallel(FusionGenerator(input_ch,target_class,16)).cuda() summary(fusion,(input_ch,img_size,img_size)) #summary(fusion,(3,256,256)) #loss関数, 最適化 loss_func = nn.SmoothL1Loss() optimizer = torch.optim.Adam(fusion.parameters(),lr=lr) #modelのロード try: checkpoint = torch.load('model/fusion_9.pkl') fusion.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] loss = checkpoint['loss'] print("\n--------model restored--------\n") except: print("\n--------model not restored--------\n") pass df = pd.DataFrame(index=[], columns=['loss','val_loss']) def data_loader(org_path,lbl_path,color): org_data = [] lbl_data = [] for m in range(len(org_path)): org_img = load_img(org_path[m],color_mode=color) lbl_img = load_img(lbl_path[m],color_mode="rgb") if np.random.rand() >0.5: n = random.randint(1,3) org_img = org_img.rotate(90*n) lbl_img = lbl_img.rotate(90*n) if np.random.rand() >0.5: org_img = ImageOps.flip(org_img) lbl_img = ImageOps.flip(lbl_img) if np.random.rand() >0.5: org_img = ImageOps.mirror(org_img) lbl_img = ImageOps.mirror(lbl_img) org_img = img_to_array(org_img)/255 lbl_img = img_to_array(lbl_img)/255 org_data.append(org_img) lbl_data.append(lbl_img) org_data = np.array(org_data) lbl_data = np.array(lbl_data) return org_data,lbl_data def make_onehot(data): palette = np.array(get_palette(),dtype=np.float32) onehot = np.zeros((data.shape[0], img_size, img_size,len(palette) ), dtype=np.uint8) for i in range(len(palette)): cat_color = palette[i] # 画像が現在カテゴリ色と一致する画素に1を立てた(256, 256)のndarrayを作る temp = np.where((data[:, :, :, 0] == cat_color[0]) & (data[:, :, :, 1] == cat_color[1]) & (data[:, :, :, 2] == cat_color[2]), 1, 0) onehot[:, :, :, i] = temp return onehot random.seed(0) for i in range(epoch): c = list(zip(org_path, lbl_path)) random.shuffle(c) org_path, lbl_path = zip(*c) for m in range(len(org_path)//batch_size): #print(org_path[m*batch_size:m*batch_size+batch_size]) x_train,y_train = data_loader(org_path[m*batch_size:m*batch_size+batch_size],lbl_path[m*batch_size:m*batch_size+batch_size],color=color) y_train = make_onehot(y_train) x_train = x_train.transpose(0,3,1,2) y_train = y_train.transpose(0,3,1,2) x_train = torch.from_numpy(x_train.astype(np.float32)).clone().cuda() y_train = torch.from_numpy(y_train.astype(np.float32)).clone().cuda() optimizer.zero_grad() x = Variable(x_train).cuda() y_ = Variable(y_train).cuda() y = fusion.forward(x) loss = loss_func(y,y_) loss.backward() optimizer.step() for m in range(len(org_path)//batch_size): x_test,y_test = data_loader(val_org_path[m*batch_size:m*batch_size+batch_size],val_lbl_path[m*batch_size:m*batch_size+batch_size],color=color) y_test = make_onehot(y_test) x_test = x_test.transpose(0,3,1,2) y_test = y_test.transpose(0,3,1,2) x_test = torch.from_numpy(x_test.astype(np.float32)).clone().cuda() y_test = torch.from_numpy(y_test.astype(np.float32)).clone().cuda() x = Variable(x_test).cuda() y_ = Variable(y_test).cuda() y = fusion.forward(x) val_loss = loss_func(y,y_) print("epoch"+str(i)+":",loss,val_loss) df.loc[i] = [loss.item(),val_loss.item()] if len(df) > 2: if df["val_loss"][len(df)-1] < min(df["val_loss"][0:len(df)-2]): %rm model/*.pkl torch.save({ 'epoch': i, 'model_state_dict': fusion.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, }, "model/fusion_"+str(i)+".pkl") df.to_csv('model/loss.csv')
コードの使用方法の解説
データセットの作成方法が分からず、kerasとpytorchが混ぜこぜになっている情けない仕様ですが、説明していきます。皆様が使用される場合は、コードを保存してるディレクトリにmodelフォルダを作成してください。
さらに、以下の部分は最低限変更する必要があります。
batch_size = 1 img_size = 256 input_ch = 1 target_class = 4 lr = 0.0002 epoch = 500 #trainigデータのpath org_path = sorted(glob.glob("test/original/*.png")) lbl_path = sorted(glob.glob("test/annotation/*.png")) #validationデータのpath val_org_path = sorted(glob.glob("test/original/*.png")) val_lbl_path = sorted(glob.glob("test/annotation/*.png")) def get_palette(): palette = [[255, 160, 93], [93, 255, 231], [255, 0, 229], [0, 185, 53]] return np.array(palette)/255
input_chはグレースケール画像もしくはカラー画像ということで、1か3にしてください。
このサンプルでは面倒だったので、trainingとvalidationを同じにしていますが、モデルを作る際は必ず違うものにしてください。
また、特に重要となるのはtarget_classとpaletteのところです!
オリジナルの画像に対して、分類先が4種になる場合はtarget_classを4にしてください。
さらにアノテーション画像の色(RGB)に合わせてpaletteの中身の値を変えてください。
上記のコードは、下図のような画像を教師画像とした場合の例となります。
さらに、以下のコードでデータの拡張を行っているので、各自で不要なものは削除してください。
def data_loader(org_path,lbl_path,color): org_data = [] lbl_data = [] for m in range(len(org_path)): org_img = load_img(org_path[m],color_mode=color) lbl_img = load_img(lbl_path[m],color_mode="rgb") if np.random.rand() >0.5: n = random.randint(1,3) org_img = org_img.rotate(90*n) lbl_img = lbl_img.rotate(90*n) if np.random.rand() >0.5: org_img = ImageOps.flip(org_img) lbl_img = ImageOps.flip(lbl_img) if np.random.rand() >0.5: org_img = ImageOps.mirror(org_img) lbl_img = ImageOps.mirror(lbl_img) org_img = img_to_array(org_img)/255 lbl_img = img_to_array(lbl_img)/255 org_data.append(org_img) lbl_data.append(lbl_img) org_data = np.array(org_data) lbl_data = np.array(lbl_data) return org_data,lbl_data
学習を実行するとmodelフォルダ内にval_lossが最も低くなったモデルと、エポックごとのlossを示したcsvファイルが保存されます。
モデルは推論時に活用してください。
学習結果について
教師画像を推論すると下図のようになりました。
入力画像にある3つの図形が異なるクラスに分類されていることが分かります。
(教師が1枚の画像だけなので、できて当たり前ですが・・・)
私自身、このコードでどれくらい精度よくセグメンテーションが可能か分からないので、試行された方はご一報くださると嬉しいです。
推論編はこちらです
コメント
コメントを投稿