Skip to content

Add per-rank disk checkpointing for adjoint tape#4891

Merged
JHopeCollins merged 36 commits intofiredrakeproject:mainfrom
sghelichkhani:sghelichkhani/per-rank-disk-checkpointing
Mar 5, 2026
Merged

Add per-rank disk checkpointing for adjoint tape#4891
JHopeCollins merged 36 commits intofiredrakeproject:mainfrom
sghelichkhani:sghelichkhani/per-rank-disk-checkpointing

Conversation

@sghelichkhani
Copy link
Copy Markdown
Contributor

@sghelichkhani sghelichkhani commented Feb 15, 2026

Motivation

We run time-dependent adjoint Stokes simulations with close to a billion degrees of freedom per timestep. Recomputation-based checkpointing schedules (revolve/binomial) are infeasible due to the cost of recomputing the Stokes solve, so disk checkpointing (SingleDiskStorageSchedule) is the only viable option.

Currently, CheckpointFile writes all ranks to a single shared HDF5 file via parallel HDF5 (PETSc.ViewerHDF5 on COMM_WORLD). On HPC systems, this means all checkpoint I/O goes through the shared parallel filesystem (Lustre/GPFS), which becomes a severe bottleneck. Under 24-hour job time limits, the disk I/O overhead makes simulations that comfortably fit in memory-checkpointed wall time infeasible when switching to disk checkpointing.

HPC nodes typically have fast node-local NVMe/SSD storage that is orders of magnitude faster than the shared filesystem. However, the current collective I/O approach in CheckpointFile cannot use node-local storage because all ranks must access the same file path.

Approach

Following @connorjward's suggestion in #4891 (comment), the implementation uses a general checkpoint_comm parameter rather than a hardcoded per-rank approach. Users pass any MPI communicator to control how function data is checkpointed:

# Per-rank files (each rank writes independently to node-local storage)
enable_disk_checkpointing(checkpoint_comm=MPI.COMM_SELF,
                          checkpoint_dir="/local/scratch")

# Per-node files (ranks on the same node share a file)
node_comm = MPI.COMM_WORLD.Split_type(MPI.COMM_TYPE_SHARED)
enable_disk_checkpointing(checkpoint_comm=node_comm,
                          checkpoint_dir="/local/scratch")

The function data is written using PETSc.Vec.createMPI + ViewerHDF5 on the supplied communicator, bypassing CheckpointFile and its collective globalVectorView/globalVectorLoad on COMM_WORLD. We tried using CheckpointFile directly with a sub-communicator (see #4891 (comment)), but loading deadlocks because the mesh DM's sectionLoad/globalVectorLoad are collective on COMM_WORLD.

The mesh checkpoint via checkpointable_mesh still uses shared storage through CheckpointFile since that's a one-time operation and not performance-critical. Fully backwards compatible: without checkpoint_comm, behaviour is unchanged.

Multi-mesh considerations

Functions on different meshes with different partitioning work correctly because Vec dataset names include the mesh name and element info (ckpt_mesh_a_CG2 vs ckpt_mesh_b_DG1), and checkpointable_mesh ensures deterministic partitioning per mesh independently.

The supermesh projection across two different meshes still fails in parallel, but that's a pre-existing limitation unrelated to this PR.

Testing

11 tests total covering three checkpointing modes:

Existing shared-mode tests (5): serial and parallel basic checkpointing, successive writes, timestepper with taylor_test, and boundary conditions.

checkpoint_comm with COMM_SELF (3): parallel basic checkpointing, successive writes (serial), and multi-mesh parallel. These exercise the per-rank file path where each rank writes independently.

checkpoint_comm with node communicator (3): parallel basic checkpointing, multi-mesh parallel, and timestepper with taylor_test. These exercise the multi-rank-per-file path using COMM_TYPE_SHARED.

Enable each MPI rank to write its adjoint checkpoint data to its own
HDF5 file using PETSc Vec I/O on COMM_SELF. This avoids parallel HDF5
overhead and enables use of fast node-local storage (NVMe/SSD) on HPC
systems, where shared filesystem I/O is a major bottleneck for
large-scale time-dependent adjoint computations.

New parameter `per_rank_dirname` on `enable_disk_checkpointing()`.
When set, function data is checkpointed per-rank while mesh data
(via `checkpointable_mesh`) remains on shared storage. Requires same
number of ranks on restore (inherent in adjoint workflows).
@connorjward
Copy link
Copy Markdown
Contributor

In PyOP2 we have the notion of a 'compilation comm' which is a communicator defined over each node (https://github.com/firedrakeproject/firedrake/blob/main/pyop2/mpi.py#L450). Might something like this be appropriate/more general here?

@sghelichkhani
Copy link
Copy Markdown
Contributor Author

Thanks Connor, that's a great idea. I hadn't considered using the compilation comm pattern here.

I did look into what a node-local comm approach would involve. The main challenge is that CheckpointFile.save_function uses topology_dm.globalVectorView, which is collective over the mesh's communicator (COMM_WORLD). So you can't simply pass a node-local sub-comm to CheckpointFile. The mesh DM lives on COMM_WORLD and you'd get a comm mismatch. That means even with a node-local comm, you'd still need to bypass CheckpointFile and use raw PETSc Vec I/O, just with a parallel Vec on the node comm instead of a sequential Vec on COMM_SELF. The overall complexity ends up being similar, with extra overhead for sub-communicator lifecycle management and intra-node coordination. The benefit is fewer files (N_nodes vs N_ranks).

From my perspective on Gadi, I have 48 cores per node writing to node-local SSDs, and per-rank I/O is completely manageable. My thinking here is that this is specifically an adjoint solution. Disk checkpointing for the adjoint tape is extremely I/O heavy, so the less communicator overhead involved, the better. I'd expect the COMM_SELF approach to actually be faster in practice since every rank operates completely independently with zero coordination or collective operations, even within a node.

That said, this is ultimately a decision for the Firedrake folks on what works best for general users. If a node-local comm approach is preferred, it's doable. Happy to refactor if that's the direction you'd like to go.

@sghelichkhani
Copy link
Copy Markdown
Contributor Author

@connorjward Following up on your suggestion about using the compilation communicator. Angus and I had a discussion about this and we tried to see if we can simply do CheckpointFile(fname, mode, comm=compilation_comm) and use the standard save_function/load_function path with a node-level communicator.

Saving works fine, but loading deadlocks. Here's a minimal reproducer:

"""mpiexec -n 4 python test_subcomm_checkpoint.py"""
import os
import tempfile
from firedrake import *
from firedrake.checkpointing import CheckpointFile

comm = COMM_WORLD
mesh = UnitSquareMesh(4, 4)
V = FunctionSpace(mesh, "CG", 1)
f = Function(V, name="f")
f.interpolate(SpatialCoordinate(mesh)[0])

node_comm = comm.Split(color=comm.rank // 2, key=comm.rank)

if comm.rank == 0:
    tmpdir = tempfile.mkdtemp()
else:
    tmpdir = None
tmpdir = comm.bcast(tmpdir, root=0)
fname = os.path.join(tmpdir, f"node{comm.rank // 2}.h5")

with CheckpointFile(fname, 'w', comm=node_comm) as out:
    out.save_mesh(mesh)
    out.save_function(f)

with CheckpointFile(fname, 'r', comm=node_comm) as inp:
    mesh2 = inp.load_mesh()
    f2 = inp.load_function(mesh2, "f")  # deadlocks here

The issue is that sectionLoad and globalVectorLoad are collective operations on the topology_dm, which lives on COMM_WORLD. When only a subset of ranks (the node sub-comm) enter the load call, the collective never completes. I think this is by design rather than a bug in CheckpointFile, since the contract is that the mesh lives on COMM_WORLD for the parallel solves, and the load path assumes the viewer and the mesh DM share a compatible communicator. Making this work with a sub-communicator would require changes both in PETSc (cross-comm DM operations) and in Firedrake's checkpointing internals.

That said, I might be missing something. Is there a way to make this work that I'm not seeing?

@connorjward
Copy link
Copy Markdown
Contributor

That said, I might be missing something. Is there a way to make this work that I'm not seeing?

I don't think so. I'd have been surprised had that worked.

That means even with a node-local comm, you'd still need to bypass CheckpointFile and use raw PETSc Vec I/O, just with a parallel Vec on the node comm instead of a sequential Vec on COMM_SELF.

This is what I'm suggesting. It doesn't seem like a lot of work to change your API from per_rank-type options to a more general filesystem_comm=some_comm type of thing. Your use case could still do filesystem_comm=COMM_SELF if desired.

@angus-g
Copy link
Copy Markdown
Contributor

angus-g commented Feb 17, 2026

That said, I might be missing something. Is there a way to make this work that I'm not seeing?

I don't think so. I'd have been surprised had that worked.

Sorry, that was probably my misleading advice. My reasoning was just that it'd be nice to reduce complexity by leveraging the existing code for viewer set up etc. (and allow pyop2 to set up the comm in the first place).

It kind of seems like a CheckpointFile should be able to take a different comm and work, or at least it should be documented in which cases it can be something other than COMM_WORLD... I get that there are complexities around mesh topology so maybe it's a silly way of thinking in the first place.

Refactors the per_rank_dirname parameter into a more general
checkpoint_comm + checkpoint_dir interface, following reviewer
feedback. Instead of hardcoding COMM_SELF, users now pass any
MPI communicator (COMM_SELF for per-rank files, a node-local
comm for per-node files, etc.). The PETSc Vec I/O uses createMPI
on the supplied communicator rather than createSeq on COMM_SELF,
making the approach work for arbitrary communicator topologies.

Removes three serial-only checkpoint_comm tests that are fully
covered by their parallel counterparts and adds node_comm tests
that exercise the multi-rank-per-file path using COMM_TYPE_SHARED.
@sghelichkhani
Copy link
Copy Markdown
Contributor Author

Thanks Connor, done. I've refactored the API from per_rank_dirname to a general checkpoint_comm + checkpoint_dir interface. Users can now pass any communicator, so COMM_SELF for per-rank files or a node-local comm via COMM_TYPE_SHARED for per-node files.

The main PETSc-level change this required was switching from Vec.createSeq (which only works on a single process) to Vec.createMPI((local_size, PETSc.DECIDE), comm=checkpoint_comm). The DECIDE for the global size is important here because when checkpoint_comm groups multiple ranks (e.g. a node communicator), each rank contributes a different number of local DOFs depending on how the mesh was partitioned. The global size of the Vec on the checkpoint communicator is not something we know upfront since it's the sum of local DOFs across ranks in that sub-comm, which is a different grouping than the mesh's COMM_WORLD partitioning. Letting PETSc compute the global size from the local sizes avoids us having to gather that information ourselves.

I've also added tests with COMM_TYPE_SHARED to exercise the multi-rank-per-file path and trimmed redundant serial tests that were fully covered by their parallel counterparts.

Copy link
Copy Markdown
Contributor

@connorjward connorjward left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems alright to me. It would definitely be good to get some feedback from @JHopeCollins, who has done similar comm wrangling for ensemble.

Comment thread firedrake/adjoint_utils/checkpointing.py Outdated
Comment thread tests/firedrake/adjoint/test_disk_checkpointing.py Outdated
Comment thread firedrake/adjoint_utils/checkpointing.py Outdated
Comment thread firedrake/adjoint_utils/checkpointing.py Outdated
@connorjward
Copy link
Copy Markdown
Contributor

It kind of seems like a CheckpointFile should be able to take a different comm and work, or at least it should be documented in which cases it can be something other than COMM_WORLD... I get that there are complexities around mesh topology so maybe it's a silly way of thinking in the first place.

It's definitely an interesting question. Conceptually I think it should be possible to checkpoint a DMPlex to multiple files but its far from trivial. An added complication is that we would have to preserve the N-to-M checkpointing behaviour (i.e. reading and writing with different numbers of ranks).

Add FutureWarning to deprecated new_checkpoint_file method. Use
isinstance(mesh, ufl.MeshSequence) instead of hasattr check. Replace
COMM_TYPE_SHARED tests with comm.Split(rank // 2) to guarantee a
communicator with 1 < size < COMM_WORLD in the 3-rank test.
Extract _generate_function_space_name from CheckpointFile into a
module-level function in firedrake/checkpointing.py and reuse it
for the checkpoint_comm Vec naming instead of maintaining a separate
_generate_checkpoint_vec_name. The free function also handles
MeshSequenceGeometry defensively. CheckpointFile method delegates
to it.
The multi-mesh tests chained two PDE solves via assemble(u_a * dx), a
global reduction whose floating-point result can vary across parallel
runs due to reduction ordering. This made the J == Jnew assertion flaky
at the np.allclose tolerance boundary. Make the mesh_b solve independent
and drop the redundant memory baseline comparison.
@sghelichkhani
Copy link
Copy Markdown
Contributor Author

sghelichkhani commented Feb 18, 2026

Simplified the multi-mesh tests to fix intermittent CI failures. The original design chained two solves via a global reduction (assemble(u_a * dx)), which amplified parallel floating-point non-determinism across tape replays. The two solves are now independent while still doing multi-mesh checkpointing.

Copy link
Copy Markdown
Member

@JHopeCollins JHopeCollins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My gut feeling is that this should be part of CheckpointFile rather than hidden in the adjoint utils.

I don't think I fully understand the problem with checkpointing the mesh using the global comm but checkpointing the function data using a subcomm. Why does that mean you have to split the DM into multiple files?

Comment thread tests/firedrake/adjoint/test_disk_checkpointing.py Outdated
Comment thread firedrake/checkpointing.py Outdated
Comment thread firedrake/adjoint_utils/checkpointing.py Outdated
@connorjward
Copy link
Copy Markdown
Contributor

I don't think I fully understand the problem with checkpointing the mesh using the global comm but checkpointing the function data using a subcomm. Why does that mean you have to split the DM into multiple files?

You can't split the DM into multiple files. That's why this only works for functions.

@JHopeCollins
Copy link
Copy Markdown
Member

JHopeCollins commented Feb 19, 2026

The main thing is that I would like to see the local save/load logic abstracted out of the adjoint code somehow, because the adjoint code really shouldn't be thinking about actual concrete data.

I can think of four potential ways to do this. Happy to hear arguments for/against each one.

In each case we'd obviously need to be very explicit about the restrictions of saving/loading locally, i.e. you must have exactly the same partition for saving and viewing so its basically only for saving/loading during the same programme (and for option 3 you also can't save/load the mesh with this class).

  1. What about dm.localVectorView and dm.localVectorLoad? It looks like they will handle splitting the data apart which seem to do what we need "through the front door". I don't know if they can save the local data to local files though.
    https://petsc.org/release/manualpages/DMPlex/DMPlexLocalVectorView/
    https://petsc.org/release/manualpages/DMPlex/DMPlexLocalVectorLoad/

  2. In CheckpointFile, only use dm.globalVector{View,Load} if we are checkpointing to a global file, otherwise ignore those functions and save the Vec data to local files with your new logic. This is essentially the same as 2. but we add the switches in CheckpointFile.

  3. A new class, something like LocalFunctionCheckpointFile (happy to bikeshed this) that has the logic for saving/loading Functions to local files on subcomms. If we really don't trust users to use the local saving/loading properly then we could do this one and just don't advertise the new class in the public API.

  4. Writing a python-type viewer context for CheckpointFile.viewer that internally creates the current global HDF5 viewer and delegates to that one except if it's being asked to view/load a Vec, in which case it uses the new logic in the adjoint_utils for saving/viewing locally. So dm.globalVector{View,Load} would end up being diverted to the comm-local implementation.
    Hopefully this would just be implementing __init__, view and load methods in the python context (and maybe a blank setUp). But this option is probably trying to be too clever and overkill for what we need.

@connorjward
Copy link
Copy Markdown
Contributor

The main thing is that I would like to see the local save/load logic abstracted out of the adjoint code somehow, because the adjoint code really shouldn't be thinking about actual concrete data.

This is a fair point.

What about dm.localVectorView and dm.localVectorLoad?

I've read the implementation and still confused about the difference between these. In PETSc the term 'local' applies to lots of different things so this may not do quite what you expect.

I wonder if this is basically reimplementing DumbCheckpoint again.

We should discuss this in today's meeting.

@connorjward
Copy link
Copy Markdown
Contributor

@sghelichkhani from today's meeting we decided that we want this functionality exposed as a new CheckpointFile-ish type living in checkpointing.py. This basically point 3 from @JHopeCollins above.

I quite like something like EphemeralFunctionCheckpointFile or similar. It would be cool if by default they could be destroyed at program exit to avoid misuse.

@JHopeCollins
Copy link
Copy Markdown
Member

JHopeCollins commented Feb 19, 2026

I quite like something like EphemeralFunctionCheckpointFile or similar.

How about TemporaryCheckpointFile, or TemporaryFunctionCheckpointFile, to mirror the tempfile naming and the common naming of a /tmp working directory.

It would be cool if by default they could be destroyed at program exit to avoid misuse.

@sghelichkhani is it possible to use tempfile.TemporaryFile and tempfile.TemporaryDirectory rather than tempfile.mkstemp and tempfile.mkdtemp so that the cleanup is automated?

@sghelichkhani
Copy link
Copy Markdown
Contributor Author

Note on _checkpoint_indices: this class-level dict on CheckpointFunction is never pruned. File entries persist after the HDF5 file is deleted. The local path (TemporaryFunctionCheckpointFile._indices) does not have this issue since remove_file handles cleanup. Not fixing this to avoid the risk of pruning entries before restore() reads them. The memory cost is negligible.

leo-collins and others added 4 commits February 21, 2026 22:37
…ject#4865)

* add log event markers

* build spatial index using CreateWithArray
Move PETSc Vec I/O into TemporaryFunctionCheckpointFile in checkpointing.py.
Rename save/restore methods, fix deprecation warning, remove redundant fixture
and forwarding method, clean up imports.
Copy link
Copy Markdown
Member

@JHopeCollins JHopeCollins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Siya,

Thanks for your effort on this! The current TemporaryFunctionCheckpointFile is a really good start on separating out the "save to local not shared" storage abstraction from the adjoint code, I think we're converging on the right solution.

I've left of comments on specific parts, but they are predominantly about two main suggestions which I think will complete the separation of concerns between the checkpointing and the adjoint, and hopefully make the code simpler by removing a bit of repetition.

  1. The most important point is that, although the logic of actually saving/loading from local storage is now totally out of the adjoint code, it has taken some of the "how to checkpoint along the tape" logic with it (i.e. the _indices stuff).
    This calculation of the stored name/index looks identical for both checkpoint types, so it should be done just once by the adjoint code, and then passed to whatever checkpoint file type is being used.
    The same is true for resetting the name/count of the Function after loading - this is the responsibility of the adjoint not the checkpoint.

  2. The second is that it looks like TemporaryFunctionCheckpointFile is actually AccessAnyTemporaryFunctionCheckpointFile because you have to pass the file name every time you want to save/load, rather than having a 1-to-1 between an instance of this class and an hdf5 file, like we have for CheckpointFile.
    If the TemporaryFunctionCheckpointFile holds onto one file like the CheckpointFile does then I think you can make the save_function signature identical between the two, and you can make the load_function signature almost identical - TemporaryFunctionCheckpointFile will take a FunctionSpace rather than a Mesh as the first argument but the rest will be the same.

Comment thread firedrake/adjoint_utils/checkpointing.py
Comment thread firedrake/checkpointing.py Outdated
Comment thread firedrake/checkpointing.py Outdated
Comment thread firedrake/checkpointing.py Outdated
Comment thread firedrake/checkpointing.py Outdated
Comment thread firedrake/adjoint_utils/checkpointing.py Outdated
Comment thread firedrake/adjoint_utils/checkpointing.py Outdated
Comment thread firedrake/adjoint_utils/checkpointing.py Outdated
Comment thread firedrake/checkpointing.py Outdated
Comment thread firedrake/checkpointing.py Outdated
@JHopeCollins
Copy link
Copy Markdown
Member

JHopeCollins commented Feb 26, 2026

Just to point out a nice-to-have-but-not-essential extra if you do both suggestions above.
If you just want this PR finished and merged that's absolutely fine, I can put this is an issue. I just want to write it down while it's in my head!

It would be great to allow the TemporaryFunctionCheckpointFile to be used as a context manager like the CheckpointFile. I think that would mean that you'd need barely any changes to the adjoint code, it would almost all be in the firedrake/checkpointing.py.

You'd just need to add these methods to the class, like CheckpointFile (see comment further down for the close method):

	def __enter__(self):
		return self

	def __exit__(self, *args):
		self.close()

Then in the adjoint code you'd save both the Checkpoint class type and the comm to make the checkpoint over on the CheckpointFileReference

to save the Function in CheckpointFunction.__init__ you would have:

	# work out stored_index and stored_name
	...
	with self.file.CheckpointClass(self.file.name, 'a', comm=self.file.checkpoint_comm) as outfile:
		outfile.save_function(function, name=stored_name, idx=stored_index)

And for restoring the Function you could have:

def restore(self):
	if isinstance(self.file.CheckpointClass, TemporaryFunctionCheckpointFile):
		base = self.function_space
	else:
		base = self.mesh

	with self.file.CheckpointClass(self.file.name, 'r', comm=self.file.checkpoint_comm) as infile:
		function = infile.load_function(base, self.stored_name, self.stored_index)
	return type(function)(function.function_space(),
						  val=function.dat, name=self.name, count=self.count)

Potentially we could allow CheckpointFile.load_function to take a FunctionSpace as the first argument rather than a Mesh, and then check that the loaded Function has a matching space. then the CheckpointFunction.restore method would simply be:

def restore(self):
	with self.file.CheckpointClass(self.file.name, 'r', comm=self.file.checkpoint_comm) as infile:
		function = infile.load_function(self.function_space, self.stored_name, self.stored_index)
	return type(function)(function.function_space(),
						  val=function.dat, name=self.name, count=self.count)

Make it 1-to-1 with a single file/mode (context manager, viewer held open).
Move index tracking into CheckpointFunction.__init__. Move directory
management into DiskCheckpointer. Prune _checkpoint_indices in
CheckPointFileReference.__del__. Add pruning test.
@sghelichkhani
Copy link
Copy Markdown
Contributor Author

sghelichkhani commented Mar 1, 2026

This ended up being quite a bit of work :), but I think the result is much cleaner.

The main change is that TemporaryFunctionCheckpointFile is now 1-to-1 with a single HDF5 file and a single mode, rather than a manager that creates and tracks multiple files. The constructor now takes (comm, filepath, mode) and opens the PETSc HDF5 viewer immediately, holding it for the object's lifetime. The class is now a proper context manager with enter/exit/del, so save and restore use it exactly like CheckpointFile: open a file, do one operation, close.

The internal _indices dict and new_file()/remove_file() are gone. Index tracking -- computing stored_name and stored_index -- now happens once in CheckpointFunction.init and is shared by both the shared and local checkpoint paths. Previously this logic was duplicated in _save_shared and inside TemporaryFunctionCheckpointFile.save_function. Now _save_shared and _save_local_checkpoint are structurally parallel: both open their respective file class as a context manager, call save_function with the pre-computed name and index, and close.

Directory management (creating the temp directory for local checkpoint files) has moved out of TemporaryFunctionCheckpointFile and into DiskCheckpointer. One thing worth noting there: the bcast for the local directory uses checkpoint_comm, not comm. Only ranks within the same checkpoint_comm group share a local filesystem, so broadcasting on COMM_WORLD would be wrong -- each group creates its own subdirectory independently.

CheckPointFileReference loses the temp_ckpt reference and gains checkpoint_comm instead. The del method picks the right communicator for file cleanup automatically: checkpoint_comm.rank == 0 for local files, comm.rank == 0 for shared ones.

@sghelichkhani
Copy link
Copy Markdown
Contributor Author

sghelichkhani commented Mar 1, 2026

I also took the opportunity to fix the long-standing issue with _checkpoint_indices never being pruned. This is a class-level dict on CheckpointFunction that maps filepath -> {stored_name -> count}, and previously it accumulated entries for every checkpoint file ever used without ever removing them. Over a long adjoint run with many tape advances this is a slow but real memory leak.

The fix is one line in CheckPointFileReference.del: pop the entry for self.name from CheckpointFunction._checkpoint_indices. The reason this is safe comes down to one invariant: CheckpointFunction holds self.file as a direct strong reference (not a weak one), so del on the CheckPointFileReference can only fire after every CheckpointFunction that wrote to that filepath has already been garbage-collected. And restore() never reads _checkpoint_indices; it uses stored_name and stored_index baked into the CheckpointFunction instance at save time, so there is no way a restore call can race with the pruning.

The one case that required careful thought was revolve-style schedules, where the tape can snapshot its state and later restore it, causing new CheckpointFunction objects to be written to an old file. This is safe because the tape's checkpoint store holds the CheckPointFileReference alive by reference until re-execution is complete, keeping the reference count above zero and preventing del from firing prematurely. I have added a test (test_checkpoint_indices_pruning) that verifies the entries are actually removed once both reference chains, the DiskCheckpointer's and the tape's, are released.

@sghelichkhani
Copy link
Copy Markdown
Contributor Author

On the question of whether this should live inside CheckpointFile: I think the answer is no, and the docstring tries to make the reason explicit. The two classes have fundamentally different contracts.

CheckpointFile is collective on COMM_WORLD, understands mesh topology, and loads functions by reconstructing them from the DM, which is why load_function takes a mesh. TemporaryFunctionCheckpointFile is collective only on the sub-communicator, knows nothing about topology, and loads functions by filling in a Vec given a FunctionSpace the caller already holds. It takes a FunctionSpace precisely because it has no way to reconstruct one from scratch.

The reason these contracts cannot be unified is the deadlock I showed earlier in the PR. The sectionLoad and globalVectorLoad calls inside CheckpointFile.load_function are collective on the mesh DM's communicator, which is COMM_WORLD. If only a subset of ranks (those in a node-local sub-comm) enter that call, the collective never completes. This is not a bug in CheckpointFile, it is a consequence of the mesh DM living on COMM_WORLD by design, because the parallel solves need it there. Making CheckpointFile work with a sub-communicator for function data while keeping the mesh on COMM_WORLD would require cross-communicator DM operations that PETSc does not support. So the separation is not just cosmetic, it reflects a real constraint.

@sghelichkhani
Copy link
Copy Markdown
Contributor Author

@JHopeCollins I think I have addressed all your points now. Let me know if anything needs more work.

Copy link
Copy Markdown
Member

@JHopeCollins JHopeCollins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks so much for thoroughly addressing our reviews! It's looking good now and I think its basically ready to go.

I think the result is much cleaner.

I'm glad you agree. The internal implementation hasn't changed but I think that the separation of concerns is clear now between how we interact with storage and how we do book-keeping for the tape.

I spotted one thing that absolutely needs to be fixed. It was a bug before this PR but its related and its only a few lines: The CheckpointFile must be passed the right comm! Currently it assumes COMM_WORLD which may not always be correct.

Other than that, there's one comment about not having to modify CheckpointFileReference that I think would reduce the diff, but you can take that or leave it.

Once the CheckpointFile is passed the right comm I'm very happy to approve this!

Comment thread firedrake/adjoint_utils/checkpointing.py Outdated
Comment thread firedrake/adjoint_utils/checkpointing.py Outdated
Comment thread firedrake/adjoint_utils/checkpointing.py Outdated
Comment thread firedrake/adjoint_utils/checkpointing.py
@JHopeCollins
Copy link
Copy Markdown
Member

JHopeCollins commented Mar 3, 2026

@sghelichkhani The CI is failing on the gusto smoke tests and an IO timeout.

  1. The gusto failures are known and unrelated so you can ignore them. There's a gusto PR to fix it but it shouldn't hold up this PR.
  2. The IO timeout has been fixed now so if you merge main it should go away. The timeout prevents the rest of the tests being run so it's important to make sure this passes.

sghelichkhani and others added 3 commits March 3, 2026 22:35
Co-authored-by: Josh Hope-Collins <jhc.jss@gmail.com>
…m:sghelichkhani/firedrake into sghelichkhani/per-rank-disk-checkpointing
@sghelichkhani
Copy link
Copy Markdown
Contributor Author

Applied your suggestion on CheckPointFileReference: dropped _cleanup_comm and now pass checkpoint_comm as comm directly in _new_checkpoint_comm_file, so comm always means "the comm whose rank 0 owns the file" without any derived attribute.

Also caught two more instances of the missing comm bug in checkpointable_mesh (the 'a' and 'r' opens on the init checkpoint file, around line 377 before this patch). Both now pass comm=checkpoint_file.comm. @JHopeCollins, those were covered by your suggestion on _new_shared_checkpoint_file in spirit but not literally, in case you want to double-check.

@sghelichkhani
Copy link
Copy Markdown
Contributor Author

sghelichkhani commented Mar 3, 2026

The CI is failing on the gusto smoke tests and an IO timeout.
1. The gusto failures are known and unrelated so you can ignore them. There's a gusto PR to fix it but it shouldn't hold up this PR.
2. The IO timeout has been fixed now so if you merge main it should go away. The timeout prevents the rest of the tests being run so it's important to make sure this passes.
@JHopeCollins Let's hope the CI is fixed and this goes green (other than gusto)

Store self.comm on CheckPointFileReference so the API introduced by the
review-suggested CheckpointFile comm fixes works correctly. Drops the
now-redundant _cleanup_comm. Also fixes two missing comm= arguments in
checkpointable_mesh.
Copy link
Copy Markdown
Member

@JHopeCollins JHopeCollins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Applied your suggestion on CheckPointFileReference

Thanks!

Also caught two more instances of the missing comm bug in checkpointable_mesh

Great, good spot on those, it looks right to me.

This all looks good to me now. Thanks for working with us through multiple reviews, I'm very happy with where this has ended up.

@JHopeCollins JHopeCollins merged commit 2d3aa95 into firedrakeproject:main Mar 5, 2026
7 checks passed
dsroberts added a commit to dsroberts/firedrake that referenced this pull request Apr 7, 2026
firedrakeproject#4891 introduced the use of tempfile.TemporaryDirectory with the
`delete` parameter. This was introduced in Python 3.12, and therefore
this commit breaks compatibility with Python 3.10 and 3.11.
pyproject.toml states that Python >= 3.10 is supported. Due to OS
constraints, we'll be stuck on Python 3.11 on our HPC systems for some
time.
connorjward pushed a commit that referenced this pull request Apr 9, 2026
#4891 introduced the use of tempfile.TemporaryDirectory with the
`delete` parameter. This was introduced in Python 3.12, and therefore
this commit breaks compatibility with Python 3.10 and 3.11.
pyproject.toml states that Python >= 3.10 is supported. Due to OS
constraints, we'll be stuck on Python 3.11 on our HPC systems for some
time.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants