A keras layer that fetches data by index

Suppose you have a set of training images in a numpy array with shape (num_imgs, height, width, channels), and you want your model to take as input not a batch of images, but their indices. Your model will fetch the images using those indices. What might that code look like?

First, your input (the index) is a scalar, but Keras doesn't let you use scalar inputs. Easiest might be to make it have shape (1,):
 inp = Input(shape=(1,))  
Next, you might try to use a Lambda layer to extract the image from the input:
 def fetch_img(x):  
   return x_train[x.flatten()]   
   
 fetch_img = Lambda(fetch_img)(inp)  
If x = [[1], [2], [3]] (i.e., a bunch of arrays of shape (1,)), then we want to turn it into [1, 2, 3], so we flatten it. Recall that x_train[[1, 2, 3]] is the same as x_train[[1, 2, 3], : , : , : ], which selects images 1, 2, and 3.

Next, we don't want to hardcode the use of x_train:
 def fetch_img(x_train):    
   def _fetch_img(x):      
     return x_train[x.flatten()]     
   return _fetch_img    
   
 fetched_imgs = Lambda(fetch_img(x_train))(inp)       
But we still have a problem: x is of type Tensor, and the function given to Lambda must also return a Tensor. We're trying to operate on a numpy array (x_train).

Often the function you see passed to Lambda will appear to be a mathematical operation, but is really an overloaded TF op (e.g., "x + y" is tf.Tensor.__add__(x, y)). If you want to run arbitrary Python code, you have to invoke tf.py_func:
 def fetch_img(x_train):  
   def _fetch_img(x):  
     return tf.py_func(lambda x: x_train[x.flatten()],  
                       [x], tf.float32)    
   return _fetch_img   

Seems like it works! The following assert passes:
  def fetch_img(x_train):   
    def _fetch_img(x):   
      return tf.py_func(lambda x: x_train[x.flatten()],    
                        [x], tf.float32)     
   return _fetch_img   
     
  inp = Input(shape=(1,), dtype=tf.uint16)   
  fetched_imgs = Lambda(fetch_img(x_train))(inp)   
  model = Model(inp, fetched_imgs)   
     
  x_in = np.array([1, 5, 10])   
  res = model.predict(x_in)   
  exp = x_train[x_in]   
  assert np.array_equal(res, exp)
Still one problem though. What happens if we try to evaluate fetched_imgs.shape? It's unknown, meaning that later stages will crash if they try to make use of it. You have to explicitly set the shape:
 def fetch_img(x_train):  
   def _fetch_img(x):  
     res = tf.py_func(lambda x: x_train[x.flatten()], [x], tf.float32)  
     res.set_shape((x.shape[0],) + x_train[0].shape)  
     return res  
   return _fetch_img  
Whew, that should do it.

No comments:

Post a Comment

Maximum Likelihood Estimation for dummies

What is Maximum Likelihood Estimation (MLE)? It's simple, but there are some gotchas. First, let's recall what likelihood  is. ...