Add per-rank disk checkpointing for adjoint tape#4891
Conversation
Enable each MPI rank to write its adjoint checkpoint data to its own HDF5 file using PETSc Vec I/O on COMM_SELF. This avoids parallel HDF5 overhead and enables use of fast node-local storage (NVMe/SSD) on HPC systems, where shared filesystem I/O is a major bottleneck for large-scale time-dependent adjoint computations. New parameter `per_rank_dirname` on `enable_disk_checkpointing()`. When set, function data is checkpointed per-rank while mesh data (via `checkpointable_mesh`) remains on shared storage. Requires same number of ranks on restore (inherent in adjoint workflows).
|
In PyOP2 we have the notion of a 'compilation comm' which is a communicator defined over each node (https://github.com/firedrakeproject/firedrake/blob/main/pyop2/mpi.py#L450). Might something like this be appropriate/more general here? |
|
Thanks Connor, that's a great idea. I hadn't considered using the compilation comm pattern here. I did look into what a node-local comm approach would involve. The main challenge is that From my perspective on Gadi, I have 48 cores per node writing to node-local SSDs, and per-rank I/O is completely manageable. My thinking here is that this is specifically an adjoint solution. Disk checkpointing for the adjoint tape is extremely I/O heavy, so the less communicator overhead involved, the better. I'd expect the That said, this is ultimately a decision for the Firedrake folks on what works best for general users. If a node-local comm approach is preferred, it's doable. Happy to refactor if that's the direction you'd like to go. |
|
@connorjward Following up on your suggestion about using the compilation communicator. Angus and I had a discussion about this and we tried to see if we can simply do Saving works fine, but loading deadlocks. Here's a minimal reproducer: """mpiexec -n 4 python test_subcomm_checkpoint.py"""
import os
import tempfile
from firedrake import *
from firedrake.checkpointing import CheckpointFile
comm = COMM_WORLD
mesh = UnitSquareMesh(4, 4)
V = FunctionSpace(mesh, "CG", 1)
f = Function(V, name="f")
f.interpolate(SpatialCoordinate(mesh)[0])
node_comm = comm.Split(color=comm.rank // 2, key=comm.rank)
if comm.rank == 0:
tmpdir = tempfile.mkdtemp()
else:
tmpdir = None
tmpdir = comm.bcast(tmpdir, root=0)
fname = os.path.join(tmpdir, f"node{comm.rank // 2}.h5")
with CheckpointFile(fname, 'w', comm=node_comm) as out:
out.save_mesh(mesh)
out.save_function(f)
with CheckpointFile(fname, 'r', comm=node_comm) as inp:
mesh2 = inp.load_mesh()
f2 = inp.load_function(mesh2, "f") # deadlocks hereThe issue is that That said, I might be missing something. Is there a way to make this work that I'm not seeing? |
I don't think so. I'd have been surprised had that worked.
This is what I'm suggesting. It doesn't seem like a lot of work to change your API from |
Sorry, that was It kind of seems like a |
Refactors the per_rank_dirname parameter into a more general checkpoint_comm + checkpoint_dir interface, following reviewer feedback. Instead of hardcoding COMM_SELF, users now pass any MPI communicator (COMM_SELF for per-rank files, a node-local comm for per-node files, etc.). The PETSc Vec I/O uses createMPI on the supplied communicator rather than createSeq on COMM_SELF, making the approach work for arbitrary communicator topologies. Removes three serial-only checkpoint_comm tests that are fully covered by their parallel counterparts and adds node_comm tests that exercise the multi-rank-per-file path using COMM_TYPE_SHARED.
|
Thanks Connor, done. I've refactored the API from The main PETSc-level change this required was switching from I've also added tests with |
connorjward
left a comment
There was a problem hiding this comment.
Seems alright to me. It would definitely be good to get some feedback from @JHopeCollins, who has done similar comm wrangling for ensemble.
It's definitely an interesting question. Conceptually I think it should be possible to checkpoint a DMPlex to multiple files but its far from trivial. An added complication is that we would have to preserve the N-to-M checkpointing behaviour (i.e. reading and writing with different numbers of ranks). |
Add FutureWarning to deprecated new_checkpoint_file method. Use isinstance(mesh, ufl.MeshSequence) instead of hasattr check. Replace COMM_TYPE_SHARED tests with comm.Split(rank // 2) to guarantee a communicator with 1 < size < COMM_WORLD in the 3-rank test.
Extract _generate_function_space_name from CheckpointFile into a module-level function in firedrake/checkpointing.py and reuse it for the checkpoint_comm Vec naming instead of maintaining a separate _generate_checkpoint_vec_name. The free function also handles MeshSequenceGeometry defensively. CheckpointFile method delegates to it.
The multi-mesh tests chained two PDE solves via assemble(u_a * dx), a global reduction whose floating-point result can vary across parallel runs due to reduction ordering. This made the J == Jnew assertion flaky at the np.allclose tolerance boundary. Make the mesh_b solve independent and drop the redundant memory baseline comparison.
|
Simplified the multi-mesh tests to fix intermittent CI failures. The original design chained two solves via a global reduction (assemble(u_a * dx)), which amplified parallel floating-point non-determinism across tape replays. The two solves are now independent while still doing multi-mesh checkpointing. |
JHopeCollins
left a comment
There was a problem hiding this comment.
My gut feeling is that this should be part of CheckpointFile rather than hidden in the adjoint utils.
I don't think I fully understand the problem with checkpointing the mesh using the global comm but checkpointing the function data using a subcomm. Why does that mean you have to split the DM into multiple files?
You can't split the DM into multiple files. That's why this only works for functions. |
|
The main thing is that I would like to see the local save/load logic abstracted out of the adjoint code somehow, because the adjoint code really shouldn't be thinking about actual concrete data. I can think of four potential ways to do this. Happy to hear arguments for/against each one. In each case we'd obviously need to be very explicit about the restrictions of saving/loading locally, i.e. you must have exactly the same partition for saving and viewing so its basically only for saving/loading during the same programme (and for option 3 you also can't save/load the mesh with this class).
|
This is a fair point.
I've read the implementation and still confused about the difference between these. In PETSc the term 'local' applies to lots of different things so this may not do quite what you expect. I wonder if this is basically reimplementing We should discuss this in today's meeting. |
|
@sghelichkhani from today's meeting we decided that we want this functionality exposed as a new I quite like something like |
How about
@sghelichkhani is it possible to use |
|
Note on _checkpoint_indices: this class-level dict on CheckpointFunction is never pruned. File entries persist after the HDF5 file is deleted. The local path (TemporaryFunctionCheckpointFile._indices) does not have this issue since remove_file handles cleanup. Not fixing this to avoid the risk of pruning entries before restore() reads them. The memory cost is negligible. |
…ject#4865) * add log event markers * build spatial index using CreateWithArray
…ates (firedrakeproject#4900) use numpy directly for non-extruded
Move PETSc Vec I/O into TemporaryFunctionCheckpointFile in checkpointing.py. Rename save/restore methods, fix deprecation warning, remove redundant fixture and forwarding method, clean up imports.
There was a problem hiding this comment.
Hi Siya,
Thanks for your effort on this! The current TemporaryFunctionCheckpointFile is a really good start on separating out the "save to local not shared" storage abstraction from the adjoint code, I think we're converging on the right solution.
I've left of comments on specific parts, but they are predominantly about two main suggestions which I think will complete the separation of concerns between the checkpointing and the adjoint, and hopefully make the code simpler by removing a bit of repetition.
-
The most important point is that, although the logic of actually saving/loading from local storage is now totally out of the adjoint code, it has taken some of the "how to checkpoint along the tape" logic with it (i.e. the
_indicesstuff).
This calculation of the stored name/index looks identical for both checkpoint types, so it should be done just once by the adjoint code, and then passed to whatever checkpoint file type is being used.
The same is true for resetting the name/count of theFunctionafter loading - this is the responsibility of the adjoint not the checkpoint. -
The second is that it looks like
TemporaryFunctionCheckpointFileis actuallyAccessAnyTemporaryFunctionCheckpointFilebecause you have to pass the file name every time you want to save/load, rather than having a 1-to-1 between an instance of this class and an hdf5 file, like we have forCheckpointFile.
If theTemporaryFunctionCheckpointFileholds onto one file like theCheckpointFiledoes then I think you can make thesave_functionsignature identical between the two, and you can make theload_functionsignature almost identical -TemporaryFunctionCheckpointFilewill take aFunctionSpacerather than aMeshas the first argument but the rest will be the same.
|
Just to point out a nice-to-have-but-not-essential extra if you do both suggestions above. It would be great to allow the You'd just need to add these methods to the class, like def __enter__(self):
return self
def __exit__(self, *args):
self.close()Then in the adjoint code you'd save both the Checkpoint class type and the comm to make the checkpoint over on the to save the # work out stored_index and stored_name
...
with self.file.CheckpointClass(self.file.name, 'a', comm=self.file.checkpoint_comm) as outfile:
outfile.save_function(function, name=stored_name, idx=stored_index)And for restoring the def restore(self):
if isinstance(self.file.CheckpointClass, TemporaryFunctionCheckpointFile):
base = self.function_space
else:
base = self.mesh
with self.file.CheckpointClass(self.file.name, 'r', comm=self.file.checkpoint_comm) as infile:
function = infile.load_function(base, self.stored_name, self.stored_index)
return type(function)(function.function_space(),
val=function.dat, name=self.name, count=self.count)Potentially we could allow def restore(self):
with self.file.CheckpointClass(self.file.name, 'r', comm=self.file.checkpoint_comm) as infile:
function = infile.load_function(self.function_space, self.stored_name, self.stored_index)
return type(function)(function.function_space(),
val=function.dat, name=self.name, count=self.count) |
Make it 1-to-1 with a single file/mode (context manager, viewer held open). Move index tracking into CheckpointFunction.__init__. Move directory management into DiskCheckpointer. Prune _checkpoint_indices in CheckPointFileReference.__del__. Add pruning test.
|
This ended up being quite a bit of work :), but I think the result is much cleaner. The main change is that TemporaryFunctionCheckpointFile is now 1-to-1 with a single HDF5 file and a single mode, rather than a manager that creates and tracks multiple files. The constructor now takes (comm, filepath, mode) and opens the PETSc HDF5 viewer immediately, holding it for the object's lifetime. The class is now a proper context manager with enter/exit/del, so save and restore use it exactly like CheckpointFile: open a file, do one operation, close. The internal _indices dict and new_file()/remove_file() are gone. Index tracking -- computing stored_name and stored_index -- now happens once in CheckpointFunction.init and is shared by both the shared and local checkpoint paths. Previously this logic was duplicated in _save_shared and inside TemporaryFunctionCheckpointFile.save_function. Now _save_shared and _save_local_checkpoint are structurally parallel: both open their respective file class as a context manager, call save_function with the pre-computed name and index, and close. Directory management (creating the temp directory for local checkpoint files) has moved out of TemporaryFunctionCheckpointFile and into DiskCheckpointer. One thing worth noting there: the bcast for the local directory uses checkpoint_comm, not comm. Only ranks within the same checkpoint_comm group share a local filesystem, so broadcasting on COMM_WORLD would be wrong -- each group creates its own subdirectory independently. CheckPointFileReference loses the temp_ckpt reference and gains checkpoint_comm instead. The del method picks the right communicator for file cleanup automatically: checkpoint_comm.rank == 0 for local files, comm.rank == 0 for shared ones. |
|
I also took the opportunity to fix the long-standing issue with _checkpoint_indices never being pruned. This is a class-level dict on CheckpointFunction that maps filepath -> {stored_name -> count}, and previously it accumulated entries for every checkpoint file ever used without ever removing them. Over a long adjoint run with many tape advances this is a slow but real memory leak. The fix is one line in CheckPointFileReference.del: pop the entry for self.name from CheckpointFunction._checkpoint_indices. The reason this is safe comes down to one invariant: CheckpointFunction holds self.file as a direct strong reference (not a weak one), so del on the CheckPointFileReference can only fire after every CheckpointFunction that wrote to that filepath has already been garbage-collected. And restore() never reads _checkpoint_indices; it uses stored_name and stored_index baked into the CheckpointFunction instance at save time, so there is no way a restore call can race with the pruning. The one case that required careful thought was revolve-style schedules, where the tape can snapshot its state and later restore it, causing new CheckpointFunction objects to be written to an old file. This is safe because the tape's checkpoint store holds the CheckPointFileReference alive by reference until re-execution is complete, keeping the reference count above zero and preventing del from firing prematurely. I have added a test (test_checkpoint_indices_pruning) that verifies the entries are actually removed once both reference chains, the DiskCheckpointer's and the tape's, are released. |
|
On the question of whether this should live inside CheckpointFile: I think the answer is no, and the docstring tries to make the reason explicit. The two classes have fundamentally different contracts. CheckpointFile is collective on COMM_WORLD, understands mesh topology, and loads functions by reconstructing them from the DM, which is why load_function takes a mesh. TemporaryFunctionCheckpointFile is collective only on the sub-communicator, knows nothing about topology, and loads functions by filling in a Vec given a FunctionSpace the caller already holds. It takes a FunctionSpace precisely because it has no way to reconstruct one from scratch. The reason these contracts cannot be unified is the deadlock I showed earlier in the PR. The sectionLoad and globalVectorLoad calls inside CheckpointFile.load_function are collective on the mesh DM's communicator, which is COMM_WORLD. If only a subset of ranks (those in a node-local sub-comm) enter that call, the collective never completes. This is not a bug in CheckpointFile, it is a consequence of the mesh DM living on COMM_WORLD by design, because the parallel solves need it there. Making CheckpointFile work with a sub-communicator for function data while keeping the mesh on COMM_WORLD would require cross-communicator DM operations that PETSc does not support. So the separation is not just cosmetic, it reflects a real constraint. |
…__del__ Store _cleanup_comm at construction time (checkpoint_comm if set, else comm) so __del__ uses it directly without branching.
Collapses the TemporaryDirectory vs mkdtemp branch into a single call. Also fixes the bcast for the local dirname to a single unconditional call.
|
@JHopeCollins I think I have addressed all your points now. Let me know if anything needs more work. |
JHopeCollins
left a comment
There was a problem hiding this comment.
Great, thanks so much for thoroughly addressing our reviews! It's looking good now and I think its basically ready to go.
I think the result is much cleaner.
I'm glad you agree. The internal implementation hasn't changed but I think that the separation of concerns is clear now between how we interact with storage and how we do book-keeping for the tape.
I spotted one thing that absolutely needs to be fixed. It was a bug before this PR but its related and its only a few lines: The CheckpointFile must be passed the right comm! Currently it assumes COMM_WORLD which may not always be correct.
Other than that, there's one comment about not having to modify CheckpointFileReference that I think would reduce the diff, but you can take that or leave it.
Once the CheckpointFile is passed the right comm I'm very happy to approve this!
|
@sghelichkhani The CI is failing on the gusto smoke tests and an IO timeout.
|
Co-authored-by: Josh Hope-Collins <jhc.jss@gmail.com>
…m:sghelichkhani/firedrake into sghelichkhani/per-rank-disk-checkpointing
|
Applied your suggestion on Also caught two more instances of the missing |
|
Store self.comm on CheckPointFileReference so the API introduced by the review-suggested CheckpointFile comm fixes works correctly. Drops the now-redundant _cleanup_comm. Also fixes two missing comm= arguments in checkpointable_mesh.
JHopeCollins
left a comment
There was a problem hiding this comment.
Applied your suggestion on CheckPointFileReference
Thanks!
Also caught two more instances of the missing comm bug in checkpointable_mesh
Great, good spot on those, it looks right to me.
This all looks good to me now. Thanks for working with us through multiple reviews, I'm very happy with where this has ended up.
firedrakeproject#4891 introduced the use of tempfile.TemporaryDirectory with the `delete` parameter. This was introduced in Python 3.12, and therefore this commit breaks compatibility with Python 3.10 and 3.11. pyproject.toml states that Python >= 3.10 is supported. Due to OS constraints, we'll be stuck on Python 3.11 on our HPC systems for some time.
#4891 introduced the use of tempfile.TemporaryDirectory with the `delete` parameter. This was introduced in Python 3.12, and therefore this commit breaks compatibility with Python 3.10 and 3.11. pyproject.toml states that Python >= 3.10 is supported. Due to OS constraints, we'll be stuck on Python 3.11 on our HPC systems for some time.
Motivation
We run time-dependent adjoint Stokes simulations with close to a billion degrees of freedom per timestep. Recomputation-based checkpointing schedules (revolve/binomial) are infeasible due to the cost of recomputing the Stokes solve, so disk checkpointing (
SingleDiskStorageSchedule) is the only viable option.Currently,
CheckpointFilewrites all ranks to a single shared HDF5 file via parallel HDF5 (PETSc.ViewerHDF5onCOMM_WORLD). On HPC systems, this means all checkpoint I/O goes through the shared parallel filesystem (Lustre/GPFS), which becomes a severe bottleneck. Under 24-hour job time limits, the disk I/O overhead makes simulations that comfortably fit in memory-checkpointed wall time infeasible when switching to disk checkpointing.HPC nodes typically have fast node-local NVMe/SSD storage that is orders of magnitude faster than the shared filesystem. However, the current collective I/O approach in
CheckpointFilecannot use node-local storage because all ranks must access the same file path.Approach
Following @connorjward's suggestion in #4891 (comment), the implementation uses a general
checkpoint_commparameter rather than a hardcoded per-rank approach. Users pass any MPI communicator to control how function data is checkpointed:The function data is written using
PETSc.Vec.createMPI+ViewerHDF5on the supplied communicator, bypassingCheckpointFileand its collectiveglobalVectorView/globalVectorLoadonCOMM_WORLD. We tried usingCheckpointFiledirectly with a sub-communicator (see #4891 (comment)), but loading deadlocks because the mesh DM'ssectionLoad/globalVectorLoadare collective onCOMM_WORLD.The mesh checkpoint via
checkpointable_meshstill uses shared storage throughCheckpointFilesince that's a one-time operation and not performance-critical. Fully backwards compatible: withoutcheckpoint_comm, behaviour is unchanged.Multi-mesh considerations
Functions on different meshes with different partitioning work correctly because Vec dataset names include the mesh name and element info (
ckpt_mesh_a_CG2vsckpt_mesh_b_DG1), andcheckpointable_meshensures deterministic partitioning per mesh independently.The supermesh projection across two different meshes still fails in parallel, but that's a pre-existing limitation unrelated to this PR.
Testing
11 tests total covering three checkpointing modes:
Existing shared-mode tests (5): serial and parallel basic checkpointing, successive writes, timestepper with taylor_test, and boundary conditions.
checkpoint_commwithCOMM_SELF(3): parallel basic checkpointing, successive writes (serial), and multi-mesh parallel. These exercise the per-rank file path where each rank writes independently.checkpoint_commwith node communicator (3): parallel basic checkpointing, multi-mesh parallel, and timestepper with taylor_test. These exercise the multi-rank-per-file path usingCOMM_TYPE_SHARED.