3

I need a custom weighted MSE loss function. I defined it in keras.backend

from keras import backend as K
def weighted_loss(y_true, y_pred):
    return K.mean( K.square(y_pred - y_true) *
    K.exp(-K.log(1.7) * (K.log(1. + K.exp((y_true - 3)/5 ))))      
    ,axis=-1  )

However, a test run returns

    weighted_loss(1,2)
ValueError: Tensor conversion requested dtype int32 for Tensor with dtype float32: 'Tensor("Exp_37:0", shape=(), dtype=float32)'  

or

    weighted_loss(1.,2.)
ZeroDivisionError: integer division or modulo by zero   

I wonder what mistakes am I making here.

1 Answer 1

9

Whether you are using Tensorflow or Theano is irrelevant for your question. Google the meaning of 'tensor' if the term confuses you.

Take a look at how Keras own loss function tests have been implemented here:

def test_metrics():
    y_a = K.variable(np.random.random((6, 7)))
    y_b = K.variable(np.random.random((6, 7)))
    for metric in all_metrics:
        output = metric(y_a, y_b)
        print(metric.__name__)
        assert K.eval(output).shape == (6,)

You can't simply feed a float or int into tensor calculations. Note also the use of K.eval to obtain the result you're looking for.

So try something similar with your function:

from keras import backend as K
import numpy as np

y_a = K.variable(np.random.random((6, 7)))
y_b = K.variable(np.random.random((6, 7)))
output = weighted_loss(y_a,y_b)
result = K.eval(output)

There is also no need to define your custom function in keras.backend - what if you decide to update Keras later on?

Instead you could do the following in your own code: define a function that returns your loss function

def weighted_loss(y_true, y_pred):
        return K.mean( K.square(y_pred - y_true) * K.exp(-K.log(1.7) * (K.log(1. + K.exp((y_true - 3)/5 )))),axis=-1  )

Then when you want to compile your model with your loss function, you can do:

model.compile(loss = weighted_loss)

In case you want to define a more general loss function, where the weighting depends on some input, you'll need to wrap the function. So for example:

def get_weighted_loss(my_input):
    def weighted_loss(y_true, y_pred):
        return K.mean( K.square(y_pred - y_true) * K.exp(-K.log(1.7) * (K.log(1. + K.exp((y_true - 3)/my_input )))),axis=-1  )
    return weighted_loss

Then when you want to compile your model with your loss function, you can do:

model.compile(loss = get_weighted_loss(5))
6
  • Thanks a lot for the answer. Your code evaluates the weighted_loss function successfully. However, when I try to use this weighted_loss as a loss function, I get a function not found error. Is that reasonable to assume that the formulation of my function is ok, it is just that Keras is not reading my function somehow?
    – axiom
    Commented Jul 25, 2017 at 5:43
  • That's entirely reasonable yes. I've added some more info as on how to add and use the loss function, wihout having to mess around in the Keras source
    – 5Ke
    Commented Jul 25, 2017 at 6:40
  • It works flawlessly. Thanks a lot! But why did I need to wrap the loss function in another function? This I didn't do, hence the code failed on me. I apologize for the newbie question. This happens to be my first python code as well.
    – axiom
    Commented Jul 25, 2017 at 7:22
  • What Keras wants, is that you set loss equal to the loss function, not to a particular loss. If you don't wrap your function, but provide it directly, you're not providing the function - you're providing the function's output for a specific input, in this case a specific loss for a given y_true and y_pred. The wrapper now returns a function, that Keras itself can provide input to, and obtain output from, during training.
    – 5Ke
    Commented Jul 25, 2017 at 8:15
  • 2
    @5Ke I see no reason why you needed to wrap the loss function in another function, you can just pass it by name without calling it.The real issue in the original post was that OP was trying to eval the loss function with float inputs rather than tensor inputs.
    – Ed Bordin
    Commented Jan 25, 2018 at 4:19

Not the answer you're looking for? Browse other questions tagged or ask your own question.