0

I am currently trying to modify how the error of one of the variables' my network is trying to predict is computed. I still want to use MSE but I would like to modify the "difference" part of the equation (because the variable represents angle degrees).

I have tried a few things, but none has worked yet :

I first tried in a some naive iterative fashion

def custom_mean_squared_loss(y_true, y_pred):
  for sample in range(35):
      for timestep in range(data_shape[1]):
          error1 = tf.abs(diff[sample][timestep][6])
          error2 = 360 - error1
          corrected_err = tf.minimum(error1, error2)
          test = tf.gather_nd(diff, [[sample, timestep, 6]])
          test.assign(corrected_err)

But as far as I understand, tensorflow needs to have operations clearly "stated" in order to evaluate them and compute the gradient of the loss function, so I tried to remove the loops and let it to the job :

diff = y_true - y_pred
data_shape = y_pred.get_shape()
error1 = tf.abs(diff[:][:][6])
error2 = 360 - error1
corrected_err = tf.minimum(error1, error2)
diff[:][:][6].assign(corrected_err)
return tf.mean(tf.square(diff), axis=-1)

However, I can't manage to make the assignment line compile :

ValueError: Sliced assignment is only supported for variables
2
  • What is the shape of y_true and y_pred?
    – gorjan
    Commented Aug 7, 2019 at 22:01
  • y_true shape=(?, ?, ?) y_pred shape=(?, 72, 7)
    – linSESH
    Commented Aug 7, 2019 at 22:03

1 Answer 1

1

there are numerous ways to do it. I would rather go for concatenating of diff[:, :, 6] and tf.minimum output:

def custom_mean_squared_loss(y_true, y_pred):
    diff = tf.abs(y_true - y_pred)
    angle_diff = tf.minimum(diff[:, :, 6:], 360 - diff[:, :, 6:])
    error = tf.concat([diff[:, :, :6], angle_diff], axis=-1)
    return tf.mean(error ** 2, axis=-1)

you can use tf.while_loop for passing modified tensors through loop steps. but in your case, loops aren't required and could be replaced with tensor operations

0

Not the answer you're looking for? Browse other questions tagged or ask your own question.