Context
In order to use fit_generator()
in Keras I use a generator-function like this pseudocode-one:
def generator(data: np.array) -> (np.array, np.array):
"""Simple generator yielding some samples and targets"""
while True:
for batch in range(number_of_batches):
yield data[batch * length_sequence], data[(batch + 1) * length_sequence]
In Keras' fit_generator()
function I want to use workers=4
and use_multiprocessing=True
- Hence, I need a threadsafe generator.
In answers on stackoverflow like here or here or in the Keras docs, I read about creating a class inheriting from Keras.utils.Sequence()
like this:
class generatorClass(Sequence):
def __init__(self, x_set, y_set, batch_size):
self.x, self.y = x_set, y_set
self.batch_size = batch_size
def __len__(self):
return int(np.ceil(len(self.x) / float(self.batch_size)))
def __getitem__(self, idx):
batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
return ...
By using Sequences
Keras does not throw any warning using multiple workes and multiprocessing; the generator is supposed to be threadsafe.
Anyhow, since I am using my custom function I stumbled upon Omer Zohars code provided on github which allows to make my generator()
threadsafe by adding a decorator.
The code looks like:
import threading
class threadsafe_iter:
"""
Takes an iterator/generator and makes it thread-safe by
serializing call to the `next` method of given iterator/generator.
"""
def __init__(self, it):
self.it = it
self.lock = threading.Lock()
def __iter__(self):
return self
def __next__(self):
with self.lock:
return self.it.__next__()
def threadsafe_generator(f):
"""A decorator that takes a generator function and makes it thread-safe."""
def g(*a, **kw):
return threadsafe_iter(f(*a, **kw))
return g
Now I can do:
@threadsafe_generator
def generator(data):
...
The thing is: Using this version of a threadsafe generator Keras still emits a warning that the generator has to be threadsafe when using workers > 1
and use_multiprocessing=True
and that this can be avoided by using Sequences
.
My questions now are:
- Does Keras emit this warning only because the generator is not inheriting
Sequences
, or does Keras also check if a generator is threadsafe in general? - Is using the approach I choosed as threadsafe as using the
generatorClass(Sequence)
-version from the Keras-docs? - Are there any other approaches leading to a thread-safe-generator Keras can deal with which are different from these two examples?
Edit:
In newer tensorflow
/keras
-versions (tf
> 2) fit_generator()
is deprecated. Instead, it is recommended to use fit()
with the generator. However, the question still applies to fit()
using a generator as well.