Relative Performance of .shard(), .skip() and .batch() in TensorFlow

Sat, Jul 11, 2020 6-minute read
Source: Eric Nielson

Ever have the wonderful experience of a multi-day/week processing job crashing on you at 99%, only to have to re-process everything again? Read on to find out how I managed to split a into smaller chunks, and the relative performance of the various approaches I tried.


In general, prefer using .batch() over .shard() or .skip() to split into chunks for processing / saving of progress etc.


As part of my day job at the National Neuroscience Institute, I work with medical images. Properties of such images are two-fold:

  • First, each image is 3D;

  • Secondly, each image may be of very high resolution.

As a result, even with ~400 samples, I would have data in the hundreds of gigabytes.

Initial Googl-ing

There must be a better way to handle the data, I thought to myself, picturing the stereotypical TV informercial actor/actress scratching his/her head.

My initial Gut Feel™ is that the data should be split into smaller files, and a quick search confirmed my belief as to why it may be a good idea to split the .tfrecord files:

Next, I needed a way to actually split my data files, and a brief Google session landed me on this solution which suggested the use of .shard().

  • Naturally, I copied and pasted the code into my notebook, selected Restart kernel and ran all cells, and lived happily ever after.

  • Of course that is not the case. My code ended up taking way longer to run, and I had to kill the job.

Having no other obvious choiceI decided to do some investigation, and below are the totally non-vigorous and unscientific results:

Performance Testing Pre-requisite

I first coded up an timer to keep track of the time:

  def timed():
      import datetime as dt
      start =
          yield None
          end =
          print(f'Duration: {end - start}')
  • Some Notes on the above fragment:

    • The @contextlib.contextmanager automagically converts a generator function into a resource manager, perfect for a use-case like ours. I would strongly recommend reading up on it and keeping it under your toolbelt.

    • The import datetime at dt is put within the function so that the function may be part of a bigger utilities module, without needing the module to import everything at initialization. Also, since Python caches it's import, the datetime module will only be imported once in any case.

The timer function may be used as so:

  with timed():

Performance Baseline

Next, we need a baseline measure.

  • This baseline measure will be compared against the performance of each of the various approaches to splitting the dataset.

  • The comparison will allow us to evaluate whether the act of splitting the dataset itself results in slower processing of the data.

The baseline is obtained by iterating through the whole without any splitting or doing anything else:

  test_ds = loadMyDataset()

  for i in range(NUM_TRIALS):
      with timed():
          for _ in test_ds:
              pass # In real application, actual processing is done here


First, we test the performance of .skip().

  • We do so by iterating through the after using .skip() to skip all but one element:

      for i in range(NUM_TRIALS):
          with timed():
              for _ in test_ds.skip(NUM_DATA-1):
                  pass # In real application, actual processing is done here
  • The hypothesis is that since we are skipping all elements but one, it should be much faster than the baseline of iterating through every record.

  • If the hypothesis is true, then we will proceed to use .skip() together with .take() to split the data.

    • For example, assuming each split has 10 records, ds.skip(0).take(10) will give the first split, and ds.skip(10).take(10) will give the second split, and so on.

    • The example code would look like this:

        NUM_TRIALS = 5
        NUM_SHARDS = 5
        items_per_shard = math.ceil(NUM_DATA / NUM_SHARDS)
        for _ in range(NUM_TRIALS):
            with timed():
                for i in range(NUM_SHARDS):
                    for _ in test_ds.skip(items_per_shard*i).take(items_per_shard):


Next, we test the performance of .shard().

  • We do so by iterating through the after sharding into five shards:

      NUM_TRIALS = 5
      NUM_SHARDS = 5
      for _ in range(NUM_TRIALS):
          with timed():
              for i in range(NUM_SHARDS):
                  for _ in test_ds.shard(NUM_SHARDS, i):
                      pass # In real application, actual processing is done here
  • Since we are still going through the whole once, the time taken should be comparable to the baseline, since both would iterate through every record.


Finally, we test the performance of .batch().

  • We do so by iterating through the by manually splitting into five batches:

      NUM_TRIALS = 5
      NUM_SHARDS = 5
      items_per_shard = math.ceil(NUM_DATA / NUM_SHARDS)
      for _ in range(NUM_TRIALS):    
          with timed():
              for shard in test_ds.batch(items_per_shard):
                  for _ in shard:
                      pass # In real application, actual processing is done here

    Note in this approach that we do need to know the total number of data (i.e., NUM_DATA) in order to split into a specific number of files.

  • Since we are still going through the whole once, the time taken should be comparable to the baseline, since both would iterate through every record.


The results are as follows:



  • Comparing .skip() and the baseline, we can see that even though we skipped all the records but the last one, it still took essentially the same amount of time as compared to the baseline which iterate though each records.

    • This suggests that skipping a record still incur the costs of the earlier processing steps in pipeline. I.e., TensorFlow is not doing anything clever to avoid the skipped item right from the start.

  • Comparing .shard() and the baseline, we can see that .shard() is taking 5 times as long as the baseline.

    • Recall that we are spliting the data into 5 shards, which correspond to the 5 times duration taken by .shard().

    • This suggests that for each shard, the .shard() operation is nonetheless processing the whole pipeline.

    • This makes sense because TensorFlow's documentation on states that sharding is done by splitting records into groups by taking the modulo of each record's index—to get each record index, it must naturally iterate through the whole dataset.

  • Finally, comparing .batch() and the baseline, we can see that both takes about the same amount of time.

    • This suggests that using .batch() to split the dataset will incur little to none overhead.

!!! Conclusion: Generally prefer .batch() when splitting your dataset to avoid unnecessary overheads.


  • In the above test conducted, the test_ds represents a pipeline that involves processing steps like loading images from file, scaling and normalization.

    • These processing steps are likely what is being duplicated in the .skip() and .shard() cases, resulting in poorer performance.

  • If the .skip() and .shard() method were applied at an earlier stage—say, immediately after globbing a list of file names—then the performance difference will be much less pronounced.

    • That said, it would generally be difficult to plan in advance at an early stage of the processing pipeline how we want to split the data. And in some cases where we get the dataset from another sources, it might not be possible to easily insert the .skip() or .shard() early in the pipeline.

Lessons Learnt

Don't just copy and paste code from StackOverflow. Easy to say; hard to follow.

Use .batch() to split your datasets. Or use .skip() / .shard() at earlier on in the pipeline, before any heavy processing.