Source code for qinfer.parallel

#!/usr/bin/python
# -*- coding: utf-8 -*-
##
# parallel.py: Tools for distributing computation.
##
# © 2014 Chris Ferrie (csferrie@gmail.com) and
#        Christopher E. Granade (cgranade@gmail.com)
#     
# This file is a part of the Qinfer project.
# Licensed under the AGPL version 3.
##
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
##

## FEATURES ##################################################################

from __future__ import absolute_import
from __future__ import division # Ensures that a/b is always a float.

## EXPORTS ###################################################################

__all__ = ['DirectViewParallelizedModel']

## IMPORTS ###################################################################

import numpy as np
from qinfer.derived_models import DerivedModel

import warnings

try:
    import ipyparallel as ipp
    interactive = ipp.interactive
except ImportError:
    try:
        import IPython.parallel as ipp
        interactive = ipp.interactive
    except (ImportError, AttributeError):
        import warnings
        warnings.warn(
            "Could not import IPython parallel. "
            "Parallelization support will be disabled."
        )
        ipp = None
        interactive = lambda fn: fn

## LOGGING ###################################################################

import logging
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
    
## CLASSES ###################################################################

[docs]class DirectViewParallelizedModel(DerivedModel): r""" Given an instance of a :class:`Model`, parallelizes execution of that model's likelihood by breaking the ``modelparams`` array into segments and executing a segment on each member of a :class:`~ipyparallel.DirectView`. This :class:`Model` assumes that it has ownership over the DirectView, such that no other processes will send tasks during the lifetime of the Model. If you are having trouble pickling your model, consider switching to ``dill`` by calling ``direct_view.use_dill()``. This mode gives more support for closures. :param qinfer.Model serial_model: Model to be parallelized. This model will be distributed to the engines in the direct view, such that the model must support pickling. :param ipyparallel.DirectView direct_view: Direct view onto the engines that will be used to parallelize evaluation of the model's likelihood function. :param bool purge_client: If ``True``, then this model will purge results and metadata from the IPython client whenever the model cache is cleared. This is useful for solving memory leaks caused by very large numbers of calls to ``likelihood``. By default, this is disabled, since enabling this option can cause data loss if the client is being sent other tasks during the operation of this model. :param int serial_threshold: Sets the number of model vectors below which the serial model is to be preferred. By default, this is set to ``10 * n_engines``, where ``n_engines`` is the number of engines exposed by ``direct_view``. """ ## INITIALIZER ## def __init__(self, serial_model, direct_view, purge_client=False, serial_threshold=None): if ipp is None: raise RuntimeError( "This model requires IPython parallelization support, " "but an error was raised importing IPython.parallel." ) self._dv = direct_view self._purge_client = purge_client self._serial_threshold = ( 10 * self.n_engines if serial_threshold is None else int(serial_threshold) ) super(DirectViewParallelizedModel, self).__init__(serial_model) ## SPECIAL METHODS ## def __getstate__(self): # Since instances of this class will be pickled as they are passed to # remote engines, we need to be careful not to include _dv return { '_underlying_model': self._underlying_model, '_dv': None, '_call_count': self._call_count, '_sim_count': self._sim_count, '_serial_threshold': self._serial_threshold } ## PROPERTIES ## # Provide _serial_model as a back-compat. @property def _serial_model(self): warnings.warn("_serial_model is deprecated in favor of _underlying_model.", DeprecationWarning ) return self._underlying_model @_serial_model.setter def _serial_model(self, value): warnings.warn("_serial_model is deprecated in favor of _underlying_model.", DeprecationWarning ) self._underlying_model = value @property def n_engines(self): """ The number of engines seen by the direct view owned by this parallelized model. :rtype: int """ return len(self._dv) if self._dv is not None else 0 ## METHODS ##
[docs] def clear_cache(self): """ Clears any cache associated with the serial model and the engines seen by the direct view. """ self.underlying_model.clear_cache() try: logger.info('DirectView results has {} items. Clearing.'.format( len(self._dv.results) )) self._dv.purge_results('all') if self._purge_client: self._dv.client.purge_everything() except: pass
[docs] def likelihood(self, outcomes, modelparams, expparams): """ Returns the likelihood for the underlying (serial) model, distributing the model parameter array across the engines controlled by this parallelized model. Returns what the serial model would return, see :attr:`~Model.likelihood` """ # By calling the superclass implementation, we can consolidate # call counting there. super(DirectViewParallelizedModel, self).likelihood(outcomes, modelparams, expparams) # If there's less models than some threshold, just use the serial model. # By default, we'll set that threshold to be the number of engines * 10. if modelparams.shape[0] <= self._serial_threshold: return self.underlying_model.likelihood(outcomes, modelparams, expparams) if self._dv is None: raise RuntimeError( "No direct view provided; this may be because the instance was " "loaded from a pickle or NumPy saved array without providing a " "new direct view." ) # Need to decorate with interactive to overcome namespace issues with # remote engines. @interactive def serial_likelihood(mps, sm, os, eps): return sm.likelihood(os, mps, eps) # TODO: check whether there's a better way to pass the extra parameters # that doesn't use so much memory. # The trick is that serial_likelihood will be pickled, so we need to be # careful about closures. L = self._dv.map_sync( serial_likelihood, np.array_split(modelparams, self.n_engines, axis=0), [self.underlying_model] * self.n_engines, [outcomes] * self.n_engines, [expparams] * self.n_engines ) return np.concatenate(L, axis=1)
[docs] def simulate_experiment(self, modelparams, expparams, repeat=1, split_by_modelparams=True): """ Simulates the underlying (serial) model using the parallel engines. Returns what the serial model would return, see :attr:`~Simulatable.simulate_experiment` :param bool split_by_modelparams: If ``True``, splits up ``modelparams`` into `n_engines` chunks and distributes across engines. If ``False``, splits up ``expparams``. """ # By calling the superclass implementation, we can consolidate # simulation counting there. super(DirectViewParallelizedModel, self).simulate_experiment(modelparams, expparams, repeat=repeat) if self._dv is None: raise RuntimeError( "No direct view provided; this may be because the instance was " "loaded from a pickle or NumPy saved array without providing a " "new direct view." ) # Need to decorate with interactive to overcome namespace issues with # remote engines. @interactive def serial_simulator(sm, mps, eps, r): return sm.simulate_experiment(mps, eps, repeat=r) if split_by_modelparams: # If there's less models than some threshold, just use the serial model. # By default, we'll set that threshold to be the number of engines * 10. if modelparams.shape[0] <= self._serial_threshold: return self.underlying_model.simulate_experiment(modelparams, expparams, repeat=repeat) # The trick is that serial_likelihood will be pickled, so we need to be # careful about closures. os = self._dv.map_sync( serial_simulator, [self.underlying_model] * self.n_engines, np.array_split(modelparams, self.n_engines, axis=0), [expparams] * self.n_engines, [repeat] * self.n_engines ) return np.concatenate(os, axis=0) else: # If there's less models than some threshold, just use the serial model. # By default, we'll set that threshold to be the number of engines * 10. if expparams.shape[0] <= self._serial_threshold: return self.underlying_model.simulate_experiment(modelparams, expparams, repeat=repeat) # The trick is that serial_likelihood will be pickled, so we need to be # careful about closures. os = self._dv.map_sync( serial_simulator, [self.underlying_model] * self.n_engines, [modelparams] * self.n_engines, np.array_split(expparams, self.n_engines, axis=0), [repeat] * self.n_engines ) return np.concatenate(os, axis=1)