Suppose you have a set of training images in a numpy array with shape (num_imgs, height, width, channels), and you want your model to take as input not a batch of images, but their indices. Your model will fetch the images using those indices. What might that code look like?
First, your input (the index) is a scalar, but Keras doesn't let you use scalar inputs. Easiest might be to make it have shape (1,):
Next, we don't want to hardcode the use of x_train:
Often the function you see passed to Lambda will appear to be a mathematical operation, but is really an overloaded TF op (e.g., "x + y" is tf.Tensor.__add__(x, y)). If you want to run arbitrary Python code, you have to invoke tf.py_func:
Seems like it works! The following assert passes:
First, your input (the index) is a scalar, but Keras doesn't let you use scalar inputs. Easiest might be to make it have shape (1,):
inp = Input(shape=(1,))
Next, you might try to use a Lambda layer to extract the image from the input: def fetch_img(x):
return x_train[x.flatten()]
fetch_img = Lambda(fetch_img)(inp)
If x = [[1], [2], [3]] (i.e., a bunch of arrays of shape (1,)), then we want to turn it into [1, 2, 3], so we flatten it. Recall that x_train[[1, 2, 3]] is the same as x_train[[1, 2, 3], : , : , : ], which selects images 1, 2, and 3.
Next, we don't want to hardcode the use of x_train:
def fetch_img(x_train):
def _fetch_img(x):
return x_train[x.flatten()]
return _fetch_img
fetched_imgs = Lambda(fetch_img(x_train))(inp)
But we still have a problem: x is of type Tensor, and the function given to Lambda must also return a Tensor. We're trying to operate on a numpy array (x_train).Often the function you see passed to Lambda will appear to be a mathematical operation, but is really an overloaded TF op (e.g., "x + y" is tf.Tensor.__add__(x, y)). If you want to run arbitrary Python code, you have to invoke tf.py_func:
def fetch_img(x_train):
def _fetch_img(x):
return tf.py_func(lambda x: x_train[x.flatten()],
[x], tf.float32)
return _fetch_img
Seems like it works! The following assert passes:
def fetch_img(x_train):
def _fetch_img(x):
return tf.py_func(lambda x: x_train[x.flatten()],
[x], tf.float32)
return _fetch_img
inp = Input(shape=(1,), dtype=tf.uint16)
fetched_imgs = Lambda(fetch_img(x_train))(inp)
model = Model(inp, fetched_imgs)
x_in = np.array([1, 5, 10])
res = model.predict(x_in)
exp = x_train[x_in]
assert np.array_equal(res, exp)
Still one problem though. What happens if we try to evaluate fetched_imgs.shape? It's unknown, meaning that later stages will crash if they try to make use of it. You have to explicitly set the shape:
def fetch_img(x_train):
def _fetch_img(x):
res = tf.py_func(lambda x: x_train[x.flatten()], [x], tf.float32)
res.set_shape((x.shape[0],) + x_train[0].shape)
return res
return _fetch_img
Whew, that should do it.
No comments:
Post a Comment