自前画像を使ったセマンティックセグメンテーションの実装_学習編(FusionNet)
FusionNetというネットワークを使って、自前の画像でセマンティックセグメンテーションの学習を行うコードを作成してみたので紹介します。
素人の私が分からないなりに作成したので、無茶苦茶な作りをしていますが、深層学習初学者でセマンティックセグメンテーションの学習をしてみたいという人は活用してください!
お手元に高性能なPCがないという方も、google colabで解析できますので、ぜひ!
このコードはgoogle driveにも公開しています。
google colabで実行する際は、ディレクトリのマウント等を忘れないでください。
やり方はこちらです。
作成したコード
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from PIL import Image, ImageOps
import os
import random
import glob
from keras.preprocessing.image import load_img, save_img, img_to_array, array_to_img
import torch
import torch.nn as nn
import torch.utils as utils
import torch.nn.init as init
import torch.utils.data as data
from torch.utils.data import DataLoader, random_split
import torchvision.utils as v_utils
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.autograd import Variable
from torchsummary import summary
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Assume that we are on a CUDA machine, then this should print a CUDA device:
print(device)
batch_size = 1
img_size = 256
input_ch = 1
target_class = 4
lr = 0.0002
epoch = 500
org_path = sorted(glob.glob("test/original/*.png"))
lbl_path = sorted(glob.glob("test/annotation/*.png"))
val_org_path = sorted(glob.glob("test/original/*.png"))
val_lbl_path = sorted(glob.glob("test/annotation/*.png"))
def get_palette():
palette = [[255, 160, 93],
[93, 255, 231],
[255, 0, 229],
[0, 185, 53]]
return np.array(palette)/255
if input_ch ==3 :
color="rgb"
else:
color="grayscale"
#modelの構築
#inputは256*256
def conv_block(in_dim,out_dim,act_fn):
model = nn.Sequential(
nn.Conv2d(in_dim,out_dim, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_dim),
act_fn,
)
return model
def conv_trans_block(in_dim,out_dim,act_fn):
model = nn.Sequential(
nn.ConvTranspose2d(in_dim,out_dim, kernel_size=3, stride=2, padding=1,output_padding=1),
nn.BatchNorm2d(out_dim),
act_fn,
)
return model
def maxpool():
pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
return pool
def conv_block_3(in_dim,out_dim,act_fn):
model = nn.Sequential(
conv_block(in_dim,out_dim,act_fn),
conv_block(out_dim,out_dim,act_fn),
nn.Conv2d(out_dim,out_dim, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_dim),
)
return model
class Conv_residual_conv(nn.Module):
def __init__(self,in_dim,out_dim,act_fn):
super(Conv_residual_conv,self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim
act_fn = act_fn
self.conv_1 = conv_block(self.in_dim,self.out_dim,act_fn)
self.conv_2 = conv_block_3(self.out_dim,self.out_dim,act_fn)
self.conv_3 = conv_block(self.out_dim,self.out_dim,act_fn)
def forward(self,input):
conv_1 = self.conv_1(input)
conv_2 = self.conv_2(conv_1)
res = conv_1 + conv_2
conv_3 = self.conv_3(res)
return conv_3
class FusionGenerator(nn.Module):
def __init__(self,input_nc, output_nc, ngf):
super(FusionGenerator,self).__init__()
self.in_dim = input_nc
self.out_dim = ngf
self.final_out_dim = output_nc
act_fn = nn.LeakyReLU(0.2, inplace=True)
act_fn_2 = nn.ReLU()
# encoder
self.down_1 = Conv_residual_conv(self.in_dim, self.out_dim, act_fn)
self.pool_1 = maxpool()
self.down_2 = Conv_residual_conv(self.out_dim, self.out_dim * 2, act_fn)
self.pool_2 = maxpool()
self.down_3 = Conv_residual_conv(self.out_dim * 2, self.out_dim * 4, act_fn)
self.pool_3 = maxpool()
self.down_4 = Conv_residual_conv(self.out_dim * 4, self.out_dim * 8, act_fn)
self.pool_4 = maxpool()
# bridge
self.bridge = Conv_residual_conv(self.out_dim * 8, self.out_dim * 16, act_fn)
# decoder
self.deconv_1 = conv_trans_block(self.out_dim * 16, self.out_dim * 8, act_fn_2)
self.up_1 = Conv_residual_conv(self.out_dim * 8, self.out_dim * 8, act_fn_2)
self.deconv_2 = conv_trans_block(self.out_dim * 8, self.out_dim * 4, act_fn_2)
self.up_2 = Conv_residual_conv(self.out_dim * 4, self.out_dim * 4, act_fn_2)
self.deconv_3 = conv_trans_block(self.out_dim * 4, self.out_dim * 2, act_fn_2)
self.up_3 = Conv_residual_conv(self.out_dim * 2, self.out_dim * 2, act_fn_2)
self.deconv_4 = conv_trans_block(self.out_dim * 2, self.out_dim, act_fn_2)
self.up_4 = Conv_residual_conv(self.out_dim, self.out_dim, act_fn_2)
# output
self.out = nn.Conv2d(self.out_dim,self.final_out_dim, kernel_size=3, stride=1, padding=1)
self.out_2 = nn.Tanh()
# initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0.0, 0.02)
m.bias.data.fill_(0)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
def forward(self,input):
down_1 = self.down_1(input)
pool_1 = self.pool_1(down_1)
down_2 = self.down_2(pool_1)
pool_2 = self.pool_2(down_2)
down_3 = self.down_3(pool_2)
pool_3 = self.pool_3(down_3)
down_4 = self.down_4(pool_3)
pool_4 = self.pool_4(down_4)
bridge = self.bridge(pool_4)
deconv_1 = self.deconv_1(bridge)
skip_1 = (deconv_1 + down_4)/2
up_1 = self.up_1(skip_1)
deconv_2 = self.deconv_2(up_1)
skip_2 = (deconv_2 + down_3)/2
up_2 = self.up_2(skip_2)
deconv_3 = self.deconv_3(up_2)
skip_3 = (deconv_3 + down_2)/2
up_3 = self.up_3(skip_3)
deconv_4 = self.deconv_4(up_3)
skip_4 = (deconv_4 + down_1)/2
up_4 = self.up_4(skip_4)
out = self.out(up_4)
out = self.out_2(out)
#out = torch.clamp(out, min=-1, max=1)
return out
fusion = nn.DataParallel(FusionGenerator(input_ch,target_class,16)).cuda()
summary(fusion,(input_ch,img_size,img_size))
#summary(fusion,(3,256,256))
#loss関数, 最適化
loss_func = nn.SmoothL1Loss()
optimizer = torch.optim.Adam(fusion.parameters(),lr=lr)
#modelのロード
try:
checkpoint = torch.load('model/fusion_9.pkl')
fusion.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
print("\n--------model restored--------\n")
except:
print("\n--------model not restored--------\n")
pass
df = pd.DataFrame(index=[], columns=['loss','val_loss'])
def data_loader(org_path,lbl_path,color):
org_data = []
lbl_data = []
for m in range(len(org_path)):
org_img = load_img(org_path[m],color_mode=color)
lbl_img = load_img(lbl_path[m],color_mode="rgb")
if np.random.rand() >0.5:
n = random.randint(1,3)
org_img = org_img.rotate(90*n)
lbl_img = lbl_img.rotate(90*n)
if np.random.rand() >0.5:
org_img = ImageOps.flip(org_img)
lbl_img = ImageOps.flip(lbl_img)
if np.random.rand() >0.5:
org_img = ImageOps.mirror(org_img)
lbl_img = ImageOps.mirror(lbl_img)
org_img = img_to_array(org_img)/255
lbl_img = img_to_array(lbl_img)/255
org_data.append(org_img)
lbl_data.append(lbl_img)
org_data = np.array(org_data)
lbl_data = np.array(lbl_data)
return org_data,lbl_data
def make_onehot(data):
palette = np.array(get_palette(),dtype=np.float32)
onehot = np.zeros((data.shape[0], img_size, img_size,len(palette) ), dtype=np.uint8)
for i in range(len(palette)):
cat_color = palette[i]
# 画像が現在カテゴリ色と一致する画素に1を立てた(256, 256)のndarrayを作る
temp = np.where((data[:, :, :, 0] == cat_color[0]) &
(data[:, :, :, 1] == cat_color[1]) &
(data[:, :, :, 2] == cat_color[2]), 1, 0)
onehot[:, :, :, i] = temp
return onehot
random.seed(0)
for i in range(epoch):
c = list(zip(org_path, lbl_path))
random.shuffle(c)
org_path, lbl_path = zip(*c)
for m in range(len(org_path)//batch_size):
#print(org_path[m*batch_size:m*batch_size+batch_size])
x_train,y_train = data_loader(org_path[m*batch_size:m*batch_size+batch_size],lbl_path[m*batch_size:m*batch_size+batch_size],color=color)
y_train = make_onehot(y_train)
x_train = x_train.transpose(0,3,1,2)
y_train = y_train.transpose(0,3,1,2)
x_train = torch.from_numpy(x_train.astype(np.float32)).clone().cuda()
y_train = torch.from_numpy(y_train.astype(np.float32)).clone().cuda()
optimizer.zero_grad()
x = Variable(x_train).cuda()
y_ = Variable(y_train).cuda()
y = fusion.forward(x)
loss = loss_func(y,y_)
loss.backward()
optimizer.step()
for m in range(len(org_path)//batch_size):
x_test,y_test = data_loader(val_org_path[m*batch_size:m*batch_size+batch_size],val_lbl_path[m*batch_size:m*batch_size+batch_size],color=color)
y_test = make_onehot(y_test)
x_test = x_test.transpose(0,3,1,2)
y_test = y_test.transpose(0,3,1,2)
x_test = torch.from_numpy(x_test.astype(np.float32)).clone().cuda()
y_test = torch.from_numpy(y_test.astype(np.float32)).clone().cuda()
x = Variable(x_test).cuda()
y_ = Variable(y_test).cuda()
y = fusion.forward(x)
val_loss = loss_func(y,y_)
print("epoch"+str(i)+":",loss,val_loss)
df.loc[i] = [loss.item(),val_loss.item()]
if len(df) > 2:
if df["val_loss"][len(df)-1] < min(df["val_loss"][0:len(df)-2]):
%rm model/*.pkl
torch.save({
'epoch': i,
'model_state_dict': fusion.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}, "model/fusion_"+str(i)+".pkl")
df.to_csv('model/loss.csv')
コードの使用方法の解説
データセットの作成方法が分からず、kerasとpytorchが混ぜこぜになっている情けない仕様ですが、説明していきます。皆様が使用される場合は、コードを保存してるディレクトリにmodelフォルダを作成してください。
さらに、以下の部分は最低限変更する必要があります。
batch_size = 1
img_size = 256
input_ch = 1
target_class = 4
lr = 0.0002
epoch = 500
#trainigデータのpath
org_path = sorted(glob.glob("test/original/*.png"))
lbl_path = sorted(glob.glob("test/annotation/*.png"))
#validationデータのpath
val_org_path = sorted(glob.glob("test/original/*.png"))
val_lbl_path = sorted(glob.glob("test/annotation/*.png"))
def get_palette():
palette = [[255, 160, 93],
[93, 255, 231],
[255, 0, 229],
[0, 185, 53]]
return np.array(palette)/255input_chはグレースケール画像もしくはカラー画像ということで、1か3にしてください。
このサンプルでは面倒だったので、trainingとvalidationを同じにしていますが、モデルを作る際は必ず違うものにしてください。
また、特に重要となるのはtarget_classとpaletteのところです!
オリジナルの画像に対して、分類先が4種になる場合はtarget_classを4にしてください。
さらにアノテーション画像の色(RGB)に合わせてpaletteの中身の値を変えてください。
上記のコードは、下図のような画像を教師画像とした場合の例となります。
さらに、以下のコードでデータの拡張を行っているので、各自で不要なものは削除してください。
def data_loader(org_path,lbl_path,color):
org_data = []
lbl_data = []
for m in range(len(org_path)):
org_img = load_img(org_path[m],color_mode=color)
lbl_img = load_img(lbl_path[m],color_mode="rgb")
if np.random.rand() >0.5:
n = random.randint(1,3)
org_img = org_img.rotate(90*n)
lbl_img = lbl_img.rotate(90*n)
if np.random.rand() >0.5:
org_img = ImageOps.flip(org_img)
lbl_img = ImageOps.flip(lbl_img)
if np.random.rand() >0.5:
org_img = ImageOps.mirror(org_img)
lbl_img = ImageOps.mirror(lbl_img)
org_img = img_to_array(org_img)/255
lbl_img = img_to_array(lbl_img)/255
org_data.append(org_img)
lbl_data.append(lbl_img)
org_data = np.array(org_data)
lbl_data = np.array(lbl_data)
return org_data,lbl_data学習を実行するとmodelフォルダ内にval_lossが最も低くなったモデルと、エポックごとのlossを示したcsvファイルが保存されます。
モデルは推論時に活用してください。
学習結果について
教師画像を推論すると下図のようになりました。
入力画像にある3つの図形が異なるクラスに分類されていることが分かります。
(教師が1枚の画像だけなので、できて当たり前ですが・・・)
私自身、このコードでどれくらい精度よくセグメンテーションが可能か分からないので、試行された方はご一報くださると嬉しいです。
推論編はこちらです


コメント
コメントを投稿