15

I have a neural network written in PyTorch, that outputs some Tensor a on GPU. I would like to continue processing a with a highly efficient TensorFlow layer.

As far as I know, the only way to do this is to move a from GPU memory to CPU memory, convert to numpy, and then feed that into TensorFlow. A simplified example:

import torch
import tensorflow as tf

# output of some neural network written in PyTorch
a = torch.ones((10, 10), dtype=torch.float32).cuda()

# move to CPU / pinned memory
c = a.to('cpu', non_blocking=True)

# setup TensorFlow stuff (only needs to happen once)
sess = tf.Session()
c_ph = tf.placeholder(tf.float32, shape=c.shape)
c_mean = tf.reduce_mean(c_ph)

# run TensorFlow
print(sess.run(c_mean, feed_dict={c_ph: c.numpy()}))

This is a bit far fetched maybe but is there a way to make it so that either

  1. a never leaves GPU memory, or
  2. a goes from GPU memory to Pinned Memory to GPU memory.

I attempted 2. in the code snipped above using non_blocking=True but I am not sure if it does what I expect (i.e. move it to pinned memory).

Ideally, my TensorFlow graph would operate directly on the memory occupied by the PyTorch tensor, but I supposed that is not possible?

1 Answer 1

1

I am not familiar with tensorflow, but you may use pyTorch to expose the "internals" of a tensor.
You can access the underlying storage of a tensor

a.storage()

Once you have the storage, you can get a pointer to the memory (either CPU or GPU):

a.storage().data_ptr()

You can check if it is pinned or not

a.storage().is_pinned()

And you can pin it

a.storage().pin_memory()

I am not familiar with interfaces between pyTorch and tensorflow, but I came across an example of a package (FAISS) directly accessing pytorch tensors in GPU.

2
  • On the pytorch side, the x.storage().data_ptr() method returns the data pointer as a python integer. Not sure how to use this on the TF side. Also, one thing to check is whether computations are be performed in the same CUDA stream, otherwise the computations may not be synchronized. Commented Feb 27, 2019 at 12:19
  • @MatthijsDouze faiss recommend using res.syncDefaultStreamCurrentDevice() where res is a handle to the current computing resource.
    – Shai
    Commented Feb 27, 2019 at 12:21

Not the answer you're looking for? Browse other questions tagged or ask your own question.