基于mmdetection进行语义分割(不修改源码)
生成数据集
!pip install -Uqqq pycocotools
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import GroupKFold
from pathlib import Path
from pycocotools.coco import COCO
from PIL import Image
FOLD = 0
dataDir=Path('/kaggle/input/sartorius-cell-instance-segmentation/train')
df = pd.read_csv('/kaggle/input/sartorius-cell-instance-segmentation/train.csv')
df = df.reset_index(drop=True)
df['fold'] = -1
skf = GroupKFold(n_splits=5)
for fold, (_, val_idx) in enumerate(skf.split(X=df, y=df['cell_type'], groups=df['id'])):df.loc[val_idx, 'fold'] = fold
def rle2mask(rle, img_w, img_h):## transforming the string into an array of shape (2, N)array = np.fromiter(rle.split(), dtype = np.uint)array = array.reshape((-1,2)).Tarray[0] = array[0] - 1## decompressing the rle encoding (ie, turning [3, 1, 10, 2] into [3, 4, 10, 11, 12])# for faster mask constructionstarts, lenghts = arraymask_decompressed = np.concatenate([np.arange(s, s + l, dtype = np.uint) for s, l in zip(starts, lenghts)])## Building the binary maskmsk_img = np.zeros(img_w * img_h, dtype = np.uint8)msk_img[mask_decompressed] = 1msk_img = msk_img.reshape((img_h, img_w))msk_img = np.asfortranarray(msk_img) ## This is important so pycocotools can handle this objectreturn msk_img
from tqdm.notebook import tqdm
from pycocotools import mask as maskUtils
from joblib import Parallel, delayed
import json,itertoolsdef annotate(idx, row, cat_ids):mask = rle2mask(row['annotation'], row['width'], row['height']) # Binary maskc_rle = maskUtils.encode(mask) # Encoding it back to rle (coco format)c_rle['counts'] = c_rle['counts'].decode('utf-8') # converting from binary to utf-8area = maskUtils.area(c_rle).item() # calculating the areabbox = maskUtils.toBbox(c_rle).astype(int).tolist() # calculating the bboxesannotation = {'segmentation': c_rle,'bbox': bbox,'area': area,'image_id':row['id'], 'category_id':1, # cat_ids[row['cell_type']], 'iscrowd':0, 'id':idx}return annotationdef coco_structure(df, workers = 4):## Building the headercat_ids = {"cell":1} cats =[{'name':name, 'id':id} for name,id in cat_ids.items()]images = [{'id':id, 'width':row.width, 'height':row.height, 'file_name':f'{id}.png'}\for id,row in df.groupby('id').agg('first').iterrows()]## Building the annotationsannotations = Parallel(n_jobs=workers)(delayed(annotate)(idx, row, cat_ids) for idx, row in tqdm(df.iterrows(), total = len(df)))return {'categories':cats, 'images':images, 'annotations':annotations}
from tqdm.notebook import tqdm
from pycocotools import mask as maskUtils
from joblib import Parallel, delayed
import json,itertoolsdef annotate(idx, row, cat_ids):mask = rle2mask(row['annotation'], row['width'], row['height']) # Binary maskc_rle = maskUtils.encode(mask) # Encoding it back to rle (coco format)c_rle['counts'] = c_rle['counts'].decode('utf-8') # converting from binary to utf-8area = maskUtils.area(c_rle).item() # calculating the areabbox = maskUtils.toBbox(c_rle).astype(int).tolist() # calculating the bboxesannotation = {'segmentation': c_rle,'bbox': bbox,'area': area,'image_id':row['id'], 'category_id':1, # cat_ids[row['cell_type']], 'iscrowd':0, 'id':idx}return annotationdef coco_structure(df, workers = 4):## Building the headercat_ids = {"cell":1} cats =[{'name':name, 'id':id} for name,id in cat_ids.items()]images = [{'id':id, 'width':row.width, 'height':row.height, 'file_name':f'{id}.png'}\for id,row in df.groupby('id').agg('first').iterrows()]## Building the annotationsannotations = Parallel(n_jobs=workers)(delayed(annotate)(idx, row, cat_ids) for idx, row in tqdm(df.iterrows(), total = len(df)))return {'categories':cats, 'images':images, 'annotations':annotations}
train_df = df.query("fold!=@FOLD")
valid_df = df.query("fold==@FOLD")
train_json = coco_structure(train_df)
valid_json = coco_structure(valid_df)
with open('annotations_train.json', 'w', encoding='utf-8') as f:json.dump(train_json, f, ensure_ascii=True, indent=4)
with open('annotations_valid.json', 'w', encoding='utf-8') as f:json.dump(valid_json, f, ensure_ascii=True, indent=4)
!mkdir -p train2017
!mkdir -p valid2017
import shutil
def run_copy(row):img_path = dataDir/f'{row.id}.png'if row.fold!=FOLD:shutil.copy(img_path, './train2017/')else:shutil.copy(img_path, './valid2017/')
tmp_df = df.groupby('id').agg('first').reset_index()
_ = Parallel(n_jobs=-1,backend='threading')(delayed(run_copy)(row) for _, row in tqdm(tmp_df.iterrows(),total=len(tmp_df)))
数据集链接
https://www.kaggle.com/datasets/linheshen/sartorius-coco/settings
数据集部分我已经弄好了,可以直接下载
训练部分代码
数据集链接
https://www.kaggle.com/datasets/linheshen/mmdet-main
https://www.kaggle.com/datasets/ammarnassanalhajali/mmdetectron-31-wheel
安装mmdetection
from IPython.display import clear_output
!pip install --no-index --no-deps /kaggle/input/mmdetectron-31-wheel/*
import os
import shutil
input_path = "/kaggle/input/mmdet-main/mmdetection-main"
output_path = "/kaggle/working/mmdetection"shutil.copytree(input_path, output_path)
%cd /kaggle/working/mmdetection
!pip install -v -e .
clear_output()
%cd /kaggle/working
直接在kaggle上面搜whl就有,挺多人发的
下载预训练权重
!wget https://download.openmmlab.com/mmdetection/v2.0/queryinst/queryinst_r50_fpn_1x_coco/queryinst_r50_fpn_1x_coco_20210907_084916-5a8f1998.pth
这个自己去mmdetection那里搜,链接如下:
open-mmlab/mmdetection: OpenMMLab Detection Toolbox and Benchmark
导入包
import cv2
from itertools import groupby
from pycocotools import mask as mutils
from pycocotools.coco import COCO
import numpy as np
from tqdm.notebook import tqdm
import pandas as pd
import os
import cv2
import matplotlib.pyplot as plt
import wandb
from PIL import Image
import gcfrom glob import glob
import matplotlib.pyplot as plt
os.makedirs('/kaggle/working/configs/sartorius/', exist_ok=True)
配置文件
%%writefile /kaggle/working/configs/sartorius/custom_config.py
_base_ = ['/kaggle/working/mmdetection/configs/_base_/schedules/schedule_2x.py', '/kaggle/working/mmdetection/configs/_base_/default_runtime.py'
]
num_stages = 6
num_proposals = 100
model = dict(type='QueryInst',data_preprocessor=dict(type='DetDataPreprocessor',mean=[123.675, 116.28, 103.53],std=[58.395, 57.12, 57.375],bgr_to_rgb=True,pad_mask=True,pad_size_divisor=32),backbone=dict(type='ResNet',depth=50,num_stages=4,out_indices=(0, 1, 2, 3),frozen_stages=1,norm_cfg=dict(type='BN', requires_grad=True),norm_eval=True,style='pytorch',init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),neck=dict(type='FPN',in_channels=[256, 512, 1024, 2048],out_channels=256,start_level=0,add_extra_convs='on_input',num_outs=4),rpn_head=dict(type='EmbeddingRPNHead',num_proposals=num_proposals,proposal_feature_channel=256),roi_head=dict(type='SparseRoIHead',num_stages=num_stages,stage_loss_weights=[1] * num_stages,proposal_feature_channel=256,bbox_roi_extractor=dict(type='SingleRoIExtractor',roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=2),out_channels=256,featmap_strides=[4, 8, 16, 32]),mask_roi_extractor=dict(type='SingleRoIExtractor',roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=2),out_channels=256,featmap_strides=[4, 8, 16, 32]),bbox_head=[dict(type='DIIHead',num_classes=1,num_ffn_fcs=2,num_heads=8,num_cls_fcs=1,num_reg_fcs=3,feedforward_channels=2048,in_channels=256,dropout=0.0,ffn_act_cfg=dict(type='ReLU', inplace=True),dynamic_conv_cfg=dict(type='DynamicConv',in_channels=256,feat_channels=64,out_channels=256,input_feat_shape=7,act_cfg=dict(type='ReLU', inplace=True),norm_cfg=dict(type='LN')),loss_bbox=dict(type='L1Loss', loss_weight=5.0),loss_iou=dict(type='GIoULoss', loss_weight=2.0),loss_cls=dict(type='FocalLoss',use_sigmoid=True,gamma=2.0,alpha=0.25,loss_weight=2.0),bbox_coder=dict(type='DeltaXYWHBBoxCoder',clip_border=False,target_means=[0., 0., 0., 0.],target_stds=[0.5, 0.5, 1., 1.])) for _ in range(num_stages)],mask_head=[dict(type='DynamicMaskHead',dynamic_conv_cfg=dict(type='DynamicConv',in_channels=256,feat_channels=64,out_channels=256,input_feat_shape=14,with_proj=False,act_cfg=dict(type='ReLU', inplace=True),norm_cfg=dict(type='LN')),num_convs=4,num_classes=1,roi_feat_size=14,in_channels=256,conv_kernel_size=3,conv_out_channels=256,class_agnostic=False,norm_cfg=dict(type='BN'),upsample_cfg=dict(type='deconv', scale_factor=2),loss_mask=dict(type='DiceLoss',loss_weight=8.0,use_sigmoid=True,activate=False,eps=1e-5)) for _ in range(num_stages)]),# training and testing settingstrain_cfg=dict(rpn=None,rcnn=[dict(assigner=dict(type='HungarianAssigner',match_costs=[dict(type='FocalLossCost', weight=2.0),dict(type='BBoxL1Cost', weight=5.0, box_format='xyxy'),dict(type='IoUCost', iou_mode='giou', weight=2.0)]),sampler=dict(type='PseudoSampler'),pos_weight=1,mask_size=28,) for _ in range(num_stages)]),test_cfg=dict(rpn=None, rcnn=dict(max_per_img=num_proposals, mask_thr_binary=0.5)))# optimizer
optim_wrapper = dict(type='OptimWrapper',optimizer=dict(_delete_=True, type='AdamW', lr=0.0001, weight_decay=0.0001),paramwise_cfg=dict(custom_keys={'backbone': dict(lr_mult=0.1, decay_mult=1.0)}),clip_grad=dict(max_norm=0.1, norm_type=2))# learning rate
param_scheduler = [dict(type='LinearLR', start_factor=0.001, by_epoch=False, begin=0,end=1000),dict(type='MultiStepLR',begin=0,end=12,by_epoch=True,milestones=[8, 11],gamma=0.1)
]
classes = ('cell')
backend_args = None
data_root = '/kaggle/input/sartorius-coco/'
img_scale = (768, 768)
dataset_type = 'CocoDataset'
train_pipeline = [dict(type='LoadImageFromFile', backend_args=backend_args),dict(type='LoadAnnotations', with_bbox=True, with_mask=True),dict(type='Resize', scale=img_scale, keep_ratio=True),dict(type='RandomFlip', prob=0.5),dict(type='PackDetInputs')
]
test_pipeline = [dict(type='LoadImageFromFile', backend_args=backend_args),dict(type='Resize', scale=img_scale, keep_ratio=True),# If you don't have a gt annotation, delete the pipelinedict(type='LoadAnnotations', with_bbox=True, with_mask=True),dict(type='PackDetInputs',meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape','scale_factor'))
]
train_dataloader = dict(batch_size=2,num_workers=2,persistent_workers=True,sampler=dict(type='DefaultSampler', shuffle=True),batch_sampler=dict(type='AspectRatioBatchSampler'),dataset=dict(type=dataset_type,data_root=data_root,metainfo=dict(classes=classes),ann_file='annotations_train.json',data_prefix=dict(img='train2017/'),filter_cfg=dict(filter_empty_gt=True, min_size=32),pipeline=train_pipeline,backend_args=backend_args))
val_dataloader = dict(batch_size=1,num_workers=2,persistent_workers=True,drop_last=False,sampler=dict(type='DefaultSampler', shuffle=False),dataset=dict(type=dataset_type,data_root=data_root,metainfo=dict(classes=classes),ann_file='annotations_valid.json',data_prefix=dict(img='valid2017/'),test_mode=True,pipeline=test_pipeline,backend_args=backend_args))
test_dataloader = val_dataloader
val_evaluator = dict(type='CocoMetric',ann_file=data_root + 'annotations_valid.json', metric=['bbox','segm'],format_only=False,backend_args=backend_args)
load_from='/kaggle/working/queryinst_r50_fpn_1x_coco_20210907_084916-5a8f1998.pth'
test_evaluator = val_evaluator
resume_from = None
default_hooks = dict(timer=dict(type='IterTimerHook'),logger=dict(type='LoggerHook', interval=100),param_scheduler=dict(type='ParamSchedulerHook'),checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=1,save_best='coco/segm_mAP'),sampler_seed=dict(type='DistSamplerSeedHook'),visualization=dict(type='DetVisualizationHook'))
work_dir = '/kaggle/working/model_output'
这个自己打开mmdetection的代码看一下就行,对应的base是instance
训练模型
!bash /kaggle/input/mmdet-main/mmdetection-main/tools/dist_train.sh /kaggle/working/configs/sartorius/custom_config.py 2
训练结果的数据集链接
https://www.kaggle.com/datasets/linheshen/sartorious-train
推理部分代码
from IPython.display import clear_output
!pip install --no-index --no-deps /kaggle/input/mmdetectron-31-wheel/*
import os
import shutil
input_path = "/kaggle/input/mmdet-main/mmdetection-main"
output_path = "/kaggle/working/mmdetection"shutil.copytree(input_path, output_path)
%cd /kaggle/working/mmdetection
!pip install -v -e .
clear_output()
%cd /kaggle/working
import pandas as pd
import numpy as np
import cupy as cp
from glob import glob
import os
import cv2
from tqdm.notebook import tqdm
import pickle
from tqdm.notebook import tqdm
from pycocotools import mask as maskUtils
from joblib import Parallel, delayed
import json,itertools
from itertools import groupby
from pycocotools import mask as mutils
from pycocotools import _mask as coco_mask
import matplotlib.pyplot as plt
import os
import base64
import typing as t
import zlib
import random
random.seed(0)
ROOT = '/kaggle/input/sartorius-cell-instance-segmentation'
train_or_test = 'test'
THR = 0.50
# Test Data
df = pd.DataFrame(glob(ROOT+f'/{train_or_test}/*'), columns=['image_path'])
df['id'] = df.image_path.map(lambda x: x.split('/')[-1].split('.')[0])
def read_img(image_id, train_or_test='train', image_size=None):filename = f'{ROOT}/{train_or_test}/{image_id}.png'assert os.path.exists(filename), f'not found {filename}'img = cv2.imread(filename, cv2.IMREAD_UNCHANGED)if image_size is not None:img = cv2.resize(img, (image_size, image_size))if img.dtype == 'uint16':img = (img/256).astype('uint8')return img
def load_RGBY_image(image_id, train_or_test='train', image_size=None):img = read_img(image_id, train_or_test, image_size)stacked_images = np.stack([img for _ in range(3)],axis=-1)return stacked_images
out_image_dir = f'/kaggle/working/mmdet_{train_or_test}/'
!mkdir -p {out_image_dir}
images=[]
annotations=[]
for idx in tqdm(range(len(df))):image_id = df.iloc[idx]['id']img = load_RGBY_image(image_id, train_or_test)cv2.imwrite(f'{out_image_dir}/{image_id}.png', img)cat_ids = {"cell":1} cats =[{'name':name, 'id':z} for name,z in cat_ids.items()]image = {'id':image_id, 'width':img.shape[1], 'height':img.shape[0], 'file_name':image_id+'.png'}images.append(image)annotation = {'segmentation': None,'bbox': [0,0,0,0],'area': 0,'image_id':image_id, 'category_id':1, # cat_ids[row['cell_type']], 'iscrowd':0, 'id':idx}annotations.append(annotation)coco_data={'categories':cats, 'images':images, 'annotations':annotations}with open('/kaggle/working/instances_test2017.json', 'w') as f:json.dump(coco_data, f)
配置
%%writefile /kaggle/working/test_config.py
num_stages = 6
num_proposals = 100
_base_ = '/kaggle/input/sartorious-train/configs/sartorius/custom_config.py'
data_root = '/kaggle/working/'
val_dataloader = dict(batch_size=4,num_workers=4,persistent_workers=True,drop_last=False,dataset=dict(data_root=data_root,ann_file=data_root + 'instances_test2017.json',data_prefix=dict(img='mmdet_test/'),test_mode=True,)
)
test_dataloader = val_dataloader
val_evaluator = dict(ann_file=data_root + 'instances_test2017.json',format_only=True, outfile_prefix='/kaggle/working')
test_evaluator = val_evaluator
%cd /kaggle/working/mmdetection
!python /kaggle/working/mmdetection/tools/test.py '/kaggle/working/test_config.py' '/kaggle/input/sartorious-train/model_output/best_coco_segm_mAP_epoch_21.pth' --out 'result0.pkl'
clear_output()
def mask2rle(msk):'''img: numpy array, 1 - mask, 0 - backgroundReturns run length as string formated'''msk = cp.array(msk)pixels = msk.flatten()pad = cp.array([0])pixels = cp.concatenate([pad, pixels, pad])runs = cp.where(pixels[1:] != pixels[:-1])[0] + 1runs[1::2] -= runs[::2]return ' '.join(str(x) for x in runs)
def print_masked_img(image_id, mask):img = load_RGBY_image(image_id, train_or_test)[...,0]clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))img2 = clahe.apply(img)img3 = cv2.equalizeHist(img)img = np.stack([img, img2, img3],axis=-1)plt.figure(figsize=(15, 15))plt.subplot(1, 3, 1)plt.imshow(img)plt.title('Image')plt.axis('off')plt.subplot(1, 3, 2)plt.imshow(mask,cmap='inferno')plt.title('Mask')plt.axis('off')plt.subplot(1, 3, 3)plt.imshow(img)plt.imshow(mask, alpha=0.4, cmap='inferno')plt.title('Image + Mask')plt.axis('off')plt.tight_layout()plt.show()
result = pickle.load(open('/kaggle/working/mmdetection/result0.pkl', 'rb'))
for ii in range(3):image_id = result[ii]['img_id']bbs = result[ii]['pred_instances']['bboxes']sgs = result[ii]['pred_instances']['masks']scores=result[ii]['pred_instances']['scores']for idx, (bb, sg,sc) in enumerate(zip(bbs,sgs,scores)):box = bb[:4]cnf = sch = sg['size'][0]w = sg['size'][0]if cnf > 0.1:if idx==0:mask = mutils.decode(sg)else:mask+=mutils.decode(sg)print_masked_img(image_id, mask)
效果
提交
import cupy as cp
import gcdef one_hot(y, num_classes, dtype=cp.uint8): # GPUy = cp.array(y, dtype='int')input_shape = y.shapeif input_shape and input_shape[-1] == 1 and len(input_shape) > 1:input_shape = tuple(input_shape[:-1])y = y.ravel()if not num_classes:num_classes = cp.max(y) + 1n = y.shape[0]categorical = cp.zeros((n, num_classes), dtype=dtype)categorical[cp.arange(n), y] = 1output_shape = input_shape + (num_classes,)categorical = cp.reshape(categorical, output_shape)return categoricaldef fix_overlap(msk): # GPU"""Args:mask: multi-channel mask, each channel is an instance of cell, shape:(520,704,None)Returns:multi-channel mask with non-overlapping values, shape:(520,704,None)"""msk = cp.array(msk)msk = cp.pad(msk, [[0,0],[0,0],[1,0]]) # add dummy mask for backgroundins_len = msk.shape[-1]msk = cp.argmax(msk,axis=-1)# convert multi channel mask to single channel mask, argmax will remove overlapmsk = one_hot(msk, num_classes=ins_len) # back to multi-channel mask, some instance might get removedmsk = msk[...,1:] # remove background maskmsk = msk[...,cp.any(msk, axis=(0,1))] # remove all-zero masks#assert np.prod(msk, axis=-1).sum()==0 # overlap check, will raise error if there is overlapreturn mskdef check_overlap(msk):msk = msk.astype(cp.bool_).astype(cp.uint8) # binary maskreturn cp.any(cp.sum(msk, axis=-1)>1) # only one channgel will contain value
import gc
import pandas as pd
import cupy as cp
from tqdm import tqdmdata = []
for ii in tqdm(range(len(result))):image_id = result[ii]['img_id']# Initialize mask list at the start of each image loopmask = []bbs = result[ii]['pred_instances']['bboxes']sgs = result[ii]['pred_instances']['masks']scores = result[ii]['pred_instances']['scores']# Loop through each instance to process bounding boxes, masks, and scoresfor idx, (bb, sg, sc) in enumerate(zip(bbs, sgs, scores)):box = bb[:4] # Extract bounding box coordinates (x1, y1, x2, y2)cnf = sc # Confidence score for the instanceh = sg['size'][0] # Height of the maskw = sg['size'][1] # Width of the maskif cnf <0.5:continue# Ensure mask is correctly decoded and appendedtry:mask_instance = cp.array(mutils.decode(sg)) # Decode the maskmask.append(mask_instance) # Append to the mask listexcept Exception as e:print(f"Error decoding mask for instance {idx} in image {image_id}: {e}")continue # Skip this instance if there's an error# Check if mask contains multiple instances, and stack themif mask:mask = cp.stack(mask, axis=-1) # Stack along the last axis (instances)# Check for overlapping instances and fix if necessaryif check_overlap(mask): # If masks have overlap, fix itmask = fix_overlap(mask)# Process each instance maskfor idx in range(mask.shape[-1]):mask_ins = mask[..., idx] # Get the individual mask for the instancerle = mask2rle(mask_ins) # Convert mask to Run-Length Encoding (RLE)data.append([image_id, rle]) # Append results to data list# Clean up variables to free memorydel mask, rle, sgs, bbsgc.collect() # Force garbage collection to release unused memory# Convert to DataFrame for further use or export
pred_df = pd.DataFrame(data, columns=['id', 'predicted'])
%cd /kaggle/working
!rm -rf *
sub_df = pd.read_csv('/kaggle/input/sartorius-cell-instance-segmentation/sample_submission.csv')
del sub_df['predicted']
sub_df = sub_df.merge(pred_df, on='id', how='left')
sub_df.to_csv('submission.csv',index=False)
sub_df.head()
结果一般,有需要自己慢慢调