Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions src/pshmem/shmem.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None):
self._n,
dtype=self._dtype,
)
self._lock_memory()
# Wrap
self.data = self._flat.reshape(self._shape)
else:
Expand All @@ -198,8 +199,7 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None):
print(msg, flush=True)
mem_err = 1
# All ranks check for error
if self._nodecomm is not None:
mem_err = self._nodecomm.bcast(mem_err, root=0)
mem_err = self._nodecomm.bcast(mem_err, root=0)
if mem_err != 0:
raise RuntimeError("Failed to allocate shared memory")

Expand All @@ -222,8 +222,7 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None):
msg += ": {}".format(e)
print(msg, flush=True)
mem_err = 1
if self._nodecomm is not None:
mem_err = self._nodecomm.allreduce(mem_err, op=MPI.SUM)
mem_err = self._nodecomm.allreduce(mem_err, op=MPI.SUM)
if mem_err != 0:
raise RuntimeError("Failed to attach to shared memory")

Expand All @@ -242,12 +241,28 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None):
if self._noderank == 0:
self._flat[:] = 0

# Lock
self._lock_memory()

# Wrap
self.data = self._flat.reshape(self._shape)

# Wait for other processes to attach and wrap
if self._nodecomm is not None:
self._nodecomm.barrier()
self._nodecomm.barrier()

def _lock_memory(self):
"""Ensure that the underlying memory has writeable=False"""
if self.data is None:
return
self._flat.flags.writeable = False
self.data.flags.writeable = False

def _unlock_memory(self):
"""Temporarily set writeable to True"""
if self.data is None:
return
self._flat.flags.writeable = True
self.data.flags.writeable = True

def __del__(self):
self.close()
Expand Down Expand Up @@ -276,7 +291,9 @@ def __setitem__(self, key, value):
raise RuntimeError("Data size is zero- cannot assign elements")
if self._comm is None:
# shortcut for the serial case
self._unlock_memory()
self.data[key] = value
self._lock_memory()
return
# WARNING: Using this function will have a performance penalty over using
# the explicit 'set()' method, since this function must first communicate to
Expand Down Expand Up @@ -566,7 +583,9 @@ def set(self, data, offset=None, fromrank=0):
slc = tuple(dslice)

# Copy data slice
self._unlock_memory()
self.data[slc] = nodedata
self._lock_memory()

# Delete the temporary copy
del nodedata
Expand All @@ -577,7 +596,9 @@ def set(self, data, offset=None, fromrank=0):
for d in range(ndims):
dslice.append(slice(offset[d], offset[d] + data.shape[d], 1))
slc = tuple(dslice)
self._unlock_memory()
self.data[slc] = data
self._lock_memory()

# Explicit barrier here, to ensure that other processes do not try
# reading data before the writing processes have finished.
Expand Down
17 changes: 6 additions & 11 deletions src/pshmem/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,17 +247,12 @@ def test_large(self):
shm.set(local, fromrank=0)
del local

# Check results on all processes. Take turns since this is a
# large memory buffer.
for proc in range(nproc):
if proc == rank:
check = np.zeros(datadims, dtype=datatype)
check[:] = shm[:]
count = np.count_nonzero(check)
self.assertTrue(count == n_elem)
del check
if self.comm is not None:
self.comm.barrier()
# Check results on all processes.
count = np.count_nonzero(shm[:])
self.assertTrue(count == n_elem)

if self.comm is not None:
self.comm.barrier()

def test_separated(self):
if self.comm is None:
Expand Down
69 changes: 69 additions & 0 deletions test_scripts/large_buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#!/usr/bin/env python3
import time
from mpi4py import MPI

import subprocess as sp
import numpy as np

from pshmem import MPIShared


def shmem_stat(msg, limits=False):
spalloc = sp.check_output(["ipcs", "-m"], universal_newlines=True)
n_segment = len(spalloc.split("\n")) - 5
if limits:
spout = sp.check_output(["ipcs", "-lm"], universal_newlines=True)
msg += ":\n"
for line in spout.split("\n")[2:-2]:
msg += f" {line}\n"
msg += f" {n_segment} allocated segments"
print(f"{msg}")
else:
print(f"{msg}: {n_segment} allocated segments")


def main():
comm = MPI.COMM_WORLD
procs = comm.size
rank = comm.rank

# Dimensions / type of our shared memory array
n_elem = np.iinfo(np.int32).max - 10
#n_elem = np.iinfo(np.int32).max + 10
datadims = (n_elem,)
datatype = np.dtype(np.uint8)

shmem_stat(f"Proc {rank} Starting state", limits=True)

# Create local data on one process
if rank == 0:
print(f"Creating large array of {n_elem} bytes", flush=True)
local = np.ones(datadims, dtype=datatype)
else:
local = None

with MPIShared(datadims, datatype, comm) as shm:
sptr = id(shm._shmem.buf.obj)
print(f"Proc {rank} address = {sptr}", flush=True)

shm.set(local, fromrank=0)
del local

sptr = id(shm._shmem.buf.obj)
print(f"Proc {rank} after set, address = {sptr}", flush=True)

shmem_stat(f"Proc {rank} inside context", limits=True)
time.sleep(10)

# Check results on all processes.
count = np.count_nonzero(shm[:])
if count != n_elem:
print(f"Rank {rank} got {count} non-zeros, not {n_elem}", flush=True)

comm.barrier()

shmem_stat(f"Proc {rank} Ending state", limits=True)


if __name__ == "__main__":
main()