Sophie

Sophie

distrib > Mageia > 4 > x86_64 > by-pkgid > 6aebace0cb967559ef14fc69ac903937 > files > 157

python-mpi4py-doc-1.3.1-4.mga4.noarch.rpm

from mpi4py import MPI

# -----------------------------------------------------------------------------

import struct as _struct
try:
    from numpy import empty as _empty
    def _array_new(size, typecode, init=0):
        a = _empty(size, typecode)
        a.fill(init)
        return a
    def _array_set(ary, value):
        ary.fill(value)
    def _array_sum(ary):
        return ary.sum()
except ImportError:
    from array import array as _array
    def _array_new(size, typecode, init=0):
        return _array(typecode, [init]) * size
    def _array_set(ary, value):
        for i, _ in enumerate(ary):
            ary[i] = value
    def _array_sum(ary):
        return sum(ary, 0)

# -----------------------------------------------------------------------------

class Counter(object):

    def __init__(self, comm, init=0):
        #
        size = comm.Get_size()
        rank = comm.Get_rank()
        mask = 1
        while mask < size:
            mask <<= 1
        mask >>= 1
        idx = 0
        get_idx = []
        acc_idx = []
        while mask >= 1:
            left  = idx + 1
            right = idx + (mask<<1)
            if rank < mask:
                acc_idx.append( left  )
                get_idx.append( right )
                idx = left
            else:
                acc_idx.append( right )
                get_idx.append( left  )
                idx = right
            rank = rank % mask
            mask >>= 1
        #
        typecode = 'i'
        datatype = MPI.INT
        itemsize = datatype.Get_size()
        #
        root = 0
        rank = comm.Get_rank()
        if rank == root:
            nlevels = len(get_idx) + 1
            nentries = (1<<nlevels) - 1
            self.mem = MPI.Alloc_mem(nentries*itemsize, MPI.INFO_NULL)
            self.mem[:] = _struct.pack(typecode, init) * nentries
        else:
            self.mem = None
        #
        self.win = MPI.Win.Create(self.mem, itemsize, MPI.INFO_NULL, comm)
        self.acc_type = datatype.Create_indexed_block(1, acc_idx).Commit()
        self.get_type = datatype.Create_indexed_block(1, get_idx).Commit()
        self.acc_buf = _array_new(len(acc_idx), typecode)
        self.get_buf = _array_new(len(get_idx), typecode)
        self.myval = 0

    def free(self):
        if self.win:
            self.win.Free()
        if self.mem:
            MPI.Free_mem(self.mem)
            self.mem = None
        if self.get_type:
            self.get_type.Free()
        if self.acc_type:
            self.acc_type.Free()

    def next(self, increment=1):
        _array_set(self.acc_buf, increment)
        root = 0
        self.win.Lock(MPI.LOCK_EXCLUSIVE, root, 0)
        self.win.Get(self.get_buf, root, [0, 1, self.get_type])
        self.win.Accumulate(self.acc_buf, root, [0, 1, self.acc_type], MPI.SUM)
        self.win.Unlock(root)
        nxtval = self.myval + _array_sum(self.get_buf)
        self.myval += increment
        return nxtval

# -----------------------------------------------------------------------------

class Mutex(object):

    def __init__(self, comm):
        self.counter = Counter(comm)

    def __enter__(self):
        self.lock()
        return self

    def __exit__(self, *exc):
        self.unlock()
        return None

    def free(self):
        self.counter.free()

    def lock(self):
        value = self.counter.next(+1)
        while value != 0:
            value = self.counter.next(-1)
            value = self.counter.next(+1)

    def unlock(self):
        self.counter.next(-1)

# -----------------------------------------------------------------------------

def test_counter():
    vals = []
    counter = Counter(MPI.COMM_WORLD)
    for i in range(5):
        c = counter.next()
        vals.append(c)
    counter.free()

    vals = MPI.COMM_WORLD.allreduce(vals)
    assert sorted(vals) == list(range(len(vals)))

def test_mutex():
    mutex = Mutex(MPI.COMM_WORLD)
    mutex.lock()
    mutex.unlock()
    mutex.free()

if __name__ == '__main__':
    test_counter()
    test_mutex()

# -----------------------------------------------------------------------------