from fastai.vision.all import *
import numpy as np
from torch.nn.modules.loss import _Loss
import segmentation_models_pytorch as smp
from steel_segmentation.utils import get_train_df
from steel_segmentation.transforms import SteelDataBlock, SteelDataLoaders
path = Path("../data")
train_pivot = get_train_df(path=path, pivot=True)
block = SteelDataBlock(path)
dls = SteelDataLoaders(block, train_pivot, bs=8)
xb, yb = dls.one_batch()
print(xb.shape, xb.device)
print(yb.shape, yb.device)
device = "cuda" if torch.cuda.is_available() else "cpu"
device
model = smp.Unet("resnet18", classes=4).to(device)
logits = model(xb)
probs = torch.sigmoid(logits)
preds = ( probs > 0.5).float()
Kaggle Dice metric
The competition evaluation metric is defined as:
This competition is evaluated on the mean Dice coefficient. The Dice coefficient can be used to compare the pixel-wise agreement between a predicted segmentation and its corresponding ground truth. The formula is given by:$$J(A,B) = \frac{2 * |A \cap B|}{|A| \cup |B|} $$
where X is the predicted set of pixels and Y is the ground truth. The Dice coefficient is defined to be 1 when both X and Y are empty. The leaderboard score is the mean of the Dice coefficients for each <ImageId, ClassId> pair in the test set.
In this section there are all the metric that can be used to evaluate the performances of the segmentation models trained.
Simulated training with compute_val
and a test Learner with TstLearner
.
@delegates()
class TstLearner(Learner):
def __init__(self,dls=None,model=None,**kwargs):
self.pred,self.xb,self.yb = None,None,None
self.loss_func=BCEWithLogitsLossFlat()
#Go through a fake cycle with various batch sizes and computes the value of met
def compute_val(met, pred, y):
met.reset()
vals = [0,6,15,20]
learn = TstLearner()
for i in range(3):
learn.pred = pred[vals[i]:vals[i+1]]
learn.yb = ( y[vals[i]:vals[i+1]], )
met.accumulate(learn)
return met.value
The fastai
library comes with a dice metric for multiple channel masks. As a segmentation metric in this frameworks, it expects a flatten mask for targets.
multidice_obj = DiceMulti()
compute_val(multidice_obj, pred=preds.detach().cpu(), y=yb.argmax(1))
Here we slightly change the DiceMulti
for a 4-channel mask as targets.
dice_obj = ModDiceMulti(with_logits=True)
compute_val(dice_obj, pred=logits.detach().cpu(), y=yb)
dice_obj = ModDiceMulti()
compute_val(dice_obj, pred=preds.detach().cpu(), y=yb)