Rewrite of fsverity.py (renamed to fsveritysetup.py) and mkfsverity.sh

Signed-off-by: Eric Biggers <ebiggers@google.com>
diff --git a/fsverity.py b/fsverity.py
deleted file mode 100755
index 6289ba3..0000000
--- a/fsverity.py
+++ /dev/null
@@ -1,250 +0,0 @@
-#!/usr/bin/python
-
-import argparse
-import binascii
-import os
-import shutil
-import subprocess
-import sys
-import tempfile
-from ctypes import *
-
-class FSVHeader(Structure):
-  _fields_ = [('magic', c_uint8 * 8),
-              ('maj_version', c_uint8),
-              ('min_version', c_uint8),
-              ('log_blocksize', c_uint8),
-	      ('log_arity', c_uint8),
-	      ('meta_algorithm', c_uint16),
-	      ('data_algorithm', c_uint16),
-	      ('reserved1', c_uint32),
-              ('size', c_uint64),
-              ('auth_blk_offset', c_uint8),
-              ('extension_count', c_uint8),
-              ('salt', c_char * 8),
-              ('reserved2', c_char * 22)]
-
-HEADER_SIZE = 64
-
-class FSVExt(Structure):
-   _fields_ = [('length', c_uint16),
-               ('type', c_uint8),
-               ('reserved', c_char * 5)]
-
-class PatchExt(Structure):
-  _fields_ = [('offset', c_uint64),
-              ('length', c_uint8),
-              ('reserved', c_char * 7)]
-# Append databytes at the end of the buffer this gets serialized into
-
-class ElideExt(Structure):
-  _fields_ = [('offset', c_uint64),
-              ('length', c_uint64)]
-    
-def parse_args():
-  parser = argparse.ArgumentParser(description='Build file-based integrity metadata')
-  parser.add_argument('--salt', metavar='<hex_string>', type=binascii.unhexlify,
-                      help='Hex string, e.g. 01ab')
-  parser.add_argument('--tree-file', metavar='<filename>', type=str,
-                      help='Filename for tree file (optional)')
-  parser.add_argument('input_file', metavar='<file>', type=str,
-                      help='Original content input file')
-  parser.add_argument('output_file', metavar='<file>', type=str,
-                      help='Output file formatted for fs-verity')
-  parser.add_argument('--patch_file', metavar='<file>', type=str,
-                      help='File containing patch content')
-  parser.add_argument('--patch_offset', metavar='<offset>', type=str,
-                      help='Offset to which to apply patch')
-  parser.add_argument('--elide_offset', metavar='<offset>', type=str,
-                      help='Offset of segment to elide')
-  parser.add_argument('--elide_length', metavar='<length>', type=str,
-                      help='Length of segment to elide')
-  return parser.parse_args()
-
-def generate_merkle_tree(args, elided_file):
-  if args.tree_file is not None:
-    tree_file_name = args.tree_file
-  else:
-    tree_file = tempfile.NamedTemporaryFile()
-    tree_file_name = tree_file.name
-    tree_file.close()
-  if elided_file is not None:
-    file_to_verity = elided_file.name
-  else:
-    file_to_verity = args.output_file
-  cmd = ['veritysetup', 'format', file_to_verity, tree_file_name, '-s', binascii.hexlify(args.salt), '--no-superblock']
-  print ' '.join(cmd)
-  output = subprocess.check_output(cmd)
-  root_hash = ''
-  for line in output.split('\n'):
-    if line.startswith('Root hash'):
-      root_hash = line.split(':')[1].strip()
-      break
-  else:
-    sys.exit('FATAL: root hash is not found')
-  with file(tree_file_name, 'r') as tree_file:
-    tree_file.seek(0, os.SEEK_SET)
-    merkle_tree = tree_file.read()
-  return root_hash, merkle_tree
-
-def copy_src_to_dst(args):
-  with file(args.output_file, 'w') as dst:
-    with file (args.input_file, 'r') as src:
-      shutil.copyfileobj(src, dst)
-
-def pad_dst(args):
-  with file (args.output_file, 'a') as dst:
-    dst.write('\0' * ((4096 - dst.tell()) % 4096))
-
-def append_merkle_tree_to_dst(args, tree):
-  with file (args.output_file, 'a') as dst:
-    dst.write(tree)
-
-def append_header_to_dst(args, header):
-  with file (args.output_file, 'a') as dst:
-    dst.write(string_at(pointer(header), sizeof(header)))
-
-class HeaderOffset(Structure):
-  _fields_ = [('hdr_offset', c_uint32)]
-  
-def append_header_reverse_offset_to_dst(args, extensions):
-  hdr_offset = HeaderOffset()
-  hdr_offset.hdr_offset = HEADER_SIZE + len(extensions) + sizeof(hdr_offset)
-  with file (args.output_file, 'a') as dst:
-    dst.write(string_at(pointer(hdr_offset), sizeof(hdr_offset)))
-
-def append_extensions_to_dst(args, extensions):
-  with file (args.output_file, 'a') as dst:
-    dst.write(extensions)
-
-def fill_header_struct(args):
-  statinfo = os.stat(args.input_file)
-  header = FSVHeader()
-  assert sizeof(header) == HEADER_SIZE
-  memset(addressof(header), 0, sizeof(header))
-  memmove(addressof(header) + FSVHeader.magic.offset, b'TrueBrew', 8)
-  header.maj_version = 1
-  header.min_version = 0
-  header.log_blocksize = 12
-  header.log_arity = 7
-  header.meta_algorithm = 1  # sha256
-  header.data_algorithm = 1  # sha256
-  header.reserved1 = 0
-  header.size = statinfo.st_size
-  header.auth_blk_offset = 0
-  header.extension_count = 0
-  if args.patch_file is not None and args.patch_offset is not None:
-    header.extension_count += 1
-  header.salt = args.salt
-  return header
-
-def apply_patch(args):
-  if args.patch_file is not None and args.patch_offset is not None:
-    statinfo = os.stat(args.patch_file)
-    patch_file_size = statinfo.st_size
-    if patch_file_size > 256:
-      print "Invalid patch file size; must be <= 256 bytes: [", patch_file_size, "]"
-      return None
-    statinfo = os.stat(args.output_file)
-    if statinfo.st_size < (int(args.patch_offset) + patch_file_size):
-      print "Invalid output file size for patch offset and size"
-      return None
-    with file (args.patch_file, 'r') as patch_file:
-      patch_buf = ""
-      original_content = ""
-      with file (args.output_file, 'r') as dst:
-        dst.seek(int(args.patch_offset), os.SEEK_SET)
-        original_content = dst.read(patch_file_size)
-        dst.seek(int(args.patch_offset), os.SEEK_SET)
-        patch_buf = patch_file.read(patch_file_size)
-        dst.close()
-      with file (args.output_file, 'w') as dst:
-        dst.seek(int(args.patch_offset), os.SEEK_SET)
-        dst.write(patch_buf)
-        dst.close()
-      return original_content
-  else:
-    return None
-                 
-def serialize_extensions(args):
-  patch_ext_buf = None
-  elide_ext_buf = None
-  if args.patch_file is not None and args.patch_offset is not None:
-    statinfo = os.stat(args.patch_file)
-    patch_file_size = statinfo.st_size
-    exthdr = FSVExt()
-    memset(addressof(exthdr), 0, sizeof(exthdr))
-    patch_ext = PatchExt()
-    memset(addressof(patch_ext), 0, sizeof(patch_ext))
-    aligned_patch_size = ((int(patch_file_size) + int(8 - 1)) / int(8)) * int(8)
-    exthdr.length = sizeof(exthdr) + sizeof(patch_ext) + aligned_patch_size;
-    exthdr.type = 1  # 1 == patch extension
-    patch_ext.offset = int(args.patch_offset)
-    print "Patch offset: ", patch_ext.offset
-    patch_ext.length = patch_file_size
-    print "Patch length: ", patch_ext.length
-    patch_ext_buf = create_string_buffer(exthdr.length)
-    memset(addressof(patch_ext_buf), 0, sizeof(patch_ext_buf))  # Includes the zero-pad
-    memmove(addressof(patch_ext_buf), addressof(exthdr), sizeof(exthdr))
-    memmove(addressof(patch_ext_buf) + sizeof(exthdr), addressof(patch_ext), sizeof(patch_ext))
-    with file (args.patch_file, 'r') as patch_file:
-      memmove(addressof(patch_ext_buf) + sizeof(exthdr) + sizeof(patch_ext), patch_file.read(patch_file_size), patch_file_size)
-  if args.elide_offset is not None and args.elide_length is not None:
-    exthdr = FSVExt()
-    memset(addressof(exthdr), 0, sizeof(exthdr))
-    elide_ext = ElideExt()
-    memset(addressof(elide_ext), 0, sizeof(elide_ext))
-    exthdr.length = sizeof(exthdr) + sizeof(elide_ext)
-    exthdr.type = 0  # 0 == elide extension
-    elide_ext.offset = int(args.elide_offset)
-    print "Elide offset: ", elide_ext.offset
-    elide_ext.length = int(args.elide_length)
-    print "Elide length: ", elide_ext.length
-    elide_ext_buf = create_string_buffer(exthdr.length)
-    memset(addressof(elide_ext_buf), 0, sizeof(elide_ext_buf))
-    memmove(addressof(elide_ext_buf), addressof(exthdr), sizeof(exthdr))
-    memmove(addressof(elide_ext_buf) + sizeof(exthdr), addressof(elide_ext), sizeof(elide_ext))
-  return (string_at(patch_ext_buf) if (patch_ext_buf is not None) else "") + (string_at(elide_ext_buf) if (elide_ext_buf is not None) else "")
-
-def restore_patched_content(args, original_content):
-  if original_content is not None:
-    with file (args.output_file, 'w') as dst:
-      dst.seek(int(args.patch_offset), os.SEEK_SET)
-      dst.write(original_content)
-
-def elide_dst(args):
-  if args.elide_offset is not None and args.elide_length is not None:
-    statinfo = os.stat(args.output_file)
-    dst_size = statinfo.st_size
-    if dst_size < (int(args.elide_offset) + elide_length):
-      print "dst_size >= elide region offet+length"
-      return None
-    elided_file = tempfile.NamedTemporaryFile()
-    with file (args.output_file, 'r') as dst:
-      elided_file.write(dst.read(int(args.elide_offset)))
-      end_of_elided_segment = int(args.elide_offset) + int(args.elide_length)
-      dst.seek(end_of_elided_segment, os.SEEK_SET)
-      elided_file.write(dst.read(dst_size - end_of_elided_segment))
-    return elided_file
-  else:
-    return None
-
-def main():
-  args = parse_args()
-
-  copy_src_to_dst(args)
-  pad_dst(args)
-  original_content = apply_patch(args)
-  elided_file = elide_dst(args)
-  root_hash, merkle_tree = generate_merkle_tree(args, elided_file)
-  append_merkle_tree_to_dst(args, merkle_tree)
-  header = fill_header_struct(args)
-  append_header_to_dst(args, header)
-  extensions = serialize_extensions(args)
-  append_extensions_to_dst(args, extensions)
-  restore_patched_content(args, original_content)
-  append_header_reverse_offset_to_dst(args, extensions)
-  print 'Merkle root hash: [', root_hash, "]"
-
-if __name__ == '__main__':
-  main()
diff --git a/fsveritysetup.py b/fsveritysetup.py
new file mode 100755
index 0000000..cc90526
--- /dev/null
+++ b/fsveritysetup.py
@@ -0,0 +1,531 @@
+#!/usr/bin/python
+"""Sets up a file for fs-verity."""
+
+from __future__ import print_function
+
+import argparse
+import binascii
+import ctypes
+import hashlib
+import io
+import math
+import os
+import subprocess
+import sys
+import tempfile
+
+DATA_BLOCK_SIZE = 4096
+HASH_BLOCK_SIZE = 4096
+HASH_ALGORITHM = 'sha256'
+FS_VERITY_MAGIC = b'TrueBrew'
+FS_VERITY_SALT_SIZE = 8
+FS_VERITY_EXT_ELIDE = 0
+FS_VERITY_EXT_PATCH = 1
+
+
+class fsverity_header(ctypes.LittleEndianStructure):
+  _fields_ = [
+      ('magic', ctypes.c_char * 8),  #
+      ('maj_version', ctypes.c_uint8),
+      ('min_version', ctypes.c_uint8),
+      ('log_blocksize', ctypes.c_uint8),
+      ('log_arity', ctypes.c_uint8),
+      ('meta_algorithm', ctypes.c_uint16),
+      ('data_algorithm', ctypes.c_uint16),
+      ('flags', ctypes.c_uint32),
+      ('reserved1', ctypes.c_uint32),
+      ('size', ctypes.c_uint64),
+      ('auth_blk_offset', ctypes.c_uint8),
+      ('extension_count', ctypes.c_uint8),
+      ('salt', ctypes.c_char * FS_VERITY_SALT_SIZE),
+      ('reserved2', ctypes.c_char * 22)
+  ]
+
+
+class fsverity_extension(ctypes.LittleEndianStructure):
+  _fields_ = [
+      ('length', ctypes.c_uint16),  #
+      ('type', ctypes.c_uint8),
+      ('reserved', ctypes.c_char * 5)
+  ]
+
+
+class fsverity_extension_patch(ctypes.LittleEndianStructure):
+  _fields_ = [
+      ('offset', ctypes.c_uint64),  #
+      ('length', ctypes.c_uint8),
+      ('reserved', ctypes.c_char * 7)
+      # followed by variable-length 'databytes'
+  ]
+
+
+class fsverity_extension_elide(ctypes.LittleEndianStructure):
+  _fields_ = [
+      ('offset', ctypes.c_uint64),  #
+      ('length', ctypes.c_uint64)
+  ]
+
+
+class HeaderOffset(ctypes.LittleEndianStructure):
+  _fields_ = [('hdr_offset', ctypes.c_uint32)]
+
+
+def copy_bytes(src, dst, n):
+  """Copies 'n' bytes from the 'src' file to the 'dst' file."""
+  if n < 0:
+    raise ValueError('Negative copy count: {}'.format(n))
+  while n > 0:
+    buf = src.read(min(n, io.DEFAULT_BUFFER_SIZE))
+    if not buf:
+      raise EOFError('Unexpected end of src file')
+    dst.write(buf)
+    n -= len(buf)
+
+
+def copy(src, dst):
+  """Copies from the 'src' file to the 'dst' file until EOF on 'src'."""
+  buf = src.read(io.DEFAULT_BUFFER_SIZE)
+  while buf:
+    dst.write(buf)
+    buf = src.read(io.DEFAULT_BUFFER_SIZE)
+
+
+def pad_to_block_boundary(f):
+  """Pads the file with zeroes to data block boundary."""
+  f.write(b'\0' * (-f.tell() % DATA_BLOCK_SIZE))
+
+
+def ilog2(n):
+  l = int(math.log(n, 2))
+  if n != 1 << l:
+    raise ValueError('{} is not a power of 2'.format(n))
+  return l
+
+
+def serialize_struct(struct):
+  """Serializes a ctypes.Structure to a byte array."""
+  return bytes(ctypes.string_at(ctypes.pointer(struct), ctypes.sizeof(struct)))
+
+
+def veritysetup(data_filename, tree_filename, salt):
+  """Built-in Merkle tree generation algorithm."""
+  salted_hash = hashlib.new(HASH_ALGORITHM)
+  salted_hash.update(salt)
+  hashes_per_block = HASH_BLOCK_SIZE // salted_hash.digest_size
+  level_blocks = [os.stat(data_filename).st_size // DATA_BLOCK_SIZE]
+  while level_blocks[-1] > 1:
+    level_blocks.append(
+        (level_blocks[-1] + hashes_per_block - 1) // hashes_per_block)
+  hash_block_offset = sum(level_blocks) - level_blocks[0]
+  with open(data_filename, 'rb') as datafile:
+    with open(tree_filename, 'r+b') as hashfile:
+      for level, blockcount in enumerate(level_blocks):
+        (i, pending) = (0, bytearray())
+        for j in range(blockcount):
+          h = salted_hash.copy()
+          if level == 0:
+            datafile.seek(j * DATA_BLOCK_SIZE)
+            h.update(datafile.read(DATA_BLOCK_SIZE))
+          else:
+            hashfile.seek((hash_block_offset + j) * HASH_BLOCK_SIZE)
+            h.update(hashfile.read(HASH_BLOCK_SIZE))
+          pending += h.digest()
+          if level + 1 == len(level_blocks):
+            assert len(pending) == salted_hash.digest_size
+            return binascii.hexlify(pending).decode('ascii')
+          if len(pending) == HASH_BLOCK_SIZE or j + 1 == blockcount:
+            pending += b'\0' * (HASH_BLOCK_SIZE - len(pending))
+            hashfile.seek((hash_block_offset - level_blocks[level + 1] + i) *
+                          HASH_BLOCK_SIZE)
+            hashfile.write(pending)
+            (i, pending) = (i + 1, bytearray())
+        hash_block_offset -= level_blocks[level + 1]
+
+
+class Extension(object):
+  """An fs-verity patch or elide extension."""
+
+  def __init__(self, offset, length):
+    self.offset = offset
+    self.length = length
+    if self.length < self.MIN_LENGTH:
+      raise ValueError('length too small (got {}, need >= {})'.format(
+          self.length, self.MIN_LENGTH))
+    if self.length > self.MAX_LENGTH:
+      raise ValueError('length too large (got {}, need <= {})'.format(
+          self.length, self.MAX_LENGTH))
+    if self.offset < 0:
+      raise ValueError('offset cannot be negative (got {})'.format(self.offset))
+
+  def serialize(self):
+    type_buf = self._serialize_impl()
+    hdr = fsverity_extension()
+    pad = -len(type_buf) % 8
+    hdr.length = ctypes.sizeof(hdr) + len(type_buf) + pad
+    hdr.type = self.TYPE_CODE
+    return serialize_struct(hdr) + type_buf + (b'\0' * pad)
+
+  def __str__(self):
+    return '{}(offset {}, length {})'.format(self.__class__.__name__,
+                                             self.offset, self.length)
+
+
+class ElideExtension(Extension):
+  """An fs-verity elide extension."""
+
+  TYPE_CODE = FS_VERITY_EXT_ELIDE
+  MIN_LENGTH = 1
+  MAX_LENGTH = (1 << 64) - 1
+
+  def __init__(self, offset, length):
+    Extension.__init__(self, offset, length)
+
+  def apply(self, out_file):
+    pass
+
+  def _serialize_impl(self):
+    ext = fsverity_extension_elide()
+    ext.offset = self.offset
+    ext.length = self.length
+    return serialize_struct(ext)
+
+
+class PatchExtension(Extension):
+  """An fs-verity patch extension."""
+
+  TYPE_CODE = FS_VERITY_EXT_PATCH
+  MIN_LENGTH = 1
+  MAX_LENGTH = 255
+
+  def __init__(self, offset, data):
+    Extension.__init__(self, offset, len(data))
+    self.data = data
+
+  def apply(self, dst):
+    dst.write(self.data)
+
+  def _serialize_impl(self):
+    ext = fsverity_extension_patch()
+    ext.offset = self.offset
+    ext.length = self.length
+    return serialize_struct(ext) + self.data
+
+
+class BadExtensionListError(Exception):
+  pass
+
+
+class FSVerityGenerator(object):
+  """Sets up a file for fs-verity."""
+
+  def __init__(self, in_filename, out_filename, salt, **kwargs):
+    self.in_filename = in_filename
+    self.original_size = os.stat(in_filename).st_size
+    self.out_filename = out_filename
+    self.salt = salt
+    assert len(salt) == FS_VERITY_SALT_SIZE
+
+    self.tree_filename = kwargs.get('tree_filename')
+
+    self.extensions = kwargs.get('extensions')
+    if self.extensions is None:
+      self.extensions = []
+
+    self.builtin_veritysetup = kwargs.get('builtin_veritysetup')
+    if self.builtin_veritysetup is None:
+      self.builtin_veritysetup = False
+
+    self.tmp_filenames = []
+
+    # Patches and elisions must be within the file size and must not overlap.
+    self.extensions = sorted(self.extensions, key=lambda ext: ext.offset)
+    for i, ext in enumerate(self.extensions):
+      ext_end = ext.offset + ext.length
+      if ext_end > self.original_size:
+        raise BadExtensionListError(
+            '{} extends beyond end of file!'.format(ext))
+      if i + 1 < len(
+          self.extensions) and ext_end > self.extensions[i + 1].offset:
+        raise BadExtensionListError('{} overlaps {}!'.format(
+            ext, self.extensions[i + 1]))
+
+  def _open_tmpfile(self, mode):
+    f = tempfile.NamedTemporaryFile(mode, delete=False)
+    self.tmp_filenames.append(f.name)
+    return f
+
+  def _delete_tmpfiles(self):
+    for filename in self.tmp_filenames:
+      os.unlink(filename)
+
+  def _apply_extensions(self, data_filename):
+    with open(data_filename, 'rb') as src:
+      with self._open_tmpfile('wb') as dst:
+        src_pos = 0
+        for ext in self.extensions:
+          print('Applying {}'.format(ext))
+          copy_bytes(src, dst, ext.offset - src_pos)
+          ext.apply(dst)
+          src_pos = ext.offset + ext.length
+          src.seek(src_pos)
+        copy(src, dst)
+        return dst.name
+
+  def _generate_merkle_tree(self, data_filename):
+    """Generates a file's Merkle tree for fs-verity.
+
+    Args:
+       data_filename: file for which to generate the tree.  Patches and/or
+           elisions may need to be applied on top of it.
+
+    Returns:
+        (root hash as hex, name of the file containing the Merkle tree).
+
+    Raises:
+        OSError: A problem occurred when executing the 'veritysetup'
+            program to generate the Merkle tree.
+    """
+
+    # If there are any patch or elide extensions, apply them to a temporary file
+    # and use that to build the Merkle tree instead of the original data.
+    if self.extensions:
+      data_filename = self._apply_extensions(data_filename)
+
+    # Pad to a data block boundary before building the Merkle tree.
+    # Note: elisions may result in padding being needed, even if the original
+    # file was block-aligned!
+    with open(data_filename, 'ab') as f:
+      pad_to_block_boundary(f)
+
+    # Choose the file to which we'll output the Merkle tree: either an
+    # explicitly specified one or a temporary one.
+    if self.tree_filename is not None:
+      tree_filename = self.tree_filename
+    else:
+      with self._open_tmpfile('wb') as f:
+        tree_filename = f.name
+
+    if self.builtin_veritysetup:
+      root_hash = veritysetup(data_filename, tree_filename, self.salt)
+    else:
+      # Delegate to 'veritysetup' to actually build the Merkle tree.
+      cmd = [
+          'veritysetup',
+          'format',
+          data_filename,
+          tree_filename,
+          '--salt=' + binascii.hexlify(self.salt).decode('ascii'),
+          '--no-superblock',
+          '--hash={}'.format(HASH_ALGORITHM),
+          '--data-block-size={}'.format(DATA_BLOCK_SIZE),
+          '--hash-block-size={}'.format(HASH_BLOCK_SIZE),
+      ]
+      print(' '.join(cmd))
+      output = subprocess.check_output(cmd, universal_newlines=True)
+
+      # Extract the root hash from veritysetup's output.
+      root_hash = None
+      for line in output.splitlines():
+        if line.startswith('Root hash'):
+          root_hash = line.split(':')[1].strip()
+          break
+      if root_hash is None:
+        raise OSError('Root hash not found in veritysetup output!')
+    return root_hash, tree_filename
+
+  def _generate_header(self):
+    """Generates the fs-verity header."""
+    header = fsverity_header()
+    assert ctypes.sizeof(header) == 64
+    header.magic = FS_VERITY_MAGIC
+    header.maj_version = 1
+    header.min_version = 0
+    header.log_blocksize = ilog2(DATA_BLOCK_SIZE)
+    assert HASH_ALGORITHM == 'sha256'
+    header.log_arity = ilog2(DATA_BLOCK_SIZE / 32)  # sha256 hash size
+    header.meta_algorithm = 1  # sha256
+    header.data_algorithm = 1  # sha256
+    header.size = self.original_size
+    header.extension_count = len(self.extensions)
+    header.salt = self.salt
+    return serialize_struct(header)
+
+  def generate(self):
+    """Sets up a file for fs-verity.
+
+    The input file will be copied to the output file, then have the fs-verity
+    metadata appended to it.
+
+    Returns:
+       (fs-verity measurement, Merkle tree root hash), both as hex.
+
+    Raises:
+       IOError: Problem reading/writing the files.
+    """
+
+    # Copy the input file to the output file.
+    with open(self.in_filename, 'rb') as infile:
+      with open(self.out_filename, 'wb') as outfile:
+        copy(infile, outfile)
+        if outfile.tell() != self.original_size:
+          raise IOError('{}: size changed!'.format(self.in_filename))
+
+    try:
+      # Generate the file's Merkle tree and calculate its root hash.
+      (root_hash, tree_filename) = self._generate_merkle_tree(self.out_filename)
+
+      with open(self.out_filename, 'ab') as outfile:
+
+        # Pad to a block boundary and append the Merkle tree.
+        pad_to_block_boundary(outfile)
+        with open(tree_filename, 'rb') as treefile:
+          copy(treefile, outfile)
+
+        # Append the fs-verity header.
+        header = self._generate_header()
+        outfile.write(header)
+
+        # Append extension items, if any.
+        extensions = bytearray()
+        for ext in self.extensions:
+          extensions += ext.serialize()
+        outfile.write(extensions)
+
+        # Finish the output file by writing the header offset field.
+        hdr_offset = HeaderOffset()
+        hdr_offset.hdr_offset = len(header) + len(extensions) + ctypes.sizeof(
+            HeaderOffset)
+        outfile.write(serialize_struct(hdr_offset))
+
+        # Compute the fs-verity measurement.
+        measurement = hashlib.new(HASH_ALGORITHM)
+        measurement.update(header)
+        measurement.update(extensions)
+        measurement.update(binascii.unhexlify(root_hash))
+        measurement = measurement.hexdigest()
+    finally:
+      self._delete_tmpfiles()
+
+    return (measurement, root_hash)
+
+
+def convert_salt_argument(argstring):
+  try:
+    b = binascii.unhexlify(argstring)
+    if len(b) != FS_VERITY_SALT_SIZE:
+      raise ValueError
+    return b
+  except (ValueError, TypeError):
+    raise argparse.ArgumentTypeError(
+        'Must be a 16-character hex string.  (Got "{}")'.format(argstring))
+
+
+def convert_patch_argument(argstring):
+  try:
+    (offset, patchfile) = argstring.split(',')
+    offset = int(offset)
+  except ValueError:
+    raise argparse.ArgumentTypeError(
+        'Must be formatted as <offset,patchfile>.  (Got "{}")'.format(
+            argstring))
+  try:
+    with open(patchfile, 'rb') as f:
+      data = f.read()
+    return PatchExtension(int(offset), data)
+  except (IOError, ValueError) as e:
+    raise argparse.ArgumentTypeError(e)
+
+
+def convert_elide_argument(argstring):
+  try:
+    (offset, length) = argstring.split(',')
+    offset = int(offset)
+    length = int(length)
+  except ValueError:
+    raise argparse.ArgumentTypeError(
+        'Must be formatted as <offset,length>.  (Got "{}")'.format(argstring))
+  try:
+    return ElideExtension(offset, length)
+  except ValueError as e:
+    raise argparse.ArgumentTypeError(e)
+
+
+def parse_args():
+  """Parses the command-line arguments."""
+  parser = argparse.ArgumentParser(
+      description='Sets up a file for fs-verity (file-based integrity)')
+  parser.add_argument(
+      'in_filename',
+      metavar='<input_file>',
+      type=str,
+      help='Original content input file')
+  parser.add_argument(
+      'out_filename',
+      metavar='<output_file>',
+      type=str,
+      help='Output file formatted for fs-verity')
+  parser.add_argument(
+      '--salt',
+      metavar='<hex_string>',
+      type=convert_salt_argument,
+      default='00' * FS_VERITY_SALT_SIZE,
+      help='{}-byte salt, given as a {}-character hex string'.format(
+          FS_VERITY_SALT_SIZE, FS_VERITY_SALT_SIZE * 2))
+  parser.add_argument(
+      '--tree-file',
+      metavar='<tree_file>',
+      dest='tree_filename',
+      type=str,
+      help='File to which to output the raw Merkle tree (optional)')
+  parser.add_argument(
+      '--patch',
+      metavar='<offset,patchfile>',
+      type=convert_patch_argument,
+      action='append',
+      dest='extensions',
+      help="""Add a patch extension (not recommended).  Data in the region
+      beginning at <offset> in the original file and continuing for
+      filesize(<patchfile>) bytes will be replaced with the contents of
+      <patchfile> for verification purposes, but reads will return the original
+      data.""")
+  parser.add_argument(
+      '--elide',
+      metavar='<offset,length>',
+      type=convert_elide_argument,
+      action='append',
+      dest='extensions',
+      help="""Add an elide extension (not recommended).  Data in the region
+      beginning at <offset> in the original file and continuing for <length>
+      bytes will not be verified.""")
+  parser.add_argument(
+      '--builtin-veritysetup',
+      action='store_const',
+      const=True,
+      help="""Use the built-in Merkle tree generation algorithm rather than
+      invoking the external veritysetup program.  They should produce the same
+      result.""")
+  return parser.parse_args()
+
+
+def main():
+  args = parse_args()
+  try:
+    generator = FSVerityGenerator(
+        args.in_filename,
+        args.out_filename,
+        args.salt,
+        tree_filename=args.tree_filename,
+        extensions=args.extensions,
+        builtin_veritysetup=args.builtin_veritysetup)
+  except BadExtensionListError as e:
+    sys.stderr.write('ERROR: {}\n'.format(e))
+    sys.exit(1)
+
+  (measurement, root_hash) = generator.generate()
+
+  print('Merkle root hash: {}'.format(root_hash))
+  print('fs-verity measurement: {}'.format(measurement))
+
+
+if __name__ == '__main__':
+  main()
diff --git a/mkfsverity.sh b/mkfsverity.sh
index bb88865..85dc55c 100755
--- a/mkfsverity.sh
+++ b/mkfsverity.sh
@@ -1,48 +1,92 @@
-#!/bin/sh
+#!/bin/bash
 
-set -x
+set -eu
 
-OPTIND=1
-patch=0
-patch_offset=28672
-patch_length=128
-elide=0
-elide_offset=12288
-elid_length=8192
-size=36864
-keep_input=0
-while getopts "poeskf:" opt; do
-    case "$opt" in
-	p)  patch=1
-	    ;;
-	o)  patch_offset=$OPTARG
-	    ;;
-	e)  elide=1
-	    ;;
-	f)  elide_offset=$OPTARG
-	    ;;
-	s)  size=$OPTARG
-	    ;;
-	k)  keep_input=1
-	    ;;
-    esac
-done
-shift $((OPTIND-1))
-[ "$1" = "--" ] && shift
-filename="input-$size.apk"
-backup_filename="input-$size-backup.apk"
-patch_filename="output-$size-patch"
-echo "size=$size, filename='$filename', patch_filename='$patch_filename', patch=$patch, patch_offset=$patch_offset, elide=$elide, unparsed: $@"
-num_blks=$(($size / 4096))
-blk_aligned_sz=$(($num_blks*4096))
-echo "Number of blocks: $num_blks"
-if [ $keep_input -eq 0 ]; then
-    remainder=$(($size % 4096))
-    echo "Remainder: $remainder"
-    dd if=/dev/urandom of=$filename bs=4096 count=$num_blks
-    dd if=/dev/urandom of=$filename bs=1 count=$remainder seek=$blk_aligned_sz
+SIZE=36864
+KEEP_INPUT=false
+PATCHES=()
+ELISIONS=()
+
+usage() {
+	cat << EOF
+Usage: $0 [OPTIONS]
+
+Test formatting a randomly generated file for fs-verity.
+
+Options:
+  -s, --size=SIZE
+  -k, --keep-input
+  -p, --patch=OFFSET,LENGTH [can be repeated]
+  -e, --elide=OFFSET,LENGTH [can be repeated]
+  -h, --help
+EOF
+}
+
+if ! options=$(getopt -o s:kp:e:h \
+	-l size:,keep-input,patch:,elide:,help -- "$@"); then
+	usage 1>&2
+	exit 2
 fi
-dd if=/dev/urandom of=$patch_filename bs=1 count=$patch_length
-if [ $elide -eq 1 ]; then ELIDE_ARGS=" --elide_offset=${elide_offset} --elide_length=${elide_length}"; fi
-if [ $patch -eq 1 ]; then PATCH_ARGS=" --patch_offset=${patch_offset} --patch_file=${patch_filename}"; fi
-./fsverity.py $filename "output-$size.apk" --salt=deadbeef00000000 ${PATCH_ARGS} ${ELIDE_ARGS}
+
+eval set -- "$options"
+
+while (( $# > 0 )); do
+	case "$1" in
+	-s|--size)
+		SIZE="$2"
+		shift
+		;;
+	-k|--keep-input)
+		KEEP_INPUT=true
+		;;
+	-p|--patch)
+		PATCHES+=("$2")
+		shift
+		;;
+	-e|--elide)
+		ELISIONS+=("$2")
+		shift
+		;;
+	-h|--help)
+		usage
+		exit 0
+		;;
+	--)
+		shift
+		break
+		;;
+	*)
+		echo 1>&2 "Invalid option \"$1\""
+		usage 1>&2
+		exit 2
+		;;
+	esac
+	shift
+done
+
+if (( $# != 0 )); then
+	usage 1>&2
+	exit 2
+fi
+
+filename="input-$SIZE.apk"
+
+if ! $KEEP_INPUT; then
+    head -c "$SIZE" /dev/urandom > "$filename"
+fi
+
+cmd=(./fsveritysetup.py "$filename" "output-$SIZE.apk")
+cmd+=("--salt=deadbeef00000000")
+
+for i in "${!PATCHES[@]}"; do
+	patch_offset=$(echo "${PATCHES[$i]}" | cut -d, -f1)
+	patch_length=$(echo "${PATCHES[$i]}" | cut -d, -f2)
+	patch_filename="output-$SIZE-patch_$i"
+	head -c "$patch_length" /dev/urandom > "$patch_filename"
+	cmd+=("--patch=$patch_offset,$patch_filename")
+done
+
+cmd+=("${ELISIONS[@]/#/--elide=}")
+
+echo "${cmd[@]}"
+"${cmd[@]}"