I am attempting to write a custom loss function in Keras from this paper. Namely, the loss I want to create is this:
This is a type of ranking loss for multi-class multi-label problems. Here are the details:
Y_i = set of positive labels for sample i
Y_i^bar = set of negative labels for sample i (complement of Y_i)
c_j^i = prediction on i^th sample at label j
In what follows, both y_true
and y_pred
are of dimension 18.
def multilabel_loss(y_true, y_pred):
""" Multi-label loss function.
More complete description here...
"""
zero = K.tf.constant(0, dtype=tf.float32)
where_one = K.tf.not_equal(y_true, zero)
where_zero = K.tf.equal(y_true, zero)
Y_p = K.tf.where(where_one)
Y_n = K.tf.where(where_zero)
n = K.tf.shape(y_true)[0]
loss = 0
for i in range(n):
# Here i is the ith sample; for a specific i, I find all locations
# where Y_p, Y_n belong to the ith sample; axis 0 denotes
# the sample index space
Y_p_i = K.tf.equal(Y_p[:,0], K.tf.constant(i, dtype=tf.int64))
Y_n_i = K.tf.equal(Y_n[:,0], K.tf.constant(i, dtype=tf.int64))
# Here I plug in those locations to get the values
Y_p_i = K.tf.where(Y_p_i)
Y_n_i = K.tf.where(Y_n_i)
# Here I get the indices of the values above
Y_p_ind = K.tf.gather(Y_p[:,1], Y_p_i)
Y_n_ind = K.tf.gather(Y_n[:,1], Y_n_i)
# Here I compute Y_i and its complement
yi = K.tf.shape(Y_p_ind)[0]
yi_not = K.tf.shape(Y_n_ind)[0]
# The value to normalize the inner summation
normalizer = K.tf.divide(1, K.tf.multiply(yi, yi_not))
# This creates a matrix of all combinations of indices k, l from the
# above equation; then it is reshaped
prod = K.tf.map_fn(lambda x: K.tf.map_fn(lambda y: K.tf.stack( [ x, y ] ), Y_n_ind ), Y_p_ind )
prod = K.tf.reshape(prod, [-1, 2, 1])
prod = K.tf.squeeze(prod)
# Next, the indices are fed into the corresponding prediction
# matrix, where the values are then exponentiated and summed
y_pred_gather = K.tf.gather(y_pred[i,:].T, prod)
s = K.tf.cast(K.sum(K.tf.exp(K.tf.subtract(y_pred_gather[:,0], y_pred_gather[:,1]))), tf.float64)
loss = loss + K.tf.multiply(normalizer, s)
return loss
My questions are the following:
- When I go to compile my graph, I get an error revolving around
n
. Namely,TypeError: 'Tensor' object cannot be interpreted as an integer
. I've looked around, but I can't find a way to stop this. My hunch is that I need to avoid a for loop altogether, which brings me to - How can I write this loss without for loops? I'm fairly new to Keras and have spent a solid few hours writing this custom loss myself. I'd love to write it more concisely. What's blocking me from using all matrices is the fact that
Y_i
and its complement can take on different sizes for eachi
.
Please let me know if you'd like me to elaborate more on my code. Happy to do so.
UPDATE 3
As per @Parag S. Chandakkar 's suggestions, I have the following:
def multi_label_loss(y_true, y_pred):
# set consistent casting
y_true = tf.cast(y_true, dtype=tf.float64)
y_pred = tf.cast(y_pred, dtype=tf.float64)
# this get all positive predictions and negative predictions
# it also exponentiates them in their respective Y_i classes
PT = K.tf.multiply(y_true, tf.exp(-y_pred))
PT_complement = K.tf.multiply((1-y_true), tf.exp(y_pred))
# this step gets the weight vector that we'll normalize by
m = K.shape(y_true)[0]
W = K.tf.multiply(K.sum(y_true, axis=1), K.sum(1-y_true, axis=1))
W_inv = 1./W
W_inv = K.reshape(W_inv, (m,1))
# this step computes the outer product of two tensors
def outer_product(inputs):
"""
inputs: list of two tensors (of equal dimensions,
for which you need to compute the outer product
"""
x, y = inputs
batchSize = K.shape(x)[0]
outerProduct = x[:,:, np.newaxis] * y[:,np.newaxis,:]
outerProduct = K.reshape(outerProduct, (batchSize, -1))
# returns a flattened batch-wise set of tensors
return outerProduct
# set up inputs to outer product
inputs = [PT, PT_complement]
# compute final loss
loss = K.sum(K.tf.multiply(W_inv, outer_product(inputs)))
return loss
# for comments
and""" For doc strings (the description of what your function does that comes right after the signature at the top of the function. It can be long and have all the indentation and whatnot that you want. """
I edited yours to illustrate.