Перейти к содержимому
Compvision.ru

Recommended Posts

День добрый. Возникло желание прикрутить к caffe Jaccard index: Jaccard

Тренирую сеть для задачи семантической сегментации, используя датасет с 5 классами.

Непременно нужно прикрутить jaccard в python layer. Для этих целей я переделал код, который реализует dice loss: https://github.com/NVIDIA/DIGITS/tree/digits-5.0/examples/medical-imaging

Получается, что я не совсем понимаю как устроены блобы в Caffe. Есть подозрение, что необходимо по-другому передавать axis, при суммировании всех предсказаний и масок:

def forward(self, bottom, top):
        smooth = 1e-12
        label = bottom[1].data[:,0,:,:]
        # compute prediction
        prediction = np.argmax(bottom[0].data, axis=1)
        # area of predicted contour
        a_p = np.sum(prediction, axis=self.sum_axes)
        # area of contour in label
        a_l = np.sum(label, axis=self.sum_axes)
        # area of intersection
        a_pl = np.sum(prediction * label, axis=self.sum_axes)
        # dice coefficient
        dice_coeff = np.mean((2.*a_pl+smooth)/(a_p + a_l+smooth))
        top[0].data[...] = dice_coeff

Получилась следующая реализация:

import random
import numpy as np
import caffe
import math

class Jaccard(caffe.Layer):
    """
    A layer that calculates the Jaccard coefficient
    """
    def setup(self, bottom, top):
        if len(bottom) != 2:
            raise Exception("Need two inputs to compute Jaccard coefficient.")
        # compute sum over all axes but the batch and channel axes
        self.sum_axes = tuple(range(1, bottom[0].data.ndim - 1))

    def reshape(self, bottom, top):
        # check input dimensions match
        if bottom[0].count != 2*bottom[1].count:
            raise Exception("Prediction must have twice the number of elements of the input.")
        # loss output is scalar
        top[0].reshape(1)

    def forward(self, bottom, top):
        smooth = 1e-12
        label = bottom[1].data[:,0,:,:]
        # compute prediction
        prediction = np.argmax(bottom[0].data, axis=1)
        # area of predicted contour
        a_p = np.sum(prediction, axis=self.sum_axes)
        # area of contour in label
        a_l = np.sum(label, axis=self.sum_axes)
        # area of intersection
        a_pl = np.sum(prediction * label, axis=self.sum_axes)
        # jaccard coef
        jaccard_coef = np.mean((a_pl+smooth) / (a_p + a_l - a_pl + smooth))
        jaccard_coef=0
        top[0].data[...] = jaccard_coef

    def backward(self, top, propagate_down, bottom):
        if propagate_down[1]:
            raise Exception("label not diff")
        elif propagate_down[0]:

            #crossentropy
            loss_crossentropy = 0
            for i in range(len(prediction)):
                loss_crossentropy = loss_crossentropy - prediction[i] * (label[i] - (prediction[i] >= 0)) - math.log(1 + math.exp(prediction[i] - 2 * prediction[i] * (prediction[i] >= 0)))
 
            bottom[0].diff[...] = -(math.log(jaccard_coef) + loss_crossentropy)
            
        else:
            raise Exception("no diff")

Код ошибки error code 1. Если установить кол-во классов равное 2 и использовать код dice loss из вышеприведенной ссылки, все прекрасно заходит. Ошибка в индексах. Есть какие-нибудь соображения?

Поделиться сообщением


Ссылка на сообщение
Поделиться на других сайтах

Так, ну вроде бы похоже на правду. Установил проверку на выход предсказания равному кол-ву классов, то есть 5 и пока работает вроде:

if bottom[0].count != 5*bottom[1].count

Не гуд вот, что:

DIGITS не визуализирует отрицательную ошибку на графике:

image.png.c895b1ce31e13543c7b86640fdab3ebf.png

Поделиться сообщением


Ссылка на сообщение
Поделиться на других сайтах

Создайте учётную запись или войдите для комментирования

Вы должны быть пользователем, чтобы оставить комментарий

Создать учётную запись

Зарегистрируйтесь для создания учётной записи. Это просто!

Зарегистрировать учётную запись

Войти

Уже зарегистрированы? Войдите здесь.

Войти сейчас


  • Сейчас на странице   0 пользователей

    Нет пользователей, просматривающих эту страницу

×