Source code for hippylib.scheduling.collective

# Copyright (c) 2016-2018, The University of Texas at Austin 
# & University of California, Merced.
# Copyright (c) 2019-2020, The University of Texas at Austin 
# University of California--Merced, Washington University in St. Louis.
# All Rights reserved.
# See file COPYRIGHT for details.
# This file is part of the hIPPYlib library. For more information and source code
# availability see
# hIPPYlib is free software; you can redistribute it and/or modify it under the
# terms of the GNU General Public License (as published by the Free
# Software Foundation) version 2.0 dated June 1991.

import dolfin as dl
import numpy as np
from mpi4py import MPI

[docs]class NullCollective: """ No-overhead "Parallel" reduction utilities when a serial system of PDEs is solved on 1 process. """
[docs] def bcast(self, v, root=0): return v
[docs] def size(self): return 1
[docs] def rank(self): return 0
[docs] def allReduce(self, v, op): if op.lower() not in ["sum", "avg"]: err_msg = "Unknown operation *{0}* in NullCollective.allReduce".format(op) raise NotImplementedError(err_msg) return v
[docs]class MultipleSamePartitioningPDEsCollective: """ Parallel reduction utilities when several serial systems of PDEs (one per process) are solved concurrently. """ def __init__(self, comm, is_serial_check=False): """ :code:`comm` is :code:`mpi4py.MPI` comm """ self.comm = comm self.is_serial_check = is_serial_check
[docs] def size(self): return self.comm.Get_size()
[docs] def rank(self): return self.comm.Get_rank()
[docs] def _allReduce_array(self,v,op): err_msg = "Unknown operation *{0}* in MultipleSerialPDEsCollective.allReduce".format(op) receive = np.zeros_like(v) self.comm.Allreduce(v, receive, op = MPI.SUM) if op == "sum": v[:] = receive elif op == "avg": v[:] = (1./float(self.size()))*receive else: raise NotImplementedError(err_msg) return v
[docs] def allReduce(self, v, op): """ Case handled: - :code:`v` is a scalar (:code:`float`, :code:`int`); - :code:`v` is a numpy array (NOTE: :code:`v` will be overwritten) - :code:`v` is a :code:`dolfin.Vector` (NOTE: :code:`v` will be overwritten) Operation: :code:`op = "Sum"` or `"Avg"` (case insentive). """ op = op.lower() if type(v) in [float, np.float64]: v_array = np.array([v], dtype=np.float64) self._allReduce_array(v_array, op) return v_array[0] elif type(v) in [int,, np.int32]: v_array = np.array([v], dtype=np.int32) self._allReduce_array(v_array, op) return v_array[0] elif (type(v) is np.array) or (type(v) is np.ndarray): return self._allReduce_array(v,op) elif hasattr(v, "mpi_comm") and hasattr(v, "get_local"): # v is most likely a dl.Vector if self.is_serial_check: assert v.mpi_comm().Get_size() == 1 v_array = v.get_local() self._allReduce_array(v_array,op) v.set_local(v_array) v.apply("") return v elif hasattr(v,'nvec'): for i in range(v.nvec()): self.allReduce(v[i],op) return v else: if self.is_serial_check: msg = "MultipleSerialPDEsCollective.allReduce not implement for v of type {0}".format(type(v)) else: msg = "MultipleSamePartitioningPDEsCollective.allReduce not implement for v of type {0}".format(type(v)) raise NotImplementedError(msg)
[docs] def bcast(self, v, root = 0): """ Case handled: - :code:`v` is a scalar (:code:`float`, :code:`int`); - :code:`v` is a numpy array (NOTE: :code:`v` will be overwritten) - :code:`v` is a :code:`dolfin.Vector` (NOTE: :code:`v` will be overwritten) - :code:`root` refers to the process rank within the communicator for which the data to be broadcasted lives. """ if type(v) in [float, np.float64,int,, np.int32]: v_array = np.array([v]) self.comm.Bcast(v_array,root = root) return v_array[0] if type(v) in [np.array, np.ndarray]: self.comm.Bcast(v,root = root) return v elif hasattr(v, "mpi_comm") and hasattr(v, "get_local"): # v is most likely a dl.Vector if self.is_serial_check: assert v.mpi_comm().Get_size() == 1 v_local = v.get_local() self.comm.Bcast(v_local, root = root) v.set_local(v_local) v.apply("") return v elif hasattr(v,'nvec'): for i in range(v.nvec()): self.bcast(v[i],root=root) return v else: if is_serial_check: msg = "MultipleSerialPDEsCollective.bcast not implement for v of type {0}".format(type(v)) else: msg = "MultipleSamePartitioningPDEsCollective.bcast not implement for v of type {0}".format(type(v)) raise NotImplementedError(msg)
[docs]def MultipleSerialPDEsCollective(comm): return MultipleSamePartitioningPDEsCollective(comm, is_serial_check=True)