diff --git a/src/pshmem/shmem.py b/src/pshmem/shmem.py index 5de4e7c..0a924de 100644 --- a/src/pshmem/shmem.py +++ b/src/pshmem/shmem.py @@ -175,6 +175,7 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None): self._n, dtype=self._dtype, ) + self._lock_memory() # Wrap self.data = self._flat.reshape(self._shape) else: @@ -198,8 +199,7 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None): print(msg, flush=True) mem_err = 1 # All ranks check for error - if self._nodecomm is not None: - mem_err = self._nodecomm.bcast(mem_err, root=0) + mem_err = self._nodecomm.bcast(mem_err, root=0) if mem_err != 0: raise RuntimeError("Failed to allocate shared memory") @@ -222,8 +222,7 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None): msg += ": {}".format(e) print(msg, flush=True) mem_err = 1 - if self._nodecomm is not None: - mem_err = self._nodecomm.allreduce(mem_err, op=MPI.SUM) + mem_err = self._nodecomm.allreduce(mem_err, op=MPI.SUM) if mem_err != 0: raise RuntimeError("Failed to attach to shared memory") @@ -242,12 +241,28 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None): if self._noderank == 0: self._flat[:] = 0 + # Lock + self._lock_memory() + # Wrap self.data = self._flat.reshape(self._shape) # Wait for other processes to attach and wrap - if self._nodecomm is not None: - self._nodecomm.barrier() + self._nodecomm.barrier() + + def _lock_memory(self): + """Ensure that the underlying memory has writeable=False""" + if self.data is None: + return + self._flat.flags.writeable = False + self.data.flags.writeable = False + + def _unlock_memory(self): + """Temporarily set writeable to True""" + if self.data is None: + return + self._flat.flags.writeable = True + self.data.flags.writeable = True def __del__(self): self.close() @@ -276,7 +291,9 @@ def __setitem__(self, key, value): raise RuntimeError("Data size is zero- cannot assign elements") if self._comm is None: # shortcut for the serial case + self._unlock_memory() self.data[key] = value + self._lock_memory() return # WARNING: Using this function will have a performance penalty over using # the explicit 'set()' method, since this function must first communicate to @@ -566,7 +583,9 @@ def set(self, data, offset=None, fromrank=0): slc = tuple(dslice) # Copy data slice + self._unlock_memory() self.data[slc] = nodedata + self._lock_memory() # Delete the temporary copy del nodedata @@ -577,7 +596,9 @@ def set(self, data, offset=None, fromrank=0): for d in range(ndims): dslice.append(slice(offset[d], offset[d] + data.shape[d], 1)) slc = tuple(dslice) + self._unlock_memory() self.data[slc] = data + self._lock_memory() # Explicit barrier here, to ensure that other processes do not try # reading data before the writing processes have finished. diff --git a/src/pshmem/test.py b/src/pshmem/test.py index 5586fe1..42534b7 100644 --- a/src/pshmem/test.py +++ b/src/pshmem/test.py @@ -247,17 +247,12 @@ def test_large(self): shm.set(local, fromrank=0) del local - # Check results on all processes. Take turns since this is a - # large memory buffer. - for proc in range(nproc): - if proc == rank: - check = np.zeros(datadims, dtype=datatype) - check[:] = shm[:] - count = np.count_nonzero(check) - self.assertTrue(count == n_elem) - del check - if self.comm is not None: - self.comm.barrier() + # Check results on all processes. + count = np.count_nonzero(shm[:]) + self.assertTrue(count == n_elem) + + if self.comm is not None: + self.comm.barrier() def test_separated(self): if self.comm is None: diff --git a/test_scripts/large_buffer.py b/test_scripts/large_buffer.py new file mode 100755 index 0000000..badea75 --- /dev/null +++ b/test_scripts/large_buffer.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 +import time +from mpi4py import MPI + +import subprocess as sp +import numpy as np + +from pshmem import MPIShared + + +def shmem_stat(msg, limits=False): + spalloc = sp.check_output(["ipcs", "-m"], universal_newlines=True) + n_segment = len(spalloc.split("\n")) - 5 + if limits: + spout = sp.check_output(["ipcs", "-lm"], universal_newlines=True) + msg += ":\n" + for line in spout.split("\n")[2:-2]: + msg += f" {line}\n" + msg += f" {n_segment} allocated segments" + print(f"{msg}") + else: + print(f"{msg}: {n_segment} allocated segments") + + +def main(): + comm = MPI.COMM_WORLD + procs = comm.size + rank = comm.rank + + # Dimensions / type of our shared memory array + n_elem = np.iinfo(np.int32).max - 10 + #n_elem = np.iinfo(np.int32).max + 10 + datadims = (n_elem,) + datatype = np.dtype(np.uint8) + + shmem_stat(f"Proc {rank} Starting state", limits=True) + + # Create local data on one process + if rank == 0: + print(f"Creating large array of {n_elem} bytes", flush=True) + local = np.ones(datadims, dtype=datatype) + else: + local = None + + with MPIShared(datadims, datatype, comm) as shm: + sptr = id(shm._shmem.buf.obj) + print(f"Proc {rank} address = {sptr}", flush=True) + + shm.set(local, fromrank=0) + del local + + sptr = id(shm._shmem.buf.obj) + print(f"Proc {rank} after set, address = {sptr}", flush=True) + + shmem_stat(f"Proc {rank} inside context", limits=True) + time.sleep(10) + + # Check results on all processes. + count = np.count_nonzero(shm[:]) + if count != n_elem: + print(f"Rank {rank} got {count} non-zeros, not {n_elem}", flush=True) + + comm.barrier() + + shmem_stat(f"Proc {rank} Ending state", limits=True) + + +if __name__ == "__main__": + main()