[med-svn] [toil] 01/02: New upstream version 3.5.0~alpha1.321
Steffen Möller
moeller at moszumanska.debian.org
Sun Jan 15 08:04:16 UTC 2017
This is an automated email from the git hooks/post-receive script.
moeller pushed a commit to branch master
in repository toil.
commit bd1ef3f6b5d284ae8d1a96e807bb3701b5ad16dd
Author: Steffen Moeller <moeller at debian.org>
Date: Sun Jan 15 08:35:49 2017 +0100
New upstream version 3.5.0~alpha1.321
---
PKG-INFO | 10 +
README.rst | 15 +
setup.cfg | 9 +
setup.py | 109 ++
src/toil/__init__.py | 149 ++
src/toil/batchSystems/__init__.py | 53 +
src/toil/batchSystems/abstractBatchSystem.py | 373 +++++
src/toil/batchSystems/gridengine.py | 346 +++++
src/toil/batchSystems/lsf.py | 285 ++++
src/toil/batchSystems/mesos/__init__.py | 56 +
src/toil/batchSystems/mesos/batchSystem.py | 586 ++++++++
src/toil/batchSystems/mesos/conftest.py | 23 +
src/toil/batchSystems/mesos/executor.py | 198 +++
src/toil/batchSystems/mesos/test/__init__.py | 84 ++
src/toil/batchSystems/parasol.py | 372 +++++
src/toil/batchSystems/parasolTestSupport.py | 104 ++
src/toil/batchSystems/singleMachine.py | 348 +++++
src/toil/batchSystems/slurm.py | 434 ++++++
src/toil/common.py | 1094 +++++++++++++++
src/toil/cwl/__init__.py | 1 +
src/toil/cwl/conftest.py | 22 +
src/toil/cwl/cwltoil.py | 740 ++++++++++
src/toil/fileStore.py | 1884 ++++++++++++++++++++++++++
src/toil/job.py | 1727 +++++++++++++++++++++++
src/toil/jobGraph.py | 148 ++
src/toil/jobStores/__init__.py | 14 +
src/toil/jobStores/abstractJobStore.py | 967 +++++++++++++
src/toil/jobStores/aws/__init__.py | 0
src/toil/jobStores/aws/jobStore.py | 1363 +++++++++++++++++++
src/toil/jobStores/aws/utils.py | 267 ++++
src/toil/jobStores/azureJobStore.py | 827 +++++++++++
src/toil/jobStores/conftest.py | 27 +
src/toil/jobStores/fileJobStore.py | 423 ++++++
src/toil/jobStores/googleJobStore.py | 476 +++++++
src/toil/jobStores/utils.py | 236 ++++
src/toil/leader.py | 909 +++++++++++++
src/toil/lib/__init__.py | 14 +
src/toil/lib/bioio.py | 309 +++++
src/toil/lib/encryption/__init__.py | 18 +
src/toil/lib/encryption/_dummy.py | 32 +
src/toil/lib/encryption/_nacl.py | 89 ++
src/toil/lib/encryption/conftest.py | 8 +
src/toil/provisioners/__init__.py | 76 ++
src/toil/provisioners/abstractProvisioner.py | 262 ++++
src/toil/provisioners/aws/__init__.py | 289 ++++
src/toil/provisioners/aws/awsProvisioner.py | 628 +++++++++
src/toil/provisioners/cgcloud/__init__.py | 0
src/toil/provisioners/cgcloud/provisioner.py | 338 +++++
src/toil/provisioners/clusterScaler.py | 423 ++++++
src/toil/realtimeLogger.py | 246 ++++
src/toil/resource.py | 573 ++++++++
src/toil/serviceManager.py | 197 +++
src/toil/statsAndLogging.py | 154 +++
src/toil/test/__init__.py | 863 ++++++++++++
src/toil/toilState.py | 174 +++
src/toil/utils/__init__.py | 27 +
src/toil/utils/toilClean.py | 37 +
src/toil/utils/toilDestroyCluster.py | 32 +
src/toil/utils/toilKill.py | 44 +
src/toil/utils/toilLaunchCluster.py | 53 +
src/toil/utils/toilMain.py | 52 +
src/toil/utils/toilRsyncCluster.py | 40 +
src/toil/utils/toilSSHCluster.py | 33 +
src/toil/utils/toilStats.py | 605 +++++++++
src/toil/utils/toilStatus.py | 112 ++
src/toil/version.py | 13 +
src/toil/worker.py | 560 ++++++++
67 files changed, 20980 insertions(+)
diff --git a/PKG-INFO b/PKG-INFO
new file mode 100644
index 0000000..0e978a0
--- /dev/null
+++ b/PKG-INFO
@@ -0,0 +1,10 @@
+Metadata-Version: 1.0
+Name: toil
+Version: 3.5.0a1.dev321
+Summary: Pipeline management software for clusters.
+Home-page: https://github.com/BD2KGenomics/toil
+Author: Benedict Paten
+Author-email: benedict at soe.usc.edu
+License: UNKNOWN
+Description: UNKNOWN
+Platform: UNKNOWN
diff --git a/README.rst b/README.rst
new file mode 100644
index 0000000..5ba0a30
--- /dev/null
+++ b/README.rst
@@ -0,0 +1,15 @@
+.. image:: https://badge.waffle.io/BD2KGenomics/toil.svg?label=ready&title=Ready
+ :target: https://waffle.io/BD2KGenomics/toil
+ :alt: 'Stories in Ready'
+
+Toil is a scalable, efficient, cross-platform pipeline management system,
+written entirely in Python, and designed around the principles of functional
+programming. Full documentation for the latest stable release can be found at
+`Read the Docs`_.
+
+.. _Read the Docs: http://toil.readthedocs.org/
+
+
+.. image:: https://badges.gitter.im/bd2k-genomics-toil/Lobby.svg
+ :alt: Join the chat at https://gitter.im/bd2k-genomics-toil/Lobby
+ :target: https://gitter.im/bd2k-genomics-toil/Lobby?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge
\ No newline at end of file
diff --git a/setup.cfg b/setup.cfg
new file mode 100644
index 0000000..4f67384
--- /dev/null
+++ b/setup.cfg
@@ -0,0 +1,9 @@
+[pytest]
+python_files = *.py
+addopts = --doctest-modules --tb=native --assert=plain
+
+[egg_info]
+tag_build =
+tag_date = 0
+tag_svn_revision = 0
+
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000..92a60bf
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,109 @@
+# Copyright (C) 2015-2016 Regents of the University of California
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from setuptools import find_packages, setup
+
+botoRequirement = 'boto==2.38.0'
+
+
+def runSetup():
+ """
+ Calls setup(). This function exists so the setup() invocation preceded more internal
+ functionality. The `version` module is imported dynamically by importVersion() below.
+ """
+ setup(
+ name='toil',
+ version=version.distVersion,
+ description='Pipeline management software for clusters.',
+ author='Benedict Paten',
+ author_email='benedict at soe.usc.edu',
+ url="https://github.com/BD2KGenomics/toil",
+ install_requires=[
+ 'bd2k-python-lib>=1.14a1.dev35',
+ 'dill==0.2.5',
+ 'six>=1.10.0'],
+ extras_require={
+ 'mesos': [
+ 'psutil==3.0.1'],
+ 'aws': [
+ botoRequirement,
+ 'cgcloud-lib==' + version.cgcloudVersion,
+ 'futures==3.0.5'],
+ 'azure': [
+ 'azure==1.0.3'],
+ 'encryption': [
+ 'pynacl==0.3.0'],
+ 'google': [
+ 'gcs_oauth2_boto_plugin==1.9',
+ botoRequirement],
+ 'cwl': [
+ 'cwltool==1.0.20161221171240']},
+ package_dir={'': 'src'},
+ packages=find_packages(where='src',
+ # Note that we intentionally include the top-level `test` package for
+ # functionality like the @experimental and @integrative decoratorss:
+ exclude=['*.test.*']),
+ # Unfortunately, the names of the entry points are hard-coded elsewhere in the code base so
+ # you can't just change them here. Luckily, most of them are pretty unique strings, and thus
+ # easy to search for.
+ entry_points={
+ 'console_scripts': [
+ 'toil = toil.utils.toilMain:main',
+ '_toil_worker = toil.worker:main',
+ 'cwltoil = toil.cwl.cwltoil:main [cwl]',
+ 'cwl-runner = toil.cwl.cwltoil:main [cwl]',
+ '_toil_mesos_executor = toil.batchSystems.mesos.executor:main [mesos]']})
+
+
+def importVersion():
+ """
+ Load and return the module object for src/toil/version.py, generating it from the template if
+ required.
+ """
+ import imp
+ try:
+ # Attempt to load the template first. It only exists in a working copy cloned via git.
+ import version_template
+ except ImportError:
+ # If loading the template fails we must be in a unpacked source distribution and
+ # src/toil/version.py will already exist.
+ pass
+ else:
+ # Use the template to generate src/toil/version.py
+ import os
+ import errno
+ from tempfile import NamedTemporaryFile
+
+ new = version_template.expand_()
+ try:
+ with open('src/toil/version.py') as f:
+ old = f.read()
+ except IOError as e:
+ if e.errno == errno.ENOENT:
+ old = None
+ else:
+ raise
+
+ if old != new:
+ with NamedTemporaryFile(dir='src/toil', prefix='version.py.', delete=False) as f:
+ f.write(new)
+ os.rename(f.name, 'src/toil/version.py')
+ # Unfortunately, we can't use a straight import here because that would also load the stuff
+ # defined in src/toil/__init__.py which imports modules from external dependencies that may
+ # yet to be installed when setup.py is invoked.
+ return imp.load_source('toil.version', 'src/toil/version.py')
+
+
+version = importVersion()
+runSetup()
diff --git a/src/toil/__init__.py b/src/toil/__init__.py
new file mode 100644
index 0000000..f6e04fd
--- /dev/null
+++ b/src/toil/__init__.py
@@ -0,0 +1,149 @@
+# Copyright (C) 2015-2016 Regents of the University of California
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+
+import logging
+import os
+import sys
+
+from subprocess import check_output
+
+from bd2k.util import memoize
+
+log = logging.getLogger(__name__)
+
+
+def toilPackageDirPath():
+ """
+ Returns the absolute path of the directory that corresponds to the top-level toil package.
+ The return value is guaranteed to end in '/toil'.
+ """
+ result = os.path.dirname(os.path.realpath(__file__))
+ assert result.endswith('/toil')
+ return result
+
+
+def inVirtualEnv():
+ return hasattr(sys, 'real_prefix')
+
+
+def resolveEntryPoint(entryPoint):
+ """
+ Returns the path to the given entry point (see setup.py) that *should* work on a worker. The
+ return value may be an absolute or a relative path.
+ """
+ if inVirtualEnv():
+ path = os.path.join(os.path.dirname(sys.executable), entryPoint)
+ # Inside a virtualenv we try to use absolute paths to the entrypoints.
+ if os.path.isfile(path):
+ # If the entrypoint is present, Toil must have been installed into the virtualenv (as
+ # opposed to being included via --system-site-packages). For clusters this means that
+ # if Toil is installed in a virtualenv on the leader, it must be installed in
+ # a virtualenv located at the same path on each worker as well.
+ assert os.access(path, os.X_OK)
+ return path
+ else:
+ # For virtualenv's that have the toil package directory on their sys.path but whose
+ # bin directory lacks the Toil entrypoints, i.e. where Toil is included via
+ # --system-site-packages, we rely on PATH just as if we weren't in a virtualenv.
+ return entryPoint
+ else:
+ # Outside a virtualenv it is hard to predict where the entry points got installed. It is
+ # the reponsibility of the user to ensure that they are present on PATH and point to the
+ # correct version of Toil. This is still better than an absolute path because it gives
+ # the user control over Toil's location on both leader and workers.
+ return entryPoint
+
+
+ at memoize
+def physicalMemory():
+ """
+ >>> n = physicalMemory()
+ >>> n > 0
+ True
+ >>> n == physicalMemory()
+ True
+ """
+ try:
+ return os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES')
+ except ValueError:
+ return int(check_output(['sysctl', '-n', 'hw.memsize']).strip())
+
+
+def physicalDisk(config, toilWorkflowDir=None):
+ if toilWorkflowDir is None:
+ from toil.common import Toil
+ toilWorkflowDir = Toil.getWorkflowDir(config.workflowID, config.workDir)
+ diskStats = os.statvfs(toilWorkflowDir)
+ return diskStats.f_frsize * diskStats.f_bavail
+
+
+def applianceSelf():
+ """
+ Returns the fully qualified name of the Docker image to start Toil appliance containers from.
+ The result is determined by the current version of Toil and three environment variables:
+ ``TOIL_DOCKER_REGISTRY``, ``TOIL_DOCKER_NAME`` and ``TOIL_APPLIANCE_SELF``.
+
+ ``TOIL_DOCKER_REGISTRY`` specifies an account on a publicly hosted docker registry like Quay
+ or Docker Hub. The default is UCSC's CGL account on Quay.io where the Toil team publishes the
+ official appliance images. ``TOIL_DOCKER_NAME`` specifies the base name of the image. The
+ default of `toil` will be adequate in most cases. ``TOIL_APPLIANCE_SELF`` fully qualifies the
+ appliance image, complete with registry, image name and version tag, overriding both
+ ``TOIL_DOCKER_NAME`` and `TOIL_DOCKER_REGISTRY`` as well as the version tag of the image.
+ Setting TOIL_APPLIANCE_SELF will not be necessary in most cases.
+
+ :rtype: str
+ """
+ import toil.version
+ registry = lookupEnvVar(name='docker registry',
+ envName='TOIL_DOCKER_REGISTRY',
+ defaultValue=toil.version.dockerRegistry)
+ name = lookupEnvVar(name='docker name',
+ envName='TOIL_DOCKER_NAME',
+ defaultValue=toil.version.dockerName)
+ appliance = lookupEnvVar(name='docker appliance',
+ envName='TOIL_APPLIANCE_SELF',
+ defaultValue=registry + '/' + name + ':' + toil.version.dockerTag)
+ return appliance
+
+
+def lookupEnvVar(name, envName, defaultValue):
+ """
+ Use this for looking up environment variables that control Toil and are important enough to
+ log the result of that lookup.
+
+ :param str name: the human readable name of the variable
+ :param str envName: the name of the environment variable to lookup
+ :param str defaultValue: the fall-back value
+ :return: the value of the environment variable or the default value the variable is not set
+ :rtype: str
+ """
+ try:
+ value = os.environ[envName]
+ except KeyError:
+ log.info('Using default %s of %s as %s is not set.', name, defaultValue, envName)
+ return defaultValue
+ else:
+ log.info('Overriding %s of %s with %s from %s.', name, defaultValue, value, envName)
+ return value
+
+
+def logProcessContext(config):
+ # toil.version.version (string) canont be imported at top level because it conflicts with
+ # toil.version (module) and Sphinx doesn't like that.
+ from toil.version import version
+ log.info("Running Toil version %s.", version)
+ log.debug("Configuration: %s", config.__dict__)
+
diff --git a/src/toil/batchSystems/__init__.py b/src/toil/batchSystems/__init__.py
new file mode 100644
index 0000000..f2ff0d4
--- /dev/null
+++ b/src/toil/batchSystems/__init__.py
@@ -0,0 +1,53 @@
+# Copyright (C) 2015-2016 Regents of the University of California
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+import sys
+
+if sys.version_info >= (3, 0):
+
+ # https://docs.python.org/3.0/whatsnew/3.0.html#ordering-comparisons
+ def cmp(a, b):
+ return (a > b) - (a < b)
+
+class MemoryString:
+ def __init__(self, string):
+ if string[-1] == 'K' or string[-1] == 'M' or string[-1] == 'G' or string[-1] == 'T':
+ self.unit = string[-1]
+ self.val = float(string[:-1])
+ else:
+ self.unit = 'B'
+ self.val = float(string)
+ self.bytes = self.byteVal()
+
+ def __str__(self):
+ if self.unit != 'B':
+ return str(self.val) + self.unit
+ else:
+ return str(self.val)
+
+ def byteVal(self):
+ if self.unit == 'B':
+ return self.val
+ elif self.unit == 'K':
+ return self.val * 1024
+ elif self.unit == 'M':
+ return self.val * 1048576
+ elif self.unit == 'G':
+ return self.val * 1073741824
+ elif self.unit == 'T':
+ return self.val * 1099511627776
+
+ def __cmp__(self, other):
+ return cmp(self.bytes, other.bytes)
diff --git a/src/toil/batchSystems/abstractBatchSystem.py b/src/toil/batchSystems/abstractBatchSystem.py
new file mode 100644
index 0000000..689e62f
--- /dev/null
+++ b/src/toil/batchSystems/abstractBatchSystem.py
@@ -0,0 +1,373 @@
+# Copyright (C) 2015-2016 Regents of the University of California
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from __future__ import absolute_import
+
+import os
+import shutil
+from abc import ABCMeta, abstractmethod
+from collections import namedtuple
+
+from bd2k.util.objects import abstractclassmethod
+
+from toil.common import Toil, cacheDirName
+from toil.fileStore import shutdownFileStore
+
+# A class containing the information required for worker cleanup on shutdown of the batch system.
+WorkerCleanupInfo = namedtuple('WorkerCleanupInfo', (
+ # A path to the value of config.workDir (where the cache would go)
+ 'workDir',
+ # The value of config.workflowID (used to identify files specific to this workflow)
+ 'workflowID',
+ # The value of the cleanWorkDir flag
+ 'cleanWorkDir'))
+
+
+class AbstractBatchSystem(object):
+ """
+ An abstract (as far as Python currently allows) base class to represent the interface the batch
+ system must provide to Toil.
+ """
+
+ __metaclass__ = ABCMeta
+
+ # noinspection PyMethodParameters
+ @abstractclassmethod
+ def supportsHotDeployment(cls):
+ """
+ Whether this batch system supports hot deployment of the user script itself. If it does,
+ the :meth:`setUserScript` can be invoked to set the resource object representing the user
+ script.
+
+ Note to implementors: If your implementation returns True here, it should also override
+
+ :rtype: bool
+ """
+ raise NotImplementedError()
+
+ # noinspection PyMethodParameters
+ @abstractclassmethod
+ def supportsWorkerCleanup(cls):
+ """
+ Indicates whether this batch system invokes :meth:`workerCleanup` after the last job for
+ a particular workflow invocation finishes. Note that the term *worker* refers to an
+ entire node, not just a worker process. A worker process may run more than one job
+ sequentially, and more than one concurrent worker process may exist on a worker node,
+ for the same workflow. The batch system is said to *shut down* after the last worker
+ process terminates.
+
+ :rtype: bool
+ """
+ raise NotImplementedError()
+
+ def setUserScript(self, userScript):
+ """
+ Set the user script for this workflow. This method must be called before the first job is
+ issued to this batch system, and only if :meth:`supportsHotDeployment` returns True,
+ otherwise it will raise an exception.
+
+ :param toil.resource.Resource userScript: the resource object representing the user script
+ or module and the modules it depends on.
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ def issueBatchJob(self, jobNode):
+ """
+ Issues a job with the specified command to the batch system and returns a unique jobID.
+
+ :param str command: the string to run as a command,
+
+ :param int memory: int giving the number of bytes of memory the job needs to run
+
+ :param float cores: the number of cores needed for the job
+
+ :param int disk: int giving the number of bytes of disk space the job needs to run
+
+ :param booleam preemptable: True if the job can be run on a preemptable node
+
+ :return: a unique jobID that can be used to reference the newly issued job
+ :rtype: int
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ def killBatchJobs(self, jobIDs):
+ """
+ Kills the given job IDs.
+
+ :param list[int] jobIDs: list of IDs of jobs to kill
+ """
+ raise NotImplementedError()
+
+ # FIXME: Return value should be a set (then also fix the tests)
+
+ @abstractmethod
+ def getIssuedBatchJobIDs(self):
+ """
+ Gets all currently issued jobs
+
+ :return: A list of jobs (as jobIDs) currently issued (may be running, or may be
+ waiting to be run). Despite the result being a list, the ordering should not
+ be depended upon.
+ :rtype: list[str]
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ def getRunningBatchJobIDs(self):
+ """
+ Gets a map of jobs as jobIDs that are currently running (not just waiting)
+ and how long they have been running, in seconds.
+
+ :return: dictionary with currently running jobID keys and how many seconds they have
+ been running as the value
+ :rtype: dict[str,float]
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ def getUpdatedBatchJob(self, maxWait):
+ """
+ Returns a job that has updated its status.
+
+ :param float maxWait: the number of seconds to block, waiting for a result
+
+ :rtype: (str, int)|None
+ :return: If a result is available, returns a tuple (jobID, exitValue, wallTime).
+ Otherwise it returns None. wallTime is the number of seconds (a float) in
+ wall-clock time the job ran for or None if this batch system does not support
+ tracking wall time. Returns None for jobs that were killed.
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ def shutdown(self):
+ """
+ Called at the completion of a toil invocation.
+ Should cleanly terminate all worker threads.
+ """
+ raise NotImplementedError()
+
+ def setEnv(self, name, value=None):
+ """
+ Set an environment variable for the worker process before it is launched. The worker
+ process will typically inherit the environment of the machine it is running on but this
+ method makes it possible to override specific variables in that inherited environment
+ before the worker is launched. Note that this mechanism is different to the one used by
+ the worker internally to set up the environment of a job. A call to this method affects
+ all jobs issued after this method returns. Note to implementors: This means that you
+ would typically need to copy the variables before enqueuing a job.
+
+ If no value is provided it will be looked up from the current environment.
+ """
+ raise NotImplementedError()
+
+ @classmethod
+ def getRescueBatchJobFrequency(cls):
+ """
+ Gets the period of time to wait (floating point, in seconds) between checking for
+ missing/overlong jobs.
+ """
+ raise NotImplementedError()
+
+
+class BatchSystemSupport(AbstractBatchSystem):
+ """
+ Partial implementation of AbstractBatchSystem, support methods.
+ """
+
+ def __init__(self, config, maxCores, maxMemory, maxDisk):
+ """
+ Initializes initial state of the object
+
+ :param toil.common.Config config: object is setup by the toilSetup script and
+ has configuration parameters for the jobtree. You can add code
+ to that script to get parameters for your batch system.
+
+ :param float maxCores: the maximum number of cores the batch system can
+ request for any one job
+
+ :param int maxMemory: the maximum amount of memory the batch system can
+ request for any one job, in bytes
+
+ :param int maxDisk: the maximum amount of disk space the batch system can
+ request for any one job, in bytes
+ """
+ super(BatchSystemSupport, self).__init__()
+ self.config = config
+ self.maxCores = maxCores
+ self.maxMemory = maxMemory
+ self.maxDisk = maxDisk
+ self.environment = {}
+ """
+ :type: dict[str,str]
+ """
+ self.workerCleanupInfo = WorkerCleanupInfo(workDir=self.config.workDir,
+ workflowID=self.config.workflowID,
+ cleanWorkDir=self.config.cleanWorkDir)
+
+ def checkResourceRequest(self, memory, cores, disk):
+ """
+ Check resource request is not greater than that available or allowed.
+
+ :param int memory: amount of memory being requested, in bytes
+
+ :param float cores: number of cores being requested
+
+ :param int disk: amount of disk space being requested, in bytes
+
+ :raise InsufficientSystemResources: raised when a resource is requested in an amount
+ greater than allowed
+ """
+ assert memory is not None
+ assert disk is not None
+ assert cores is not None
+ if cores > self.maxCores:
+ raise InsufficientSystemResources('cores', cores, self.maxCores)
+ if memory > self.maxMemory:
+ raise InsufficientSystemResources('memory', memory, self.maxMemory)
+ if disk > self.maxDisk:
+ raise InsufficientSystemResources('disk', disk, self.maxDisk)
+
+
+ def setEnv(self, name, value=None):
+ """
+ Set an environment variable for the worker process before it is launched. The worker
+ process will typically inherit the environment of the machine it is running on but this
+ method makes it possible to override specific variables in that inherited environment
+ before the worker is launched. Note that this mechanism is different to the one used by
+ the worker internally to set up the environment of a job. A call to this method affects
+ all jobs issued after this method returns. Note to implementors: This means that you
+ would typically need to copy the variables before enqueuing a job.
+
+ If no value is provided it will be looked up from the current environment.
+
+ NB: Only the Mesos and single-machine batch systems support passing environment
+ variables. On other batch systems, this method has no effect. See
+ https://github.com/BD2KGenomics/toil/issues/547.
+
+ :param str name: the environment variable to be set on the worker.
+
+ :param str value: if given, the environment variable given by name will be set to this value.
+ if None, the variable's current value will be used as the value on the worker
+
+ :raise RuntimeError: if value is None and the name cannot be found in the environment
+ """
+ if value is None:
+ try:
+ value = os.environ[name]
+ except KeyError:
+ raise RuntimeError("%s does not exist in current environment", name)
+ self.environment[name] = value
+
+ @classmethod
+ def getRescueBatchJobFrequency(cls):
+ """
+ Gets the period of time to wait (floating point, in seconds) between checking for
+ missing/overlong jobs.
+
+ :return: time in seconds to wait in between checking for lost jobs
+ :rtype: float
+ """
+ raise NotImplementedError()
+
+ def _getResultsFileName(self, toilPath):
+ """
+ Get a path for the batch systems to store results. GridEngine, slurm,
+ and LSF currently use this and only work if locator is file.
+ """
+ # Use parser to extract the path and type
+ locator, filePath = Toil.parseLocator(toilPath)
+ assert locator == "file"
+ return os.path.join(filePath, "results.txt")
+
+ @staticmethod
+ def workerCleanup(info):
+ """
+ Cleans up the worker node on batch system shutdown. Also see :meth:`supportsWorkerCleanup`.
+
+ :param WorkerCleanupInfo info: A named tuple consisting of all the relevant information
+ for cleaning up the worker.
+ """
+ assert isinstance(info, WorkerCleanupInfo)
+ workflowDir = Toil.getWorkflowDir(info.workflowID, info.workDir)
+ workflowDirContents = os.listdir(workflowDir)
+ shutdownFileStore(workflowDir, info.workflowID)
+ if (info.cleanWorkDir == 'always'
+ or info.cleanWorkDir in ('onSuccess', 'onError')
+ and workflowDirContents in ([], [cacheDirName(info.workflowID)])):
+ shutil.rmtree(workflowDir)
+
+
+class NodeInfo(namedtuple("_NodeInfo", "cores memory workers")):
+ """
+ The cores attribute is a floating point value between 0 (all cores idle) and 1 (all cores
+ busy), reflecting the CPU load of the node.
+
+ The memory attribute is a floating point value between 0 (no memory used) and 1 (all memory
+ used), reflecting the memory pressure on the node.
+
+ The workers attribute is an integer reflecting the number of workers currently active workers
+ on the node.
+ """
+
+
+class AbstractScalableBatchSystem(AbstractBatchSystem):
+ """
+ A batch system that supports a variable number of worker nodes. Used by :class:`toil.
+ provisioners.clusterScaler.ClusterScaler` to scale the number of worker nodes in the cluster
+ up or down depending on overall load.
+ """
+
+ @abstractmethod
+ def getNodes(self, preemptable=None):
+ """
+ Returns a dictionary mapping node identifiers of preemptable or non-preemptable nodes to
+ NodeInfo objects, one for each node.
+
+ :param bool preemptable: If True (False) only (non-)preemptable nodes will be returned.
+ If None, all nodes will be returned.
+
+ :rtype: dict[str,NodeInfo]
+ """
+ raise NotImplementedError()
+
+
+class InsufficientSystemResources(Exception):
+ """
+ To be raised when a job requests more of a particular resource than is either currently allowed
+ or avaliable
+ """
+ def __init__(self, resource, requested, available):
+ """
+ Creates an instance of this exception that indicates which resource is insufficient for current
+ demands, as well as the amount requested and amount actually available.
+
+ :param str resource: string representing the resource type
+
+ :param int|float requested: the amount of the particular resource requested that resulted
+ in this exception
+
+ :param int|float available: amount of the particular resource actually available
+ """
+ self.requested = requested
+ self.available = available
+ self.resource = resource
+
+ def __str__(self):
+ return 'Requesting more {} than either physically available, or enforced by --max{}. ' \
+ 'Requested: {}, Available: {}'.format(self.resource, self.resource.capitalize(),
+ self.requested, self.available)
diff --git a/src/toil/batchSystems/gridengine.py b/src/toil/batchSystems/gridengine.py
new file mode 100644
index 0000000..fa88cc8
--- /dev/null
+++ b/src/toil/batchSystems/gridengine.py
@@ -0,0 +1,346 @@
+# Copyright (C) 2015-2016 Regents of the University of California
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+import logging
+import os
+from pipes import quote
+import subprocess
+import time
+import math
+from threading import Thread
+
+# Python 3 compatibility imports
+from six.moves.queue import Empty, Queue
+from six import iteritems
+
+from toil.batchSystems import MemoryString
+from toil.batchSystems.abstractBatchSystem import BatchSystemSupport
+
+logger = logging.getLogger(__name__)
+
+sleepSeconds = 1
+
+
+class Worker(Thread):
+ def __init__(self, newJobsQueue, updatedJobsQueue, killQueue, killedJobsQueue, boss):
+ Thread.__init__(self)
+ self.newJobsQueue = newJobsQueue
+ self.updatedJobsQueue = updatedJobsQueue
+ self.killQueue = killQueue
+ self.killedJobsQueue = killedJobsQueue
+ self.waitingJobs = list()
+ self.runningJobs = set()
+ self.boss = boss
+ self.allocatedCpus = dict()
+ self.sgeJobIDs = dict()
+
+ def getRunningJobIDs(self):
+ times = {}
+ currentjobs = dict((str(self.sgeJobIDs[x][0]), x) for x in self.runningJobs)
+ process = subprocess.Popen(["qstat"], stdout=subprocess.PIPE)
+ stdout, stderr = process.communicate()
+
+ for currline in stdout.split('\n'):
+ items = currline.strip().split()
+ if items:
+ if items[0] in currentjobs and items[4] == 'r':
+ jobstart = " ".join(items[5:7])
+ jobstart = time.mktime(time.strptime(jobstart, "%m/%d/%Y %H:%M:%S"))
+ times[currentjobs[items[0]]] = time.time() - jobstart
+
+ return times
+
+ def getSgeID(self, jobID):
+ if not jobID in self.sgeJobIDs:
+ RuntimeError("Unknown jobID, could not be converted")
+
+ (job, task) = self.sgeJobIDs[jobID]
+ if task is None:
+ return str(job)
+ else:
+ return str(job) + "." + str(task)
+
+ def forgetJob(self, jobID):
+ self.runningJobs.remove(jobID)
+ del self.allocatedCpus[jobID]
+ del self.sgeJobIDs[jobID]
+
+ def killJobs(self):
+ # Load hit list:
+ killList = list()
+ while True:
+ try:
+ jobId = self.killQueue.get(block=False)
+ except Empty:
+ break
+ else:
+ killList.append(jobId)
+
+ if not killList:
+ return False
+
+ # Do the dirty job
+ for jobID in list(killList):
+ if jobID in self.runningJobs:
+ logger.debug('Killing job: %s', jobID)
+ subprocess.check_call(['qdel', self.getSgeID(jobID)])
+ else:
+ if jobID in self.waitingJobs:
+ self.waitingJobs.remove(jobID)
+ self.killedJobsQueue.put(jobID)
+ killList.remove(jobID)
+
+ # Wait to confirm the kill
+ while killList:
+ for jobID in list(killList):
+ if self.getJobExitCode(self.sgeJobIDs[jobID]) is not None:
+ logger.debug('Adding jobID %s to killedJobsQueue', jobID)
+ self.killedJobsQueue.put(jobID)
+ killList.remove(jobID)
+ self.forgetJob(jobID)
+ if len(killList) > 0:
+ logger.warn("Some jobs weren't killed, trying again in %is.", sleepSeconds)
+ time.sleep(sleepSeconds)
+
+ return True
+
+ def createJobs(self, newJob):
+ activity = False
+ # Load new job id if present:
+ if newJob is not None:
+ self.waitingJobs.append(newJob)
+ # Launch jobs as necessary:
+ while (len(self.waitingJobs) > 0
+ and sum(self.allocatedCpus.values()) < int(self.boss.maxCores)):
+ activity = True
+ jobID, cpu, memory, command = self.waitingJobs.pop(0)
+ qsubline = self.prepareQsub(cpu, memory, jobID) + [command]
+ sgeJobID = self.qsub(qsubline)
+ self.sgeJobIDs[jobID] = (sgeJobID, None)
+ self.runningJobs.add(jobID)
+ self.allocatedCpus[jobID] = cpu
+ return activity
+
+ def checkOnJobs(self):
+ activity = False
+ logger.debug('List of running jobs: %r', self.runningJobs)
+ for jobID in list(self.runningJobs):
+ status = self.getJobExitCode(self.sgeJobIDs[jobID])
+ if status is not None:
+ activity = True
+ self.updatedJobsQueue.put((jobID, status))
+ self.forgetJob(jobID)
+ return activity
+
+ def run(self):
+ while True:
+ activity = False
+ newJob = None
+ if not self.newJobsQueue.empty():
+ activity = True
+ newJob = self.newJobsQueue.get()
+ if newJob is None:
+ logger.debug('Received queue sentinel.')
+ break
+ activity |= self.killJobs()
+ activity |= self.createJobs(newJob)
+ activity |= self.checkOnJobs()
+ if not activity:
+ logger.debug('No activity, sleeping for %is', sleepSeconds)
+ time.sleep(sleepSeconds)
+
+ def prepareQsub(self, cpu, mem, jobID):
+ qsubline = ['qsub', '-V', '-b', 'y', '-terse', '-j', 'y', '-cwd',
+ '-N', 'toil_job_' + str(jobID)]
+
+ if self.boss.environment:
+ qsubline.append('-v')
+ qsubline.append(','.join(k + '=' + quote(os.environ[k] if v is None else v)
+ for k, v in iteritems(self.boss.environment)))
+
+ reqline = list()
+ if mem is not None:
+ memStr = str(mem / 1024) + 'K'
+ reqline += ['vf=' + memStr, 'h_vmem=' + memStr]
+ if len(reqline) > 0:
+ qsubline.extend(['-hard', '-l', ','.join(reqline)])
+ sgeArgs = os.getenv('TOIL_GRIDENGINE_ARGS')
+ if sgeArgs:
+ sgeArgs = sgeArgs.split()
+ for arg in sgeArgs:
+ if arg.startswith(("vf=", "hvmem=", "-pe")):
+ raise ValueError("Unexpected CPU, memory or pe specifications in TOIL_GRIDGENGINE_ARGs: %s" % arg)
+ qsubline.extend(sgeArgs)
+ if cpu is not None and math.ceil(cpu) > 1:
+ peConfig = os.getenv('TOIL_GRIDENGINE_PE') or 'shm'
+ qsubline.extend(['-pe', peConfig, str(int(math.ceil(cpu)))])
+ return qsubline
+
+ def qsub(self, qsubline):
+ logger.debug("Running %r", " ".join(qsubline))
+ process = subprocess.Popen(qsubline, stdout=subprocess.PIPE)
+ result = int(process.stdout.readline().strip().split('.')[0])
+ return result
+
+ def getJobExitCode(self, sgeJobID):
+ job, task = sgeJobID
+ args = ["qacct", "-j", str(job)]
+ if task is not None:
+ args.extend(["-t", str(task)])
+ logger.debug("Running %r", args)
+ process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
+ for line in process.stdout:
+ if line.startswith("failed") and int(line.split()[1]) == 1:
+ return 1
+ elif line.startswith("exit_status"):
+ logger.debug('Exit Status: %r', line.split()[1])
+ return int(line.split()[1])
+ return None
+
+
+class GridengineBatchSystem(BatchSystemSupport):
+ """
+ The interface for SGE aka Sun GridEngine.
+ """
+
+ @classmethod
+ def supportsWorkerCleanup(cls):
+ return False
+
+ @classmethod
+ def supportsHotDeployment(cls):
+ return False
+
+ def __init__(self, config, maxCores, maxMemory, maxDisk):
+ super(GridengineBatchSystem, self).__init__(config, maxCores, maxMemory, maxDisk)
+ self.gridengineResultsFile = self._getResultsFileName(config.jobStore)
+ # Reset the job queue and results (initially, we do this again once we've killed the jobs)
+ self.gridengineResultsFileHandle = open(self.gridengineResultsFile, 'w')
+ # We lose any previous state in this file, and ensure the files existence
+ self.gridengineResultsFileHandle.close()
+ self.currentJobs = set()
+ self.maxCPU, self.maxMEM = self.obtainSystemConstants()
+ self.nextJobID = 0
+ self.newJobsQueue = Queue()
+ self.updatedJobsQueue = Queue()
+ self.killQueue = Queue()
+ self.killedJobsQueue = Queue()
+ self.worker = Worker(self.newJobsQueue, self.updatedJobsQueue, self.killQueue,
+ self.killedJobsQueue, self)
+ self.worker.start()
+
+ def __des__(self):
+ # Closes the file handle associated with the results file.
+ self.gridengineResultsFileHandle.close()
+
+ def issueBatchJob(self, jobNode):
+ self.checkResourceRequest(jobNode.memory, jobNode.cores, jobNode.disk)
+ jobID = self.nextJobID
+ self.nextJobID += 1
+ self.currentJobs.add(jobID)
+ self.newJobsQueue.put((jobID, jobNode.cores, jobNode.memory, jobNode.command))
+ logger.debug("Issued the job command: %s with job id: %s ", jobNode.command, str(jobID))
+ return jobID
+
+ def killBatchJobs(self, jobIDs):
+ """
+ Kills the given jobs, represented as Job ids, then checks they are dead by checking
+ they are not in the list of issued jobs.
+ """
+ jobIDs = set(jobIDs)
+ logger.debug('Jobs to be killed: %r', jobIDs)
+ for jobID in jobIDs:
+ self.killQueue.put(jobID)
+ while jobIDs:
+ killedJobId = self.killedJobsQueue.get()
+ if killedJobId is None:
+ break
+ jobIDs.remove(killedJobId)
+ if killedJobId in self.currentJobs:
+ self.currentJobs.remove(killedJobId)
+ if jobIDs:
+ logger.debug('Some kills (%s) still pending, sleeping %is', len(jobIDs),
+ sleepSeconds)
+ time.sleep(sleepSeconds)
+
+ def getIssuedBatchJobIDs(self):
+ """
+ Gets the list of jobs issued to SGE.
+ """
+ return list(self.currentJobs)
+
+ def getRunningBatchJobIDs(self):
+ return self.worker.getRunningJobIDs()
+
+ def getUpdatedBatchJob(self, maxWait):
+ try:
+ item = self.updatedJobsQueue.get(timeout=maxWait)
+ except Empty:
+ return None
+ logger.debug('UpdatedJobsQueue Item: %s', item)
+ jobID, retcode = item
+ self.currentJobs.remove(jobID)
+ return jobID, retcode, None
+
+ def shutdown(self):
+ """
+ Signals worker to shutdown (via sentinel) then cleanly joins the thread
+ """
+ newJobsQueue = self.newJobsQueue
+ self.newJobsQueue = None
+
+ newJobsQueue.put(None)
+ self.worker.join()
+
+ def getWaitDuration(self):
+ return 0.0
+
+ @classmethod
+ def getRescueBatchJobFrequency(cls):
+ return 30 * 60 # Half an hour
+
+ @staticmethod
+ def obtainSystemConstants():
+ lines = filter(None, map(str.strip, subprocess.check_output(["qhost"]).split('\n')))
+ line = lines[0]
+ items = line.strip().split()
+ num_columns = len(items)
+ cpu_index = None
+ mem_index = None
+ for i in range(num_columns):
+ if items[i] == 'NCPU':
+ cpu_index = i
+ elif items[i] == 'MEMTOT':
+ mem_index = i
+ if cpu_index is None or mem_index is None:
+ RuntimeError('qhost command does not return NCPU or MEMTOT columns')
+ maxCPU = 0
+ maxMEM = MemoryString("0")
+ for line in lines[2:]:
+ items = line.strip().split()
+ if len(items) < num_columns:
+ RuntimeError('qhost output has a varying number of columns')
+ if items[cpu_index] != '-' and items[cpu_index] > maxCPU:
+ maxCPU = items[cpu_index]
+ if items[mem_index] != '-' and MemoryString(items[mem_index]) > maxMEM:
+ maxMEM = MemoryString(items[mem_index])
+ if maxCPU is 0 or maxMEM is 0:
+ RuntimeError('qhost returned null NCPU or MEMTOT info')
+ return maxCPU, maxMEM
+
+ def setEnv(self, name, value=None):
+ if value and ',' in value:
+ raise ValueError("GridEngine does not support commata in environment variable values")
+ return super(GridengineBatchSystem,self).setEnv(name, value)
diff --git a/src/toil/batchSystems/lsf.py b/src/toil/batchSystems/lsf.py
new file mode 100644
index 0000000..73ef671
--- /dev/null
+++ b/src/toil/batchSystems/lsf.py
@@ -0,0 +1,285 @@
+#Copyright (C) 2013 by Thomas Keane (tk2 at sanger.ac.uk)
+#
+#Permission is hereby granted, free of charge, to any person obtaining a copy
+#of this software and associated documentation files (the "Software"), to deal
+#in the Software without restriction, including without limitation the rights
+#to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+#copies of the Software, and to permit persons to whom the Software is
+#furnished to do so, subject to the following conditions:
+#
+#The above copyright notice and this permission notice shall be included in
+#all copies or substantial portions of the Software.
+#
+#THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+#IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+#FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+#AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+#LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+#OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+#THE SOFTWARE.
+from __future__ import absolute_import
+import logging
+import subprocess
+import time
+from threading import Thread
+from datetime import date
+
+# Python 3 compatibility imports
+from six.moves.queue import Empty, Queue
+
+from toil.batchSystems import MemoryString
+from toil.batchSystems.abstractBatchSystem import BatchSystemSupport
+
+logger = logging.getLogger( __name__ )
+
+
+
+def prepareBsub(cpu, mem):
+ mem = '' if mem is None else '-R "select[type==X86_64 && mem > ' + str(int(mem/ 1000000)) + '] rusage[mem=' + str(int(mem/ 1000000)) + ']" -M' + str(int(mem/ 1000000)) + '000'
+ cpu = '' if cpu is None else '-n ' + str(int(cpu))
+ bsubline = ["bsub", mem, cpu,"-cwd", ".", "-o", "/dev/null", "-e", "/dev/null"]
+ return bsubline
+
+def bsub(bsubline):
+ process = subprocess.Popen(" ".join(bsubline), shell=True, stdout = subprocess.PIPE, stderr = subprocess.STDOUT)
+ liney = process.stdout.readline()
+ logger.debug("BSUB: " + liney)
+ result = int(liney.strip().split()[1].strip('<>'))
+ logger.debug("Got the job id: %s" % (str(result)))
+ return result
+
+def getjobexitcode(lsfJobID):
+ job, task = lsfJobID
+
+ #first try bjobs to find out job state
+ args = ["bjobs", "-l", str(job)]
+ logger.debug("Checking job exit code for job via bjobs: " + str(job))
+ process = subprocess.Popen(" ".join(args), shell=True, stdout = subprocess.PIPE, stderr = subprocess.STDOUT)
+ started = 0
+ for line in process.stdout:
+ if line.find("Done successfully") > -1:
+ logger.debug("bjobs detected job completed for job: " + str(job))
+ return 0
+ elif line.find("Completed <exit>") > -1:
+ logger.debug("bjobs detected job failed for job: " + str(job))
+ return 1
+ elif line.find("New job is waiting for scheduling") > -1:
+ logger.debug("bjobs detected job pending scheduling for job: " + str(job))
+ return None
+ elif line.find("PENDING REASONS") > -1:
+ logger.debug("bjobs detected job pending for job: " + str(job))
+ return None
+ elif line.find("Started on ") > -1:
+ started = 1
+
+ if started == 1:
+ logger.debug("bjobs detected job started but not completed: " + str(job))
+ return None
+
+ #if not found in bjobs, then try bacct (slower than bjobs)
+ logger.debug("bjobs failed to detect job - trying bacct: " + str(job))
+
+ args = ["bacct", "-l", str(job)]
+ logger.debug("Checking job exit code for job via bacct:" + str(job))
+ process = subprocess.Popen(" ".join(args), shell=True, stdout = subprocess.PIPE, stderr = subprocess.STDOUT)
+ for line in process.stdout:
+ if line.find("Completed <done>") > -1:
+ logger.debug("Detected job completed for job: " + str(job))
+ return 0
+ elif line.find("Completed <exit>") > -1:
+ logger.debug("Detected job failed for job: " + str(job))
+ return 1
+ logger.debug("Cant determine exit code for job or job still running: " + str(job))
+ return None
+
+class Worker(Thread):
+ def __init__(self, newJobsQueue, updatedJobsQueue, boss):
+ Thread.__init__(self)
+ self.newJobsQueue = newJobsQueue
+ self.updatedJobsQueue = updatedJobsQueue
+ self.currentjobs = list()
+ self.runningjobs = set()
+ self.boss = boss
+
+ def run(self):
+ while True:
+ # Load new job ids:
+ while not self.newJobsQueue.empty():
+ self.currentjobs.append(self.newJobsQueue.get())
+
+ # Launch jobs as necessary:
+ while len(self.currentjobs) > 0:
+ jobID, bsubline = self.currentjobs.pop()
+ lsfJobID = bsub(bsubline)
+ self.boss.jobIDs[(lsfJobID, None)] = jobID
+ self.boss.lsfJobIDs[jobID] = (lsfJobID, None)
+ self.runningjobs.add((lsfJobID, None))
+
+ # Test known job list
+ for lsfJobID in list(self.runningjobs):
+ exit = getjobexitcode(lsfJobID)
+ if exit is not None:
+ self.updatedJobsQueue.put((lsfJobID, exit))
+ self.runningjobs.remove(lsfJobID)
+
+ time.sleep(10)
+
+class LSFBatchSystem(BatchSystemSupport):
+ """
+ The interface for running jobs on lsf, runs all the jobs you give it as they come in,
+ but in parallel.
+ """
+ @classmethod
+ def supportsWorkerCleanup(cls):
+ return False
+
+ @classmethod
+ def supportsHotDeployment(cls):
+ return False
+
+ def shutdown(self):
+ pass
+
+ def __init__(self, config, maxCores, maxMemory, maxDisk):
+ super(LSFBatchSystem, self).__init__(config, maxCores, maxMemory, maxDisk)
+ self.lsfResultsFile = self._getResultsFileName(config.jobStore)
+ #Reset the job queue and results (initially, we do this again once we've killed the jobs)
+ self.lsfResultsFileHandle = open(self.lsfResultsFile, 'w')
+ self.lsfResultsFileHandle.close() #We lose any previous state in this file, and ensure the files existence
+ self.currentjobs = set()
+ self.obtainSystemConstants()
+ self.jobIDs = dict()
+ self.lsfJobIDs = dict()
+ self.nextJobID = 0
+
+ self.newJobsQueue = Queue()
+ self.updatedJobsQueue = Queue()
+ self.worker = Worker(self.newJobsQueue, self.updatedJobsQueue, self)
+ self.worker.setDaemon(True)
+ self.worker.start()
+
+ def __des__(self):
+ #Closes the file handle associated with the results file.
+ self.lsfResultsFileHandle.close() #Close the results file, cos were done.
+
+ def issueBatchJob(self, jobNode):
+ jobID = self.nextJobID
+ self.nextJobID += 1
+ self.currentjobs.add(jobID)
+ bsubline = prepareBsub(jobNode.cores, jobNode.memory) + [jobNode.command]
+ self.newJobsQueue.put((jobID, bsubline))
+ logger.debug("Issued the job command: %s with job id: %s " % (jobNode.command, str(jobID)))
+ return jobID
+
+ def getLsfID(self, jobID):
+ if not jobID in self.lsfJobIDs:
+ RuntimeError("Unknown jobID, could not be converted")
+
+ (job,task) = self.lsfJobIDs[jobID]
+ if task is None:
+ return str(job)
+ else:
+ return str(job) + "." + str(task)
+
+ def killBatchJobs(self, jobIDs):
+ """Kills the given job IDs.
+ """
+ for jobID in jobIDs:
+ logger.debug("DEL: " + str(self.getLsfID(jobID)))
+ self.currentjobs.remove(jobID)
+ process = subprocess.Popen(["bkill", self.getLsfID(jobID)])
+ del self.jobIDs[self.lsfJobIDs[jobID]]
+ del self.lsfJobIDs[jobID]
+
+ toKill = set(jobIDs)
+ while len(toKill) > 0:
+ for jobID in list(toKill):
+ if getjobexitcode(self.lsfJobIDs[jobID]) is not None:
+ toKill.remove(jobID)
+
+ if len(toKill) > 0:
+ logger.warn("Tried to kill some jobs, but something happened and they are still going, "
+ "so I'll try again")
+ time.sleep(5)
+
+ def getIssuedBatchJobIDs(self):
+ """A list of jobs (as jobIDs) currently issued (may be running, or maybe
+ just waiting).
+ """
+ return self.currentjobs
+
+ def getRunningBatchJobIDs(self):
+ """Gets a map of jobs (as jobIDs) currently running (not just waiting)
+ and a how long they have been running for (in seconds).
+ """
+ times = {}
+ currentjobs = set(self.lsfJobIDs[x] for x in self.getIssuedBatchJobIDs())
+ process = subprocess.Popen(["bjobs"], stdout = subprocess.PIPE)
+
+ for curline in process.stdout:
+ items = curline.strip().split()
+ if (len(items) > 9 and (items[0]) in currentjobs) and items[2] == 'RUN':
+ jobstart = "/".join(items[7:9]) + '/' + str(date.today().year)
+ jobstart = jobstart + ' ' + items[9]
+ jobstart = time.mktime(time.strptime(jobstart,"%b/%d/%Y %H:%M"))
+ jobstart = time.mktime(time.strptime(jobstart,"%m/%d/%Y %H:%M:%S"))
+ times[self.jobIDs[(items[0])]] = time.time() - jobstart
+ return times
+
+ def getUpdatedBatchJob(self, maxWait):
+ try:
+ sgeJobID, retcode = self.updatedJobsQueue.get(timeout=maxWait)
+ self.updatedJobsQueue.task_done()
+ jobID, retcode = (self.jobIDs[sgeJobID], retcode)
+ self.currentjobs -= {self.jobIDs[sgeJobID]}
+ except Empty:
+ pass
+ else:
+ return jobID, retcode, None
+
+ def getWaitDuration(self):
+ """We give parasol a second to catch its breath (in seconds)
+ """
+ #return 0.0
+ return 15
+
+ @classmethod
+ def getRescueBatchJobFrequency(cls):
+ """Parasol leaks jobs, but rescuing jobs involves calls to parasol list jobs and pstat2,
+ making it expensive. We allow this every 10 minutes..
+ """
+ return 1800
+
+ def obtainSystemConstants(self):
+ p = subprocess.Popen(["lshosts"], stdout = subprocess.PIPE, stderr = subprocess.STDOUT)
+
+ line = p.stdout.readline()
+ items = line.strip().split()
+ num_columns = len(items)
+ cpu_index = None
+ mem_index = None
+ for i in range(num_columns):
+ if items[i] == 'ncpus':
+ cpu_index = i
+ elif items[i] == 'maxmem':
+ mem_index = i
+
+ if cpu_index is None or mem_index is None:
+ RuntimeError("lshosts command does not return ncpus or maxmem columns")
+
+ p.stdout.readline()
+
+ self.maxCPU = 0
+ self.maxMEM = MemoryString("0")
+ for line in p.stdout:
+ items = line.strip().split()
+ if len(items) < num_columns:
+ RuntimeError("lshosts output has a varying number of columns")
+ if items[cpu_index] != '-' and items[cpu_index] > self.maxCPU:
+ self.maxCPU = items[cpu_index]
+ if items[mem_index] != '-' and MemoryString(items[mem_index]) > self.maxMEM:
+ self.maxMEM = MemoryString(items[mem_index])
+
+ if self.maxCPU is 0 or self.maxMEM is 0:
+ RuntimeError("lshosts returns null ncpus or maxmem info")
+ logger.debug("Got the maxCPU: %s" % (self.maxMEM))
diff --git a/src/toil/batchSystems/mesos/__init__.py b/src/toil/batchSystems/mesos/__init__.py
new file mode 100644
index 0000000..4cd4707
--- /dev/null
+++ b/src/toil/batchSystems/mesos/__init__.py
@@ -0,0 +1,56 @@
+# Copyright (C) 2015-2016 Regents of the University of California
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from __future__ import absolute_import
+from collections import namedtuple
+
+TaskData = namedtuple('TaskData', (
+ # Time when the task was started
+ 'startTime',
+ # Mesos' ID of the slave where task is being run
+ 'slaveID',
+ # Mesos' ID of the executor running the task
+ 'executorID'))
+
+
+class ResourceRequirement( namedtuple('_ResourceRequirement', (
+ # Number of bytes (!) needed for a task
+ 'memory',
+ # Number of CPU cores needed for a task
+ 'cores',
+ # Number of bytes (!) needed for task on disk
+ 'disk',
+ # True, if job can be run on a preemptable node, False otherwise
+ 'preemptable'))):
+ def size(self):
+ """
+ The scalar size of an offer. Can be used to compare offers.
+ """
+ return self.cores
+
+
+ToilJob = namedtuple('ToilJob', (
+ # A job ID specific to this batch system implementation
+ 'jobID',
+ # What string to display in the mesos UI
+ 'name',
+ # A ResourceRequirement tuple describing the resources needed by this job
+ 'resources',
+ # The command to be run on the worker node
+ 'command',
+ # The resource object representing the user script
+ 'userScript',
+ # A dictionary with additional environment variables to be set on the worker process
+ 'environment',
+ # A named tuple containing all the required info for cleaning up the worker node
+ 'workerCleanupInfo'))
diff --git a/src/toil/batchSystems/mesos/batchSystem.py b/src/toil/batchSystems/mesos/batchSystem.py
new file mode 100644
index 0000000..38902e3
--- /dev/null
+++ b/src/toil/batchSystems/mesos/batchSystem.py
@@ -0,0 +1,586 @@
+# Copyright (C) 2015-2016 Regents of the University of California
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+
+import ast
+import logging
+import os
+import pickle
+import pwd
+import socket
+import time
+from collections import defaultdict
+from operator import attrgetter
+from struct import unpack
+
+import itertools
+
+# Python 3 compatibility imports
+from six.moves.queue import Empty, Queue
+from six import iteritems
+
+import mesos.interface
+import mesos.native
+from bd2k.util import strict_bool
+from mesos.interface import mesos_pb2
+
+from toil import resolveEntryPoint
+from toil.batchSystems.abstractBatchSystem import (AbstractScalableBatchSystem,
+ BatchSystemSupport,
+ NodeInfo)
+from toil.batchSystems.mesos import ToilJob, ResourceRequirement, TaskData
+
+log = logging.getLogger(__name__)
+
+
+class MesosBatchSystem(BatchSystemSupport,
+ AbstractScalableBatchSystem,
+ mesos.interface.Scheduler):
+ """
+ A Toil batch system implementation that uses Apache Mesos to distribute toil jobs as Mesos
+ tasks over a cluster of slave nodes. A Mesos framework consists of a scheduler and an
+ executor. This class acts as the scheduler and is typically run on the master node that also
+ runs the Mesos master process with which the scheduler communicates via a driver component.
+ The executor is implemented in a separate class. It is run on each slave node and
+ communicates with the Mesos slave process via another driver object. The scheduler may also
+ be run on a separate node from the master, which we then call somewhat ambiguously the driver
+ node.
+ """
+
+ @classmethod
+ def supportsHotDeployment(cls):
+ return True
+
+ @classmethod
+ def supportsWorkerCleanup(cls):
+ return True
+
+ class ExecutorInfo(object):
+ def __init__(self, nodeAddress, slaveId, nodeInfo, lastSeen):
+ super(MesosBatchSystem.ExecutorInfo, self).__init__()
+ self.nodeAddress = nodeAddress
+ self.slaveId = slaveId
+ self.nodeInfo = nodeInfo
+ self.lastSeen = lastSeen
+
+ def __init__(self, config, maxCores, maxMemory, maxDisk, masterAddress):
+ super(MesosBatchSystem, self).__init__(config, maxCores, maxMemory, maxDisk)
+
+ # The hot-deployed resource representing the user script. Will be passed along in every
+ # Mesos task. Also see setUserScript().
+ self.userScript = None
+ """
+ :type: toil.resource.Resource
+ """
+
+ # Dictionary of queues, which toil assigns jobs to. Each queue represents a job type,
+ # defined by resource usage
+ self.jobQueues = defaultdict(list)
+
+ # Address of the Mesos master in the form host:port where host can be an IP or a hostname
+ self.masterAddress = masterAddress
+
+ # Written to when Mesos kills tasks, as directed by Toil
+ self.killedJobIds = set()
+
+ # The IDs of job to be killed
+ self.killJobIds = set()
+
+ # Contains jobs on which killBatchJobs were called, regardless of whether or not they
+ # actually were killed or ended by themselves
+ self.intendedKill = set()
+
+ # Dict of launched jobIDs to TaskData objects
+ self.runningJobMap = {}
+
+ # Queue of jobs whose status has been updated, according to Mesos
+ self.updatedJobsQueue = Queue()
+
+ # The Mesos driver used by this scheduler
+ self.driver = None
+
+ # A dictionary mapping a node's IP to an ExecutorInfo object describing important
+ # properties of our executor running on that node. Only an approximation of the truth.
+ self.executors = {}
+
+ # A set of Mesos slave IDs, one for each slave running on a non-preemptable node. Only an
+ # approximation of the truth. Recently launched nodes may be absent from this set for a
+ # while and a node's absence from this set does not imply its preemptability. But it is
+ # generally safer to assume a node is preemptable since non-preemptability is a stronger
+ # requirement. If we tracked the set of preemptable nodes instead, we'd have to use
+ # absence as an indicator of non-preemptability and could therefore be misled into
+ # believeing that a recently launched preemptable node was non-preemptable.
+ self.nonPreemptibleNodes = set()
+
+ self.executor = self._buildExecutor()
+
+ self.unusedJobID = itertools.count()
+ self.lastReconciliation = time.time()
+ self.reconciliationPeriod = 120
+
+ # These control how frequently to log a message that would indicate if no jobs are
+ # currently able to run on the offers given. This can happen if the cluster is busy
+ # or if the nodes in the cluster simply don't have enough resources to run the jobs
+ self.lastTimeOfferLogged = 0
+ self.logPeriod = 30 # seconds
+
+ self._startDriver()
+
+ def setUserScript(self, userScript):
+ self.userScript = userScript
+
+ def issueBatchJob(self, jobNode):
+ """
+ Issues the following command returning a unique jobID. Command is the string to run, memory
+ is an int giving the number of bytes the job needs to run in and cores is the number of cpus
+ needed for the job and error-file is the path of the file to place any std-err/std-out in.
+ """
+ self.checkResourceRequest(jobNode.memory, jobNode.cores, jobNode.disk)
+ jobID = next(self.unusedJobID)
+ job = ToilJob(jobID=jobID,
+ name=str(jobNode),
+ resources=ResourceRequirement(**jobNode._requirements),
+ command=jobNode.command,
+ userScript=self.userScript,
+ environment=self.environment.copy(),
+ workerCleanupInfo=self.workerCleanupInfo)
+ jobType = job.resources
+ log.debug("Queueing the job command: %s with job id: %s ...", jobNode.command, str(jobID))
+ self.jobQueues[jobType].append(job)
+ log.debug("... queued")
+ return jobID
+
+ def killBatchJobs(self, jobIDs):
+ # FIXME: probably still racy
+ assert self.driver is not None
+ localSet = set()
+ for jobID in jobIDs:
+ self.killJobIds.add(jobID)
+ localSet.add(jobID)
+ self.intendedKill.add(jobID)
+ # FIXME: a bit too expensive for my taste
+ if jobID in self.getIssuedBatchJobIDs():
+ taskId = mesos_pb2.TaskID()
+ taskId.value = str(jobID)
+ self.driver.killTask(taskId)
+ else:
+ self.killJobIds.remove(jobID)
+ localSet.remove(jobID)
+ while localSet:
+ intersection = localSet.intersection(self.killedJobIds)
+ if intersection:
+ localSet -= intersection
+ self.killedJobIds -= intersection
+ else:
+ time.sleep(1)
+
+ def getIssuedBatchJobIDs(self):
+ jobIds = set()
+ for queue in self.jobQueues.values():
+ for job in queue:
+ jobIds.add(job.jobID)
+ jobIds.update(self.runningJobMap.keys())
+ return list(jobIds)
+
+ def getRunningBatchJobIDs(self):
+ currentTime = dict()
+ for jobID, data in self.runningJobMap.items():
+ currentTime[jobID] = time.time() - data.startTime
+ return currentTime
+
+ def getUpdatedBatchJob(self, maxWait):
+ while True:
+ try:
+ item = self.updatedJobsQueue.get(timeout=maxWait)
+ except Empty:
+ return None
+ jobId, exitValue, wallTime = item
+ try:
+ self.intendedKill.remove(jobId)
+ except KeyError:
+ log.debug('Job %s ended with status %i, took %s seconds.', jobId, exitValue,
+ '???' if wallTime is None else str(wallTime))
+ return item
+ else:
+ log.debug('Job %s ended naturally before it could be killed.', jobId)
+
+ def getWaitDuration(self):
+ """
+ Gets the period of time to wait (floating point, in seconds) between checking for
+ missing/overlong jobs.
+ """
+ return self.reconciliationPeriod
+
+ @classmethod
+ def getRescueBatchJobFrequency(cls):
+ return 30 * 60 # Half an hour
+
+ def _buildExecutor(self):
+ """
+ Creates and returns an ExecutorInfo instance representing our executor implementation.
+ """
+ # The executor program is installed as a setuptools entry point by setup.py
+ info = mesos_pb2.ExecutorInfo()
+ info.name = "toil"
+ info.command.value = resolveEntryPoint('_toil_mesos_executor')
+ info.executor_id.value = "toil-%i" % os.getpid()
+ info.source = pwd.getpwuid(os.getuid()).pw_name
+ return info
+
+ def _startDriver(self):
+ """
+ The Mesos driver thread which handles the scheduler's communication with the Mesos master
+ """
+ framework = mesos_pb2.FrameworkInfo()
+ framework.user = "" # Have Mesos fill in the current user.
+ framework.name = "toil"
+ framework.principal = framework.name
+ self.driver = mesos.native.MesosSchedulerDriver(self,
+ framework,
+ self._resolveAddress(self.masterAddress),
+ True) # enable implicit acknowledgements
+ assert self.driver.start() == mesos_pb2.DRIVER_RUNNING
+
+ @staticmethod
+ def _resolveAddress(address):
+ """
+ Resolves the host in the given string. The input is of the form host[:port]. This method
+ is idempotent, i.e. the host may already be a dotted IP address.
+
+ >>> # noinspection PyProtectedMember
+ >>> f=MesosBatchSystem._resolveAddress
+ >>> f('localhost')
+ '127.0.0.1'
+ >>> f('127.0.0.1')
+ '127.0.0.1'
+ >>> f('localhost:123')
+ '127.0.0.1:123'
+ >>> f('127.0.0.1:123')
+ '127.0.0.1:123'
+ """
+ address = address.split(':')
+ assert len(address) in (1, 2)
+ address[0] = socket.gethostbyname(address[0])
+ return ':'.join(address)
+
+ def shutdown(self):
+ log.debug("Stopping Mesos driver")
+ self.driver.stop()
+ log.debug("Joining Mesos driver")
+ driver_result = self.driver.join()
+ log.debug("Joined Mesos driver")
+ if driver_result != mesos_pb2.DRIVER_STOPPED:
+ raise RuntimeError("Mesos driver failed with %i", driver_result)
+
+ def registered(self, driver, frameworkId, masterInfo):
+ """
+ Invoked when the scheduler successfully registers with a Mesos master
+ """
+ log.debug("Registered with framework ID %s", frameworkId.value)
+
+ def _sortJobsByResourceReq(self):
+ jobTypes = self.jobQueues.keys()
+ # The dominant criteria is preemptability of jobs. Non-preemptable (NP) jobs should be
+ # considered first because they can only be run on on NP nodes while P jobs can run on
+ # both. Without this prioritization of NP jobs, P jobs could steal NP cores from NP jobs,
+ # leaving subsequently offered P cores unused. Despite the prioritization of NP jobs,
+ # NP jobs can not steal P cores from P jobs, simply because the offer-acceptance logic
+ # would not accept a P offer with a NP job.
+ jobTypes.sort(key=attrgetter('preemptable', 'size'))
+ jobTypes.reverse()
+ return jobTypes
+
+ def _declineAllOffers(self, driver, offers):
+ for offer in offers:
+ log.debug("Declining offer %s.", offer.id.value)
+ driver.declineOffer(offer.id)
+
+ def _parseOffer(self, offer):
+ cores = 0
+ memory = 0
+ disk = 0
+ preemptable = None
+ for attribute in offer.attributes:
+ if attribute.name == 'preemptable':
+ assert preemptable is None, "Attribute 'preemptable' occurs more than once."
+ preemptable = strict_bool(attribute.text.value)
+ if preemptable is None:
+ log.warn('Slave not marked as either preemptable or not. Assuming non-preemptable.')
+ preemptable = False
+ for resource in offer.resources:
+ if resource.name == "cpus":
+ cores += resource.scalar.value
+ elif resource.name == "mem":
+ memory += resource.scalar.value
+ elif resource.name == "disk":
+ disk += resource.scalar.value
+ return cores, memory, disk, preemptable
+
+ def _prepareToRun(self, jobType, offer, index):
+ # Get the first element to insure FIFO
+ job = self.jobQueues[jobType][index]
+ task = self._newMesosTask(job, offer)
+ return task
+
+ def _deleteByJobID(self, jobID, ):
+ # FIXME: Surely there must be a more efficient way to do this
+ for jobType in self.jobQueues.values():
+ for job in jobType:
+ if jobID == job.jobID:
+ jobType.remove(job)
+
+ def _updateStateToRunning(self, offer, task):
+ self.runningJobMap[int(task.task_id.value)] = TaskData(startTime=time.time(),
+ slaveID=offer.slave_id,
+ executorID=task.executor.executor_id)
+ self._deleteByJobID(int(task.task_id.value))
+
+ def resourceOffers(self, driver, offers):
+ """
+ Invoked when resources have been offered to this framework.
+ """
+ self._trackOfferedNodes(offers)
+
+ jobTypes = self._sortJobsByResourceReq()
+
+ # TODO: We may want to assert that numIssued >= numRunning
+ if not jobTypes or len(self.getIssuedBatchJobIDs()) == len(self.getRunningBatchJobIDs()):
+ log.debug('There are no queued tasks. Declining Mesos offers.')
+ # Without jobs, we can get stuck with no jobs and no new offers until we decline it.
+ self._declineAllOffers(driver, offers)
+ return
+
+ unableToRun = True
+ # Right now, gives priority to largest jobs
+ for offer in offers:
+ runnableTasks = []
+ # TODO: In an offer, can there ever be more than one resource with the same name?
+ offerCores, offerMemory, offerDisk, offerPreemptable = self._parseOffer(offer)
+ log.debug('Got offer %s for a %spreemptable slave with %.2f MiB memory, %.2f core(s) '
+ 'and %.2f MiB of disk.', offer.id.value, '' if offerPreemptable else 'non-',
+ offerMemory, offerCores, offerDisk)
+ remainingCores = offerCores
+ remainingMemory = offerMemory
+ remainingDisk = offerDisk
+
+ for jobType in jobTypes:
+ runnableTasksOfType = []
+ # Because we are not removing from the list until outside of the while loop, we
+ # must decrement the number of jobs left to run ourselves to avoid an infinite
+ # loop.
+ nextToLaunchIndex = 0
+ # Toil specifies disk and memory in bytes but Mesos uses MiB
+ while (len(self.jobQueues[jobType]) - nextToLaunchIndex > 0
+ # On a non-preemptable node we can run any job, on a preemptable node we
+ # can only run preemptable jobs:
+ and (not offerPreemptable or jobType.preemptable)
+ and remainingCores >= jobType.cores
+ and remainingDisk >= toMiB(jobType.disk)
+ and remainingMemory >= toMiB(jobType.memory)):
+ task = self._prepareToRun(jobType, offer, nextToLaunchIndex)
+ # TODO: this used to be a conditional but Hannes wanted it changed to an assert
+ # TODO: ... so we can understand why it exists.
+ assert int(task.task_id.value) not in self.runningJobMap
+ runnableTasksOfType.append(task)
+ log.debug("Preparing to launch Mesos task %s using offer %s ...",
+ task.task_id.value, offer.id.value)
+ remainingCores -= jobType.cores
+ remainingMemory -= toMiB(jobType.memory)
+ remainingDisk -= toMiB(jobType.disk)
+ nextToLaunchIndex += 1
+ if self.jobQueues[jobType] and not runnableTasksOfType:
+ log.debug('Offer %(offer)s not suitable to run the tasks with requirements '
+ '%(requirements)r. Mesos offered %(memory)s memory, %(cores)s cores '
+ 'and %(disk)s of disk on a %(non)spreemptable slave.',
+ dict(offer=offer.id.value,
+ requirements=jobType,
+ non='' if offerPreemptable else 'non-',
+ memory=fromMiB(offerMemory),
+ cores=offerCores,
+ disk=fromMiB(offerDisk)))
+ runnableTasks.extend(runnableTasksOfType)
+ # Launch all runnable tasks together so we only call launchTasks once per offer
+ if runnableTasks:
+ unableToRun = False
+ driver.launchTasks(offer.id, runnableTasks)
+ for task in runnableTasks:
+ self._updateStateToRunning(offer, task)
+ log.debug('Launched Mesos task %s.', task.task_id.value)
+ else:
+ log.debug('Although there are queued jobs, none of them could be run with offer %s '
+ 'extended to the framework.', offer.id)
+ driver.declineOffer(offer.id)
+
+ if unableToRun and time.time() > (self.lastTimeOfferLogged + self.logPeriod):
+ self.lastTimeOfferLogged = time.time()
+ log.debug('Although there are queued jobs, none of them were able to run in '
+ 'any of the offers extended to the framework. There are currently '
+ '%i jobs running. Enable debug level logging to see more details about '
+ 'job types and offers received.', len(self.runningJobMap))
+
+ def _trackOfferedNodes(self, offers):
+ for offer in offers:
+ nodeAddress = socket.gethostbyname(offer.hostname)
+ self._registerNode(nodeAddress, offer.slave_id.value)
+ preemptable = False
+ for attribute in offer.attributes:
+ if attribute.name == 'preemptable':
+ preemptable = strict_bool(attribute.text.value)
+ if preemptable:
+ try:
+ self.nonPreemptibleNodes.remove(offer.slave_id.value)
+ except KeyError:
+ pass
+ else:
+ self.nonPreemptibleNodes.add(offer.slave_id.value)
+
+ def _newMesosTask(self, job, offer):
+ """
+ Build the Mesos task object for a given the Toil job and Mesos offer
+ """
+ task = mesos_pb2.TaskInfo()
+ task.task_id.value = str(job.jobID)
+ task.slave_id.value = offer.slave_id.value
+ # FIXME: what bout
+ task.name = job.name
+ task.data = pickle.dumps(job)
+ task.executor.MergeFrom(self.executor)
+
+ cpus = task.resources.add()
+ cpus.name = "cpus"
+ cpus.type = mesos_pb2.Value.SCALAR
+ cpus.scalar.value = job.resources.cores
+
+ disk = task.resources.add()
+ disk.name = "disk"
+ disk.type = mesos_pb2.Value.SCALAR
+ if toMiB(job.resources.disk) > 1:
+ disk.scalar.value = toMiB(job.resources.disk)
+ else:
+ log.warning("Job %s uses less disk than Mesos requires. Rounding %s up to 1 MiB.",
+ job.jobID, job.resources.disk)
+ disk.scalar.value = 1
+ mem = task.resources.add()
+ mem.name = "mem"
+ mem.type = mesos_pb2.Value.SCALAR
+ if toMiB(job.resources.memory) > 1:
+ mem.scalar.value = toMiB(job.resources.memory)
+ else:
+ log.warning("Job %s uses less memory than Mesos requires. Rounding %s up to 1 MiB.",
+ job.jobID, job.resources.memory)
+ mem.scalar.value = 1
+ return task
+
+ def statusUpdate(self, driver, update):
+ """
+ Invoked when the status of a task has changed (e.g., a slave is lost and so the task is
+ lost, a task finishes and an executor sends a status update saying so, etc). Note that
+ returning from this callback _acknowledges_ receipt of this status update! If for
+ whatever reason the scheduler aborts during this callback (or the process exits) another
+ status update will be delivered (note, however, that this is currently not true if the
+ slave sending the status update is lost/fails during that time).
+ """
+ jobID = int(update.task_id.value)
+ stateName = mesos_pb2.TaskState.Name(update.state)
+ log.debug("Job %i is in state '%s'.", jobID, stateName)
+
+ def jobEnded(_exitStatus, wallTime=None):
+ try:
+ self.killJobIds.remove(jobID)
+ except KeyError:
+ pass
+ else:
+ self.killedJobIds.add(jobID)
+ self.updatedJobsQueue.put((jobID, _exitStatus, wallTime))
+ try:
+ del self.runningJobMap[jobID]
+ except KeyError:
+ log.warning("Job %i returned exit code %i but isn't tracked as running.",
+ jobID, _exitStatus)
+
+ if update.state == mesos_pb2.TASK_FINISHED:
+ jobEnded(0, wallTime=unpack('d', update.data)[0])
+ elif update.state == mesos_pb2.TASK_FAILED:
+ try:
+ exitStatus = int(update.message)
+ except ValueError:
+ exitStatus = 255
+ log.warning("Job %i failed with message '%s'", jobID, update.message)
+ else:
+ log.warning('Job %i failed with exit status %i', jobID, exitStatus)
+ jobEnded(exitStatus)
+ elif update.state in (mesos_pb2.TASK_LOST, mesos_pb2.TASK_KILLED, mesos_pb2.TASK_ERROR):
+ log.warning("Job %i is in unexpected state %s with message '%s'.",
+ jobID, stateName, update.message)
+ jobEnded(255)
+
+ def frameworkMessage(self, driver, executorId, slaveId, message):
+ """
+ Invoked when an executor sends a message.
+ """
+ log.debug('Got framework message from executor %s running on slave %s: %s',
+ executorId.value, slaveId.value, message)
+ message = ast.literal_eval(message)
+ assert isinstance(message, dict)
+ # Handle the mandatory fields of a message
+ nodeAddress = message.pop('address')
+ executor = self._registerNode(nodeAddress, slaveId.value)
+ # Handle optional message fields
+ for k, v in iteritems(message):
+ if k == 'nodeInfo':
+ assert isinstance(v, dict)
+ executor.nodeInfo = NodeInfo(**v)
+ self.executors[nodeAddress] = executor
+ else:
+ raise RuntimeError("Unknown message field '%s'." % k)
+
+ def _registerNode(self, nodeAddress, slaveId):
+ executor = self.executors.get(nodeAddress)
+ if executor is None or executor.slaveId != slaveId:
+ executor = self.ExecutorInfo(nodeAddress=nodeAddress,
+ slaveId=slaveId,
+ nodeInfo=None,
+ lastSeen=time.time())
+ self.executors[nodeAddress] = executor
+ else:
+ executor.lastSeen = time.time()
+ return executor
+
+ def getNodes(self, preemptable=None):
+ return {nodeAddress: executor.nodeInfo
+ for nodeAddress, executor in iteritems(self.executors)
+ if time.time() - executor.lastSeen < 600
+ and (preemptable is None
+ or preemptable == (executor.slaveId not in self.nonPreemptibleNodes))}
+
+ def reregistered(self, driver, masterInfo):
+ """
+ Invoked when the scheduler re-registers with a newly elected Mesos master.
+ """
+ log.debug('Registered with new master')
+
+ def executorLost(self, driver, executorId, slaveId, status):
+ """
+ Invoked when an executor has exited/terminated.
+ """
+ log.warning("Executor '%s' lost.", executorId)
+
+
+def toMiB(n):
+ return n / 1024 / 1024
+
+
+def fromMiB(n):
+ return n * 1024 * 1024
diff --git a/src/toil/batchSystems/mesos/conftest.py b/src/toil/batchSystems/mesos/conftest.py
new file mode 100644
index 0000000..657c64f
--- /dev/null
+++ b/src/toil/batchSystems/mesos/conftest.py
@@ -0,0 +1,23 @@
+# Copyright (C) 2015-2016 Regents of the University of California
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# https://pytest.org/latest/example/pythoncollection.html
+
+collect_ignore = []
+
+try:
+ import mesos.interface
+except ImportError:
+ collect_ignore.append("batchSystem.py")
+ collect_ignore.append("executor.py")
diff --git a/src/toil/batchSystems/mesos/executor.py b/src/toil/batchSystems/mesos/executor.py
new file mode 100644
index 0000000..761531f
--- /dev/null
+++ b/src/toil/batchSystems/mesos/executor.py
@@ -0,0 +1,198 @@
+# Copyright (C) 2015-2016 Regents of the University of California
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+import os
+import random
+import socket
+import sys
+import threading
+import pickle
+import logging
+import subprocess
+import traceback
+from time import sleep, time
+
+import psutil
+import mesos.interface
+from bd2k.util.expando import Expando
+from mesos.interface import mesos_pb2
+import mesos.native
+from struct import pack
+from toil.batchSystems.abstractBatchSystem import BatchSystemSupport
+from toil.resource import Resource
+
+log = logging.getLogger(__name__)
+
+
+class MesosExecutor(mesos.interface.Executor):
+ """
+ Part of Toil's Mesos framework, runs on a Mesos slave. A Toil job is passed to it via the
+ task.data field, and launched via call(toil.command).
+ """
+
+ def __init__(self):
+ super(MesosExecutor, self).__init__()
+ self.popenLock = threading.Lock()
+ self.runningTasks = {}
+ self.workerCleanupInfo = None
+ Resource.prepareSystem()
+ self.address = None
+ # Setting this value at this point will ensure that the toil workflow directory will go to
+ # the mesos sandbox if the user hasn't specified --workDir on the command line.
+ if not os.getenv('TOIL_WORKDIR'):
+ os.environ['TOIL_WORKDIR'] = os.getcwd()
+
+ def registered(self, driver, executorInfo, frameworkInfo, slaveInfo):
+ """
+ Invoked once the executor driver has been able to successfully connect with Mesos.
+ """
+ log.debug("Registered with framework")
+ self.address = socket.gethostbyname(slaveInfo.hostname)
+ nodeInfoThread = threading.Thread(target=self._sendFrameworkMessage, args=[driver])
+ nodeInfoThread.daemon = True
+ nodeInfoThread.start()
+
+ def reregistered(self, driver, slaveInfo):
+ """
+ Invoked when the executor re-registers with a restarted slave.
+ """
+ log.debug("Re-registered")
+
+ def disconnected(self, driver):
+ """
+ Invoked when the executor becomes "disconnected" from the slave (e.g., the slave is being
+ restarted due to an upgrade).
+ """
+ log.critical("Disconnected from slave")
+
+ def killTask(self, driver, taskId):
+ try:
+ pid = self.runningTasks[taskId]
+ except KeyError:
+ pass
+ else:
+ os.kill(pid, 9)
+
+ def shutdown(self, driver):
+ log.critical('Shutting down executor ...')
+ for taskId in self.runningTasks.keys():
+ self.killTask(driver, taskId)
+ Resource.cleanSystem()
+ BatchSystemSupport.workerCleanup(self.workerCleanupInfo)
+ log.critical('... executor shut down.')
+
+ def error(self, driver, message):
+ """
+ Invoked when a fatal error has occurred with the executor and/or executor driver.
+ """
+ log.critical("FATAL ERROR: " + message)
+
+ def _sendFrameworkMessage(self, driver):
+ message = None
+ while True:
+ # The psutil documentation recommends that we ignore the value returned by the first
+ # invocation of cpu_percent(). However, we do want to send a sign of life early after
+ # starting (e.g. to unblock the provisioner waiting for an instance to come up) so
+ # the first message we send omits the load info.
+ if message is None:
+ message = Expando(address=self.address)
+ psutil.cpu_percent()
+ else:
+ message.nodeInfo = dict(cores=float(psutil.cpu_percent()) * .01,
+ memory=float(psutil.virtual_memory().percent) * .01,
+ workers=len(self.runningTasks))
+ driver.sendFrameworkMessage(repr(message))
+ # Prevent workers launched together from repeatedly hitting the leader at the same time
+ sleep(random.randint(45, 75))
+
+ def launchTask(self, driver, task):
+ """
+ Invoked by SchedulerDriver when a Mesos task should be launched by this executor
+ """
+
+ def runTask():
+ log.debug("Running task %s", task.task_id.value)
+ sendUpdate(mesos_pb2.TASK_RUNNING)
+ # This is where task.data is first invoked. Using this position to setup cleanupInfo
+ taskData = pickle.loads(task.data)
+ if self.workerCleanupInfo is not None:
+ assert self.workerCleanupInfo == taskData.workerCleanupInfo
+ else:
+ self.workerCleanupInfo = taskData.workerCleanupInfo
+ startTime = time()
+ try:
+ popen = runJob(taskData)
+ self.runningTasks[task.task_id.value] = popen.pid
+ try:
+ exitStatus = popen.wait()
+ wallTime = time() - startTime
+ if 0 == exitStatus:
+ sendUpdate(mesos_pb2.TASK_FINISHED, wallTime)
+ elif -9 == exitStatus:
+ sendUpdate(mesos_pb2.TASK_KILLED, wallTime)
+ else:
+ sendUpdate(mesos_pb2.TASK_FAILED, wallTime, message=str(exitStatus))
+ finally:
+ del self.runningTasks[task.task_id.value]
+ except:
+ wallTime = time() - startTime
+ exc_info = sys.exc_info()
+ log.error('Exception while running task:', exc_info=exc_info)
+ exc_type, exc_value, exc_trace = exc_info
+ sendUpdate(mesos_pb2.TASK_FAILED, wallTime,
+ message=''.join(traceback.format_exception_only(exc_type, exc_value)))
+
+ def runJob(job):
+ """
+ :type job: toil.batchSystems.mesos.ToilJob
+
+ :rtype: subprocess.Popen
+ """
+ if job.userScript:
+ job.userScript.register()
+ log.debug("Invoking command: '%s'", job.command)
+ with self.popenLock:
+ return subprocess.Popen(job.command,
+ shell=True, env=dict(os.environ, **job.environment))
+
+ def sendUpdate(taskState, wallTime=None, message=''):
+ log.debug('Sending task status update ...')
+ status = mesos_pb2.TaskStatus()
+ status.task_id.value = task.task_id.value
+ status.message = message
+ status.state = taskState
+ if wallTime is not None:
+ status.data = pack('d', wallTime)
+ driver.sendStatusUpdate(status)
+ log.debug('... done sending task status update.')
+
+ thread = threading.Thread(target=runTask)
+ thread.start()
+
+ def frameworkMessage(self, driver, message):
+ """
+ Invoked when a framework message has arrived for this executor.
+ """
+ log.debug("Received message from framework: {}".format(message))
+
+
+def main(executorClass=MesosExecutor):
+ logging.basicConfig(level=logging.DEBUG)
+ log.debug("Starting executor")
+ executor = executorClass()
+ driver = mesos.native.MesosExecutorDriver(executor)
+ exit_value = 0 if driver.run() == mesos_pb2.DRIVER_STOPPED else 1
+ assert len(executor.runningTasks) == 0
+ sys.exit(exit_value)
diff --git a/src/toil/batchSystems/mesos/test/__init__.py b/src/toil/batchSystems/mesos/test/__init__.py
new file mode 100644
index 0000000..f5b50fc
--- /dev/null
+++ b/src/toil/batchSystems/mesos/test/__init__.py
@@ -0,0 +1,84 @@
+from __future__ import absolute_import
+from abc import ABCMeta, abstractmethod
+import logging
+import shutil
+import threading
+import subprocess
+import multiprocessing
+
+from bd2k.util.processes import which
+from bd2k.util.threading import ExceptionalThread
+
+log = logging.getLogger(__name__)
+
+
+class MesosTestSupport(object):
+ """
+ A mixin for test cases that need a running Mesos master and slave on the local host
+ """
+
+ def _startMesos(self, numCores=None):
+ if numCores is None:
+ numCores = multiprocessing.cpu_count()
+ shutil.rmtree('/tmp/mesos', ignore_errors=True)
+ self.master = self.MesosMasterThread(numCores)
+ self.master.start()
+ self.slave = self.MesosSlaveThread(numCores)
+ self.slave.start()
+
+ def _stopMesos(self):
+ self.slave.popen.kill()
+ self.slave.join()
+ self.master.popen.kill()
+ self.master.join()
+
+ class MesosThread(ExceptionalThread):
+ __metaclass__ = ABCMeta
+
+ # Lock is used because subprocess is NOT thread safe: http://tinyurl.com/pkp5pgq
+ lock = threading.Lock()
+
+ def __init__(self, numCores):
+ threading.Thread.__init__(self)
+ self.numCores = numCores
+ with self.lock:
+ self.popen = subprocess.Popen(self.mesosCommand())
+
+ @abstractmethod
+ def mesosCommand(self):
+ raise NotImplementedError
+
+ def tryRun(self):
+ self.popen.wait()
+ log.info('Exiting %s', self.__class__.__name__)
+
+ def findMesosBinary(self, name):
+ try:
+ return next(which(name))
+ except StopIteration:
+ try:
+ # Special case for users of PyCharm on OS X. This is where Homebrew installs
+ # it. It's hard to set PATH for PyCharm (or any GUI app) on OS X so let's
+ # make it easy for those poor souls.
+ return next(which(name, path=['/usr/local/sbin']))
+ except StopIteration:
+ raise RuntimeError("Cannot find the '%s' binary. Make sure Mesos is installed "
+ "and it's 'bin' directory is present on the PATH." % name)
+
+ class MesosMasterThread(MesosThread):
+ def mesosCommand(self):
+ return [self.findMesosBinary('mesos-master'),
+ '--registry=in_memory',
+ '--ip=127.0.0.1',
+ '--port=5050',
+ '--allocation_interval=500ms']
+
+ class MesosSlaveThread(MesosThread):
+ def mesosCommand(self):
+ # NB: The --resources parameter forces this test to use a predictable number of
+ # cores, independent of how many cores the system running the test actually has.
+ return [self.findMesosBinary('mesos-slave'),
+ '--ip=127.0.0.1',
+ '--master=127.0.0.1:5050',
+ '--attributes=preemptable:False',
+ '--resources=cpus(*):%i' % self.numCores]
diff --git a/src/toil/batchSystems/parasol.py b/src/toil/batchSystems/parasol.py
new file mode 100644
index 0000000..d38e02f
--- /dev/null
+++ b/src/toil/batchSystems/parasol.py
@@ -0,0 +1,372 @@
+# Copyright (C) 2015-2016 Regents of the University of California
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+import logging
+import os
+import re
+import sys
+import subprocess
+import tempfile
+import time
+from threading import Thread
+
+# Python 3 compatibility imports
+from six.moves.queue import Empty, Queue
+from six import itervalues
+
+from bd2k.util.iterables import concat
+from bd2k.util.processes import which
+
+from toil.batchSystems.abstractBatchSystem import BatchSystemSupport
+from toil.lib.bioio import getTempFile
+
+logger = logging.getLogger(__name__)
+
+
+class ParasolBatchSystem(BatchSystemSupport):
+ """
+ The interface for Parasol.
+ """
+
+ @classmethod
+ def supportsWorkerCleanup(cls):
+ return False
+
+ @classmethod
+ def supportsHotDeployment(cls):
+ return False
+
+ def __init__(self, config, maxCores, maxMemory, maxDisk):
+ super(ParasolBatchSystem, self).__init__(config, maxCores, maxMemory, maxDisk)
+ if maxMemory != sys.maxint:
+ logger.warn('The Parasol batch system does not support maxMemory.')
+ # Keep the name of the results file for the pstat2 command..
+ command = config.parasolCommand
+ if os.path.sep not in command:
+ try:
+ command = next(which(command))
+ except StopIteration:
+ raise RuntimeError("Can't find %s on PATH." % command)
+ logger.debug('Using Parasol at %s', command)
+ self.parasolCommand = command
+ self.parasolResultsDir = tempfile.mkdtemp(dir=config.jobStore)
+
+ # In Parasol, each results file corresponds to a separate batch, and all jobs in a batch
+ # have the same cpu and memory requirements. The keys to this dictionary are the (cpu,
+ # memory) tuples for each batch. A new batch is created whenever a job has a new unique
+ # combination of cpu and memory requirements.
+ self.resultsFiles = dict()
+ self.maxBatches = config.parasolMaxBatches
+
+ # Allows the worker process to send back the IDs of jobs that have finished, so the batch
+ # system can decrease its used cpus counter
+ self.cpuUsageQueue = Queue()
+
+ # Also stores finished job IDs, but is read by getUpdatedJobIDs().
+ self.updatedJobsQueue = Queue()
+
+ # Use this to stop the worker when shutting down
+ self.running = True
+
+ self.worker = Thread(target=self.updatedJobWorker, args=())
+ self.worker.start()
+ self.usedCpus = 0
+ self.jobIDsToCpu = {}
+
+ # Set of jobs that have been issued but aren't known to have finished or been killed yet.
+ # Jobs that end by themselves are removed in getUpdatedJob, and jobs that are killed are
+ # removed in killBatchJobs.
+ self.runningJobs = set()
+
+ def _runParasol(self, command, autoRetry=True):
+ """
+ Issues a parasol command using popen to capture the output. If the command fails then it
+ will try pinging parasol until it gets a response. When it gets a response it will
+ recursively call the issue parasol command, repeating this pattern for a maximum of N
+ times. The final exit value will reflect this.
+ """
+ command = list(concat(self.parasolCommand, command))
+ while True:
+ logger.debug('Running %r', command)
+ process = subprocess.Popen(command,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ bufsize=-1)
+ stdout, stderr = process.communicate()
+ status = process.wait()
+ for line in stderr.split('\n'):
+ if line: logger.warn(line)
+ if status == 0:
+ return 0, stdout.split('\n')
+ message = 'Command %r failed with exit status %i' % (command, status)
+ if autoRetry:
+ logger.warn(message)
+ else:
+ logger.error(message)
+ return status, None
+ logger.warn('Waiting for a 10s, before trying again')
+ time.sleep(10)
+
+ parasolOutputPattern = re.compile("your job ([0-9]+).*")
+
+ def issueBatchJob(self, jobNode):
+ """
+ Issues parasol with job commands.
+ """
+ self.checkResourceRequest(jobNode.memory, jobNode.cores, jobNode.disk)
+
+ MiB = 1 << 20
+ truncatedMemory = (jobNode.memory / MiB) * MiB
+ # Look for a batch for jobs with these resource requirements, with
+ # the memory rounded down to the nearest megabyte. Rounding down
+ # meams the new job can't ever decrease the memory requirements
+ # of jobs already in the batch.
+ if len(self.resultsFiles) >= self.maxBatches:
+ raise RuntimeError( 'Number of batches reached limit of %i' % self.maxBatches)
+ try:
+ results = self.resultsFiles[(truncatedMemory, jobNode.cores)]
+ except KeyError:
+ results = getTempFile(rootDir=self.parasolResultsDir)
+ self.resultsFiles[(truncatedMemory, jobNode.cores)] = results
+
+ # Prefix the command with environment overrides, optionally looking them up from the
+ # current environment if the value is None
+ command = ' '.join(concat('env', self.__environment(), jobNode.command))
+ parasolCommand = ['-verbose',
+ '-ram=%i' % jobNode.memory,
+ '-cpu=%i' % jobNode.cores,
+ '-results=' + results,
+ 'add', 'job', command]
+ # Deal with the cpus
+ self.usedCpus += jobNode.cores
+ while True: # Process finished results with no wait
+ try:
+ jobID = self.cpuUsageQueue.get_nowait()
+ except Empty:
+ break
+ if jobID in self.jobIDsToCpu.keys():
+ self.usedCpus -= self.jobIDsToCpu.pop(jobID)
+ assert self.usedCpus >= 0
+ while self.usedCpus > self.maxCores: # If we are still waiting
+ jobID = self.cpuUsageQueue.get()
+ if jobID in self.jobIDsToCpu.keys():
+ self.usedCpus -= self.jobIDsToCpu.pop(jobID)
+ assert self.usedCpus >= 0
+ # Now keep going
+ while True:
+ line = self._runParasol(parasolCommand)[1][0]
+ match = self.parasolOutputPattern.match(line)
+ if match is None:
+ # This is because parasol add job will return success, even if the job was not
+ # properly issued!
+ logger.debug('We failed to properly add the job, we will try again after a 5s.')
+ time.sleep(5)
+ else:
+ jobID = int(match.group(1))
+ self.jobIDsToCpu[jobID] = jobNode.cores
+ self.runningJobs.add(jobID)
+ logger.debug("Got the parasol job id: %s from line: %s" % (jobID, line))
+ return jobID
+
+ def setEnv(self, name, value=None):
+ if value and ' ' in value:
+ raise ValueError('Parasol does not support spaces in environment variable values.')
+ return super(ParasolBatchSystem, self).setEnv(name, value)
+
+ def __environment(self):
+ return (k + '=' + (os.environ[k] if v is None else v) for k, v in self.environment.items())
+
+ def killBatchJobs(self, jobIDs):
+ """Kills the given jobs, represented as Job ids, then checks they are dead by checking
+ they are not in the list of issued jobs.
+ """
+ while True:
+ for jobID in jobIDs:
+ if jobID in self.runningJobs:
+ self.runningJobs.remove(jobID)
+ exitValue = self._runParasol(['remove', 'job', str(jobID)],
+ autoRetry=False)[0]
+ logger.debug("Tried to remove jobID: %i, with exit value: %i" % (jobID, exitValue))
+ runningJobs = self.getIssuedBatchJobIDs()
+ if set(jobIDs).difference(set(runningJobs)) == set(jobIDs):
+ break
+ logger.warn( 'Tried to kill some jobs, but something happened and they are still '
+ 'going, will try againin 5s.')
+ time.sleep(5)
+ # Update the CPU usage, because killed jobs aren't written to the results file.
+ for jobID in jobIDs:
+ if jobID in self.jobIDsToCpu.keys():
+ self.usedCpus -= self.jobIDsToCpu.pop(jobID)
+
+ queuePattern = re.compile(r'q\s+([0-9]+)')
+ runningPattern = re.compile(r'r\s+([0-9]+)\s+[\S]+\s+[\S]+\s+([0-9]+)\s+[\S]+')
+
+ def getJobIDsForResultsFile(self, resultsFile):
+ """
+ Get all queued and running jobs for a results file.
+ """
+ jobIDs = []
+ for line in self._runParasol(['-results=' + resultsFile, 'pstat2'])[1]:
+ runningJobMatch = self.runningPattern.match(line)
+ queuedJobMatch = self.queuePattern.match(line)
+ if runningJobMatch:
+ jobID = runningJobMatch.group(1)
+ elif queuedJobMatch:
+ jobID = queuedJobMatch.group(1)
+ else:
+ continue
+ jobIDs.append(int(jobID))
+ return set(jobIDs)
+
+ def getIssuedBatchJobIDs(self):
+ """
+ Gets the list of jobs issued to parasol in all results files, but not including jobs
+ created by other users.
+ """
+ issuedJobs = set()
+ for resultsFile in itervalues(self.resultsFiles):
+ issuedJobs.update(self.getJobIDsForResultsFile(resultsFile))
+
+ return list(issuedJobs)
+
+ def getRunningBatchJobIDs(self):
+ """
+ Returns map of running jobIDs and the time they have been running.
+ """
+ # Example lines..
+ # r 5410186 benedictpaten worker 1247029663 localhost
+ # r 5410324 benedictpaten worker 1247030076 localhost
+ runningJobs = {}
+ issuedJobs = self.getIssuedBatchJobIDs()
+ for line in self._runParasol(['pstat2'])[1]:
+ if line != '':
+ match = self.runningPattern.match(line)
+ if match is not None:
+ jobID = int(match.group(1))
+ startTime = int(match.group(2))
+ if jobID in issuedJobs: # It's one of our jobs
+ runningJobs[jobID] = time.time() - startTime
+ return runningJobs
+
+ def getUpdatedBatchJob(self, maxWait):
+ while True:
+ try:
+ jobID, status, wallTime = self.updatedJobsQueue.get(timeout=maxWait)
+ except Empty:
+ return None
+ try:
+ self.runningJobs.remove(jobID)
+ except KeyError:
+ # We tried to kill this job, but it ended by itself instead, so skip it.
+ pass
+ else:
+ return jobID, status, wallTime
+
+ @classmethod
+ def getRescueBatchJobFrequency(cls):
+ """
+ Parasol leaks jobs, but rescuing jobs involves calls to parasol list jobs and pstat2,
+ making it expensive.
+ """
+ return 5400 # Once every 90 minutes
+
+ def updatedJobWorker(self):
+ """
+ We use the parasol results to update the status of jobs, adding them
+ to the list of updated jobs.
+
+ Results have the following structure.. (thanks Mark D!)
+
+ int status; /* Job status - wait() return format. 0 is good. */
+ char *host; /* Machine job ran on. */
+ char *jobId; /* Job queuing system job ID */
+ char *exe; /* Job executable file (no path) */
+ int usrTicks; /* 'User' CPU time in ticks. */
+ int sysTicks; /* 'System' CPU time in ticks. */
+ unsigned submitTime; /* Job submission time in seconds since 1/1/1970 */
+ unsigned startTime; /* Job start time in seconds since 1/1/1970 */
+ unsigned endTime; /* Job end time in seconds since 1/1/1970 */
+ char *user; /* User who ran job */
+ char *errFile; /* Location of stderr file on host */
+
+ Plus you finally have the command name.
+ """
+ resultsFiles = set()
+ resultsFileHandles = []
+ try:
+ while self.running:
+ # Look for any new results files that have been created, and open them
+ newResultsFiles = set(os.listdir(self.parasolResultsDir)).difference(resultsFiles)
+ for newFile in newResultsFiles:
+ newFilePath = os.path.join(self.parasolResultsDir, newFile)
+ resultsFileHandles.append(open(newFilePath, 'r'))
+ resultsFiles.add(newFile)
+ for fileHandle in resultsFileHandles:
+ while self.running:
+ line = fileHandle.readline()
+ if not line:
+ break
+ assert line[-1] == '\n'
+ (status, host, jobId, exe, usrTicks, sysTicks, submitTime, startTime,
+ endTime, user, errFile, command) = line[:-1].split(None, 11)
+ status = int(status)
+ jobId = int(jobId)
+ if os.WIFEXITED(status):
+ status = os.WEXITSTATUS(status)
+ else:
+ status = -status
+ self.cpuUsageQueue.put(jobId)
+ startTime = int(startTime)
+ endTime = int(endTime)
+ if endTime == startTime:
+ # Both, start and end time is an integer so to get sub-second
+ # accuracy we use the ticks reported by Parasol as an approximation.
+ # This isn't documented but what Parasol calls "ticks" is actually a
+ # hundredth of a second. Parasol does the unit conversion early on
+ # after a job finished. Search paraNode.c for ticksToHundreths. We
+ # also cheat a little by always reporting at least one hundredth of a
+ # second.
+ usrTicks = int(usrTicks)
+ sysTicks = int(sysTicks)
+ wallTime = float( max( 1, usrTicks + sysTicks) ) * 0.01
+ else:
+ wallTime = float(endTime - startTime)
+ self.updatedJobsQueue.put((jobId, status, wallTime))
+ time.sleep(1)
+ except:
+ logger.warn("Error occurred while parsing parasol results files.")
+ raise
+ finally:
+ for fileHandle in resultsFileHandles:
+ fileHandle.close()
+
+ def shutdown(self):
+ self.killBatchJobs(self.getIssuedBatchJobIDs()) # cleanup jobs
+ for results in itervalues(self.resultsFiles):
+ exitValue = self._runParasol(['-results=' + results, 'clear', 'sick'],
+ autoRetry=False)[0]
+ if exitValue is not None:
+ logger.warn("Could not clear sick status of the parasol batch %s" % results)
+ exitValue = self._runParasol(['-results=' + results, 'flushResults'],
+ autoRetry=False)[0]
+ if exitValue is not None:
+ logger.warn("Could not flush the parasol batch %s" % results)
+ self.running = False
+ logger.debug('Joining worker thread...')
+ self.worker.join()
+ logger.debug('... joined worker thread.')
+ for results in self.resultsFiles.values():
+ os.remove(results)
+ os.rmdir(self.parasolResultsDir)
diff --git a/src/toil/batchSystems/parasolTestSupport.py b/src/toil/batchSystems/parasolTestSupport.py
new file mode 100644
index 0000000..d534b53
--- /dev/null
+++ b/src/toil/batchSystems/parasolTestSupport.py
@@ -0,0 +1,104 @@
+from __future__ import absolute_import
+import logging
+import tempfile
+import threading
+import time
+import subprocess
+import multiprocessing
+import os
+from bd2k.util.files import rm_f
+from bd2k.util.objects import InnerClass
+
+from toil import physicalMemory
+
+log = logging.getLogger(__name__)
+
+
+class ParasolTestSupport(object):
+ """
+ For test cases that need a running Parasol leader and worker on the local host
+ """
+
+ def _startParasol(self, numCores=None, memory=None):
+ if numCores is None:
+ numCores = multiprocessing.cpu_count()
+ if memory is None:
+ memory = physicalMemory()
+ self.numCores = numCores
+ self.memory = memory
+ self.leader = self.ParasolLeaderThread()
+ self.leader.start()
+ self.worker = self.ParasolWorkerThread()
+ self.worker.start()
+ while self.leader.popen is None or self.worker.popen is None:
+ log.info('Waiting for leader and worker processes')
+ time.sleep(.1)
+
+ def _stopParasol(self):
+ self.worker.popen.kill()
+ self.worker.join()
+ self.leader.popen.kill()
+ self.leader.join()
+ for path in ('para.results', 'parasol.jid'):
+ rm_f(path)
+
+ class ParasolThread(threading.Thread):
+
+ # Lock is used because subprocess is NOT thread safe: http://tinyurl.com/pkp5pgq
+ lock = threading.Lock()
+
+ def __init__(self):
+ threading.Thread.__init__(self)
+ self.popen = None
+
+ def parasolCommand(self):
+ raise NotImplementedError
+
+ def run(self):
+ command = self.parasolCommand()
+ with self.lock:
+ self.popen = subprocess.Popen(command)
+ status = self.popen.wait()
+ if status != 0:
+ log.error("Command '%s' failed with %i.", command, status)
+ raise subprocess.CalledProcessError(status, command)
+ log.info('Exiting %s', self.__class__.__name__)
+
+ @InnerClass
+ class ParasolLeaderThread(ParasolThread):
+
+ def __init__(self):
+ super(ParasolTestSupport.ParasolLeaderThread, self).__init__()
+ self.machineList = None
+
+ def run(self):
+ with tempfile.NamedTemporaryFile(prefix='machineList.txt', mode='w') as f:
+ self.machineList = f.name
+ # name - Network name
+ # cpus - Number of CPUs we can use
+ # ramSize - Megabytes of memory
+ # tempDir - Location of (local) temp dir
+ # localDir - Location of local data dir
+ # localSize - Megabytes of local disk
+ # switchName - Name of switch this is on
+ f.write('localhost {numCores} {ramSize} {tempDir} {tempDir} 1024 foo'.format(
+ numCores=self.outer.numCores,
+ tempDir=tempfile.gettempdir(),
+ ramSize=self.outer.memory / 1024 / 1024))
+ f.flush()
+ super(ParasolTestSupport.ParasolLeaderThread, self).run()
+
+ def parasolCommand(self):
+ return ['paraHub',
+ '-spokes=1',
+ '-debug',
+ self.machineList]
+
+ @InnerClass
+ class ParasolWorkerThread(ParasolThread):
+ def parasolCommand(self):
+ return ['paraNode',
+ '-cpu=%i' % self.outer.numCores,
+ '-randomDelay=0',
+ '-debug',
+ 'start']
diff --git a/src/toil/batchSystems/singleMachine.py b/src/toil/batchSystems/singleMachine.py
new file mode 100644
index 0000000..08ce30f
--- /dev/null
+++ b/src/toil/batchSystems/singleMachine.py
@@ -0,0 +1,348 @@
+# Copyright (C) 2015-2016 Regents of the University of California
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from contextlib import contextmanager
+import logging
+import multiprocessing
+import os
+import subprocess
+import time
+import math
+from threading import Thread
+from threading import Lock, Condition
+
+# Python 3 compatibility imports
+from six.moves.queue import Empty, Queue
+from six.moves import xrange
+from six import iteritems
+
+import toil
+from toil.batchSystems.abstractBatchSystem import BatchSystemSupport, InsufficientSystemResources
+
+log = logging.getLogger(__name__)
+
+
+class SingleMachineBatchSystem(BatchSystemSupport):
+ """
+ The interface for running jobs on a single machine, runs all the jobs you give it as they
+ come in, but in parallel.
+ """
+
+ @classmethod
+ def supportsHotDeployment(cls):
+ return False
+
+ @classmethod
+ def supportsWorkerCleanup(cls):
+ return True
+
+ numCores = multiprocessing.cpu_count()
+
+ minCores = 0.1
+ """
+ The minimal fractional CPU. Tasks with a smaller core requirement will be rounded up to this
+ value. One important invariant of this class is that each worker thread represents a CPU
+ requirement of minCores, meaning that we can never run more than numCores / minCores jobs
+ concurrently.
+ """
+ physicalMemory = toil.physicalMemory()
+
+ def __init__(self, config, maxCores, maxMemory, maxDisk):
+ if maxCores > self.numCores:
+ log.warn('Limiting maxCores to CPU count of system (%i).', self.numCores)
+ maxCores = self.numCores
+ if maxMemory > self.physicalMemory:
+ log.warn('Limiting maxMemory to physically available memory (%i).', self.physicalMemory)
+ maxMemory = self.physicalMemory
+ self.physicalDisk = toil.physicalDisk(config)
+ if maxDisk > self.physicalDisk:
+ log.warn('Limiting maxDisk to physically available disk (%i).', self.physicalDisk)
+ maxDisk = self.physicalDisk
+ super(SingleMachineBatchSystem, self).__init__(config, maxCores, maxMemory, maxDisk)
+ assert self.maxCores >= self.minCores
+ assert self.maxMemory >= 1
+
+ # The scale allows the user to apply a factor to each task's cores requirement, thereby
+ # squeezing more tasks onto each core (scale < 1) or stretching tasks over more cores
+ # (scale > 1).
+ self.scale = config.scale
+ # Number of worker threads that will be started
+ self.numWorkers = int(self.maxCores / self.minCores)
+ # A counter to generate job IDs and a lock to guard it
+ self.jobIndex = 0
+ self.jobIndexLock = Lock()
+ # A dictionary mapping IDs of submitted jobs to the command line
+ self.jobs = {}
+ """
+ :type: dict[str,toil.job.JobNode]
+ """
+ # A queue of jobs waiting to be executed. Consumed by the workers.
+ self.inputQueue = Queue()
+ # A queue of finished jobs. Produced by the workers.
+ self.outputQueue = Queue()
+ # A dictionary mapping IDs of currently running jobs to their Info objects
+ self.runningJobs = {}
+ """
+ :type: dict[str,Info]
+ """
+ # The list of worker threads
+ self.workerThreads = []
+ """
+ :type list[Thread]
+ """
+ # Variables involved with non-blocking resource acquisition
+ self.acquisitionTimeout = 5
+ self.acquisitionRetryDelay = 10
+ self.aquisitionCondition = Condition()
+
+ # A pool representing available CPU in units of minCores
+ self.coreFractions = ResourcePool(self.numWorkers, 'cores', self.acquisitionTimeout)
+ # A lock to work around the lack of thread-safety in Python's subprocess module
+ self.popenLock = Lock()
+ # A pool representing available memory in bytes
+ self.memory = ResourcePool(self.maxMemory, 'memory', self.acquisitionTimeout)
+ # A pool representing the available space in bytes
+ self.disk = ResourcePool(self.maxDisk, 'disk', self.acquisitionTimeout)
+
+ log.debug('Setting up the thread pool with %i workers, '
+ 'given a minimum CPU fraction of %f '
+ 'and a maximum CPU value of %i.', self.numWorkers, self.minCores, maxCores)
+ for i in xrange(self.numWorkers):
+ worker = Thread(target=self.worker, args=(self.inputQueue,))
+ self.workerThreads.append(worker)
+ worker.start()
+
+ # Note: The input queue is passed as an argument because the corresponding attribute is reset
+ # to None in shutdown()
+
+ def worker(self, inputQueue):
+ while True:
+ args = inputQueue.get()
+ if args is None:
+ log.debug('Received queue sentinel.')
+ break
+ jobCommand, jobID, jobCores, jobMemory, jobDisk, environment = args
+ while True:
+ try:
+ coreFractions = int(jobCores / self.minCores)
+ log.debug('Acquiring %i bytes of memory from a pool of %s.', jobMemory,
+ self.memory)
+ with self.memory.acquisitionOf(jobMemory):
+ log.debug('Acquiring %i fractional cores from a pool of %s to satisfy a '
+ 'request of %f cores', coreFractions, self.coreFractions,
+ jobCores)
+ with self.coreFractions.acquisitionOf(coreFractions):
+ with self.disk.acquisitionOf(jobDisk):
+ startTime = time.time() #Time job is started
+ with self.popenLock:
+ popen = subprocess.Popen(jobCommand,
+ shell=True,
+ env=dict(os.environ, **environment))
+ statusCode = None
+ info = Info(time.time(), popen, killIntended=False)
+ try:
+ self.runningJobs[jobID] = info
+ try:
+ statusCode = popen.wait()
+ if 0 != statusCode:
+ if statusCode != -9 or not info.killIntended:
+ log.error("Got exit code %i (indicating failure) "
+ "from job %s.", statusCode,
+ self.jobs[jobID])
+ finally:
+ self.runningJobs.pop(jobID)
+ finally:
+ if statusCode is not None and not info.killIntended:
+ self.outputQueue.put((jobID, statusCode,
+ time.time() - startTime))
+ except ResourcePool.AcquisitionTimeoutException as e:
+ log.debug('Could not acquire enough (%s) to run job. Requested: (%s), '
+ 'Avaliable: %s. Sleeping for 10s.', e.resource, e.requested,
+ e.available)
+ with self.aquisitionCondition:
+ # Make threads sleep for the given delay, or until another job finishes.
+ # Whichever is sooner.
+ self.aquisitionCondition.wait(timeout=self.acquisitionRetryDelay)
+ continue
+ else:
+ log.debug('Finished job. self.coreFractions ~ %s and self.memory ~ %s',
+ self.coreFractions.value, self.memory.value)
+ with self.aquisitionCondition:
+ # Wake up sleeping threads
+ self.aquisitionCondition.notifyAll()
+ break
+ log.debug('Exiting worker thread normally.')
+
+ def issueBatchJob(self, jobNode):
+ """
+ Adds the command and resources to a queue to be run.
+ """
+ # Round cores to minCores and apply scale
+ cores = math.ceil(jobNode.cores * self.scale / self.minCores) * self.minCores
+ assert cores <= self.maxCores, ('The job is requesting {} cores, more than the maximum of '
+ '{} cores this batch system was configured with. Scale is '
+ 'set to {}.'.format(cores, self.maxCores, self.scale))
+ assert cores >= self.minCores
+ assert jobNode.memory <= self.maxMemory, ('The job is requesting {} bytes of memory, more than '
+ 'the maximum of {} this batch system was configured '
+ 'with.'.format(jobNode.memory, self.maxMemory))
+
+ self.checkResourceRequest(jobNode.memory, cores, jobNode.disk)
+ log.debug("Issuing the command: %s with memory: %i, cores: %i, disk: %i" % (
+ jobNode.command, jobNode.memory, cores, jobNode.disk))
+ with self.jobIndexLock:
+ jobID = self.jobIndex
+ self.jobIndex += 1
+ self.jobs[jobID] = jobNode.command
+ self.inputQueue.put((jobNode.command, jobID, cores, jobNode.memory,
+ jobNode.disk, self.environment.copy()))
+ return jobID
+
+ def killBatchJobs(self, jobIDs):
+ """
+ Kills jobs by ID
+ """
+ log.debug('Killing jobs: {}'.format(jobIDs))
+ for jobID in jobIDs:
+ if jobID in self.runningJobs:
+ info = self.runningJobs[jobID]
+ info.killIntended = True
+ os.kill(info.popen.pid, 9)
+ while jobID in self.runningJobs:
+ pass
+
+ def getIssuedBatchJobIDs(self):
+ """
+ Just returns all the jobs that have been run, but not yet returned as updated.
+ """
+ return self.jobs.keys()
+
+ def getRunningBatchJobIDs(self):
+ now = time.time()
+ return {jobID: now - info.time for jobID, info in iteritems(self.runningJobs)}
+
+ def shutdown(self):
+ """
+ Cleanly terminate worker threads. Add sentinels to inputQueue equal to maxThreads. Join
+ all worker threads.
+ """
+ # Remove reference to inputQueue (raises exception if inputQueue is used after method call)
+ inputQueue = self.inputQueue
+ self.inputQueue = None
+ for i in xrange(self.numWorkers):
+ inputQueue.put(None)
+ for thread in self.workerThreads:
+ thread.join()
+ BatchSystemSupport.workerCleanup(self.workerCleanupInfo)
+
+ def getUpdatedBatchJob(self, maxWait):
+ """
+ Returns a map of the run jobs and the return value of their processes.
+ """
+ try:
+ item = self.outputQueue.get(timeout=maxWait)
+ except Empty:
+ return None
+ jobID, exitValue, wallTime = item
+ jobCommand = self.jobs.pop(jobID)
+ log.debug("Ran jobID: %s with exit value: %i", jobID, exitValue)
+ return jobID, exitValue, wallTime
+
+ @classmethod
+ def getRescueBatchJobFrequency(cls):
+ """
+ This should not really occur, wihtout an error. To exercise the system we allow it every 90 minutes.
+ """
+ return 5400
+
+class Info(object):
+ # Can't use namedtuple here since killIntended needs to be mutable
+ def __init__(self, startTime, popen, killIntended):
+ self.time = startTime
+ self.popen = popen
+ self.killIntended = killIntended
+
+
+class ResourcePool(object):
+ def __init__(self, initial_value, resourceType, timeout):
+ super(ResourcePool, self).__init__()
+ self.condition = Condition()
+ self.value = initial_value
+ self.resourceType = resourceType
+ self.timeout = timeout
+
+ def acquire(self, amount):
+ with self.condition:
+ startTime = time.time()
+ while amount > self.value:
+ if time.time() - startTime >= self.timeout:
+ # This means the thread timed out waiting for the resource. We exit the nested
+ # context managers in worker to prevent blocking of a resource due to
+ # unavailability of a nested resource request.
+ raise self.AcquisitionTimeoutException(resource=self.resourceType,
+ requested=amount, available=self.value)
+ # Allow 5 seconds to get the resource, else quit through the above if condition.
+ # This wait + timeout is the last thing in the loop such that a request that takes
+ # longer than 5s due to multiple wakes under the 5 second threshold are still
+ # honored.
+ self.condition.wait(timeout=self.timeout)
+ self.value -= amount
+ self.__validate()
+
+ def release(self, amount):
+ with self.condition:
+ self.value += amount
+ self.__validate()
+ self.condition.notifyAll()
+
+ def __validate(self):
+ assert 0 <= self.value
+
+ def __str__(self):
+ return str(self.value)
+
+ def __repr__(self):
+ return "ResourcePool(%i)" % self.value
+
+ @contextmanager
+ def acquisitionOf(self, amount):
+ self.acquire(amount)
+ try:
+ yield
+ finally:
+ self.release(amount)
+
+ class AcquisitionTimeoutException(Exception):
+ """
+ To be raised when a resource request times out.
+ """
+
+ def __init__(self, resource, requested, available):
+ """
+ Creates an instance of this exception that indicates which resource is insufficient for
+ current demands, as well as the amount requested and amount actually available.
+
+ :param str resource: string representing the resource type
+
+ :param int|float requested: the amount of the particular resource requested that resulted
+ in this exception
+
+ :param int|float available: amount of the particular resource actually available
+ """
+ self.requested = requested
+ self.available = available
+ self.resource = resource
+
+
diff --git a/src/toil/batchSystems/slurm.py b/src/toil/batchSystems/slurm.py
new file mode 100644
index 0000000..e8799d4
--- /dev/null
+++ b/src/toil/batchSystems/slurm.py
@@ -0,0 +1,434 @@
+# Copyright (c) 2016 Duke Center for Genomic and Computational Biology
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+import logging
+import os
+from pipes import quote
+import subprocess
+import time
+import math
+from threading import Thread
+
+# Python 3 compatibility imports
+from six.moves.queue import Empty, Queue
+from six import iteritems
+
+from toil.batchSystems import MemoryString
+from toil.batchSystems.abstractBatchSystem import BatchSystemSupport
+
+logger = logging.getLogger(__name__)
+
+sleepSeconds = 1
+
+
+class Worker(Thread):
+ def __init__(self, newJobsQueue, updatedJobsQueue, killQueue, killedJobsQueue, boss):
+ Thread.__init__(self)
+ self.newJobsQueue = newJobsQueue
+ self.updatedJobsQueue = updatedJobsQueue
+ self.killQueue = killQueue
+ self.killedJobsQueue = killedJobsQueue
+ self.waitingJobs = list()
+ self.runningJobs = set()
+ self.boss = boss
+ self.allocatedCpus = dict()
+ self.slurmJobIDs = dict()
+
+ def parse_elapsed(self, elapsed):
+ # slurm returns elapsed time in days-hours:minutes:seconds format
+ # Sometimes it will only return minutes:seconds, so days may be omitted
+ # For ease of calculating, we'll make sure all the delimeters are ':'
+ # Then reverse the list so that we're always counting up from seconds -> minutes -> hours -> days
+ total_seconds = 0
+ try:
+ elapsed = elapsed.replace('-', ':').split(':')
+ elapsed.reverse()
+ seconds_per_unit = [1, 60, 3600, 86400]
+ for index, multiplier in enumerate(seconds_per_unit):
+ if index < len(elapsed):
+ total_seconds += multiplier * int(elapsed[index])
+ except ValueError:
+ pass # slurm may return INVALID instead of a time
+ return total_seconds
+
+ def getRunningJobIDs(self):
+ # Should return a dictionary of Job IDs and number of seconds
+ times = {}
+ currentjobs = dict((str(self.slurmJobIDs[x]), x) for x in self.runningJobs)
+ # currentjobs is a dictionary that maps a slurm job id (string) to our own internal job id
+ # squeue arguments:
+ # -h for no header
+ # --format to get jobid i, state %t and time days-hours:minutes:seconds
+
+ lines = subprocess.check_output(['squeue', '-h', '--format', '%i %t %M']).split('\n')
+ for line in lines:
+ values = line.split()
+ if len(values) < 3:
+ continue
+ slurm_jobid, state, elapsed_time = values
+ if slurm_jobid in currentjobs and state == 'R':
+ seconds_running = self.parse_elapsed(elapsed_time)
+ times[currentjobs[slurm_jobid]] = seconds_running
+
+ return times
+
+ def getSlurmID(self, jobID):
+ if not jobID in self.slurmJobIDs:
+ RuntimeError("Unknown jobID, could not be converted")
+
+ job = self.slurmJobIDs[jobID]
+ return str(job)
+
+ def forgetJob(self, jobID):
+ self.runningJobs.remove(jobID)
+ del self.allocatedCpus[jobID]
+ del self.slurmJobIDs[jobID]
+
+ def killJobs(self):
+ # Load hit list:
+ killList = list()
+ while True:
+ try:
+ jobId = self.killQueue.get(block=False)
+ except Empty:
+ break
+ else:
+ killList.append(jobId)
+
+ if not killList:
+ return False
+
+ # Do the dirty job
+ for jobID in list(killList):
+ if jobID in self.runningJobs:
+ logger.debug('Killing job: %s', jobID)
+ subprocess.check_call(['scancel', self.getSlurmID(jobID)])
+ else:
+ if jobID in self.waitingJobs:
+ self.waitingJobs.remove(jobID)
+ self.killedJobsQueue.put(jobID)
+ killList.remove(jobID)
+
+ # Wait to confirm the kill
+ while killList:
+ for jobID in list(killList):
+ if self.getJobExitCode(self.slurmJobIDs[jobID]) is not None:
+ logger.debug('Adding jobID %s to killedJobsQueue', jobID)
+ self.killedJobsQueue.put(jobID)
+ killList.remove(jobID)
+ self.forgetJob(jobID)
+ if len(killList) > 0:
+ logger.warn("Some jobs weren't killed, trying again in %is.", sleepSeconds)
+ time.sleep(sleepSeconds)
+
+ return True
+
+ def createJobs(self, newJob):
+ activity = False
+ # Load new job id if present:
+ if newJob is not None:
+ self.waitingJobs.append(newJob)
+ # Launch jobs as necessary:
+ while (len(self.waitingJobs) > 0
+ and sum(self.allocatedCpus.values()) < int(self.boss.maxCores)):
+ activity = True
+ jobID, cpu, memory, command = self.waitingJobs.pop(0)
+ sbatch_line = self.prepareSbatch(cpu, memory, jobID) + ['--wrap={}'.format(command)]
+ slurmJobID = self.sbatch(sbatch_line)
+ self.slurmJobIDs[jobID] = slurmJobID
+ self.runningJobs.add(jobID)
+ self.allocatedCpus[jobID] = cpu
+ return activity
+
+ def checkOnJobs(self):
+ activity = False
+ logger.debug('List of running jobs: %r', self.runningJobs)
+ for jobID in list(self.runningJobs):
+ logger.debug("Checking status of internal job id %d", jobID)
+ status = self.getJobExitCode(self.slurmJobIDs[jobID])
+ if status is not None:
+ activity = True
+ self.updatedJobsQueue.put((jobID, status))
+ self.forgetJob(jobID)
+ return activity
+
+ def run(self):
+ while True:
+ activity = False
+ newJob = None
+ if not self.newJobsQueue.empty():
+ activity = True
+ newJob = self.newJobsQueue.get()
+ if newJob is None:
+ logger.debug('Received queue sentinel.')
+ break
+ activity |= self.killJobs()
+ activity |= self.createJobs(newJob)
+ activity |= self.checkOnJobs()
+ if not activity:
+ logger.debug('No activity, sleeping for %is', sleepSeconds)
+ time.sleep(sleepSeconds)
+
+ def prepareSbatch(self, cpu, mem, jobID):
+ # Returns the sbatch command line before the script to run
+ sbatch_line = ['sbatch', '-Q', '-J', 'toil_job_{}'.format(jobID)]
+
+ if self.boss.environment:
+ argList = []
+
+ for k, v in iteritems(self.boss.environment):
+ quoted_value = quote(os.environ[k] if v is None else v)
+ argList.append('{}={}'.format(k, quoted_value))
+
+ sbatch_line.append('--export=' + ','.join(argList))
+
+ if mem is not None:
+ # memory passed in is in bytes, but slurm expects megabytes
+ sbatch_line.append('--mem={}'.format(int(mem) / 2 ** 20))
+ if cpu is not None:
+ sbatch_line.append('--cpus-per-task={}'.format(int(math.ceil(cpu))))
+
+ # "Native extensions" for SLURM (see DRMAA or SAGA)
+ nativeConfig = os.getenv('TOIL_SLURM_ARGS')
+ if nativeConfig is not None:
+ logger.debug("Native SLURM options appended to sbatch from TOIL_SLURM_RESOURCES env. variable: {}".format(nativeConfig))
+ if "--mem" or "--cpus-per-task" in nativeConfig:
+ raise ValueError("Some resource arguments are incompatible: {}".format(nativeConfig))
+
+ sbatch_line.extend([nativeConfig])
+
+ return sbatch_line
+
+ def sbatch(self, sbatch_line):
+ logger.debug("Running %r", sbatch_line)
+ try:
+ output = subprocess.check_output(sbatch_line, stderr=subprocess.STDOUT)
+ # sbatch prints a line like 'Submitted batch job 2954103'
+ result = int(output.strip().split()[-1])
+ logger.debug("sbatch submitted job %d", result)
+ return result
+ except subprocess.CalledProcessError as e:
+ logger.error("sbatch command failed with code %d: %s", e.returncode, e.output)
+ raise e
+ except OSError as e:
+ logger.error("sbatch command failed")
+ raise e
+
+ def getJobExitCode(self, slurmJobID):
+ logger.debug("Getting exit code for slurm job %d", slurmJobID)
+
+ state, rc = self._getJobDetailsFromSacct(slurmJobID)
+
+ if rc == -999:
+ state, rc = self._getJobDetailsFromScontrol(slurmJobID)
+
+ logger.debug("s job state is %s", state)
+ # If Job is in a running state, return None to indicate we don't have an update
+ if state in ('PENDING', 'RUNNING', 'CONFIGURING', 'COMPLETING', 'RESIZING', 'SUSPENDED'):
+ return None
+
+ return rc
+
+ def _getJobDetailsFromSacct(self, slurmJobID):
+ # SLURM job exit codes are obtained by running sacct.
+ args = ['sacct',
+ '-n', # no header
+ '-j', str(slurmJobID), # job
+ '--format', 'State,ExitCode', # specify output columns
+ '-P', # separate columns with pipes
+ '-S', '1970-01-01'] # override start time limit
+
+ process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
+ rc = process.returncode
+
+ if rc != 0:
+ # no accounting system or some other error
+ return (None, -999)
+
+ for line in process.stdout:
+ values = line.strip().split('|')
+ if len(values) < 2:
+ continue
+ state, exitcode = values
+ logger.debug("sacct job state is %s", state)
+ # If Job is in a running state, return None to indicate we don't have an update
+ status, _ = exitcode.split(':')
+ logger.debug("sacct exit code is %s, returning status %s", exitcode, status)
+ return (state, int(status))
+
+ logger.debug("Did not find exit code for job in sacct output")
+ return (None, None)
+
+ def _getJobDetailsFromScontrol(self, slurmJobID):
+ args = ['scontrol',
+ 'show',
+ 'job',
+ str(slurmJobID)]
+
+ process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
+
+ job = dict()
+ for line in process.stdout:
+ values = line.strip().split()
+
+ # If job information is not available an error is issued:
+ # slurm_load_jobs error: Invalid job id specified
+ # There is no job information, so exit.
+ if len(values)>0 and values[0] == 'slurm_load_jobs':
+ return None
+
+ # Output is in the form of many key=value pairs, multiple pairs on each line
+ # and multiple lines in the output. Each pair is pulled out of each line and
+ # added to a dictionary
+ for v in values:
+ bits = v.split('=')
+ job[bits[0]] = bits[1]
+
+ state = job['JobState']
+ try:
+ exitcode = job['ExitCode']
+ if exitcode is not None:
+ status, _ = exitcode.split(':')
+ logger.debug("scontrol exit code is %s, returning status %s", exitcode, status)
+ rc = int(status)
+ else:
+ rc = None
+ except KeyError:
+ rc = None
+
+ return (state, rc)
+
+class SlurmBatchSystem(BatchSystemSupport):
+ """
+ The interface for SLURM
+ """
+
+ @classmethod
+ def supportsWorkerCleanup(cls):
+ return False
+
+ @classmethod
+ def supportsHotDeployment(cls):
+ return False
+
+ def __init__(self, config, maxCores, maxMemory, maxDisk):
+ super(SlurmBatchSystem, self).__init__(config, maxCores, maxMemory, maxDisk)
+ self.slurmResultsFile = self._getResultsFileName(config.jobStore)
+ # Reset the job queue and results (initially, we do this again once we've killed the jobs)
+ self.slurmResultsFileHandle = open(self.slurmResultsFile, 'w')
+ # We lose any previous state in this file, and ensure the files existence
+ self.slurmResultsFileHandle.close()
+ self.currentJobs = set()
+ self.maxCPU, self.maxMEM = self.obtainSystemConstants()
+ self.nextJobID = 0
+ self.newJobsQueue = Queue()
+ self.updatedJobsQueue = Queue()
+ self.killQueue = Queue()
+ self.killedJobsQueue = Queue()
+ self.worker = Worker(self.newJobsQueue, self.updatedJobsQueue, self.killQueue,
+ self.killedJobsQueue, self)
+ self.worker.start()
+
+ def __des__(self):
+ # Closes the file handle associated with the results file.
+ self.slurmResultsFileHandle.close()
+
+ def issueBatchJob(self, jobNode):
+ self.checkResourceRequest(jobNode.memory, jobNode.cores, jobNode.disk)
+ jobID = self.nextJobID
+ self.nextJobID += 1
+ self.currentJobs.add(jobID)
+ self.newJobsQueue.put((jobID, jobNode.cores, jobNode.memory, jobNode.command))
+ logger.debug("Issued the job command: %s with job id: %s ", jobNode.command, str(jobID))
+ return jobID
+
+ def killBatchJobs(self, jobIDs):
+ """
+ Kills the given jobs, represented as Job ids, then checks they are dead by checking
+ they are not in the list of issued jobs.
+ """
+ jobIDs = set(jobIDs)
+ logger.debug('Jobs to be killed: %r', jobIDs)
+ for jobID in jobIDs:
+ self.killQueue.put(jobID)
+ while jobIDs:
+ killedJobId = self.killedJobsQueue.get()
+ if killedJobId is None:
+ break
+ jobIDs.remove(killedJobId)
+ if killedJobId in self.currentJobs:
+ self.currentJobs.remove(killedJobId)
+ if jobIDs:
+ logger.debug('Some kills (%s) still pending, sleeping %is', len(jobIDs),
+ sleepSeconds)
+ time.sleep(sleepSeconds)
+
+ def getIssuedBatchJobIDs(self):
+ """
+ Gets the list of jobs issued to SLURM.
+ """
+ return list(self.currentJobs)
+
+ def getRunningBatchJobIDs(self):
+ return self.worker.getRunningJobIDs()
+
+ def getUpdatedBatchJob(self, maxWait):
+ try:
+ item = self.updatedJobsQueue.get(timeout=maxWait)
+ except Empty:
+ return None
+ logger.debug('UpdatedJobsQueue Item: %s', item)
+ jobID, retcode = item
+ self.currentJobs.remove(jobID)
+ return jobID, retcode, None
+
+ def shutdown(self):
+ """
+ Signals worker to shutdown (via sentinel) then cleanly joins the thread
+ """
+ newJobsQueue = self.newJobsQueue
+ self.newJobsQueue = None
+
+ newJobsQueue.put(None)
+ self.worker.join()
+
+ def getWaitDuration(self):
+ return 1.0
+
+ @classmethod
+ def getRescueBatchJobFrequency(cls):
+ return 30 * 60 # Half an hour
+
+ @staticmethod
+ def obtainSystemConstants():
+ # sinfo -Ne --format '%m,%c'
+ # sinfo arguments:
+ # -N for node-oriented
+ # -h for no header
+ # -e for exact values (e.g. don't return 32+)
+ # --format to get memory, cpu
+ max_cpu = 0
+ max_mem = MemoryString('0')
+ lines = subprocess.check_output(['sinfo', '-Nhe', '--format', '%m %c']).split('\n')
+ for line in lines:
+ values = line.split()
+ if len(values) < 2:
+ continue
+ mem, cpu = values
+ max_cpu = max(max_cpu, int(cpu))
+ max_mem = max(max_mem, MemoryString(mem + 'M'))
+ if max_cpu == 0 or max_mem.byteVal() == 0:
+ RuntimeError('sinfo did not return memory or cpu info')
+ return max_cpu, max_mem
diff --git a/src/toil/common.py b/src/toil/common.py
new file mode 100644
index 0000000..1d4f820
--- /dev/null
+++ b/src/toil/common.py
@@ -0,0 +1,1094 @@
+# Copyright (C) 2015-2016 Regents of the University of California
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+
+import logging
+import os
+import re
+import sys
+import tempfile
+import time
+from argparse import ArgumentParser
+from threading import Thread
+
+# Python 3 compatibility imports
+from six.moves import cPickle
+from six import iteritems
+
+from bd2k.util.exceptions import require
+from bd2k.util.humanize import bytes2human
+
+from toil import logProcessContext
+from toil.lib.bioio import addLoggingOptions, getLogLevelString, setLoggingFromOptions
+from toil.realtimeLogger import RealtimeLogger
+
+logger = logging.getLogger(__name__)
+
+# This constant is set to the default value used on unix for block size (in bytes) when
+# os.stat(<file>).st_blocks is called.
+unixBlockSize = 512
+
+
+class Config(object):
+ """
+ Class to represent configuration operations for a toil workflow run.
+ """
+ def __init__(self):
+ # Core options
+ self.workflowID = None
+ """This attribute uniquely identifies the job store and therefore the workflow. It is
+ necessary in order to distinguish between two consequitive workflows for which
+ self.jobStore is the same, e.g. when a job store name is reused after a previous run has
+ finished sucessfully and its job store has been clean up."""
+ self.workflowAttemptNumber = None
+ self.jobStore = None
+ self.logLevel = getLogLevelString()
+ self.workDir = None
+ self.stats = False
+
+ # Because the stats option needs the jobStore to persist past the end of the run,
+ # the clean default value depends the specified stats option and is determined in setOptions
+ self.clean = None
+ self.cleanWorkDir = None
+ self.clusterStats = None
+
+ #Restarting the workflow options
+ self.restart = False
+
+ #Batch system options
+ self.batchSystem = "singleMachine"
+ self.scale = 1
+ self.mesosMasterAddress = 'localhost:5050'
+ self.parasolCommand = "parasol"
+ self.parasolMaxBatches = 10000
+ self.environment = {}
+
+ #Autoscaling options
+ self.provisioner = None
+ self.nodeType = None
+ self.nodeOptions = None
+ self.minNodes = 0
+ self.maxNodes = 10
+ self.preemptableNodeType = None
+ self.preemptableNodeOptions = None
+ self.minPreemptableNodes = 0
+ self.maxPreemptableNodes = 0
+ self.alphaPacking = 0.8
+ self.betaInertia = 1.2
+ self.scaleInterval = 10
+ self.preemptableCompensation = 0.0
+
+ # Parameters to limit service jobs, so preventing deadlock scheduling scenarios
+ self.maxPreemptableServiceJobs = sys.maxint
+ self.maxServiceJobs = sys.maxint
+ self.deadlockWait = 60 # Wait one minute before declaring a deadlock
+
+ #Resource requirements
+ self.defaultMemory = 2147483648
+ self.defaultCores = 1
+ self.defaultDisk = 2147483648
+ self.readGlobalFileMutableByDefault = False
+ self.defaultPreemptable = False
+ self.maxCores = sys.maxint
+ self.maxMemory = sys.maxint
+ self.maxDisk = sys.maxint
+
+ #Retrying/rescuing jobs
+ self.retryCount = 0
+ self.maxJobDuration = sys.maxint
+ self.rescueJobsFrequency = 3600
+
+ #Misc
+ self.disableCaching = False
+ self.maxLogFileSize = 64000
+ self.writeLogs = None
+ self.writeLogsGzip = None
+ self.sseKey = None
+ self.cseKey = None
+ self.servicePollingInterval = 60
+ self.useAsync = True
+
+ #Debug options
+ self.badWorker = 0.0
+ self.badWorkerFailInterval = 0.01
+
+ def setOptions(self, options):
+ """
+ Creates a config object from the options object.
+ """
+ from bd2k.util.humanize import human2bytes #This import is used to convert
+ #from human readable quantites to integers
+ def setOption(varName, parsingFn=None, checkFn=None):
+ #If options object has the option "varName" specified
+ #then set the "varName" attrib to this value in the config object
+ x = getattr(options, varName, None)
+ if x is not None:
+ if parsingFn is not None:
+ x = parsingFn(x)
+ if checkFn is not None:
+ try:
+ checkFn(x)
+ except AssertionError:
+ raise RuntimeError("The %s option has an invalid value: %s"
+ % (varName, x))
+ setattr(self, varName, x)
+
+ # Function to parse integer from string expressed in different formats
+ h2b = lambda x : human2bytes(str(x))
+
+ def iC(minValue, maxValue=sys.maxint):
+ # Returns function that checks if a given int is in the given half-open interval
+ assert isinstance(minValue, int) and isinstance(maxValue, int)
+ return lambda x: minValue <= x < maxValue
+
+ def fC(minValue, maxValue=None):
+ # Returns function that checks if a given float is in the given half-open interval
+ assert isinstance(minValue, float)
+ if maxValue is None:
+ return lambda x: minValue <= x
+ else:
+ assert isinstance(maxValue, float)
+ return lambda x: minValue <= x < maxValue
+
+ def parseJobStore(s):
+ name, rest = Toil.parseLocator(s)
+ if name == 'file':
+ # We need to resolve relative paths early, on the leader, because the worker process
+ # may have a different working directory than the leader, e.g. under Mesos.
+ return Toil.buildLocator(name, os.path.abspath(rest))
+ else:
+ return s
+
+ #Core options
+ setOption("jobStore", parsingFn=parseJobStore)
+ #TODO: LOG LEVEL STRING
+ setOption("workDir")
+ if self.workDir is not None:
+ self.workDir = os.path.abspath(self.workDir)
+ if not os.path.exists(self.workDir):
+ raise RuntimeError("The path provided to --workDir (%s) does not exist."
+ % self.workDir)
+ setOption("stats")
+ setOption("cleanWorkDir")
+ setOption("clean")
+ if self.stats:
+ if self.clean != "never" and self.clean is not None:
+ raise RuntimeError("Contradicting options passed: Clean flag is set to %s "
+ "despite the stats flag requiring "
+ "the jobStore to be intact at the end of the run. "
+ "Set clean to \'never\'" % self.clean)
+ self.clean = "never"
+ elif self.clean is None:
+ self.clean = "onSuccess"
+ setOption('clusterStats')
+
+ #Restarting the workflow options
+ setOption("restart")
+
+ #Batch system options
+ setOption("batchSystem")
+ setOption("scale", float, fC(0.0))
+ setOption("mesosMasterAddress")
+ setOption("parasolCommand")
+ setOption("parasolMaxBatches", int, iC(1))
+
+ setOption("environment", parseSetEnv)
+
+ #Autoscaling options
+ setOption("provisioner")
+ setOption("nodeType")
+ setOption("nodeOptions")
+ setOption("minNodes", int)
+ setOption("maxNodes", int)
+ setOption("preemptableNodeType")
+ setOption("preemptableNodeOptions")
+ setOption("minPreemptableNodes", int)
+ setOption("maxPreemptableNodes", int)
+ setOption("alphaPacking", float)
+ setOption("betaInertia", float)
+ setOption("scaleInterval", float)
+
+ setOption("preemptableCompensation", float)
+ require(0.0 <= self.preemptableCompensation <= 1.0,
+ '--preemptableCompensation (%f) must be >= 0.0 and <= 1.0',
+ self.preemptableCompensation)
+
+ # Parameters to limit service jobs / detect deadlocks
+ setOption("maxServiceJobs", int)
+ setOption("maxPreemptableServiceJobs", int)
+ setOption("deadlockWait", int)
+
+ # Resource requirements
+ setOption("defaultMemory", h2b, iC(1))
+ setOption("defaultCores", float, fC(1.0))
+ setOption("defaultDisk", h2b, iC(1))
+ setOption("readGlobalFileMutableByDefault")
+ setOption("maxCores", int, iC(1))
+ setOption("maxMemory", h2b, iC(1))
+ setOption("maxDisk", h2b, iC(1))
+ setOption("defaultPreemptable")
+
+ #Retrying/rescuing jobs
+ setOption("retryCount", int, iC(0))
+ setOption("maxJobDuration", int, iC(1))
+ setOption("rescueJobsFrequency", int, iC(1))
+
+ #Misc
+ setOption("disableCaching")
+ setOption("maxLogFileSize", h2b, iC(1))
+ setOption("writeLogs")
+ setOption("writeLogsGzip")
+
+ def checkSse(sseKey):
+ with open(sseKey) as f:
+ assert(len(f.readline().rstrip()) == 32)
+ setOption("sseKey", checkFn=checkSse)
+ setOption("cseKey", checkFn=checkSse)
+ setOption("servicePollingInterval", float, fC(0.0))
+
+ #Debug options
+ setOption("badWorker", float, fC(0.0, 1.0))
+ setOption("badWorkerFailInterval", float, fC(0.0))
+
+ def __eq__(self, other):
+ return self.__dict__ == other.__dict__
+
+ def __hash__(self):
+ return self.__dict__.__hash__()
+
+jobStoreLocatorHelp = ("A job store holds persistent information about the jobs and files in a "
+ "workflow. If the workflow is run with a distributed batch system, the job "
+ "store must be accessible by all worker nodes. Depending on the desired "
+ "job store implementation, the location should be formatted according to "
+ "one of the following schemes:\n\n"
+ "file:<path> where <path> points to a directory on the file systen\n\n"
+ "aws:<region>:<prefix> where <region> is the name of an AWS region like "
+ "us-west-2 and <prefix> will be prepended to the names of any top-level "
+ "AWS resources in use by job store, e.g. S3 buckets.\n\n "
+ "azure:<account>:<prefix>\n\n"
+ "google:<project_id>:<prefix> TODO: explain\n\n"
+ "For backwards compatibility, you may also specify ./foo (equivalent to "
+ "file:./foo or just file:foo) or /bar (equivalent to file:/bar).")
+
+def _addOptions(addGroupFn, config):
+ #
+ #Core options
+ #
+ addOptionFn = addGroupFn("toil core options",
+ "Options to specify the location of the Toil workflow and turn on "
+ "stats collation about the performance of jobs.")
+ addOptionFn('jobStore', type=str,
+ help="The location of the job store for the workflow. " + jobStoreLocatorHelp)
+ addOptionFn("--workDir", dest="workDir", default=None,
+ help="Absolute path to directory where temporary files generated during the Toil "
+ "run should be placed. Temp files and folders will be placed in a directory "
+ "toil-<workflowID> within workDir (The workflowID is generated by Toil and "
+ "will be reported in the workflow logs. Default is determined by the "
+ "user-defined environmental variable TOIL_TEMPDIR, or the environment "
+ "variables (TMPDIR, TEMP, TMP) via mkdtemp. This directory needs to exist on "
+ "all machines running jobs.")
+ addOptionFn("--stats", dest="stats", action="store_true", default=None,
+ help="Records statistics about the toil workflow to be used by 'toil stats'.")
+ addOptionFn("--clean", dest="clean", choices=['always', 'onError', 'never', 'onSuccess'],
+ default=None,
+ help=("Determines the deletion of the jobStore upon completion of the program. "
+ "Choices: 'always', 'onError','never', 'onSuccess'. The --stats option requires "
+ "information from the jobStore upon completion so the jobStore will never be deleted with"
+ "that flag. If you wish to be able to restart the run, choose \'never\' or \'onSuccess\'. "
+ "Default is \'never\' if stats is enabled, and \'onSuccess\' otherwise"))
+ addOptionFn("--cleanWorkDir", dest="cleanWorkDir",
+ choices=['always', 'never', 'onSuccess', 'onError'], default='always',
+ help=("Determines deletion of temporary worker directory upon completion of a job. Choices: 'always', "
+ "'never', 'onSuccess'. Default = always. WARNING: This option should be changed for debugging "
+ "only. Running a full pipeline with this option could fill your disk with intermediate data."))
+ addOptionFn("--clusterStats", dest="clusterStats", nargs='?', action='store',
+ default=None, const=os.getcwd(),
+ help="If enabled, writes out JSON resource usage statistics to a file. "
+ "The default location for this file is the current working directory, "
+ "but an absolute path can also be passed to specify where this file "
+ "should be written. This options only applies when using scalable batch "
+ "systems.")
+ #
+ #Restarting the workflow options
+ #
+ addOptionFn = addGroupFn("toil options for restarting an existing workflow",
+ "Allows the restart of an existing workflow")
+ addOptionFn("--restart", dest="restart", default=None, action="store_true",
+ help="If --restart is specified then will attempt to restart existing workflow "
+ "at the location pointed to by the --jobStore option. Will raise an exception if the workflow does not exist")
+
+ #
+ #Batch system options
+ #
+
+ addOptionFn = addGroupFn("toil options for specifying the batch system",
+ "Allows the specification of the batch system, and arguments to the batch system/big batch system (see below).")
+ addOptionFn("--batchSystem", dest="batchSystem", default=None,
+ help=("The type of batch system to run the job(s) with, currently can be one "
+ "of singleMachine, parasol, gridEngine, lsf or mesos'. default=%s" % config.batchSystem))
+ addOptionFn("--scale", dest="scale", default=None,
+ help=("A scaling factor to change the value of all submitted tasks's submitted cores. "
+ "Used in singleMachine batch system. default=%s" % config.scale))
+ addOptionFn("--mesosMaster", dest="mesosMasterAddress", default=None,
+ help=("The host and port of the Mesos master separated by colon. default=%s" % config.mesosMasterAddress))
+ addOptionFn("--parasolCommand", dest="parasolCommand", default=None,
+ help="The name or path of the parasol program. Will be looked up on PATH "
+ "unless it starts with a slashdefault=%s" % config.parasolCommand)
+ addOptionFn("--parasolMaxBatches", dest="parasolMaxBatches", default=None,
+ help="Maximum number of job batches the Parasol batch is allowed to create. One "
+ "batch is created for jobs with a a unique set of resource requirements. "
+ "default=%i" % config.parasolMaxBatches)
+
+ #
+ #Auto scaling options
+ #
+ addOptionFn = addGroupFn("toil options for autoscaling the cluster of worker nodes",
+ "Allows the specification of the minimum and maximum number of nodes "
+ "in an autoscaled cluster, as well as parameters to control the "
+ "level of provisioning.")
+
+ addOptionFn("--provisioner", dest="provisioner", choices=['cgcloud', 'aws'],
+ help="The provisioner for cluster auto-scaling. The currently supported choices are"
+ "'cgcloud' or 'aws'. The default is %s." % config.provisioner)
+
+ for preemptable in (False, True):
+ def _addOptionFn(*name, **kwargs):
+ name = list(name)
+ if preemptable:
+ name.insert(-1, 'preemptable' )
+ name = ''.join((s[0].upper() + s[1:]) if i else s for i, s in enumerate(name))
+ terms = re.compile(r'\{([^{}]+)\}')
+ _help = kwargs.pop('help')
+ _help = ''.join((term.split('|') * 2)[int(preemptable)] for term in terms.split(_help))
+ addOptionFn('--' + name, dest=name,
+ help=_help + ' The default is %s.' % getattr(config, name),
+ **kwargs)
+
+ _addOptionFn('nodeType', metavar='TYPE',
+ help="Node type for {non-|}preemptable nodes. The syntax depends on the "
+ "provisioner used. For the cgcloud and AWS provisioners this is the name "
+ "of an EC2 instance type{|, followed by a colon and the price in dollar "
+ "to bid for a spot instance}, for example 'c3.8xlarge{|:0.42}'.")
+ _addOptionFn('nodeOptions', metavar='OPTIONS',
+ help="Provisioning options for the {non-|}preemptable node type. The syntax "
+ "depends on the provisioner used. Neither the CGCloud nor the AWS "
+ "provisioner support any node options.")
+ for p, q in [('min', 'Minimum'), ('max', 'Maximum')]:
+ _addOptionFn(p, 'nodes', default=None, metavar='NUM',
+ help=q + " number of {non-|}preemptable nodes in the cluster, if using "
+ "auto-scaling.")
+
+ # TODO: DESCRIBE THE FOLLOWING TWO PARAMETERS
+ addOptionFn("--alphaPacking", dest="alphaPacking", default=None,
+ help=("The total number of nodes estimated to be required to compute the issued "
+ "jobs is multiplied by the alpha packing parameter to produce the actual "
+ "number of nodes requested. Values of this coefficient greater than one will "
+ "tend to over provision and values less than one will under provision. default=%s" % config.alphaPacking))
+ addOptionFn("--betaInertia", dest="betaInertia", default=None,
+ help=("A smoothing parameter to prevent unnecessary oscillations in the "
+ "number of provisioned nodes. If the number of nodes is within the beta "
+ "inertia of the currently provisioned number of nodes then no change is made "
+ "to the number of requested nodes. default=%s" % config.betaInertia))
+ addOptionFn("--scaleInterval", dest="scaleInterval", default=None,
+ help=("The interval (seconds) between assessing if the scale of"
+ " the cluster needs to change. default=%s" % config.scaleInterval))
+ addOptionFn("--preemptableCompensation", dest="preemptableCompensation",
+ default=None,
+ help=("The preference of the autoscaler to replace preemptable nodes with "
+ "non-preemptable nodes, when preemptable nodes cannot be started for some "
+ "reason. Defaults to %s. This value must be between 0.0 and 1.0, inclusive. "
+ "A value of 0.0 disables such compensation, a value of 0.5 compensates two "
+ "missing preemptable nodes with a non-preemptable one. A value of 1.0 "
+ "replaces every missing pre-emptable node with a non-preemptable one." %
+ config.preemptableCompensation))
+
+ #
+ # Parameters to limit service jobs / detect service deadlocks
+ #
+ addOptionFn = addGroupFn("toil options for limiting the number of service jobs and detecting service deadlocks",
+ "Allows the specification of the maximum number of service jobs "
+ "in a cluster. By keeping this limited "
+ " we can avoid all the nodes being occupied with services, so causing a deadlock")
+ addOptionFn("--maxServiceJobs", dest="maxServiceJobs", default=None,
+ help=("The maximum number of service jobs that can be run concurrently, excluding service jobs running on preemptable nodes. default=%s" % config.maxServiceJobs))
+ addOptionFn("--maxPreemptableServiceJobs", dest="maxPreemptableServiceJobs", default=None,
+ help=("The maximum number of service jobs that can run concurrently on preemptable nodes. default=%s" % config.maxPreemptableServiceJobs))
+ addOptionFn("--deadlockWait", dest="deadlockWait", default=None,
+ help=("The minimum number of seconds to observe the cluster stuck running only the same service jobs before throwing a deadlock exception. default=%s" % config.deadlockWait))
+
+ #
+ #Resource requirements
+ #
+ addOptionFn = addGroupFn("toil options for cores/memory requirements",
+ "The options to specify default cores/memory requirements (if not "
+ "specified by the jobs themselves), and to limit the total amount of "
+ "memory/cores requested from the batch system.")
+ addOptionFn('--defaultMemory', dest='defaultMemory', default=None, metavar='INT',
+ help='The default amount of memory to request for a job. Only applicable to jobs '
+ 'that do not specify an explicit value for this requirement. Standard '
+ 'suffixes like K, Ki, M, Mi, G or Gi are supported. Default is %s' %
+ bytes2human( config.defaultMemory, symbols='iec' ))
+ addOptionFn('--defaultCores', dest='defaultCores', default=None, metavar='FLOAT',
+ help='The default number of CPU cores to dedicate a job. Only applicable to jobs '
+ 'that do not specify an explicit value for this requirement. Fractions of a '
+ 'core (for example 0.1) are supported on some batch systems, namely Mesos '
+ 'and singleMachine. Default is %.1f ' % config.defaultCores)
+ addOptionFn('--defaultDisk', dest='defaultDisk', default=None, metavar='INT',
+ help='The default amount of disk space to dedicate a job. Only applicable to jobs '
+ 'that do not specify an explicit value for this requirement. Standard '
+ 'suffixes like K, Ki, M, Mi, G or Gi are supported. Default is %s' %
+ bytes2human( config.defaultDisk, symbols='iec' ))
+ assert not config.defaultPreemptable, 'User would be unable to reset config.defaultPreemptable'
+ addOptionFn('--defaultPreemptable', dest='defaultPreemptable', action='store_true')
+ addOptionFn("--readGlobalFileMutableByDefault", dest="readGlobalFileMutableByDefault",
+ action='store_true', default=None, help='Toil disallows modification of read '
+ 'global files by default. This flag makes '
+ 'it makes read file mutable by default, '
+ 'however it also defeats the purpose of '
+ 'shared caching via hard links to save '
+ 'space. Default is False')
+ addOptionFn('--maxCores', dest='maxCores', default=None, metavar='INT',
+ help='The maximum number of CPU cores to request from the batch system at any one '
+ 'time. Standard suffixes like K, Ki, M, Mi, G or Gi are supported. Default '
+ 'is %s' % bytes2human(config.maxCores, symbols='iec'))
+ addOptionFn('--maxMemory', dest='maxMemory', default=None, metavar='INT',
+ help="The maximum amount of memory to request from the batch system at any one "
+ "time. Standard suffixes like K, Ki, M, Mi, G or Gi are supported. Default "
+ "is %s" % bytes2human( config.maxMemory, symbols='iec'))
+ addOptionFn('--maxDisk', dest='maxDisk', default=None, metavar='INT',
+ help='The maximum amount of disk space to request from the batch system at any '
+ 'one time. Standard suffixes like K, Ki, M, Mi, G or Gi are supported. '
+ 'Default is %s' % bytes2human(config.maxDisk, symbols='iec'))
+
+ #
+ #Retrying/rescuing jobs
+ #
+ addOptionFn = addGroupFn("toil options for rescuing/killing/restarting jobs", \
+ "The options for jobs that either run too long/fail or get lost \
+ (some batch systems have issues!)")
+ addOptionFn("--retryCount", dest="retryCount", default=None,
+ help=("Number of times to retry a failing job before giving up and "
+ "labeling job failed. default=%s" % config.retryCount))
+ addOptionFn("--maxJobDuration", dest="maxJobDuration", default=None,
+ help=("Maximum runtime of a job (in seconds) before we kill it "
+ "(this is a lower bound, and the actual time before killing "
+ "the job may be longer). default=%s" % config.maxJobDuration))
+ addOptionFn("--rescueJobsFrequency", dest="rescueJobsFrequency", default=None,
+ help=("Period of time to wait (in seconds) between checking for "
+ "missing/overlong jobs, that is jobs which get lost by the batch system. Expert parameter. default=%s" % config.rescueJobsFrequency))
+
+ #
+ #Misc options
+ #
+ addOptionFn = addGroupFn("toil miscellaneous options", "Miscellaneous options")
+ addOptionFn('--disableCaching', dest='disableCaching', action='store_true', default=False,
+ help='Disables caching in the file store. This flag must be set to use '
+ 'a batch system that does not support caching such as Grid Engine, Parasol, '
+ 'LSF, or Slurm')
+ addOptionFn("--maxLogFileSize", dest="maxLogFileSize", default=None,
+ help=("The maximum size of a job log file to keep (in bytes), log files "
+ "larger than this will be truncated to the last X bytes. Setting "
+ "this option to zero will prevent any truncation. Setting this "
+ "option to a negative value will truncate from the beginning."
+ "Default=%s" % bytes2human(config.maxLogFileSize)))
+ addOptionFn("--writeLogs", dest="writeLogs", nargs='?', action='store',
+ default=None, const=os.getcwd(),
+ help="Write worker logs received by the leader into their own files at the "
+ "specified path. The current working directory will be used if a path is "
+ "not specified explicitly. Note: By default "
+ "only the logs of failed jobs are returned to leader. Set log level to "
+ "'debug' to get logs back from successful jobs, and adjust 'maxLogFileSize' "
+ "to control the truncation limit for worker logs.")
+ addOptionFn("--writeLogsGzip", dest="writeLogsGzip", nargs='?', action='store',
+ default=None, const=os.getcwd(),
+ help="Identical to --writeLogs except the logs files are gzipped on the leader.")
+ addOptionFn("--realTimeLogging", dest="realTimeLogging", action="store_true", default=False,
+ help="Enable real-time logging from workers to masters")
+
+ addOptionFn("--sseKey", dest="sseKey", default=None,
+ help="Path to file containing 32 character key to be used for server-side encryption on awsJobStore. SSE will "
+ "not be used if this flag is not passed.")
+ addOptionFn("--cseKey", dest="cseKey", default=None,
+ help="Path to file containing 256-bit key to be used for client-side encryption on "
+ "azureJobStore. By default, no encryption is used.")
+ addOptionFn("--setEnv", '-e', metavar='NAME=VALUE or NAME',
+ dest="environment", default=[], action="append",
+ help="Set an environment variable early on in the worker. If VALUE is omitted, "
+ "it will be looked up in the current environment. Independently of this "
+ "option, the worker will try to emulate the leader's environment before "
+ "running a job. Using this option, a variable can be injected into the "
+ "worker process itself before it is started.")
+ addOptionFn("--servicePollingInterval", dest="servicePollingInterval", default=None,
+ help="Interval of time service jobs wait between polling for the existence"
+ " of the keep-alive flag (defailt=%s)" % config.servicePollingInterval)
+ #
+ #Debug options
+ #
+ addOptionFn = addGroupFn("toil debug options", "Debug options")
+ addOptionFn("--badWorker", dest="badWorker", default=None,
+ help=("For testing purposes randomly kill 'badWorker' proportion of jobs using SIGKILL, default=%s" % config.badWorker))
+ addOptionFn("--badWorkerFailInterval", dest="badWorkerFailInterval", default=None,
+ help=("When killing the job pick uniformly within the interval from 0.0 to "
+ "'badWorkerFailInterval' seconds after the worker starts, default=%s" % config.badWorkerFailInterval))
+
+def addOptions(parser, config=Config()):
+ """
+ Adds toil options to a parser object, either optparse or argparse.
+ """
+ # Wrapper function that allows toil to be used with both the optparse and
+ # argparse option parsing modules
+ addLoggingOptions(parser) # This adds the logging stuff.
+ if isinstance(parser, ArgumentParser):
+ def addGroup(headingString, bodyString):
+ return parser.add_argument_group(headingString, bodyString).add_argument
+ _addOptions(addGroup, config)
+ else:
+ raise RuntimeError("Unanticipated class passed to addOptions(), %s. Expecting "
+ "argparse.ArgumentParser" % parser.__class__)
+
+
+class Toil(object):
+ """
+ A context manager that represents a Toil workflow, specifically the batch system, job store,
+ and its configuration.
+ """
+
+ def __init__(self, options):
+ """
+ Initialize a Toil object from the given options. Note that this is very light-weight and
+ that the bulk of the work is done when the context is entered.
+
+ :param argparse.Namespace options: command line options specified by the user
+ """
+ super(Toil, self).__init__()
+ self.options = options
+ self.config = None
+ """
+ :type: toil.common.Config
+ """
+ self._jobStore = None
+ """
+ :type: toil.jobStores.abstractJobStore.AbstractJobStore
+ """
+ self._batchSystem = None
+ """
+ :type: toil.batchSystems.abstractBatchSystem.AbstractBatchSystem
+ """
+ self._provisioner = None
+ """
+ :type: toil.provisioners.abstractProvisioner.AbstractProvisioner
+ """
+ self._jobCache = dict()
+ self._inContextManager = False
+
+ def __enter__(self):
+ """
+ Derive configuration from the command line options, load the job store and, on restart,
+ consolidate the derived configuration with the one from the previous invocation of the
+ workflow.
+ """
+ setLoggingFromOptions(self.options)
+ config = Config()
+ config.setOptions(self.options)
+ jobStore = self.getJobStore(config.jobStore)
+ if not config.restart:
+ config.workflowAttemptNumber = 0
+ jobStore.initialize(config)
+ else:
+ jobStore.resume()
+ # Merge configuration from job store with command line options
+ config = jobStore.config
+ config.setOptions(self.options)
+ config.workflowAttemptNumber += 1
+ jobStore.writeConfig()
+ self.config = config
+ self._jobStore = jobStore
+ self._inContextManager = True
+ return self
+
+ # noinspection PyUnusedLocal
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ """
+ Clean up after a workflow invocation. Depending on the configuration, delete the job store.
+ """
+ try:
+ if (exc_type is not None and self.config.clean == "onError" or
+ exc_type is None and self.config.clean == "onSuccess" or
+ self.config.clean == "always"):
+ logger.info("Attempting to delete the job store")
+ self._jobStore.destroy()
+ logger.info("Successfully deleted the job store")
+ except Exception as e:
+ if exc_type is None:
+ raise
+ else:
+ logger.exception('The following error was raised during clean up:')
+ self._inContextManager = False
+ return False # let exceptions through
+
+ def start(self, rootJob):
+ """
+ Invoke a Toil workflow with the given job as the root for an initial run. This method
+ must be called in the body of a ``with Toil(...) as toil:`` statement. This method should
+ not be called more than once for a workflow that has not finished.
+
+ :param toil.job.Job rootJob: The root job of the workflow
+ :return: The root job's return value
+ """
+ self._assertContextManagerUsed()
+ if self.config.restart:
+ raise ToilRestartException('A Toil workflow can only be started once. Use '
+ 'Toil.restart() to resume it.')
+
+ self._batchSystem = self.createBatchSystem(self.config)
+ self._setupHotDeployment(rootJob.getUserScript())
+ try:
+ self._setBatchSystemEnvVars()
+ self._serialiseEnv()
+ self._cacheAllJobs()
+
+ # Pickle the promised return value of the root job, then write the pickled promise to
+ # a shared file, where we can find and unpickle it at the end of the workflow.
+ # Unpickling the promise will automatically substitute the promise for the actual
+ # return value.
+ with self._jobStore.writeSharedFileStream('rootJobReturnValue') as fH:
+ rootJob.prepareForPromiseRegistration(self._jobStore)
+ promise = rootJob.rv()
+ cPickle.dump(promise, fH)
+
+ # Setup the first wrapper and cache it
+ rootJobGraph = rootJob._serialiseFirstJob(self._jobStore)
+ self._cacheJob(rootJobGraph)
+
+ self._setProvisioner()
+ return self._runMainLoop(rootJobGraph)
+ finally:
+ self._shutdownBatchSystem()
+
+ def restart(self):
+ """
+ Restarts a workflow that has been interrupted. This method should be called if and only
+ if a workflow has previously been started and has not finished.
+
+ :return: The root job's return value
+ """
+ self._assertContextManagerUsed()
+ if not self.config.restart:
+ raise ToilRestartException('A Toil workflow must be initiated with Toil.start(), '
+ 'not restart().')
+
+ self._batchSystem = self.createBatchSystem(self.config)
+ self._setupHotDeployment()
+ try:
+ self._setBatchSystemEnvVars()
+ self._serialiseEnv()
+ self._cacheAllJobs()
+ self._setProvisioner()
+ rootJobGraph = self._jobStore.clean(jobCache=self._jobCache)
+ return self._runMainLoop(rootJobGraph)
+ finally:
+ self._shutdownBatchSystem()
+
+ def _setProvisioner(self):
+ if self.config.provisioner is None:
+ self._provisioner = None
+ elif self.config.provisioner == 'cgcloud':
+ logger.info('Using cgcloud provisioner.')
+ from toil.provisioners.cgcloud.provisioner import CGCloudProvisioner
+ self._provisioner = CGCloudProvisioner(self.config, self._batchSystem)
+ elif self.config.provisioner == 'aws':
+ logger.info('Using AWS provisioner.')
+ from bd2k.util.ec2.credentials import enable_metadata_credential_caching
+ from toil.provisioners.aws.awsProvisioner import AWSProvisioner
+ enable_metadata_credential_caching()
+ self._provisioner = AWSProvisioner(self.config, self._batchSystem)
+ else:
+ # Command line parser shold have checked argument validity already
+ assert False, self.config.provisioner
+
+ @classmethod
+ def getJobStore(cls, locator):
+ """
+ Create an instance of the concrete job store implementation that matches the given locator.
+
+ :param str locator: The location of the job store to be represent by the instance
+
+ :return: an instance of a concrete subclass of AbstractJobStore
+ :rtype: toil.jobStores.abstractJobStore.AbstractJobStore
+ """
+ name, rest = cls.parseLocator(locator)
+ if name == 'file':
+ from toil.jobStores.fileJobStore import FileJobStore
+ return FileJobStore(rest)
+ elif name == 'aws':
+ from bd2k.util.ec2.credentials import enable_metadata_credential_caching
+ from toil.jobStores.aws.jobStore import AWSJobStore
+ enable_metadata_credential_caching()
+ return AWSJobStore(rest)
+ elif name == 'azure':
+ from toil.jobStores.azureJobStore import AzureJobStore
+ return AzureJobStore(rest)
+ elif name == 'google':
+ from toil.jobStores.googleJobStore import GoogleJobStore
+ projectID, namePrefix = rest.split(':', 1)
+ return GoogleJobStore(namePrefix, projectID)
+ else:
+ raise RuntimeError("Unknown job store implementation '%s'" % name)
+
+ @staticmethod
+ def parseLocator(locator):
+ if locator[0] in '/.' or ':' not in locator:
+ return 'file', locator
+ else:
+ try:
+ name, rest = locator.split(':', 1)
+ except ValueError:
+ raise RuntimeError('Invalid job store locator syntax.')
+ else:
+ return name, rest
+
+ @staticmethod
+ def buildLocator(name, rest):
+ assert ':' not in name
+ return name + ':' + rest
+
+ @classmethod
+ def resumeJobStore(cls, locator):
+ jobStore = cls.getJobStore(locator)
+ jobStore.resume()
+ return jobStore
+
+ @staticmethod
+ def createBatchSystem(config):
+ """
+ Creates an instance of the batch system specified in the given config.
+
+ :param toil.common.Config config: the current configuration
+
+ :rtype: batchSystems.abstractBatchSystem.AbstractBatchSystem
+
+ :return: an instance of a concrete subclass of AbstractBatchSystem
+ """
+ kwargs = dict(config=config,
+ maxCores=config.maxCores,
+ maxMemory=config.maxMemory,
+ maxDisk=config.maxDisk)
+
+ if config.batchSystem == 'parasol':
+ from toil.batchSystems.parasol import ParasolBatchSystem
+ batchSystemClass = ParasolBatchSystem
+
+ elif config.batchSystem == 'single_machine' or config.batchSystem == 'singleMachine':
+ from toil.batchSystems.singleMachine import SingleMachineBatchSystem
+ batchSystemClass = SingleMachineBatchSystem
+
+ elif config.batchSystem == 'gridengine' or config.batchSystem == 'gridEngine':
+ from toil.batchSystems.gridengine import GridengineBatchSystem
+ batchSystemClass = GridengineBatchSystem
+
+ elif config.batchSystem == 'lsf' or config.batchSystem == 'LSF':
+ from toil.batchSystems.lsf import LSFBatchSystem
+ batchSystemClass = LSFBatchSystem
+
+ elif config.batchSystem == 'mesos' or config.batchSystem == 'Mesos':
+ from toil.batchSystems.mesos.batchSystem import MesosBatchSystem
+ batchSystemClass = MesosBatchSystem
+
+ kwargs['masterAddress'] = config.mesosMasterAddress
+
+ elif config.batchSystem == 'slurm' or config.batchSystem == 'Slurm':
+ from toil.batchSystems.slurm import SlurmBatchSystem
+ batchSystemClass = SlurmBatchSystem
+
+ else:
+ raise RuntimeError('Unrecognised batch system: %s' % config.batchSystem)
+
+ if not config.disableCaching and not batchSystemClass.supportsWorkerCleanup():
+ raise RuntimeError('%s currently does not support shared caching. Set the '
+ '--disableCaching flag if you want to '
+ 'use this batch system.' % config.batchSystem)
+ logger.info('Using the %s' %
+ re.sub("([a-z])([A-Z])", "\g<1> \g<2>", batchSystemClass.__name__).lower())
+
+ return batchSystemClass(**kwargs)
+
+ def _setupHotDeployment(self, userScript=None):
+ """
+ Determine the user script, save it to the job store and inject a reference to the saved
+ copy into the batch system such that it can hot-deploy the resource on the worker
+ nodes.
+
+ :param toil.resource.ModuleDescriptor userScript: the module descriptor referencing the
+ user script. If None, it will be looked up in the job store.
+ """
+ if userScript is not None:
+ # This branch is hit when a workflow is being started
+ if userScript.belongsToToil:
+ logger.info('User script %s belongs to Toil. No need to hot-deploy it.', userScript)
+ userScript = None
+ else:
+ if self._batchSystem.supportsHotDeployment():
+ # Note that by saving the ModuleDescriptor, and not the Resource we allow for
+ # redeploying a potentially modified user script on workflow restarts.
+ with self._jobStore.writeSharedFileStream('userScript') as f:
+ cPickle.dump(userScript, f, protocol=cPickle.HIGHEST_PROTOCOL)
+ else:
+ from toil.batchSystems.singleMachine import SingleMachineBatchSystem
+ if not isinstance(self._batchSystem, SingleMachineBatchSystem):
+ logger.warn('Batch system does not support hot-deployment. The user '
+ 'script %s will have to be present at the same location on '
+ 'every worker.', userScript)
+ userScript = None
+ else:
+ # This branch is hit on restarts
+ from toil.jobStores.abstractJobStore import NoSuchFileException
+ try:
+ with self._jobStore.readSharedFileStream('userScript') as f:
+ userScript = cPickle.load(f)
+ except NoSuchFileException:
+ logger.info('User script neither set explicitly nor present in the job store.')
+ userScript = None
+ if userScript is None:
+ logger.info('No user script to hot-deploy.')
+ else:
+ logger.debug('Saving user script %s as a resource', userScript)
+ userScriptResource = userScript.saveAsResourceTo(self._jobStore)
+ logger.debug('Injecting user script %s into batch system.', userScriptResource)
+ self._batchSystem.setUserScript(userScriptResource)
+ thread = Thread(target=self._refreshUserScript,
+ name='refreshUserScript',
+ kwargs=dict(userScriptResource=userScriptResource))
+ thread.daemon = True
+ thread.start()
+
+ def _refreshUserScript(self, userScriptResource):
+ """
+ Periodically refresh the user script in the job store to prevent credential
+ expiration from causing the public URL to the user script to expire.
+ """
+ while True:
+ # Boto refreshes IAM credentials if they will be expiring within the next five
+ # minutes, but it will only check the expiry if and when credentials are needed to
+ # sign an actual AWS request. This means that we should be refreshing the user script
+ # at least every 5 minutes. Note that refreshing the user script in the job store
+ # involves an S3 request requiring credentials and therefore also triggers refreshing
+ # the IAM role credentials. In the worst case, refresh() is called 5 minutes plus
+ # epsilon before IAM credential expiration. The resource is refreshed three minutes
+ # after that, leaving two minutes plus epsilon generating a new signed URL, this time
+ # with refreshed IAM role credentials. This consideration only applies to AWS and
+ # Boto2, of course. See https://github.com/BD2KGenomics/toil/issues/1372.
+ time.sleep(3 * 60)
+ logger.debug('Refreshing user script resource %s.', userScriptResource)
+ userScriptResource = userScriptResource.refresh(self._jobStore)
+ logger.debug('Injecting refreshed user script %s into batch system.', userScriptResource)
+ self._batchSystem.setUserScript(userScriptResource)
+
+ def importFile(self, srcUrl, sharedFileName=None):
+ self._assertContextManagerUsed()
+ return self._jobStore.importFile(srcUrl, sharedFileName=sharedFileName)
+
+ def exportFile(self, jobStoreFileID, dstUrl):
+ self._assertContextManagerUsed()
+ self._jobStore.exportFile(jobStoreFileID, dstUrl)
+
+ def _setBatchSystemEnvVars(self):
+ """
+ Sets the environment variables required by the job store and those passed on command line.
+ """
+ for envDict in (self._jobStore.getEnv(), self.config.environment):
+ for k, v in iteritems(envDict):
+ self._batchSystem.setEnv(k, v)
+
+ def _serialiseEnv(self):
+ """
+ Puts the environment in a globally accessible pickle file.
+ """
+ # Dump out the environment of this process in the environment pickle file.
+ with self._jobStore.writeSharedFileStream("environment.pickle") as fileHandle:
+ cPickle.dump(os.environ, fileHandle, cPickle.HIGHEST_PROTOCOL)
+ logger.info("Written the environment for the jobs to the environment file")
+
+ def _cacheAllJobs(self):
+ """
+ Downloads all jobs in the current job store into self.jobCache.
+ """
+ logger.info('Caching all jobs in job store')
+ self._jobCache = {jobGraph.jobStoreID: jobGraph for jobGraph in self._jobStore.jobs()}
+ logger.info('{} jobs downloaded.'.format(len(self._jobCache)))
+
+ def _cacheJob(self, job):
+ """
+ Adds given job to current job cache.
+
+ :param toil.jobGraph.JobGraph job: job to be added to current job cache
+ """
+ self._jobCache[job.jobStoreID] = job
+
+ @staticmethod
+ def getWorkflowDir(workflowID, configWorkDir=None):
+ """
+ Returns a path to the directory where worker directories and the cache will be located
+ for this workflow.
+
+ :param str workflowID: Unique identifier for the workflow
+ :param str configWorkDir: Value passed to the program using the --workDir flag
+ :return: Path to the workflow directory
+ :rtype: str
+ """
+ workDir = configWorkDir or os.getenv('TOIL_WORKDIR') or tempfile.gettempdir()
+ if not os.path.exists(workDir):
+ raise RuntimeError("The directory specified by --workDir or TOIL_WORKDIR (%s) does not "
+ "exist." % workDir)
+ # Create the workflow dir
+ workflowDir = os.path.join(workDir, 'toil-%s' % workflowID)
+ try:
+ # Directory creation is atomic
+ os.mkdir(workflowDir)
+ except OSError as err:
+ if err.errno != 17:
+ # The directory exists if a previous worker set it up.
+ raise
+ else:
+ logger.info('Created the workflow directory at %s' % workflowDir)
+ return workflowDir
+
+ def _runMainLoop(self, rootJob):
+ """
+ Runs the main loop with the given job.
+ :param toil.job.Job rootJob: The root job for the workflow.
+ :rtype: Any
+ """
+ logProcessContext(self.config)
+
+ with RealtimeLogger(self._batchSystem,
+ level=self.options.logLevel if self.options.realTimeLogging else None):
+ # FIXME: common should not import from leader
+ from toil.leader import Leader
+ return Leader(config=self.config,
+ batchSystem=self._batchSystem,
+ provisioner=self._provisioner,
+ jobStore=self._jobStore,
+ rootJob=rootJob,
+ jobCache=self._jobCache).run()
+
+ def _shutdownBatchSystem(self):
+ """
+ Shuts down current batch system if it has been created.
+ """
+ assert self._batchSystem is not None
+
+ startTime = time.time()
+ logger.debug('Shutting down batch system ...')
+ self._batchSystem.shutdown()
+ logger.debug('... finished shutting down the batch system in %s seconds.'
+ % (time.time() - startTime))
+
+ def _assertContextManagerUsed(self):
+ if not self._inContextManager:
+ raise ToilContextManagerException()
+
+
+class ToilRestartException(Exception):
+ def __init__(self, message):
+ super(ToilRestartException, self).__init__(message)
+
+
+class ToilContextManagerException(Exception):
+ def __init__(self):
+ super(ToilContextManagerException, self).__init__(
+ 'This method cannot be called outside the "with Toil(...)" context manager.')
+
+# Nested functions can't have doctests so we have to make this global
+
+
+def parseSetEnv(l):
+ """
+ Parses a list of strings of the form "NAME=VALUE" or just "NAME" into a dictionary. Strings
+ of the latter from will result in dictionary entries whose value is None.
+
+ :type l: list[str]
+ :rtype: dict[str,str]
+
+ >>> parseSetEnv([])
+ {}
+ >>> parseSetEnv(['a'])
+ {'a': None}
+ >>> parseSetEnv(['a='])
+ {'a': ''}
+ >>> parseSetEnv(['a=b'])
+ {'a': 'b'}
+ >>> parseSetEnv(['a=a', 'a=b'])
+ {'a': 'b'}
+ >>> parseSetEnv(['a=b', 'c=d'])
+ {'a': 'b', 'c': 'd'}
+ >>> parseSetEnv(['a=b=c'])
+ {'a': 'b=c'}
+ >>> parseSetEnv([''])
+ Traceback (most recent call last):
+ ...
+ ValueError: Empty name
+ >>> parseSetEnv(['=1'])
+ Traceback (most recent call last):
+ ...
+ ValueError: Empty name
+ """
+ d = dict()
+ for i in l:
+ try:
+ k, v = i.split('=', 1)
+ except ValueError:
+ k, v = i, None
+ if not k:
+ raise ValueError('Empty name')
+ d[k] = v
+ return d
+
+
+def cacheDirName(workflowID):
+ """
+ :return: Name of the cache directory.
+ """
+ return 'cache-' + workflowID
+
+
+def getDirSizeRecursively(dirPath):
+ """
+ This method will walk through a directory and return the cumulative filesize in bytes of all
+ the files in the directory and its subdirectories.
+
+ :param dirPath: Path to a directory.
+ :return: cumulative size in bytes of all files in the directory.
+ :rtype: int
+ """
+ totalSize = 0
+ # The value from running stat on each linked file is equal. To prevent the same file
+ # from being counted multiple times, we save the inodes of files that have more than one
+ # nlink associated with them.
+ seenInodes = set()
+ for dirPath, dirNames, fileNames in os.walk(dirPath):
+ folderSize = 0
+ for f in fileNames:
+ fp = os.path.join(dirPath, f)
+ fileStats = os.stat(fp)
+ if fileStats.st_nlink > 1:
+ if fileStats.st_ino not in seenInodes:
+ folderSize += fileStats.st_blocks * unixBlockSize
+ seenInodes.add(fileStats.st_ino)
+ else:
+ continue
+ else:
+ folderSize += fileStats.st_blocks * unixBlockSize
+ totalSize += folderSize
+ return totalSize
diff --git a/src/toil/cwl/__init__.py b/src/toil/cwl/__init__.py
new file mode 100644
index 0000000..8b13789
--- /dev/null
+++ b/src/toil/cwl/__init__.py
@@ -0,0 +1 @@
+
diff --git a/src/toil/cwl/conftest.py b/src/toil/cwl/conftest.py
new file mode 100644
index 0000000..2eca681
--- /dev/null
+++ b/src/toil/cwl/conftest.py
@@ -0,0 +1,22 @@
+# Copyright (C) 2015-2016 Regents of the University of California
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# https://pytest.org/latest/example/pythoncollection.html
+
+collect_ignore = []
+
+try:
+ import cwltool
+except ImportError:
+ collect_ignore.append("cwltoil.py")
diff --git a/src/toil/cwl/cwltoil.py b/src/toil/cwl/cwltoil.py
new file mode 100755
index 0000000..d380d9c
--- /dev/null
+++ b/src/toil/cwl/cwltoil.py
@@ -0,0 +1,740 @@
+# Implement support for Common Workflow Language (CWL) for Toil.
+#
+# Copyright (C) 2015 Curoverse, Inc
+# Copyright (C) 2016 UCSC Computational Genomics Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from toil.job import Job
+from toil.common import Toil
+from toil.version import baseVersion
+from toil.lib.bioio import setLoggingFromOptions
+
+from argparse import ArgumentParser
+import cwltool.load_tool
+import cwltool.main
+import cwltool.workflow
+import cwltool.expression
+import cwltool.builder
+import cwltool.resolver
+import cwltool.stdfsaccess
+from cwltool.pathmapper import adjustFiles
+from cwltool.process import shortname, adjustFilesWithSecondary, fillInDefaults, compute_checksums
+from cwltool.utils import aslist
+import schema_salad.validate as validate
+import schema_salad.ref_resolver
+import os
+import tempfile
+import json
+import sys
+import logging
+import copy
+import shutil
+import functools
+
+# Python 3 compatibility imports
+from six.moves import xrange
+from six import iteritems, string_types
+import six.moves.urllib.parse as urlparse
+
+cwllogger = logging.getLogger("cwltool")
+
+# The job object passed into CWLJob and CWLWorkflow
+# is a dict mapping to tuple of (key, dict)
+# the final dict is derived by evaluating each
+# tuple looking up the key in the supplied dict.
+#
+# This is necessary because Toil jobs return a single value (a dict)
+# but CWL permits steps to have multiple output parameters that may
+# feed into multiple other steps. This transformation maps the key in the
+# output object to the correct key of the input object.
+
+class IndirectDict(dict):
+ pass
+
+class MergeInputs(object):
+ def __init__(self, sources):
+ self.sources = sources
+ def resolve(self):
+ raise NotImplementedError()
+
+class MergeInputsNested(MergeInputs):
+ def resolve(self):
+ return [v[1][v[0]] for v in self.sources]
+
+class MergeInputsFlattened(MergeInputs):
+ def resolve(self):
+ r = []
+ for v in self.sources:
+ v = v[1][v[0]]
+ if isinstance(v, list):
+ r.extend(v)
+ else:
+ r.append(v)
+ return r
+
+class StepValueFrom(object):
+ def __init__(self, expr, inner, req):
+ self.expr = expr
+ self.inner = inner
+ self.req = req
+
+ def do_eval(self, inputs, ctx):
+ return cwltool.expression.do_eval(self.expr, inputs, self.req,
+ None, None, {}, context=ctx)
+
+def resolve_indirect_inner(d):
+ if isinstance(d, IndirectDict):
+ r = {}
+ for k, v in d.items():
+ if isinstance(v, MergeInputs):
+ r[k] = v.resolve()
+ else:
+ r[k] = v[1][v[0]]
+ return r
+ else:
+ return d
+
+def resolve_indirect(d):
+ inner = IndirectDict() if isinstance(d, IndirectDict) else {}
+ needEval = False
+ for k, v in iteritems(d):
+ if isinstance(v, StepValueFrom):
+ inner[k] = v.inner
+ needEval = True
+ else:
+ inner[k] = v
+ res = resolve_indirect_inner(inner)
+ if needEval:
+ ev = {}
+ for k, v in iteritems(d):
+ if isinstance(v, StepValueFrom):
+ ev[k] = v.do_eval(res, res[k])
+ else:
+ ev[k] = res[k]
+ return ev
+ else:
+ return res
+
+def getFile(fileStore, dir, fileTuple, index=None, export=False, primary=None, rename_collision=False):
+ # File literal outputs with no path, from writeFile
+ if fileTuple is None:
+ raise cwltool.process.UnsupportedRequirement("CWL expression file inputs not yet supported in Toil")
+ fileStoreID, fileName = fileTuple
+
+ if rename_collision is False:
+ if primary:
+ dir = os.path.dirname(primary)
+ else:
+ dir = tempfile.mkdtemp(dir=dir)
+
+ dstPath = os.path.join(dir, fileName)
+ if rename_collision:
+ n = 1
+ while os.path.exists(dstPath):
+ n += 1
+ stem, ext = os.path.splitext(dstPath)
+ stem = "%s_%s" % (stem, n)
+ dstPath = stem + ext
+
+ if export:
+ fileStore.exportFile(fileStoreID, "file://" + dstPath)
+ else:
+ srcPath = fileStore.readGlobalFile(fileStoreID)
+ if srcPath != dstPath:
+ if copy:
+ shutil.copyfile(srcPath, dstPath)
+ else:
+ if os.path.exists(dstPath):
+ if index.get(dstPath, None) != fileStoreID:
+ raise Exception("Conflicting filesStoreID %s and %s both trying to link to %s" % (index.get(dstPath, None), fileStoreID, dstPath))
+ else:
+ os.symlink(srcPath, dstPath)
+ index[dstPath] = fileStoreID
+ return dstPath
+
+def writeFile(writeFunc, index, x):
+ # Toil fileStore references are tuples of pickle and internal file
+ if isinstance(x, tuple):
+ return x
+ # File literal outputs with no path, we don't write these and will fail
+ # with unsupportedRequirement when retrieving later with getFile
+ elif x.startswith("_:"):
+ return None
+ else:
+ if x not in index:
+ if not urlparse.urlparse(x).scheme:
+ rp = os.path.realpath(x)
+ else:
+ rp = x
+ try:
+ index[x] = (writeFunc(rp), os.path.basename(x))
+ except Exception as e:
+ cwllogger.error("Got exception '%s' while copying '%s'", e, x)
+ raise
+ return index[x]
+
+def computeFileChecksums(fs_access, f):
+ # File literal inputs with no path, no checksum
+ if isinstance(f, dict) and f.get("location", "").startswith("_:"):
+ return f
+ else:
+ return compute_checksums(fs_access, f)
+
+def addFilePartRefs(p):
+ """Provides new v1.0 functionality for referencing file parts.
+ """
+ if p.get("class") == "File" and p.get("path"):
+ dirname, basename = os.path.split(p["path"])
+ nameroot, nameext = os.path.splitext(basename)
+ for k, v in [("dirname", dirname,), ("basename", basename),
+ ("nameroot", nameroot), ("nameext", nameext)]:
+ if k not in p:
+ p[k] = v
+ return p
+
+def locToPath(p):
+ """Back compatibility -- handle converting locations into paths.
+ """
+ if "path" not in p and "location" in p:
+ p["path"] = p["location"].replace("file:", "")
+ return p
+
+def pathToLoc(p):
+ """Associate path with location.
+
+ v1.0 should be specifying location but older YAML uses path
+ -- this provides back compatibility.
+ """
+ if "path" in p:
+ p["location"] = p["path"]
+ return p
+
+class ResolveIndirect(Job):
+ def __init__(self, cwljob):
+ super(ResolveIndirect, self).__init__()
+ self.cwljob = cwljob
+
+ def run(self, fileStore):
+ return resolve_indirect(self.cwljob)
+
+
+class CWLJob(Job):
+ """Execute a CWL tool wrapper."""
+
+ def __init__(self, tool, cwljob, **kwargs):
+ builder = cwltool.builder.Builder()
+ builder.job = {}
+ builder.requirements = []
+ builder.outdir = None
+ builder.tmpdir = None
+ builder.timeout = 0
+ builder.resources = {}
+ req = tool.evalResources(builder, {})
+ self.cwltool = remove_pickle_problems(tool)
+ # pass the default of None if basecommand is empty
+ unitName = self.cwltool.tool.get("baseCommand", None)
+ if isinstance(unitName, (list, tuple)):
+ unitName = ' '.join(unitName)
+ super(CWLJob, self).__init__(cores=req["cores"],
+ memory=(req["ram"]*1024*1024),
+ disk=((req["tmpdirSize"]*1024*1024) + (req["outdirSize"]*1024*1024)),
+ unitName=unitName)
+ #super(CWLJob, self).__init__()
+ self.cwljob = cwljob
+ try:
+ self.jobName = str(self.cwltool.tool['id'])
+ except KeyError:
+ # fall back to the Toil defined class name if the tool doesn't have an identifier
+ pass
+ self.executor_options = kwargs
+
+ def run(self, fileStore):
+ cwljob = resolve_indirect(self.cwljob)
+ fillInDefaults(self.cwltool.tool["inputs"], cwljob)
+
+ inpdir = os.path.join(fileStore.getLocalTempDir(), "inp")
+ outdir = os.path.join(fileStore.getLocalTempDir(), "out")
+ tmpdir = os.path.join(fileStore.getLocalTempDir(), "tmp")
+ os.mkdir(inpdir)
+ os.mkdir(outdir)
+ os.mkdir(tmpdir)
+
+ # Copy input files out of the global file store.
+ index = {}
+ adjustFilesWithSecondary(cwljob, functools.partial(getFile, fileStore, inpdir, index=index))
+
+ # Run the tool
+ opts = copy.deepcopy(self.executor_options)
+ # Exports temporary directory for batch systems that reset TMPDIR
+ os.environ["TMPDIR"] = os.path.realpath(opts.pop("tmpdir", None) or tmpdir)
+ output = cwltool.main.single_job_executor(self.cwltool, cwljob,
+ basedir=os.getcwd(),
+ outdir=outdir,
+ tmpdir=tmpdir,
+ tmpdir_prefix="tmp",
+ **opts)
+ cwltool.builder.adjustDirObjs(output, locToPath)
+ cwltool.builder.adjustFileObjs(output, locToPath)
+ cwltool.builder.adjustFileObjs(output, functools.partial(computeFileChecksums,
+ cwltool.stdfsaccess.StdFsAccess(outdir)))
+ # Copy output files into the global file store.
+ adjustFiles(output, functools.partial(writeFile, fileStore.writeGlobalFile, {}))
+
+ return output
+
+
+def makeJob(tool, jobobj, **kwargs):
+ if tool.tool["class"] == "Workflow":
+ wfjob = CWLWorkflow(tool, jobobj, **kwargs)
+ followOn = ResolveIndirect(wfjob.rv())
+ wfjob.addFollowOn(followOn)
+ return (wfjob, followOn)
+ else:
+ job = CWLJob(tool, jobobj, **kwargs)
+ return (job, job)
+
+
+class CWLScatter(Job):
+ def __init__(self, step, cwljob, **kwargs):
+ super(CWLScatter, self).__init__()
+ self.step = step
+ self.cwljob = cwljob
+ self.executor_options = kwargs
+
+ def flat_crossproduct_scatter(self, joborder, scatter_keys, outputs, postScatterEval):
+ scatter_key = shortname(scatter_keys[0])
+ l = len(joborder[scatter_key])
+ for n in xrange(0, l):
+ jo = copy.copy(joborder)
+ jo[scatter_key] = joborder[scatter_key][n]
+ if len(scatter_keys) == 1:
+ jo = postScatterEval(jo)
+ (subjob, followOn) = makeJob(self.step.embedded_tool, jo, **self.executor_options)
+ self.addChild(subjob)
+ outputs.append(followOn.rv())
+ else:
+ self.flat_crossproduct_scatter(jo, scatter_keys[1:], outputs, postScatterEval)
+
+ def nested_crossproduct_scatter(self, joborder, scatter_keys, postScatterEval):
+ scatter_key = shortname(scatter_keys[0])
+ l = len(joborder[scatter_key])
+ outputs = []
+ for n in xrange(0, l):
+ jo = copy.copy(joborder)
+ jo[scatter_key] = joborder[scatter_key][n]
+ if len(scatter_keys) == 1:
+ jo = postScatterEval(jo)
+ (subjob, followOn) = makeJob(self.step.embedded_tool, jo, **self.executor_options)
+ self.addChild(subjob)
+ outputs.append(followOn.rv())
+ else:
+ outputs.append(self.nested_crossproduct_scatter(jo, scatter_keys[1:], postScatterEval))
+ return outputs
+
+ def run(self, fileStore):
+ cwljob = resolve_indirect(self.cwljob)
+
+ if isinstance(self.step.tool["scatter"], string_types):
+ scatter = [self.step.tool["scatter"]]
+ else:
+ scatter = self.step.tool["scatter"]
+
+ scatterMethod = self.step.tool.get("scatterMethod", None)
+ if len(scatter) == 1:
+ scatterMethod = "dotproduct"
+ outputs = []
+
+ valueFrom = {shortname(i["id"]): i["valueFrom"] for i in self.step.tool["inputs"] if "valueFrom" in i}
+ def postScatterEval(io):
+ shortio = {shortname(k): v for k, v in iteritems(io)}
+ def valueFromFunc(k, v):
+ if k in valueFrom:
+ return cwltool.expression.do_eval(
+ valueFrom[k], shortio, self.step.requirements,
+ None, None, {}, context=v)
+ else:
+ return v
+ return {k: valueFromFunc(k, v) for k,v in io.items()}
+
+ if scatterMethod == "dotproduct":
+ for i in xrange(0, len(cwljob[shortname(scatter[0])])):
+ copyjob = copy.copy(cwljob)
+ for sc in [shortname(x) for x in scatter]:
+ copyjob[sc] = cwljob[sc][i]
+ copyjob = postScatterEval(copyjob)
+ (subjob, followOn) = makeJob(self.step.embedded_tool, copyjob, **self.executor_options)
+ self.addChild(subjob)
+ outputs.append(followOn.rv())
+ elif scatterMethod == "nested_crossproduct":
+ outputs = self.nested_crossproduct_scatter(cwljob, scatter, postScatterEval)
+ elif scatterMethod == "flat_crossproduct":
+ self.flat_crossproduct_scatter(cwljob, scatter, outputs, postScatterEval)
+ else:
+ if scatterMethod:
+ raise validate.ValidationException(
+ "Unsupported complex scatter type '%s'" % scatterMethod)
+ else:
+ raise validate.ValidationException(
+ "Must provide scatterMethod to scatter over multiple inputs")
+
+ return outputs
+
+
+class CWLGather(Job):
+ def __init__(self, step, outputs):
+ super(CWLGather, self).__init__()
+ self.step = step
+ self.outputs = outputs
+
+ def allkeys(self, obj, keys):
+ if isinstance(obj, dict):
+ for k in obj.keys():
+ keys.add(k)
+ elif isinstance(obj, list):
+ for l in obj:
+ self.allkeys(l, keys)
+
+ def extract(self, obj, k):
+ if isinstance(obj, dict):
+ return obj.get(k)
+ elif isinstance(obj, list):
+ cp = []
+ for l in obj:
+ cp.append(self.extract(l, k))
+ return cp
+
+ def run(self, fileStore):
+ outobj = {}
+ keys = set()
+ self.allkeys(self.outputs, keys)
+
+ for k in keys:
+ outobj[k] = self.extract(self.outputs, k)
+
+ return outobj
+
+
+class SelfJob(object):
+ """Fake job object to facilitate implementation of CWLWorkflow.run()"""
+
+ def __init__(self, j, v):
+ self.j = j
+ self.v = v
+
+ def rv(self):
+ return self.v
+
+ def addChild(self, c):
+ return self.j.addChild(c)
+
+ def hasChild(self, c):
+ return self.j.hasChild(c)
+
+def remove_pickle_problems(obj):
+ """doc_loader does not pickle correctly, causing Toil errors, remove from objects.
+ """
+ if hasattr(obj, "doc_loader"):
+ obj.doc_loader = None
+ if hasattr(obj, "embedded_tool"):
+ obj.embedded_tool = remove_pickle_problems(obj.embedded_tool)
+ if hasattr(obj, "steps"):
+ obj.steps = [remove_pickle_problems(s) for s in obj.steps]
+ return obj
+
+class CWLWorkflow(Job):
+ """Traverse a CWL workflow graph and schedule a Toil job graph."""
+
+ def __init__(self, cwlwf, cwljob, **kwargs):
+ super(CWLWorkflow, self).__init__()
+ self.cwlwf = cwlwf
+ self.cwljob = cwljob
+ self.executor_options = kwargs
+ self.cwlwf = remove_pickle_problems(self.cwlwf)
+
+ def run(self, fileStore):
+ cwljob = resolve_indirect(self.cwljob)
+
+ # `promises` dict
+ # from: each parameter (workflow input or step output)
+ # that may be used as a "source" for a step input workflow output
+ # parameter
+ # to: the job that will produce that value.
+ promises = {}
+
+ # `jobs` dict from step id to job that implements that step.
+ jobs = {}
+
+ for inp in self.cwlwf.tool["inputs"]:
+ promises[inp["id"]] = SelfJob(self, cwljob)
+
+ alloutputs_fufilled = False
+ while not alloutputs_fufilled:
+ # Iteratively go over the workflow steps, scheduling jobs as their
+ # dependencies can be fufilled by upstream workflow inputs or
+ # step outputs. Loop exits when the workflow outputs
+ # are satisfied.
+
+ alloutputs_fufilled = True
+
+ for step in self.cwlwf.steps:
+ if step.tool["id"] not in jobs:
+ stepinputs_fufilled = True
+ for inp in step.tool["inputs"]:
+ if "source" in inp:
+ for s in aslist(inp["source"]):
+ if s not in promises:
+ stepinputs_fufilled = False
+ if stepinputs_fufilled:
+ jobobj = {}
+
+ for inp in step.tool["inputs"]:
+ key = shortname(inp["id"])
+ if "source" in inp:
+ if inp.get("linkMerge") or len(aslist(inp["source"])) > 1:
+ linkMerge = inp.get("linkMerge", "merge_nested")
+ if linkMerge == "merge_nested":
+ jobobj[key] = (
+ MergeInputsNested([(shortname(s), promises[s].rv())
+ for s in aslist(inp["source"])]))
+ elif linkMerge == "merge_flattened":
+ jobobj[key] = (
+ MergeInputsFlattened([(shortname(s), promises[s].rv())
+ for s in aslist(inp["source"])]))
+ else:
+ raise validate.ValidationException(
+ "Unsupported linkMerge '%s'", linkMerge)
+ else:
+ jobobj[key] = (
+ shortname(inp["source"]), promises[inp["source"]].rv())
+ elif "default" in inp:
+ d = copy.copy(inp["default"])
+ jobobj[key] = ("default", {"default": d})
+
+ if "valueFrom" in inp and "scatter" not in step.tool:
+ if key in jobobj:
+ jobobj[key] = StepValueFrom(inp["valueFrom"],
+ jobobj[key],
+ self.cwlwf.requirements)
+ else:
+ jobobj[key] = StepValueFrom(inp["valueFrom"],
+ ("None", {"None": None}),
+ self.cwlwf.requirements)
+
+ if "scatter" in step.tool:
+ wfjob = CWLScatter(step, IndirectDict(jobobj), **self.executor_options)
+ followOn = CWLGather(step, wfjob.rv())
+ wfjob.addFollowOn(followOn)
+ else:
+ (wfjob, followOn) = makeJob(step.embedded_tool, IndirectDict(jobobj),
+ **self.executor_options)
+
+ jobs[step.tool["id"]] = followOn
+
+ connected = False
+ for inp in step.tool["inputs"]:
+ for s in aslist(inp.get("source", [])):
+ if not promises[s].hasChild(wfjob):
+ promises[s].addChild(wfjob)
+ connected = True
+ if not connected:
+ # workflow step has default inputs only, isn't connected to other jobs,
+ # so add it as child of workflow.
+ self.addChild(wfjob)
+
+ for out in step.tool["outputs"]:
+ promises[out["id"]] = followOn
+
+ for inp in step.tool["inputs"]:
+ for s in aslist(inp.get("source", [])):
+ if s not in promises:
+ alloutputs_fufilled = False
+
+ # may need a test
+ for out in self.cwlwf.tool["outputs"]:
+ if "source" in out:
+ if out["source"] not in promises:
+ alloutputs_fufilled = False
+
+ outobj = {}
+ for out in self.cwlwf.tool["outputs"]:
+ outobj[shortname(out["id"])] = (shortname(out["outputSource"]), promises[out["outputSource"]].rv())
+
+ return IndirectDict(outobj)
+
+
+cwltool.process.supportedProcessRequirements = ("DockerRequirement",
+ "ExpressionEngineRequirement",
+ "InlineJavascriptRequirement",
+ "SchemaDefRequirement",
+ "EnvVarRequirement",
+ "CreateFileRequirement",
+ "SubworkflowFeatureRequirement",
+ "ScatterFeatureRequirement",
+ "ShellCommandRequirement",
+ "MultipleInputFeatureRequirement",
+ "StepInputExpressionRequirement",
+ "ResourceRequirement")
+
+def unsupportedInputCheck(p):
+ """Check for file inputs we don't current support in Toil:
+
+ - Directories
+ - File literals
+ """
+ if p.get("class") == "Directory":
+ raise cwltool.process.UnsupportedRequirement("CWL Directory inputs not yet supported in Toil")
+ if p.get("contents") and (not p.get("path") and not p.get("location")):
+ raise cwltool.process.UnsupportedRequirement("CWL File literals not yet supported in Toil")
+
+def unsupportedDefaultCheck(tool):
+ """Check for file-based defaults, which don't get staged correctly in Toil.
+ """
+ for inp in tool["in"]:
+ if isinstance(inp, dict) and "default" in inp:
+ if isinstance(inp["default"], dict) and inp["default"].get("class") == "File":
+ raise cwltool.process.UnsupportedRequirement("CWL default file inputs not yet supported in Toil")
+
+def main(args=None, stdout=sys.stdout):
+ parser = ArgumentParser()
+ Job.Runner.addToilOptions(parser)
+ parser.add_argument("cwltool", type=str)
+ parser.add_argument("cwljob", type=str, nargs="?", default=None)
+
+ # Will override the "jobStore" positional argument, enables
+ # user to select jobStore or get a default from logic one below.
+ parser.add_argument("--jobStore", type=str)
+ parser.add_argument("--conformance-test", action="store_true")
+ parser.add_argument("--no-container", action="store_true")
+ parser.add_argument("--quiet", dest="logLevel", action="store_const", const="ERROR")
+ parser.add_argument("--basedir", type=str)
+ parser.add_argument("--outdir", type=str, default=os.getcwd())
+ parser.add_argument("--version", action='version', version=baseVersion)
+ parser.add_argument("--preserve-environment", type=str, nargs='+',
+ help="Preserve specified environment variables when running CommandLineTools",
+ metavar=("VAR1,VAR2"),
+ default=("PATH",),
+ dest="preserve_environment")
+
+ # mkdtemp actually creates the directory, but
+ # toil requires that the directory not exist,
+ # so make it and delete it and allow
+ # toil to create it again (!)
+ workdir = tempfile.mkdtemp()
+ os.rmdir(workdir)
+
+ if args is None:
+ args = sys.argv[1:]
+
+ options = parser.parse_args([workdir] + args)
+
+ use_container = not options.no_container
+
+ setLoggingFromOptions(options)
+ if options.logLevel:
+ cwllogger.setLevel(options.logLevel)
+
+ try:
+ t = cwltool.load_tool.load_tool(options.cwltool, cwltool.workflow.defaultMakeTool,
+ resolver=cwltool.resolver.tool_resolver)
+ except cwltool.process.UnsupportedRequirement as e:
+ logging.error(e)
+ return 33
+
+ if options.conformance_test:
+ loader = schema_salad.ref_resolver.Loader({})
+ else:
+ jobloaderctx = {"path": {"@type": "@id"}, "format": {"@type": "@id"}}
+ jobloaderctx.update(t.metadata.get("$namespaces", {}))
+ loader = schema_salad.ref_resolver.Loader(jobloaderctx)
+
+ if options.cwljob:
+ uri = (options.cwljob if urlparse.urlparse(options.cwljob).scheme
+ else "file://" + os.path.abspath(options.cwljob))
+ job, _ = loader.resolve_ref(uri, checklinks=False)
+ else:
+ job = {}
+
+ try:
+ cwltool.builder.adjustDirObjs(job, unsupportedInputCheck)
+ cwltool.builder.adjustFileObjs(job, unsupportedInputCheck)
+ except cwltool.process.UnsupportedRequirement as e:
+ logging.error(e)
+ return 33
+
+ cwltool.builder.adjustDirObjs(job, pathToLoc)
+ cwltool.builder.adjustFileObjs(job, pathToLoc)
+
+ if type(t) == int:
+ return t
+
+ fillInDefaults(t.tool["inputs"], job)
+
+ if options.conformance_test:
+ adjustFiles(job, lambda x: x.replace("file://", ""))
+ stdout.write(json.dumps(
+ cwltool.main.single_job_executor(t, job, basedir=options.basedir,
+ tmpdir_prefix="tmp",
+ conformance_test=True, use_container=use_container,
+ preserve_environment=options.preserve_environment), indent=4))
+ return 0
+
+ if not options.basedir:
+ options.basedir = os.path.dirname(os.path.abspath(options.cwljob or options.cwltool))
+
+ outdir = options.outdir
+
+ with Toil(options) as toil:
+ def importDefault(tool):
+ cwltool.builder.adjustDirObjs(tool, locToPath)
+ cwltool.builder.adjustFileObjs(tool, locToPath)
+ adjustFiles(tool, lambda x: "file://%s" % x if not urlparse.urlparse(x).scheme else x)
+ adjustFiles(tool, functools.partial(writeFile, toil.importFile, {}))
+ t.visit(importDefault)
+
+ if options.restart:
+ outobj = toil.restart()
+ else:
+ basedir = os.path.dirname(os.path.abspath(options.cwljob or options.cwltool))
+ builder = t._init_job(job, basedir=basedir, use_container=use_container)
+ (wf1, wf2) = makeJob(t, {}, use_container=use_container, preserve_environment=options.preserve_environment, tmpdir=os.path.realpath(outdir))
+ try:
+ if isinstance(wf1, CWLWorkflow):
+ [unsupportedDefaultCheck(s.tool) for s in wf1.cwlwf.steps]
+ except cwltool.process.UnsupportedRequirement as e:
+ logging.error(e)
+ return 33
+
+ cwltool.builder.adjustDirObjs(builder.job, locToPath)
+ cwltool.builder.adjustFileObjs(builder.job, locToPath)
+ adjustFiles(builder.job, lambda x: "file://%s" % os.path.abspath(os.path.join(basedir, x))
+ if not urlparse.urlparse(x).scheme else x)
+ cwltool.builder.adjustDirObjs(builder.job, pathToLoc)
+ cwltool.builder.adjustFileObjs(builder.job, pathToLoc)
+ cwltool.builder.adjustFileObjs(builder.job, addFilePartRefs)
+ adjustFiles(builder.job, functools.partial(writeFile, toil.importFile, {}))
+ wf1.cwljob = builder.job
+ outobj = toil.start(wf1)
+
+ outobj = resolve_indirect(outobj)
+
+ try:
+ adjustFilesWithSecondary(outobj, functools.partial(getFile, toil, outdir, index={},
+ export=True, rename_collision=True))
+ except cwltool.process.UnsupportedRequirement as e:
+ logging.error(e)
+ return 33
+
+ stdout.write(json.dumps(outobj, indent=4))
+
+ return 0
diff --git a/src/toil/fileStore.py b/src/toil/fileStore.py
new file mode 100644
index 0000000..6b9d13d
--- /dev/null
+++ b/src/toil/fileStore.py
@@ -0,0 +1,1884 @@
+# Copyright (C) 2015-2016 Regents of the University of California
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import, print_function
+from abc import abstractmethod, ABCMeta
+
+from bd2k.util.objects import abstractclassmethod
+
+import base64
+from collections import namedtuple, defaultdict
+
+import dill
+import errno
+import logging
+import os
+import shutil
+import stat
+import tempfile
+import time
+import uuid
+
+from contextlib import contextmanager
+from fcntl import flock, LOCK_EX, LOCK_UN
+from functools import partial
+from hashlib import sha1
+from threading import Thread, Semaphore, Event
+
+# Python 3 compatibility imports
+from six.moves.queue import Empty, Queue
+from six.moves import xrange
+
+from bd2k.util.humanize import bytes2human
+from toil.common import cacheDirName, getDirSizeRecursively
+from toil.lib.bioio import makePublicDir
+from toil.resource import ModuleDescriptor
+
+logger = logging.getLogger(__name__)
+
+
+class DeferredFunction(namedtuple('DeferredFunction', 'function args kwargs name module')):
+ """
+ >>> df = DeferredFunction.create(defaultdict, None, {'x':1}, y=2)
+ >>> df
+ DeferredFunction(defaultdict, ...)
+ >>> df.invoke() == defaultdict(None, x=1, y=2)
+ True
+ """
+ @classmethod
+ def create(cls, function, *args, **kwargs):
+ """
+ Capture the given callable and arguments as an instance of this class.
+
+ :param callable function: The deferred action to take in the form of a function
+ :param tuple args: Non-keyword arguments to the function
+ :param dict kwargs: Keyword arguments to the function
+ """
+ # The general principle is to deserialize as late as possible, i.e. when the function is
+ # to be invoked, as that will avoid redundantly deserializing deferred functions for
+ # concurrently running jobs when the cache state is loaded from disk. By implication we
+ # should serialize as early as possible. We need to serialize the function as well as its
+ # arguments.
+ return cls(*map(dill.dumps, (function, args, kwargs)),
+ name=function.__name__,
+ module=ModuleDescriptor.forModule(function.__module__).globalize())
+
+ def invoke(self):
+ """
+ Invoke the captured function with the captured arguments.
+ """
+ logger.debug('Running deferred function %s.', self)
+ self.module.makeLoadable()
+ function, args, kwargs = map(dill.loads, (self.function, self.args, self.kwargs))
+ return function(*args, **kwargs)
+
+ def __str__(self):
+ return '%s(%s, ...)' % (self.__class__.__name__, self.name)
+
+ __repr__ = __str__
+
+
+class FileStore(object):
+ """
+ An abstract base class to represent the interface between a worker and the job store. Concrete
+ subclasses will be used to manage temporary files, read and write files from the job store and
+ log messages, passed as argument to the :meth:`toil.job.Job.run` method.
+ """
+ # Variables used for syncing reads/writes
+ _pendingFileWritesLock = Semaphore()
+ _pendingFileWrites = set()
+ _terminateEvent = Event() # Used to signify crashes in threads
+
+ __metaclass__ = ABCMeta
+
+ def __init__(self, jobStore, jobGraph, localTempDir, inputBlockFn):
+ self.jobStore = jobStore
+ self.jobGraph = jobGraph
+ self.localTempDir = os.path.abspath(localTempDir)
+ self.workFlowDir = os.path.dirname(self.localTempDir)
+ self.jobName = self.jobGraph.command.split()[1]
+ self.inputBlockFn = inputBlockFn
+ self.loggingMessages = []
+ self.filesToDelete = set()
+ self.jobsToDelete = set()
+
+ @staticmethod
+ def createFileStore(jobStore, jobGraph, localTempDir, inputBlockFn, caching):
+ fileStoreCls = CachingFileStore if caching else NonCachingFileStore
+ return fileStoreCls(jobStore, jobGraph, localTempDir, inputBlockFn)
+
+ @abstractmethod
+ @contextmanager
+ def open(self, job):
+ """
+ The context manager used to conduct tasks prior-to, and after a job has been run.
+
+ :param toil.job.Job job: The job instance of the toil job to run.
+ """
+ raise NotImplementedError()
+
+ # Functions related to temp files and directories
+ def getLocalTempDir(self):
+ """
+ Get a new local temporary directory in which to write files that persist for the duration of
+ the job.
+
+ :return: The absolute path to a new local temporary directory. This directory will exist
+ for the duration of the job only, and is guaranteed to be deleted once the job terminates,
+ removing all files it contains recursively.
+ :rtype: str
+ """
+ return os.path.abspath(tempfile.mkdtemp(prefix="t", dir=self.localTempDir))
+
+ def getLocalTempFile(self):
+ """
+ Get a new local temporary file that will persist for the duration of the job.
+
+ :return: The absolute path to a local temporary file. This file will exist for the duration
+ of the job only, and is guaranteed to be deleted once the job terminates.
+ :rtype: str
+ """
+ handle, tmpFile = tempfile.mkstemp(prefix="tmp", suffix=".tmp", dir=self.localTempDir)
+ os.close(handle)
+ return os.path.abspath(tmpFile)
+
+ def getLocalTempFileName(self):
+ """
+ Get a valid name for a new local file. Don't actually create a file at the path.
+
+ :return: Path to valid file
+ :rtype: str
+ """
+ # Create, and then delete a temp file. Creating will guarantee you a unique, unused
+ # file name. There is a very, very, very low chance that another job will create the
+ # same file name in the span of this one being deleted and then being used by the user.
+ tempFile = self.getLocalTempFile()
+ os.remove(tempFile)
+ return tempFile
+
+ # Functions related to reading, writing and removing files to/from the job store
+ @abstractmethod
+ def writeGlobalFile(self, localFileName, cleanup=False):
+ """
+ Takes a file (as a path) and uploads it to the job store.
+
+ :param string localFileName: The path to the local file to upload.
+ :param Boolean cleanup: if True then the copy of the global file will be deleted once the
+ job and all its successors have completed running. If not the global file must be
+ deleted manually.
+ :return: an ID that can be used to retrieve the file.
+ :rtype: FileID
+ """
+ raise NotImplementedError()
+
+ def writeGlobalFileStream(self, cleanup=False):
+ """
+ Similar to writeGlobalFile, but allows the writing of a stream to the job store.
+ The yielded file handle does not need to and should not be closed explicitly.
+
+ :param Boolean cleanup: is as in :func:`toil.fileStore.FileStore.writeGlobalFile`.
+ :return: A context manager yielding a tuple of
+ 1) a file handle which can be written to and
+ 2) the ID of the resulting file in the job store.
+ """
+ # TODO: Make this work with FileID
+ return self.jobStore.writeFileStream(None if not cleanup else self.jobGraph.jobStoreID)
+
+ @abstractmethod
+ def readGlobalFile(self, fileStoreID, userPath=None, cache=True, mutable=None):
+ """
+ Downloads a file described by fileStoreID from the file store to the local directory.
+
+ If a user path is specified, it is used as the destination. If a user path isn't
+ specified, the file is stored in the local temp directory with an encoded name.
+
+ :param FileID fileStoreID: job store id for the file
+ :param string userPath: a path to the name of file to which the global file will be copied
+ or hard-linked (see below).
+ :param boolean cache: Described in :func:`~filestore.Filestore.readGlobalFile`
+ :param boolean mutable: Described in :func:`~filestore.Filestore.readGlobalFile`
+ :return: An absolute path to a local, temporary copy of the file keyed by fileStoreID.
+ :rtype: str
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ def readGlobalFileStream(self, fileStoreID):
+ """
+ Similar to readGlobalFile, but allows a stream to be read from the job store. The yielded
+ file handle does not need to and should not be closed explicitly.
+
+ :return: a context manager yielding a file handle which can be read from.
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ def deleteLocalFile(self, fileStoreID):
+ """
+ Deletes Local copies of files associated with the provided job store ID.
+
+ :param str fileStoreID: File Store ID of the file to be deleted.
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ def deleteGlobalFile(self, fileStoreID):
+ """
+ Deletes local files with the provided job store ID and then permanently deletes them from
+ the job store. To ensure that the job can be restarted if necessary, the delete will not
+ happen until after the job's run method has completed.
+
+ :param fileStoreID: the job store ID of the file to be deleted.
+ """
+ raise NotImplementedError()
+
+ # Functions used to read and write files directly between a source url and the job store.
+ def importFile(self, srcUrl, sharedFileName=None):
+ return self.jobStore.importFile(srcUrl, sharedFileName=sharedFileName)
+
+ def exportFile(self, jobStoreFileID, dstUrl):
+ self.jobStore.exportFile(jobStoreFileID, dstUrl)
+
+ # A utility method for accessing filenames
+ def _resolveAbsoluteLocalPath(self, filePath):
+ """
+ Return the absolute path to filePath. This is a wrapper for os.path.abspath because mac OS
+ symlinks /tmp and /var (the most common places for a default tempdir) to /private/tmp and
+ /private/var respectively.
+
+ :param str filePath: The absolute or relative path to the file. If relative, it must be
+ relative to the local temp working dir
+ :return: Absolute path to key
+ :rtype: str
+ """
+ if os.path.isabs(filePath):
+ return os.path.abspath(filePath)
+ else:
+ return os.path.join(self.localTempDir, filePath)
+
+ class _StateFile(object):
+ """
+ Utility class to read and write dill-ed state dictionaries from/to a file into a namespace.
+ """
+ def __init__(self, stateDict):
+ assert isinstance(stateDict, dict)
+ self.__dict__.update(stateDict)
+
+ @abstractclassmethod
+ @contextmanager
+ def open(cls, outer=None):
+ """
+ This is a context manager that state file and reads it into an object that is returned
+ to the user in the yield.
+
+ :param outer: Instance of the calling class (to use outer methods).
+ """
+ raise NotImplementedError()
+
+ @classmethod
+ def _load(cls, fileName):
+ """
+ Load the state of the cache from the state file
+
+ :param str fileName: Path to the cache state file.
+ :return: An instance of the state as a namespace.
+ :rtype: _StateFile
+ """
+ # Read the value from the cache state file then initialize and instance of
+ # _CacheState with it.
+ with open(fileName, 'r') as fH:
+ infoDict = dill.load(fH)
+ return cls(infoDict)
+
+ def write(self, fileName):
+ """
+ Write the current state into a temporary file then atomically rename it to the main
+ state file.
+
+ :param str fileName: Path to the state file.
+ """
+ with open(fileName + '.tmp', 'w') as fH:
+ # Based on answer by user "Mark" at:
+ # http://stackoverflow.com/questions/2709800/how-to-pickle-yourself
+ # We can't pickle nested classes. So we have to pickle the variables of the class
+ # If we ever change this, we need to ensure it doesn't break FileID
+ dill.dump(self.__dict__, fH)
+ os.rename(fileName + '.tmp', fileName)
+
+ # Methods related to the deferred function logic
+ @abstractclassmethod
+ def findAndHandleDeadJobs(cls, nodeInfo, batchSystemShutdown=False):
+ """
+ This function looks at the state of all jobs registered on the node and will handle them
+ (clean up their presence ont he node, and run any registered defer functions)
+
+ :param nodeInfo: Information regarding the node required for identifying dead jobs.
+ :param bool batchSystemShutdown: Is the batch system in the process of shutting down?
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ def _registerDeferredFunction(self, deferredFunction):
+ """
+ Register the given deferred function with this job.
+
+ :param DeferredFunction deferredFunction: the function to register
+ """
+ raise NotImplementedError()
+
+ @staticmethod
+ def _runDeferredFunctions(deferredFunctions):
+ """
+ Invoke the specified deferred functions and return a list of names of functions that
+ raised an exception while being invoked.
+
+ :param list[DeferredFunction] deferredFunctions: the DeferredFunctions to run
+ :rtype: list[str]
+ """
+ failures = []
+ for deferredFunction in deferredFunctions:
+ try:
+ deferredFunction.invoke()
+ except:
+ failures.append(deferredFunction.name)
+ logger.exception('%s failed.', deferredFunction)
+ return failures
+
+ # Functions related to logging
+ def logToMaster(self, text, level=logging.INFO):
+ """
+ Send a logging message to the leader. The message will also be \
+ logged by the worker at the same level.
+
+ :param text: The string to log.
+ :param int level: The logging level.
+ """
+ logger.log(level=level, msg=("LOG-TO-MASTER: " + text))
+ self.loggingMessages.append(dict(text=text, level=level))
+
+ # Functions run after the completion of the job.
+ @abstractmethod
+ def _updateJobWhenDone(self):
+ """
+ Update the status of the job on the disk.
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ def _blockFn(self):
+ """
+ Blocks while _updateJobWhenDone is running. This function is called by this job's
+ successor to ensure that it does not begin modifying the job store until after this job has
+ finished doing so.
+ """
+ raise NotImplementedError()
+
+ # Utility function used to identify if a pid is still running on the node.
+ @staticmethod
+ def _pidExists(pid):
+ """
+ This will return True if the process associated with pid is still running on the machine.
+ This is based on stackoverflow question 568271.
+
+ :param int pid: ID of the process to check for
+ :return: True/False
+ :rtype: bool
+ """
+ assert pid > 0
+ try:
+ os.kill(pid, 0)
+ except OSError as err:
+ if err.errno == errno.ESRCH:
+ # ESRCH == No such process
+ return False
+ else:
+ raise
+ else:
+ return True
+
+ @abstractclassmethod
+ def shutdown(cls, dir_):
+ """
+ Shutdown the filestore on this node.
+
+ This is intended to be called on batch system shutdown.
+
+ :param dir_: The jeystone directory containing the required information for fixing the state
+ of failed workers on the node before cleaning up.
+ """
+ raise NotImplementedError()
+
+
+class CachingFileStore(FileStore):
+ """
+ A cache-enabled file store that attempts to use hard-links and asynchronous job store writes to
+ reduce I/O between, and during jobs.
+ """
+
+ def __init__(self, jobStore, jobGraph, localTempDir, inputBlockFn):
+ super(CachingFileStore, self).__init__(jobStore, jobGraph, localTempDir, inputBlockFn)
+ # Variables related to asynchronous writes.
+ self.workerNumber = 2
+ self.queue = Queue()
+ self.updateSemaphore = Semaphore()
+ self.mutable = self.jobStore.config.readGlobalFileMutableByDefault
+ self.workers = map(lambda i: Thread(target=self.asyncWrite),
+ range(self.workerNumber))
+ for worker in self.workers:
+ worker.start()
+ # Variables related to caching
+ # cacheDir has to be 1 levels above local worker tempdir, at the same level as the
+ # worker dirs. At this point, localTempDir is the worker directory, not the job
+ # directory.
+ self.localCacheDir = os.path.join(os.path.dirname(localTempDir),
+ cacheDirName(self.jobStore.config.workflowID))
+ self.cacheLockFile = os.path.join(self.localCacheDir, '.cacheLock')
+ self.cacheStateFile = os.path.join(self.localCacheDir, '_cacheState')
+ # Since each worker has it's own unique CachingFileStore instance, and only one Job can run
+ # at a time on a worker, we can bookkeep the job's file store operated files in a
+ # dictionary.
+ self.jobSpecificFiles = {}
+ self.jobName = str(self.jobGraph)
+ self.jobID = sha1(self.jobName).hexdigest()
+ logger.info('Starting job (%s) with ID (%s).', self.jobName, self.jobID)
+ # A variable to describe how many hard links an unused file in the cache will have.
+ self.nlinkThreshold = None
+ self.workflowAttemptNumber = self.jobStore.config.workflowAttemptNumber
+ # This is a flag to better resolve cache equation imbalances at cleanup time.
+ self.cleanupInProgress = False
+ # Now that we've setup all the required variables, setup the cache directory for the
+ # job if required.
+ self._setupCache()
+
+ @contextmanager
+ def open(self, job):
+ """
+ This context manager decorated method allows cache-specific operations to be conducted
+ before and after the execution of a job in worker.py
+ """
+ # Create a working directory for the job
+ startingDir = os.getcwd()
+ self.localTempDir = makePublicDir(os.path.join(self.localTempDir, str(uuid.uuid4())))
+ # Check the status of all jobs on this node. If there are jobs that started and died before
+ # cleaning up their presence from the cache state file, restore the cache file to a state
+ # where the jobs don't exist.
+ with self._CacheState.open(self) as cacheInfo:
+ self.findAndHandleDeadJobs(cacheInfo)
+ # Get the requirements for the job and clean the cache if necessary. cleanCache will
+ # ensure that the requirements for this job are stored in the state file.
+ jobReqs = job.disk
+ # Cleanup the cache to free up enough space for this job (if needed)
+ self.cleanCache(jobReqs)
+ try:
+ os.chdir(self.localTempDir)
+ yield
+ finally:
+ diskUsed = getDirSizeRecursively(self.localTempDir)
+ logString = ("Job {jobName} used {percent:.2f}% ({humanDisk}B [{disk}B] used, "
+ "{humanRequestedDisk}B [{requestedDisk}B] requested) at the end of "
+ "its run.".format(jobName=self.jobName,
+ percent=(float(diskUsed) / jobReqs * 100 if
+ jobReqs > 0 else 0.0),
+ humanDisk=bytes2human(diskUsed),
+ disk=diskUsed,
+ humanRequestedDisk=bytes2human(jobReqs),
+ requestedDisk=jobReqs))
+ self.logToMaster(logString, level=logging.DEBUG)
+ if diskUsed > jobReqs:
+ self.logToMaster("Job used more disk than requested. Please reconsider modifying "
+ "the user script to avoid the chance of failure due to "
+ "incorrectly requested resources. " + logString,
+ level=logging.WARNING)
+ os.chdir(startingDir)
+ self.cleanupInProgress = True
+ # Delete all the job specific files and return sizes to jobReqs
+ self.returnJobReqs(jobReqs)
+ with self._CacheState.open(self) as cacheInfo:
+ # Carry out any user-defined cleanup actions
+ deferredFunctions = cacheInfo.jobState[self.jobID]['deferredFunctions']
+ failures = self._runDeferredFunctions(deferredFunctions)
+ for failure in failures:
+ self.logToMaster('Deferred function "%s" failed.' % failure, logging.WARN)
+ # Finally delete the job from the cache state file
+ cacheInfo.jobState.pop(self.jobID)
+
+ # Functions related to reading, writing and removing files to/from the job store
+ def writeGlobalFile(self, localFileName, cleanup=False):
+ """
+ Takes a file (as a path) and uploads it to the job store. Depending on the jobstore
+ used, carry out the appropriate cache functions.
+ """
+ absLocalFileName = self._resolveAbsoluteLocalPath(localFileName)
+ # What does this do?
+ cleanupID = None if not cleanup else self.jobGraph.jobStoreID
+ # If the file is from the scope of local temp dir
+ if absLocalFileName.startswith(self.localTempDir):
+ # If the job store is of type FileJobStore and the job store and the local temp dir
+ # are on the same file system, then we want to hard link the files istead of copying
+ # barring the case where the file being written was one that was previously read
+ # from the file store. In that case, you want to copy to the file store so that
+ # the two have distinct nlink counts.
+ # Can read without a lock because we're only reading job-specific info.
+ jobSpecificFiles = self._CacheState._load(self.cacheStateFile).jobState[
+ self.jobID]['filesToFSIDs'].keys()
+ # Saying nlink is 2 implicitly means we are using the job file store, and it is on
+ # the same device as the work dir.
+ if self.nlinkThreshold == 2 and absLocalFileName not in jobSpecificFiles:
+ jobStoreFileID = self.jobStore.getEmptyFileStoreID(cleanupID)
+ # getEmptyFileStoreID creates the file in the scope of the job store hence we
+ # need to delete it before linking.
+ os.remove(self.jobStore._getAbsPath(jobStoreFileID))
+ os.link(absLocalFileName, self.jobStore._getAbsPath(jobStoreFileID))
+ # If they're not on the file system, or if the file is already linked with an
+ # existing file, we need to copy to the job store.
+ # Check if the user allows asynchronous file writes
+ elif self.jobStore.config.useAsync:
+ jobStoreFileID = self.jobStore.getEmptyFileStoreID(cleanupID)
+ # Before we can start the async process, we should also create a dummy harbinger
+ # file in the cache such that any subsequent jobs asking for this file will not
+ # attempt to download it from the job store till the write is complete. We do
+ # this now instead of in the writing thread because there is an edge case where
+ # readGlobalFile in a subsequent job is called before the writing thread has
+ # received the message to write the file and has created the dummy harbinger
+ # (and the file was unable to be cached/was evicted from the cache).
+ harbingerFile = self.HarbingerFile(self, fileStoreID=jobStoreFileID)
+ harbingerFile.write()
+ fileHandle = open(absLocalFileName, 'r')
+ with self._pendingFileWritesLock:
+ self._pendingFileWrites.add(jobStoreFileID)
+ # A file handle added to the queue allows the asyncWrite threads to remove their
+ # jobID from _pendingFileWrites. Therefore, a file should only be added after
+ # its fileID is added to _pendingFileWrites
+ self.queue.put((fileHandle, jobStoreFileID))
+ # Else write directly to the job store.
+ else:
+ jobStoreFileID = self.jobStore.writeFile(absLocalFileName, cleanupID)
+ # Local files are cached by default, unless they were written from previously read
+ # files.
+ if absLocalFileName not in jobSpecificFiles:
+ self.addToCache(absLocalFileName, jobStoreFileID, 'write')
+ else:
+ self._JobState.updateJobSpecificFiles(self, jobStoreFileID, absLocalFileName,
+ 0.0, False)
+ # Else write directly to the job store.
+ else:
+ jobStoreFileID = self.jobStore.writeFile(absLocalFileName, cleanupID)
+ # Non local files are NOT cached by default, but they are tracked as local files.
+ self._JobState.updateJobSpecificFiles(self, jobStoreFileID, None,
+ 0.0, False)
+ return FileID.forPath(jobStoreFileID, absLocalFileName)
+
+ def writeGlobalFileStream(self, cleanup=False):
+ # TODO: Make this work with caching
+ return super(CachingFileStore, self).writeGlobalFileStream(cleanup)
+
+ def readGlobalFile(self, fileStoreID, userPath=None, cache=True, mutable=None):
+ """
+ Downloads a file described by fileStoreID from the file store to the local directory.
+ The function first looks for the file in the cache and if found, it hardlinks to the
+ cached copy instead of downloading.
+
+ The cache parameter will be used only if the file isn't already in the cache, and
+ provided user path (if specified) is in the scope of local temp dir.
+
+ :param boolean cache: If True, a copy of the file will be saved into a cache that can be
+ used by other workers. caching supports multiple concurrent workers requesting the
+ same file by allowing only one to download the file while the others wait for it to
+ complete.
+ :param boolean mutable: If True, the file path returned points to a file that is
+ modifiable by the user. Using False is recommended as it saves disk by making
+ multiple workers share a file via hard links. The default is False unless backwards
+ compatibility was requested.
+ """
+ # Check that the file hasn't been deleted by the user
+ if fileStoreID in self.filesToDelete:
+ raise RuntimeError('Trying to access a file in the jobStore you\'ve deleted: ' + \
+ '%s' % fileStoreID)
+ # Set up the modifiable variable if it wasn't provided by the user in the function call.
+ if mutable is None:
+ mutable = self.mutable
+ # Get the name of the file as it would be in the cache
+ cachedFileName = self.encodedFileID(fileStoreID)
+ # setup the harbinger variable for the file. This is an identifier that the file is
+ # currently being downloaded by another job and will be in the cache shortly. It is used
+ # to prevent multiple jobs from simultaneously downloading the same file from the file
+ # store.
+ harbingerFile = self.HarbingerFile(self, cachedFileName=cachedFileName)
+ # setup the output filename. If a name is provided, use it - This makes it a Named
+ # Local File. If a name isn't provided, use the base64 encoded name such that we can
+ # easily identify the files later on.
+ if userPath is not None:
+ localFilePath = self._resolveAbsoluteLocalPath(userPath)
+ if os.path.exists(localFilePath):
+ # yes, this is illegal now.
+ raise RuntimeError(' File %s ' % localFilePath + ' exists. Cannot Overwrite.')
+ fileIsLocal = True if localFilePath.startswith(self.localTempDir) else False
+ else:
+ localFilePath = self.getLocalTempFileName()
+ fileIsLocal = True
+ # First check whether the file is in cache. If it is, then hardlink the file to
+ # userPath. Cache operations can only occur on local files.
+ with self.cacheLock() as lockFileHandle:
+ if fileIsLocal and self._fileIsCached(fileStoreID):
+ logger.debug('CACHE: Cache hit on file with ID \'%s\'.' % fileStoreID)
+ assert not os.path.exists(localFilePath)
+ if mutable:
+ shutil.copyfile(cachedFileName, localFilePath)
+ cacheInfo = self._CacheState._load(self.cacheStateFile)
+ jobState = self._JobState(cacheInfo.jobState[self.jobID])
+ jobState.addToJobSpecFiles(fileStoreID, localFilePath, -1, None)
+ cacheInfo.jobState[self.jobID] = jobState.__dict__
+ cacheInfo.write(self.cacheStateFile)
+ else:
+ os.link(cachedFileName, localFilePath)
+ self.returnFileSize(fileStoreID, localFilePath, lockFileHandle,
+ fileAlreadyCached=True)
+ # If the file is not in cache, check whether the .harbinger file for the given
+ # FileStoreID exists. If it does, the wait and periodically check for the removal
+ # of the file and the addition of the completed download into cache of the file by
+ # the other job. Then we link to it.
+ elif fileIsLocal and harbingerFile.exists():
+ harbingerFile.waitOnDownload(lockFileHandle)
+ # If the code reaches here, the harbinger file has been removed. This means
+ # either the file was successfully downloaded and added to cache, or something
+ # failed. To prevent code duplication, we recursively call readGlobalFile.
+ flock(lockFileHandle, LOCK_UN)
+ return self.readGlobalFile(fileStoreID, userPath=userPath, cache=cache,
+ mutable=mutable)
+ # If the file is not in cache, then download it to the userPath and then add to
+ # cache if specified.
+ else:
+ logger.debug('CACHE: Cache miss on file with ID \'%s\'.' % fileStoreID)
+ if fileIsLocal and cache:
+ # If caching of the downloaded file is desired, First create the harbinger
+ # file so other jobs know not to redundantly download the same file. Write
+ # the PID of this process into the file so other jobs know who is carrying
+ # out the download.
+ harbingerFile.write()
+ # Now release the file lock while the file is downloaded as download could
+ # take a while.
+ flock(lockFileHandle, LOCK_UN)
+ # Use try:finally: so that the .harbinger file is removed whether the
+ # download succeeds or not.
+ try:
+ self.jobStore.readFile(fileStoreID,
+ '/.'.join(os.path.split(cachedFileName)))
+ except:
+ if os.path.exists('/.'.join(os.path.split(cachedFileName))):
+ os.remove('/.'.join(os.path.split(cachedFileName)))
+ raise
+ else:
+ # If the download succeded, officially add the file to cache (by
+ # recording it in the cache lock file) if possible.
+ if os.path.exists('/.'.join(os.path.split(cachedFileName))):
+ os.rename('/.'.join(os.path.split(cachedFileName)), cachedFileName)
+ self.addToCache(localFilePath, fileStoreID, 'read', mutable)
+ # We don't need to return the file size here because addToCache
+ # already does it for us
+ finally:
+ # In any case, delete the harbinger file.
+ harbingerFile.delete()
+ else:
+ # Release the cache lock since the remaining stuff is not cache related.
+ flock(lockFileHandle, LOCK_UN)
+ self.jobStore.readFile(fileStoreID, localFilePath)
+ os.chmod(localFilePath, stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH)
+ # Now that we have the file, we have 2 options. It's modifiable or not.
+ # Either way, we need to account for FileJobStore making links instead of
+ # copies.
+ if mutable:
+ if self.nlinkThreshold == 2:
+ # nlinkThreshold can only be 1 or 2 and it can only be 2 iff the
+ # job store is FilejobStore, and the job store and local temp dir
+ # are on the same device. An atomic rename removes the nlink on the
+ # file handle linked from the job store.
+ shutil.copyfile(localFilePath, localFilePath + '.tmp')
+ os.rename(localFilePath + '.tmp', localFilePath)
+ self._JobState.updateJobSpecificFiles(self, fileStoreID, localFilePath,
+ -1, False)
+ # If it was immutable
+ else:
+ if self.nlinkThreshold == 2:
+ self._accountForNlinkEquals2(localFilePath)
+ self._JobState.updateJobSpecificFiles(self, fileStoreID, localFilePath,
+ 0.0, False)
+ return localFilePath
+
+ def readGlobalFileStream(self, fileStoreID):
+ if fileStoreID in self.filesToDelete:
+ raise RuntimeError(
+ "Trying to access a file in the jobStore you've deleted: %s" % fileStoreID)
+
+ # If fileStoreID is in the cache provide a handle from the local cache
+ if self._fileIsCached(fileStoreID):
+ logger.debug('CACHE: Cache hit on file with ID \'%s\'.' % fileStoreID)
+ return open(self.encodedFileID(fileStoreID), 'r')
+ else:
+ logger.debug('CACHE: Cache miss on file with ID \'%s\'.' % fileStoreID)
+ return self.jobStore.readFileStream(fileStoreID)
+
+ def deleteLocalFile(self, fileStoreID):
+ # The local file may or may not have been cached. If it was, we need to do some
+ # bookkeeping. If it wasn't, we just delete the file and continue with no might need
+ # some bookkeeping if the file store and cache live on the same filesystem. We can know
+ # if a file was cached or not based on the value held in the third tuple value for the
+ # dict item having key = fileStoreID. If it was cached, it holds the value True else
+ # False.
+ with self._CacheState.open(self) as cacheInfo:
+ jobState = self._JobState(cacheInfo.jobState[self.jobID])
+ if fileStoreID not in jobState.jobSpecificFiles.keys():
+ # EOENT indicates that the file did not exist
+ raise OSError(errno.ENOENT, "Attempting to delete a non-local file")
+ # filesToDelete is a dictionary of file: fileSize
+ filesToDelete = jobState.jobSpecificFiles[fileStoreID]
+ allOwnedFiles = jobState.filesToFSIDs
+ for (fileToDelete, fileSize) in filesToDelete.items():
+ # Handle the case where a file not in the local temp dir was written to
+ # filestore
+ if fileToDelete is None:
+ filesToDelete.pop(fileToDelete)
+ allOwnedFiles[fileToDelete].remove(fileStoreID)
+ cacheInfo.jobState[self.jobID] = jobState.__dict__
+ cacheInfo.write(self.cacheStateFile)
+ continue
+ # If the file size is zero (copied into the local temp dir) or -1 (mutable), we
+ # can safely delete without any bookkeeping
+ if fileSize in (0, -1):
+ # Only remove the file if there is only one FSID associated with it.
+ if len(allOwnedFiles[fileToDelete]) == 1:
+ try:
+ os.remove(fileToDelete)
+ except OSError as err:
+ if err.errno == errno.ENOENT and fileSize == -1:
+ logger.debug('%s was read mutably and deleted by the user',
+ fileToDelete)
+ else:
+ raise IllegalDeletionCacheError(fileToDelete)
+ allOwnedFiles[fileToDelete].remove(fileStoreID)
+ filesToDelete.pop(fileToDelete)
+ cacheInfo.jobState[self.jobID] = jobState.__dict__
+ cacheInfo.write(self.cacheStateFile)
+ continue
+ # If not, we need to do bookkeeping
+ # Get the size of the file to be deleted, and the number of jobs using the file
+ # at the moment.
+ if not os.path.exists(fileToDelete):
+ raise IllegalDeletionCacheError(fileToDelete)
+ fileStats = os.stat(fileToDelete)
+ if fileSize != fileStats.st_size:
+ logger.warn("the size on record differed from the real size by " +
+ "%s bytes" % str(fileSize - fileStats.st_size))
+ # Remove the file and return file size to the job
+ if len(allOwnedFiles[fileToDelete]) == 1:
+ os.remove(fileToDelete)
+ cacheInfo.sigmaJob += fileSize
+ filesToDelete.pop(fileToDelete)
+ allOwnedFiles[fileToDelete].remove(fileStoreID)
+ jobState.updateJobReqs(fileSize, 'remove')
+ cacheInfo.jobState[self.jobID] = jobState.__dict__
+ # If the job is not in the process of cleaning up, then we may need to remove the
+ # cached copy of the file as well.
+ if not self.cleanupInProgress:
+ # If the file is cached and if other jobs are using the cached copy of the file,
+ # or if retaining the file in the cache doesn't affect the cache equation, then
+ # don't remove it from cache.
+ if self._fileIsCached(fileStoreID):
+ cachedFile = self.encodedFileID(fileStoreID)
+ jobsUsingFile = os.stat(cachedFile).st_nlink
+ if not cacheInfo.isBalanced() and jobsUsingFile == self.nlinkThreshold:
+ os.remove(cachedFile)
+ cacheInfo.cached -= fileSize
+ self.logToMaster('Successfully deleted cached copy of file with ID '
+ '\'%s\'.' % fileStoreID, level=logging.DEBUG)
+ self.logToMaster('Successfully deleted local copies of file with ID '
+ '\'%s\'.' % fileStoreID, level=logging.DEBUG)
+
+ def deleteGlobalFile(self, fileStoreID):
+ jobStateIsPopulated = False
+ with self._CacheState.open(self) as cacheInfo:
+ if self.jobID in cacheInfo.jobState:
+ jobState = self._JobState(cacheInfo.jobState[self.jobID])
+ jobStateIsPopulated = True
+ if jobStateIsPopulated and fileStoreID in jobState.jobSpecificFiles.keys():
+ # Use deleteLocalFile in the backend to delete the local copy of the file.
+ self.deleteLocalFile(fileStoreID)
+ # At this point, the local file has been deleted, and possibly the cached copy. If
+ # the cached copy exists, it is either because another job is using the file, or
+ # because retaining the file in cache doesn't unbalance the caching equation. The
+ # first case is unacceptable for deleteGlobalFile and the second requires explicit
+ # deletion of the cached copy.
+ # Check if the fileStoreID is in the cache. If it is, ensure only the current job is
+ # using it.
+ cachedFile = self.encodedFileID(fileStoreID)
+ if os.path.exists(cachedFile):
+ self.removeSingleCachedFile(fileStoreID)
+ # Add the file to the list of files to be deleted once the run method completes.
+ self.filesToDelete.add(fileStoreID)
+ self.logToMaster('Added file with ID \'%s\' to the list of files to be' % fileStoreID +
+ ' globally deleted.', level=logging.DEBUG)
+
+ # Cache related methods
+ @contextmanager
+ def cacheLock(self):
+ """
+ This is a context manager to acquire a lock on the Lock file that will be used to
+ prevent synchronous cache operations between workers.
+ :yields: File descriptor for cache lock file in r+ mode
+ """
+ cacheLockFile = open(self.cacheLockFile, 'w')
+ try:
+ flock(cacheLockFile, LOCK_EX)
+ logger.debug("CACHE: Obtained lock on file %s" % self.cacheLockFile)
+ yield cacheLockFile
+ except IOError:
+ logger.critical('CACHE: Unable to acquire lock on %s' % self.cacheLockFile)
+ raise
+ finally:
+ cacheLockFile.close()
+ logger.debug("CACHE: Released lock")
+
+ def _setupCache(self):
+ """
+ Setup the cache based on the provided values for localCacheDir.
+ """
+ # we first check whether the cache directory exists. If it doesn't, create it.
+ if not os.path.exists(self.localCacheDir):
+ # Create a temporary directory as this worker's private cache. If all goes well, it
+ # will be renamed into the cache for this node.
+ personalCacheDir = ''.join([os.path.dirname(self.localCacheDir), '/.ctmp-',
+ str(uuid.uuid4())])
+ os.mkdir(personalCacheDir, 0o755)
+ self._createCacheLockFile(personalCacheDir)
+ try:
+ os.rename(personalCacheDir, self.localCacheDir)
+ except OSError as err:
+ # The only acceptable FAIL case is that the destination is a non-empty directory
+ # directory. Assuming (it's ambiguous) atomic renaming of directories, if the
+ # dst is non-empty, it only means that another worker has beaten this one to the
+ # rename.
+ if err.errno == errno.ENOTEMPTY:
+ # Cleanup your own mess. It's only polite.
+ shutil.rmtree(personalCacheDir)
+ else:
+ raise
+ # You can't reach here unless a local cache directory has been created successfully
+ with self._CacheState.open(self) as cacheInfo:
+ # Ensure this cache is from the correct attempt at the workflow! If it isn't, we
+ # need to reset the cache lock file
+ if cacheInfo.attemptNumber != self.workflowAttemptNumber:
+ if cacheInfo.nlink == 2:
+ cacheInfo.cached = 0 # cached file sizes are accounted for by job store
+ else:
+ allCachedFiles = [os.path.join(self.localCacheDir, x)
+ for x in os.listdir(self.localCacheDir)
+ if not self._isHidden(x)]
+ cacheInfo.cached = sum([os.stat(cachedFile).st_size
+ for cachedFile in allCachedFiles])
+ # TODO: Delete the working directories
+ cacheInfo.sigmaJob = 0
+ cacheInfo.attemptNumber = self.workflowAttemptNumber
+ self.nlinkThreshold = cacheInfo.nlink
+
+ def _createCacheLockFile(self, tempCacheDir):
+ """
+ Create the cache lock file file to contain the state of the cache on the node.
+
+ :param str tempCacheDir: Temporary directory to use for setting up a cache lock file the
+ first time.
+ """
+ # The nlink threshold is setup along with the first instance of the cache class on the
+ # node.
+ self.setNlinkThreshold()
+ # Get the free space on the device
+ diskStats = os.statvfs(tempCacheDir)
+ freeSpace = diskStats.f_frsize * diskStats.f_bavail
+ # Create the cache lock file.
+ open(os.path.join(tempCacheDir, os.path.basename(self.cacheLockFile)), 'w').close()
+ # Setup the cache state file
+ personalCacheStateFile = os.path.join(tempCacheDir,
+ os.path.basename(self.cacheStateFile))
+ # Setup the initial values for the cache state file in a dict
+ cacheInfo = self._CacheState({
+ 'nlink': self.nlinkThreshold,
+ 'attemptNumber': self.workflowAttemptNumber,
+ 'total': freeSpace,
+ 'cached': 0,
+ 'sigmaJob': 0,
+ 'cacheDir': self.localCacheDir,
+ 'jobState': {}})
+ cacheInfo.write(personalCacheStateFile)
+
+ def encodedFileID(self, jobStoreFileID):
+ """
+ Uses a url safe base64 encoding to encode the jobStoreFileID into a unique identifier to
+ use as filename within the cache folder. jobstore IDs are essentially urls/paths to
+ files and thus cannot be used as is. Base64 encoding is used since it is reversible.
+
+ :param jobStoreFileID: string representing a job store file ID
+ :return: outCachedFile: A path to the hashed file in localCacheDir
+ :rtype: str
+ """
+ outCachedFile = os.path.join(self.localCacheDir,
+ base64.urlsafe_b64encode(jobStoreFileID))
+ return outCachedFile
+
+ def _fileIsCached(self, jobStoreFileID):
+ """
+ Is the file identified by jobStoreFileID in cache or not.
+ """
+ return os.path.exists(self.encodedFileID(jobStoreFileID))
+
+ def decodedFileID(self, cachedFilePath):
+ """
+ Decode a cached fileName back to a job store file ID.
+
+ :param str cachedFilePath: Path to the cached file
+ :return: The jobstore file ID associated with the file
+ :rtype: str
+ """
+ fileDir, fileName = os.path.split(cachedFilePath)
+ assert fileDir == self.localCacheDir, 'Can\'t decode uncached file names'
+ return base64.urlsafe_b64decode(fileName)
+
+ def addToCache(self, localFilePath, jobStoreFileID, callingFunc, mutable=None):
+ """
+ Used to process the caching of a file. This depends on whether a file is being written
+ to file store, or read from it.
+ WRITING
+ The file is in localTempDir. It needs to be linked into cache if possible.
+ READING
+ The file is already in the cache dir. Depending on whether it is modifiable or not, does
+ it need to be linked to the required location, or copied. If it is copied, can the file
+ still be retained in cache?
+
+ :param str localFilePath: Path to the Source file
+ :param jobStoreFileID: jobStoreID for the file
+ :param str callingFunc: Who called this function, 'write' or 'read'
+ :param boolean mutable: See modifiable in readGlobalFile
+ """
+ assert callingFunc in ('read', 'write')
+ # Set up the modifiable variable if it wasn't provided by the user in the function call.
+ if mutable is None:
+ mutable = self.mutable
+ assert isinstance(mutable, bool)
+ with self.cacheLock() as lockFileHandle:
+ cachedFile = self.encodedFileID(jobStoreFileID)
+ # The file to be cached MUST originate in the environment of the TOIL temp directory
+ if (os.stat(self.localCacheDir).st_dev !=
+ os.stat(os.path.dirname(localFilePath)).st_dev):
+ raise InvalidSourceCacheError('Attempting to cache a file across file systems '
+ 'cachedir = %s, file = %s.' % (self.localCacheDir,
+ localFilePath))
+ if not localFilePath.startswith(self.localTempDir):
+ raise InvalidSourceCacheError('Attempting a cache operation on a non-local file '
+ '%s.' % localFilePath)
+ if callingFunc == 'read' and mutable:
+ shutil.copyfile(cachedFile, localFilePath)
+ fileSize = os.stat(cachedFile).st_size
+ cacheInfo = self._CacheState._load(self.cacheStateFile)
+ cacheInfo.cached += fileSize if cacheInfo.nlink != 2 else 0
+ if not cacheInfo.isBalanced():
+ os.remove(cachedFile)
+ cacheInfo.cached -= fileSize if cacheInfo.nlink != 2 else 0
+ logger.debug('Could not download both download ' +
+ '%s as mutable and add to ' % os.path.basename(localFilePath) +
+ 'cache. Hence only mutable copy retained.')
+ else:
+ logger.info('CACHE: Added file with ID \'%s\' to the cache.' %
+ jobStoreFileID)
+ jobState = self._JobState(cacheInfo.jobState[self.jobID])
+ jobState.addToJobSpecFiles(jobStoreFileID, localFilePath, -1, False)
+ cacheInfo.jobState[self.jobID] = jobState.__dict__
+ cacheInfo.write(self.cacheStateFile)
+ else:
+ # There are two possibilities, read and immutable, and write. both cases do
+ # almost the same thing except for the direction of the os.link hence we're
+ # writing them together.
+ if callingFunc == 'read': # and mutable is inherently False
+ src = cachedFile
+ dest = localFilePath
+ # To mirror behaviour of shutil.copyfile
+ if os.path.exists(dest):
+ os.remove(dest)
+ else: # write
+ src = localFilePath
+ dest = cachedFile
+ try:
+ os.link(src, dest)
+ except OSError as err:
+ if err.errno != errno.EEXIST:
+ raise
+ # If we get the EEXIST error, it can only be from write since in read we are
+ # explicitly deleting the file. This shouldn't happen with the .partial
+ # logic hence we raise a cache error.
+ raise CacheError('Attempting to recache a file %s.' % src)
+ else:
+ # Chmod the cached file. Cached files can never be modified.
+ os.chmod(cachedFile, stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH)
+ # Return the filesize of cachedFile to the job and increase the cached size
+ # The values passed here don't matter since rFS looks at the file only for
+ # the stat
+ self.returnFileSize(jobStoreFileID, localFilePath, lockFileHandle,
+ fileAlreadyCached=False)
+ if callingFunc == 'read':
+ logger.debug('CACHE: Read file with ID \'%s\' from the cache.' %
+ jobStoreFileID)
+ else:
+ logger.debug('CACHE: Added file with ID \'%s\' to the cache.' %
+ jobStoreFileID)
+
+ def returnFileSize(self, fileStoreID, cachedFileSource, lockFileHandle,
+ fileAlreadyCached=False):
+ """
+ Returns the fileSize of the file described by fileStoreID to the job requirements pool
+ if the file was recently added to, or read from cache (A job that reads n bytes from
+ cache doesn't really use those n bytes as a part of it's job disk since cache is already
+ accounting for that disk space).
+
+ :param fileStoreID: fileStore ID of the file bein added to cache
+ :param str cachedFileSource: File being added to cache
+ :param file lockFileHandle: Open file handle to the cache lock file
+ :param bool fileAlreadyCached: A flag to indicate whether the file was already cached or
+ not. If it was, then it means that you don't need to add the filesize to cache again.
+ """
+ fileSize = os.stat(cachedFileSource).st_size
+ cacheInfo = self._CacheState._load(self.cacheStateFile)
+ # If the file isn't cached, add the size of the file to the cache pool. However, if the
+ # nlink threshold is not 1 - i.e. it is 2 (it can only be 1 or 2), then don't do this
+ # since the size of the file is accounted for by the file store copy.
+ if not fileAlreadyCached and self.nlinkThreshold == 1:
+ cacheInfo.cached += fileSize
+ cacheInfo.sigmaJob -= fileSize
+ if not cacheInfo.isBalanced():
+ self.logToMaster('CACHE: The cache was not balanced on returning file size',
+ logging.WARN)
+ # Add the info to the job specific cache info
+ jobState = self._JobState(cacheInfo.jobState[self.jobID])
+ jobState.addToJobSpecFiles(fileStoreID, cachedFileSource, fileSize, True)
+ cacheInfo.jobState[self.jobID] = jobState.__dict__
+ cacheInfo.write(self.cacheStateFile)
+
+ @staticmethod
+ def _isHidden(filePath):
+ """
+ This is a function that checks whether filePath is hidden
+
+ :param str filePath: Path to the file under consideration
+ :return: A boolean indicating whether the file is hidden or not.
+ :rtype: bool
+ """
+ assert isinstance(filePath, str)
+ # I can safely assume i will never see an empty string because this is always called on
+ # the results of an os.listdir()
+ return filePath[0] in ('.', '_')
+
+ def cleanCache(self, newJobReqs):
+ """
+ Cleanup all files in the cache directory to ensure that at lead newJobReqs are available
+ for use.
+
+ :param float newJobReqs: the total number of bytes of files allowed in the cache.
+ """
+ with self._CacheState.open(self) as cacheInfo:
+ # Add the new job's disk requirements to the sigmaJobDisk variable
+ cacheInfo.sigmaJob += newJobReqs
+ # Initialize the job state here. we use a partial in the jobSpecificFiles call so
+ # that this entire thing is pickleable. Based on answer by user Nathaniel Gentile at
+ # http://stackoverflow.com/questions/2600790
+ assert self.jobID not in cacheInfo.jobState
+ cacheInfo.jobState[self.jobID] = {
+ 'jobName': self.jobName,
+ 'jobReqs': newJobReqs,
+ 'jobDir': self.localTempDir,
+ 'jobSpecificFiles': defaultdict(partial(defaultdict,int)),
+ 'filesToFSIDs': defaultdict(set),
+ 'pid': os.getpid(),
+ 'deferredFunctions': []}
+ # If the caching equation is balanced, do nothing.
+ if cacheInfo.isBalanced():
+ return None
+
+ # List of deletable cached files. A deletable cache file is one
+ # that is not in use by any other worker (identified by the number of symlinks to
+ # the file)
+ allCacheFiles = [os.path.join(self.localCacheDir, x)
+ for x in os.listdir(self.localCacheDir)
+ if not self._isHidden(x)]
+ allCacheFiles = [(path, os.stat(path)) for path in allCacheFiles]
+ # TODO mtime vs ctime
+ deletableCacheFiles = {(path, inode.st_mtime, inode.st_size)
+ for path, inode in allCacheFiles
+ if inode.st_nlink == self.nlinkThreshold}
+
+ # Sort in descending order of mtime so the first items to be popped from the list
+ # are the least recently created.
+ deletableCacheFiles = sorted(deletableCacheFiles, key=lambda x: (-x[1], -x[2]))
+ logger.debug('CACHE: Need %s bytes for new job. Have %s' %
+ (newJobReqs, cacheInfo.cached + cacheInfo.sigmaJob - newJobReqs))
+ logger.debug('CACHE: Evicting files to make room for the new job.')
+
+ # Now do the actual file removal
+ while not cacheInfo.isBalanced() and len(deletableCacheFiles) > 0:
+ cachedFile, fileCreateTime, cachedFileSize = deletableCacheFiles.pop()
+ os.remove(cachedFile)
+ cacheInfo.cached -= cachedFileSize if self.nlinkThreshold != 2 else 0
+ assert cacheInfo.cached >= 0
+ # self.logToMaster('CACHE: Evicted file with ID \'%s\' (%s bytes)' %
+ # (self.decodedFileID(cachedFile), cachedFileSize))
+ logger.debug('CACHE: Evicted file with ID \'%s\' (%s bytes)' %
+ (self.decodedFileID(cachedFile), cachedFileSize))
+ if not cacheInfo.isBalanced():
+ raise CacheUnbalancedError()
+ logger.debug('CACHE: After Evictions, ended up with %s.' %
+ (cacheInfo.cached + cacheInfo.sigmaJob))
+
+ def removeSingleCachedFile(self, fileStoreID):
+ """
+ Removes a single file described by the fileStoreID from the cache forcibly.
+ """
+ with self._CacheState.open(self) as cacheInfo:
+ cachedFile = self.encodedFileID(fileStoreID)
+ cachedFileStats = os.stat(cachedFile)
+ # We know the file exists because this function was called in the if block. So we
+ # have to ensure nothing has changed since then.
+ assert cachedFileStats.st_nlink == self.nlinkThreshold, 'Attempting to delete ' + \
+ 'a global file that is in use by another job.'
+ # Remove the file size from the cached file size if the jobstore is not fileJobStore
+ # and then delete the file
+ os.remove(cachedFile)
+ if self.nlinkThreshold != 2:
+ cacheInfo.cached -= cachedFileStats.st_size
+ if not cacheInfo.isBalanced():
+ self.logToMaster('CACHE: The cache was not balanced on removing single file',
+ logging.WARN)
+ self.logToMaster('CACHE: Successfully removed file with ID \'%s\'.' % fileStoreID)
+ return None
+
+ def setNlinkThreshold(self):
+ # FIXME Can't do this at the top because of loopy (circular) import errors
+ from toil.jobStores.fileJobStore import FileJobStore
+ if (isinstance(self.jobStore, FileJobStore) and
+ os.stat(os.path.dirname(self.localCacheDir)).st_dev == os.stat(
+ self.jobStore.jobStoreDir).st_dev):
+ self.nlinkThreshold = 2
+ else:
+ self.nlinkThreshold = 1
+
+ def _accountForNlinkEquals2(self, localFilePath):
+ """
+ This is a utility function that accounts for the fact that if nlinkThreshold == 2, the
+ size of the file is accounted for by the file store copy of the file and thus the file
+ size shouldn't be added to the cached file sizes.
+
+ :param str localFilePath: Path to the local file that was linked to the file store copy.
+ """
+ fileStats = os.stat(localFilePath)
+ assert fileStats.st_nlink >= self.nlinkThreshold
+ with self._CacheState.open(self) as cacheInfo:
+ cacheInfo.sigmaJob -= fileStats.st_size
+ jobState = self._JobState(cacheInfo.jobState[self.jobID])
+ jobState.updateJobReqs(fileStats.st_size, 'remove')
+
+ def returnJobReqs(self, jobReqs):
+ """
+ This function returns the effective job requirements back to the pool after the job
+ completes. It also deletes the local copies of files with the cache lock held.
+
+ :param float jobReqs: Original size requirement of the job
+ """
+ # Since we are only reading this job's specific values from the state file, we don't
+ # need a lock
+ jobState = self._JobState(self._CacheState._load(self.cacheStateFile
+ ).jobState[self.jobID])
+ for x in jobState.jobSpecificFiles.keys():
+ self.deleteLocalFile(x)
+ with self._CacheState.open(self) as cacheInfo:
+ cacheInfo.sigmaJob -= jobReqs
+ # assert cacheInfo.isBalanced() # commenting this out for now. God speed
+
+ class _CacheState(FileStore._StateFile):
+ """
+ Utility class to read and write the cache lock file. Also for checking whether the
+ caching equation is balanced or not. It extends the _StateFile class to add other cache
+ related functions.
+ """
+ @classmethod
+ @contextmanager
+ def open(cls, outer=None):
+ """
+ This is a context manager that opens the cache state file and reads it into an object
+ that is returned to the user in the yield
+ """
+ assert outer is not None
+ with outer.cacheLock():
+ cacheInfo = cls._load(outer.cacheStateFile)
+ yield cacheInfo
+ cacheInfo.write(outer.cacheStateFile)
+
+ def isBalanced(self):
+ """
+ Checks for the inequality of the caching equation, i.e.
+ cachedSpace + sigmaJobDisk <= totalFreeSpace
+ Essentially, the sum of all cached file + disk requirements of all running jobs
+ should always be less than the available space on the system
+
+ :return: Boolean for equation is balanced (T) or not (F)
+ :rtype: bool
+ """
+ return self.cached + self.sigmaJob <= self.total
+
+ def purgeRequired(self, jobReqs):
+ """
+ Similar to isBalanced, however it looks at the actual state of the system and
+ decides whether an eviction is required.
+
+ :return: Is a purge required(T) or no(F)
+ :rtype: bool
+ """
+ return not self.isBalanced()
+ # totalStats = os.statvfs(self.cacheDir)
+ # totalFree = totalStats.f_bavail * totalStats.f_frsize
+ # return totalFree < jobReqs
+
+ # Methods related to the deferred function logic
+ @classmethod
+ def findAndHandleDeadJobs(cls, nodeInfo, batchSystemShutdown=False):
+ """
+
+ :param toil.fileStore.CachingFileStore._CacheState nodeInfo: The state of the node cache as
+ a _CacheState object
+ """
+ # A list of tuples of (hashed job id, pid or process running job)
+ registeredJobs = [(jid, state['pid']) for jid, state in nodeInfo.jobState.items()]
+ for jobID, jobPID in registeredJobs:
+ if not cls._pidExists(jobPID):
+ jobState = CachingFileStore._JobState(nodeInfo.jobState[jobID])
+ logger.warning('Detected that job (%s) prematurely terminated. Fixing the state '
+ 'of the cache.', jobState.jobName)
+ if not batchSystemShutdown:
+ logger.debug("Returning dead job's used disk to cache.")
+ # Delete the old work directory if it still exists, to remove unwanted nlinks.
+ # Do this only during the life of the program and dont' do it during the
+ # batch system cleanup. Leave that to the batch system cleanup code.
+ if os.path.exists(jobState.jobDir):
+ shutil.rmtree(jobState.jobDir)
+ nodeInfo.sigmaJob -= jobState.jobReqs
+ logger.debug('Running user-defined deferred functions.')
+ cls._runDeferredFunctions(jobState.deferredFunctions)
+ # Remove job from the cache state file
+ nodeInfo.jobState.pop(jobID)
+
+ def _registerDeferredFunction(self, deferredFunction):
+ with self._CacheState.open(self) as cacheInfo:
+ cacheInfo.jobState[self.jobID]['deferredFunctions'].append(deferredFunction)
+ logger.debug('Registered "%s" with job "%s".', deferredFunction, self.jobName)
+
+ class _JobState(object):
+ """
+ This is a utility class to handle the state of a job in terms of it's current disk
+ requirements, working directory, and job specific files.
+ """
+
+ def __init__(self, dictObj):
+ assert isinstance(dictObj, dict)
+ self.__dict__.update(dictObj)
+
+ @classmethod
+ def updateJobSpecificFiles(cls, outer, jobStoreFileID, filePath, fileSize, cached):
+ """
+ This method will update the job specifc files in the job state object. It deals with
+ opening a cache lock file, etc.
+
+ :param toil.fileStore.CachingFileStore outer: An instance of CachingFileStore
+ :param str jobStoreFileID: job store Identifier for the file
+ :param str filePath: The path to the file
+ :param float fileSize: The size of the file (may be deprecated soon)
+ :param bool cached: T : F : None :: cached : not cached : mutably read
+ """
+ with outer._CacheState.open(outer) as cacheInfo:
+ jobState = cls(cacheInfo.jobState[outer.jobID])
+ jobState.addToJobSpecFiles(jobStoreFileID, filePath, fileSize, cached)
+ cacheInfo.jobState[outer.jobID] = jobState.__dict__
+
+ def addToJobSpecFiles(self, jobStoreFileID, filePath, fileSize, cached):
+ """
+ This is the real method that actually does the updations.
+
+ :param jobStoreFileID: job store Identifier for the file
+ :param filePath: The path to the file
+ :param fileSize: The size of the file (may be deprecated soon)
+ :param cached: T : F : None :: cached : not cached : mutably read
+ """
+ # If there is no entry for the jsfID, make one. self.jobSpecificFiles is a default
+ # dict of default dicts and the absence of a key will return an empty dict
+ # (equivalent to a None for the if)
+ if not self.jobSpecificFiles[jobStoreFileID]:
+ self.jobSpecificFiles[jobStoreFileID][filePath] = fileSize
+ else:
+ # If there's no entry for the filepath, create one
+ if not self.jobSpecificFiles[jobStoreFileID][filePath]:
+ self.jobSpecificFiles[jobStoreFileID][filePath] = fileSize
+ # This should never happen
+ else:
+ raise RuntimeError()
+ # Now add the file to the reverse mapper. This will speed up cleanup and local file
+ # deletion.
+ self.filesToFSIDs[filePath].add(jobStoreFileID)
+ if cached:
+ self.updateJobReqs(fileSize, 'add')
+
+ def updateJobReqs(self, fileSize, actions):
+ """
+ This method will update the current state of the disk required by the job after the
+ most recent cache operation.
+
+ :param fileSize: Size of the last file added/removed from the cache
+ :param actions: 'add' or 'remove'
+ """
+ assert actions in ('add', 'remove')
+ multiplier = 1 if actions == 'add' else -1
+ # If the file was added to the cache, the value is subtracted from the requirements,
+ # and it is added if the file was removed form the cache.
+ self.jobReqs -= (fileSize * multiplier)
+
+ def isPopulated(self):
+ return self.__dict__ != {}
+
+ class HarbingerFile(object):
+ """
+ Represents the placeholder file that harbinges the arrival of a local copy of a file in
+ the job store.
+ """
+
+ def __init__(self, fileStore, fileStoreID=None, cachedFileName=None):
+ """
+ Returns the harbinger file name for a cached file, or for a job store ID
+
+ :param class fileStore: The 'self' object of the fileStore class
+ :param str fileStoreID: The file store ID for an input file
+ :param str cachedFileName: The cache file name corresponding to a given file
+ """
+ # We need either a file store ID, or a cached file name, but not both (XOR).
+ assert (fileStoreID is None) != (cachedFileName is None)
+ if fileStoreID is not None:
+ self.fileStoreID = fileStoreID
+ cachedFileName = fileStore.encodedFileID(fileStoreID)
+ else:
+ self.fileStoreID = fileStore.decodedFileID(cachedFileName)
+ self.fileStore = fileStore
+ self.harbingerFileName = '/.'.join(os.path.split(cachedFileName)) + '.harbinger'
+
+ def write(self):
+ self.fileStore.logToMaster('CACHE: Creating a harbinger file for (%s). '
+ % self.fileStoreID, logging.DEBUG)
+ with open(self.harbingerFileName + '.tmp', 'w') as harbingerFile:
+ harbingerFile.write(str(os.getpid()))
+ # Make this File read only to prevent overwrites
+ os.chmod(self.harbingerFileName + '.tmp', 0o444)
+ os.rename(self.harbingerFileName + '.tmp', self.harbingerFileName)
+
+ def waitOnDownload(self, lockFileHandle):
+ """
+ This method is called when a readGlobalFile process is waiting on another process to
+ write a file to the cache.
+
+ :param lockFileHandle: The open handle to the cache lock file
+ """
+ while self.exists():
+ logger.info('CACHE: Waiting for another worker to download file with ID %s.'
+ % self.fileStoreID)
+ # Ensure that the process downloading the file is still alive. The PID will
+ # be in the harbinger file.
+ pid = self.read()
+ if FileStore._pidExists(pid):
+ # Release the file lock and then wait for a bit before repeating.
+ flock(lockFileHandle, LOCK_UN)
+ time.sleep(20)
+ # Grab the file lock before repeating.
+ flock(lockFileHandle, LOCK_EX)
+ else:
+ # The process that was supposed to download the file has died so we need
+ # to remove the harbinger.
+ self._delete()
+
+ def read(self):
+ return int(open(self.harbingerFileName).read())
+
+ def exists(self):
+ return os.path.exists(self.harbingerFileName)
+
+ def delete(self):
+ """
+ Acquires the cache lock then attempts to delete the harbinger file.
+ """
+ with self.fileStore.cacheLock():
+ self._delete()
+
+ def _delete(self):
+ """
+ This function assumes you already have the cache lock!
+ """
+ assert self.exists()
+ self.fileStore.logToMaster('CACHE: Deleting the harbinger file for (%s)' %
+ self.fileStoreID, logging.DEBUG)
+ os.remove(self.harbingerFileName)
+
+ # Functions related to async updates
+ def asyncWrite(self):
+ """
+ A function to write files asynchronously to the job store such that subsequent jobs are
+ not delayed by a long write operation.
+ """
+ try:
+ while True:
+ try:
+ # Block for up to two seconds waiting for a file
+ args = self.queue.get(timeout=2)
+ except Empty:
+ # Check if termination event is signaled
+ # (set in the event of an exception in the worker)
+ if self._terminateEvent.isSet():
+ raise RuntimeError("The termination flag is set, exiting")
+ continue
+ # Normal termination condition is getting None from queue
+ if args is None:
+ break
+ inputFileHandle, jobStoreFileID = args
+ cachedFileName = self.encodedFileID(jobStoreFileID)
+ # Ensure that the harbinger exists in the cache directory and that the PID
+ # matches that of this writing thread.
+ # If asyncWrite is ported to subprocesses instead of threads in the future,
+ # insert logic here to securely overwrite the harbinger file.
+ harbingerFile = self.HarbingerFile(self, cachedFileName=cachedFileName)
+ assert harbingerFile.exists()
+ assert harbingerFile.read() == int(os.getpid())
+ # We pass in a fileHandle, rather than the file-name, in case
+ # the file itself is deleted. The fileHandle itself should persist
+ # while we maintain the open file handle
+ with self.jobStore.updateFileStream(jobStoreFileID) as outputFileHandle:
+ shutil.copyfileobj(inputFileHandle, outputFileHandle)
+ inputFileHandle.close()
+ # Remove the file from the lock files
+ with self._pendingFileWritesLock:
+ self._pendingFileWrites.remove(jobStoreFileID)
+ # Remove the harbinger file
+ harbingerFile.delete()
+ except:
+ self._terminateEvent.set()
+ raise
+
+ def _updateJobWhenDone(self):
+ """
+ Asynchronously update the status of the job on the disk, first waiting \
+ until the writing threads have finished and the input blockFn has stopped \
+ blocking.
+ """
+
+ def asyncUpdate():
+ try:
+ # Wait till all file writes have completed
+ for i in xrange(len(self.workers)):
+ self.queue.put(None)
+
+ for thread in self.workers:
+ thread.join()
+
+ # Wait till input block-fn returns - in the event of an exception
+ # this will eventually terminate
+ self.inputBlockFn()
+
+ # Check the terminate event, if set we can not guarantee
+ # that the workers ended correctly, therefore we exit without
+ # completing the update
+ if self._terminateEvent.isSet():
+ raise RuntimeError("The termination flag is set, exiting before update")
+
+ # Indicate any files that should be deleted once the update of
+ # the job wrapper is completed.
+ self.jobGraph.filesToDelete = list(self.filesToDelete)
+
+ # Complete the job
+ self.jobStore.update(self.jobGraph)
+
+ # Delete any remnant jobs
+ map(self.jobStore.delete, self.jobsToDelete)
+
+ # Delete any remnant files
+ map(self.jobStore.deleteFile, self.filesToDelete)
+
+ # Remove the files to delete list, having successfully removed the files
+ if len(self.filesToDelete) > 0:
+ self.jobGraph.filesToDelete = []
+ # Update, removing emptying files to delete
+ self.jobStore.update(self.jobGraph)
+ except:
+ self._terminateEvent.set()
+ raise
+ finally:
+ # Indicate that _blockFn can return
+ # This code will always run
+ self.updateSemaphore.release()
+
+ # The update semaphore is held while the job is written to the job store
+ try:
+ self.updateSemaphore.acquire()
+ t = Thread(target=asyncUpdate)
+ t.start()
+ except:
+ # This is to ensure that the semaphore is released in a crash to stop a deadlock
+ # scenario
+ self.updateSemaphore.release()
+ raise
+
+ def _blockFn(self):
+ self.updateSemaphore.acquire()
+ self.updateSemaphore.release() # Release so that the block function can be recalled
+ # This works, because once acquired the semaphore will not be acquired
+ # by _updateJobWhenDone again.
+ return
+
+ @classmethod
+ def shutdown(cls, dir_):
+ """
+ :param dir_: The directory that will contain the cache state file.
+ """
+ cacheInfo = cls._CacheState._load(os.path.join(dir_, '_cacheState'))
+ cls.findAndHandleDeadJobs(cacheInfo, batchSystemShutdown=True)
+ shutil.rmtree(dir_)
+
+ def __del__(self):
+ """
+ Cleanup function that is run when destroying the class instance that ensures that all the
+ file writing threads exit.
+ """
+ self.updateSemaphore.acquire()
+ for i in xrange(len(self.workers)):
+ self.queue.put(None)
+ for thread in self.workers:
+ thread.join()
+ self.updateSemaphore.release()
+
+
+class NonCachingFileStore(FileStore):
+ def __init__(self, jobStore, jobGraph, localTempDir, inputBlockFn):
+ self.jobStore = jobStore
+ self.jobGraph = jobGraph
+ self.jobName = str(self.jobGraph)
+ self.localTempDir = os.path.abspath(localTempDir)
+ self.inputBlockFn = inputBlockFn
+ self.jobsToDelete = set()
+ self.loggingMessages = []
+ self.filesToDelete = set()
+ super(NonCachingFileStore, self).__init__(jobStore, jobGraph, localTempDir, inputBlockFn)
+ # This will be defined in the `open` method.
+ self.jobStateFile = None
+ self.localFileMap = defaultdict(list)
+
+ @contextmanager
+ def open(self, job):
+ jobReqs = job.disk
+ startingDir = os.getcwd()
+ self.localTempDir = makePublicDir(os.path.join(self.localTempDir, str(uuid.uuid4())))
+ self.findAndHandleDeadJobs(self.workFlowDir)
+ self.jobStateFile = self._createJobStateFile()
+ try:
+ os.chdir(self.localTempDir)
+ yield
+ finally:
+ diskUsed = getDirSizeRecursively(self.localTempDir)
+ logString = ("Job {jobName} used {percent:.2f}% ({humanDisk}B [{disk}B] used, "
+ "{humanRequestedDisk}B [{requestedDisk}B] requested) at the end of "
+ "its run.".format(jobName=self.jobName,
+ percent=(float(diskUsed) / jobReqs * 100 if
+ jobReqs > 0 else 0.0),
+ humanDisk=bytes2human(diskUsed),
+ disk=diskUsed,
+ humanRequestedDisk=bytes2human(jobReqs),
+ requestedDisk=jobReqs))
+ self.logToMaster(logString, level=logging.DEBUG)
+ if diskUsed > jobReqs:
+ self.logToMaster("Job used more disk than requested. Cconsider modifying the user "
+ "script to avoid the chance of failure due to incorrectly "
+ "requested resources. " + logString, level=logging.WARNING)
+ os.chdir(startingDir)
+ jobState = self._readJobState(self.jobStateFile)
+ deferredFunctions = jobState['deferredFunctions']
+ failures = self._runDeferredFunctions(deferredFunctions)
+ for failure in failures:
+ self.logToMaster('Deferred function "%s" failed.' % failure, logging.WARN)
+ # Finally delete the job from the worker
+ os.remove(self.jobStateFile)
+
+ def writeGlobalFile(self, localFileName, cleanup=False):
+ absLocalFileName = self._resolveAbsoluteLocalPath(localFileName)
+ cleanupID = None if not cleanup else self.jobGraph.jobStoreID
+ fileStoreID = self.jobStore.writeFile(absLocalFileName, cleanupID)
+ self.localFileMap[fileStoreID].append(absLocalFileName)
+ return FileID.forPath(fileStoreID, absLocalFileName)
+
+ def readGlobalFile(self, fileStoreID, userPath=None, cache=True, mutable=None):
+ if userPath is not None:
+ localFilePath = self._resolveAbsoluteLocalPath(userPath)
+ if os.path.exists(localFilePath):
+ raise RuntimeError(' File %s ' % localFilePath + ' exists. Cannot Overwrite.')
+ else:
+ localFilePath = self.getLocalTempFileName()
+
+ self.jobStore.readFile(fileStoreID, localFilePath)
+ self.localFileMap[fileStoreID].append(localFilePath)
+ return localFilePath
+
+ @contextmanager
+ def readGlobalFileStream(self, fileStoreID):
+ with self.jobStore.readFileStream(fileStoreID) as f:
+ yield f
+
+ def deleteLocalFile(self, fileStoreID):
+ try:
+ localFilePaths = self.localFileMap.pop(fileStoreID)
+ except KeyError:
+ raise OSError(errno.ENOENT, "Attempting to delete a non-local file")
+ else:
+ for localFilePath in localFilePaths:
+ os.remove(localFilePath)
+
+ def deleteGlobalFile(self, fileStoreID):
+ try:
+ self.deleteLocalFile(fileStoreID)
+ except OSError as e:
+ if e.errno == errno.ENOENT:
+ # the file does not exist locally, so no local deletion necessary
+ pass
+ else:
+ raise
+ self.filesToDelete.add(fileStoreID)
+
+ def _blockFn(self):
+ # there is no asynchronicity in this file store so no need to block at all
+ return True
+
+ def _updateJobWhenDone(self):
+ try:
+ # Indicate any files that should be deleted once the update of
+ # the job wrapper is completed.
+ self.jobGraph.filesToDelete = list(self.filesToDelete)
+ # Complete the job
+ self.jobStore.update(self.jobGraph)
+ # Delete any remnant jobs
+ map(self.jobStore.delete, self.jobsToDelete)
+ # Delete any remnant files
+ map(self.jobStore.deleteFile, self.filesToDelete)
+ # Remove the files to delete list, having successfully removed the files
+ if len(self.filesToDelete) > 0:
+ self.jobGraph.filesToDelete = []
+ # Update, removing emptying files to delete
+ self.jobStore.update(self.jobGraph)
+ except:
+ self._terminateEvent.set()
+ raise
+
+ def __del__(self):
+ """
+ Cleanup function that is run when destroying the class instance. Nothing to do since there
+ are no async write events.
+ """
+ pass
+
+ # Functions related to the deferred function logic
+ @classmethod
+ def findAndHandleDeadJobs(cls, nodeInfo, batchSystemShutdown=False):
+ """
+ Look at the state of all jobs registered in the individual job state files, and handle them
+ (clean up the disk, and run any registered defer functions)
+
+ :param str nodeInfo: The location of the workflow directory on the node.
+ :param bool batchSystemShutdown: Is the batch system in the process of shutting down?
+ :return:
+ """
+ # A list of tuples of (job name, pid or process running job, registered defer functions)
+ for jobState in cls._getAllJobStates(nodeInfo):
+ if not cls._pidExists(jobState['jobPID']):
+ # using same logic to prevent races as CachingFileStore._setupCache
+ myPID = str(os.getpid())
+ cleanupFile = os.path.join(jobState['jobDir'], '.cleanup')
+ with open(os.path.join(jobState['jobDir'], '.' + myPID), 'w') as f:
+ f.write(myPID)
+ while True:
+ try:
+ os.rename(f.name, cleanupFile)
+ except OSError as err:
+ if err.errno == errno.ENOTEMPTY:
+ with open(cleanupFile, 'r') as f:
+ cleanupPID = f.read()
+ if cls._pidExists(int(cleanupPID)):
+ # Cleanup your own mess. It's only polite.
+ os.remove(f.name)
+ break
+ else:
+ os.remove(cleanupFile)
+ continue
+ else:
+ raise
+ else:
+ logger.warning('Detected that job (%s) prematurely terminated. Fixing the '
+ 'state of the job on disk.', jobState['jobName'])
+ if not batchSystemShutdown:
+ logger.debug("Deleting the stale working directory.")
+ # Delete the old work directory if it still exists. Do this only during
+ # the life of the program and dont' do it during the batch system
+ # cleanup. Leave that to the batch system cleanup code.
+ shutil.rmtree(jobState['jobDir'])
+ # Run any deferred functions associated with the job
+ logger.debug('Running user-defined deferred functions.')
+ cls._runDeferredFunctions(jobState['deferredFunctions'])
+ break
+
+ @staticmethod
+ def _getAllJobStates(workflowDir):
+ """
+ Generator function that deserializes and yields the job state for every job on the node,
+ one at a time.
+
+ :param str workflowDir: The location of the workflow directory on the node.
+ :return: dict with keys (jobName, jobPID, jobDir, deferredFunctions)
+ :rtype: dict
+ """
+ jobStateFiles = []
+ for root, dirs, files in os.walk(workflowDir):
+ for filename in files:
+ if filename == '.jobState':
+ jobStateFiles.append(os.path.join(root, filename))
+ for filename in jobStateFiles:
+ yield NonCachingFileStore._readJobState(filename)
+
+ @staticmethod
+ def _readJobState(jobStateFileName):
+ with open(jobStateFileName) as fH:
+ state = dill.load(fH)
+ return state
+
+ def _registerDeferredFunction(self, deferredFunction):
+ with open(self.jobStateFile) as fH:
+ jobState = dill.load(fH)
+ jobState['deferredFunctions'].append(deferredFunction)
+ with open(self.jobStateFile + '.tmp', 'w') as fH:
+ dill.dump(jobState, fH)
+ os.rename(self.jobStateFile + '.tmp', self.jobStateFile)
+ logger.debug('Registered "%s" with job "%s".', deferredFunction, self.jobName)
+
+ def _createJobStateFile(self):
+ """
+ Create the job state file for the current job and fill in the required
+ values.
+
+ :return: Path to the job state file
+ :rtype: str
+ """
+ jobStateFile = os.path.join(self.localTempDir, '.jobState')
+ jobState = {'jobPID': os.getpid(),
+ 'jobName': self.jobName,
+ 'jobDir': self.localTempDir,
+ 'deferredFunctions': []}
+ with open(jobStateFile + '.tmp', 'w') as fH:
+ dill.dump(jobState, fH)
+ os.rename(jobStateFile + '.tmp', jobStateFile)
+ return jobStateFile
+
+ @classmethod
+ def shutdown(cls, dir_):
+ """
+ :param dir_: The workflow directory that will contain all the individual worker directories.
+ """
+ cls.findAndHandleDeadJobs(dir_, batchSystemShutdown=True)
+
+
+class FileID(str):
+ """
+ A class to wrap the job store file id returned by writeGlobalFile and any attributes we may want
+ to add to it.
+ """
+ def __new__(cls, fileStoreID, *args):
+ return super(FileID, cls).__new__(cls, fileStoreID)
+
+ def __init__(self, fileStoreID, size):
+ super(FileID, self).__init__(fileStoreID)
+ self.size = size
+
+ @classmethod
+ def forPath(cls, fileStoreID, filePath):
+ return cls(fileStoreID, os.stat(filePath).st_size)
+
+
+def shutdownFileStore(workflowDir, workflowID):
+ """
+ Run the deferred functions from any prematurely terminated jobs still lingering on the system
+ and carry out any necessary filestore-specific cleanup.
+
+ This is a destructive operation and it is important to ensure that there are no other running
+ processes on the system that are modifying or using the file store for this workflow.
+
+
+ This is the intended to be the last call to the file store in a Toil run, called by the
+ batch system cleanup function upon batch system shutdown.
+
+ :param str workflowDir: The path to the cache directory
+ :param str workflowID: The workflow ID for this invocation of the workflow
+ """
+ cacheDir = os.path.join(workflowDir, cacheDirName(workflowID))
+ if os.path.exists(cacheDir):
+ # The presence of the cacheDir suggests this was a cached run. We don't need the cache lock
+ # for any of this since this is the final cleanup of a job and there should be no other
+ # conflicting processes using the cache.
+ CachingFileStore.shutdown(cacheDir)
+ else:
+ # This absence of cacheDir suggests otherwise.
+ NonCachingFileStore.shutdown(workflowDir)
+
+
+class CacheError(Exception):
+ """
+ Error Raised if the user attempts to add a non-local file to cache
+ """
+
+ def __init__(self, message):
+ super(CacheError, self).__init__(message)
+
+
+class CacheUnbalancedError(CacheError):
+ """
+ Raised if file store can't free enough space for caching
+ """
+ message = 'Unable unable to free enough space for caching'
+ def __init__(self):
+ super(CacheUnbalancedError, self).__init__(self.message)
+
+
+class IllegalDeletionCacheError(CacheError):
+ """
+ Error Raised if the Toil detects the user deletes a cached file
+ """
+
+ def __init__(self, deletedFile):
+ message = 'Cache tracked file (%s) deleted explicitly by user. Use deleteLocalFile to ' \
+ 'delete such files.' % deletedFile
+ super(IllegalDeletionCacheError, self).__init__(message)
+
+
+class InvalidSourceCacheError(CacheError):
+ """
+ Error Raised if the user attempts to add a non-local file to cache
+ """
+
+ def __init__(self, message):
+ super(InvalidSourceCacheError, self).__init__(message)
diff --git a/src/toil/job.py b/src/toil/job.py
new file mode 100644
index 0000000..406f0bd
--- /dev/null
+++ b/src/toil/job.py
@@ -0,0 +1,1727 @@
+# Copyright (C) 2015-2016 Regents of the University of California
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import, print_function
+
+import collections
+import importlib
+import inspect
+import logging
+import os
+import sys
+import time
+import uuid
+import dill
+
+from abc import ABCMeta, abstractmethod
+from argparse import ArgumentParser
+from contextlib import contextmanager
+from io import BytesIO
+
+# Python 3 compatibility imports
+from six.moves import cPickle
+from six import iteritems, string_types
+
+from bd2k.util.exceptions import require
+from bd2k.util.expando import Expando
+from bd2k.util.humanize import human2bytes
+
+from toil.common import Toil, addOptions
+from toil.fileStore import DeferredFunction
+from toil.lib.bioio import (setLoggingFromOptions,
+ getTotalCpuTimeAndMemoryUsage,
+ getTotalCpuTime)
+from toil.resource import ModuleDescriptor
+
+logger = logging.getLogger( __name__ )
+
+
+class JobLikeObject(object):
+ """
+ Inherit from this class to add requirement properties to a job (or job-like) object.
+ If the object doesn't specify explicit requirements, these properties will fall back
+ to the configured defaults. If the value cannot be determined, an AttributeError is raised.
+ """
+ def __init__(self, requirements, unitName, jobName=None):
+ cores = requirements.get('cores')
+ memory = requirements.get('memory')
+ disk = requirements.get('disk')
+ preemptable = requirements.get('preemptable')
+ if unitName is not None:
+ assert isinstance(unitName, str)
+ if jobName is not None:
+ assert isinstance(jobName, str)
+ self.unitName = unitName
+ self.jobName = jobName if jobName is not None else self.__class__.__name__
+ self._cores = self._parseResource('cores', cores)
+ self._memory = self._parseResource('memory', memory)
+ self._disk = self._parseResource('disk', disk)
+ self._preemptable = preemptable
+ self._config = None
+
+ @property
+ def disk(self):
+ """
+ The maximum number of bytes of disk the job will require to run.
+ """
+ if self._disk is not None:
+ return self._disk
+ elif self._config is not None:
+ return self._config.defaultDisk
+ else:
+ raise AttributeError("Default value for 'disk' cannot be determined")
+
+ @property
+ def memory(self):
+ """
+ The maximum number of bytes of memory the job will require to run.
+ """
+ if self._memory is not None:
+ return self._memory
+ elif self._config is not None:
+ return self._config.defaultMemory
+ else:
+ raise AttributeError("Default value for 'memory' cannot be determined")
+
+ @property
+ def cores(self):
+ """
+ The number of CPU cores required.
+ """
+ if self._cores is not None:
+ return self._cores
+ elif self._config is not None:
+ return self._config.defaultCores
+ else:
+ raise AttributeError("Default value for 'cores' cannot be determined")
+
+ @property
+ def preemptable(self):
+ """
+ Whether the job can be run on a preemptable node.
+ """
+ if self._preemptable is not None:
+ return self._preemptable
+ elif self._config is not None:
+ return self._config.defaultPreemptable
+ else:
+ raise AttributeError("Default value for 'preemptable' cannot be determined")
+
+ @property
+ def _requirements(self):
+ """
+ Gets a dictionary of all the object's resource requirements. Unset values are defaulted to None
+ """
+ return {'memory': getattr(self, 'memory', None),
+ 'cores': getattr(self, 'cores', None),
+ 'disk': getattr(self, 'disk', None),
+ 'preemptable': getattr(self, 'preemptable', None)}
+
+ @staticmethod
+ def _parseResource(name, value):
+ """
+ Parse a Toil job's resource requirement value and apply resource-specific type checks. If the
+ value is a string, a binary or metric unit prefix in it will be evaluated and the
+ corresponding integral value will be returned.
+
+ :param str name: The name of the resource
+
+ :param None|str|float|int value: The resource value
+
+ :rtype: int|float|None
+
+ >>> Job._parseResource('cores', None)
+ >>> Job._parseResource('cores', 1), Job._parseResource('disk', 1), \
+ Job._parseResource('memory', 1)
+ (1, 1, 1)
+ >>> Job._parseResource('cores', '1G'), Job._parseResource('disk', '1G'), \
+ Job._parseResource('memory', '1G')
+ (1073741824, 1073741824, 1073741824)
+ >>> Job._parseResource('cores', 1.1)
+ 1.1
+ >>> Job._parseResource('disk', 1.1)
+ Traceback (most recent call last):
+ ...
+ TypeError: The 'disk' requirement does not accept values that are of <type 'float'>
+ >>> Job._parseResource('memory', object())
+ Traceback (most recent call last):
+ ...
+ TypeError: The 'memory' requirement does not accept values that are of <type 'object'>
+ """
+ assert name in ('memory', 'disk', 'cores')
+ if value is None:
+ return value
+ elif isinstance(value, str):
+ value = human2bytes(value)
+ if isinstance(value, int):
+ return value
+ elif isinstance(value, float) and name == 'cores':
+ return value
+ else:
+ raise TypeError("The '%s' requirement does not accept values that are of %s"
+ % (name, type(value)))
+
+ def __str__(self):
+ printedName = "'" + self.jobName + "'"
+ if self.unitName:
+ printedName += ' ' + self.unitName
+ elif self.unitName == '':
+ printedName += ' ' + 'user passed empty string for name'
+ return printedName
+
+
+class JobNode(JobLikeObject):
+ """
+ This object bridges the job graph, job, and batchsystem classes
+ """
+ def __init__(self, requirements, jobName, unitName, jobStoreID,
+ command, predecessorNumber=1):
+ super(JobNode, self).__init__(requirements=requirements, unitName=unitName, jobName=jobName)
+ self.jobStoreID = jobStoreID
+ self.predecessorNumber = predecessorNumber
+ self.command = command
+
+ def __str__(self):
+ return super(JobNode, self).__str__() + ' ' + self.jobStoreID
+
+ def __hash__(self):
+ return hash(self.jobStoreID)
+
+ def __eq__(self, other):
+ if isinstance(other, self.__class__):
+ return self.__dict__ == other.__dict__
+ return NotImplemented
+
+ def __ne__(self, other):
+ if isinstance(other, self.__class__):
+ return not self.__eq__(other)
+ return NotImplemented
+
+ def __repr__(self):
+ return '%s( **%r )' % (self.__class__.__name__, self.__dict__)
+
+ @classmethod
+ def fromJobGraph(cls, jobGraph):
+ """
+ Takes a job graph object and returns a job node object
+ :param toil.jobGraph.JobGraph jobGraph: A job graph object to be transformed into a job node
+ :return: A job node object
+ :rtype: toil.job.JobNode
+ """
+ return cls(jobStoreID=jobGraph.jobStoreID,
+ requirements=jobGraph._requirements,
+ command=jobGraph.command,
+ jobName=jobGraph.jobName,
+ unitName=jobGraph.unitName,
+ predecessorNumber=jobGraph.predecessorNumber)
+
+ @classmethod
+ def fromJob(cls, job, command, predecessorNumber):
+ """
+ Build a job node from a job object
+ :param toil.job.Job job: the job object to be transformed into a job node
+ :param str command: the JobNode's command
+ :param int predecessorNumber: the number of predecessors that must finish
+ successfully before the job can be scheduled
+ :return: a JobNode object representing the job object parameter
+ :rtype: toil.job.JobNode
+ """
+ return cls(jobStoreID=None,
+ requirements=job._requirements,
+ command=command,
+ jobName=job.jobName,
+ unitName=job.unitName,
+ predecessorNumber=predecessorNumber)
+
+
+class Job(JobLikeObject):
+ """
+ Class represents a unit of work in toil.
+ """
+ def __init__(self, memory=None, cores=None, disk=None, preemptable=None, unitName=None,
+ checkpoint=False):
+ """
+ This method must be called by any overriding constructor.
+
+ :param memory: the maximum number of bytes of memory the job will require to run.
+ :param cores: the number of CPU cores required.
+ :param disk: the amount of local disk space required by the job, expressed in bytes.
+ :param preemptable: if the job can be run on a preemptable node.
+ :param checkpoint: if any of this job's successor jobs completely fails,
+ exhausting all their retries, remove any successor jobs and rerun this job to restart the
+ subtree. Job must be a leaf vertex in the job graph when initially defined, see
+ :func:`toil.job.Job.checkNewCheckpointsAreCutVertices`.
+ :type cores: int or string convertable by bd2k.util.humanize.human2bytes to an int
+ :type disk: int or string convertable by bd2k.util.humanize.human2bytes to an int
+ :type preemptable: boolean
+ :type cache: int or string convertable by bd2k.util.humanize.human2bytes to an int
+ :type memory: int or string convertable by bd2k.util.humanize.human2bytes to an int
+ """
+ requirements = {'memory': memory, 'cores': cores, 'disk': disk,
+ 'preemptable': preemptable}
+ super(Job, self).__init__(requirements=requirements, unitName=unitName)
+ self.checkpoint = checkpoint
+ #Private class variables
+
+ #See Job.addChild
+ self._children = []
+ #See Job.addFollowOn
+ self._followOns = []
+ #See Job.addService
+ self._services = []
+ #A follow-on, service or child of a job A, is a "direct successor" of A; if B
+ #is a direct successor of A, then A is a "direct predecessor" of B.
+ self._directPredecessors = set()
+ # Note that self.__module__ is not necessarily this module, i.e. job.py. It is the module
+ # defining the class self is an instance of, which may be a subclass of Job that may be
+ # defined in a different module.
+ self.userModule = ModuleDescriptor.forModule(self.__module__)
+ # Maps index paths into composite return values to lists of IDs of files containing
+ # promised values for those return value items. An index path is a tuple of indices that
+ # traverses a nested data structure of lists, dicts, tuples or any other type supporting
+ # the __getitem__() protocol.. The special key `()` (the empty tuple) represents the
+ # entire return value.
+ self._rvs = collections.defaultdict(list)
+ self._promiseJobStore = None
+ self.fileStore = None
+
+ def run(self, fileStore):
+ """
+ Override this function to perform work and dynamically create successor jobs.
+
+ :param toil.fileStore.FileStore fileStore: Used to create local and globally
+ sharable temporary files and to send log messages to the leader process.
+
+ :return: The return value of the function can be passed to other jobs by means of
+ :func:`toil.job.Job.rv`.
+ """
+ pass
+
+ def addChild(self, childJob):
+ """
+ Adds childJob to be run as child of this job. Child jobs will be run \
+ directly after this job's :func:`toil.job.Job.run` method has completed.
+
+ :param toil.job.Job childJob:
+ :return: childJob
+ :rtype: toil.job.Job
+ """
+ self._children.append(childJob)
+ childJob._addPredecessor(self)
+ return childJob
+
+ def hasChild(self, childJob):
+ """
+ Check if childJob is already a child of this job.
+
+ :param toil.job.Job childJob:
+ :return: True if childJob is a child of the job, else False.
+ :rtype: Boolean
+ """
+ return childJob in self._children
+
+ def addFollowOn(self, followOnJob):
+ """
+ Adds a follow-on job, follow-on jobs will be run after the child jobs and \
+ their successors have been run.
+
+ :param toil.job.Job followOnJob:
+ :return: followOnJob
+ :rtype: toil.job.Job
+ """
+ self._followOns.append(followOnJob)
+ followOnJob._addPredecessor(self)
+ return followOnJob
+
+ def addService(self, service, parentService=None):
+ """
+ Add a service.
+
+ The :func:`toil.job.Job.Service.start` method of the service will be called
+ after the run method has completed but before any successors are run.
+ The service's :func:`toil.job.Job.Service.stop` method will be called once
+ the successors of the job have been run.
+
+ Services allow things like databases and servers to be started and accessed
+ by jobs in a workflow.
+
+ :raises toil.job.JobException: If service has already been made the child of a job or another service.
+ :param toil.job.Job.Service service: Service to add.
+ :param toil.job.Job.Service parentService: Service that will be started before 'service' is
+ started. Allows trees of services to be established. parentService must be a service
+ of this job.
+ :return: a promise that will be replaced with the return value from
+ :func:`toil.job.Job.Service.start` of service in any successor of the job.
+ :rtype: toil.job.Promise
+ """
+ if parentService is not None:
+ # Do check to ensure that parentService is a service of this job
+ def check(services):
+ for jS in services:
+ if jS.service == parentService or check(jS.service._childServices):
+ return True
+ return False
+ if not check(self._services):
+ raise JobException("Parent service is not a service of the given job")
+ return parentService._addChild(service)
+ else:
+ if service._hasParent:
+ raise JobException("The service already has a parent service")
+ service._hasParent = True
+ jobService = ServiceJob(service)
+ self._services.append(jobService)
+ return jobService.rv()
+
+ ##Convenience functions for creating jobs
+
+ def addChildFn(self, fn, *args, **kwargs):
+ """
+ Adds a function as a child job.
+
+ :param fn: Function to be run as a child job with ``*args`` and ``**kwargs`` as \
+ arguments to this function. See toil.job.FunctionWrappingJob for reserved \
+ keyword arguments used to specify resource requirements.
+ :return: The new child job that wraps fn.
+ :rtype: toil.job.FunctionWrappingJob
+ """
+ if PromisedRequirement.convertPromises(kwargs):
+ return self.addChild(PromisedRequirementFunctionWrappingJob.create(fn, *args, **kwargs))
+ else:
+ return self.addChild(FunctionWrappingJob(fn, *args, **kwargs))
+
+ def addFollowOnFn(self, fn, *args, **kwargs):
+ """
+ Adds a function as a follow-on job.
+
+ :param fn: Function to be run as a follow-on job with ``*args`` and ``**kwargs`` as \
+ arguments to this function. See toil.job.FunctionWrappingJob for reserved \
+ keyword arguments used to specify resource requirements.
+ :return: The new follow-on job that wraps fn.
+ :rtype: toil.job.FunctionWrappingJob
+ """
+ if PromisedRequirement.convertPromises(kwargs):
+ return self.addFollowOn(PromisedRequirementFunctionWrappingJob.create(fn, *args, **kwargs))
+ else:
+ return self.addFollowOn(FunctionWrappingJob(fn, *args, **kwargs))
+
+ def addChildJobFn(self, fn, *args, **kwargs):
+ """
+ Adds a job function as a child job. See :class:`toil.job.JobFunctionWrappingJob`
+ for a definition of a job function.
+
+ :param fn: Job function to be run as a child job with ``*args`` and ``**kwargs`` as \
+ arguments to this function. See toil.job.JobFunctionWrappingJob for reserved \
+ keyword arguments used to specify resource requirements.
+ :return: The new child job that wraps fn.
+ :rtype: toil.job.JobFunctionWrappingJob
+ """
+ if PromisedRequirement.convertPromises(kwargs):
+ return self.addChild(PromisedRequirementJobFunctionWrappingJob.create(fn, *args, **kwargs))
+ else:
+ return self.addChild(JobFunctionWrappingJob(fn, *args, **kwargs))
+
+ def addFollowOnJobFn(self, fn, *args, **kwargs):
+ """
+ Add a follow-on job function. See :class:`toil.job.JobFunctionWrappingJob`
+ for a definition of a job function.
+
+ :param fn: Job function to be run as a follow-on job with ``*args`` and ``**kwargs`` as \
+ arguments to this function. See toil.job.JobFunctionWrappingJob for reserved \
+ keyword arguments used to specify resource requirements.
+ :return: The new follow-on job that wraps fn.
+ :rtype: toil.job.JobFunctionWrappingJob
+ """
+ if PromisedRequirement.convertPromises(kwargs):
+ return self.addFollowOn(PromisedRequirementJobFunctionWrappingJob.create(fn, *args, **kwargs))
+ else:
+ return self.addFollowOn(JobFunctionWrappingJob(fn, *args, **kwargs))
+
+ @staticmethod
+ def wrapFn(fn, *args, **kwargs):
+ """
+ Makes a Job out of a function. \
+ Convenience function for constructor of :class:`toil.job.FunctionWrappingJob`.
+
+ :param fn: Function to be run with ``*args`` and ``**kwargs`` as arguments. \
+ See toil.job.JobFunctionWrappingJob for reserved keyword arguments used \
+ to specify resource requirements.
+ :return: The new function that wraps fn.
+ :rtype: toil.job.FunctionWrappingJob
+ """
+ if PromisedRequirement.convertPromises(kwargs):
+ return PromisedRequirementFunctionWrappingJob.create(fn, *args, **kwargs)
+ else:
+ return FunctionWrappingJob(fn, *args, **kwargs)
+
+ @staticmethod
+ def wrapJobFn(fn, *args, **kwargs):
+ """
+ Makes a Job out of a job function. \
+ Convenience function for constructor of :class:`toil.job.JobFunctionWrappingJob`.
+
+ :param fn: Job function to be run with ``*args`` and ``**kwargs`` as arguments. \
+ See toil.job.JobFunctionWrappingJob for reserved keyword arguments used \
+ to specify resource requirements.
+ :return: The new job function that wraps fn.
+ :rtype: toil.job.JobFunctionWrappingJob
+ """
+ if PromisedRequirement.convertPromises(kwargs):
+ return PromisedRequirementJobFunctionWrappingJob.create(fn, *args, **kwargs)
+ else:
+ return JobFunctionWrappingJob(fn, *args, **kwargs)
+
+ def encapsulate(self):
+ """
+ Encapsulates the job, see :class:`toil.job.EncapsulatedJob`.
+ Convenience function for constructor of :class:`toil.job.EncapsulatedJob`.
+
+ :return: an encapsulated version of this job.
+ :rtype: toil.job.EncapsulatedJob.
+ """
+ return EncapsulatedJob(self)
+
+ ####################################################
+ #The following function is used for passing return values between
+ #job run functions
+ ####################################################
+
+ def rv(self, *path):
+ """
+ Creates a *promise* (:class:`toil.job.Promise`) representing a return value of the job's
+ run method, or, in case of a function-wrapping job, the wrapped function's return value.
+
+ :param (Any) path: Optional path for selecting a component of the promised return value.
+ If absent or empty, the entire return value will be used. Otherwise, the first
+ element of the path is used to select an individual item of the return value. For
+ that to work, the return value must be a list, dictionary or of any other type
+ implementing the `__getitem__()` magic method. If the selected item is yet another
+ composite value, the second element of the path can be used to select an item from
+ it, and so on. For example, if the return value is `[6,{'a':42}]`, `.rv(0)` would
+ select `6` , `rv(1)` would select `{'a':3}` while `rv(1,'a')` would select `3`. To
+ select a slice from a return value that is slicable, e.g. tuple or list, the path
+ element should be a `slice` object. For example, assuming that the return value is
+ `[6, 7, 8, 9]` then `.rv(slice(1, 3))` would select `[7, 8]`. Note that slicing
+ really only makes sense at the end of path.
+
+ :return: A promise representing the return value of this jobs :meth:`toil.job.Job.run`
+ method.
+
+ :rtype: toil.job.Promise
+ """
+ return Promise(self, path)
+
+ def registerPromise(self, path):
+ if self._promiseJobStore is None:
+ raise RuntimeError('Trying to pass a promise from a promising job that is not a ' +
+ 'predecessor of the job receiving the promise')
+ jobStoreFileID = self._promiseJobStore.getEmptyFileStoreID()
+ self._rvs[path].append(jobStoreFileID)
+ return self._promiseJobStore.config.jobStore, jobStoreFileID
+
+ def prepareForPromiseRegistration(self, jobStore):
+ """
+ Ensure that a promise by this job (the promissor) can register with the promissor when
+ another job referring to the promise (the promissee) is being serialized. The promissee
+ holds the reference to the promise (usually as part of the the job arguments) and when it
+ is being pickled, so will the promises it refers to. Pickling a promise triggers it to be
+ registered with the promissor.
+
+ :return:
+ """
+ self._promiseJobStore = jobStore
+
+ ####################################################
+ #Cycle/connectivity checking
+ ####################################################
+
+ def checkJobGraphForDeadlocks(self):
+ """
+ See :func:`toil.job.Job.checkJobGraphConnected`,
+ :func:`toil.job.Job.checkJobGraphAcyclic` and
+ :func:`toil.job.Job.checkNewCheckpointsAreLeafVertices` for more info.
+
+ :raises toil.job.JobGraphDeadlockException: if the job graph
+ is cyclic, contains multiple roots or contains checkpoint jobs that are
+ not leaf vertices when defined (see :func:`toil.job.Job.checkNewCheckpointsAreLeaves`).
+ """
+ self.checkJobGraphConnected()
+ self.checkJobGraphAcylic()
+ self.checkNewCheckpointsAreLeafVertices()
+
+ def getRootJobs(self):
+ """
+ :return: The roots of the connected component of jobs that contains this job. \
+ A root is a job with no predecessors.
+
+ :rtype : set of toil.job.Job instances
+ """
+ roots = set()
+ visited = set()
+ #Function to get the roots of a job
+ def getRoots(job):
+ if job not in visited:
+ visited.add(job)
+ if len(job._directPredecessors) > 0:
+ map(lambda p : getRoots(p), job._directPredecessors)
+ else:
+ roots.add(job)
+ #The following call ensures we explore all successor edges.
+ map(lambda c : getRoots(c), job._children +
+ job._followOns)
+ getRoots(self)
+ return roots
+
+ def checkJobGraphConnected(self):
+ """
+ :raises toil.job.JobGraphDeadlockException: if :func:`toil.job.Job.getRootJobs` does \
+ not contain exactly one root job.
+
+ As execution always starts from one root job, having multiple root jobs will \
+ cause a deadlock to occur.
+ """
+ rootJobs = self.getRootJobs()
+ if len(rootJobs) != 1:
+ raise JobGraphDeadlockException("Graph does not contain exactly one"
+ " root job: %s" % rootJobs)
+
+ def checkJobGraphAcylic(self):
+ """
+ :raises toil.job.JobGraphDeadlockException: if the connected component \
+ of jobs containing this job contains any cycles of child/followOn dependencies \
+ in the *augmented job graph* (see below). Such cycles are not allowed \
+ in valid job graphs.
+
+ A follow-on edge (A, B) between two jobs A and B is equivalent \
+ to adding a child edge to B from (1) A, (2) from each child of A, \
+ and (3) from the successors of each child of A. We call each such edge \
+ an edge an "implied" edge. The augmented job graph is a job graph including \
+ all the implied edges.
+
+ For a job graph G = (V, E) the algorithm is ``O(|V|^2)``. It is ``O(|V| + |E|)`` for \
+ a graph with no follow-ons. The former follow-on case could be improved!
+ """
+ #Get the root jobs
+ roots = self.getRootJobs()
+ if len(roots) == 0:
+ raise JobGraphDeadlockException("Graph contains no root jobs due to cycles")
+
+ #Get implied edges
+ extraEdges = self._getImpliedEdges(roots)
+
+ #Check for directed cycles in the augmented graph
+ visited = set()
+ for root in roots:
+ root._checkJobGraphAcylicDFS([], visited, extraEdges)
+
+ def checkNewCheckpointsAreLeafVertices(self):
+ """
+ A checkpoint job is a job that is restarted if either it fails, or if any of \
+ its successors completely fails, exhausting their retries.
+
+ A job is a leaf it is has no successors.
+
+ A checkpoint job must be a leaf when initially added to the job graph. When its \
+ run method is invoked it can then create direct successors. This restriction is made
+ to simplify implementation.
+
+ :raises toil.job.JobGraphDeadlockException: if there exists a job being added to the graph for which \
+ checkpoint=True and which is not a leaf.
+ """
+ roots = self.getRootJobs() # Roots jobs of component, these are preexisting jobs in the graph
+
+ # All jobs in the component of the job graph containing self
+ jobs = set()
+ map(lambda x : x._dfs(jobs), roots)
+
+ # Check for each job for which checkpoint is true that it is a cut vertex or leaf
+ for y in filter(lambda x : x.checkpoint, jobs):
+ if y not in roots: # The roots are the prexisting jobs
+ if len(y._children) != 0 and len(y._followOns) != 0 and len(y._services) != 0:
+ raise JobGraphDeadlockException("New checkpoint job %s is not a leaf in the job graph" % y)
+
+ def defer(self, function, *args, **kwargs):
+ """
+ Register a deferred function, i.e. a callable that will be invoked after the current
+ attempt at running this job concludes. A job attempt is said to conclude when the job
+ function (or the :meth:`Job.run` method for class-based jobs) returns, raises an
+ exception or after the process running it terminates abnormally. A deferred function will
+ be called on the node that attempted to run the job, even if a subsequent attempt is made
+ on another node. A deferred function should be idempotent because it may be called
+ multiple times on the same node or even in the same process. More than one deferred
+ function may be registered per job attempt by calling this method repeatedly with
+ different arguments. If the same function is registered twice with the same or different
+ arguments, it will be called twice per job attempt.
+
+ Examples for deferred functions are ones that handle cleanup of resources external to
+ Toil, like Docker containers, files outside the work directory, etc.
+
+ :param callable function: The function to be called after this job concludes.
+
+ :param list args: The arguments to the function
+
+ :param dict kwargs: The keyword arguments to the function
+ """
+ require( self.fileStore is not None, 'A deferred function may only be registered with a '
+ 'job while that job is running.')
+ self.fileStore._registerDeferredFunction(DeferredFunction.create(function, *args, **kwargs))
+
+
+ ####################################################
+ #The following nested classes are used for
+ #creating jobtrees (Job.Runner),
+ #and defining a service (Job.Service)
+ ####################################################
+
+ class Runner(object):
+ """
+ Used to setup and run Toil workflow.
+ """
+ @staticmethod
+ def getDefaultArgumentParser():
+ """
+ Get argument parser with added toil workflow options.
+
+ :returns: The argument parser used by a toil workflow with added Toil options.
+ :rtype: :class:`argparse.ArgumentParser`
+ """
+ parser = ArgumentParser()
+ Job.Runner.addToilOptions(parser)
+ return parser
+
+ @staticmethod
+ def getDefaultOptions(jobStore):
+ """
+ Get default options for a toil workflow.
+
+ :param string jobStore: A string describing the jobStore \
+ for the workflow.
+ :returns: The options used by a toil workflow.
+ :rtype: argparse.ArgumentParser values object
+ """
+ parser = Job.Runner.getDefaultArgumentParser()
+ return parser.parse_args(args=[jobStore])
+
+ @staticmethod
+ def addToilOptions(parser):
+ """
+ Adds the default toil options to an :mod:`optparse` or :mod:`argparse`
+ parser object.
+
+ :param parser: Options object to add toil options to.
+ :type parser: optparse.OptionParser or argparse.ArgumentParser
+ """
+ addOptions(parser)
+
+ @staticmethod
+ def startToil(job, options):
+ """
+ Deprecated by toil.common.Toil.run. Runs the toil workflow using the given options
+ (see Job.Runner.getDefaultOptions and Job.Runner.addToilOptions) starting with this
+ job.
+ :param toil.job.Job job: root job of the workflow
+ :raises: toil.leader.FailedJobsException if at the end of function \
+ their remain failed jobs.
+ :return: The return value of the root job's run function.
+ :rtype: Any
+ """
+ setLoggingFromOptions(options)
+ with Toil(options) as toil:
+ if not options.restart:
+ return toil.start(job)
+ else:
+ return toil.restart()
+
+ class Service(JobLikeObject):
+ """
+ Abstract class used to define the interface to a service.
+ """
+ __metaclass__ = ABCMeta
+ def __init__(self, memory=None, cores=None, disk=None, preemptable=None, unitName=None):
+ """
+ Memory, core and disk requirements are specified identically to as in \
+ :func:`toil.job.Job.__init__`.
+ """
+ requirements = {'memory': memory, 'cores': cores, 'disk': disk,
+ 'preemptable': preemptable}
+ super(Job.Service, self).__init__(requirements=requirements, unitName=unitName)
+ self._childServices = []
+ self._hasParent = False
+
+ @abstractmethod
+ def start(self, job):
+ """
+ Start the service.
+
+ :param toil.job.Job job: The underlying job that is being run. Can be used to register
+ deferred functions, or to access the fileStore for creating temporary files.
+
+ :returns: An object describing how to access the service. The object must be pickleable \
+ and will be used by jobs to access the service (see :func:`toil.job.Job.addService`).
+ """
+ pass
+
+ @abstractmethod
+ def stop(self, job):
+ """
+ Stops the service. Function can block until complete.
+
+ :param toil.job.Job job: The underlying job that is being run. Can be used to register
+ deferred functions, or to access the fileStore for creating temporary files.
+ """
+ pass
+
+ def check(self):
+ """
+ Checks the service is still running.
+
+ :raise RuntimeError: If the service failed, this will cause the service job to be labeled failed.
+ :returns: True if the service is still running, else False. If False then the service job will be terminated,
+ and considered a success. Important point: if the service job exits due to a failure, it should raise a
+ RuntimeError, not return False!
+ """
+ pass
+
+ def _addChild(self, service):
+ """
+ Add a child service to start up after this service has started. This should not be
+ called by the user, instead use :func:`toil.job.Job.Service.addService` with the
+ ``parentService`` option.
+
+ :raises toil.job.JobException: If service has already been made the child of a job or another service.
+ :param toil.job.Job.Service service: Service to add as a "child" of this service
+ :return: a promise that will be replaced with the return value from \
+ :func:`toil.job.Job.Service.start` of service after the service has started.
+ :rtype: toil.job.Promise
+ """
+ if service._hasParent:
+ raise JobException("The service already has a parent service")
+ service._parent = True
+ jobService = ServiceJob(service)
+ self._childServices.append(jobService)
+ return jobService.rv()
+
+ ####################################################
+ #Private functions
+ ####################################################
+
+ def _addPredecessor(self, predecessorJob):
+ """
+ Adds a predecessor job to the set of predecessor jobs. Raises a \
+ RuntimeError if the job is already a predecessor.
+ """
+ if predecessorJob in self._directPredecessors:
+ raise RuntimeError("The given job is already a predecessor of this job")
+ self._directPredecessors.add(predecessorJob)
+
+ @classmethod
+ def _loadUserModule(cls, userModule):
+ """
+ Imports and returns the module object represented by the given module descriptor.
+
+ :type userModule: ModuleDescriptor
+ """
+ return userModule.load()
+
+ @classmethod
+ def _loadJob(cls, command, jobStore):
+ """
+ Unpickles a :class:`toil.job.Job` instance by decoding command.
+
+ The command is a reference to a jobStoreFileID containing the \
+ pickle file for the job and a list of modules which must be imported so that \
+ the Job can be successfully unpickled. \
+ See :func:`toil.job.Job._serialiseFirstJob` and \
+ :func:`toil.job.Job._makeJobGraphs` to see precisely how the Job is encoded \
+ in the command.
+
+ :param string command: encoding of the job in the job store.
+ :param toil.jobStores.abstractJobStore.AbstractJobStore jobStore: The job store.
+ :returns: The job referenced by the command.
+ :rtype: toil.job.Job
+ """
+ commandTokens = command.split()
+ assert "_toil" == commandTokens[0]
+ userModule = ModuleDescriptor.fromCommand(commandTokens[2:])
+ logger.debug('Loading user module %s.', userModule)
+ userModule = cls._loadUserModule(userModule)
+ pickleFile = commandTokens[1]
+ if pickleFile == "firstJob":
+ openFileStream = jobStore.readSharedFileStream(pickleFile)
+ else:
+ openFileStream = jobStore.readFileStream(pickleFile)
+ with openFileStream as fileHandle:
+ return cls._unpickle(userModule, fileHandle, jobStore.config)
+
+
+ @classmethod
+ def _unpickle(cls, userModule, fileHandle, config):
+ """
+ Unpickles an object graph from the given file handle while loading symbols \
+ referencing the __main__ module from the given userModule instead.
+
+ :param userModule:
+ :param fileHandle:
+ :returns:
+ """
+ unpickler = cPickle.Unpickler(fileHandle)
+
+ def filter_main(module_name, class_name):
+ if module_name == '__main__':
+ logger.debug('Getting %s from user module __main__ (%s).', class_name, userModule)
+ return getattr(userModule, class_name)
+ else:
+ logger.debug('Getting %s from module %s.', class_name, module_name)
+ return getattr(importlib.import_module(module_name), class_name)
+
+ unpickler.find_global = filter_main
+ runnable = unpickler.load()
+ assert isinstance(runnable, JobLikeObject)
+ runnable._config = config
+ return runnable
+
+ def getUserScript(self):
+ return self.userModule
+
+ def _fulfillPromises(self, returnValues, jobStore):
+ """
+ Sets the values for promises using the return values from this job's run() function.
+ """
+ for path, promiseFileStoreIDs in iteritems(self._rvs):
+ if not path:
+ # Note that its possible for returnValues to be a promise, not an actual return
+ # value. This is the case if the job returns a promise from another job. In
+ # either case, we just pass it on.
+ promisedValue = returnValues
+ else:
+ # If there is an path ...
+ if isinstance(returnValues, Promise):
+ # ... and the value itself is a Promise, we need to created a new, narrower
+ # promise and pass it on.
+ promisedValue = Promise(returnValues.job, path)
+ else:
+ # Otherwise, we just select the desired component of the return value.
+ promisedValue = returnValues
+ for index in path:
+ promisedValue = promisedValue[index]
+ for promiseFileStoreID in promiseFileStoreIDs:
+ # File may be gone if the job is a service being re-run and the accessing job is
+ # already complete.
+ if jobStore.fileExists(promiseFileStoreID):
+ with jobStore.updateFileStream(promiseFileStoreID) as fileHandle:
+ cPickle.dump(promisedValue, fileHandle, cPickle.HIGHEST_PROTOCOL)
+
+ # Functions associated with Job.checkJobGraphAcyclic to establish that the job graph does not
+ # contain any cycles of dependencies:
+
+ def _dfs(self, visited):
+ """
+ Adds the job and all jobs reachable on a directed path from current node to the given set.
+ """
+ if self not in visited:
+ visited.add(self)
+ for successor in self._children + self._followOns:
+ successor._dfs(visited)
+
+ def _checkJobGraphAcylicDFS(self, stack, visited, extraEdges):
+ """
+ DFS traversal to detect cycles in augmented job graph.
+ """
+ if self not in visited:
+ visited.add(self)
+ stack.append(self)
+ for successor in self._children + self._followOns + extraEdges[self]:
+ successor._checkJobGraphAcylicDFS(stack, visited, extraEdges)
+ assert stack.pop() == self
+ if self in stack:
+ stack.append(self)
+ raise JobGraphDeadlockException("A cycle of job dependencies has been detected '%s'" % stack)
+
+ @staticmethod
+ def _getImpliedEdges(roots):
+ """
+ Gets the set of implied edges. See Job.checkJobGraphAcylic
+ """
+ #Get nodes in job graph
+ nodes = set()
+ for root in roots:
+ root._dfs(nodes)
+
+ ##For each follow-on edge calculate the extra implied edges
+ #Adjacency list of implied edges, i.e. map of jobs to lists of jobs
+ #connected by an implied edge
+ extraEdges = dict(map(lambda n : (n, []), nodes))
+ for job in nodes:
+ if len(job._followOns) > 0:
+ #Get set of jobs connected by a directed path to job, starting
+ #with a child edge
+ reacheable = set()
+ for child in job._children:
+ child._dfs(reacheable)
+ #Now add extra edges
+ for descendant in reacheable:
+ extraEdges[descendant] += job._followOns[:]
+ return extraEdges
+
+ ####################################################
+ #The following functions are used to serialise
+ #a job graph to the jobStore
+ ####################################################
+
+ def _createEmptyJobGraphForJob(self, jobStore, command=None, predecessorNumber=0):
+ """
+ Create an empty job for the job.
+ """
+ # set _config to determine user determined default values for resource requirements
+ self._config = jobStore.config
+ return jobStore.create(JobNode.fromJob(self, command=command,
+ predecessorNumber=predecessorNumber))
+
+ def _makeJobGraphs(self, jobGraph, jobStore):
+ """
+ Creates a jobGraph for each job in the job graph, recursively.
+ """
+ jobsToJobGraphs = {self:jobGraph}
+ for successors in (self._followOns, self._children):
+ jobs = map(lambda successor:
+ successor._makeJobGraphs2(jobStore, jobsToJobGraphs), successors)
+ jobGraph.stack.append(jobs)
+ return jobsToJobGraphs
+
+ def _makeJobGraphs2(self, jobStore, jobsToJobGraphs):
+ #Make the jobGraph for the job, if necessary
+ if self not in jobsToJobGraphs:
+ jobGraph = self._createEmptyJobGraphForJob(jobStore, predecessorNumber=len(self._directPredecessors))
+ jobsToJobGraphs[self] = jobGraph
+ #Add followOns/children to be run after the current job.
+ for successors in (self._followOns, self._children):
+ jobs = map(lambda successor:
+ successor._makeJobGraphs2(jobStore, jobsToJobGraphs), successors)
+ jobGraph.stack.append(jobs)
+ else:
+ jobGraph = jobsToJobGraphs[self]
+ #The return is a tuple stored within a job.stack
+ #The tuple is jobStoreID, memory, cores, disk,
+ #The predecessorID is used to establish which predecessors have been
+ #completed before running the given Job - it is just a unique ID
+ #per predecessor
+ return JobNode.fromJobGraph(jobGraph)
+
+ def getTopologicalOrderingOfJobs(self):
+ """
+ :returns: a list of jobs such that for all pairs of indices i, j for which i < j, \
+ the job at index i can be run before the job at index j.
+ :rtype: list
+ """
+ ordering = []
+ visited = set()
+ def getRunOrder(job):
+ #Do not add the job to the ordering until all its predecessors have been
+ #added to the ordering
+ for p in job._directPredecessors:
+ if p not in visited:
+ return
+ if job not in visited:
+ visited.add(job)
+ ordering.append(job)
+ map(getRunOrder, job._children + job._followOns)
+ getRunOrder(self)
+ return ordering
+
+ def _serialiseJob(self, jobStore, jobsToJobGraphs, rootJobGraph):
+ """
+ Pickle a job and its jobGraph to disk.
+ """
+ # Pickle the job so that its run method can be run at a later time.
+ # Drop out the children/followOns/predecessors/services - which are
+ # all recorded within the jobStore and do not need to be stored within
+ # the job
+ self._children, self._followOns, self._services = [], [], []
+ self._directPredecessors, self._promiseJobStore = set(), None
+ # The pickled job is "run" as the command of the job, see worker
+ # for the mechanism which unpickles the job and executes the Job.run
+ # method.
+ with jobStore.writeFileStream(rootJobGraph.jobStoreID) as (fileHandle, fileStoreID):
+ cPickle.dump(self, fileHandle, cPickle.HIGHEST_PROTOCOL)
+ # Note that getUserScript() may have been overridden. This is intended. If we used
+ # self.userModule directly, we'd be getting a reference to job.py if the job was
+ # specified as a function (as opposed to a class) since that is where FunctionWrappingJob
+ # is defined. What we really want is the module that was loaded as __main__,
+ # and FunctionWrappingJob overrides getUserScript() to give us just that. Only then can
+ # filter_main() in _unpickle( ) do its job of resolving any user-defined type or function.
+ userScript = self.getUserScript().globalize()
+ jobsToJobGraphs[self].command = ' '.join(('_toil', fileStoreID) + userScript.toCommand())
+ #Update the status of the jobGraph on disk
+ jobStore.update(jobsToJobGraphs[self])
+
+ def _serialiseServices(self, jobStore, jobGraph, rootJobGraph):
+ """
+ Serialises the services for a job.
+ """
+ def processService(serviceJob, depth):
+ # Extend the depth of the services if necessary
+ if depth == len(jobGraph.services):
+ jobGraph.services.append([])
+
+ # Recursively call to process child services
+ for childServiceJob in serviceJob.service._childServices:
+ processService(childServiceJob, depth+1)
+
+ # Make a job wrapper
+ serviceJobGraph = serviceJob._createEmptyJobGraphForJob(jobStore, predecessorNumber=1)
+
+ # Create the start and terminate flags
+ serviceJobGraph.startJobStoreID = jobStore.getEmptyFileStoreID()
+ serviceJobGraph.terminateJobStoreID = jobStore.getEmptyFileStoreID()
+ serviceJobGraph.errorJobStoreID = jobStore.getEmptyFileStoreID()
+ assert jobStore.fileExists(serviceJobGraph.startJobStoreID)
+ assert jobStore.fileExists(serviceJobGraph.terminateJobStoreID)
+ assert jobStore.fileExists(serviceJobGraph.errorJobStoreID)
+
+ # Create the service job tuple
+ j = ServiceJobNode(jobStoreID=serviceJobGraph.jobStoreID,
+ memory=serviceJobGraph.memory, cores=serviceJobGraph.cores,
+ disk=serviceJobGraph.disk, startJobStoreID=serviceJobGraph.startJobStoreID,
+ terminateJobStoreID=serviceJobGraph.terminateJobStoreID,
+ errorJobStoreID=serviceJobGraph.errorJobStoreID,
+ jobName=serviceJobGraph.jobName, unitName=serviceJobGraph.unitName,
+ command=serviceJobGraph.command,
+ predecessorNumber=serviceJobGraph.predecessorNumber)
+
+ # Add the service job tuple to the list of services to run
+ jobGraph.services[depth].append(j)
+
+ # Break the links between the services to stop them being serialised together
+ #childServices = serviceJob.service._childServices
+ serviceJob.service._childServices = None
+ assert serviceJob._services == []
+ #service = serviceJob.service
+
+ # Pickle the job
+ serviceJob.pickledService = cPickle.dumps(serviceJob.service)
+ serviceJob.service = None
+
+ # Serialise the service job and job wrapper
+ serviceJob._serialiseJob(jobStore, { serviceJob:serviceJobGraph }, rootJobGraph)
+
+ # Restore values
+ #serviceJob.service = service
+ #serviceJob.service._childServices = childServices
+
+ for serviceJob in self._services:
+ processService(serviceJob, 0)
+
+ self._services = []
+
+ def _serialiseJobGraph(self, jobGraph, jobStore, returnValues, firstJob):
+ """
+ Pickle the graph of jobs in the jobStore. The graph is not fully serialised \
+ until the jobGraph itself is written to disk, this is not performed by this \
+ function because of the need to coordinate this operation with other updates. \
+ """
+ #Check if the job graph has created
+ #any cycles of dependencies or has multiple roots
+ self.checkJobGraphForDeadlocks()
+
+ #Create the jobGraphs for followOns/children
+ jobsToJobGraphs = self._makeJobGraphs(jobGraph, jobStore)
+ #Get an ordering on the jobs which we use for pickling the jobs in the
+ #correct order to ensure the promises are properly established
+ ordering = self.getTopologicalOrderingOfJobs()
+ assert len(ordering) == len(jobsToJobGraphs)
+
+ # Temporarily set the jobStore locators for the promise call back functions
+ for job in ordering:
+ job.prepareForPromiseRegistration(jobStore)
+ def setForServices(serviceJob):
+ serviceJob.prepareForPromiseRegistration(jobStore)
+ for childServiceJob in serviceJob.service._childServices:
+ setForServices(childServiceJob)
+ for serviceJob in job._services:
+ setForServices(serviceJob)
+
+ ordering.reverse()
+ assert self == ordering[-1]
+ if firstJob:
+ #If the first job we serialise all the jobs, including the root job
+ for job in ordering:
+ # Pickle the services for the job
+ job._serialiseServices(jobStore, jobsToJobGraphs[job], jobGraph)
+ # Now pickle the job
+ job._serialiseJob(jobStore, jobsToJobGraphs, jobGraph)
+ else:
+ #We store the return values at this point, because if a return value
+ #is a promise from another job, we need to register the promise
+ #before we serialise the other jobs
+ self._fulfillPromises(returnValues, jobStore)
+ #Pickle the non-root jobs
+ for job in ordering[:-1]:
+ # Pickle the services for the job
+ job._serialiseServices(jobStore, jobsToJobGraphs[job], jobGraph)
+ # Pickle the job itself
+ job._serialiseJob(jobStore, jobsToJobGraphs, jobGraph)
+ # Pickle any services for the job
+ self._serialiseServices(jobStore, jobGraph, jobGraph)
+
+ def _serialiseFirstJob(self, jobStore):
+ """
+ Serialises the root job. Returns the wrapping job.
+
+ :param toil.jobStores.abstractJobStore.AbstractJobStore jobStore:
+ """
+ # Create first jobGraph
+ jobGraph = self._createEmptyJobGraphForJob(jobStore=jobStore, predecessorNumber=0)
+ # Write the graph of jobs to disk
+ self._serialiseJobGraph(jobGraph, jobStore, None, True)
+ jobStore.update(jobGraph)
+ # Store the name of the first job in a file in case of restart. Up to this point the
+ # root job is not recoverable. FIXME: "root job" or "first job", which one is it?
+ jobStore.setRootJob(jobGraph.jobStoreID)
+ return jobGraph
+
+ def _serialiseExistingJob(self, jobGraph, jobStore, returnValues):
+ """
+ Serialise an existing job.
+ """
+ self._serialiseJobGraph(jobGraph, jobStore, returnValues, False)
+ #Drop the completed command, if not dropped already
+ jobGraph.command = None
+ #Merge any children (follow-ons) created in the initial serialisation
+ #with children (follow-ons) created in the subsequent scale-up.
+ assert len(jobGraph.stack) >= 4
+ combinedChildren = jobGraph.stack[-1] + jobGraph.stack[-3]
+ combinedFollowOns = jobGraph.stack[-2] + jobGraph.stack[-4]
+ jobGraph.stack = jobGraph.stack[:-4]
+ if len(combinedFollowOns) > 0:
+ jobGraph.stack.append(combinedFollowOns)
+ if len(combinedChildren) > 0:
+ jobGraph.stack.append(combinedChildren)
+
+ ####################################################
+ #Function which worker calls to ultimately invoke
+ #a jobs Job.run method, and then handle created
+ #children/followOn jobs
+ ####################################################
+
+ def _run(self, jobGraph, fileStore):
+ return self.run(fileStore)
+
+ @contextmanager
+ def _executor(self, jobGraph, stats, fileStore):
+ """
+ This is the core wrapping method for running the job within a worker. It sets up the stats
+ and logging before yielding. After completion of the body, the function will finish up the
+ stats and logging, and starts the async update process for the job.
+ """
+ if stats is not None:
+ startTime = time.time()
+ startClock = getTotalCpuTime()
+ baseDir = os.getcwd()
+
+ yield
+
+ # If the job is not a checkpoint job, add the promise files to delete
+ # to the list of jobStoreFileIDs to delete
+ if not self.checkpoint:
+ for jobStoreFileID in Promise.filesToDelete:
+ fileStore.deleteGlobalFile(jobStoreFileID)
+ else:
+ # Else copy them to the job wrapper to delete later
+ jobGraph.checkpointFilesToDelete = list(Promise.filesToDelete)
+ Promise.filesToDelete.clear()
+ # Now indicate the asynchronous update of the job can happen
+ fileStore._updateJobWhenDone()
+ # Change dir back to cwd dir, if changed by job (this is a safety issue)
+ if os.getcwd() != baseDir:
+ os.chdir(baseDir)
+ # Finish up the stats
+ if stats is not None:
+ totalCpuTime, totalMemoryUsage = getTotalCpuTimeAndMemoryUsage()
+ stats.jobs.append(
+ Expando(
+ time=str(time.time() - startTime),
+ clock=str(totalCpuTime - startClock),
+ class_name=self._jobName(),
+ memory=str(totalMemoryUsage)
+ )
+ )
+
+ def _runner(self, jobGraph, jobStore, fileStore):
+ """
+ This method actually runs the job, and serialises the next jobs.
+
+ :param class jobGraph: Instance of a jobGraph object
+ :param class jobStore: Instance of the job store
+ :param toil.fileStore.FileStore fileStore: Instance of a Cached on uncached
+ filestore
+ :return:
+ """
+ # Run the job
+ returnValues = self._run(jobGraph, fileStore)
+ # Serialize the new jobs defined by the run method to the jobStore
+ self._serialiseExistingJob(jobGraph, jobStore, returnValues)
+
+ def _jobName(self):
+ """
+ :rtype : string, used as identifier of the job class in the stats report.
+ """
+ return self.__class__.__name__
+
+
+class JobException( Exception ):
+ """
+ General job exception.
+ """
+ def __init__( self, message ):
+ super( JobException, self ).__init__( message )
+
+
+class JobGraphDeadlockException( JobException ):
+ """
+ An exception raised in the event that a workflow contains an unresolvable \
+ dependency, such as a cycle. See :func:`toil.job.Job.checkJobGraphForDeadlocks`.
+ """
+ def __init__( self, string ):
+ super( JobGraphDeadlockException, self ).__init__( string )
+
+
+class FunctionWrappingJob(Job):
+ """
+ Job used to wrap a function. In its `run` method the wrapped function is called.
+ """
+ def __init__(self, userFunction, *args, **kwargs):
+ """
+ :param callable userFunction: The function to wrap. It will be called with ``*args`` and
+ ``**kwargs`` as arguments.
+
+ The keywords ``memory``, ``cores``, ``disk``, ``preemptable`` and ``checkpoint`` are
+ reserved keyword arguments that if specified will be used to determine the resources
+ required for the job, as :func:`toil.job.Job.__init__`. If they are keyword arguments to
+ the function they will be extracted from the function definition, but may be overridden
+ by the user (as you would expect).
+ """
+ # Use the user-specified requirements, if specified, else grab the default argument
+ # from the function, if specified, else default to None
+ argSpec = inspect.getargspec(userFunction)
+ if argSpec.defaults is None:
+ argDict = {}
+ else:
+ argDict = dict(zip(argSpec.args[-len(argSpec.defaults):], argSpec.defaults))
+
+ def resolve(key, default=None, dehumanize=False):
+ try:
+ # First, try constructor arguments, ...
+ value = kwargs.pop(key)
+ except KeyError:
+ try:
+ # ..., then try default value for function keyword arguments, ...
+ value = argDict[key]
+ except KeyError:
+ # ... and finally fall back to a default value.
+ value = default
+ # Optionally, convert strings with metric or binary prefixes.
+ if dehumanize and isinstance(value, string_types):
+ value = human2bytes(value)
+ return value
+
+ Job.__init__(self,
+ memory=resolve('memory', dehumanize=True),
+ cores=resolve('cores', dehumanize=True),
+ disk=resolve('disk', dehumanize=True),
+ preemptable=resolve('preemptable'),
+ checkpoint=resolve('checkpoint', default=False),
+ unitName=resolve('name', default=None))
+
+ self.userFunctionModule = ModuleDescriptor.forModule(userFunction.__module__).globalize()
+ self.userFunctionName = str(userFunction.__name__)
+ self.jobName = self.userFunctionName
+ self._args = args
+ self._kwargs = kwargs
+
+ def _getUserFunction(self):
+ logger.debug('Loading user function %s from module %s.',
+ self.userFunctionName,
+ self.userFunctionModule)
+ userFunctionModule = self._loadUserModule(self.userFunctionModule)
+ return getattr(userFunctionModule, self.userFunctionName)
+
+ def run(self,fileStore):
+ userFunction = self._getUserFunction( )
+ return userFunction(*self._args, **self._kwargs)
+
+ def getUserScript(self):
+ return self.userFunctionModule
+
+ def _jobName(self):
+ return ".".join((self.__class__.__name__,self.userFunctionModule.name,self.userFunctionName))
+
+
+class JobFunctionWrappingJob(FunctionWrappingJob):
+ """
+ A job function is a function whose first argument is a :class:`job.Job` \
+ instance that is the wrapping job for the function. This can be used to \
+ add successor jobs for the function and perform all the functions the \
+ :class:`job.Job` class provides.
+
+ To enable the job function to get access to the :class:`toil.fileStore.FileStore` \
+ instance (see :func:`toil.job.Job.Run`), it is made a variable of the wrapping job \
+ called fileStore.
+ """
+ def run(self, fileStore):
+ userFunction = self._getUserFunction()
+ self.fileStore = fileStore
+ rValue = userFunction(*((self,) + tuple(self._args)), **self._kwargs)
+ return rValue
+
+
+class PromisedRequirementFunctionWrappingJob(FunctionWrappingJob):
+ """
+ Handles dynamic resource allocation using :class:`toil.job.Promise` instances.
+ Spawns child function using parent function parameters and fulfilled promised
+ resource requirements.
+ """
+ def __init__(self, userFunction, *args, **kwargs):
+ self._promisedKwargs = kwargs.copy()
+ # Replace resource requirements in intermediate job with small values.
+ kwargs.update(dict(disk='1M', memory='32M', cores=0.1))
+ super(PromisedRequirementFunctionWrappingJob, self).__init__(userFunction, *args, **kwargs)
+
+ @classmethod
+ def create(cls, userFunction, *args, **kwargs):
+ """
+ Creates an encapsulated Toil job function with unfulfilled promised resource
+ requirements. After the promises are fulfilled, a child job function is created
+ using updated resource values. The subgraph is encapsulated to ensure that this
+ child job function is run before other children in the workflow. Otherwise, a
+ different child may try to use an unresolved promise return value from the parent.
+ """
+ return EncapsulatedJob(cls(userFunction, *args, **kwargs))
+
+ def run(self, fileStore):
+ # Assumes promises are fulfilled when parent job is run
+ self.evaluatePromisedRequirements()
+ userFunction = self._getUserFunction()
+ return self.addChildFn(userFunction, *self._args, **self._promisedKwargs).rv()
+
+ def evaluatePromisedRequirements(self):
+ requirements = ["disk", "memory", "cores"]
+ # Fulfill resource requirement promises
+ for requirement in requirements:
+ try:
+ if isinstance(self._promisedKwargs[requirement], PromisedRequirement):
+ self._promisedKwargs[requirement] = self._promisedKwargs[requirement].getValue()
+ except KeyError:
+ pass
+
+
+class PromisedRequirementJobFunctionWrappingJob(PromisedRequirementFunctionWrappingJob):
+ """
+ Handles dynamic resource allocation for job functions.
+ See :class:`toil.job.JobFunctionWrappingJob`
+ """
+
+ def run(self, fileStore):
+ self.evaluatePromisedRequirements()
+ userFunction = self._getUserFunction()
+ return self.addChildJobFn(userFunction, *self._args, **self._promisedKwargs).rv()
+
+
+class EncapsulatedJob(Job):
+ """
+ A convenience Job class used to make a job subgraph appear to be a single job.
+
+ Let A be the root job of a job subgraph and B be another job we'd like to run after A
+ and all its successors have completed, for this use encapsulate::
+
+ # Job A and subgraph, Job B
+ A, B = A(), B()
+ A' = A.encapsulate()
+ A'.addChild(B)
+ # B will run after A and all its successors have completed, A and its subgraph of
+ # successors in effect appear to be just one job.
+
+ The return value of an encapsulatd job (as accessed by the :func:`toil.job.Job.rv` function)
+ is the return value of the root job, e.g. A().encapsulate().rv() and A().rv() will resolve to
+ the same value after A or A.encapsulate() has been run.
+ """
+ def __init__(self, job):
+ """
+ :param toil.job.Job job: the job to encapsulate.
+ """
+ # Giving the root of the subgraph the same resources as the first job in the subgraph.
+ Job.__init__(self, **job._requirements)
+ self.encapsulatedJob = job
+ Job.addChild(self, job)
+ # Use small resource requirements for dummy Job instance.
+ self.encapsulatedFollowOn = Job(disk='1M', memory='32M', cores=0.1)
+ Job.addFollowOn(self, self.encapsulatedFollowOn)
+
+ def addChild(self, childJob):
+ return Job.addChild(self.encapsulatedFollowOn, childJob)
+
+ def addService(self, service, parentService=None):
+ return Job.addService(self.encapsulatedFollowOn, service, parentService=parentService)
+
+ def addFollowOn(self, followOnJob):
+ return Job.addFollowOn(self.encapsulatedFollowOn, followOnJob)
+
+ def rv(self, *path):
+ return self.encapsulatedJob.rv(*path)
+
+ def prepareForPromiseRegistration(self, jobStore):
+ super(EncapsulatedJob, self).prepareForPromiseRegistration(jobStore)
+ self.encapsulatedJob.prepareForPromiseRegistration(jobStore)
+
+ def getUserScript(self):
+ return self.encapsulatedJob.getUserScript()
+
+
+class ServiceJobNode(JobNode):
+ def __init__(self, jobStoreID, memory, cores, disk, startJobStoreID, terminateJobStoreID,
+ errorJobStoreID, unitName, jobName, command, predecessorNumber):
+ requirements = dict(memory=memory, cores=cores, disk=disk, preemptable=False)
+ super(ServiceJobNode, self).__init__(unitName=unitName, jobName=jobName,
+ requirements=requirements,
+ jobStoreID=jobStoreID,
+ command=command,
+ predecessorNumber=predecessorNumber)
+ self.startJobStoreID = startJobStoreID
+ self.terminateJobStoreID = terminateJobStoreID
+ self.errorJobStoreID = errorJobStoreID
+
+
+class ServiceJob(Job):
+ """
+ Job used to wrap a :class:`toil.job.Job.Service` instance.
+ """
+ def __init__(self, service):
+ """
+ This constructor should not be called by a user.
+
+ :param service: The service to wrap in a job.
+ :type service: toil.job.Job.Service
+ """
+ Job.__init__(self, **service._requirements)
+ # service.__module__ is the module defining the class service is an instance of.
+ self.serviceModule = ModuleDescriptor.forModule(service.__module__).globalize()
+
+ #The service to run - this will be replace before serialization with a pickled version
+ self.service = service
+ self.pickledService = None
+ self.jobName = service.jobName
+ # This references the parent job wrapper. It is initialised just before
+ # the job is run. It is used to access the start and terminate flags.
+ self.jobGraph = None
+
+ def run(self, fileStore):
+
+ # we need access to the filestore from underneath the service job
+ self.fileStore = fileStore
+
+ # Unpickle the service
+ logger.debug('Loading service module %s.', self.serviceModule)
+ userModule = self._loadUserModule(self.serviceModule)
+ service = self._unpickle( userModule, BytesIO( self.pickledService ), fileStore.jobStore.config )
+ #Start the service
+ startCredentials = service.start(self)
+ try:
+ #The start credentials must be communicated to processes connecting to
+ #the service, to do this while the run method is running we
+ #cheat and set the return value promise within the run method
+ self._fulfillPromises(startCredentials, fileStore.jobStore)
+ self._rvs = {} # Set this to avoid the return values being updated after the
+ #run method has completed!
+
+ #Now flag that the service is running jobs can connect to it
+ logger.debug("Removing the start jobStoreID to indicate that establishment of the service")
+ assert self.jobGraph.startJobStoreID != None
+ if fileStore.jobStore.fileExists(self.jobGraph.startJobStoreID):
+ fileStore.jobStore.deleteFile(self.jobGraph.startJobStoreID)
+ assert not fileStore.jobStore.fileExists(self.jobGraph.startJobStoreID)
+
+ #Now block until we are told to stop, which is indicated by the removal
+ #of a file
+ assert self.jobGraph.terminateJobStoreID != None
+ while True:
+ # Check for the terminate signal
+ if not fileStore.jobStore.fileExists(self.jobGraph.terminateJobStoreID):
+ logger.debug("Detected that the terminate jobStoreID has been removed so exiting")
+ if not fileStore.jobStore.fileExists(self.jobGraph.errorJobStoreID):
+ raise RuntimeError("Detected the error jobStoreID has been removed so exiting with an error")
+ break
+
+ # Check the service's status and exit if failed or complete
+ try:
+ if not service.check():
+ logger.debug("The service has finished okay, exiting")
+ break
+ except RuntimeError:
+ logger.debug("Detected termination of the service")
+ raise
+
+ time.sleep(fileStore.jobStore.config.servicePollingInterval) #Avoid excessive polling
+
+ # Remove link to the jobGraph
+ self.jobGraph = None
+
+ logger.debug("Service is done")
+ finally:
+ # The stop function is always called
+ service.stop(self)
+
+ def _run(self, jobGraph, fileStore):
+ # Set the jobGraph for the job
+ self.jobGraph = jobGraph
+ #Run the job
+ returnValues = self.run(fileStore)
+ assert jobGraph.stack == []
+ assert jobGraph.services == []
+ # Unset the jobGraph for the job
+ self.jobGraph = None
+ # Set the stack to mimic what would be expected for a non-service job (this is a hack)
+ jobGraph.stack = [[], []]
+ return returnValues
+
+ def getUserScript(self):
+ return self.serviceModule
+
+
+class Promise(object):
+ """
+ References a return value from a :meth:`toil.job.Job.run` or
+ :meth:`toil.job.Job.Service.start` method as a *promise* before the method itself is run.
+
+ Let T be a job. Instances of :class:`Promise` (termed a *promise*) are returned by T.rv(),
+ which is used to reference the return value of T's run function. When the promise is passed
+ to the constructor (or as an argument to a wrapped function) of a different, successor job
+ the promise will be replaced by the actual referenced return value. This mechanism allows a
+ return values from one job's run method to be input argument to job before the former job's
+ run function has been executed.
+ """
+ _jobstore = None
+ """
+ Caches the job store instance used during unpickling to prevent it from being instantiated
+ for each promise
+
+ :type: toil.jobStores.abstractJobStore.AbstractJobStore
+ """
+
+ filesToDelete = set()
+ """
+ A set of IDs of files containing promised values when we know we won't need them anymore
+ """
+ def __init__(self, job, path):
+ """
+ :param Job job: the job whose return value this promise references
+ :param path: see :meth:`Job.rv`
+ """
+ self.job = job
+ self.path = path
+
+ def __reduce__(self):
+ """
+ Called during pickling when a promise (an instance of this class) is about to be be
+ pickled. Returns the Promise class and construction arguments that will be evaluated
+ during unpickling, namely the job store coordinates of a file that will hold the promised
+ return value. By the time the promise is about to be unpickled, that file should be
+ populated.
+ """
+ # The allocation of the file in the job store is intentionally lazy, we only allocate an
+ # empty file in the job store if the promise is actually being pickled. This is done so
+ # that we do not allocate files for promises that are never used.
+ jobStoreLocator, jobStoreFileID = self.job.registerPromise(self.path)
+ # Returning a class object here causes the pickling machinery to attempt to instantiate
+ # the class. We will catch that with __new__ and return an the actual return value instead.
+ return self.__class__, (jobStoreLocator, jobStoreFileID)
+
+ @staticmethod
+ def __new__(cls, *args):
+ assert len(args) == 2
+ if isinstance(args[0], Job):
+ # Regular instantiation when promise is created, before it is being pickled
+ return super(Promise, cls).__new__(cls, *args)
+ else:
+ # Attempted instantiation during unpickling, return promised value instead
+ return cls._resolve(*args)
+
+ @classmethod
+ def _resolve(cls, jobStoreLocator, jobStoreFileID):
+ # Initialize the cached job store if it was never initialized in the current process or
+ # if it belongs to a different workflow that was run earlier in the current process.
+ if cls._jobstore is None or cls._jobstore.config.jobStore != jobStoreLocator:
+ cls._jobstore = Toil.resumeJobStore(jobStoreLocator)
+ cls.filesToDelete.add(jobStoreFileID)
+ with cls._jobstore.readFileStream(jobStoreFileID) as fileHandle:
+ # If this doesn't work then the file containing the promise may not exist or be
+ # corrupted
+ value = cPickle.load(fileHandle)
+ return value
+
+
+class PromisedRequirement(object):
+ def __init__(self, valueOrCallable, *args):
+ """
+ Class for dynamically allocating job function resource requirements involving
+ :class:`toil.job.Promise` instances.
+
+ Use when resource requirements depend on the return value of a parent function.
+ PromisedRequirements can be modified by passing a function that takes the
+ :class:`Promise` as input.
+
+ For example, let f, g, and h be functions. Then a Toil workflow can be
+ defined as follows::
+ A = Job.wrapFn(f)
+ B = A.addChildFn(g, cores=PromisedRequirement(A.rv())
+ C = B.addChildFn(h, cores=PromisedRequirement(lambda x: 2*x, B.rv()))
+
+ :param valueOrCallable: A single Promise instance or a function that
+ takes \*args as input parameters.
+ :param int|Promise \*args: variable length argument list
+ """
+ if hasattr(valueOrCallable, '__call__'):
+ assert len(args) != 0, 'Need parameters for PromisedRequirement function.'
+ func = valueOrCallable
+ else:
+ assert len(args) == 0, 'Define a PromisedRequirement function to handle multiple arguments.'
+ func = lambda x: x
+ args = [valueOrCallable]
+
+ self._func = dill.dumps(func)
+ self._args = list(args)
+
+ def getValue(self):
+ """
+ Returns PromisedRequirement value
+ """
+ func = dill.loads(self._func)
+ return func(*self._args)
+
+ @staticmethod
+ def convertPromises(kwargs):
+ """
+ Returns True if reserved resource keyword is a Promise or
+ PromisedRequirement instance. Converts Promise instance
+ to PromisedRequirement.
+
+ :param kwargs: function keyword arguments
+ :return: bool
+ """
+ requirements = ["disk", "memory", "cores"]
+ foundPromisedRequirement = False
+ for r in requirements:
+ if isinstance(kwargs.get(r), Promise):
+ kwargs[r] = PromisedRequirement(kwargs[r])
+ foundPromisedRequirement = True
+ elif isinstance(kwargs.get(r), PromisedRequirement):
+ foundPromisedRequirement = True
+ return foundPromisedRequirement
diff --git a/src/toil/jobGraph.py b/src/toil/jobGraph.py
new file mode 100644
index 0000000..0c47d60
--- /dev/null
+++ b/src/toil/jobGraph.py
@@ -0,0 +1,148 @@
+# Copyright (C) 2015-2016 Regents of the University of California
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from __future__ import absolute_import
+import logging
+
+from toil.job import JobNode
+
+logger = logging.getLogger( __name__ )
+
+
+class JobGraph(JobNode):
+ """
+ A class encapsulating the minimal state of a Toil job. Instances of this class are persisted
+ in the job store and held in memory by the master. The actual state of job objects in user
+ scripts is persisted separately since it may be much bigger than the state managed by this
+ class and should therefore only be held in memory for brief periods of time.
+ """
+ def __init__(self, command, memory, cores, disk, unitName, jobName, preemptable,
+ jobStoreID, remainingRetryCount, predecessorNumber,
+ filesToDelete=None, predecessorsFinished=None,
+ stack=None, services=None,
+ startJobStoreID=None, terminateJobStoreID=None,
+ errorJobStoreID=None,
+ logJobStoreFileID=None,
+ checkpoint=None,
+ checkpointFilesToDelete=None,
+ chainedJobs=None):
+ requirements = {'memory': memory, 'cores': cores, 'disk': disk,
+ 'preemptable': preemptable}
+ super(JobGraph, self).__init__(command=command,
+ requirements=requirements,
+ unitName=unitName, jobName=jobName,
+ jobStoreID=jobStoreID,
+ predecessorNumber=predecessorNumber)
+
+ # The number of times the job should be retried if it fails This number is reduced by
+ # retries until it is zero and then no further retries are made
+ self.remainingRetryCount = remainingRetryCount
+
+ # This variable is used in creating a graph of jobs. If a job crashes after an update to
+ # the jobGraph but before the list of files to remove is deleted then this list can be
+ # used to clean them up.
+ self.filesToDelete = filesToDelete or []
+
+ # The number of predecessor jobs of a given job. A predecessor is a job which references
+ # this job in its stack.
+ self.predecessorNumber = predecessorNumber
+ # The IDs of predecessors that have finished. When len(predecessorsFinished) ==
+ # predecessorNumber then the job can be run.
+ self.predecessorsFinished = predecessorsFinished or set()
+
+ # The list of successor jobs to run. Successor jobs are stored as jobNodes. Successor
+ # jobs are run in reverse order from the stack.
+ self.stack = stack or []
+
+ # A jobStoreFileID of the log file for a job. This will be none unless the job failed and
+ # the logging has been captured to be reported on the leader.
+ self.logJobStoreFileID = logJobStoreFileID
+
+ # A list of lists of service jobs to run. Each sub list is a list of service jobs
+ # descriptions, each of which is stored as a 6-tuple of the form (jobStoreId, memory,
+ # cores, disk, startJobStoreID, terminateJobStoreID).
+ self.services = services or []
+
+ # An empty file in the jobStore which when deleted is used to signal that the service
+ # should cease.
+ self.terminateJobStoreID = terminateJobStoreID
+
+ # Similarly a empty file which when deleted is used to signal that the service is
+ # established
+ self.startJobStoreID = startJobStoreID
+
+ # An empty file in the jobStore which when deleted is used to signal that the service
+ # should terminate signaling an error.
+ self.errorJobStoreID = errorJobStoreID
+
+ # None, or a copy of the original command string used to reestablish the job after failure.
+ self.checkpoint = checkpoint
+
+ # Files that can not be deleted until the job and its successors have completed
+ self.checkpointFilesToDelete = checkpointFilesToDelete
+
+ # Names of jobs that were run as part of this job's invocation, starting with
+ # this job
+ self.chainedJobs = chainedJobs
+
+ def setupJobAfterFailure(self, config):
+ """
+ Reduce the remainingRetryCount if greater than zero and set the memory
+ to be at least as big as the default memory (in case of exhaustion of memory,
+ which is common).
+ """
+ self.remainingRetryCount = max(0, self.remainingRetryCount - 1)
+ logger.warn("Due to failure we are reducing the remaining retry count of job %s with ID %s to %s",
+ self, self.jobStoreID, self.remainingRetryCount)
+ # Set the default memory to be at least as large as the default, in
+ # case this was a malloc failure (we do this because of the combined
+ # batch system)
+ if self.memory < config.defaultMemory:
+ self._memory = config.defaultMemory
+ logger.warn("We have increased the default memory of the failed job %s to %s bytes",
+ self, self.memory)
+
+ def getLogFileHandle( self, jobStore ):
+ """
+ Returns a context manager that yields a file handle to the log file
+ """
+ return jobStore.readFileStream( self.logJobStoreFileID )
+
+ @classmethod
+ def fromJobNode(cls, jobNode, jobStoreID, tryCount):
+ """
+ Builds a job graph from a given job node
+ :param toil.job.JobNode jobNode: a job node object to build into a job graph
+ :param str jobStoreID: the job store ID to assign to the resulting job graph object
+ :param int tryCount: the number of times the resulting job graph object can be retried after
+ failure
+ :return: The newly created job graph object
+ :rtype: toil.jobGraph.JobGraph
+ """
+ return cls(command=jobNode.command,
+ jobStoreID=jobStoreID,
+ remainingRetryCount=tryCount,
+ predecessorNumber=jobNode.predecessorNumber,
+ unitName=jobNode.unitName, jobName=jobNode.jobName,
+ **jobNode._requirements)
+
+ def __eq__(self, other):
+ return (
+ isinstance(other, self.__class__)
+ and self.remainingRetryCount == other.remainingRetryCount
+ and self.jobStoreID == other.jobStoreID
+ and self.filesToDelete == other.filesToDelete
+ and self.stack == other.stack
+ and self.predecessorNumber == other.predecessorNumber
+ and self.predecessorsFinished == other.predecessorsFinished
+ and self.logJobStoreFileID == other.logJobStoreFileID)
diff --git a/src/toil/jobStores/__init__.py b/src/toil/jobStores/__init__.py
new file mode 100644
index 0000000..20da7b0
--- /dev/null
+++ b/src/toil/jobStores/__init__.py
@@ -0,0 +1,14 @@
+# Copyright (C) 2015-2016 Regents of the University of California
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from __future__ import absolute_import
diff --git a/src/toil/jobStores/abstractJobStore.py b/src/toil/jobStores/abstractJobStore.py
new file mode 100644
index 0000000..fd49f28
--- /dev/null
+++ b/src/toil/jobStores/abstractJobStore.py
@@ -0,0 +1,967 @@
+# Copyright (C) 2015-2016 Regents of the University of California
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from __future__ import absolute_import
+
+import shutil
+
+import re
+from abc import ABCMeta, abstractmethod
+from contextlib import contextmanager, closing
+from datetime import timedelta
+from uuid import uuid4
+
+# Python 3 compatibility imports
+from six import itervalues
+from six.moves.urllib.request import urlopen
+import six.moves.urllib.parse as urlparse
+
+from bd2k.util.retry import retry_http
+
+from toil.job import JobException
+from bd2k.util import memoize
+from bd2k.util.objects import abstractclassmethod
+
+try:
+ import cPickle
+except ImportError:
+ import pickle as cPickle
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class InvalidImportExportUrlException(Exception):
+ def __init__(self, url):
+ """
+ :param urlparse.ParseResult url:
+ """
+ super(InvalidImportExportUrlException, self).__init__(
+ "The URL '%s' is invalid" % url.geturl())
+
+
+class NoSuchJobException(Exception):
+ def __init__(self, jobStoreID):
+ """
+ Indicates that the specified job does not exist
+
+ :param str jobStoreID: the jobStoreID that was mistakenly assumed to exist
+ """
+ super(NoSuchJobException, self).__init__("The job '%s' does not exist" % jobStoreID)
+
+
+class ConcurrentFileModificationException(Exception):
+ def __init__(self, jobStoreFileID):
+ """
+ Indicates that the file was attempted to be modified by multiple processes at once.
+
+ :param str jobStoreFileID: the ID of the file that was modified by multiple workers
+ or processes concurrently
+ """
+ super(ConcurrentFileModificationException, self).__init__(
+ 'Concurrent update to file %s detected.' % jobStoreFileID)
+
+
+class NoSuchFileException(Exception):
+ def __init__(self, jobStoreFileID, customName=None):
+ """
+ Indicates that the specified file does not exist
+
+ :param str jobStoreFileID: the ID of the file that was mistakenly assumed to exist
+
+ :param str customName: optionally, an alternate name for the nonexistent file
+ """
+ if customName is None:
+ message = "File '%s' does not exist" % jobStoreFileID
+ else:
+ message = "File '%s' (%s) does not exist" % (customName, jobStoreFileID)
+ super(NoSuchFileException, self).__init__(message)
+
+
+class NoSuchJobStoreException(Exception):
+ def __init__(self, locator):
+ super(NoSuchJobStoreException, self).__init__(
+ "The job store '%s' does not exist, so there is nothing to restart" % locator)
+
+
+class JobStoreExistsException(Exception):
+ def __init__(self, locator):
+ super(JobStoreExistsException, self).__init__(
+ "The job store '%s' already exists. Use --restart to resume the workflow, or remove "
+ "the job store with 'toil clean' to start the workflow from scratch" % locator)
+
+
+class AbstractJobStore(object):
+ """
+ Represents the physical storage for the jobs and files in a Toil workflow.
+ """
+ __metaclass__ = ABCMeta
+
+ def __init__(self):
+ """
+ Create an instance of the job store. The instance will not be fully functional until
+ either :meth:`.initialize` or :meth:`.resume` is invoked. Note that the :meth:`.destroy`
+ method may be invoked on the object with or without prior invocation of either of these two
+ methods.
+ """
+ self.__config = None
+
+ def initialize(self, config):
+ """
+ Create the physical storage for this job store, allocate a workflow ID and persist the
+ given Toil configuration to the store.
+
+ :param toil.common.Config config: the Toil configuration to initialize this job store
+ with. The given configuration will be updated with the newly allocated workflow ID.
+
+ :raises JobStoreExistsException: if the physical storage for this job store already exists
+ """
+ assert config.workflowID is None
+ config.workflowID = str(uuid4())
+ logger.info("The workflow ID is: '%s'" % config.workflowID)
+ self.__config = config
+ self.writeConfig()
+
+ def writeConfig(self):
+ """
+ Persists the value of the :attr:`.config` attribute to the job store, so that it can be
+ retrieved later by other instances of this class.
+ """
+ with self.writeSharedFileStream('config.pickle', isProtected=False) as fileHandle:
+ cPickle.dump(self.__config, fileHandle, cPickle.HIGHEST_PROTOCOL)
+
+ def resume(self):
+ """
+ Connect this instance to the physical storage it represents and load the Toil configuration
+ into the :attr:`.config` attribute.
+
+ :raises NoSuchJobStoreException: if the physical storage for this job store doesn't exist
+ """
+ with self.readSharedFileStream('config.pickle') as fileHandle:
+ config = cPickle.load(fileHandle)
+ assert config.workflowID is not None
+ self.__config = config
+
+ @property
+ def config(self):
+ """
+ The Toil configuration associated with this job store.
+
+ :rtype: toil.common.Config
+ """
+ return self.__config
+
+ rootJobStoreIDFileName = 'rootJobStoreID'
+
+ def setRootJob(self, rootJobStoreID):
+ """
+ Set the root job of the workflow backed by this job store
+
+ :param str rootJobStoreID: The ID of the job to set as root
+ """
+ with self.writeSharedFileStream(self.rootJobStoreIDFileName) as f:
+ f.write(rootJobStoreID)
+
+ def loadRootJob(self):
+ """
+ Loads the root job in the current job store.
+
+ :raises toil.job.JobException: If no root job is set or if the root job doesn't exist in
+ this job store
+ :return: The root job.
+ :rtype: toil.jobGraph.JobGraph
+ """
+ try:
+ with self.readSharedFileStream(self.rootJobStoreIDFileName) as f:
+ rootJobStoreID = f.read()
+ except NoSuchFileException:
+ raise JobException('No job has been set as the root in this job store')
+ if not self.exists(rootJobStoreID):
+ raise JobException("The root job '%s' doesn't exist. Either the Toil workflow "
+ "is finished or has never been started" % rootJobStoreID)
+ return self.load(rootJobStoreID)
+
+ # FIXME: This is only used in tests, why do we have it?
+
+ def createRootJob(self, *args, **kwargs):
+ """
+ Create a new job and set it as the root job in this job store
+
+ :rtype : toil.jobGraph.JobGraph
+ """
+ rootJob = self.create(*args, **kwargs)
+ self.setRootJob(rootJob.jobStoreID)
+ return rootJob
+
+ @property
+ @memoize
+ def _jobStoreClasses(self):
+ """
+ A list of concrete AbstractJobStore implementations whose dependencies are installed.
+
+ :rtype: list[AbstractJobStore]
+ """
+ jobStoreClassNames = (
+ "toil.jobStores.azureJobStore.AzureJobStore",
+ "toil.jobStores.fileJobStore.FileJobStore",
+ "toil.jobStores.googleJobStore.GoogleJobStore",
+ "toil.jobStores.aws.jobStore.AWSJobStore",
+ "toil.jobStores.abstractJobStore.JobStoreSupport")
+ jobStoreClasses = []
+ for className in jobStoreClassNames:
+ moduleName, className = className.rsplit('.', 1)
+ from importlib import import_module
+ try:
+ module = import_module(moduleName)
+ except ImportError:
+ logger.debug("Unable to import '%s' as is expected if the corresponding extra was "
+ "omitted at installation time.", moduleName)
+ else:
+ jobStoreClass = getattr(module, className)
+ jobStoreClasses.append(jobStoreClass)
+ return jobStoreClasses
+
+ def _findJobStoreForUrl(self, url, export=False):
+ """
+ Returns the AbstractJobStore subclass that supports the given URL.
+
+ :param urlparse.ParseResult url: The given URL
+ :param bool export: The URL for
+ :rtype: toil.jobStore.AbstractJobStore
+ """
+ for jobStoreCls in self._jobStoreClasses:
+ if jobStoreCls._supportsUrl(url, export):
+ return jobStoreCls
+ raise RuntimeError("No job store implementation supports %sporting for URL '%s'" %
+ ('ex' if export else 'im', url.geturl()))
+
+ def importFile(self, srcUrl, sharedFileName=None):
+ """
+ Imports the file at the given URL into job store. The ID of the newly imported file is
+ returned. If the name of a shared file name is provided, the file will be imported as
+ such and None is returned.
+
+ Currently supported schemes are:
+
+ - 's3' for objects in Amazon S3
+ e.g. s3://bucket/key
+
+ - 'wasb' for blobs in Azure Blob Storage
+ e.g. wasb://container/blob
+
+ - 'file' for local files
+ e.g. file:///local/file/path
+
+ - 'http'
+ e.g. http://someurl.com/path
+
+ :param str srcUrl: URL that points to a file or object in the storage mechanism of a
+ supported URL scheme e.g. a blob in an Azure Blob Storage container.
+
+ :param str sharedFileName: Optional name to assign to the imported file within the job store
+
+ :return The jobStoreFileId of the imported file or None if sharedFileName was given
+ :rtype: str|None
+ """
+ # Note that the helper method _importFile is used to read from the source and write to
+ # destination (which is the current job store in this case). To implement any
+ # optimizations that circumvent this, the _importFile method should be overridden by
+ # subclasses of AbstractJobStore.
+ srcUrl = urlparse.urlparse(srcUrl)
+ otherCls = self._findJobStoreForUrl(srcUrl)
+ return self._importFile(otherCls, srcUrl, sharedFileName=sharedFileName)
+
+ def _importFile(self, otherCls, url, sharedFileName=None):
+ """
+ Import the file at the given URL using the given job store class to retrieve that file.
+ See also :meth:`.importFile`. This method applies a generic approach to importing: it
+ asks the other job store class for a stream and writes that stream as eiher a regular or
+ a shared file.
+
+ :param AbstractJobStore otherCls: The concrete subclass of AbstractJobStore that supports
+ reading from the given URL.
+
+ :param urlparse.ParseResult url: The location of the file to import.
+
+ :param str sharedFileName: Optional name to assign to the imported file within the job store
+
+ :return The jobStoreFileId of imported file or None if sharedFileName was given
+ :rtype: str|None
+ """
+ if sharedFileName is None:
+ with self.writeFileStream() as (writable, jobStoreFileID):
+ otherCls._readFromUrl(url, writable)
+ return jobStoreFileID
+ else:
+ self._requireValidSharedFileName(sharedFileName)
+ with self.writeSharedFileStream(sharedFileName) as writable:
+ otherCls._readFromUrl(url, writable)
+ return None
+
+ def exportFile(self, jobStoreFileID, dstUrl):
+ """
+ Exports file to destination pointed at by the destination URL.
+
+ Refer to AbstractJobStore.importFile documentation for currently supported URL schemes.
+
+ Note that the helper method _exportFile is used to read from the source and write to
+ destination. To implement any optimizations that circumvent this, the _exportFile method
+ should be overridden by subclasses of AbstractJobStore.
+
+ :param str jobStoreFileID: The id of the file in the job store that should be exported.
+ :param str dstUrl: URL that points to a file or object in the storage mechanism of a
+ supported URL scheme e.g. a blob in an Azure Blob Storage container.
+ """
+ dstUrl = urlparse.urlparse(dstUrl)
+ otherCls = self._findJobStoreForUrl(dstUrl, export=True)
+ return self._exportFile(otherCls, jobStoreFileID, dstUrl)
+
+ def _exportFile(self, otherCls, jobStoreFileID, url):
+ """
+ Refer to exportFile docstring for information about this method.
+
+ :param AbstractJobStore otherCls: The concrete subclass of AbstractJobStore that supports
+ exporting to the given URL. Note that the type annotation here is not completely
+ accurate. This is not an instance, it's a class, but there is no way to reflect
+ that in PEP-484 type hints.
+
+ :param str jobStoreFileID: The id of the file that will be exported.
+
+ :param urlparse.ParseResult url: The parsed URL of the file to export to.
+ """
+ with self.readFileStream(jobStoreFileID) as readable:
+ otherCls._writeToUrl(readable, url)
+
+ @abstractclassmethod
+ def _readFromUrl(cls, url, writable):
+ """
+ Reads the contents of the object at the specified location and writes it to the given
+ writable stream.
+
+ Refer to AbstractJobStore.importFile documentation for currently supported URL schemes.
+
+ :param urlparse.ParseResult url: URL that points to a file or object in the storage
+ mechanism of a supported URL scheme e.g. a blob in an Azure Blob Storage container.
+
+ :param writable: a writable stream
+ """
+ raise NotImplementedError()
+
+ @abstractclassmethod
+ def _writeToUrl(cls, readable, url):
+ """
+ Reads the contents of the given readable stream and writes it to the object at the
+ specified location.
+
+ Refer to AbstractJobStore.importFile documentation for currently supported URL schemes.
+
+ :param urlparse.ParseResult url: URL that points to a file or object in the storage
+ mechanism of a supported URL scheme e.g. a blob in an Azure Blob Storage container.
+
+ :param readable: a readable stream
+ """
+ raise NotImplementedError()
+
+ @abstractclassmethod
+ def _supportsUrl(cls, url, export=False):
+ """
+ Returns True if the job store supports the URL's scheme.
+
+ Refer to AbstractJobStore.importFile documentation for currently supported URL schemes.
+
+ :param bool export: Determines if the url is supported for exported
+ :param urlparse.ParseResult url: a parsed URL that may be supported
+ :return bool: returns true if the cls supports the URL
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ def destroy(self):
+ """
+ The inverse of :meth:`.initialize`, this method deletes the physical storage represented
+ by this instance. While not being atomic, this method *is* at least idempotent,
+ as a means to counteract potential issues with eventual consistency exhibited by the
+ underlying storage mechanisms. This means that if the method fails (raises an exception),
+ it may (and should be) invoked again. If the underlying storage mechanism is eventually
+ consistent, even a successful invocation is not an ironclad guarantee that the physical
+ storage vanished completely and immediately. A successful invocation only guarantees that
+ the deletion will eventually happen. It is therefore recommended to not immediately reuse
+ the same job store location for a new Toil workflow.
+ """
+ raise NotImplementedError()
+
+ def getEnv(self):
+ """
+ Returns a dictionary of environment variables that this job store requires to be set in
+ order to function properly on a worker.
+
+ :rtype: dict[str,str]
+ """
+ return {}
+
+ # Cleanup functions
+
+ def clean(self, jobCache=None):
+ """
+ Function to cleanup the state of a job store after a restart.
+ Fixes jobs that might have been partially updated. Resets the try counts and removes jobs
+ that are not successors of the current root job.
+
+ :param dict[str,toil.jobGraph.JobGraph] jobCache: if a value it must be a dict
+ from job ID keys to JobGraph object values. Jobs will be loaded from the cache
+ (which can be downloaded from the job store in a batch) instead of piecemeal when
+ recursed into.
+ """
+ if jobCache is None:
+ logger.warning("Cleaning jobStore recursively. This may be slow.")
+
+ # Functions to get and check the existence of jobs, using the jobCache
+ # if present
+ def getJob(jobId):
+ if jobCache is not None:
+ try:
+ return jobCache[jobId]
+ except KeyError:
+ self.load(jobId)
+ else:
+ return self.load(jobId)
+
+ def haveJob(jobId):
+ if jobCache is not None:
+ if jobCache.has_key(jobId):
+ return True
+ else:
+ return self.exists(jobId)
+ else:
+ return self.exists(jobId)
+
+ def getJobs():
+ if jobCache is not None:
+ return itervalues(jobCache)
+ else:
+ return self.jobs()
+
+ # Iterate from the root jobGraph and collate all jobs that are reachable from it
+ # All other jobs returned by self.jobs() are orphaned and can be removed
+ reachableFromRoot = set()
+
+ def getConnectedJobs(jobGraph):
+ if jobGraph.jobStoreID in reachableFromRoot:
+ return
+ reachableFromRoot.add(jobGraph.jobStoreID)
+ # Traverse jobs in stack
+ for jobs in jobGraph.stack:
+ for successorJobStoreID in map(lambda x: x.jobStoreID, jobs):
+ if (successorJobStoreID not in reachableFromRoot
+ and haveJob(successorJobStoreID)):
+ getConnectedJobs(getJob(successorJobStoreID))
+ # Traverse service jobs
+ for jobs in jobGraph.services:
+ for serviceJobStoreID in map(lambda x: x.jobStoreID, jobs):
+ if haveJob(serviceJobStoreID):
+ assert serviceJobStoreID not in reachableFromRoot
+ reachableFromRoot.add(serviceJobStoreID)
+
+ logger.info("Checking job graph connectivity...")
+ getConnectedJobs(self.loadRootJob())
+ logger.info("%d jobs reachable from root." % len(reachableFromRoot))
+
+ # Cleanup jobs that are not reachable from the root, and therefore orphaned
+ jobsToDelete = filter(lambda x: x.jobStoreID not in reachableFromRoot, getJobs())
+ for jobGraph in jobsToDelete:
+ # clean up any associated files before deletion
+ for fileID in jobGraph.filesToDelete:
+ # Delete any files that should already be deleted
+ logger.warn("Deleting file '%s'. It is marked for deletion but has not yet been "
+ "removed.", fileID)
+ self.deleteFile(fileID)
+ # Delete the job
+ self.delete(jobGraph.jobStoreID)
+
+ # Clean up jobs that are in reachable from the root
+ for jobGraph in (getJob(x) for x in reachableFromRoot):
+ # jobGraphs here are necessarily in reachable from root.
+
+ changed = [False] # This is a flag to indicate the jobGraph state has
+ # changed
+
+ # If the job has files to delete delete them.
+ if len(jobGraph.filesToDelete) != 0:
+ # Delete any files that should already be deleted
+ for fileID in jobGraph.filesToDelete:
+ logger.critical("Removing file in job store: %s that was "
+ "marked for deletion but not previously removed" % fileID)
+ self.deleteFile(fileID)
+ jobGraph.filesToDelete = []
+ changed[0] = True
+
+ # For a job whose command is already executed, remove jobs from the stack that are
+ # already deleted. This cleans up the case that the jobGraph had successors to run,
+ # but had not been updated to reflect this.
+ if jobGraph.command is None:
+ stackSizeFn = lambda: sum(map(len, jobGraph.stack))
+ startStackSize = stackSizeFn()
+ # Remove deleted jobs
+ jobGraph.stack = map(lambda x: filter(lambda y: self.exists(y.jobStoreID), x),
+ jobGraph.stack)
+ # Remove empty stuff from the stack
+ jobGraph.stack = filter(lambda x: len(x) > 0, jobGraph.stack)
+ # Check if anything got removed
+ if stackSizeFn() != startStackSize:
+ changed[0] = True
+
+ # Cleanup any services that have already been finished.
+ # Filter out deleted services and update the flags for services that exist
+ # If there are services then renew
+ # the start and terminate flags if they have been removed
+ def subFlagFile(jobStoreID, jobStoreFileID, flag):
+ if self.fileExists(jobStoreFileID):
+ return jobStoreFileID
+
+ # Make a new flag
+ newFlag = self.getEmptyFileStoreID()
+
+ # Load the jobGraph for the service and initialise the link
+ serviceJobGraph = getJob(jobStoreID)
+
+ if flag == 1:
+ logger.debug("Recreating a start service flag for job: %s, flag: %s",
+ jobStoreID, newFlag)
+ serviceJobGraph.startJobStoreID = newFlag
+ elif flag == 2:
+ logger.debug("Recreating a terminate service flag for job: %s, flag: %s",
+ jobStoreID, newFlag)
+ serviceJobGraph.terminateJobStoreID = newFlag
+ else:
+ logger.debug("Recreating a error service flag for job: %s, flag: %s",
+ jobStoreID, newFlag)
+ assert flag == 3
+ serviceJobGraph.errorJobStoreID = newFlag
+
+ # Update the service job on disk
+ self.update(serviceJobGraph)
+
+ changed[0] = True
+
+ return newFlag
+
+ servicesSizeFn = lambda: sum(map(len, jobGraph.services))
+ startServicesSize = servicesSizeFn()
+
+ def replaceFlagsIfNeeded(serviceJobNode):
+ serviceJobNode.startJobStoreID = subFlagFile(serviceJobNode.jobStoreID, serviceJobNode.startJobStoreID, 1)
+ serviceJobNode.terminateJobStoreID = subFlagFile(serviceJobNode.jobStoreID, serviceJobNode.terminateJobStoreID, 2)
+ serviceJobNode.errorJobStoreID = subFlagFile(serviceJobNode.jobStoreID, serviceJobNode.errorJobStoreID, 3)
+
+ # jobGraph.services is a list of lists containing serviceNodes
+ # remove all services that no longer exist
+ services = jobGraph.services
+ jobGraph.services = []
+ for serviceList in services:
+ existingServices = filter(lambda service: self.exists(service.jobStoreID), serviceList)
+ if existingServices:
+ jobGraph.services.append(existingServices)
+
+ map(lambda serviceList: map(replaceFlagsIfNeeded, serviceList), jobGraph.services)
+
+ if servicesSizeFn() != startServicesSize:
+ changed[0] = True
+
+ # Reset the retry count of the jobGraph
+ if jobGraph.remainingRetryCount != self._defaultTryCount():
+ jobGraph.remainingRetryCount = self._defaultTryCount()
+ changed[0] = True
+
+ # This cleans the old log file which may
+ # have been left if the jobGraph is being retried after a jobGraph failure.
+ if jobGraph.logJobStoreFileID != None:
+ self.delete(jobGraph.logJobStoreFileID)
+ jobGraph.logJobStoreFileID = None
+ changed[0] = True
+
+ if changed[0]: # Update, but only if a change has occurred
+ logger.critical("Repairing job: %s" % jobGraph.jobStoreID)
+ self.update(jobGraph)
+
+ # Remove any crufty stats/logging files from the previous run
+ logger.info("Discarding old statistics and logs...")
+ self.readStatsAndLogging(lambda x: None)
+
+ logger.info("Job store is clean")
+ # TODO: reloading of the rootJob may be redundant here
+ return self.loadRootJob()
+
+ ##########################################
+ # The following methods deal with creating/loading/updating/writing/checking for the
+ # existence of jobs
+ ##########################################
+
+ @abstractmethod
+ def create(self, jobNode):
+ """
+ Creates a job graph from the given job node & writes it to the job store.
+
+ :rtype: toil.jobGraph.JobGraph
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ def exists(self, jobStoreID):
+ """
+ Indicates whether the job with the specified jobStoreID exists in the job store
+
+ :rtype: bool
+ """
+ raise NotImplementedError()
+
+ # One year should be sufficient to finish any pipeline ;-)
+ publicUrlExpiration = timedelta(days=365)
+
+ @abstractmethod
+ def getPublicUrl(self, fileName):
+ """
+ Returns a publicly accessible URL to the given file in the job store. The returned URL may
+ expire as early as 1h after its been returned. Throw an exception if the file does not
+ exist.
+
+ :param str fileName: the jobStoreFileID of the file to generate a URL for
+
+ :raise NoSuchFileException: if the specified file does not exist in this job store
+
+ :rtype: str
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ def getSharedPublicUrl(self, sharedFileName):
+ """
+ Differs from :meth:`getPublicUrl` in that this method is for generating URLs for shared
+ files written by :meth:`writeSharedFileStream`.
+
+ Returns a publicly accessible URL to the given file in the job store. The returned URL
+ starts with 'http:', 'https:' or 'file:'. The returned URL may expire as early as 1h
+ after its been returned. Throw an exception if the file does not exist.
+
+ :param str sharedFileName: The name of the shared file to generate a publically accessible url for.
+
+ :raise NoSuchFileException: raised if the specified file does not exist in the store
+
+ :rtype: str
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ def load(self, jobStoreID):
+ """
+ Loads the job referenced by the given ID and returns it.
+
+ :param str jobStoreID: the ID of the job to load
+
+ :raise NoSuchJobException: if there is no job with the given ID
+
+ :rtype: toil.jobGraph.JobGraph
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ def update(self, job):
+ """
+ Persists the job in this store atomically.
+
+ :param toil.jobGraph.JobGraph job: the job to write to this job store
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ def delete(self, jobStoreID):
+ """
+ Removes from store atomically, can not then subsequently call load(), write(), update(),
+ etc. with the job.
+
+ This operation is idempotent, i.e. deleting a job twice or deleting a non-existent job
+ will succeed silently.
+
+ :param str jobStoreID: the ID of the job to delete from this job store
+ """
+ raise NotImplementedError()
+
+ def jobs(self):
+ """
+ Best effort attempt to return iterator on all jobs in the store. The iterator may not
+ return all jobs and may also contain orphaned jobs that have already finished succesfully
+ and should not be rerun. To guarantee you get any and all jobs that can be run instead
+ construct a more expensive ToilState object
+
+ :return: Returns iterator on jobs in the store. The iterator may or may not contain all jobs and may contain
+ invalid jobs
+ :rtype: Iterator[toil.jobGraph.JobGraph]
+ """
+ raise NotImplementedError()
+
+ ##########################################
+ # The following provide an way of creating/reading/writing/updating files
+ # associated with a given job.
+ ##########################################
+
+ @abstractmethod
+ def writeFile(self, localFilePath, jobStoreID=None):
+ """
+ Takes a file (as a path) and places it in this job store. Returns an ID that can be used
+ to retrieve the file at a later time.
+
+ :param str localFilePath: the path to the local file that will be uploaded to the job store.
+
+ :param str|None jobStoreID: If specified the file will be associated with that job and when
+ jobStore.delete(job) is called all files written with the given job.jobStoreID will
+ be removed from the job store.
+
+ :raise ConcurrentFileModificationException: if the file was modified concurrently during
+ an invocation of this method
+
+ :raise NoSuchJobException: if the job specified via jobStoreID does not exist
+
+ FIXME: some implementations may not raise this
+
+ :return: an ID referencing the newly created file and can be used to read the
+ file in the future.
+ :rtype: str
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ @contextmanager
+ def writeFileStream(self, jobStoreID=None):
+ """
+ Similar to writeFile, but returns a context manager yielding a tuple of
+ 1) a file handle which can be written to and 2) the ID of the resulting
+ file in the job store. The yielded file handle does not need to and
+ should not be closed explicitly.
+
+ :param str jobStoreID: the id of a job, or None. If specified, the file will be associated
+ with that job and when when jobStore.delete(job) is called all files written with the
+ given job.jobStoreID will be removed from the job store.
+
+ :raise ConcurrentFileModificationException: if the file was modified concurrently during
+ an invocation of this method
+
+ :raise NoSuchJobException: if the job specified via jobStoreID does not exist
+
+ FIXME: some implementations may not raise this
+
+ :return: an ID that references the newly created file and can be used to read the
+ file in the future.
+ :rtype: str
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ def getEmptyFileStoreID(self, jobStoreID=None):
+ """
+ Creates an empty file in the job store and returns its ID.
+ Call to fileExists(getEmptyFileStoreID(jobStoreID)) will return True.
+
+ :param str jobStoreID: the id of a job, or None. If specified, the file will be associated with
+ that job and when jobStore.delete(job) is called a best effort attempt is made to delete
+ all files written with the given job.jobStoreID
+
+ :return: a jobStoreFileID that references the newly created file and can be used to reference the
+ file in the future.
+ :rtype: str
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ def readFile(self, jobStoreFileID, localFilePath):
+ """
+ Copies the file referenced by jobStoreFileID to the given local file path. The version
+ will be consistent with the last copy of the file written/updated.
+
+ The file at the given local path may not be modified after this method returns!
+
+ :param str jobStoreFileID: ID of the file to be copied
+
+ :param str localFilePath: the local path indicating where to place the contents of the
+ given file in the job store
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ @contextmanager
+ def readFileStream(self, jobStoreFileID):
+ """
+ Similar to readFile, but returns a context manager yielding a file handle which can be
+ read from. The yielded file handle does not need to and should not be closed explicitly.
+
+ :param str jobStoreFileID: ID of the file to get a readable file handle for
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ def deleteFile(self, jobStoreFileID):
+ """
+ Deletes the file with the given ID from this job store. This operation is idempotent, i.e.
+ deleting a file twice or deleting a non-existent file will succeed silently.
+
+ :param str jobStoreFileID: ID of the file to delete
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ def fileExists(self, jobStoreFileID):
+ """
+ Determine whether a file exists in this job store.
+
+ :param str jobStoreFileID: an ID referencing the file to be checked
+
+ :rtype: bool
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ def updateFile(self, jobStoreFileID, localFilePath):
+ """
+ Replaces the existing version of a file in the job store. Throws an exception if the file
+ does not exist.
+
+ :param str jobStoreFileID: the ID of the file in the job store to be updated
+
+ :param str localFilePath: the local path to a file that will overwrite the current version
+ in the job store
+
+ :raise ConcurrentFileModificationException: if the file was modified concurrently during
+ an invocation of this method
+
+ :raise NoSuchFileException: if the specified file does not exist
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ def updateFileStream(self, jobStoreFileID):
+ """
+ Replaces the existing version of a file in the job store. Similar to writeFile, but
+ returns a context manager yielding a file handle which can be written to. The
+ yielded file handle does not need to and should not be closed explicitly.
+
+ :param str jobStoreFileID: the ID of the file in the job store to be updated
+
+ :raise ConcurrentFileModificationException: if the file was modified concurrently during
+ an invocation of this method
+
+ :raise NoSuchFileException: if the specified file does not exist
+ """
+ raise NotImplementedError()
+
+ ##########################################
+ # The following methods deal with shared files, i.e. files not associated
+ # with specific jobs.
+ ##########################################
+
+ sharedFileNameRegex = re.compile(r'^[a-zA-Z0-9._-]+$')
+
+ # FIXME: Rename to updateSharedFileStream
+
+ @abstractmethod
+ @contextmanager
+ def writeSharedFileStream(self, sharedFileName, isProtected=None):
+ """
+ Returns a context manager yielding a writable file handle to the global file referenced
+ by the given name.
+
+ :param str sharedFileName: A file name matching AbstractJobStore.fileNameRegex, unique within
+ this job store
+
+ :param bool isProtected: True if the file must be encrypted, None if it may be encrypted or
+ False if it must be stored in the clear.
+
+ :raise ConcurrentFileModificationException: if the file was modified concurrently during
+ an invocation of this method
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ @contextmanager
+ def readSharedFileStream(self, sharedFileName):
+ """
+ Returns a context manager yielding a readable file handle to the global file referenced
+ by the given name.
+
+ :param str sharedFileName: A file name matching AbstractJobStore.fileNameRegex, unique within
+ this job store
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ def writeStatsAndLogging(self, statsAndLoggingString):
+ """
+ Adds the given statistics/logging string to the store of statistics info.
+
+ :param str statsAndLoggingString: the string to be written to the stats file
+
+ :raise ConcurrentFileModificationException: if the file was modified concurrently during
+ an invocation of this method
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ def readStatsAndLogging(self, callback, readAll=False):
+ """
+ Reads stats/logging strings accumulated by the writeStatsAndLogging() method. For each
+ stats/logging string this method calls the given callback function with an open,
+ readable file handle from which the stats string can be read. Returns the number of
+ stats/logging strings processed. Each stats/logging string is only processed once unless
+ the readAll parameter is set, in which case the given callback will be invoked for all
+ existing stats/logging strings, including the ones from a previous invocation of this
+ method.
+
+ :param Callable callback: a function to be applied to each of the stats file handles found
+
+ :param bool readAll: a boolean indicating whether to read the already processed stats files
+ in addition to the unread stats files
+
+ :raise ConcurrentFileModificationException: if the file was modified concurrently during
+ an invocation of this method
+
+ :return: the number of stats files processed
+ :rtype: int
+ """
+ raise NotImplementedError()
+
+ ## Helper methods for subclasses
+
+ def _defaultTryCount(self):
+ return int(self.config.retryCount + 1)
+
+ @classmethod
+ def _validateSharedFileName(cls, sharedFileName):
+ return bool(cls.sharedFileNameRegex.match(sharedFileName))
+
+ @classmethod
+ def _requireValidSharedFileName(cls, sharedFileName):
+ if not cls._validateSharedFileName(sharedFileName):
+ raise ValueError("Not a valid shared file name: '%s'." % sharedFileName)
+
+
+class JobStoreSupport(AbstractJobStore):
+ __metaclass__ = ABCMeta
+
+ @classmethod
+ def _supportsUrl(cls, url, export=False):
+ return url.scheme.lower() in ('http', 'https', 'ftp') and not export
+
+ @classmethod
+ def _readFromUrl(cls, url, writable):
+ for attempt in retry_http():
+ with attempt:
+ with closing(urlopen(url.geturl())) as readable:
+ shutil.copyfileobj(readable, writable)
diff --git a/src/toil/jobStores/aws/__init__.py b/src/toil/jobStores/aws/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/toil/jobStores/aws/jobStore.py b/src/toil/jobStores/aws/jobStore.py
new file mode 100644
index 0000000..a052994
--- /dev/null
+++ b/src/toil/jobStores/aws/jobStore.py
@@ -0,0 +1,1363 @@
+# Copyright (C) 2015-2016 Regents of the University of California
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+
+from contextlib import contextmanager, closing
+import logging
+from multiprocessing import cpu_count
+
+import os
+import re
+import uuid
+import base64
+import hashlib
+import itertools
+
+# Python 3 compatibility imports
+from six.moves import xrange, cPickle, StringIO, reprlib
+from six import iteritems
+
+from bd2k.util import strict_bool
+from bd2k.util.exceptions import panic
+from bd2k.util.objects import InnerClass
+from boto.sdb.domain import Domain
+from boto.s3.bucket import Bucket
+from boto.s3.connection import S3Connection
+from boto.sdb.connection import SDBConnection
+from boto.sdb.item import Item
+import boto.s3
+import boto.sdb
+from boto.exception import S3CreateError
+from boto.s3.key import Key
+from boto.exception import SDBResponseError, S3ResponseError
+from concurrent.futures import ThreadPoolExecutor
+
+from toil.jobStores.abstractJobStore import (AbstractJobStore,
+ NoSuchJobException,
+ ConcurrentFileModificationException,
+ NoSuchFileException,
+ NoSuchJobStoreException,
+ JobStoreExistsException)
+from toil.jobStores.aws.utils import (SDBHelper,
+ retry_sdb,
+ no_such_sdb_domain,
+ sdb_unavailable,
+ monkeyPatchSdbConnection,
+ retry_s3,
+ bucket_location_to_region,
+ region_to_bucket_location)
+from toil.jobStores.utils import WritablePipe, ReadablePipe
+from toil.jobGraph import JobGraph
+import toil.lib.encryption as encryption
+
+log = logging.getLogger(__name__)
+
+
+def copyKeyMultipart(srcKey, dstBucketName, dstKeyName, partSize, headers=None):
+ """
+ Copies a key from a source key to a destination key in multiple parts. Note that if the
+ destination key exists it will be overwritten implicitly, and if it does not exist a new
+ key will be created. If the destination bucket does not exist an error will be raised.
+
+ :param boto.s3.key.Key srcKey: The source key to be copied from.
+ :param str dstBucketName: The name of the destination bucket for the copy.
+ :param str dstKeyName: The name of the destination key that will be created or overwritten.
+ :param int partSize: The size of each individual part, must be >= 5 MiB but large enough to
+ not exceed 10k parts for the whole file
+ :param dict headers: Any headers that should be passed.
+
+ :rtype: boto.s3.multipart.CompletedMultiPartUpload
+ :return: An object representing the completed upload.
+ """
+
+ def copyPart(partIndex):
+ if exceptions:
+ return None
+ try:
+ for attempt in retry_s3():
+ with attempt:
+ start = partIndex * partSize
+ end = min(start + partSize, totalSize)
+ part = upload.copy_part_from_key(src_bucket_name=srcKey.bucket.name,
+ src_key_name=srcKey.name,
+ src_version_id=srcKey.version_id,
+ # S3 part numbers are 1-based
+ part_num=partIndex + 1,
+ # S3 range intervals are closed at the end
+ start=start, end=end - 1,
+ headers=headers)
+ except Exception as e:
+ if len(exceptions) < 5:
+ exceptions.append(e)
+ log.error('Failed to copy part number %d:', partIndex, exc_info=True)
+ else:
+ log.warn('Also failed to copy part number %d due to %s.', partIndex, e)
+ return None
+ else:
+ log.debug('Successfully copied part %d of %d.', partIndex, totalParts)
+ # noinspection PyUnboundLocalVariable
+ return part
+
+ totalSize = srcKey.size
+ totalParts = (totalSize + partSize - 1) / partSize
+ exceptions = []
+ # We need a location-agnostic connection to S3 so we can't use the one that we
+ # normally use for interacting with the job store bucket.
+ with closing(boto.connect_s3()) as s3:
+ for attempt in retry_s3():
+ with attempt:
+ dstBucket = s3.get_bucket(dstBucketName)
+ upload = dstBucket.initiate_multipart_upload(dstKeyName, headers=headers)
+ log.info("Initiated multipart copy from 's3://%s/%s' to 's3://%s/%s'.",
+ srcKey.bucket.name, srcKey.name, dstBucketName, dstKeyName)
+ try:
+ # We can oversubscribe cores by at least a factor of 16 since each copy task just
+ # blocks, waiting on the server. Limit # of threads to 128, since threads aren't
+ # exactly free either. Lastly, we don't need more threads than we have parts.
+ with ThreadPoolExecutor(max_workers=min(cpu_count() * 16, totalParts, 128)) as executor:
+ parts = list(executor.map(copyPart, xrange(0, totalParts)))
+ if exceptions:
+ raise RuntimeError('Failed to copy at least %d part(s)' % len(exceptions))
+ assert len(filter(None, parts)) == totalParts
+ except:
+ with panic(log=log):
+ upload.cancel_upload()
+ else:
+ for attempt in retry_s3():
+ with attempt:
+ completed = upload.complete_upload()
+ log.info("Completed copy from 's3://%s/%s' to 's3://%s/%s'.",
+ srcKey.bucket.name, srcKey.name, dstBucketName, dstKeyName)
+ return completed
+
+
+class AWSJobStore(AbstractJobStore):
+ """
+ A job store that uses Amazon's S3 for file storage and SimpleDB for storing job info and
+ enforcing strong consistency on the S3 file storage. There will be SDB domains for jobs and
+ files and a versioned S3 bucket for file contents. Job objects are pickled, compressed,
+ partitioned into chunks of 1024 bytes and each chunk is stored as a an attribute of the SDB
+ item representing the job. UUIDs are used to identify jobs and files.
+ """
+
+ # Dots in bucket names should be avoided because bucket names are used in HTTPS bucket
+ # URLs where the may interfere with the certificate common name. We use a double
+ # underscore as a separator instead.
+ #
+ bucketNameRe = re.compile(r'^[a-z0-9][a-z0-9-]+[a-z0-9]$')
+
+ # See http://docs.aws.amazon.com/AmazonS3/latest/dev/BucketRestrictions.html
+ #
+ minBucketNameLen = 3
+ maxBucketNameLen = 63
+ maxNameLen = 10
+ nameSeparator = '--'
+
+ def __init__(self, locator, partSize=50 << 20):
+ """
+ Create a new job store in AWS or load an existing one from there.
+
+ :param int partSize: The size of each individual part used for multipart operations like
+ upload and copy, must be >= 5 MiB but large enough to not exceed 10k parts for the
+ whole file
+ """
+ super(AWSJobStore, self).__init__()
+ region, namePrefix = locator.split(':')
+ if not self.bucketNameRe.match(namePrefix):
+ raise ValueError("Invalid name prefix '%s'. Name prefixes must contain only digits, "
+ "hyphens or lower-case letters and must not start or end in a "
+ "hyphen." % namePrefix)
+ # Reserve 13 for separator and suffix
+ if len(namePrefix) > self.maxBucketNameLen - self.maxNameLen - len(self.nameSeparator):
+ raise ValueError("Invalid name prefix '%s'. Name prefixes may not be longer than 50 "
+ "characters." % namePrefix)
+ if '--' in namePrefix:
+ raise ValueError("Invalid name prefix '%s'. Name prefixes may not contain "
+ "%s." % (namePrefix, self.nameSeparator))
+ log.debug("Instantiating %s for region %s and name prefix '%s'",
+ self.__class__, region, namePrefix)
+ self.locator = locator
+ self.region = region
+ self.namePrefix = namePrefix
+ self.partSize = partSize
+ self.jobsDomain = None
+ self.filesDomain = None
+ self.filesBucket = None
+ self.db = self._connectSimpleDB()
+ self.s3 = self._connectS3()
+
+ def initialize(self, config):
+ if self._registered:
+ raise JobStoreExistsException(self.locator)
+ self._registered = None
+ try:
+ self._bind(create=True)
+ except:
+ with panic(log):
+ self.destroy()
+ else:
+ super(AWSJobStore, self).initialize(config)
+ # Only register after job store has been full initialized
+ self._registered = True
+
+ @property
+ def sseKeyPath(self):
+ return self.config.sseKey
+
+ def resume(self):
+ if not self._registered:
+ raise NoSuchJobStoreException(self.locator)
+ self._bind(create=False)
+ super(AWSJobStore, self).resume()
+
+ def _bind(self, create=False, block=True):
+ def qualify(name):
+ assert len(name) <= self.maxNameLen
+ return self.namePrefix + self.nameSeparator + name
+
+ # The order in which this sequence of events happens is important. We can easily handle the
+ # inability to bind a domain, but it is a little harder to handle some cases of binding the
+ # jobstore bucket. Maintaining this order allows for an easier `destroy` method.
+ if self.jobsDomain is None:
+ self.jobsDomain = self._bindDomain(qualify('jobs'), create=create, block=block)
+ if self.filesDomain is None:
+ self.filesDomain = self._bindDomain(qualify('files'), create=create, block=block)
+ if self.filesBucket is None:
+ self.filesBucket = self._bindBucket(qualify('files'),
+ create=create,
+ block=block,
+ versioning=True)
+
+ @property
+ def _registered(self):
+ """
+ A optional boolean property indidcating whether this job store is registered. The
+ registry is the authority on deciding if a job store exists or not. If True, this job
+ store exists, if None the job store is transitioning from True to False or vice versa,
+ if False the job store doesn't exist.
+
+ :type: bool|None
+ """
+ # The weird mapping of the SDB item attribute value to the property value is due to
+ # backwards compatibility. 'True' becomes True, that's easy. Toil < 3.3.0 writes this at
+ # the end of job store creation. Absence of either the registry, the item or the
+ # attribute becomes False, representing a truly absent, non-existing job store. An
+ # attribute value of 'False', which is what Toil < 3.3.0 writes at the *beginning* of job
+ # store destruction, indicates a job store in transition, reflecting the fact that 3.3.0
+ # may leak buckets or domains even though the registry reports 'False' for them. We
+ # can't handle job stores that were partially created by 3.3.0, though.
+ registry_domain = self._bindDomain(domain_name='toil-registry',
+ create=False,
+ block=False)
+ if registry_domain is None:
+ return False
+ else:
+ for attempt in retry_sdb():
+ with attempt:
+ attributes = registry_domain.get_attributes(item_name=self.namePrefix,
+ attribute_name='exists',
+ consistent_read=True)
+ try:
+ exists = attributes['exists']
+ except KeyError:
+ return False
+ else:
+ if exists == 'True':
+ return True
+ elif exists == 'False':
+ return None
+ else:
+ assert False
+
+ @_registered.setter
+ def _registered(self, value):
+
+ registry_domain = self._bindDomain(domain_name='toil-registry',
+ # Only create registry domain when registering or
+ # transitioning a store
+ create=value is not False,
+ block=False)
+ if registry_domain is None and value is False:
+ pass
+ else:
+ for attempt in retry_sdb():
+ with attempt:
+ if value is False:
+ registry_domain.delete_attributes(item_name=self.namePrefix)
+ else:
+ if value is True:
+ attributes = dict(exists='True')
+ elif value is None:
+ attributes = dict(exists='False')
+ else:
+ assert False
+ registry_domain.put_attributes(item_name=self.namePrefix,
+ attributes=attributes)
+
+ def create(self, jobNode):
+ jobStoreID = self._newJobID()
+ log.debug("Creating job %s for '%s'",
+ jobStoreID, '<no command>' if jobNode.command is None else jobNode.command)
+ job = AWSJob.fromJobNode(jobNode, jobStoreID=jobStoreID, tryCount=self._defaultTryCount())
+ for attempt in retry_sdb():
+ with attempt:
+ assert self.jobsDomain.put_attributes(*job.toItem())
+ return job
+
+ def exists(self, jobStoreID):
+ for attempt in retry_sdb():
+ with attempt:
+ return bool(self.jobsDomain.get_attributes(
+ item_name=jobStoreID,
+ attribute_name=[AWSJob.presenceIndicator()],
+ consistent_read=True))
+
+ def jobs(self):
+ result = None
+ for attempt in retry_sdb():
+ with attempt:
+ result = list(self.jobsDomain.select(
+ consistent_read=True,
+ query="select * from `%s`" % self.jobsDomain.name))
+ assert result is not None
+ for jobItem in result:
+ yield AWSJob.fromItem(jobItem)
+
+ def load(self, jobStoreID):
+ item = None
+ for attempt in retry_sdb():
+ with attempt:
+ item = self.jobsDomain.get_attributes(jobStoreID, consistent_read=True)
+ if not item:
+ raise NoSuchJobException(jobStoreID)
+ job = AWSJob.fromItem(item)
+ if job is None:
+ raise NoSuchJobException(jobStoreID)
+ log.debug("Loaded job %s", jobStoreID)
+ return job
+
+ def update(self, job):
+ log.debug("Updating job %s", job.jobStoreID)
+ for attempt in retry_sdb():
+ with attempt:
+ assert self.jobsDomain.put_attributes(*job.toItem())
+
+ itemsPerBatchDelete = 25
+
+ def delete(self, jobStoreID):
+ # remove job and replace with jobStoreId.
+ log.debug("Deleting job %s", jobStoreID)
+ for attempt in retry_sdb():
+ with attempt:
+ self.jobsDomain.delete_attributes(item_name=jobStoreID)
+ items = None
+ for attempt in retry_sdb():
+ with attempt:
+ items = list(self.filesDomain.select(
+ consistent_read=True,
+ query="select version from `%s` where ownerID='%s'" % (
+ self.filesDomain.name, jobStoreID)))
+ assert items is not None
+ if items:
+ log.debug("Deleting %d file(s) associated with job %s", len(items), jobStoreID)
+ n = self.itemsPerBatchDelete
+ batches = [items[i:i + n] for i in range(0, len(items), n)]
+ for batch in batches:
+ itemsDict = {item.name: None for item in batch}
+ for attempt in retry_sdb():
+ with attempt:
+ self.filesDomain.batch_delete_attributes(itemsDict)
+ for item in items:
+ version = item.get('version')
+ for attempt in retry_s3():
+ with attempt:
+ if version:
+ self.filesBucket.delete_key(key_name=item.name, version_id=version)
+ else:
+ self.filesBucket.delete_key(key_name=item.name)
+
+ def getEmptyFileStoreID(self, jobStoreID=None):
+ info = self.FileInfo.create(jobStoreID)
+ info.save()
+ log.debug("Created %r.", info)
+ return info.fileID
+
+ def _importFile(self, otherCls, url, sharedFileName=None):
+ if issubclass(otherCls, AWSJobStore):
+ srcKey = self._getKeyForUrl(url, existing=True)
+ try:
+ if sharedFileName is None:
+ info = self.FileInfo.create(srcKey.name)
+ else:
+ self._requireValidSharedFileName(sharedFileName)
+ jobStoreFileID = self._sharedFileID(sharedFileName)
+ info = self.FileInfo.loadOrCreate(jobStoreFileID=jobStoreFileID,
+ ownerID=str(self.sharedFileOwnerID),
+ encrypted=None)
+ info.copyFrom(srcKey)
+ info.save()
+ finally:
+ srcKey.bucket.connection.close()
+ return info.fileID if sharedFileName is None else None
+ else:
+ return super(AWSJobStore, self)._importFile(otherCls, url,
+ sharedFileName=sharedFileName)
+
+ def _exportFile(self, otherCls, jobStoreFileID, url):
+ if issubclass(otherCls, AWSJobStore):
+ dstKey = self._getKeyForUrl(url)
+ try:
+ info = self.FileInfo.loadOrFail(jobStoreFileID)
+ info.copyTo(dstKey)
+ finally:
+ dstKey.bucket.connection.close()
+ else:
+ super(AWSJobStore, self)._exportFile(otherCls, jobStoreFileID, url)
+
+ @classmethod
+ def _readFromUrl(cls, url, writable):
+ srcKey = cls._getKeyForUrl(url, existing=True)
+ try:
+ srcKey.get_contents_to_file(writable)
+ finally:
+ srcKey.bucket.connection.close()
+
+ @classmethod
+ def _writeToUrl(cls, readable, url):
+ dstKey = cls._getKeyForUrl(url)
+ try:
+ dstKey.set_contents_from_string(readable.read())
+ finally:
+ dstKey.bucket.connection.close()
+
+ @staticmethod
+ def _getKeyForUrl(url, existing=None):
+ """
+ Extracts a key from a given s3:// URL. On return, but not on exceptions, this method
+ leaks an S3Connection object. The caller is responsible to close that by calling
+ key.bucket.connection.close().
+
+ :param bool existing: If True, key is expected to exist. If False, key is expected not to
+ exists and it will be created. If None, the key will be created if it doesn't exist.
+
+ :rtype: Key
+ """
+ # Get the bucket's region to avoid a redirect per request
+ try:
+ with closing(boto.connect_s3()) as s3:
+ location = s3.get_bucket(url.netloc).get_location()
+ region = bucket_location_to_region(location)
+ except S3ResponseError as e:
+ if e.error_code == 'AccessDenied':
+ log.warn("Could not determine location of bucket hosting URL '%s', reverting "
+ "to generic S3 endpoint.", url.geturl())
+ s3 = boto.connect_s3()
+ else:
+ raise
+ else:
+ # Note that caller is responsible for closing the connection
+ s3 = boto.s3.connect_to_region(region)
+
+ try:
+ keyName = url.path[1:]
+ bucketName = url.netloc
+ bucket = s3.get_bucket(bucketName)
+ key = bucket.get_key(keyName)
+ if existing is True:
+ if key is None:
+ raise RuntimeError("Key '%s' does not exist in bucket '%s'." %
+ (keyName, bucketName))
+ elif existing is False:
+ if key is not None:
+ raise RuntimeError("Key '%s' exists in bucket '%s'." %
+ (keyName, bucketName))
+ elif existing is None:
+ pass
+ else:
+ assert False
+ if key is None:
+ key = bucket.new_key(keyName)
+ except:
+ with panic():
+ s3.close()
+ else:
+ return key
+
+ @classmethod
+ def _supportsUrl(cls, url, export=False):
+ return url.scheme.lower() == 's3'
+
+ def writeFile(self, localFilePath, jobStoreID=None):
+ info = self.FileInfo.create(jobStoreID)
+ info.upload(localFilePath)
+ info.save()
+ log.debug("Wrote %r of from %r", info, localFilePath)
+ return info.fileID
+
+ @contextmanager
+ def writeFileStream(self, jobStoreID=None):
+ info = self.FileInfo.create(jobStoreID)
+ with info.uploadStream() as writable:
+ yield writable, info.fileID
+ info.save()
+ log.debug("Wrote %r.", info)
+
+ @contextmanager
+ def writeSharedFileStream(self, sharedFileName, isProtected=None):
+ assert self._validateSharedFileName(sharedFileName)
+ info = self.FileInfo.loadOrCreate(jobStoreFileID=self._sharedFileID(sharedFileName),
+ ownerID=str(self.sharedFileOwnerID),
+ encrypted=isProtected)
+ with info.uploadStream() as writable:
+ yield writable
+ info.save()
+ log.debug("Wrote %r for shared file %r.", info, sharedFileName)
+
+ def updateFile(self, jobStoreFileID, localFilePath):
+ info = self.FileInfo.loadOrFail(jobStoreFileID)
+ info.upload(localFilePath)
+ info.save()
+ log.debug("Wrote %r from path %r.", info, localFilePath)
+
+ @contextmanager
+ def updateFileStream(self, jobStoreFileID):
+ info = self.FileInfo.loadOrFail(jobStoreFileID)
+ with info.uploadStream() as writable:
+ yield writable
+ info.save()
+ log.debug("Wrote %r from stream.", info)
+
+ def fileExists(self, jobStoreFileID):
+ return self.FileInfo.exists(jobStoreFileID)
+
+ def readFile(self, jobStoreFileID, localFilePath):
+ info = self.FileInfo.loadOrFail(jobStoreFileID)
+ log.debug("Reading %r into %r.", info, localFilePath)
+ info.download(localFilePath)
+
+ @contextmanager
+ def readFileStream(self, jobStoreFileID):
+ info = self.FileInfo.loadOrFail(jobStoreFileID)
+ log.debug("Reading %r into stream.", info)
+ with info.downloadStream() as readable:
+ yield readable
+
+ @contextmanager
+ def readSharedFileStream(self, sharedFileName):
+ assert self._validateSharedFileName(sharedFileName)
+ jobStoreFileID = self._sharedFileID(sharedFileName)
+ info = self.FileInfo.loadOrFail(jobStoreFileID, customName=sharedFileName)
+ log.debug("Reading %r for shared file %r into stream.", info, sharedFileName)
+ with info.downloadStream() as readable:
+ yield readable
+
+ def deleteFile(self, jobStoreFileID):
+ info = self.FileInfo.load(jobStoreFileID)
+ if info is None:
+ log.debug("File %s does not exist, skipping deletion.", jobStoreFileID)
+ else:
+ info.delete()
+
+ def writeStatsAndLogging(self, statsAndLoggingString):
+ info = self.FileInfo.create(str(self.statsFileOwnerID))
+ with info.uploadStream(multipart=False) as writeable:
+ writeable.write(statsAndLoggingString)
+ info.save()
+
+ def readStatsAndLogging(self, callback, readAll=False):
+ itemsProcessed = 0
+
+ for info in self._readStatsAndLogging(callback, self.statsFileOwnerID):
+ info._ownerID = self.readStatsFileOwnerID
+ info.save()
+ itemsProcessed += 1
+
+ if readAll:
+ for _ in self._readStatsAndLogging(callback, self.readStatsFileOwnerID):
+ itemsProcessed += 1
+
+ return itemsProcessed
+
+ def _readStatsAndLogging(self, callback, ownerId):
+ items = None
+ for attempt in retry_sdb():
+ with attempt:
+ items = list(self.filesDomain.select(
+ consistent_read=True,
+ query="select * from `%s` where ownerID='%s'" % (
+ self.filesDomain.name, str(ownerId))))
+ assert items is not None
+ for item in items:
+ info = self.FileInfo.fromItem(item)
+ with info.downloadStream() as readable:
+ callback(readable)
+ yield info
+
+ def getPublicUrl(self, jobStoreFileID):
+ info = self.FileInfo.loadOrFail(jobStoreFileID)
+ if info.content is not None:
+ with info.uploadStream(allowInlining=False) as f:
+ f.write(info.content)
+ for attempt in retry_s3():
+ with attempt:
+ key = self.filesBucket.get_key(key_name=jobStoreFileID, version_id=info.version)
+ return key.generate_url(expires_in=self.publicUrlExpiration.total_seconds())
+
+ def getSharedPublicUrl(self, sharedFileName):
+ assert self._validateSharedFileName(sharedFileName)
+ return self.getPublicUrl(self._sharedFileID(sharedFileName))
+
+ def _connectSimpleDB(self):
+ """
+ :rtype: SDBConnection
+ """
+ db = boto.sdb.connect_to_region(self.region)
+ if db is None:
+ raise ValueError("Could not connect to SimpleDB. Make sure '%s' is a valid SimpleDB "
+ "region." % self.region)
+ assert db is not None
+ monkeyPatchSdbConnection(db)
+ return db
+
+ def _connectS3(self):
+ """
+ :rtype: S3Connection
+ """
+ s3 = boto.s3.connect_to_region(self.region)
+ if s3 is None:
+ raise ValueError("Could not connect to S3. Make sure '%s' is a valid S3 region." %
+ self.region)
+ return s3
+
+ def _bindBucket(self, bucket_name, create=False, block=True, versioning=False):
+ """
+ Return the Boto Bucket object representing the S3 bucket with the given name. If the
+ bucket does not exist and `create` is True, it will be created.
+
+ :param str bucket_name: the name of the bucket to bind to
+
+ :param bool create: Whether to create bucket the if it doesn't exist
+
+ :param bool block: If False, return None if the bucket doesn't exist. If True, wait until
+ bucket appears. Ignored if `create` is True.
+
+ :rtype: Bucket|None
+ :raises S3ResponseError: If `block` is True and the bucket still doesn't exist after the
+ retry timeout expires.
+ """
+ assert self.minBucketNameLen <= len(bucket_name) <= self.maxBucketNameLen
+ assert self.bucketNameRe.match(bucket_name)
+ log.debug("Binding to job store bucket '%s'.", bucket_name)
+
+ def bucket_creation_pending(e):
+ # https://github.com/BD2KGenomics/toil/issues/955
+ # https://github.com/BD2KGenomics/toil/issues/995
+ # https://github.com/BD2KGenomics/toil/issues/1093
+ return (isinstance(e, (S3CreateError, S3ResponseError))
+ and e.error_code in ('BucketAlreadyOwnedByYou', 'OperationAborted'))
+
+ bucketExisted = True
+ for attempt in retry_s3(predicate=bucket_creation_pending):
+ with attempt:
+ try:
+ bucket = self.s3.get_bucket(bucket_name, validate=True)
+ except S3ResponseError as e:
+ if e.error_code == 'NoSuchBucket':
+ bucketExisted = False
+ log.debug("Bucket '%s' does not exist.", bucket_name)
+ if create:
+ log.debug("Creating bucket '%s'.", bucket_name)
+ location = region_to_bucket_location(self.region)
+ bucket = self.s3.create_bucket(bucket_name, location=location)
+ assert self.__getBucketRegion(bucket) == self.region
+ elif block:
+ raise
+ else:
+ return None
+ elif e.status == 301:
+ # This is raised if the user attempts to get a bucket in a region outside
+ # the specified one, if the specified one is not `us-east-1`. The us-east-1
+ # server allows a user to use buckets from any region.
+ bucket = self.s3.get_bucket(bucket_name, validate=False)
+ raise BucketLocationConflictException(self.__getBucketRegion(bucket))
+ else:
+ raise
+ else:
+ if self.__getBucketRegion(bucket) != self.region:
+ raise BucketLocationConflictException(self.__getBucketRegion(bucket))
+ if versioning:
+ bucket.configure_versioning(True)
+ else:
+ bucket_versioning = self.__getBucketVersioning(bucket)
+ if bucket_versioning is True:
+ assert False, 'Cannot disable bucket versioning if it is already enabled'
+ elif bucket_versioning is None:
+ assert False, 'Cannot use a bucket with versioning suspended'
+ if bucketExisted:
+ log.debug("Using pre-existing job store bucket '%s'.", bucket_name)
+ else:
+ log.debug("Created new job store bucket '%s'.", bucket_name)
+
+ return bucket
+
+ def _bindDomain(self, domain_name, create=False, block=True):
+ """
+ Return the Boto Domain object representing the SDB domain of the given name. If the
+ domain does not exist and `create` is True, it will be created.
+
+ :param str domain_name: the name of the domain to bind to
+
+ :param bool create: True if domain should be created if it doesn't exist
+
+ :param bool block: If False, return None if the domain doesn't exist. If True, wait until
+ domain appears. This parameter is ignored if create is True.
+
+ :rtype: Domain|None
+ :raises SDBResponseError: If `block` is True and the domain still doesn't exist after the
+ retry timeout expires.
+ """
+ log.debug("Binding to job store domain '%s'.", domain_name)
+ for attempt in retry_sdb(predicate=lambda e: no_such_sdb_domain(e) or sdb_unavailable(e)):
+ with attempt:
+ try:
+ return self.db.get_domain(domain_name)
+ except SDBResponseError as e:
+ if no_such_sdb_domain(e):
+ if create:
+ return self.db.create_domain(domain_name)
+ elif block:
+ raise
+ else:
+ return None
+ else:
+ raise
+
+ def _newJobID(self):
+ return str(uuid.uuid4())
+
+ # A dummy job ID under which all shared files are stored
+ sharedFileOwnerID = uuid.UUID('891f7db6-e4d9-4221-a58e-ab6cc4395f94')
+
+ # A dummy job ID under which all unread stats files are stored
+ statsFileOwnerID = uuid.UUID('bfcf5286-4bc7-41ef-a85d-9ab415b69d53')
+
+ # A dummy job ID under which all read stats files are stored
+ readStatsFileOwnerID = uuid.UUID('e77fc3aa-d232-4255-ae04-f64ee8eb0bfa')
+
+ def _sharedFileID(self, sharedFileName):
+ return str(uuid.uuid5(self.sharedFileOwnerID, str(sharedFileName)))
+
+ @InnerClass
+ class FileInfo(SDBHelper):
+ """
+ Represents a file in this job store.
+ """
+ outer = None
+ """
+ :type: AWSJobStore
+ """
+
+ def __init__(self, fileID, ownerID, encrypted,
+ version=None, content=None, numContentChunks=0):
+ """
+ :type fileID: str
+ :param fileID: the file's ID
+
+ :type ownerID: str
+ :param ownerID: ID of the entity owning this file, typically a job ID aka jobStoreID
+
+ :type encrypted: bool
+ :param encrypted: whether the file is stored in encrypted form
+
+ :type version: str|None
+ :param version: a non-empty string containing the most recent version of the S3
+ object storing this file's content, None if the file is new, or empty string if the
+ file is inlined.
+
+ :type content: str|None
+ :param content: this file's inlined content
+
+ :type numContentChunks: int
+ :param numContentChunks: the number of SDB domain attributes occupied by this files
+ inlined content. Note that an inlined empty string still occupies one chunk.
+ """
+ super(AWSJobStore.FileInfo, self).__init__()
+ self._fileID = fileID
+ self._ownerID = ownerID
+ self.encrypted = encrypted
+ self._version = version
+ self._previousVersion = version
+ self._content = content
+ self._numContentChunks = numContentChunks
+
+ @property
+ def fileID(self):
+ return self._fileID
+
+ @property
+ def ownerID(self):
+ return self._ownerID
+
+ @property
+ def version(self):
+ return self._version
+
+ @version.setter
+ def version(self, version):
+ # Version should only change once
+ assert self._previousVersion == self._version
+ self._version = version
+ if version:
+ self.content = None
+
+ @property
+ def previousVersion(self):
+ return self._previousVersion
+
+ @property
+ def content(self):
+ return self._content
+
+ @content.setter
+ def content(self, content):
+ self._content = content
+ if content is not None:
+ self.version = ''
+
+ @classmethod
+ def create(cls, ownerID):
+ return cls(str(uuid.uuid4()), ownerID, encrypted=cls.outer.sseKeyPath is not None)
+
+ @classmethod
+ def presenceIndicator(cls):
+ return 'encrypted'
+
+ @classmethod
+ def exists(cls, jobStoreFileID):
+ for attempt in retry_sdb():
+ with attempt:
+ return bool(cls.outer.filesDomain.get_attributes(
+ item_name=jobStoreFileID,
+ attribute_name=[cls.presenceIndicator()],
+ consistent_read=True))
+
+ @classmethod
+ def load(cls, jobStoreFileID):
+ for attempt in retry_sdb():
+ with attempt:
+ self = cls.fromItem(
+ cls.outer.filesDomain.get_attributes(item_name=jobStoreFileID,
+ consistent_read=True))
+ return self
+
+ @classmethod
+ def loadOrCreate(cls, jobStoreFileID, ownerID, encrypted):
+ self = cls.load(jobStoreFileID)
+ if encrypted is None:
+ encrypted = cls.outer.sseKeyPath is not None
+ if self is None:
+ self = cls(jobStoreFileID, ownerID, encrypted=encrypted)
+ else:
+ assert self.fileID == jobStoreFileID
+ assert self.ownerID == ownerID
+ self.encrypted = encrypted
+ return self
+
+ @classmethod
+ def loadOrFail(cls, jobStoreFileID, customName=None):
+ """
+ :rtype: AWSJobStore.FileInfo
+ :return: an instance of this class representing the file with the given ID
+ :raises NoSuchFileException: if given file does not exist
+ """
+ self = cls.load(jobStoreFileID)
+ if self is None:
+ raise NoSuchFileException(jobStoreFileID, customName=customName)
+ else:
+ return self
+
+ @classmethod
+ def fromItem(cls, item):
+ """
+ Convert an SDB item to an instance of this class.
+
+ :type item: Item
+ """
+ assert item is not None
+
+ # Strings come back from SDB as unicode
+ def strOrNone(s):
+ return s if s is None else str(s)
+
+ # ownerID and encrypted are the only mandatory attributes
+ ownerID = strOrNone(item.get('ownerID'))
+ encrypted = item.get('encrypted')
+ if ownerID is None:
+ assert encrypted is None
+ return None
+ else:
+ version = strOrNone(item['version'])
+ encrypted = strict_bool(encrypted)
+ content, numContentChunks = cls.attributesToBinary(item)
+ if encrypted:
+ sseKeyPath = cls.outer.sseKeyPath
+ if sseKeyPath is None:
+ raise AssertionError('Content is encrypted but no key was provided.')
+ if content is not None:
+ content = encryption.decrypt(content, sseKeyPath)
+ self = cls(fileID=item.name, ownerID=ownerID, encrypted=encrypted, version=version,
+ content=content, numContentChunks=numContentChunks)
+ return self
+
+ def toItem(self):
+ """
+ Convert this instance to an attribute dictionary suitable for SDB put_attributes().
+
+ :rtype: (dict,int)
+
+ :return: the attributes dict and an integer specifying the the number of chunk
+ attributes in the dictionary that are used for storing inlined content.
+ """
+ if self.content is None:
+ numChunks = 0
+ attributes = {}
+ else:
+ content = self.content
+ if self.encrypted:
+ sseKeyPath = self.outer.sseKeyPath
+ if sseKeyPath is None:
+ raise AssertionError('Encryption requested but no key was provided.')
+ content = encryption.encrypt(content, sseKeyPath)
+ attributes = self.binaryToAttributes(content)
+ numChunks = len(attributes)
+ attributes.update(dict(ownerID=self.ownerID,
+ encrypted=self.encrypted,
+ version=self.version or ''))
+ return attributes, numChunks
+
+ @classmethod
+ def _reservedAttributes(cls):
+ return 3
+
+ @classmethod
+ def maxInlinedSize(cls, encrypted):
+ return cls.maxBinarySize() - (encryption.overhead if encrypted else 0)
+
+ def _maxInlinedSize(self):
+ return self.maxInlinedSize(self.encrypted)
+
+ def save(self):
+ attributes, numNewContentChunks = self.toItem()
+ # False stands for absence
+ expected = ['version', False if self.previousVersion is None else self.previousVersion]
+ try:
+ for attempt in retry_sdb():
+ with attempt:
+ assert self.outer.filesDomain.put_attributes(item_name=self.fileID,
+ attributes=attributes,
+ expected_value=expected)
+ # clean up the old version of the file if necessary and safe
+ if self.previousVersion and (self.previousVersion != self.version):
+ for attempt in retry_s3():
+ with attempt:
+ self.outer.filesBucket.delete_key(self.fileID,
+ version_id=self.previousVersion)
+ self._previousVersion = self._version
+ if numNewContentChunks < self._numContentChunks:
+ residualChunks = xrange(numNewContentChunks, self._numContentChunks)
+ attributes = [self._chunkName(i) for i in residualChunks]
+ for attempt in retry_sdb():
+ with attempt:
+ self.outer.filesDomain.delete_attributes(self.fileID,
+ attributes=attributes)
+ self._numContentChunks = numNewContentChunks
+ except SDBResponseError as e:
+ if e.error_code == 'ConditionalCheckFailed':
+ raise ConcurrentFileModificationException(self.fileID)
+ else:
+ raise
+
+ def upload(self, localFilePath):
+ file_size, file_time = self._fileSizeAndTime(localFilePath)
+ if file_size <= self._maxInlinedSize():
+ with open(localFilePath) as f:
+ self.content = f.read()
+ else:
+ headers = self._s3EncryptionHeaders()
+ if file_size <= self.outer.partSize:
+ key = self.outer.filesBucket.new_key(key_name=self.fileID)
+ key.name = self.fileID
+ for attempt in retry_s3():
+ with attempt:
+ key.set_contents_from_filename(localFilePath, headers=headers)
+ self.version = key.version_id
+ else:
+ with open(localFilePath, 'rb') as f:
+ for attempt in retry_s3():
+ with attempt:
+ upload = self.outer.filesBucket.initiate_multipart_upload(
+ key_name=self.fileID,
+ headers=headers)
+ try:
+ start = 0
+ part_num = itertools.count()
+ while start < file_size:
+ end = min(start + self.outer.partSize, file_size)
+ assert f.tell() == start
+ for attempt in retry_s3():
+ with attempt:
+ upload.upload_part_from_file(fp=f,
+ part_num=next(part_num) + 1,
+ size=end - start,
+ headers=headers)
+ start = end
+ assert f.tell() == file_size == start
+ except:
+ with panic(log=log):
+ for attempt in retry_s3():
+ with attempt:
+ upload.cancel_upload()
+ else:
+ for attempt in retry_s3():
+ with attempt:
+ self.version = upload.complete_upload().version_id
+ for attempt in retry_s3():
+ with attempt:
+ key = self.outer.filesBucket.get_key(self.fileID,
+ headers=headers,
+ version_id=self.version)
+ assert key.size == file_size
+ # Make resonably sure that the file wasn't touched during the upload
+ assert self._fileSizeAndTime(localFilePath) == (file_size, file_time)
+
+ @contextmanager
+ def uploadStream(self, multipart=True, allowInlining=True):
+ info = self
+ store = self.outer
+
+ class MultiPartPipe(WritablePipe):
+ def readFrom(self, readable):
+ buf = readable.read(store.partSize)
+ if allowInlining and len(buf) <= info._maxInlinedSize():
+ info.content = buf
+ else:
+ headers = info._s3EncryptionHeaders()
+ for attempt in retry_s3():
+ with attempt:
+ upload = store.filesBucket.initiate_multipart_upload(
+ key_name=info.fileID,
+ headers=headers)
+ try:
+ for part_num in itertools.count():
+ # There must be at least one part, even if the file is empty.
+ if len(buf) == 0 and part_num > 0:
+ break
+ for attempt in retry_s3():
+ with attempt:
+ upload.upload_part_from_file(fp=StringIO(buf),
+ # part numbers are 1-based
+ part_num=part_num + 1,
+ headers=headers)
+ if len(buf) == 0:
+ break
+ buf = readable.read(info.outer.partSize)
+ except:
+ with panic(log=log):
+ for attempt in retry_s3():
+ with attempt:
+ upload.cancel_upload()
+ else:
+ for attempt in retry_s3():
+ with attempt:
+ info.version = upload.complete_upload().version_id
+
+ class SinglePartPipe(WritablePipe):
+ def readFrom(self, readable):
+ buf = readable.read()
+ if allowInlining and len(buf) <= info._maxInlinedSize():
+ info.content = buf
+ else:
+ key = store.filesBucket.new_key(key_name=info.fileID)
+ buf = StringIO(buf)
+ headers = info._s3EncryptionHeaders()
+ for attempt in retry_s3():
+ with attempt:
+ assert buf.len == key.set_contents_from_file(fp=buf,
+ headers=headers)
+ info.version = key.version_id
+
+ with MultiPartPipe() if multipart else SinglePartPipe() as writable:
+ yield writable
+
+ assert bool(self.version) == (self.content is None)
+
+ def copyFrom(self, srcKey):
+ """
+ Copies contents of source key into this file.
+
+ :param srcKey: The key that will be copied from
+ """
+ assert srcKey.size is not None
+ if srcKey.size <= self._maxInlinedSize():
+ self.content = srcKey.get_contents_as_string()
+ else:
+ self.version = self._copyKey(srcKey=srcKey,
+ dstBucketName=self.outer.filesBucket.name,
+ dstKeyName=self._fileID,
+ headers=self._s3EncryptionHeaders()).version_id
+
+ def copyTo(self, dstKey):
+ """
+ Copies contents of this file to the given key.
+
+ :param Key dstKey: The key to copy this file's content to
+ """
+ if self.content is not None:
+ for attempt in retry_s3():
+ with attempt:
+ dstKey.set_contents_from_string(self.content)
+ elif self.version:
+ for attempt in retry_s3():
+ srcKey = self.outer.filesBucket.get_key(self.fileID,
+ validate=False)
+ srcKey.version_id = self.version
+ with attempt:
+ headers = {k.replace('amz-', 'amz-copy-source-', 1): v
+ for k, v in iteritems(self._s3EncryptionHeaders())}
+ self._copyKey(srcKey=srcKey,
+ dstBucketName=dstKey.bucket.name,
+ dstKeyName=dstKey.name,
+ headers=headers)
+ else:
+ assert False
+
+ def _copyKey(self, srcKey, dstBucketName, dstKeyName, headers=None):
+ headers = headers or {}
+ if srcKey.size > self.outer.partSize:
+ return copyKeyMultipart(srcKey=srcKey,
+ dstBucketName=dstBucketName,
+ dstKeyName=dstKeyName,
+ partSize=self.outer.partSize,
+ headers=headers)
+ else:
+ # We need a location-agnostic connection to S3 so we can't use the one that we
+ # normally use for interacting with the job store bucket.
+ with closing(boto.connect_s3()) as s3:
+ for attempt in retry_s3():
+ with attempt:
+ dstBucket = s3.get_bucket(dstBucketName)
+ return dstBucket.copy_key(new_key_name=dstKeyName,
+ src_bucket_name=srcKey.bucket.name,
+ src_version_id=srcKey.version_id,
+ src_key_name=srcKey.name,
+ metadata=srcKey.metadata,
+ headers=headers)
+
+ def download(self, localFilePath):
+ if self.content is not None:
+ with open(localFilePath, 'w') as f:
+ f.write(self.content)
+ elif self.version:
+ headers = self._s3EncryptionHeaders()
+ key = self.outer.filesBucket.get_key(self.fileID, validate=False)
+ for attempt in retry_s3():
+ with attempt:
+ key.get_contents_to_filename(localFilePath,
+ version_id=self.version,
+ headers=headers)
+ else:
+ assert False
+
+ @contextmanager
+ def downloadStream(self):
+ info = self
+
+ class DownloadPipe(ReadablePipe):
+ def writeTo(self, writable):
+ if info.content is not None:
+ writable.write(info.content)
+ elif info.version:
+ headers = info._s3EncryptionHeaders()
+ key = info.outer.filesBucket.get_key(info.fileID, validate=False)
+ for attempt in retry_s3():
+ with attempt:
+ key.get_contents_to_file(writable,
+ headers=headers,
+ version_id=info.version)
+ else:
+ assert False
+
+ with DownloadPipe() as readable:
+ yield readable
+
+ def delete(self):
+ store = self.outer
+ if self.previousVersion is not None:
+ for attempt in retry_sdb():
+ with attempt:
+ store.filesDomain.delete_attributes(
+ self.fileID,
+ expected_values=['version', self.previousVersion])
+ if self.previousVersion:
+ for attempt in retry_s3():
+ with attempt:
+ store.filesBucket.delete_key(key_name=self.fileID,
+ version_id=self.previousVersion)
+
+ def _s3EncryptionHeaders(self):
+ sseKeyPath = self.outer.sseKeyPath
+ if self.encrypted:
+ if sseKeyPath is None:
+ raise AssertionError('Content is encrypted but no key was provided.')
+ else:
+ with open(sseKeyPath) as f:
+ sseKey = f.read()
+ assert len(sseKey) == 32
+ encodedSseKey = base64.b64encode(sseKey)
+ encodedSseKeyMd5 = base64.b64encode(hashlib.md5(sseKey).digest())
+ return {'x-amz-server-side-encryption-customer-algorithm': 'AES256',
+ 'x-amz-server-side-encryption-customer-key': encodedSseKey,
+ 'x-amz-server-side-encryption-customer-key-md5': encodedSseKeyMd5}
+ else:
+ return {}
+
+ def _fileSizeAndTime(self, localFilePath):
+ file_stat = os.stat(localFilePath)
+ file_size, file_time = file_stat.st_size, file_stat.st_mtime
+ return file_size, file_time
+
+ def __repr__(self):
+ r = custom_repr
+ d = (('fileID', r(self.fileID)),
+ ('ownerID', r(self.ownerID)),
+ ('encrypted', r(self.encrypted)),
+ ('version', r(self.version)),
+ ('previousVersion', r(self.previousVersion)),
+ ('content', r(self.content)),
+ ('_numContentChunks', r(self._numContentChunks)))
+ return "{}({})".format(type(self).__name__,
+ ', '.join('%s=%s' % (k, v) for k, v in d))
+
+ versionings = dict(Enabled=True, Disabled=False, Suspended=None)
+
+ def __getBucketVersioning(self, bucket):
+ """
+ For newly created buckets get_versioning_status returns an empty dict. In the past we've
+ seen None in this case. We map both to a return value of False.
+
+ Otherwise, the 'Versioning' entry in the dictionary returned by get_versioning_status can
+ be 'Enabled', 'Suspended' or 'Disabled' which we map to True, None and False
+ respectively. Note that we've never seen a versioning status of 'Disabled', only the
+ empty dictionary. Calling configure_versioning with False on a bucket will cause
+ get_versioning_status to then return 'Suspended' even on a new bucket that never had
+ versioning enabled.
+ """
+ for attempt in retry_s3():
+ with attempt:
+ status = bucket.get_versioning_status()
+ return self.versionings[status['Versioning']] if status else False
+
+ def __getBucketRegion(self, bucket):
+ for attempt in retry_s3():
+ with attempt:
+ return bucket_location_to_region(bucket.get_location())
+
+ def destroy(self):
+ # FIXME: Destruction of encrypted stores only works after initialize() or .resume()
+ # See https://github.com/BD2KGenomics/toil/issues/1041
+ try:
+ self._bind(create=False, block=False)
+ except BucketLocationConflictException:
+ # If the unique jobstore bucket name existed, _bind would have raised a
+ # BucketLocationConflictException before calling destroy. Calling _bind here again
+ # would reraise the same exception so we need to catch and ignore that exception.
+ pass
+ # TODO: Add other failure cases to be ignored here.
+ self._registered = None
+ if self.filesBucket is not None:
+ self._delete_bucket(self.filesBucket)
+ self.filesBucket = None
+ for name in 'filesDomain', 'jobsDomain':
+ domain = getattr(self, name)
+ if domain is not None:
+ self._delete_domain(domain)
+ setattr(self, name, None)
+ self._registered = False
+
+ def _delete_domain(self, domain):
+ for attempt in retry_sdb():
+ with attempt:
+ try:
+ domain.delete()
+ except SDBResponseError as e:
+ if no_such_sdb_domain(e):
+ pass
+ else:
+ raise
+
+ def _delete_bucket(self, bucket):
+ for attempt in retry_s3():
+ with attempt:
+ try:
+ for upload in bucket.list_multipart_uploads():
+ upload.cancel_upload()
+ for key in list(bucket.list_versions()):
+ bucket.delete_key(key.name, version_id=key.version_id)
+ bucket.delete()
+ except S3ResponseError as e:
+ if e.error_code == 'NoSuchBucket':
+ pass
+ else:
+ raise
+
+
+aRepr = reprlib.Repr()
+aRepr.maxstring = 38 # so UUIDs don't get truncated (36 for UUID plus 2 for quotes)
+custom_repr = aRepr.repr
+
+
+class AWSJob(JobGraph, SDBHelper):
+ """
+ A Job that can be converted to and from an SDB item.
+ """
+
+ @classmethod
+ def fromItem(cls, item):
+ """
+ :type item: Item
+ :rtype: AWSJob
+ """
+ binary, _ = cls.attributesToBinary(item)
+ assert binary is not None
+ return cPickle.loads(binary)
+
+ def toItem(self):
+ """
+ To to a peculiarity of Boto's SDB bindings, this method does not return an Item,
+ but a tuple. The returned tuple can be used with put_attributes like so
+
+ domain.put_attributes( *toItem(...) )
+
+ :rtype: (str,dict)
+ :return: a str for the item's name and a dictionary for the item's attributes
+ """
+ return self.jobStoreID, self.binaryToAttributes(cPickle.dumps(self))
+
+
+class BucketLocationConflictException(Exception):
+ def __init__(self, bucketRegion):
+ super(BucketLocationConflictException, self).__init__(
+ 'A bucket with the same name as the jobstore was found in another region (%s). '
+ 'Cannot proceed as the unique bucket name is already in use.' % bucketRegion)
diff --git a/src/toil/jobStores/aws/utils.py b/src/toil/jobStores/aws/utils.py
new file mode 100644
index 0000000..0f37587
--- /dev/null
+++ b/src/toil/jobStores/aws/utils.py
@@ -0,0 +1,267 @@
+# Copyright (C) 2015-2016 Regents of the University of California
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+import base64
+import bz2
+import socket
+import logging
+import types
+
+import errno
+from ssl import SSLError
+
+# Python 3 compatibility imports
+from six import iteritems
+
+from bd2k.util.retry import retry
+from boto.exception import (SDBResponseError,
+ BotoServerError,
+ S3ResponseError,
+ S3CreateError,
+ S3CopyError)
+
+log = logging.getLogger(__name__)
+
+
+class SDBHelper(object):
+ """
+ A mixin with methods for storing limited amounts of binary data in an SDB item
+
+ >>> import os
+ >>> H=SDBHelper
+ >>> H.presenceIndicator()
+ '000'
+ >>> H.binaryToAttributes(None)
+ {}
+ >>> H.attributesToBinary({})
+ (None, 0)
+ >>> H.binaryToAttributes('')
+ {'000': 'VQ=='}
+ >>> H.attributesToBinary({'000': 'VQ=='})
+ ('', 1)
+
+ Good pseudo-random data is very likely smaller than its bzip2ed form. Subtract 1 for the type
+ character, i.e 'C' or 'U', with which the string is prefixed. We should get one full chunk:
+
+ >>> s = os.urandom(H.maxRawValueSize-1)
+ >>> d = H.binaryToAttributes(s)
+ >>> len(d), len(d['000'])
+ (1, 1024)
+ >>> H.attributesToBinary(d) == (s, 1)
+ True
+
+ One byte more and we should overflow four bytes into the second chunk, two bytes for
+ base64-encoding the additional character and two bytes for base64-padding to the next quartet.
+
+ >>> s += s[0]
+ >>> d = H.binaryToAttributes(s)
+ >>> len(d), len(d['000']), len(d['001'])
+ (2, 1024, 4)
+ >>> H.attributesToBinary(d) == (s, 2)
+ True
+
+ """
+ # The SDB documentation is not clear as to whether the attribute value size limit of 1024
+ # applies to the base64-encoded value or the raw value. It suggests that responses are
+ # automatically encoded from which I conclude that the limit should apply to the raw,
+ # unencoded value. However, there seems to be a discrepancy between how Boto computes the
+ # request signature if a value contains a binary data, and how SDB does it. This causes
+ # requests to fail signature verification, resulting in a 403. We therefore have to
+ # base64-encode values ourselves even if that means we loose a quarter of capacity.
+
+ maxAttributesPerItem = 256
+ maxValueSize = 1024
+ maxRawValueSize = maxValueSize * 3 / 4
+ # Just make sure we don't have a problem with padding or integer truncation:
+ assert len(base64.b64encode(' ' * maxRawValueSize)) == 1024
+ assert len(base64.b64encode(' ' * (1 + maxRawValueSize))) > 1024
+
+ @classmethod
+ def _reservedAttributes(cls):
+ """
+ Override in subclass to reserve a certain number of attributes that can't be used for
+ chunks.
+ """
+ return 0
+
+ @classmethod
+ def _maxChunks(cls):
+ return cls.maxAttributesPerItem - cls._reservedAttributes()
+
+ @classmethod
+ def maxBinarySize(cls):
+ return cls._maxChunks() * cls.maxRawValueSize - 1 # for the 'C' or 'U' prefix
+
+ @classmethod
+ def _maxEncodedSize(cls):
+ return cls._maxChunks() * cls.maxValueSize
+
+ @classmethod
+ def binaryToAttributes(cls, binary):
+ if binary is None: return {}
+ assert len(binary) <= cls.maxBinarySize()
+ # The use of compression is just an optimization. We can't include it in the maxValueSize
+ # computation because the compression ratio depends on the input.
+ compressed = bz2.compress(binary)
+ if len(compressed) > len(binary):
+ compressed = 'U' + binary
+ else:
+ compressed = 'C' + compressed
+ encoded = base64.b64encode(compressed)
+ assert len(encoded) <= cls._maxEncodedSize()
+ n = cls.maxValueSize
+ chunks = (encoded[i:i + n] for i in range(0, len(encoded), n))
+ return {cls._chunkName(i): chunk for i, chunk in enumerate(chunks)}
+
+ @classmethod
+ def _chunkName(cls, i):
+ return str(i).zfill(3)
+
+ @classmethod
+ def _isValidChunkName(cls, s):
+ return len(s) == 3 and s.isdigit()
+
+ @classmethod
+ def presenceIndicator(cls):
+ """
+ The key that is guaranteed to be present in the return value of binaryToAttributes().
+ Assuming that binaryToAttributes() is used with SDB's PutAttributes, the return value of
+ this method could be used to detect the presence/absence of an item in SDB.
+ """
+ return cls._chunkName(0)
+
+ @classmethod
+ def attributesToBinary(cls, attributes):
+ """
+ :rtype: (str|None,int)
+ :return: the binary data and the number of chunks it was composed from
+ """
+ chunks = [(int(k), v) for k, v in iteritems(attributes) if cls._isValidChunkName(k)]
+ chunks.sort()
+ numChunks = len(chunks)
+ if numChunks:
+ assert len(set(k for k, v in chunks)) == chunks[-1][0] + 1 == numChunks
+ serializedJob = ''.join(v for k, v in chunks)
+ compressed = base64.b64decode(serializedJob)
+ if compressed[0] == 'C':
+ binary = bz2.decompress(compressed[1:])
+ elif compressed[0] == 'U':
+ binary = compressed[1:]
+ else:
+ assert False
+ else:
+ binary = None
+ return binary, numChunks
+
+
+from boto.sdb.connection import SDBConnection
+
+
+def _put_attributes_using_post(self, domain_or_name, item_name, attributes,
+ replace=True, expected_value=None):
+ """
+ Monkey-patched version of SDBConnection.put_attributes that uses POST instead of GET
+
+ The GET version is subject to the URL length limit which kicks in before the 256 x 1024 limit
+ for attribute values. Using POST prevents that.
+
+ https://github.com/BD2KGenomics/toil/issues/502
+ """
+ domain, domain_name = self.get_domain_and_name(domain_or_name)
+ params = {'DomainName': domain_name,
+ 'ItemName': item_name}
+ self._build_name_value_list(params, attributes, replace)
+ if expected_value:
+ self._build_expected_value(params, expected_value)
+ # The addition of the verb keyword argument is the only difference to put_attributes (Hannes)
+ return self.get_status('PutAttributes', params, verb='POST')
+
+
+def monkeyPatchSdbConnection(sdb):
+ """
+ :type sdb: SDBConnection
+ """
+ sdb.put_attributes = types.MethodType(_put_attributes_using_post, sdb)
+
+
+default_delays = (0, 1, 1, 4, 16, 64)
+default_timeout = 300
+
+
+def connection_reset(e):
+ # For some reason we get 'error: [Errno 104] Connection reset by peer' where the
+ # English description suggests that errno is 54 (ECONNRESET) while the actual
+ # errno is listed as 104. To be safe, we check for both:
+ return isinstance(e, socket.error) and e.errno in (errno.ECONNRESET, 104)
+
+
+def sdb_unavailable(e):
+ return isinstance(e, BotoServerError) and e.status == 503
+
+
+def no_such_sdb_domain(e):
+ return (isinstance(e, SDBResponseError)
+ and e.error_code
+ and e.error_code.endswith('NoSuchDomain'))
+
+
+def retryable_ssl_error(e):
+ # https://github.com/BD2KGenomics/toil/issues/978
+ return isinstance(e, SSLError) and e.reason == 'DECRYPTION_FAILED_OR_BAD_RECORD_MAC'
+
+
+def retryable_sdb_errors(e):
+ return (sdb_unavailable(e)
+ or no_such_sdb_domain(e)
+ or connection_reset(e)
+ or retryable_ssl_error(e))
+
+
+def retry_sdb(delays=default_delays, timeout=default_timeout, predicate=retryable_sdb_errors):
+ return retry(delays=delays, timeout=timeout, predicate=predicate)
+
+
+def retryable_s3_errors(e):
+ return (isinstance(e, (S3CreateError, S3ResponseError))
+ and e.status == 409
+ and 'try again' in e.message
+ or connection_reset(e)
+ or isinstance(e, BotoServerError) and e.status == 500
+ or isinstance(e, S3CopyError) and 'try again' in e.message)
+
+
+def retry_s3(delays=default_delays, timeout=default_timeout, predicate=retryable_s3_errors):
+ return retry(delays=delays, timeout=timeout, predicate=predicate)
+
+
+def region_to_bucket_location(region):
+ if region == 'us-east-1':
+ return ''
+ else:
+ return region
+
+
+def bucket_location_to_region(location):
+ if location == '':
+ return 'us-east-1'
+ else:
+ return location
+
+
+def bucket_location_to_http_url(location):
+ if location:
+ return 'https://s3-' + location + '.amazonaws.com'
+ else:
+ return 'https://s3.amazonaws.com'
diff --git a/src/toil/jobStores/azureJobStore.py b/src/toil/jobStores/azureJobStore.py
new file mode 100644
index 0000000..dc725bc
--- /dev/null
+++ b/src/toil/jobStores/azureJobStore.py
@@ -0,0 +1,827 @@
+# Copyright (C) 2015-2016 Regents of the University of California
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+
+import bz2
+import inspect
+import logging
+import os
+import re
+import socket
+import uuid
+from collections import namedtuple
+from contextlib import contextmanager
+from datetime import datetime, timedelta
+
+# Python 3 compatibility imports
+from six.moves import cPickle
+from six.moves.http_client import HTTPException
+from six.moves.configparser import RawConfigParser, NoOptionError
+
+from azure.common import AzureMissingResourceHttpError, AzureException
+from azure.storage import SharedAccessPolicy, AccessPolicy
+from azure.storage.blob import BlobService, BlobSharedAccessPermissions
+from azure.storage.table import TableService, EntityProperty
+
+# noinspection PyPackageRequirements
+# (pulled in transitively)
+import requests
+from bd2k.util import strict_bool, memoize
+from bd2k.util.exceptions import panic
+from bd2k.util.retry import retry
+
+from toil.jobStores.utils import WritablePipe, ReadablePipe
+from toil.jobGraph import JobGraph
+from toil.jobStores.abstractJobStore import (AbstractJobStore,
+ NoSuchJobException,
+ ConcurrentFileModificationException,
+ NoSuchFileException,
+ InvalidImportExportUrlException,
+ JobStoreExistsException,
+ NoSuchJobStoreException)
+import toil.lib.encryption as encryption
+
+logger = logging.getLogger(__name__)
+
+credential_file_path = '~/.toilAzureCredentials'
+
+
+def _fetchAzureAccountKey(accountName):
+ """
+ Find the account key for a given Azure storage account.
+
+ The account key is taken from the AZURE_ACCOUNT_KEY_<account> environment variable if it
+ exists, then from plain AZURE_ACCOUNT_KEY, and then from looking in the file
+ ~/.toilAzureCredentials. That file has format:
+
+ [AzureStorageCredentials]
+ accountName1=ACCOUNTKEY1==
+ accountName2=ACCOUNTKEY2==
+ """
+ try:
+ return os.environ['AZURE_ACCOUNT_KEY_' + accountName]
+ except KeyError:
+ try:
+ return os.environ['AZURE_ACCOUNT_KEY']
+ except KeyError:
+ configParser = RawConfigParser()
+ configParser.read(os.path.expanduser(credential_file_path))
+ try:
+ return configParser.get('AzureStorageCredentials', accountName)
+ except NoOptionError:
+ raise RuntimeError("No account key found for '%s', please provide it in '%s'" %
+ (accountName, credential_file_path))
+
+
+maxAzureTablePropertySize = 64 * 1024
+
+
+class AzureJobStore(AbstractJobStore):
+ """
+ A job store that uses Azure's blob store for file storage and Table Service to store job info
+ with strong consistency.
+ """
+
+ # Dots in container names should be avoided because container names are used in HTTPS bucket
+ # URLs where the may interfere with the certificate common name. We use a double underscore
+ # as a separator instead.
+ #
+ containerNameRe = re.compile(r'^[a-z0-9](-?[a-z0-9]+)+[a-z0-9]$')
+
+ # See https://msdn.microsoft.com/en-us/library/azure/dd135715.aspx
+ #
+ minContainerNameLen = 3
+ maxContainerNameLen = 63
+ maxNameLen = 10
+ nameSeparator = 'xx' # Table names must be alphanumeric
+ # Length of a jobID - used to test if a stats file has been read already or not
+ jobIDLength = len(str(uuid.uuid4()))
+
+ def __init__(self, locator, jobChunkSize=maxAzureTablePropertySize):
+ super(AzureJobStore, self).__init__()
+ accountName, namePrefix = locator.split(':', 1)
+ if '--' in namePrefix:
+ raise ValueError("Invalid name prefix '%s'. Name prefixes may not contain %s."
+ % (namePrefix, self.nameSeparator))
+ if not self.containerNameRe.match(namePrefix):
+ raise ValueError("Invalid name prefix '%s'. Name prefixes must contain only digits, "
+ "hyphens or lower-case letters and must not start or end in a "
+ "hyphen." % namePrefix)
+ # Reserve 13 for separator and suffix
+ if len(namePrefix) > self.maxContainerNameLen - self.maxNameLen - len(self.nameSeparator):
+ raise ValueError(("Invalid name prefix '%s'. Name prefixes may not be longer than 50 "
+ "characters." % namePrefix))
+ if '--' in namePrefix:
+ raise ValueError("Invalid name prefix '%s'. Name prefixes may not contain "
+ "%s." % (namePrefix, self.nameSeparator))
+ self.locator = locator
+ self.jobChunkSize = jobChunkSize
+ self.accountKey = _fetchAzureAccountKey(accountName)
+ self.accountName = accountName
+ # Table names have strict requirements in Azure
+ self.namePrefix = self._sanitizeTableName(namePrefix)
+ # These are the main API entry points.
+ self.tableService = TableService(account_key=self.accountKey, account_name=accountName)
+ self.blobService = BlobService(account_key=self.accountKey, account_name=accountName)
+ # Serialized jobs table
+ self.jobItems = None
+ # Job<->file mapping table
+ self.jobFileIDs = None
+ # Container for all shared and unshared files
+ self.files = None
+ # Stats and logging strings
+ self.statsFiles = None
+ # File IDs that contain stats and logging strings
+ self.statsFileIDs = None
+
+ @property
+ def keyPath(self):
+ return self.config.cseKey
+
+ def initialize(self, config):
+ if self._jobStoreExists():
+ raise JobStoreExistsException(self.locator)
+ logger.debug("Creating job store at '%s'" % self.locator)
+ self._bind(create=True)
+ super(AzureJobStore, self).initialize(config)
+
+ def resume(self):
+ if not self._jobStoreExists():
+ raise NoSuchJobStoreException(self.locator)
+ logger.debug("Using existing job store at '%s'" % self.locator)
+ self._bind(create=False)
+ super(AzureJobStore, self).resume()
+
+ def destroy(self):
+ for name in 'jobItems', 'jobFileIDs', 'files', 'statsFiles', 'statsFileIDs':
+ resource = getattr(self, name)
+ if resource is not None:
+ if isinstance(resource, AzureTable):
+ resource.delete_table()
+ elif isinstance(resource, AzureBlobContainer):
+ resource.delete_container()
+ else:
+ assert False
+ setattr(self, name, None)
+
+ def _jobStoreExists(self):
+ """
+ Checks if job store exists by querying the existence of the statsFileIDs table. Note that
+ this is the last component that is deleted in :meth:`.destroy`.
+ """
+ for attempt in retry_azure():
+ with attempt:
+ try:
+ table = self.tableService.query_tables(table_name=self._qualify('statsFileIDs'))
+ except AzureMissingResourceHttpError as e:
+ if e.status_code == 404:
+ return False
+ else:
+ raise
+ else:
+ return table is not None
+
+ def _bind(self, create=False):
+ table = self._bindTable
+ container = self._bindContainer
+ for name, binder in (('jobItems', table),
+ ('jobFileIDs', table),
+ ('files', container),
+ ('statsFiles', container),
+ ('statsFileIDs', table)):
+ if getattr(self, name) is None:
+ setattr(self, name, binder(self._qualify(name), create=create))
+
+ def _qualify(self, name):
+ return self.namePrefix + self.nameSeparator + name.lower()
+
+ def jobs(self):
+
+ # How many jobs have we done?
+ total_processed = 0
+
+ for jobEntity in self.jobItems.query_entities_auto():
+ # Process the items in the page
+ yield AzureJob.fromEntity(jobEntity)
+ total_processed += 1
+
+ if total_processed % 1000 == 0:
+ # Produce some feedback for the user, because this can take
+ # a long time on, for example, Azure
+ logger.debug("Processed %d total jobs" % total_processed)
+
+ logger.debug("Processed %d total jobs" % total_processed)
+
+ def create(self, jobNode):
+ jobStoreID = self._newJobID()
+ job = AzureJob.fromJobNode(jobNode, jobStoreID, self._defaultTryCount())
+ entity = job.toItem(chunkSize=self.jobChunkSize)
+ entity['RowKey'] = jobStoreID
+ self.jobItems.insert_entity(entity=entity)
+ return job
+
+ def exists(self, jobStoreID):
+ if self.jobItems.get_entity(row_key=jobStoreID) is None:
+ return False
+ return True
+
+ def load(self, jobStoreID):
+ jobEntity = self.jobItems.get_entity(row_key=jobStoreID)
+ if jobEntity is None:
+ raise NoSuchJobException(jobStoreID)
+ return AzureJob.fromEntity(jobEntity)
+
+ def update(self, job):
+ self.jobItems.update_entity(row_key=job.jobStoreID,
+ entity=job.toItem(chunkSize=self.jobChunkSize))
+
+ def delete(self, jobStoreID):
+ try:
+ self.jobItems.delete_entity(row_key=jobStoreID)
+ except AzureMissingResourceHttpError:
+ # Job deletion is idempotent, and this job has been deleted already
+ return
+ filterString = "PartitionKey eq '%s'" % jobStoreID
+ for fileEntity in self.jobFileIDs.query_entities(filter=filterString):
+ jobStoreFileID = fileEntity.RowKey
+ self.deleteFile(jobStoreFileID)
+
+ def getEnv(self):
+ return dict(AZURE_ACCOUNT_KEY=self.accountKey)
+
+ class BlobInfo(namedtuple('BlobInfo', ('account', 'container', 'name'))):
+ @property
+ @memoize
+ def service(self):
+ return BlobService(account_name=self.account,
+ account_key=_fetchAzureAccountKey(self.account))
+
+ @classmethod
+ def _readFromUrl(cls, url, writable):
+ blob = cls._parseWasbUrl(url)
+ blob.service.get_blob_to_file(container_name=blob.container,
+ blob_name=blob.name,
+ stream=writable)
+
+ @classmethod
+ def _writeToUrl(cls, readable, url):
+ blob = cls._parseWasbUrl(url)
+ blob.service.put_block_blob_from_file(container_name=blob.container,
+ blob_name=blob.name,
+ stream=readable)
+
+ @classmethod
+ def _parseWasbUrl(cls, url):
+ """
+ :param urlparse.ParseResult url: x
+ :rtype: AzureJobStore.BlobInfo
+ """
+ assert url.scheme in ('wasb', 'wasbs')
+ try:
+ container, account = url.netloc.split('@')
+ except ValueError:
+ raise InvalidImportExportUrlException(url)
+ suffix = '.blob.core.windows.net'
+ if account.endswith(suffix):
+ account = account[:-len(suffix)]
+ else:
+ raise InvalidImportExportUrlException(url)
+ assert url.path[0] == '/'
+ return cls.BlobInfo(account=account, container=container, name=url.path[1:])
+
+ @classmethod
+ def _supportsUrl(cls, url, export=False):
+ return url.scheme.lower() in ('wasb', 'wasbs')
+
+ def writeFile(self, localFilePath, jobStoreID=None):
+ jobStoreFileID = self._newFileID()
+ self.updateFile(jobStoreFileID, localFilePath)
+ self._associateFileWithJob(jobStoreFileID, jobStoreID)
+ return jobStoreFileID
+
+ def updateFile(self, jobStoreFileID, localFilePath):
+ with open(localFilePath) as read_fd:
+ with self._uploadStream(jobStoreFileID, self.files) as write_fd:
+ while True:
+ buf = read_fd.read(self._maxAzureBlockBytes)
+ write_fd.write(buf)
+ if len(buf) == 0:
+ break
+
+ def readFile(self, jobStoreFileID, localFilePath):
+ try:
+ with self._downloadStream(jobStoreFileID, self.files) as read_fd:
+ with open(localFilePath, 'w') as write_fd:
+ while True:
+ buf = read_fd.read(self._maxAzureBlockBytes)
+ write_fd.write(buf)
+ if not buf:
+ break
+ except AzureMissingResourceHttpError:
+ raise NoSuchFileException(jobStoreFileID)
+
+ def deleteFile(self, jobStoreFileID):
+ try:
+ self.files.delete_blob(blob_name=jobStoreFileID)
+ self._dissociateFileFromJob(jobStoreFileID)
+ except AzureMissingResourceHttpError:
+ pass
+
+ def fileExists(self, jobStoreFileID):
+ # As Azure doesn't have a blob_exists method (at least in the
+ # python API) we just try to download the metadata, and hope
+ # the metadata is small so the call will be fast.
+ try:
+ self.files.get_blob_metadata(blob_name=jobStoreFileID)
+ return True
+ except AzureMissingResourceHttpError:
+ return False
+
+ @contextmanager
+ def writeFileStream(self, jobStoreID=None):
+ # TODO: this (and all stream methods) should probably use the
+ # Append Blob type, but that is not currently supported by the
+ # Azure Python API.
+ jobStoreFileID = self._newFileID()
+ with self._uploadStream(jobStoreFileID, self.files) as fd:
+ yield fd, jobStoreFileID
+ self._associateFileWithJob(jobStoreFileID, jobStoreID)
+
+ @contextmanager
+ def updateFileStream(self, jobStoreFileID):
+ with self._uploadStream(jobStoreFileID, self.files, checkForModification=True) as fd:
+ yield fd
+
+ def getEmptyFileStoreID(self, jobStoreID=None):
+ jobStoreFileID = self._newFileID()
+ self.files.put_blob(blob_name=jobStoreFileID, blob='',
+ x_ms_blob_type='BlockBlob')
+ self._associateFileWithJob(jobStoreFileID, jobStoreID)
+ return jobStoreFileID
+
+ @contextmanager
+ def readFileStream(self, jobStoreFileID):
+ if not self.fileExists(jobStoreFileID):
+ raise NoSuchFileException(jobStoreFileID)
+ with self._downloadStream(jobStoreFileID, self.files) as fd:
+ yield fd
+
+ @contextmanager
+ def writeSharedFileStream(self, sharedFileName, isProtected=None):
+ assert self._validateSharedFileName(sharedFileName)
+ sharedFileID = self._newFileID(sharedFileName)
+ with self._uploadStream(sharedFileID, self.files, encrypted=isProtected) as fd:
+ yield fd
+
+ @contextmanager
+ def readSharedFileStream(self, sharedFileName):
+ assert self._validateSharedFileName(sharedFileName)
+ sharedFileID = self._newFileID(sharedFileName)
+ if not self.fileExists(sharedFileID):
+ raise NoSuchFileException(sharedFileID)
+ with self._downloadStream(sharedFileID, self.files) as fd:
+ yield fd
+
+ def writeStatsAndLogging(self, statsAndLoggingString):
+ # TODO: would be a great use case for the append blobs, once implemented in the Azure SDK
+ jobStoreFileID = self._newFileID()
+ encrypted = self.keyPath is not None
+ if encrypted:
+ statsAndLoggingString = encryption.encrypt(statsAndLoggingString, self.keyPath)
+ self.statsFiles.put_block_blob_from_text(blob_name=jobStoreFileID,
+ text=statsAndLoggingString,
+ x_ms_meta_name_values=dict(
+ encrypted=str(encrypted)))
+ self.statsFileIDs.insert_entity(entity={'RowKey': jobStoreFileID})
+
+ def readStatsAndLogging(self, callback, readAll=False):
+ suffix = '_old'
+ numStatsFiles = 0
+ for entity in self.statsFileIDs.query_entities():
+ jobStoreFileID = entity.RowKey
+ hasBeenRead = len(jobStoreFileID) > self.jobIDLength
+ if not hasBeenRead:
+ with self._downloadStream(jobStoreFileID, self.statsFiles) as fd:
+ callback(fd)
+ # Mark this entity as read by appending the suffix
+ self.statsFileIDs.insert_entity(entity={'RowKey': jobStoreFileID + suffix})
+ self.statsFileIDs.delete_entity(row_key=jobStoreFileID)
+ numStatsFiles += 1
+ elif readAll:
+ # Strip the suffix to get the original ID
+ jobStoreFileID = jobStoreFileID[:-len(suffix)]
+ with self._downloadStream(jobStoreFileID, self.statsFiles) as fd:
+ callback(fd)
+ numStatsFiles += 1
+ return numStatsFiles
+
+ _azureTimeFormat = "%Y-%m-%dT%H:%M:%SZ"
+
+ def getPublicUrl(self, jobStoreFileID):
+ try:
+ self.files.get_blob_properties(blob_name=jobStoreFileID)
+ except AzureMissingResourceHttpError:
+ raise NoSuchFileException(jobStoreFileID)
+ # Compensate of a little bit of clock skew
+ startTimeStr = (datetime.utcnow() - timedelta(minutes=5)).strftime(self._azureTimeFormat)
+ endTime = datetime.utcnow() + self.publicUrlExpiration
+ endTimeStr = endTime.strftime(self._azureTimeFormat)
+ sap = SharedAccessPolicy(AccessPolicy(startTimeStr, endTimeStr,
+ BlobSharedAccessPermissions.READ))
+ sas_token = self.files.generate_shared_access_signature(blob_name=jobStoreFileID,
+ shared_access_policy=sap)
+ return self.files.make_blob_url(blob_name=jobStoreFileID) + '?' + sas_token
+
+ def getSharedPublicUrl(self, sharedFileName):
+ jobStoreFileID = self._newFileID(sharedFileName)
+ return self.getPublicUrl(jobStoreFileID)
+
+ def _newJobID(self):
+ # raw UUIDs don't work for Azure property names because the '-' character is disallowed.
+ return str(uuid.uuid4()).replace('-', '_')
+
+ # A dummy job ID under which all shared files are stored.
+ sharedFileJobID = uuid.UUID('891f7db6-e4d9-4221-a58e-ab6cc4395f94')
+
+ def _newFileID(self, sharedFileName=None):
+ if sharedFileName is None:
+ ret = str(uuid.uuid4())
+ else:
+ ret = str(uuid.uuid5(self.sharedFileJobID, str(sharedFileName)))
+ return ret.replace('-', '_')
+
+ def _associateFileWithJob(self, jobStoreFileID, jobStoreID=None):
+ if jobStoreID is not None:
+ self.jobFileIDs.insert_entity(entity={'PartitionKey': jobStoreID,
+ 'RowKey': jobStoreFileID})
+
+ def _dissociateFileFromJob(self, jobStoreFileID):
+ entities = self.jobFileIDs.query_entities(filter="RowKey eq '%s'" % jobStoreFileID)
+ if entities:
+ assert len(entities) == 1
+ jobStoreID = entities[0].PartitionKey
+ self.jobFileIDs.delete_entity(partition_key=jobStoreID, row_key=jobStoreFileID)
+
+ def _bindTable(self, tableName, create=False):
+ for attempt in retry_azure():
+ with attempt:
+ try:
+ tables = self.tableService.query_tables(table_name=tableName)
+ except AzureMissingResourceHttpError as e:
+ if e.status_code != 404:
+ raise
+ else:
+ if tables:
+ assert tables[0].name == tableName
+ return AzureTable(self.tableService, tableName)
+ if create:
+ self.tableService.create_table(tableName)
+ return AzureTable(self.tableService, tableName)
+ else:
+ return None
+
+ def _bindContainer(self, containerName, create=False):
+ for attempt in retry_azure():
+ with attempt:
+ try:
+ self.blobService.get_container_properties(containerName)
+ except AzureMissingResourceHttpError as e:
+ if e.status_code == 404:
+ if create:
+ self.blobService.create_container(containerName)
+ else:
+ return None
+ else:
+ raise
+ return AzureBlobContainer(self.blobService, containerName)
+
+ def _sanitizeTableName(self, tableName):
+ """
+ Azure table names must start with a letter and be alphanumeric.
+
+ This will never cause a collision if uuids are used, but
+ otherwise may not be safe.
+ """
+ return 'a' + filter(lambda x: x.isalnum(), tableName)
+
+ # Maximum bytes that can be in any block of an Azure block blob
+ # https://github.com/Azure/azure-storage-python/blob/4c7666e05a9556c10154508335738ee44d7cb104/azure/storage/blob/blobservice.py#L106
+ _maxAzureBlockBytes = 4 * 1024 * 1024
+
+ @contextmanager
+ def _uploadStream(self, jobStoreFileID, container, checkForModification=False, encrypted=None):
+ """
+ :param encrypted: True to enforce encryption (will raise exception unless key is set),
+ False to prevent encryption or None to encrypt if key is set.
+ """
+ if checkForModification:
+ try:
+ expectedVersion = container.get_blob_properties(blob_name=jobStoreFileID)['etag']
+ except AzureMissingResourceHttpError:
+ expectedVersion = None
+
+ if encrypted is None:
+ encrypted = self.keyPath is not None
+ elif encrypted:
+ if self.keyPath is None:
+ raise RuntimeError('Encryption requested but no key was provided')
+
+ maxBlockSize = self._maxAzureBlockBytes
+ if encrypted:
+ # There is a small overhead for encrypted data.
+ maxBlockSize -= encryption.overhead
+
+ store = self
+
+ class UploadPipe(WritablePipe):
+
+ def readFrom(self, readable):
+ blockIDs = []
+ try:
+ while True:
+ buf = readable.read(maxBlockSize)
+ if len(buf) == 0:
+ # We're safe to break here even if we never read anything, since
+ # putting an empty block list creates an empty blob.
+ break
+ if encrypted:
+ buf = encryption.encrypt(buf, store.keyPath)
+ blockID = store._newFileID()
+ container.put_block(blob_name=jobStoreFileID,
+ block=buf,
+ blockid=blockID)
+ blockIDs.append(blockID)
+ except:
+ with panic(log=logger):
+ # This is guaranteed to delete any uncommitted blocks.
+ container.delete_blob(blob_name=jobStoreFileID)
+
+ if checkForModification and expectedVersion is not None:
+ # Acquire a (60-second) write lock,
+ leaseID = container.lease_blob(blob_name=jobStoreFileID,
+ x_ms_lease_action='acquire')['x-ms-lease-id']
+ # check for modification,
+ blobProperties = container.get_blob_properties(blob_name=jobStoreFileID)
+ if blobProperties['etag'] != expectedVersion:
+ container.lease_blob(blob_name=jobStoreFileID,
+ x_ms_lease_action='release',
+ x_ms_lease_id=leaseID)
+ raise ConcurrentFileModificationException(jobStoreFileID)
+ # commit the file,
+ container.put_block_list(blob_name=jobStoreFileID,
+ block_list=blockIDs,
+ x_ms_lease_id=leaseID,
+ x_ms_meta_name_values=dict(
+ encrypted=str(encrypted)))
+ # then release the lock.
+ container.lease_blob(blob_name=jobStoreFileID,
+ x_ms_lease_action='release',
+ x_ms_lease_id=leaseID)
+ else:
+ # No need to check for modification, just blindly write over whatever
+ # was there.
+ container.put_block_list(blob_name=jobStoreFileID,
+ block_list=blockIDs,
+ x_ms_meta_name_values=dict(encrypted=str(encrypted)))
+
+ with UploadPipe() as writable:
+ yield writable
+
+ @contextmanager
+ def _downloadStream(self, jobStoreFileID, container):
+ # The reason this is not in the writer is so we catch non-existant blobs early
+
+ blobProps = container.get_blob_properties(blob_name=jobStoreFileID)
+
+ encrypted = strict_bool(blobProps['x-ms-meta-encrypted'])
+ if encrypted and self.keyPath is None:
+ raise AssertionError('Content is encrypted but no key was provided.')
+
+ outer_self = self
+
+ class DownloadPipe(ReadablePipe):
+ def writeTo(self, writable):
+ chunkStart = 0
+ fileSize = int(blobProps['Content-Length'])
+ while chunkStart < fileSize:
+ chunkEnd = chunkStart + outer_self._maxAzureBlockBytes - 1
+ buf = container.get_blob(blob_name=jobStoreFileID,
+ x_ms_range="bytes=%d-%d" % (chunkStart, chunkEnd))
+ if encrypted:
+ buf = encryption.decrypt(buf, outer_self.keyPath)
+ writable.write(buf)
+ chunkStart = chunkEnd + 1
+
+ with DownloadPipe() as readable:
+ yield readable
+
+
+class AzureTable(object):
+ """
+ A shim over the Azure TableService API, specfic for a single table.
+
+ This class automatically forwards method calls to the TableService
+ API, including the proper table name and default partition key if
+ needed. To avoid confusion, all method calls must use *only*
+ keyword arguments.
+
+ In addition, this wrapper:
+ - allows a default partition key to be used when one is not specified
+ - returns None when attempting to get a non-existent entity.
+ """
+
+ def __init__(self, tableService, tableName):
+ self.tableService = tableService
+ self.tableName = tableName
+
+ defaultPartition = 'default'
+
+ def __getattr__(self, name):
+ def f(*args, **kwargs):
+ assert len(args) == 0
+ function = getattr(self.tableService, name)
+ funcArgs, _, _, _ = inspect.getargspec(function)
+ kwargs['table_name'] = self.tableName
+ if 'partition_key' not in kwargs and 'partition_key' in funcArgs:
+ kwargs['partition_key'] = self.defaultPartition
+ if 'entity' in kwargs:
+ if 'PartitionKey' not in kwargs['entity']:
+ kwargs['entity']['PartitionKey'] = self.defaultPartition
+
+ for attempt in retry_azure():
+ with attempt:
+ return function(**kwargs)
+
+ return f
+
+ def get_entity(self, **kwargs):
+ try:
+ return self.__getattr__('get_entity')(**kwargs)
+ except AzureMissingResourceHttpError:
+ return None
+
+ def query_entities_auto(self, **kwargs):
+ """
+ An automatically-paged version of query_entities. The iterator just
+ yields all entities matching the query, occasionally going back to Azure
+ for the next page.
+ """
+
+ # We need to page through the results, since we only get some of them at
+ # a time. Just like in the BlobService. See the only documentation
+ # available: the API bindings source code, at:
+ # https://github.com/Azure/azure-storage-python/blob/09e9f186740407672777d6cb6646c33a2273e1a8/azure/storage/table/tableservice.py#L385
+
+ # These two together constitute the primary key for an item.
+ next_partition_key = None
+ next_row_key = None
+
+ while True:
+ # Get a page (up to 1000 items)
+ kwargs['next_partition_key'] = next_partition_key
+ kwargs['next_row_key'] = next_row_key
+ page = self.query_entities(**kwargs)
+
+ for result in page:
+ # Yield each item one at a time
+ yield result
+
+ if hasattr(page, 'x_ms_continuation'):
+ # Next time ask for the next page. If you use .get() you need
+ # the lower-case versions, but this is some kind of fancy case-
+ # insensitive dictionary.
+ next_partition_key = page.x_ms_continuation['NextPartitionKey']
+ next_row_key = page.x_ms_continuation['NextRowKey']
+ else:
+ # No continuation to check
+ next_partition_key = None
+ next_row_key = None
+
+ if not next_partition_key and not next_row_key:
+ # If we run out of pages, stop
+ break
+
+
+class AzureBlobContainer(object):
+ """
+ A shim over the BlobService API, so that the container name is automatically filled in.
+
+ To avoid confusion over the position of any remaining positional arguments, all method calls
+ must use *only* keyword arguments.
+ """
+
+ def __init__(self, blobService, containerName):
+ self.blobService = blobService
+ self.containerName = containerName
+
+ def __getattr__(self, name):
+ def f(*args, **kwargs):
+ assert len(args) == 0
+ function = getattr(self.blobService, name)
+ kwargs['container_name'] = self.containerName
+
+ for attempt in retry_azure():
+ with attempt:
+ return function(**kwargs)
+
+ return f
+
+
+class AzureJob(JobGraph):
+ """
+ Serialize and unserialize a job for storage on Azure.
+
+ Copied almost entirely from AWSJob, except to take into account the
+ fact that Azure properties must start with a letter or underscore.
+ """
+
+ defaultAttrs = ['PartitionKey', 'RowKey', 'etag', 'Timestamp']
+
+ @classmethod
+ def fromEntity(cls, jobEntity):
+ """
+ :type jobEntity: Entity
+ :rtype: AzureJob
+ """
+ jobEntity = jobEntity.__dict__
+ for attr in cls.defaultAttrs:
+ del jobEntity[attr]
+ return cls.fromItem(jobEntity)
+
+ @classmethod
+ def fromItem(cls, item):
+ """
+ :type item: dict
+ :rtype: AzureJob
+ """
+ chunkedJob = item.items()
+ chunkedJob.sort()
+ if len(chunkedJob) == 1:
+ # First element of list = tuple, second element of tuple = serialized job
+ wholeJobString = chunkedJob[0][1].value
+ else:
+ wholeJobString = ''.join(item[1].value for item in chunkedJob)
+ return cPickle.loads(bz2.decompress(wholeJobString))
+
+ def toItem(self, chunkSize=maxAzureTablePropertySize):
+ """
+ :param chunkSize: the size of a chunk for splitting up the serialized job into chunks
+ that each fit into a property value of the an Azure table entity
+ :rtype: dict
+ """
+ assert chunkSize <= maxAzureTablePropertySize
+ item = {}
+ serializedAndEncodedJob = bz2.compress(cPickle.dumps(self))
+ jobChunks = [serializedAndEncodedJob[i:i + chunkSize]
+ for i in range(0, len(serializedAndEncodedJob), chunkSize)]
+ for attributeOrder, chunk in enumerate(jobChunks):
+ item['_' + str(attributeOrder).zfill(3)] = EntityProperty('Edm.Binary', chunk)
+ return item
+
+
+def defaultRetryPredicate(exception):
+ """
+ >>> defaultRetryPredicate(socket.error())
+ True
+ >>> defaultRetryPredicate(socket.gaierror())
+ True
+ >>> defaultRetryPredicate(HTTPException())
+ True
+ >>> defaultRetryPredicate(requests.ConnectionError())
+ True
+ >>> defaultRetryPredicate(AzureException('x could not be completed within the specified time'))
+ True
+ >>> defaultRetryPredicate(AzureException('x service unavailable'))
+ True
+ >>> defaultRetryPredicate(AzureException('x server is busy'))
+ True
+ >>> defaultRetryPredicate(AzureException('x'))
+ False
+ >>> defaultRetryPredicate(RuntimeError())
+ False
+ """
+ return (isinstance(exception, (socket.error,
+ socket.gaierror,
+ HTTPException,
+ requests.ConnectionError))
+ or isinstance(exception, AzureException) and
+ any(message in str(exception).lower() for message in (
+ "could not be completed within the specified time",
+ "service unavailable",
+ "server is busy")))
+
+
+def retry_azure(delays=(0, 1, 1, 4, 16, 64), timeout=300, predicate=defaultRetryPredicate):
+ return retry(delays=delays, timeout=timeout, predicate=predicate)
diff --git a/src/toil/jobStores/conftest.py b/src/toil/jobStores/conftest.py
new file mode 100644
index 0000000..0d32e61
--- /dev/null
+++ b/src/toil/jobStores/conftest.py
@@ -0,0 +1,27 @@
+# Copyright (C) 2015-2016 Regents of the University of California
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# https://pytest.org/latest/example/pythoncollection.html
+
+collect_ignore = []
+
+try:
+ import azure
+except ImportError:
+ collect_ignore.append("azureJobStore.py")
+
+try:
+ import boto
+except ImportError:
+ collect_ignore.append("aws")
diff --git a/src/toil/jobStores/fileJobStore.py b/src/toil/jobStores/fileJobStore.py
new file mode 100644
index 0000000..54f2fd0
--- /dev/null
+++ b/src/toil/jobStores/fileJobStore.py
@@ -0,0 +1,423 @@
+# Copyright (C) 2015-2016 Regents of the University of California
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+
+from contextlib import contextmanager
+import logging
+import pickle as pickler
+import random
+import shutil
+import os
+import tempfile
+import stat
+import errno
+
+# Python 3 compatibility imports
+from six.moves import xrange
+
+from bd2k.util.exceptions import require
+
+from toil.lib.bioio import absSymPath
+from toil.jobStores.abstractJobStore import (AbstractJobStore,
+ NoSuchJobException,
+ NoSuchFileException,
+ JobStoreExistsException,
+ NoSuchJobStoreException)
+from toil.jobGraph import JobGraph
+
+logger = logging.getLogger( __name__ )
+
+
+class FileJobStore(AbstractJobStore):
+ """
+ A job store that uses a directory on a locally attached file system. To be compatible with
+ distributed batch systems, that file system must be shared by all worker nodes.
+ """
+
+ # Parameters controlling the creation of temporary files
+ validDirs = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
+ levels = 2
+
+ def __init__(self, path):
+ """
+ :param str path: Path to directory holding the job store
+ """
+ super(FileJobStore, self).__init__()
+ self.jobStoreDir = absSymPath(path)
+ logger.debug("Path to job store directory is '%s'.", self.jobStoreDir)
+ # Directory where temporary files go
+ self.tempFilesDir = os.path.join(self.jobStoreDir, 'tmp')
+
+ def initialize(self, config):
+ try:
+ os.mkdir(self.jobStoreDir)
+ except OSError as e:
+ if e.errno == errno.EEXIST:
+ raise JobStoreExistsException(self.jobStoreDir)
+ else:
+ raise
+ os.mkdir(self.tempFilesDir)
+ super(FileJobStore, self).initialize(config)
+
+ def resume(self):
+ if not os.path.exists(self.jobStoreDir):
+ raise NoSuchJobStoreException(self.jobStoreDir)
+ require( os.path.isdir, "'%s' is not a directory", self.jobStoreDir)
+ super(FileJobStore, self).resume()
+
+ def destroy(self):
+ if os.path.exists(self.jobStoreDir):
+ shutil.rmtree(self.jobStoreDir)
+
+ ##########################################
+ # The following methods deal with creating/loading/updating/writing/checking for the
+ # existence of jobs
+ ##########################################
+
+ def create(self, jobNode):
+ # The absolute path to the job directory.
+ absJobDir = tempfile.mkdtemp(prefix="job", dir=self._getTempSharedDir())
+ # Sub directory to put temporary files associated with the job in
+ os.mkdir(os.path.join(absJobDir, "g"))
+ # Make the job
+ job = JobGraph.fromJobNode(jobNode, jobStoreID=self._getRelativePath(absJobDir),
+ tryCount=self._defaultTryCount())
+ # Write job file to disk
+ self.update(job)
+ return job
+
+ def exists(self, jobStoreID):
+ return os.path.exists(self._getJobFileName(jobStoreID))
+
+ def getPublicUrl(self, jobStoreFileID):
+ self._checkJobStoreFileID(jobStoreFileID)
+ jobStorePath = self._getAbsPath(jobStoreFileID)
+ if os.path.exists(jobStorePath):
+ return 'file:' + jobStorePath
+ else:
+ raise NoSuchFileException(jobStoreFileID)
+
+ def getSharedPublicUrl(self, sharedFileName):
+ jobStorePath = self.jobStoreDir + '/' + sharedFileName
+ if os.path.exists(jobStorePath):
+ return 'file:' + jobStorePath
+ else:
+ raise NoSuchFileException(sharedFileName)
+
+ def load(self, jobStoreID):
+ self._checkJobStoreId(jobStoreID)
+ # Load a valid version of the job
+ jobFile = self._getJobFileName(jobStoreID)
+ with open(jobFile, 'r') as fileHandle:
+ job = pickler.load(fileHandle)
+ # The following cleans up any issues resulting from the failure of the
+ # job during writing by the batch system.
+ if os.path.isfile(jobFile + ".new"):
+ logger.warn("There was a .new file for the job: %s", jobStoreID)
+ os.remove(jobFile + ".new")
+ job.setupJobAfterFailure(self.config)
+ return job
+
+ def update(self, job):
+ # The job is serialised to a file suffixed by ".new"
+ # The file is then moved to its correct path.
+ # Atomicity guarantees use the fact the underlying file systems "move"
+ # function is atomic.
+ with open(self._getJobFileName(job.jobStoreID) + ".new", 'w') as f:
+ pickler.dump(job, f)
+ # This should be atomic for the file system
+ os.rename(self._getJobFileName(job.jobStoreID) + ".new", self._getJobFileName(job.jobStoreID))
+
+ def delete(self, jobStoreID):
+ # The jobStoreID is the relative path to the directory containing the job,
+ # removing this directory deletes the job.
+ if self.exists(jobStoreID):
+ shutil.rmtree(self._getAbsPath(jobStoreID))
+
+ def jobs(self):
+ # Walk through list of temporary directories searching for jobs
+ for tempDir in self._tempDirectories():
+ for i in os.listdir(tempDir):
+ if i.startswith( 'job' ):
+ try:
+ yield self.load(self._getRelativePath(os.path.join(tempDir, i)))
+ except NoSuchJobException:
+ # An orphaned job may leave an empty or incomplete job file which we can safely ignore
+ pass
+
+ ##########################################
+ # Functions that deal with temporary files associated with jobs
+ ##########################################
+
+ def _importFile(self, otherCls, url, sharedFileName=None):
+ if issubclass(otherCls, FileJobStore):
+ if sharedFileName is None:
+ fd, absPath = self._getTempFile()
+ shutil.copyfile(self._extractPathFromUrl(url), absPath)
+ os.close(fd)
+ return self._getRelativePath(absPath)
+ else:
+ self._requireValidSharedFileName(sharedFileName)
+ with self.writeSharedFileStream(sharedFileName) as writable:
+ with open(self._extractPathFromUrl(url), 'r') as readable:
+ shutil.copyfileobj(readable, writable)
+ return None
+ else:
+ return super(FileJobStore, self)._importFile(otherCls, url,
+ sharedFileName=sharedFileName)
+
+ def _exportFile(self, otherCls, jobStoreFileID, url):
+ if issubclass(otherCls, FileJobStore):
+ shutil.copyfile(self._getAbsPath(jobStoreFileID), self._extractPathFromUrl(url))
+ else:
+ super(FileJobStore, self)._exportFile(otherCls, jobStoreFileID, url)
+
+ @classmethod
+ def _readFromUrl(cls, url, writable):
+ with open(cls._extractPathFromUrl(url), 'r') as f:
+ writable.write(f.read())
+
+ @classmethod
+ def _writeToUrl(cls, readable, url):
+ with open(cls._extractPathFromUrl(url), 'w') as f:
+ f.write(readable.read())
+
+ @staticmethod
+ def _extractPathFromUrl(url):
+ """
+ :return: local file path of file pointed at by the given URL
+ """
+ if url.netloc != '' and url.netloc != 'localhost':
+ raise RuntimeError("The URL '%s' is invalid" % url.geturl())
+ return url.netloc + url.path
+
+ @classmethod
+ def _supportsUrl(cls, url, export=False):
+ return url.scheme.lower() == 'file'
+
+ def writeFile(self, localFilePath, jobStoreID=None):
+ fd, absPath = self._getTempFile(jobStoreID)
+ shutil.copyfile(localFilePath, absPath)
+ os.close(fd)
+ return self._getRelativePath(absPath)
+
+ @contextmanager
+ def writeFileStream(self, jobStoreID=None):
+ fd, absPath = self._getTempFile(jobStoreID)
+ with open(absPath, 'w') as f:
+ yield f, self._getRelativePath(absPath)
+ os.close(fd) # Close the os level file descriptor
+
+ def getEmptyFileStoreID(self, jobStoreID=None):
+ with self.writeFileStream(jobStoreID) as (fileHandle, jobStoreFileID):
+ return jobStoreFileID
+
+ def updateFile(self, jobStoreFileID, localFilePath):
+ self._checkJobStoreFileID(jobStoreFileID)
+ shutil.copyfile(localFilePath, self._getAbsPath(jobStoreFileID))
+
+ def readFile(self, jobStoreFileID, localFilePath):
+ self._checkJobStoreFileID(jobStoreFileID)
+ jobStoreFilePath = self._getAbsPath(jobStoreFileID)
+ localDirPath = os.path.dirname(localFilePath)
+ # If local file would end up on same file system as the one hosting this job store ...
+ if os.stat(jobStoreFilePath).st_dev == os.stat(localDirPath).st_dev:
+ # ... we can hard-link the file, ...
+ try:
+ os.link(jobStoreFilePath, localFilePath)
+ except OSError as e:
+ if e.errno == errno.EEXIST:
+ # Overwrite existing file, emulating shutil.copyfile().
+ os.unlink(localFilePath)
+ # It would be very unlikely to fail again for same reason but possible
+ # nonetheless in which case we should just give up.
+ os.link(jobStoreFilePath, localFilePath)
+ else:
+ raise
+ else:
+ # ... otherwise we have to copy it.
+ shutil.copyfile(jobStoreFilePath, localFilePath)
+
+ def deleteFile(self, jobStoreFileID):
+ if not self.fileExists(jobStoreFileID):
+ return
+ os.remove(self._getAbsPath(jobStoreFileID))
+
+ def fileExists(self, jobStoreFileID):
+ absPath = self._getAbsPath(jobStoreFileID)
+ try:
+ st = os.stat(absPath)
+ except os.error:
+ return False
+ if not stat.S_ISREG(st.st_mode):
+ raise NoSuchFileException("Path %s is not a file in the jobStore" % jobStoreFileID)
+ return True
+
+ @contextmanager
+ def updateFileStream(self, jobStoreFileID):
+ self._checkJobStoreFileID(jobStoreFileID)
+ # File objects are context managers (CM) so we could simply return what open returns.
+ # However, it is better to wrap it in another CM so as to prevent users from accessing
+ # the file object directly, without a with statement.
+ with open(self._getAbsPath(jobStoreFileID), 'w') as f:
+ yield f
+
+ @contextmanager
+ def readFileStream(self, jobStoreFileID):
+ self._checkJobStoreFileID(jobStoreFileID)
+ with open(self._getAbsPath(jobStoreFileID), 'r') as f:
+ yield f
+
+ ##########################################
+ # The following methods deal with shared files, i.e. files not associated
+ # with specific jobs.
+ ##########################################
+
+ @contextmanager
+ def writeSharedFileStream(self, sharedFileName, isProtected=None):
+ # the isProtected parameter has no effect on the fileStore
+ assert self._validateSharedFileName( sharedFileName )
+ with open( os.path.join( self.jobStoreDir, sharedFileName ), 'w' ) as f:
+ yield f
+
+ @contextmanager
+ def readSharedFileStream(self, sharedFileName):
+ assert self._validateSharedFileName( sharedFileName )
+ try:
+ with open(os.path.join(self.jobStoreDir, sharedFileName), 'r') as f:
+ yield f
+ except IOError as e:
+ if e.errno == errno.ENOENT:
+ raise NoSuchFileException(sharedFileName,sharedFileName)
+ else:
+ raise
+
+ def writeStatsAndLogging(self, statsAndLoggingString):
+ # Temporary files are placed in the set of temporary files/directoies
+ fd, tempStatsFile = tempfile.mkstemp(prefix="stats", suffix=".new", dir=self._getTempSharedDir())
+ with open(tempStatsFile, "w") as f:
+ f.write(statsAndLoggingString)
+ os.close(fd)
+ os.rename(tempStatsFile, tempStatsFile[:-4]) # This operation is atomic
+
+ def readStatsAndLogging(self, callback, readAll=False):
+ numberOfFilesProcessed = 0
+ for tempDir in self._tempDirectories():
+ for tempFile in os.listdir(tempDir):
+ if tempFile.startswith('stats'):
+ absTempFile = os.path.join(tempDir, tempFile)
+ if readAll or not tempFile.endswith('.new'):
+ with open(absTempFile, 'r') as fH:
+ callback(fH)
+ numberOfFilesProcessed += 1
+ newName = tempFile.rsplit('.', 1)[0] + '.new'
+ newAbsTempFile = os.path.join(tempDir, newName)
+ # Mark this item as read
+ os.rename(absTempFile, newAbsTempFile)
+ return numberOfFilesProcessed
+
+ ##########################################
+ # Private methods
+ ##########################################
+
+ def _getAbsPath(self, relativePath):
+ """
+ :rtype : string, string is the absolute path to a file path relative
+ to the self.tempFilesDir.
+ """
+ return os.path.join(self.tempFilesDir, relativePath)
+
+ def _getRelativePath(self, absPath):
+ """
+ absPath is the absolute path to a file in the store,.
+
+ :rtype : string, string is the path to the absPath file relative to the
+ self.tempFilesDir
+
+ """
+ return absPath[len(self.tempFilesDir)+1:]
+
+ def _getJobFileName(self, jobStoreID):
+ """
+ Return the path to the file containing the serialised JobGraph instance for the given
+ job.
+
+ :rtype: str
+ """
+ return os.path.join(self._getAbsPath(jobStoreID), "job")
+
+ def _checkJobStoreId(self, jobStoreID):
+ """
+ Raises a NoSuchJobException if the jobStoreID does not exist.
+ """
+ if not self.exists(jobStoreID):
+ raise NoSuchJobException(jobStoreID)
+
+ def _checkJobStoreFileID(self, jobStoreFileID):
+ """
+ :raise NoSuchFileException: if the jobStoreFileID does not exist or is not a file
+ """
+ if not self.fileExists(jobStoreFileID):
+ raise NoSuchFileException("File %s does not exist in jobStore" % jobStoreFileID)
+
+ def _getTempSharedDir(self):
+ """
+ Gets a temporary directory in the hierarchy of directories in self.tempFilesDir.
+ This directory may contain multiple shared jobs/files.
+
+ :rtype : string, path to temporary directory in which to place files/directories.
+ """
+ tempDir = self.tempFilesDir
+ for i in xrange(self.levels):
+ tempDir = os.path.join(tempDir, random.choice(self.validDirs))
+ if not os.path.exists(tempDir):
+ try:
+ os.mkdir(tempDir)
+ except os.error:
+ if not os.path.exists(tempDir): # In the case that a collision occurs and
+ # it is created while we wait then we ignore
+ raise
+ return tempDir
+
+ def _tempDirectories(self):
+ """
+ :rtype : an iterator to the temporary directories containing jobs/stats files
+ in the hierarchy of directories in self.tempFilesDir
+ """
+ def _dirs(path, levels):
+ if levels > 0:
+ for subPath in os.listdir(path):
+ for i in _dirs(os.path.join(path, subPath), levels-1):
+ yield i
+ else:
+ yield path
+ for tempDir in _dirs(self.tempFilesDir, self.levels):
+ yield tempDir
+
+ def _getTempFile(self, jobStoreID=None):
+ """
+ :rtype : file-descriptor, string, string is the absolute path to a temporary file within
+ the given job's (referenced by jobStoreID's) temporary file directory. The file-descriptor
+ is integer pointing to open operating system file handle. Should be closed using os.close()
+ after writing some material to the file.
+ """
+ if jobStoreID != None:
+ # Make a temporary file within the job's directory
+ self._checkJobStoreId(jobStoreID)
+ return tempfile.mkstemp(suffix=".tmp",
+ dir=os.path.join(self._getAbsPath(jobStoreID), "g"))
+ else:
+ # Make a temporary file within the temporary file structure
+ return tempfile.mkstemp(prefix="tmp", suffix=".tmp", dir=self._getTempSharedDir())
diff --git a/src/toil/jobStores/googleJobStore.py b/src/toil/jobStores/googleJobStore.py
new file mode 100644
index 0000000..fa07a34
--- /dev/null
+++ b/src/toil/jobStores/googleJobStore.py
@@ -0,0 +1,476 @@
+import base64
+from contextlib import contextmanager
+import hashlib
+import os
+import uuid
+from bd2k.util.threading import ExceptionalThread
+import boto
+import logging
+import time
+
+# Python 3 compatibility imports
+from six.moves import cPickle, StringIO
+
+from toil.jobStores.abstractJobStore import (AbstractJobStore, NoSuchJobException,
+ NoSuchFileException,
+ ConcurrentFileModificationException)
+from toil.jobStores.utils import WritablePipe, ReadablePipe
+from toil.jobGraph import JobGraph
+
+log = logging.getLogger(__name__)
+
+GOOGLE_STORAGE = 'gs'
+
+
+class GoogleJobStore(AbstractJobStore):
+
+ @classmethod
+ def initialize(cls, locator, config=None):
+ try:
+ projectID, namePrefix = locator.split(":", 1)
+ except ValueError:
+ # we don't have a specified projectID
+ namePrefix = locator
+ projectID = None
+ return cls(namePrefix, projectID, config)
+
+ # BOTO WILL UPDATE HEADERS WITHOUT COPYING THEM FIRST. To enforce immutability & prevent
+ # this, we use getters that return copies of our original dictionaries. reported:
+ # https://github.com/boto/boto/issues/3517
+ @property
+ def encryptedHeaders(self):
+ return self._encryptedHeaders.copy()
+
+ @encryptedHeaders.setter
+ def encryptedHeaders(self, value):
+ self._encryptedHeaders = value
+
+ @property
+ def headerValues(self):
+ return self._headerValues.copy()
+
+ @headerValues.setter
+ def headerValues(self, value):
+ self._headerValues = value
+
+ def __init__(self, namePrefix, projectID=None, config=None):
+ # create 2 buckets
+ self.projectID = projectID
+
+ self.bucketName = namePrefix+"--toil"
+ log.debug("Instantiating google jobStore with name: %s", self.bucketName)
+ self.gsBucketURL = "gs://"+self.bucketName
+
+ self._headerValues = {"x-goog-project-id": projectID} if projectID else {}
+ self._encryptedHeaders = self.headerValues
+
+ self.uri = boto.storage_uri(self.gsBucketURL, GOOGLE_STORAGE)
+ self.files = None
+
+ exists = True
+ try:
+ self.files = self.uri.get_bucket(headers=self.headerValues, validate=True)
+ except boto.exception.GSResponseError:
+ exists = False
+
+ create = config is not None
+ self._checkJobStoreCreation(create, exists, projectID+':'+namePrefix)
+
+ if not exists:
+ self.files = self._retryCreateBucket(self.uri, self.headerValues)
+
+ super(GoogleJobStore, self).__init__(config=config)
+ self.sseKeyPath = self.config.sseKey
+ # functionally equivalent to dictionary1.update(dictionary2) but works with our immutable dicts
+ self.encryptedHeaders = dict(self.encryptedHeaders, **self._resolveEncryptionHeaders())
+
+ self.statsBaseID = 'f16eef0c-b597-4b8b-9b0c-4d605b4f506c'
+ self.statsReadPrefix = '_'
+ self.readStatsBaseID = self.statsReadPrefix+self.statsBaseID
+
+ def destroy(self):
+ # no upper time limit on this call keep trying delete calls until we succeed - we can
+ # fail because of eventual consistency in 2 ways: 1) skipping unlisted objects in bucket
+ # that are meant to be deleted 2) listing of ghost objects when trying to delete bucket
+ while True:
+ try:
+ self.uri.delete_bucket()
+ except boto.exception.GSResponseError as e:
+ if e.status == 404:
+ return # the bucket doesn't exist so we are done
+ else:
+ # bucket could still have objects, or contain ghost objects
+ time.sleep(0.5)
+ else:
+ # we have succesfully deleted bucket
+ return
+
+ # object could have been deleted already
+ for obj in self.files.list():
+ try:
+ obj.delete()
+ except boto.exception.GSResponseError:
+ pass
+
+ def create(self, jobNode):
+ jobStoreID = self._newID()
+ job = JobGraph(jobStoreID=jobStoreID, unitName=jobNode.name, jobName=jobNode.job,
+ command=jobNode.command, remainingRetryCount=self._defaultTryCount(),
+ logJobStoreFileID=None, predecessorNumber=jobNode.predecessorNumber,
+ **jobNode._requirements)
+ self._writeString(jobStoreID, cPickle.dumps(job))
+ return job
+
+ def exists(self, jobStoreID):
+ # used on job files, which will be encrypted if avaliable
+ headers = self.encryptedHeaders
+ try:
+ self._getKey(jobStoreID, headers)
+ except NoSuchFileException:
+ return False
+ return True
+
+ def getPublicUrl(self, fileName):
+ try:
+ key = self._getKey(fileName)
+ except:
+ raise NoSuchFileException(fileName)
+ return key.generate_url(expires_in=self.publicUrlExpiration.total_seconds())
+
+ def getSharedPublicUrl(self, sharedFileName):
+ return self.getPublicUrl(sharedFileName)
+
+ def load(self, jobStoreID):
+ try:
+ jobString = self._readContents(jobStoreID)
+ except NoSuchFileException:
+ raise NoSuchJobException(jobStoreID)
+ return cPickle.loads(jobString)
+
+ def update(self, job):
+ self._writeString(job.jobStoreID, cPickle.dumps(job), update=True)
+
+ def delete(self, jobStoreID):
+ # jobs will always be encrypted when avaliable
+ self._delete(jobStoreID, encrypt=True)
+
+ def jobs(self):
+ for key in self.files.list(prefix='job'):
+ jobStoreID = key.name
+ if len(jobStoreID) == 39:
+ yield self.load(jobStoreID)
+
+ def writeFile(self, localFilePath, jobStoreID=None):
+ fileID = self._newID(isFile=True, jobStoreID=jobStoreID)
+ with open(localFilePath) as f:
+ self._writeFile(fileID, f)
+ return fileID
+
+ @contextmanager
+ def writeFileStream(self, jobStoreID=None):
+ fileID = self._newID(isFile=True, jobStoreID=jobStoreID)
+ key = self._newKey(fileID)
+ with self._uploadStream(key, update=False) as writable:
+ yield writable, key.name
+
+ def getEmptyFileStoreID(self, jobStoreID=None):
+ fileID = self._newID(isFile=True, jobStoreID=jobStoreID)
+ self._writeFile(fileID, StringIO(""))
+ return fileID
+
+ def readFile(self, jobStoreFileID, localFilePath):
+ # used on non-shared files which will be encrypted if avaliable
+ headers = self.encryptedHeaders
+ # checking for JobStoreID existance
+ if not self.exists(jobStoreFileID):
+ raise NoSuchFileException(jobStoreFileID)
+ with open(localFilePath, 'w') as writeable:
+ self._getKey(jobStoreFileID, headers).get_contents_to_file(writeable, headers=headers)
+
+ @contextmanager
+ def readFileStream(self, jobStoreFileID):
+ with self.readSharedFileStream(jobStoreFileID, isProtected=True) as readable:
+ yield readable
+
+ def deleteFile(self, jobStoreFileID):
+ headers = self.encryptedHeaders
+ try:
+ self._getKey(jobStoreFileID, headers).delete(headers)
+ except boto.exception.GSDataError as e:
+ # we tried to delete unencrypted file with encryption headers. unfortunately,
+ # we can't determine whether the file passed in is encrypted or not beforehand.
+ if e.status == 400:
+ headers = self.headerValues
+ self._getKey(jobStoreFileID, headers).delete(headers)
+ else:
+ raise e
+
+ def fileExists(self, jobStoreFileID):
+ try:
+ self._getKey(jobStoreFileID)
+ return True
+ except (NoSuchFileException, boto.exception.GSResponseError) as e:
+ if isinstance(e, NoSuchFileException):
+ return False
+ elif e.status == 400:
+ # will happen w/ self.fileExists(encryptedFile). If file didn't exist code == 404
+ return True
+ else:
+ return False
+
+ def updateFile(self, jobStoreFileID, localFilePath):
+ with open(localFilePath) as f:
+ self._writeFile(jobStoreFileID, f, update=True)
+
+ @contextmanager
+ def updateFileStream(self, jobStoreFileID):
+ headers = self.encryptedHeaders
+ key = self._getKey(jobStoreFileID, headers)
+ with self._uploadStream(key, update=True) as readable:
+ yield readable
+
+ @contextmanager
+ def writeSharedFileStream(self, sharedFileName, isProtected=True):
+ key = self._newKey(sharedFileName)
+ with self._uploadStream(key, encrypt=isProtected, update=True) as readable:
+ yield readable
+
+ @contextmanager
+ def readSharedFileStream(self, sharedFileName, isProtected=True):
+ headers = self.encryptedHeaders if isProtected else self.headerValues
+ key = self._getKey(sharedFileName, headers=headers)
+ with self._downloadStream(key, encrypt=isProtected) as readable:
+ yield readable
+
+ @staticmethod
+ def _getResources(url):
+ projectID = url.host
+ bucketAndKey = url.path
+ return projectID, 'gs://'+bucketAndKey
+
+ @classmethod
+ def _readFromUrl(cls, url, writable):
+ # gs://projectid/bucket/key
+ projectID, uri = GoogleJobStore._getResources(url)
+ uri = boto.storage_uri(uri, GOOGLE_STORAGE)
+ headers = {"x-goog-project-id": projectID}
+ uri.get_contents_to_file(writable, headers=headers)
+
+ @classmethod
+ def _supportsUrl(cls, url, export=False):
+ return url.scheme.lower() == 'gs'
+
+ @classmethod
+ def _writeToUrl(cls, readable, url):
+ projectID, uri = GoogleJobStore._getResources(url)
+ uri = boto.storage_uri(uri, GOOGLE_STORAGE)
+ headers = {"x-goog-project-id": projectID}
+ uri.set_contents_from_file(readable, headers=headers)
+
+ def writeStatsAndLogging(self, statsAndLoggingString):
+ statsID = self.statsBaseID + str(uuid.uuid4())
+ key = self._newKey(statsID)
+ log.debug("Writing stats file: %s", key.name)
+ with self._uploadStream(key, encrypt=False, update=False) as f:
+ f.write(statsAndLoggingString)
+
+ def readStatsAndLogging(self, callback, readAll=False):
+ prefix = self.readStatsBaseID if readAll else self.statsBaseID
+ filesRead = 0
+ lastTry = False
+
+ while True:
+ filesReadThisLoop = 0
+ for key in list(self.files.list(prefix=prefix)):
+ try:
+ with self.readSharedFileStream(key.name) as readable:
+ log.debug("Reading stats file: %s", key.name)
+ callback(readable)
+ filesReadThisLoop += 1
+ if not readAll:
+ # rename this file by copying it and deleting the old version to avoid
+ # rereading it
+ newID = self.readStatsBaseID + key.name[len(self.statsBaseID):]
+ self.files.copy_key(newID, self.files.name, key.name)
+ key.delete()
+ except NoSuchFileException:
+ log.debug("Stats file not found: %s", key.name)
+ if readAll:
+ # The readAll parameter is only by the toil stats util after the completion of the
+ # pipeline. Assume that this means the bucket is in a consistent state when readAll
+ # is passed.
+ return filesReadThisLoop
+ if filesReadThisLoop == 0:
+ # Listing is unfortunately eventually consistent so we can't be 100% sure there
+ # really aren't any stats files left to read
+ if lastTry:
+ # this was our second try, we are reasonably sure there aren't any stats
+ # left to gather
+ break
+ # Try one more time in a couple seconds
+ time.sleep(5)
+ lastTry = True
+ continue
+ else:
+ lastTry = False
+ filesRead += filesReadThisLoop
+
+ return filesRead
+
+ @staticmethod
+ def _retryCreateBucket(uri, headers):
+ # FIMXE: This should use retry from utils
+ while True:
+ try:
+ # FIMXE: this leaks a connection on exceptions
+ return uri.create_bucket(headers=headers)
+ except boto.exception.GSResponseError as e:
+ if e.status == 429:
+ time.sleep(10)
+ else:
+ raise
+
+ @staticmethod
+ def _newID(isFile=False, jobStoreID=None):
+ if isFile and jobStoreID: # file associated with job
+ return jobStoreID+str(uuid.uuid4())
+ elif isFile: # nonassociated file
+ return str(uuid.uuid4())
+ else: # job id
+ return "job"+str(uuid.uuid4())
+
+ def _resolveEncryptionHeaders(self):
+ sseKeyPath = self.sseKeyPath
+ if sseKeyPath is None:
+ return {}
+ else:
+ with open(sseKeyPath) as f:
+ sseKey = f.read()
+ assert len(sseKey) == 32
+ encodedSseKey = base64.b64encode(sseKey)
+ encodedSseKeyMd5 = base64.b64encode(hashlib.sha256(sseKey).digest())
+ return {'x-goog-encryption-algorithm': 'AES256',
+ 'x-goog-encryption-key': encodedSseKey,
+ 'x-goog-encryption-key-sha256': encodedSseKeyMd5,
+ "Cache-Control": "no-store"}
+
+ def _delete(self, jobStoreID, encrypt=True):
+ headers = self.encryptedHeaders if encrypt else self.headerValues
+ try:
+ key = self._getKey(jobStoreID, headers)
+ except NoSuchFileException:
+ pass
+ else:
+ try:
+ key.delete()
+ except boto.exception.GSResponseError as e:
+ if e.status == 404:
+ pass
+ # best effort delete associated files
+ for fileID in self.files.list(prefix=jobStoreID):
+ try:
+ self.deleteFile(fileID)
+ except NoSuchFileException:
+ pass
+
+ def _getKey(self, jobStoreID=None, headers=None):
+ # gets remote key, in contrast to self._newKey
+ key = None
+ try:
+ key = self.files.get_key(jobStoreID, headers=headers)
+ except boto.exception.GSDataError:
+ if headers == self.encryptedHeaders:
+ # https://github.com/boto/boto/issues/3518
+ # see self._writeFile for more
+ pass
+ else:
+ raise
+ if key is None:
+ raise NoSuchFileException(jobStoreID)
+ else:
+ return key
+
+ def _newKey(self, jobStoreID):
+ # Does not create a key remotely. Careful -- does not ensure key name is unique
+ return self.files.new_key(jobStoreID)
+
+ def _readContents(self, jobStoreID):
+ # used on job files only, which will be encrypted if avaliable
+ headers = self.encryptedHeaders
+ return self._getKey(jobStoreID, headers).get_contents_as_string(headers=headers)
+
+ def _writeFile(self, jobStoreID, fileObj, update=False, encrypt=True):
+ headers = self.encryptedHeaders if encrypt else self.headerValues
+ if update:
+ key = self._getKey(jobStoreID=jobStoreID, headers=headers)
+ else:
+ key = self._newKey(jobStoreID=jobStoreID)
+ headers = self.encryptedHeaders if encrypt else self.headerValues
+ try:
+ key.set_contents_from_file(fileObj, headers=headers)
+ except boto.exception.GSDataError:
+ if encrypt:
+ # Per https://cloud.google.com/storage/docs/encryption#customer-supplied_encryption_keys
+ # the etag and md5 will not match with customer supplied
+ # keys. However boto didn't get the memo apparently, and will raise this error if
+ # they dont match. Reported: https://github.com/boto/boto/issues/3518
+ pass
+ else:
+ raise
+
+ def _writeString(self, jobStoreID, stringToUpload, **kwarg):
+ self._writeFile(jobStoreID, StringIO(stringToUpload), **kwarg)
+
+ @contextmanager
+ def _uploadStream(self, key, update=False, encrypt=True):
+ store = self
+
+ class UploadPipe(WritablePipe):
+ def readFrom(self, readable):
+ headers = store.encryptedHeaders if encrypt else store.headerValues
+ if update:
+ try:
+ key.set_contents_from_stream(readable, headers=headers)
+ except boto.exception.GSDataError:
+ if encrypt:
+ # https://github.com/boto/boto/issues/3518
+ # see self._writeFile for more
+ pass
+ else:
+ raise
+ else:
+ try:
+ # The if_generation argument insures that the existing key matches the
+ # given generation, i.e. version, before modifying anything. Passing a
+ # generation of 0 insures that the key does not exist remotely.
+ key.set_contents_from_stream(readable, headers=headers, if_generation=0)
+ except (boto.exception.GSResponseError, boto.exception.GSDataError) as e:
+ if isinstance(e, boto.exception.GSResponseError):
+ if e.status == 412:
+ raise ConcurrentFileModificationException(key.name)
+ else:
+ raise e
+ elif encrypt:
+ # https://github.com/boto/boto/issues/3518
+ # see self._writeFile for more
+ pass
+ else:
+ raise
+
+ with UploadPipe() as writable:
+ yield writable
+
+ @contextmanager
+ def _downloadStream(self, key, encrypt=True):
+ store = self
+
+ class DownloadPipe(ReadablePipe):
+ def writeTo(self, writable):
+ headers = store.encryptedHeaders if encrypt else store.headerValues
+ try:
+ key.get_file(writable, headers=headers)
+ finally:
+ writable.close()
+
+ with DownloadPipe() as readable:
+ yield readable
diff --git a/src/toil/jobStores/utils.py b/src/toil/jobStores/utils.py
new file mode 100644
index 0000000..5615297
--- /dev/null
+++ b/src/toil/jobStores/utils.py
@@ -0,0 +1,236 @@
+import logging
+import os
+from abc import ABCMeta
+from abc import abstractmethod
+
+from bd2k.util.threading import ExceptionalThread
+
+log = logging.getLogger(__name__)
+
+class WritablePipe(object):
+ """
+ An object-oriented wrapper for os.pipe. Clients should subclass it, implement
+ :meth:`.readFrom` to consume the readable end of the pipe, then instantiate the class as a
+ context manager to get the writable end. See the example below.
+
+ >>> import sys, shutil
+ >>> class MyPipe(WritablePipe):
+ ... def readFrom(self, readable):
+ ... shutil.copyfileobj(readable, sys.stdout)
+ >>> with MyPipe() as writable:
+ ... writable.write('Hello, world!\\n')
+ Hello, world!
+
+ Each instance of this class creates a thread and invokes the readFrom method in that thread.
+ The thread will be join()ed upon normal exit from the context manager, i.e. the body of the
+ `with` statement. If an exception occurs, the thread will not be joined but a well-behaved
+ :meth:`.readFrom` implementation will terminate shortly thereafter due to the pipe having
+ been closed.
+
+ Now, exceptions in the reader thread will be reraised in the main thread:
+
+ >>> class MyPipe(WritablePipe):
+ ... def readFrom(self, readable):
+ ... raise RuntimeError('Hello, world!')
+ >>> with MyPipe() as writable:
+ ... pass
+ Traceback (most recent call last):
+ ...
+ RuntimeError: Hello, world!
+
+ More complicated, less illustrative tests:
+
+ Same as above, but provving that handles are closed:
+
+ >>> x = os.dup(0); os.close(x)
+ >>> class MyPipe(WritablePipe):
+ ... def readFrom(self, readable):
+ ... raise RuntimeError('Hello, world!')
+ >>> with MyPipe() as writable:
+ ... pass
+ Traceback (most recent call last):
+ ...
+ RuntimeError: Hello, world!
+ >>> y = os.dup(0); os.close(y); x == y
+ True
+
+ Exceptions in the body of the with statement aren't masked, and handles are closed:
+
+ >>> x = os.dup(0); os.close(x)
+ >>> class MyPipe(WritablePipe):
+ ... def readFrom(self, readable):
+ ... pass
+ >>> with MyPipe() as writable:
+ ... raise RuntimeError('Hello, world!')
+ Traceback (most recent call last):
+ ...
+ RuntimeError: Hello, world!
+ >>> y = os.dup(0); os.close(y); x == y
+ True
+ """
+
+ __metaclass__ = ABCMeta
+
+ @abstractmethod
+ def readFrom(self, readable):
+ """
+ Implement this method to read data from the pipe.
+
+ :param file readable: the file object representing the readable end of the pipe. Do not
+ explicitly invoke the close() method of the object, that will be done automatically.
+ """
+ raise NotImplementedError()
+
+ def _reader(self):
+ with os.fdopen(self.readable_fh, 'r') as readable:
+ # FIXME: another race here, causing a redundant attempt to close in the main thread
+ self.readable_fh = None # signal to parent thread that we've taken over
+ self.readFrom(readable)
+
+ def __init__(self):
+ super(WritablePipe, self).__init__()
+ self.readable_fh = None
+ self.writable = None
+ self.thread = None
+
+ def __enter__(self):
+ self.readable_fh, writable_fh = os.pipe()
+ self.writable = os.fdopen(writable_fh, 'w')
+ self.thread = ExceptionalThread(target=self._reader)
+ self.thread.start()
+ return self.writable
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ try:
+ self.writable.close()
+ # Closeing the writable end will send EOF to the readable and cause the reader thread
+ # to finish.
+ if exc_type is None:
+ if self.thread is not None:
+ # reraises any exception that was raised in the thread
+ self.thread.join()
+ finally:
+ # The responsibility for closing the readable end is generally that of the reader
+ # thread. To cover the small window before the reader takes over we also close it here.
+ readable_fh = self.readable_fh
+ if readable_fh is not None:
+ # FIXME: This is still racy. The reader thread could close it now, and someone
+ # else may immediately open a new file, reusing the file handle.
+ os.close(readable_fh)
+
+
+# FIXME: Unfortunately these two classes are almost an exact mirror image of each other.
+# Basically, read and write are swapped. The only asymmetry lies in how shutdown is handled. I
+# tried generalizing but the code becomes inscrutable. Until I (or someone else) has a better
+# idea how to solve this, I think its better to have code that is readable at the expense of
+# duplication.
+
+
+class ReadablePipe(object):
+ """
+ An object-oriented wrapper for os.pipe. Clients should subclass it, implement
+ :meth:`.writeTo` to place data into the writable end of the pipe, then instantiate the class
+ as a context manager to get the writable end. See the example below.
+
+ >>> import sys, shutil
+ >>> class MyPipe(ReadablePipe):
+ ... def writeTo(self, writable):
+ ... writable.write('Hello, world!\\n')
+ >>> with MyPipe() as readable:
+ ... shutil.copyfileobj(readable, sys.stdout)
+ Hello, world!
+
+ Each instance of this class creates a thread and invokes the :meth:`.writeTo` method in that
+ thread. The thread will be join()ed upon normal exit from the context manager, i.e. the body
+ of the `with` statement. If an exception occurs, the thread will not be joined but a
+ well-behaved :meth:`.writeTo` implementation will terminate shortly thereafter due to the
+ pipe having been closed.
+
+ Now, exceptions in the reader thread will be reraised in the main thread:
+
+ >>> class MyPipe(ReadablePipe):
+ ... def writeTo(self, writable):
+ ... raise RuntimeError('Hello, world!')
+ >>> with MyPipe() as readable:
+ ... pass
+ Traceback (most recent call last):
+ ...
+ RuntimeError: Hello, world!
+
+ More complicated, less illustrative tests:
+
+ Same as above, but provving that handles are closed:
+
+ >>> x = os.dup(0); os.close(x)
+ >>> class MyPipe(ReadablePipe):
+ ... def writeTo(self, writable):
+ ... raise RuntimeError('Hello, world!')
+ >>> with MyPipe() as readable:
+ ... pass
+ Traceback (most recent call last):
+ ...
+ RuntimeError: Hello, world!
+ >>> y = os.dup(0); os.close(y); x == y
+ True
+
+ Exceptions in the body of the with statement aren't masked, and handles are closed:
+
+ >>> x = os.dup(0); os.close(x)
+ >>> class MyPipe(ReadablePipe):
+ ... def writeTo(self, writable):
+ ... pass
+ >>> with MyPipe() as readable:
+ ... raise RuntimeError('Hello, world!')
+ Traceback (most recent call last):
+ ...
+ RuntimeError: Hello, world!
+ >>> y = os.dup(0); os.close(y); x == y
+ True
+ """
+
+ __metaclass__ = ABCMeta
+
+ @abstractmethod
+ def writeTo(self, writable):
+ """
+ Implement this method to read data from the pipe.
+
+ :param file writable: the file object representing the writable end of the pipe. Do not
+ explicitly invoke the close() method of the object, that will be done automatically.
+ """
+ raise NotImplementedError()
+
+ def _writer(self):
+ with os.fdopen(self.writable_fh, 'w') as writable:
+ # FIXME: another race here, causing a redundant attempt to close in the main thread
+ self.writable_fh = None # signal to parent thread that we've taken over
+ self.writeTo(writable)
+
+ def __init__(self):
+ super(ReadablePipe, self).__init__()
+ self.writable_fh = None
+ self.readable = None
+ self.thread = None
+
+ def __enter__(self):
+ readable_fh, self.writable_fh = os.pipe()
+ self.readable = os.fdopen(readable_fh, 'r')
+ self.thread = ExceptionalThread(target=self._writer)
+ self.thread.start()
+ return self.readable
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ try:
+ if exc_type is None:
+ if self.thread is not None:
+ # reraises any exception that was raised in the thread
+ self.thread.join()
+ finally:
+ self.readable.close()
+ # The responsibility for closing the writable end is generally that of the writer
+ # thread. To cover the small window before the writer takes over we also close it here.
+ writable_fh = self.writable_fh
+ if writable_fh is not None:
+ # FIXME: This is still racy. The writer thread could close it now, and someone
+ # else may immediately open a new file, reusing the file handle.
+ os.close(writable_fh)
diff --git a/src/toil/leader.py b/src/toil/leader.py
new file mode 100644
index 0000000..ccf66b3
--- /dev/null
+++ b/src/toil/leader.py
@@ -0,0 +1,909 @@
+# Copyright (C) 2015-2016 Regents of the University of California
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+The leader script (of the leader/worker pair) for running jobs.
+"""
+from __future__ import absolute_import
+
+import logging
+import gzip
+import os
+import time
+from collections import namedtuple
+
+# Python 3 compatibility imports
+from six.moves import cPickle
+
+from bd2k.util.expando import Expando
+from bd2k.util.humanize import bytes2human
+
+from toil import resolveEntryPoint
+from toil.jobStores.abstractJobStore import NoSuchJobException
+from toil.provisioners.clusterScaler import ClusterScaler
+from toil.serviceManager import ServiceManager
+from toil.statsAndLogging import StatsAndLogging
+from toil.jobGraph import JobNode
+from toil.toilState import ToilState
+
+logger = logging.getLogger( __name__ )
+
+####################################################
+# Exception thrown by the Leader class when one or more jobs fails
+####################################################
+
+class FailedJobsException( Exception ):
+ def __init__(self, jobStoreLocator, failedJobs, jobStore):
+ msg = "The job store '%s' contains %i failed jobs" % (jobStoreLocator, len(failedJobs))
+ try:
+ msg += ": %s" % ", ".join((str(failedJob) for failedJob in failedJobs))
+ for jobNode in failedJobs:
+ job = jobStore.load(jobNode.jobStoreID)
+ if job.logJobStoreFileID:
+ msg += "\n=========> Failed job %s \n" % jobNode
+ with job.getLogFileHandle(jobStore) as fH:
+ msg += fH.read()
+ msg += "<=========\n"
+ # catch failures to prepare more complex details and only return the basics
+ except:
+ logger.exception('Exception when compiling information about failed jobs')
+ super( FailedJobsException, self ).__init__(msg)
+ self.jobStoreLocator = jobStoreLocator
+ self.numberOfFailedJobs = len(failedJobs)
+
+####################################################
+# Exception thrown by the Leader class when a deadlock is encountered due to insufficient
+# resources to run the workflow
+####################################################
+
+class DeadlockException( Exception ):
+ def __init__(self, msg):
+ msg = "Deadlock encountered: " + msg
+ super( DeadlockException, self ).__init__(msg)
+
+####################################################
+##Following class represents the leader
+####################################################
+
+class Leader:
+ """ Class that encapsulates the logic of the leader.
+ """
+ def __init__(self, config, batchSystem, provisioner, jobStore, rootJob, jobCache=None):
+ """
+ :param toil.common.Config config:
+ :param toil.batchSystems.abstractBatchSystem.AbstractBatchSystem batchSystem:
+ :param toil.provisioners.abstractProvisioner.AbstractProvisioner provisioner
+ :param toil.jobStores.abstractJobStore.AbstractJobStore jobStore:
+ :param toil.jobGraph.JobGraph rootJob
+
+ If jobCache is passed, it must be a dict from job ID to pre-existing
+ JobGraph objects. Jobs will be loaded from the cache (which can be
+ downloaded from the jobStore in a batch) during the construction of the ToilState object.
+ """
+ # Object containing parameters for the run
+ self.config = config
+
+ # The job store
+ self.jobStore = jobStore
+ self.jobStoreLocator = config.jobStore
+
+ # Get a snap shot of the current state of the jobs in the jobStore
+ self.toilState = ToilState(jobStore, rootJob, jobCache=jobCache)
+ logger.info("Found %s jobs to start and %i jobs with successors to run",
+ len(self.toilState.updatedJobs), len(self.toilState.successorCounts))
+
+ # Batch system
+ self.batchSystem = batchSystem
+ assert len(self.batchSystem.getIssuedBatchJobIDs()) == 0 #Batch system must start with no active jobs!
+ logger.info("Checked batch system has no running jobs and no updated jobs")
+
+ # Map of batch system IDs to IsseudJob tuples
+ self.jobBatchSystemIDToIssuedJob = {}
+
+ # Number of preempetable jobs currently being run by batch system
+ self.preemptableJobsIssued = 0
+
+ # Tracking the number service jobs issued,
+ # this is used limit the number of services issued to the batch system
+ self.serviceJobsIssued = 0
+ self.serviceJobsToBeIssued = [] # A queue of service jobs that await scheduling
+ #Equivalents for service jobs to be run on preemptable nodes
+ self.preemptableServiceJobsIssued = 0
+ self.preemptableServiceJobsToBeIssued = []
+
+ # Hash to store number of times a job is lost by the batch system,
+ # used to decide if to reissue an apparently missing job
+ self.reissueMissingJobs_missingHash = {}
+
+ # Class used to create/destroy nodes in the cluster, may be None if
+ # using a statically defined cluster
+ self.provisioner = provisioner
+
+ # Create cluster scaling thread if the provisioner is not None
+ self.clusterScaler = None if self.provisioner is None else ClusterScaler(self.provisioner, self, self.config)
+
+ # A service manager thread to start and terminate services
+ self.serviceManager = ServiceManager(jobStore, self.toilState)
+
+ # A thread to manage the aggregation of statistics and logging from the run
+ self.statsAndLogging = StatsAndLogging(self.jobStore, self.config)
+
+ # Set used to monitor deadlocked jobs
+ self.potentialDeadlockedJobs = None
+ self.potentialDeadlockTime = 0
+
+ # internal jobs we should not expose at top level debugging
+ self.debugJobNames = ("CWLJob", "CWLWorkflow", "CWLScatter", "CWLGather",
+ "ResolveIndirect")
+
+ def run(self):
+ """
+ This runs the leader process to issue and manage jobs.
+
+ :raises: toil.leader.FailedJobsException if at the end of function their remain \
+ failed jobs
+
+ :return: The return value of the root job's run function.
+ :rtype: Any
+ """
+ # Start the stats/logging aggregation thread
+ self.statsAndLogging.start()
+ try:
+
+ # Start service manager thread
+ self.serviceManager.start()
+ try:
+
+ # Create cluster scaling processes if not None
+ if self.clusterScaler != None:
+ self.clusterScaler.start()
+
+ try:
+ # Run the main loop
+ self.innerLoop()
+ finally:
+ if self.clusterScaler is not None:
+ logger.info('Waiting for workers to shutdown')
+ startTime = time.time()
+ self.clusterScaler.shutdown()
+ logger.info('Worker shutdown complete in %s seconds', time.time() - startTime)
+
+ finally:
+ # Ensure service manager thread is properly shutdown
+ self.serviceManager.shutdown()
+
+ finally:
+ # Ensure the stats and logging thread is properly shutdown
+ self.statsAndLogging.shutdown()
+
+ # Filter the failed jobs
+ self.toilState.totalFailedJobs = filter(lambda j : self.jobStore.exists(j.jobStoreID), self.toilState.totalFailedJobs)
+
+ logger.info("Finished toil run %s" %
+ ("successfully" if len(self.toilState.totalFailedJobs) == 0 else ("with %s failed jobs" % len(self.toilState.totalFailedJobs))))
+
+ if len(self.toilState.totalFailedJobs):
+ logger.info("Failed jobs at end of the run: %s", ' '.join(str(job) for job in self.toilState.totalFailedJobs))
+ # Cleanup
+ if len(self.toilState.totalFailedJobs) > 0:
+ raise FailedJobsException(self.config.jobStore, self.toilState.totalFailedJobs, self.jobStore)
+
+ # Parse out the return value from the root job
+ with self.jobStore.readSharedFileStream('rootJobReturnValue') as fH:
+ try:
+ return cPickle.load(fH)
+ except EOFError:
+ logger.exception('Failed to unpickle root job return value')
+ raise FailedJobsException(self.config.jobStore, self.toilState.totalFailedJobs, self.jobStore)
+
+ def innerLoop(self):
+ """
+ The main loop for processing jobs by the leader.
+ """
+ # Sets up the timing of the jobGraph rescuing method
+ timeSinceJobsLastRescued = time.time()
+
+ logger.info("Starting the main loop")
+ while True:
+ # Process jobs that are ready to be scheduled/have successors to schedule
+ if len(self.toilState.updatedJobs) > 0:
+ logger.debug('Built the jobs list, currently have %i jobs to update and %i jobs issued',
+ len(self.toilState.updatedJobs), self.getNumberOfJobsIssued())
+
+ updatedJobs = self.toilState.updatedJobs # The updated jobs to consider below
+ self.toilState.updatedJobs = set() # Resetting the list for the next set
+
+ for jobGraph, resultStatus in updatedJobs:
+
+ logger.debug('Updating status of job %s with ID %s: with result status: %s',
+ jobGraph, jobGraph.jobStoreID, resultStatus)
+
+ # This stops a job with services being issued by the serviceManager from
+ # being considered further in this loop. This catch is necessary because
+ # the job's service's can fail while being issued, causing the job to be
+ # added to updated jobs.
+ if jobGraph in self.serviceManager.jobGraphsWithServicesBeingStarted:
+ logger.debug("Got a job to update which is still owned by the service "
+ "manager: %s", jobGraph.jobStoreID)
+ continue
+
+ # If some of the jobs successors failed then either fail the job
+ # or restart it if it has retries left and is a checkpoint job
+ if jobGraph.jobStoreID in self.toilState.hasFailedSuccessors:
+
+ # If the job has services running, signal for them to be killed
+ # once they are killed then the jobGraph will be re-added to the
+ # updatedJobs set and then scheduled to be removed
+ if jobGraph.jobStoreID in self.toilState.servicesIssued:
+ logger.debug("Telling job: %s to terminate its services due to successor failure",
+ jobGraph.jobStoreID)
+ self.serviceManager.killServices(self.toilState.servicesIssued[jobGraph.jobStoreID],
+ error=True)
+
+ # If the job has non-service jobs running wait for them to finish
+ # the job will be re-added to the updated jobs when these jobs are done
+ elif jobGraph.jobStoreID in self.toilState.successorCounts:
+ logger.debug("Job %s with ID: %s with failed successors still has successor jobs running",
+ jobGraph, jobGraph.jobStoreID)
+ continue
+
+ # If the job is a checkpoint and has remaining retries then reissue it.
+ elif jobGraph.checkpoint is not None and jobGraph.remainingRetryCount > 0:
+ logger.warn('Job: %s is being restarted as a checkpoint after the total '
+ 'failure of jobs in its subtree.', jobGraph.jobStoreID)
+ self.issueJob(JobNode.fromJobGraph(jobGraph))
+ else: # Mark it totally failed
+ logger.debug("Job %s is being processed as completely failed", jobGraph.jobStoreID)
+ self.processTotallyFailedJob(jobGraph)
+
+ # If the jobGraph has a command it must be run before any successors.
+ # Similarly, if the job previously failed we rerun it, even if it doesn't have a
+ # command to run, to eliminate any parts of the stack now completed.
+ elif jobGraph.command is not None or resultStatus != 0:
+ isServiceJob = jobGraph.jobStoreID in self.toilState.serviceJobStoreIDToPredecessorJob
+
+ # If the job has run out of retries or is a service job whose error flag has
+ # been indicated, fail the job.
+ if (jobGraph.remainingRetryCount == 0
+ or isServiceJob and not self.jobStore.fileExists(jobGraph.errorJobStoreID)):
+ self.processTotallyFailedJob(jobGraph)
+ logger.warn("Job %s with ID %s is completely failed",
+ jobGraph, jobGraph.jobStoreID)
+ else:
+ # Otherwise try the job again
+ self.issueJob(JobNode.fromJobGraph(jobGraph))
+
+ # If the job has services to run, which have not been started, start them
+ elif len(jobGraph.services) > 0:
+ # Build a map from the service jobs to the job and a map
+ # of the services created for the job
+ assert jobGraph.jobStoreID not in self.toilState.servicesIssued
+ self.toilState.servicesIssued[jobGraph.jobStoreID] = {}
+ for serviceJobList in jobGraph.services:
+ for serviceTuple in serviceJobList:
+ serviceID = serviceTuple.jobStoreID
+ assert serviceID not in self.toilState.serviceJobStoreIDToPredecessorJob
+ self.toilState.serviceJobStoreIDToPredecessorJob[serviceID] = jobGraph
+ self.toilState.servicesIssued[jobGraph.jobStoreID][serviceID] = serviceTuple
+
+ # Use the service manager to start the services
+ self.serviceManager.scheduleServices(jobGraph)
+
+ logger.debug("Giving job: %s to service manager to schedule its jobs", jobGraph.jobStoreID)
+
+ # There exist successors to run
+ elif len(jobGraph.stack) > 0:
+ assert len(jobGraph.stack[-1]) > 0
+ logger.debug("Job: %s has %i successors to schedule",
+ jobGraph.jobStoreID, len(jobGraph.stack[-1]))
+ #Record the number of successors that must be completed before
+ #the jobGraph can be considered again
+ assert jobGraph.jobStoreID not in self.toilState.successorCounts
+ self.toilState.successorCounts[jobGraph.jobStoreID] = len(jobGraph.stack[-1])
+ #List of successors to schedule
+ successors = []
+
+ #For each successor schedule if all predecessors have been completed
+ for jobNode in jobGraph.stack[-1]:
+ successorJobStoreID = jobNode.jobStoreID
+ #Build map from successor to predecessors.
+ if successorJobStoreID not in self.toilState.successorJobStoreIDToPredecessorJobs:
+ self.toilState.successorJobStoreIDToPredecessorJobs[successorJobStoreID] = []
+ self.toilState.successorJobStoreIDToPredecessorJobs[successorJobStoreID].append(jobGraph)
+ #Case that the jobGraph has multiple predecessors
+ if jobNode.predecessorNumber > 1:
+ logger.debug("Successor job: %s of job: %s has multiple "
+ "predecessors", jobNode, jobGraph)
+
+ # Get the successor job, using a cache
+ # (if the successor job has already been seen it will be in this cache,
+ # but otherwise put it in the cache)
+ if successorJobStoreID not in self.toilState.jobsToBeScheduledWithMultiplePredecessors:
+ self.toilState.jobsToBeScheduledWithMultiplePredecessors[successorJobStoreID] = self.jobStore.load(successorJobStoreID)
+ successorJobGraph = self.toilState.jobsToBeScheduledWithMultiplePredecessors[successorJobStoreID]
+
+ #Add the jobGraph as a finished predecessor to the successor
+ successorJobGraph.predecessorsFinished.add(jobGraph.jobStoreID)
+
+ # If the successor is in the set of successors of failed jobs
+ if successorJobStoreID in self.toilState.failedSuccessors:
+ logger.debug("Successor job: %s of job: %s has failed "
+ "predecessors", jobNode, jobGraph)
+
+ # Add the job to the set having failed successors
+ self.toilState.hasFailedSuccessors.add(jobGraph.jobStoreID)
+
+ # Reduce active successor count and remove the successor as an active successor of the job
+ self.toilState.successorCounts[jobGraph.jobStoreID] -= 1
+ assert self.toilState.successorCounts[jobGraph.jobStoreID] >= 0
+ self.toilState.successorJobStoreIDToPredecessorJobs[successorJobStoreID].remove(jobGraph)
+ if len(self.toilState.successorJobStoreIDToPredecessorJobs[successorJobStoreID]) == 0:
+ self.toilState.successorJobStoreIDToPredecessorJobs.pop(successorJobStoreID)
+
+ # If the job now has no active successors add to active jobs
+ # so it can be processed as a job with failed successors
+ if self.toilState.successorCounts[jobGraph.jobStoreID] == 0:
+ logger.debug("Job: %s has no successors to run "
+ "and some are failed, adding to list of jobs "
+ "with failed successors", jobGraph)
+ self.toilState.successorCounts.pop(jobGraph.jobStoreID)
+ self.toilState.updatedJobs.add((jobGraph, 0))
+ continue
+
+ # If the successor job's predecessors have all not all completed then
+ # ignore the jobGraph as is not yet ready to run
+ assert len(successorJobGraph.predecessorsFinished) <= successorJobGraph.predecessorNumber
+ if len(successorJobGraph.predecessorsFinished) < successorJobGraph.predecessorNumber:
+ continue
+ else:
+ # Remove the successor job from the cache
+ self.toilState.jobsToBeScheduledWithMultiplePredecessors.pop(successorJobStoreID)
+
+ # Add successor to list of successors to schedule
+ successors.append(jobNode)
+ self.issueJobs(successors)
+
+ elif jobGraph.jobStoreID in self.toilState.servicesIssued:
+ logger.debug("Telling job: %s to terminate its services due to the "
+ "successful completion of its successor jobs",
+ jobGraph)
+ self.serviceManager.killServices(self.toilState.servicesIssued[jobGraph.jobStoreID], error=False)
+
+ #There are no remaining tasks to schedule within the jobGraph, but
+ #we schedule it anyway to allow it to be deleted.
+
+ #TODO: An alternative would be simple delete it here and add it to the
+ #list of jobs to process, or (better) to create an asynchronous
+ #process that deletes jobs and then feeds them back into the set
+ #of jobs to be processed
+ else:
+ # Remove the job
+ if jobGraph.remainingRetryCount > 0:
+ self.issueJob(JobNode.fromJobGraph(jobGraph))
+ logger.debug("Job: %s is empty, we are scheduling to clean it up", jobGraph.jobStoreID)
+ else:
+ self.processTotallyFailedJob(jobGraph)
+ logger.warn("Job: %s is empty but completely failed - something is very wrong", jobGraph.jobStoreID)
+
+ # Start any service jobs available from the service manager
+ self.issueQueingServiceJobs()
+ while True:
+ serviceJob = self.serviceManager.getServiceJobsToStart(0)
+ # Stop trying to get jobs when function returns None
+ if serviceJob is None:
+ break
+ logger.debug('Launching service job: %s', serviceJob)
+ self.issueServiceJob(serviceJob)
+
+ # Get jobs whose services have started
+ while True:
+ jobGraph = self.serviceManager.getJobGraphWhoseServicesAreRunning(0)
+ if jobGraph is None: # Stop trying to get jobs when function returns None
+ break
+ logger.debug('Job: %s has established its services.', jobGraph.jobStoreID)
+ jobGraph.services = []
+ self.toilState.updatedJobs.add((jobGraph, 0))
+
+ # Gather any new, updated jobGraph from the batch system
+ updatedJobTuple = self.batchSystem.getUpdatedBatchJob(2)
+ if updatedJobTuple is not None:
+ jobID, result, wallTime = updatedJobTuple
+ # easy, track different state
+ try:
+ updatedJob = self.jobBatchSystemIDToIssuedJob[jobID]
+ except KeyError:
+ logger.warn("A result seems to already have been processed "
+ "for job %s", jobID)
+ else:
+ if result == 0:
+ cur_logger = (logger.debug if str(updatedJob.jobName).startswith(self.debugJobNames)
+ else logger.info)
+ cur_logger('Job ended successfully: %s', updatedJob)
+ else:
+ logger.warn('Job failed with exit value %1: %s',
+ result, updatedJob)
+ self.processFinishedJob(jobID, result, wallTime=wallTime)
+
+ else:
+ # Process jobs that have gone awry
+
+ #In the case that there is nothing happening
+ #(no updated jobs to gather for 10 seconds)
+ #check if there are any jobs that have run too long
+ #(see self.reissueOverLongJobs) or which
+ #have gone missing from the batch system (see self.reissueMissingJobs)
+ if (time.time() - timeSinceJobsLastRescued >=
+ self.config.rescueJobsFrequency): #We only
+ #rescue jobs every N seconds, and when we have
+ #apparently exhausted the current jobGraph supply
+ self.reissueOverLongJobs()
+ logger.info("Reissued any over long jobs")
+
+ hasNoMissingJobs = self.reissueMissingJobs()
+ if hasNoMissingJobs:
+ timeSinceJobsLastRescued = time.time()
+ else:
+ timeSinceJobsLastRescued += 60 #This means we'll try again
+ #in a minute, providing things are quiet
+ logger.info("Rescued any (long) missing jobs")
+
+ # Check on the associated threads and exit if a failure is detected
+ self.statsAndLogging.check()
+ self.serviceManager.check()
+ # the cluster scaler object will only be instantiated if autoscaling is enabled
+ if self.clusterScaler is not None:
+ self.clusterScaler.check()
+
+ # The exit criterion
+ if len(self.toilState.updatedJobs) == 0 and self.getNumberOfJobsIssued() == 0 and self.serviceManager.jobsIssuedToServiceManager == 0:
+ logger.info("No jobs left to run so exiting.")
+ break
+
+ # Check for deadlocks
+ self.checkForDeadlocks()
+
+ logger.info("Finished the main loop")
+
+ # Consistency check the toil state
+ assert self.toilState.updatedJobs == set()
+ assert self.toilState.successorCounts == {}
+ assert self.toilState.successorJobStoreIDToPredecessorJobs == {}
+ assert self.toilState.serviceJobStoreIDToPredecessorJob == {}
+ assert self.toilState.servicesIssued == {}
+ # assert self.toilState.jobsToBeScheduledWithMultiplePredecessors # These are not properly emptied yet
+ # assert self.toilState.hasFailedSuccessors == set() # These are not properly emptied yet
+
+ def checkForDeadlocks(self):
+ """
+ Checks if the system is deadlocked running service jobs.
+ """
+ # If there are no updated jobs and at least some jobs issued
+ if len(self.toilState.updatedJobs) == 0 and self.getNumberOfJobsIssued() > 0:
+
+ # If all scheduled jobs are services
+ assert self.serviceJobsIssued + self.preemptableServiceJobsIssued <= self.getNumberOfJobsIssued()
+ if self.serviceJobsIssued + self.preemptableServiceJobsIssued == self.getNumberOfJobsIssued():
+
+ # Sanity check that all issued jobs are actually services
+ for jobNode in self.jobBatchSystemIDToIssuedJob.values():
+ assert jobNode.jobStoreID in self.toilState.serviceJobStoreIDToPredecessorJob
+
+ # An active service job is one that is not in the process of terminating
+ activeServiceJobs = filter(lambda x : self.serviceManager.isActive(x), self.jobBatchSystemIDToIssuedJob.values())
+
+ # If all the service jobs are active then we have a potential deadlock
+ if len(activeServiceJobs) == len(self.jobBatchSystemIDToIssuedJob):
+ # We wait self.config.deadlockWait seconds before declaring the system deadlocked
+ if self.potentialDeadlockedJobs != activeServiceJobs:
+ self.potentialDeadlockedJobs = activeServiceJobs
+ self.potentialDeadlockTime = time.time()
+ elif time.time() - self.potentialDeadlockTime >= self.config.deadlockWait:
+ raise DeadlockException("The system is service deadlocked - all issued jobs %s are active services" % self.getNumberOfJobsIssued())
+
+
+ def issueJob(self, jobNode):
+ """
+ Add a job to the queue of jobs
+ """
+ if jobNode.preemptable:
+ self.preemptableJobsIssued += 1
+ jobNode.command = ' '.join((resolveEntryPoint('_toil_worker'),
+ self.jobStoreLocator, jobNode.jobStoreID))
+ jobBatchSystemID = self.batchSystem.issueBatchJob(jobNode)
+ self.jobBatchSystemIDToIssuedJob[jobBatchSystemID] = jobNode
+ cur_logger = (logger.debug if jobNode.jobName.startswith(self.debugJobNames)
+ else logger.info)
+ cur_logger("Issued job %s with job batch system ID: "
+ "%s and cores: %s, disk: %s, and memory: %s",
+ jobNode, str(jobBatchSystemID), int(jobNode.cores),
+ bytes2human(jobNode.disk), bytes2human(jobNode.memory))
+
+ def issueJobs(self, jobs):
+ """
+ Add a list of jobs, each represented as a jobNode object
+ """
+ for job in jobs:
+ self.issueJob(job)
+
+ def issueServiceJob(self, jobNode):
+ """
+ Issue a service job, putting it on a queue if the maximum number of service
+ jobs to be scheduled has been reached.
+ """
+ if jobNode.preemptable:
+ self.preemptableServiceJobsToBeIssued.append(jobNode)
+ else:
+ self.serviceJobsToBeIssued.append(jobNode)
+ self.issueQueingServiceJobs()
+
+ def issueQueingServiceJobs(self):
+ """
+ Issues any queuing service jobs up to the limit of the maximum allowed.
+ """
+ while len(self.serviceJobsToBeIssued) > 0 and self.serviceJobsIssued < self.config.maxServiceJobs:
+ self.issueJob(self.serviceJobsToBeIssued.pop())
+ self.serviceJobsIssued += 1
+ while len(self.preemptableServiceJobsToBeIssued) > 0 and self.preemptableServiceJobsIssued < self.config.maxPreemptableServiceJobs:
+ self.issueJob(self.preemptableServiceJobsToBeIssued.pop())
+ self.preemptableServiceJobsIssued += 1
+
+ def getNumberOfJobsIssued(self, preemptable=None):
+ """
+ Gets number of jobs that have been added by issueJob(s) and not
+ removed by removeJob
+
+ :param None or boolean preemptable: If none, return all types of jobs.
+ If true, return just the number of preemptable jobs. If false, return
+ just the number of non-preemptable jobs.
+ """
+ #assert self.jobsIssued >= 0 and self._preemptableJobsIssued >= 0
+ if preemptable is None:
+ return len(self.jobBatchSystemIDToIssuedJob)
+ elif preemptable:
+ return self.preemptableJobsIssued
+ else:
+ assert len(self.jobBatchSystemIDToIssuedJob) >= self.preemptableJobsIssued
+ return len(self.jobBatchSystemIDToIssuedJob) - self.preemptableJobsIssued
+
+ def getNumberAndAvgRuntimeOfCurrentlyRunningJobs(self):
+ """
+ Returns a tuple (x, y) where x is number of currently running jobs and y
+ is the average number of seconds (as a float)
+ the jobs have been running for.
+ """
+ runningJobs = self.batchSystem.getRunningBatchJobIDs()
+ return len(runningJobs), 0 if len(runningJobs) == 0 else float(sum(runningJobs.values()))/len(runningJobs)
+
+ def getJobStoreID(self, jobBatchSystemID):
+ """
+ Gets the job file associated the a given id
+ """
+ return self.jobBatchSystemIDToIssuedJob[jobBatchSystemID].jobStoreID
+
+ def removeJob(self, jobBatchSystemID):
+ """
+ Removes a job from the system.
+ """
+ assert jobBatchSystemID in self.jobBatchSystemIDToIssuedJob
+ jobNode = self.jobBatchSystemIDToIssuedJob.pop(jobBatchSystemID)
+ if jobNode.preemptable:
+ assert self.preemptableJobsIssued > 0
+ self.preemptableJobsIssued -= 1
+
+ # If service job
+ if jobNode.jobStoreID in self.toilState.serviceJobStoreIDToPredecessorJob:
+ # Decrement the number of services
+ if jobNode.preemptable:
+ self.preemptableServiceJobsIssued -= 1
+ else:
+ self.serviceJobsIssued -= 1
+
+ return jobNode
+
+ def getJobIDs(self):
+ """
+ Gets the set of jobs currently issued.
+ """
+ return self.jobBatchSystemIDToIssuedJob.keys()
+
+ def killJobs(self, jobsToKill):
+ """
+ Kills the given set of jobs and then sends them for processing
+ """
+ if len(jobsToKill) > 0:
+ self.batchSystem.killBatchJobs(jobsToKill)
+ for jobBatchSystemID in jobsToKill:
+ self.processFinishedJob(jobBatchSystemID, 1)
+
+ #Following functions handle error cases for when jobs have gone awry with the batch system.
+
+ def reissueOverLongJobs(self):
+ """
+ Check each issued job - if it is running for longer than desirable
+ issue a kill instruction.
+ Wait for the job to die then we pass the job to processFinishedJob.
+ """
+ maxJobDuration = self.config.maxJobDuration
+ jobsToKill = []
+ if maxJobDuration < 10000000: # We won't bother doing anything if the rescue
+ # time is more than 16 weeks.
+ runningJobs = self.batchSystem.getRunningBatchJobIDs()
+ for jobBatchSystemID in runningJobs.keys():
+ if runningJobs[jobBatchSystemID] > maxJobDuration:
+ logger.warn("The job: %s has been running for: %s seconds, more than the "
+ "max job duration: %s, we'll kill it",
+ str(self.getJobStoreID(jobBatchSystemID)),
+ str(runningJobs[jobBatchSystemID]),
+ str(maxJobDuration))
+ jobsToKill.append(jobBatchSystemID)
+ self.killJobs(jobsToKill)
+
+ def reissueMissingJobs(self, killAfterNTimesMissing=3):
+ """
+ Check all the current job ids are in the list of currently running batch system jobs.
+ If a job is missing, we mark it as so, if it is missing for a number of runs of
+ this function (say 10).. then we try deleting the job (though its probably lost), we wait
+ then we pass the job to processFinishedJob.
+ """
+ runningJobs = set(self.batchSystem.getIssuedBatchJobIDs())
+ jobBatchSystemIDsSet = set(self.getJobIDs())
+ #Clean up the reissueMissingJobs_missingHash hash, getting rid of jobs that have turned up
+ missingJobIDsSet = set(self.reissueMissingJobs_missingHash.keys())
+ for jobBatchSystemID in missingJobIDsSet.difference(jobBatchSystemIDsSet):
+ self.reissueMissingJobs_missingHash.pop(jobBatchSystemID)
+ logger.warn("Batch system id: %s is no longer missing", str(jobBatchSystemID))
+ assert runningJobs.issubset(jobBatchSystemIDsSet) #Assert checks we have
+ #no unexpected jobs running
+ jobsToKill = []
+ for jobBatchSystemID in set(jobBatchSystemIDsSet.difference(runningJobs)):
+ jobStoreID = self.getJobStoreID(jobBatchSystemID)
+ if self.reissueMissingJobs_missingHash.has_key(jobBatchSystemID):
+ self.reissueMissingJobs_missingHash[jobBatchSystemID] += 1
+ else:
+ self.reissueMissingJobs_missingHash[jobBatchSystemID] = 1
+ timesMissing = self.reissueMissingJobs_missingHash[jobBatchSystemID]
+ logger.warn("Job store ID %s with batch system id %s is missing for the %i time",
+ jobStoreID, str(jobBatchSystemID), timesMissing)
+ if timesMissing == killAfterNTimesMissing:
+ self.reissueMissingJobs_missingHash.pop(jobBatchSystemID)
+ jobsToKill.append(jobBatchSystemID)
+ self.killJobs(jobsToKill)
+ return len( self.reissueMissingJobs_missingHash ) == 0 #We use this to inform
+ #if there are missing jobs
+
+ def processFinishedJob(self, batchSystemID, resultStatus, wallTime=None):
+ """
+ Function reads a processed jobGraph file and updates it state.
+ """
+ def processRemovedJob(issuedJob):
+ if resultStatus != 0:
+ logger.warn("Despite the batch system claiming failure the "
+ "job %s seems to have finished and been removed", issuedJob)
+ self._updatePredecessorStatus(issuedJob.jobStoreID)
+ jobNode = self.removeJob(batchSystemID)
+ jobStoreID = jobNode.jobStoreID
+ if wallTime is not None and self.clusterScaler is not None:
+ self.clusterScaler.addCompletedJob(jobNode, wallTime)
+ if self.jobStore.exists(jobStoreID):
+ logger.debug("Job %s continues to exist (i.e. has more to do)", jobNode)
+ try:
+ jobGraph = self.jobStore.load(jobStoreID)
+ except NoSuchJobException:
+ # Avoid importing AWSJobStore as the corresponding extra might be missing
+ if self.jobStore.__class__.__name__ == 'AWSJobStore':
+ # We have a ghost job - the job has been deleted but a stale read from
+ # SDB gave us a false positive when we checked for its existence.
+ # Process the job from here as any other job removed from the job store.
+ # This is a temporary work around until https://github.com/BD2KGenomics/toil/issues/1091
+ # is completed
+ logger.warn('Got a stale read from SDB for job %s', jobNode)
+ processRemovedJob(jobNode)
+ return
+ else:
+ raise
+ if jobGraph.logJobStoreFileID is not None:
+ with jobGraph.getLogFileHandle( self.jobStore ) as logFileStream:
+ # more memory efficient than read().striplines() while leaving off the
+ # trailing \n left when using readlines()
+ # http://stackoverflow.com/a/15233739
+ messages = [line.rstrip('\n') for line in logFileStream]
+ logFormat = '\n%s ' % jobStoreID
+ logger.warn('The job seems to have left a log file, indicating failure: %s\n%s',
+ jobGraph, logFormat.join(messages))
+ StatsAndLogging.writeLogFiles(jobGraph.chainedJobs, messages, self.config)
+ if resultStatus != 0:
+ # If the batch system returned a non-zero exit code then the worker
+ # is assumed not to have captured the failure of the job, so we
+ # reduce the retry count here.
+ if jobGraph.logJobStoreFileID is None:
+ logger.warn("No log file is present, despite job failing: %s", jobNode)
+ jobGraph.setupJobAfterFailure(self.config)
+ self.jobStore.update(jobGraph)
+ elif jobStoreID in self.toilState.hasFailedSuccessors:
+ # If the job has completed okay, we can remove it from the list of jobs with failed successors
+ self.toilState.hasFailedSuccessors.remove(jobStoreID)
+
+ self.toilState.updatedJobs.add((jobGraph, resultStatus)) #Now we know the
+ #jobGraph is done we can add it to the list of updated jobGraph files
+ logger.debug("Added job: %s to active jobs", jobGraph)
+ else: #The jobGraph is done
+ processRemovedJob(jobNode)
+
+ @staticmethod
+ def getSuccessors(jobGraph, alreadySeenSuccessors, jobStore):
+ """
+ Gets successors of the given job by walking the job graph recursively.
+ Any successor in alreadySeenSuccessors is ignored and not traversed.
+ Returns the set of found successors. This set is added to alreadySeenSuccessors.
+ """
+ successors = set()
+
+ def successorRecursion(jobGraph):
+ # For lists of successors
+ for successorList in jobGraph.stack:
+
+ # For each successor in list of successors
+ for successorJobNode in successorList:
+
+ # Id of the successor
+ successorJobStoreID = successorJobNode.jobStoreID
+
+ # If successor not already visited
+ if successorJobStoreID not in alreadySeenSuccessors:
+
+ # Add to set of successors
+ successors.add(successorJobStoreID)
+ alreadySeenSuccessors.add(successorJobStoreID)
+
+ # Recurse if job exists
+ # (job may not exist if already completed)
+ if jobStore.exists(successorJobStoreID):
+ successorRecursion(jobStore.load(successorJobStoreID))
+
+ successorRecursion(jobGraph) # Recurse from jobGraph
+
+ return successors
+
+ def processTotallyFailedJob(self, jobGraph):
+ """
+ Processes a totally failed job.
+ """
+ # Mark job as a totally failed job
+ self.toilState.totalFailedJobs.add(JobNode.fromJobGraph(jobGraph))
+
+ if jobGraph.jobStoreID in self.toilState.serviceJobStoreIDToPredecessorJob: # Is
+ # a service job
+ logger.debug("Service job is being processed as a totally failed job: %s", jobGraph)
+
+ predecesssorJobGraph = self.toilState.serviceJobStoreIDToPredecessorJob[jobGraph.jobStoreID]
+
+ # This removes the service job as a service of the predecessor
+ # and potentially makes the predecessor active
+ self._updatePredecessorStatus(jobGraph.jobStoreID)
+
+ # Remove the start flag, if it still exists. This indicates
+ # to the service manager that the job has "started", this prevents
+ # the service manager from deadlocking while waiting
+ self.jobStore.deleteFile(jobGraph.startJobStoreID)
+
+ # Signal to any other services in the group that they should
+ # terminate. We do this to prevent other services in the set
+ # of services from deadlocking waiting for this service to start properly
+ if predecesssorJobGraph.jobStoreID in self.toilState.servicesIssued:
+ self.serviceManager.killServices(self.toilState.servicesIssued[predecesssorJobGraph.jobStoreID], error=True)
+ logger.debug("Job: %s is instructing all the services of its parent job to quit", jobGraph)
+
+ self.toilState.hasFailedSuccessors.add(predecesssorJobGraph.jobStoreID) # This ensures that the
+ # job will not attempt to run any of it's successors on the stack
+ else:
+ # Is a non-service job
+ assert jobGraph.jobStoreID not in self.toilState.servicesIssued
+
+ # Traverse failed job's successor graph and get the jobStoreID of new successors.
+ # Any successor already in toilState.failedSuccessors will not be traversed
+ # All successors traversed will be added to toilState.failedSuccessors and returned
+ # as a set (unseenSuccessors).
+ unseenSuccessors = self.getSuccessors(jobGraph, self.toilState.failedSuccessors,
+ self.jobStore)
+ logger.debug("Found new failed successors: %s of job: %s", " ".join(
+ unseenSuccessors), jobGraph)
+
+ # For each newly found successor
+ for successorJobStoreID in unseenSuccessors:
+
+ # If the successor is a successor of other jobs that have already tried to schedule it
+ if successorJobStoreID in self.toilState.successorJobStoreIDToPredecessorJobs:
+
+ # For each such predecessor job
+ # (we remove the successor from toilState.successorJobStoreIDToPredecessorJobs to avoid doing
+ # this multiple times for each failed predecessor)
+ for predecessorJob in self.toilState.successorJobStoreIDToPredecessorJobs.pop(successorJobStoreID):
+
+ # Reduce the predecessor job's successor count.
+ self.toilState.successorCounts[predecessorJob.jobStoreID] -= 1
+
+ # Indicate that it has failed jobs.
+ self.toilState.hasFailedSuccessors.add(predecessorJob.jobStoreID)
+ logger.debug("Marking job: %s as having failed successors (found by "
+ "reading successors failed job)", predecessorJob)
+
+ # If the predecessor has no remaining successors, add to list of active jobs
+ assert self.toilState.successorCounts[predecessorJob.jobStoreID] >= 0
+ if self.toilState.successorCounts[predecessorJob.jobStoreID] == 0:
+ self.toilState.updatedJobs.add((predecessorJob, 0))
+
+ # Remove the predecessor job from the set of jobs with successors.
+ self.toilState.successorCounts.pop(predecessorJob.jobStoreID)
+
+ # If the job has predecessor(s)
+ if jobGraph.jobStoreID in self.toilState.successorJobStoreIDToPredecessorJobs:
+
+ # For each predecessor of the job
+ for predecessorJobGraph in self.toilState.successorJobStoreIDToPredecessorJobs[jobGraph.jobStoreID]:
+
+ # Mark the predecessor as failed
+ self.toilState.hasFailedSuccessors.add(predecessorJobGraph.jobStoreID)
+ logger.debug("Totally failed job: %s is marking direct predecessor: %s "
+ "as having failed jobs", jobGraph, predecessorJobGraph)
+
+ self._updatePredecessorStatus(jobGraph.jobStoreID)
+
+ def _updatePredecessorStatus(self, jobStoreID):
+ """
+ Update status of predecessors for finished successor job.
+ """
+ if jobStoreID in self.toilState.serviceJobStoreIDToPredecessorJob:
+ # Is a service job
+ predecessorJob = self.toilState.serviceJobStoreIDToPredecessorJob.pop(jobStoreID)
+ self.toilState.servicesIssued[predecessorJob.jobStoreID].pop(jobStoreID)
+ if len(self.toilState.servicesIssued[predecessorJob.jobStoreID]) == 0: # Predecessor job has
+ # all its services terminated
+ self.toilState.servicesIssued.pop(predecessorJob.jobStoreID) # The job has no running services
+ self.toilState.updatedJobs.add((predecessorJob, 0)) # Now we know
+ # the job is done we can add it to the list of updated job files
+ logger.debug("Job %s services have completed or totally failed, adding to updated jobs", predecessorJob)
+
+ elif jobStoreID not in self.toilState.successorJobStoreIDToPredecessorJobs:
+ #We have reach the root job
+ assert len(self.toilState.updatedJobs) == 0
+ assert len(self.toilState.successorJobStoreIDToPredecessorJobs) == 0
+ assert len(self.toilState.successorCounts) == 0
+ logger.debug("Reached root job %s so no predecessors to clean up" % jobStoreID)
+
+ else:
+ # Is a non-root, non-service job
+ logger.debug("Cleaning the predecessors of %s" % jobStoreID)
+
+ # For each predecessor
+ for predecessorJob in self.toilState.successorJobStoreIDToPredecessorJobs.pop(jobStoreID):
+
+ # Reduce the predecessor's number of successors by one to indicate the
+ # completion of the jobStoreID job
+ self.toilState.successorCounts[predecessorJob.jobStoreID] -= 1
+
+ # If the predecessor job is done and all the successors are complete
+ if self.toilState.successorCounts[predecessorJob.jobStoreID] == 0:
+
+ # Remove it from the set of jobs with active successors
+ self.toilState.successorCounts.pop(predecessorJob.jobStoreID)
+
+ # Pop stack at this point, as we can get rid of its successors
+ predecessorJob.stack.pop()
+
+ # Now we know the job is done we can add it to the list of updated job files
+ assert predecessorJob not in self.toilState.updatedJobs
+ self.toilState.updatedJobs.add((predecessorJob, 0))
+
+ logger.debug('Job %s has all its non-service successors completed or totally '
+ 'failed', predecessorJob)
diff --git a/src/toil/lib/__init__.py b/src/toil/lib/__init__.py
new file mode 100644
index 0000000..20da7b0
--- /dev/null
+++ b/src/toil/lib/__init__.py
@@ -0,0 +1,14 @@
+# Copyright (C) 2015-2016 Regents of the University of California
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from __future__ import absolute_import
diff --git a/src/toil/lib/bioio.py b/src/toil/lib/bioio.py
new file mode 100644
index 0000000..3a3d1c7
--- /dev/null
+++ b/src/toil/lib/bioio.py
@@ -0,0 +1,309 @@
+# Copyright (C) 2015-2016 Regents of the University of California
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+
+import socket
+import sys
+import os
+import logging
+import resource
+import logging.handlers
+import tempfile
+import random
+import math
+import shutil
+from argparse import ArgumentParser
+from optparse import OptionContainer, OptionGroup
+import subprocess
+
+# Python 3 compatibility imports
+from six.moves import xrange
+from six import string_types
+
+import xml.etree.cElementTree as ET
+from xml.dom import minidom # For making stuff pretty
+
+defaultLogLevel = logging.INFO
+logger = logging.getLogger(__name__)
+rootLogger = logging.getLogger()
+toilLogger = logging.getLogger('toil')
+
+
+def getLogLevelString(logger=None):
+ if logger is None:
+ logger = rootLogger
+ return logging.getLevelName(logger.getEffectiveLevel())
+
+__loggingFiles = []
+def addLoggingFileHandler(fileName, rotatingLogging=False):
+ if fileName in __loggingFiles:
+ return
+ __loggingFiles.append(fileName)
+ if rotatingLogging:
+ handler = logging.handlers.RotatingFileHandler(fileName, maxBytes=1000000, backupCount=1)
+ else:
+ handler = logging.FileHandler(fileName)
+ rootLogger.addHandler(handler)
+ return handler
+
+
+def setLogLevel(level, logger=None):
+ """
+ Sets the log level to a given string level (like "INFO"). Operates on the
+ root logger by default, but another logger can be specified instead.
+ """
+ if logger is None:
+ logger = rootLogger
+ level = level.upper()
+ if level == "OFF": level = "CRITICAL"
+ # Note that getLevelName works in both directions, numeric to textual and textual to numeric
+ numericLevel = logging.getLevelName(level)
+ assert logging.getLevelName(numericLevel) == level
+ logger.setLevel(numericLevel)
+ # There are quite a few cases where we expect AWS requests to fail, but it seems
+ # that boto handles these by logging the error *and* raising an exception. We
+ # don't want to confuse the user with those error messages.
+ logging.getLogger( 'boto' ).setLevel( logging.CRITICAL )
+
+def logFile(fileName, printFunction=logger.info):
+ """Writes out a formatted version of the given log file
+ """
+ printFunction("Reporting file: %s" % fileName)
+ shortName = fileName.split("/")[-1]
+ fileHandle = open(fileName, 'r')
+ line = fileHandle.readline()
+ while line != '':
+ if line[-1] == '\n':
+ line = line[:-1]
+ printFunction("%s:\t%s" % (shortName, line))
+ line = fileHandle.readline()
+ fileHandle.close()
+
+def logStream(fileHandle, shortName, printFunction=logger.info):
+ """Writes out a formatted version of the given log stream.
+ """
+ printFunction("Reporting file: %s" % shortName)
+ line = fileHandle.readline()
+ while line != '':
+ if line[-1] == '\n':
+ line = line[:-1]
+ printFunction("%s:\t%s" % (shortName, line))
+ line = fileHandle.readline()
+ fileHandle.close()
+
+def addLoggingOptions(parser):
+ # Wrapper function that allows toil to be used with both the optparse and
+ # argparse option parsing modules
+ if isinstance(parser, ArgumentParser):
+ group = parser.add_argument_group("Logging Options",
+ "Options that control logging")
+ _addLoggingOptions(group.add_argument)
+ else:
+ raise RuntimeError("Unanticipated class passed to "
+ "addLoggingOptions(), %s. Expecting "
+ "argparse.ArgumentParser" % parser.__class__)
+
+supportedLogLevels = (logging.CRITICAL, logging.ERROR, logging.WARNING, logging.INFO, logging.DEBUG)
+
+def _addLoggingOptions(addOptionFn):
+ """
+ Adds logging options
+ """
+ # BEFORE YOU ADD OR REMOVE OPTIONS TO THIS FUNCTION, KNOW THAT YOU MAY ONLY USE VARIABLES ACCEPTED BY BOTH
+ # optparse AND argparse FOR EXAMPLE, YOU MAY NOT USE default=%default OR default=%(default)s
+ defaultLogLevelName = logging.getLevelName( defaultLogLevel )
+ addOptionFn("--logOff", dest="logCritical", action="store_true", default=False,
+ help="Same as --logCritical")
+ for level in supportedLogLevels:
+ levelName = logging.getLevelName(level)
+ levelNameCapitalized = levelName.capitalize()
+ addOptionFn("--log" + levelNameCapitalized, dest="logLevel",
+ action="store_const", const=levelName,
+ help="Turn on logging at level %s and above. (default is %s)" % (levelName, defaultLogLevelName))
+ addOptionFn("--logLevel", dest="logLevel", default=defaultLogLevelName,
+ help=("Log at given level (may be either OFF (or CRITICAL), ERROR, WARN (or WARNING), INFO or DEBUG). "
+ "(default is %s)" % defaultLogLevelName))
+ addOptionFn("--logFile", dest="logFile", help="File to log in")
+ addOptionFn("--rotatingLogging", dest="logRotating", action="store_true", default=False,
+ help="Turn on rotating logging, which prevents log files getting too big.")
+
+def setLoggingFromOptions(options):
+ """
+ Sets the logging from a dictionary of name/value options.
+ """
+ formatStr = ' '.join([socket.gethostname(), '%(asctime)s', '%(threadName)s',
+ '%(levelname)s', '%(name)s:', '%(message)s'])
+ logging.basicConfig(format=formatStr)
+ rootLogger.setLevel(defaultLogLevel)
+ if options.logLevel is not None:
+ setLogLevel(options.logLevel)
+ else:
+ # Ensure that any other log level overrides are in effect even if no log level is explicitly set
+ setLogLevel(getLogLevelString())
+ logger.info("Root logger is at level '%s', 'toil' logger at level '%s'.",
+ getLogLevelString(logger=rootLogger), getLogLevelString(logger=toilLogger))
+ if options.logFile is not None:
+ addLoggingFileHandler(options.logFile, rotatingLogging=options.logRotating)
+ logger.info("Logging to file '%s'." % options.logFile)
+
+
+def system(command):
+ """
+ A convenience wrapper around subprocess.check_call that logs the command before passing it
+ on. The command can be either a string or a sequence of strings. If it is a string shell=True
+ will be passed to subprocess.check_call.
+
+ :type command: str | sequence[string]
+ """
+ logger.debug('Running: %r', command)
+ subprocess.check_call(command, shell=isinstance(command, string_types), bufsize=-1)
+
+def getTotalCpuTimeAndMemoryUsage():
+ """Gives the total cpu time and memory usage of itself and its children.
+ """
+ me = resource.getrusage(resource.RUSAGE_SELF)
+ childs = resource.getrusage(resource.RUSAGE_CHILDREN)
+ totalCPUTime = me.ru_utime+me.ru_stime+childs.ru_utime+childs.ru_stime
+ totalMemoryUsage = me.ru_maxrss+ me.ru_maxrss
+ return totalCPUTime, totalMemoryUsage
+
+def getTotalCpuTime():
+ """Gives the total cpu time, including the children.
+ """
+ return getTotalCpuTimeAndMemoryUsage()[0]
+
+def getTotalMemoryUsage():
+ """Gets the amount of memory used by the process and its children.
+ """
+ return getTotalCpuTimeAndMemoryUsage()[1]
+
+def absSymPath(path):
+ """like os.path.abspath except it doesn't dereference symlinks
+ """
+ curr_path = os.getcwd()
+ return os.path.normpath(os.path.join(curr_path, path))
+
+#########################################################
+#########################################################
+#########################################################
+#testing settings
+#########################################################
+#########################################################
+#########################################################
+
+class TestStatus:
+ ###Global variables used by testing framework to run tests.
+ TEST_SHORT = 0
+ TEST_MEDIUM = 1
+ TEST_LONG = 2
+ TEST_VERY_LONG = 3
+
+ TEST_STATUS = TEST_SHORT
+
+ SAVE_ERROR_LOCATION = None
+
+ def getTestStatus():
+ return TestStatus.TEST_STATUS
+ getTestStatus = staticmethod(getTestStatus)
+
+ def setTestStatus(status):
+ assert status in (TestStatus.TEST_SHORT, TestStatus.TEST_MEDIUM, TestStatus.TEST_LONG, TestStatus.TEST_VERY_LONG)
+ TestStatus.TEST_STATUS = status
+ setTestStatus = staticmethod(setTestStatus)
+
+ def getSaveErrorLocation():
+ """Location to in which to write inputs which created test error.
+ """
+ return TestStatus.SAVE_ERROR_LOCATION
+ getSaveErrorLocation = staticmethod(getSaveErrorLocation)
+
+ def setSaveErrorLocation(dir):
+ """Set location in which to write inputs which created test error.
+ """
+ logger.info("Location to save error files in: %s" % dir)
+ assert os.path.isdir(dir)
+ TestStatus.SAVE_ERROR_LOCATION = dir
+ setSaveErrorLocation = staticmethod(setSaveErrorLocation)
+
+ def getTestSetup(shortTestNo=1, mediumTestNo=5, longTestNo=100, veryLongTestNo=0):
+ if TestStatus.TEST_STATUS == TestStatus.TEST_SHORT:
+ return shortTestNo
+ elif TestStatus.TEST_STATUS == TestStatus.TEST_MEDIUM:
+ return mediumTestNo
+ elif TestStatus.TEST_STATUS == TestStatus.TEST_LONG:
+ return longTestNo
+ else: #Used for long example tests
+ return veryLongTestNo
+ getTestSetup = staticmethod(getTestSetup)
+
+ def getPathToDataSets():
+ """This method is used to store the location of
+ the path where all the data sets used by tests for analysis are kept.
+ These are not kept in the distrbution itself for reasons of size.
+ """
+ assert "SON_TRACE_DATASETS" in os.environ
+ return os.environ["SON_TRACE_DATASETS"]
+ getPathToDataSets = staticmethod(getPathToDataSets)
+
+def getBasicOptionParser( parser=None):
+ if parser is None:
+ parser = ArgumentParser()
+
+ addLoggingOptions(parser)
+
+ parser.add_argument("--tempDirRoot", dest="tempDirRoot", type=str,
+ help="Path to where temporary directory containing all temp files are created, by default uses the current working directory as the base.",
+ default=tempfile.gettempdir())
+
+ return parser
+
+def parseBasicOptions(parser):
+ """Setups the standard things from things added by getBasicOptionParser.
+ """
+ options = parser.parse_args()
+
+ setLoggingFromOptions(options)
+
+ #Set up the temp dir root
+ if options.tempDirRoot == "None": # FIXME: Really, a string containing the word None?
+ options.tempDirRoot = tempfile.gettempdir()
+
+ return options
+
+def getRandomAlphaNumericString(length=10):
+ """Returns a random alpha numeric string of the given length.
+ """
+ return "".join([ random.choice('0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz') for i in xrange(0, length) ])
+
+def makePublicDir(dirName):
+ """Makes a given subdirectory if it doesn't already exist, making sure it is public.
+ """
+ if not os.path.exists(dirName):
+ os.mkdir(dirName)
+ os.chmod(dirName, 0o777)
+ return dirName
+
+def getTempFile(suffix="", rootDir=None):
+ """Returns a string representing a temporary file, that must be manually deleted
+ """
+ if rootDir is None:
+ handle, tmpFile = tempfile.mkstemp(suffix)
+ os.close(handle)
+ return tmpFile
+ else:
+ tmpFile = os.path.join(rootDir, "tmp_" + getRandomAlphaNumericString() + suffix)
+ open(tmpFile, 'w').close()
+ os.chmod(tmpFile, 0o777) #Ensure everyone has access to the file.
+ return tmpFile
diff --git a/src/toil/lib/encryption/__init__.py b/src/toil/lib/encryption/__init__.py
new file mode 100644
index 0000000..33d9a2e
--- /dev/null
+++ b/src/toil/lib/encryption/__init__.py
@@ -0,0 +1,18 @@
+# Copyright (C) 2015-2016 Regents of the University of California
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+try:
+ from _nacl import *
+except ImportError:
+ from _dummy import *
diff --git a/src/toil/lib/encryption/_dummy.py b/src/toil/lib/encryption/_dummy.py
new file mode 100644
index 0000000..1b270b3
--- /dev/null
+++ b/src/toil/lib/encryption/_dummy.py
@@ -0,0 +1,32 @@
+# Copyright (C) 2015-2016 Regents of the University of California
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+overhead = 0
+
+
+# noinspection PyUnusedLocal
+def encrypt(message, keyPath):
+ _bail()
+
+
+# noinspection PyUnusedLocal
+def decrypt(ciphertext, keyPath):
+ _bail()
+
+
+def _bail():
+ raise NotImplementedError("Encryption support is not installed. Consider re-installing toil "
+ "with the 'encryption' extra along with any other extras you might "
+ "want, e.g. 'pip install toil[encryption,...]'.")
diff --git a/src/toil/lib/encryption/_nacl.py b/src/toil/lib/encryption/_nacl.py
new file mode 100644
index 0000000..d440a94
--- /dev/null
+++ b/src/toil/lib/encryption/_nacl.py
@@ -0,0 +1,89 @@
+# Copyright (C) 2015-2016 Regents of the University of California
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import nacl
+from nacl.secret import SecretBox
+
+
+# 16-byte MAC plus a nonce is added to every message.
+overhead = 16 + SecretBox.NONCE_SIZE
+
+def encrypt(message, keyPath):
+ """
+ Encrypts a message given a path to a local file containing a key.
+
+ :param message: The message to be encrypted.
+ :param keyPath: A path to a file containing a 256-bit key (and nothing else).
+ :type message: str
+ :type keyPath: str
+ :rtype: str
+
+ A constant overhead is added to every encrypted message (for the nonce and MAC).
+ >>> import tempfile
+ >>> k = tempfile.mktemp()
+ >>> with open(k, 'w') as f:
+ ... f.write(nacl.utils.random(SecretBox.KEY_SIZE))
+ >>> message = 'test'
+ >>> len(encrypt(message, k)) == overhead + len(message)
+ True
+ """
+ with open(keyPath) as f:
+ key = f.read()
+ if len(key) != SecretBox.KEY_SIZE:
+ raise ValueError("Key is %d bytes, but must be exactly %d bytes" % (len(key),
+ SecretBox.KEY_SIZE))
+ sb = SecretBox(key)
+ # We generate the nonce using secure random bits. For long enough
+ # nonce size, the chance of a random nonce collision becomes
+ # *much* smaller than the chance of a subtle coding error causing
+ # a nonce reuse. Currently the nonce size is 192 bits--the chance
+ # of a collision is astronomically low. (This approach is
+ # recommended in the libsodium documentation.)
+ nonce = nacl.utils.random(SecretBox.NONCE_SIZE)
+ assert len(nonce) == SecretBox.NONCE_SIZE
+ return str(sb.encrypt(message, nonce))
+
+def decrypt(ciphertext, keyPath):
+ """
+ Decrypts a given message that was encrypted with the encrypt() method.
+
+ :param ciphertext: The encrypted message (as a string).
+ :param keyPath: A path to a file containing a 256-bit key (and nothing else).
+ :type keyPath: str
+ :rtype: str
+
+ Raises an error if ciphertext was modified
+ >>> import tempfile
+ >>> k = tempfile.mktemp()
+ >>> with open(k, 'w') as f:
+ ... f.write(nacl.utils.random(SecretBox.KEY_SIZE))
+ >>> ciphertext = encrypt("testMessage", k)
+ >>> ciphertext = chr(ord(ciphertext[0]) ^ 1) + ciphertext[1:]
+ >>> decrypt(ciphertext, k)
+ Traceback (most recent call last):
+ ...
+ CryptoError: Decryption failed. Ciphertext failed verification
+
+ Otherwise works correctly
+ >>> decrypt(encrypt("testMessage", k), k)
+ 'testMessage'
+ """
+ with open(keyPath) as f:
+ key = f.read()
+ if len(key) != SecretBox.KEY_SIZE:
+ raise ValueError("Key is %d bytes, but must be exactly %d bytes" % (len(key),
+ SecretBox.KEY_SIZE))
+ sb = SecretBox(key)
+ # The nonce is kept with the message.
+ return sb.decrypt(ciphertext)
diff --git a/src/toil/lib/encryption/conftest.py b/src/toil/lib/encryption/conftest.py
new file mode 100644
index 0000000..d2572f6
--- /dev/null
+++ b/src/toil/lib/encryption/conftest.py
@@ -0,0 +1,8 @@
+# https://pytest.org/latest/example/pythoncollection.html
+
+collect_ignore = []
+
+try:
+ import nacl
+except ImportError:
+ collect_ignore.append("_nacl.py")
diff --git a/src/toil/provisioners/__init__.py b/src/toil/provisioners/__init__.py
new file mode 100644
index 0000000..76424a4
--- /dev/null
+++ b/src/toil/provisioners/__init__.py
@@ -0,0 +1,76 @@
+# Copyright (C) 2015-2016 Regents of the University of California
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from __future__ import absolute_import
+import datetime
+import logging
+import os
+
+from bd2k.util import parse_iso_utc, less_strict_bool
+
+
+logger = logging.getLogger(__name__)
+
+
+def awsRemainingBillingInterval(instance):
+ def partialBillingInterval(instance):
+ """
+ Returns a floating point value between 0 and 1.0 representing how far we are into the
+ current billing cycle for the given instance. If the return value is .25, we are one
+ quarter into the billing cycle, with three quarters remaining before we will be charged
+ again for that instance.
+ """
+ launch_time = parse_iso_utc(instance.launch_time)
+ now = datetime.datetime.utcnow()
+ delta = now - launch_time
+ return delta.total_seconds() / 3600.0 % 1.0
+
+ return 1.0 - partialBillingInterval(instance)
+
+
+def awsFilterImpairedNodes(nodes, ec2):
+ # if TOIL_AWS_NODE_DEBUG is set don't terminate nodes with
+ # failing status checks so they can be debugged
+ nodeDebug = less_strict_bool(os.environ.get('TOIL_AWS_NODE_DEBUG'))
+ if not nodeDebug:
+ return nodes
+ nodeIDs = [node.id for node in nodes]
+ statuses = ec2.get_all_instance_status(instance_ids=nodeIDs)
+ statusMap = {status.id: status.instance_status for status in statuses}
+ healthyNodes = [node for node in nodes if statusMap.get(node.id, None) != 'impaired']
+ impairedNodes = [node.id for node in nodes if statusMap.get(node.id, None) == 'impaired']
+ logger.warn('TOIL_AWS_NODE_DEBUG is set and nodes %s have failed EC2 status checks so '
+ 'will not be terminated.', ' '.join(impairedNodes))
+ return healthyNodes
+
+
+class Cluster(object):
+ def __init__(self, clusterName, provisioner):
+ self.clusterName = clusterName
+ if provisioner == 'aws':
+ from toil.provisioners.aws.awsProvisioner import AWSProvisioner
+ self.provisioner = AWSProvisioner
+ elif provisioner == 'cgcloud':
+ from toil.provisioners.cgcloud.provisioner import CGCloudProvisioner
+ self.provisioner = CGCloudProvisioner
+ else:
+ assert False, "Invalid provisioner '%s'" % provisioner
+
+ def sshCluster(self, args):
+ self.provisioner.sshLeader(self.clusterName, args)
+
+ def rsyncCluster(self, args):
+ self.provisioner.rsyncLeader(self.clusterName, args)
+
+ def destroyCluster(self):
+ self.provisioner.destroyCluster(self.clusterName)
diff --git a/src/toil/provisioners/abstractProvisioner.py b/src/toil/provisioners/abstractProvisioner.py
new file mode 100644
index 0000000..c5ab5c3
--- /dev/null
+++ b/src/toil/provisioners/abstractProvisioner.py
@@ -0,0 +1,262 @@
+# Copyright (C) 2015-2016 Regents of the University of California
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+import logging
+import os
+import threading
+from abc import ABCMeta, abstractmethod
+
+from collections import namedtuple
+
+from itertools import islice
+
+import time
+
+from bd2k.util.threading import ExceptionalThread
+
+from toil.batchSystems.abstractBatchSystem import AbstractScalableBatchSystem
+
+log = logging.getLogger(__name__)
+
+
+Shape = namedtuple("_Shape", "wallTime memory cores disk")
+"""
+Represents a job or a node's "shape", in terms of the dimensions of memory, cores, disk and
+wall-time allocation. All attributes are integers.
+
+The wallTime attribute stores the number of seconds of a node allocation, e.g. 3600 for AWS,
+or 60 for Azure. FIXME: and for jobs?
+
+The memory and disk attributes store the number of bytes required by a job (or provided by a
+node) in RAM or on disk (SSD or HDD), respectively.
+"""
+
+
+class AbstractProvisioner(object):
+ """
+ An abstract base class to represent the interface for provisioning worker nodes to use in a
+ Toil cluster.
+ """
+
+ __metaclass__ = ABCMeta
+
+ def __init__(self, config, batchSystem):
+ self.config = config
+ self.batchSystem = batchSystem
+ self.stop = False
+ self.stats = {}
+ self.statsThreads = []
+ self.statsPath = config.clusterStats
+ self.scaleable = isinstance(self.batchSystem, AbstractScalableBatchSystem)
+
+ def shutDown(self, preemptable):
+ if not self.stop:
+ # only shutdown the stats threads once
+ self._shutDownStats()
+ log.debug('Forcing provisioner to reduce cluster size to zero.')
+ totalNodes = self.setNodeCount(numNodes=0, preemptable=preemptable, force=True)
+ if totalNodes != 0:
+ raise RuntimeError('Provisioner was not able to reduce cluster size to zero.')
+
+ def _shutDownStats(self):
+ def getFileName():
+ extension = '.json'
+ file = '%s-stats' % self.config.jobStore
+ counter = 0
+ while True:
+ suffix = str(counter).zfill(3) + extension
+ fullName = os.path.join(self.statsPath, file + suffix)
+ if not os.path.exists(fullName):
+ return fullName
+ counter += 1
+ if self.config.clusterStats and self.scaleable:
+ self.stop = True
+ for thread in self.statsThreads:
+ thread.join()
+ fileName = getFileName()
+ with open(fileName, 'w') as f:
+ json.dump(self.stats, f)
+
+ def startStats(self, preemptable):
+ thread = ExceptionalThread(target=self._gatherStats, args=[preemptable])
+ thread.start()
+ self.statsThreads.append(thread)
+
+ def checkStats(self):
+ for thread in self.statsThreads:
+ # propagate any errors raised in the threads execution
+ thread.join(timeout=0)
+
+ def _gatherStats(self, preemptable):
+ def toDict(nodeInfo):
+ # namedtuples don't retain attribute names when dumped to JSON.
+ # convert them to dicts instead to improve stats output. Also add
+ # time.
+ return dict(memory=nodeInfo.memory,
+ cores=nodeInfo.cores,
+ workers=nodeInfo.workers,
+ time=time.time()
+ )
+ if self.scaleable:
+ stats = {}
+ try:
+ while not self.stop:
+ nodeInfo = self.batchSystem.getNodes(preemptable)
+ for nodeIP in nodeInfo.keys():
+ nodeStats = nodeInfo[nodeIP]
+ if nodeStats is not None:
+ nodeStats = toDict(nodeStats)
+ try:
+ # if the node is already registered update the dictionary with
+ # the newly reported stats
+ stats[nodeIP].append(nodeStats)
+ except KeyError:
+ # create a new entry for the node
+ stats[nodeIP] = [nodeStats]
+ time.sleep(60)
+ finally:
+ threadName = 'Preemptable' if preemptable else 'Non-preemptable'
+ log.debug('%s provisioner stats thread shut down successfully.', threadName)
+ self.stats[threadName] = stats
+ else:
+ pass
+
+ def setNodeCount(self, numNodes, preemptable=False, force=False):
+ """
+ Attempt to grow or shrink the number of prepemptable or non-preemptable worker nodes in
+ the cluster to the given value, or as close a value as possible, and, after performing
+ the necessary additions or removals of worker nodes, return the resulting number of
+ preemptable or non-preemptable nodes currently in the cluster.
+
+ :param int numNodes: Desired size of the cluster
+
+ :param bool preemptable: whether the added nodes will be preemptable, i.e. whether they
+ may be removed spontaneously by the underlying platform at any time.
+
+ :param bool force: If False, the provisioner is allowed to deviate from the given number
+ of nodes. For example, when downsizing a cluster, a provisioner might leave nodes
+ running if they have active jobs running on them.
+
+ :rtype: int :return: the number of nodes in the cluster after making the necessary
+ adjustments. This value should be, but is not guaranteed to be, close or equal to
+ the `numNodes` argument. It represents the closest possible approximation of the
+ actual cluster size at the time this method returns.
+ """
+ workerInstances = self._getWorkersInCluster(preemptable)
+ numCurrentNodes = len(workerInstances)
+ delta = numNodes - numCurrentNodes
+ if delta > 0:
+ log.info('Adding %i %s nodes to get to desired cluster size of %i.', delta, 'preemptable' if preemptable else 'non-preemptable', numNodes)
+ numNodes = numCurrentNodes + self._addNodes(workerInstances,
+ numNodes=delta,
+ preemptable=preemptable)
+ elif delta < 0:
+ log.info('Removing %i %s nodes to get to desired cluster size of %i.', -delta, 'preemptable' if preemptable else 'non-preemptable', numNodes)
+ numNodes = numCurrentNodes - self._removeNodes(workerInstances,
+ numNodes=-delta,
+ preemptable=preemptable,
+ force=force)
+ else:
+ log.info('Cluster already at desired size of %i. Nothing to do.', numNodes)
+ return numNodes
+
+ def _removeNodes(self, instances, numNodes, preemptable=False, force=False):
+ # If the batch system is scalable, we can use the number of currently running workers on
+ # each node as the primary criterion to select which nodes to terminate.
+ if isinstance(self.batchSystem, AbstractScalableBatchSystem):
+ nodes = self.batchSystem.getNodes(preemptable)
+ # Join nodes and instances on private IP address.
+ nodes = [(instance, nodes.get(instance.private_ip_address)) for instance in instances]
+ log.debug('Nodes considered to terminate: %s', ' '.join(map(str, nodes)))
+ # Unless forced, exclude nodes with runnning workers. Note that it is possible for
+ # the batch system to report stale nodes for which the corresponding instance was
+ # terminated already. There can also be instances that the batch system doesn't have
+ # nodes for yet. We'll ignore those, too, unless forced.
+ nodesToTerminate = []
+ for instance, nodeInfo in nodes:
+ if force:
+ nodesToTerminate.append((instance, nodeInfo))
+ elif nodeInfo is not None and nodeInfo.workers < 1:
+ nodesToTerminate.append((instance, nodeInfo))
+ else:
+ log.debug('Not terminating instances %s. Node info: %s', instance, nodeInfo)
+ # Sort nodes by number of workers and time left in billing cycle
+ nodesToTerminate.sort(key=lambda (instance, nodeInfo): (
+ nodeInfo.workers if nodeInfo else 1,
+ self._remainingBillingInterval(instance)))
+ nodesToTerminate = nodesToTerminate[:numNodes]
+ if log.isEnabledFor(logging.DEBUG):
+ for instance, nodeInfo in nodesToTerminate:
+ log.debug("Instance %s is about to be terminated. Its node info is %r. It "
+ "would be billed again in %s minutes.", instance.id, nodeInfo,
+ 60 * self._remainingBillingInterval(instance))
+ instanceIds = [instance.id for instance, nodeInfo in nodesToTerminate]
+ else:
+ # Without load info all we can do is sort instances by time left in billing cycle.
+ instances = sorted(instances, key=self._remainingBillingInterval)
+ instanceIds = [instance.id for instance in islice(instances, numNodes)]
+ log.info('Terminating %i instance(s).', len(instanceIds))
+ if instanceIds:
+ self._logAndTerminate(instanceIds)
+ return len(instanceIds)
+
+ @abstractmethod
+ def _addNodes(self, instances, numNodes, preemptable):
+ raise NotImplementedError
+
+ @abstractmethod
+ def _logAndTerminate(self, instanceIDs):
+ raise NotImplementedError
+
+ @abstractmethod
+ def _getWorkersInCluster(self, preemptable):
+ raise NotImplementedError
+
+ @abstractmethod
+ def _remainingBillingInterval(self, instance):
+ raise NotImplementedError
+
+ @abstractmethod
+ def getNodeShape(self, preemptable=False):
+ """
+ The shape of a preemptable or non-preemptable node managed by this provisioner. The node
+ shape defines key properties of a machine, such as its number of cores or the time
+ between billing intervals.
+
+ :param preemptable: Whether to return the shape of preemptable nodes or that of
+ non-preemptable ones.
+
+ :rtype: Shape
+ """
+ raise NotImplementedError
+
+ @classmethod
+ @abstractmethod
+ def rsyncLeader(cls, clusterName, src, dst):
+ raise NotImplementedError
+
+ @classmethod
+ @abstractmethod
+ def launchCluster(cls, instanceType, keyName, clusterName, spotBid=None):
+ raise NotImplementedError
+
+ @classmethod
+ @abstractmethod
+ def sshLeader(cls, clusterName, args):
+ raise NotImplementedError
+
+ @classmethod
+ @abstractmethod
+ def destroyCluster(cls, clusterName):
+ raise NotImplementedError
diff --git a/src/toil/provisioners/aws/__init__.py b/src/toil/provisioners/aws/__init__.py
new file mode 100644
index 0000000..f6b855e
--- /dev/null
+++ b/src/toil/provisioners/aws/__init__.py
@@ -0,0 +1,289 @@
+# Copyright (C) 2015 UCSC Computational Genomics Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import os
+from collections import namedtuple
+from operator import attrgetter
+import datetime
+from cgcloud.lib.util import std_dev, mean
+
+logger = logging.getLogger(__name__)
+
+ZoneTuple = namedtuple('ZoneTuple', ['name', 'price_deviation'])
+
+
+def getSpotZone(spotBid, nodeType, ctx):
+ return _getCurrentAWSZone(spotBid, nodeType, ctx)
+
+
+def getCurrentAWSZone():
+ return _getCurrentAWSZone()
+
+
+def _getCurrentAWSZone(spotBid=None, nodeType=None, ctx=None):
+ zone = None
+ try:
+ import boto
+ from boto.utils import get_instance_metadata
+ except ImportError:
+ pass
+ else:
+ zone = os.environ.get('TOIL_AWS_ZONE', None)
+ if spotBid:
+ # if spot bid is present, all the other parameters must be as well
+ assert bool(spotBid) == bool(nodeType) == bool(ctx)
+ # if the zone is unset and we are using the spot market, optimize our
+ # choice based on the spot history
+ return optimize_spot_bid(ctx=ctx, instance_type=nodeType, spot_bid=float(spotBid))
+ if not zone:
+ zone = boto.config.get('Boto', 'ec2_region_name')
+ if zone is not None:
+ zone += 'a' # derive an availability zone in the region
+ if not zone:
+ try:
+ zone = get_instance_metadata()['placement']['availability-zone']
+ except KeyError:
+ pass
+ return zone
+
+
+def choose_spot_zone(zones, bid, spot_history):
+ """
+ Returns the zone to put the spot request based on, in order of priority:
+
+ 1) zones with prices currently under the bid
+
+ 2) zones with the most stable price
+
+ :param list[boto.ec2.zone.Zone] zones:
+ :param float bid:
+ :param list[boto.ec2.spotpricehistory.SpotPriceHistory] spot_history:
+
+ :rtype: str
+ :return: the name of the selected zone
+
+ >>> from collections import namedtuple
+ >>> FauxHistory = namedtuple( 'FauxHistory', [ 'price', 'availability_zone' ] )
+ >>> ZoneTuple = namedtuple( 'ZoneTuple', [ 'name' ] )
+
+ >>> zones = [ ZoneTuple( 'us-west-2a' ), ZoneTuple( 'us-west-2b' ) ]
+ >>> spot_history = [ FauxHistory( 0.1, 'us-west-2a' ), \
+ FauxHistory( 0.2,'us-west-2a'), \
+ FauxHistory( 0.3,'us-west-2b'), \
+ FauxHistory( 0.6,'us-west-2b')]
+ >>> # noinspection PyProtectedMember
+ >>> choose_spot_zone( zones, 0.15, spot_history )
+ 'us-west-2a'
+
+ >>> spot_history=[ FauxHistory( 0.3, 'us-west-2a' ), \
+ FauxHistory( 0.2, 'us-west-2a' ), \
+ FauxHistory( 0.1, 'us-west-2b'), \
+ FauxHistory( 0.6, 'us-west-2b') ]
+ >>> # noinspection PyProtectedMember
+ >>> choose_spot_zone(zones, 0.15, spot_history)
+ 'us-west-2b'
+
+ >>> spot_history={ FauxHistory( 0.1, 'us-west-2a' ), \
+ FauxHistory( 0.7, 'us-west-2a' ), \
+ FauxHistory( 0.1, "us-west-2b" ), \
+ FauxHistory( 0.6, 'us-west-2b' ) }
+ >>> # noinspection PyProtectedMember
+ >>> choose_spot_zone(zones, 0.15, spot_history)
+ 'us-west-2b'
+ """
+
+ # Create two lists of tuples of form: [ (zone.name, std_deviation), ... ] one for zones
+ # over the bid price and one for zones under bid price. Each are sorted by increasing
+ # standard deviation values.
+ #
+ markets_under_bid, markets_over_bid = [], []
+ for zone in zones:
+ zone_histories = filter(lambda zone_history:
+ zone_history.availability_zone == zone.name, spot_history)
+ price_deviation = std_dev([history.price for history in zone_histories])
+ recent_price = zone_histories[0]
+ zone_tuple = ZoneTuple(name=zone.name, price_deviation=price_deviation)
+ (markets_over_bid, markets_under_bid)[recent_price.price < bid].append(zone_tuple)
+
+ return min(markets_under_bid or markets_over_bid,
+ key=attrgetter('price_deviation')).name
+
+
+def optimize_spot_bid(ctx, instance_type, spot_bid):
+ """
+ Check whether the bid is sane and makes an effort to place the instance in a sensible zone.
+ """
+ spot_history = _get_spot_history(ctx, instance_type)
+ _check_spot_bid(spot_bid, spot_history)
+ zones = ctx.ec2.get_all_zones()
+ most_stable_zone = choose_spot_zone(zones, spot_bid, spot_history)
+ logger.info("Placing spot instances in zone %s.", most_stable_zone)
+ return most_stable_zone
+
+
+def _check_spot_bid(spot_bid, spot_history):
+ """
+ Prevents users from potentially over-paying for instances
+
+ Note: this checks over the whole region, not a particular zone
+
+ :param spot_bid: float
+
+ :type spot_history: list[SpotPriceHistory]
+
+ :raises UserError: if bid is > 2X the spot price's average
+
+ >>> from collections import namedtuple
+ >>> FauxHistory = namedtuple( "FauxHistory", [ "price", "availability_zone" ] )
+ >>> spot_data = [ FauxHistory( 0.1, "us-west-2a" ), \
+ FauxHistory( 0.2, "us-west-2a" ), \
+ FauxHistory( 0.3, "us-west-2b" ), \
+ FauxHistory( 0.6, "us-west-2b" ) ]
+ >>> # noinspection PyProtectedMember
+ >>> _check_spot_bid( 0.1, spot_data )
+ >>> # noinspection PyProtectedMember
+
+ # >>> Box._check_spot_bid( 2, spot_data )
+ Traceback (most recent call last):
+ ...
+ UserError: Your bid $ 2.000000 is more than double this instance type's average spot price ($ 0.300000) over the last week
+ """
+ average = mean([datum.price for datum in spot_history])
+ if spot_bid > average * 2:
+ logger.warn("Your bid $ %f is more than double this instance type's average "
+ "spot price ($ %f) over the last week", spot_bid, average)
+
+def _get_spot_history(ctx, instance_type):
+ """
+ Returns list of 1,000 most recent spot market data points represented as SpotPriceHistory
+ objects. Note: The most recent object/data point will be first in the list.
+
+ :rtype: list[SpotPriceHistory]
+ """
+
+ one_week_ago = datetime.datetime.now() - datetime.timedelta(days=7)
+ spot_data = ctx.ec2.get_spot_price_history(start_time=one_week_ago.isoformat(),
+ instance_type=instance_type,
+ product_description="Linux/UNIX")
+ spot_data.sort(key=attrgetter("timestamp"), reverse=True)
+ return spot_data
+
+ec2FullPolicy = dict(Version="2012-10-17", Statement=[
+ dict(Effect="Allow", Resource="*", Action="ec2:*")])
+
+s3FullPolicy = dict(Version="2012-10-17", Statement=[
+ dict(Effect="Allow", Resource="*", Action="s3:*")])
+
+sdbFullPolicy = dict(Version="2012-10-17", Statement=[
+ dict(Effect="Allow", Resource="*", Action="sdb:*")])
+
+iamFullPolicy = dict(Version="2012-10-17", Statement=[
+ dict(Effect="Allow", Resource="*", Action="iam:*")])
+
+
+logDir = '--log_dir=/var/lib/mesos'
+leaderArgs = logDir + ' --registry=in_memory --cluster={name}'
+workerArgs = '{keyPath} --work_dir=/var/lib/mesos --master={ip}:5050 --attributes=preemptable:{preemptable} ' + logDir
+
+awsUserData = """#cloud-config
+
+write_files:
+ - path: "/home/core/volumes.sh"
+ permissions: "0777"
+ owner: "root"
+ content: |
+ #!/bin/bash
+ set -x
+ ephemeral_count=0
+ possible_drives="/dev/xvdb /dev/xvdc /dev/xvdd /dev/xvde"
+ drives=""
+ directories="toil mesos docker"
+ for drive in $possible_drives; do
+ echo checking for $drive
+ if [ -b $drive ]; then
+ echo found it
+ ephemeral_count=$((ephemeral_count + 1 ))
+ drives="$drives $drive"
+ echo increased ephemeral count by one
+ fi
+ done
+ if (("$ephemeral_count" == "0" )); then
+ echo no ephemeral drive
+ for directory in $directories; do
+ sudo mkdir -p /var/lib/$directory
+ done
+ exit 0
+ fi
+ sudo mkdir /mnt/ephemeral
+ if (("$ephemeral_count" == "1" )); then
+ echo one ephemeral drive to mount
+ sudo mkfs.ext4 -F $drives
+ sudo mount $drives /mnt/ephemeral
+ fi
+ if (("$ephemeral_count" > "1" )); then
+ echo multiple drives
+ for drive in $drives; do
+ dd if=/dev/zero of=$drive bs=4096 count=1024
+ done
+ sudo mdadm --create -f --verbose /dev/md0 --level=0 --raid-devices=$ephemeral_count $drives # determine force flag
+ sudo mkfs.ext4 -F /dev/md0
+ sudo mount /dev/md0 /mnt/ephemeral
+ fi
+ for directory in $directories; do
+ sudo mkdir -p /mnt/ephemeral/var/lib/$directory
+ sudo mkdir -p /var/lib/$directory
+ sudo mount --bind /mnt/ephemeral/var/lib/$directory /var/lib/$directory
+ done
+
+coreos:
+ update:
+ reboot-strategy: off
+ units:
+ - name: "volume-mounting.service"
+ command: "start"
+ content: |
+ [Unit]
+ Description=mounts ephemeral volumes & bind mounts toil directories
+ Author=cketchum at ucsc.edu
+ Before=docker.service
+
+ [Service]
+ Type=oneshot
+ Restart=no
+ ExecStart=/usr/bin/bash /home/core/volumes.sh
+
+ - name: "toil-{role}.service"
+ command: "start"
+ content: |
+ [Unit]
+ Description=toil-{role} container
+ Author=cketchum at ucsc.edu
+ After=docker.service
+
+ [Service]
+ Restart=on-failure
+ ExecStart=/usr/bin/docker run \
+ --entrypoint={entrypoint} \
+ --net=host \
+ -v /var/run/docker.sock:/var/run/docker.sock \
+ -v /var/lib/mesos:/var/lib/mesos \
+ -v /var/lib/docker:/var/lib/docker \
+ -v /var/lib/toil:/var/lib/toil \
+ --name=toil_{role} \
+ {image} \
+ {args}
+
+ssh_authorized_keys:
+ - "ssh-rsa {sshKey}"
+"""
diff --git a/src/toil/provisioners/aws/awsProvisioner.py b/src/toil/provisioners/aws/awsProvisioner.py
new file mode 100644
index 0000000..e20dc9f
--- /dev/null
+++ b/src/toil/provisioners/aws/awsProvisioner.py
@@ -0,0 +1,628 @@
+# Copyright (C) 2015 UCSC Computational Genomics Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from StringIO import StringIO
+import pipes
+import socket
+import subprocess
+import logging
+
+import time
+
+import sys
+
+# Python 3 compatibility imports
+from six.moves import xrange
+
+from bd2k.util import memoize
+from boto.ec2.blockdevicemapping import BlockDeviceMapping, BlockDeviceType
+from boto.exception import BotoServerError, EC2ResponseError
+from cgcloud.lib.ec2 import (ec2_instance_types, a_short_time, create_ondemand_instances,
+ create_spot_instances, wait_instances_running)
+from itertools import count
+
+from toil import applianceSelf
+from toil.provisioners.abstractProvisioner import AbstractProvisioner, Shape
+from toil.provisioners.aws import *
+from cgcloud.lib.context import Context
+from boto.utils import get_instance_metadata
+from bd2k.util.retry import retry
+from toil.provisioners import awsRemainingBillingInterval, awsFilterImpairedNodes
+
+logger = logging.getLogger(__name__)
+
+
+class AWSProvisioner(AbstractProvisioner):
+
+ def __init__(self, config, batchSystem):
+ super(AWSProvisioner, self).__init__(config, batchSystem)
+ self.instanceMetaData = get_instance_metadata()
+ self.clusterName = self.instanceMetaData['security-groups']
+ self.ctx = self._buildContext(clusterName=self.clusterName)
+ self.spotBid = None
+ assert config.preemptableNodeType or config.nodeType
+ if config.preemptableNodeType is not None:
+ nodeBidTuple = config.preemptableNodeType.split(':', 1)
+ self.spotBid = nodeBidTuple[1]
+ self.instanceType = ec2_instance_types[nodeBidTuple[0]]
+ else:
+ self.instanceType = ec2_instance_types[config.nodeType]
+ self.leaderIP = self.instanceMetaData['local-ipv4']
+ self.keyName = self.instanceMetaData['public-keys'].keys()[0]
+ self.masterPublicKey = self.setSSH()
+
+ def setSSH(self):
+ if not os.path.exists('/root/.sshSuccess'):
+ subprocess.check_call(['ssh-keygen', '-f', '/root/.ssh/id_rsa', '-t', 'rsa', '-N', ''])
+ with open('/root/.sshSuccess', 'w') as f:
+ f.write('written here because of restrictive permissions on .ssh dir')
+ os.chmod('/root/.ssh', 0o700)
+ subprocess.check_call(['bash', '-c', 'eval $(ssh-agent) && ssh-add -k'])
+ with open('/root/.ssh/id_rsa.pub') as f:
+ masterPublicKey = f.read()
+ masterPublicKey = masterPublicKey.split(' ')[1] # take 'body' of key
+ # confirm it really is an RSA public key
+ assert masterPublicKey.startswith('AAAAB3NzaC1yc2E'), masterPublicKey
+ return masterPublicKey
+
+ def getNodeShape(self, preemptable=False):
+ instanceType = self.instanceType
+ return Shape(wallTime=60 * 60,
+ memory=instanceType.memory * 2 ** 30,
+ cores=instanceType.cores,
+ disk=(instanceType.disks * instanceType.disk_capacity * 2 ** 30))
+
+ @classmethod
+ def _buildContext(cls, clusterName, zone=None):
+ if zone is None:
+ zone = getCurrentAWSZone()
+ if zone is None:
+ raise RuntimeError(
+ 'Could not determine availability zone. Insure that one of the following '
+ 'is true: the --zone flag is set, the TOIL_AWS_ZONE environment variable '
+ 'is set, ec2_region_name is set in the .boto file, or that '
+ 'you are running on EC2.')
+ return Context(availability_zone=zone, namespace=cls._toNameSpace(clusterName))
+
+ @classmethod
+ def sshLeader(cls, clusterName, args=None, zone=None, **kwargs):
+ leader = cls._getLeader(clusterName)
+ logger.info('SSH ready')
+ kwargs['tty'] = sys.stdin.isatty()
+ command = args if args else ['bash']
+ cls._sshAppliance(leader.ip_address, *command, **kwargs)
+
+ def _remainingBillingInterval(self, instance):
+ return awsRemainingBillingInterval(instance)
+
+ @classmethod
+ @memoize
+ def _discoverAMI(cls, ctx):
+ def descriptionMatches(ami):
+ return ami.description is not None and 'stable 1068.9.0' in ami.description
+ coreOSAMI = os.environ.get('TOIL_AWS_AMI')
+ if coreOSAMI is not None:
+ return coreOSAMI
+ # that ownerID corresponds to coreOS
+ coreOSAMI = [ami for ami in ctx.ec2.get_all_images(owners=['679593333241']) if
+ descriptionMatches(ami)]
+ assert len(coreOSAMI) == 1
+ return coreOSAMI.pop().id
+
+ @classmethod
+ def dockerInfo(cls):
+ try:
+ return os.environ['TOIL_APPLIANCE_SELF']
+ except KeyError:
+ raise RuntimeError('Please set TOIL_APPLIANCE_SELF environment variable to the '
+ 'image of the Toil Appliance you wish to use. For example: '
+ "'quay.io/ucsc_cgl/toil:3.5.0a1--80c340c5204bde016440e78e84350e3c13bd1801'. "
+ 'See https://quay.io/repository/ucsc_cgl/toil-leader?tab=tags '
+ 'for a full list of available versions.')
+
+ @classmethod
+ def _sshAppliance(cls, leaderIP, *args, **kwargs):
+ """
+ :param str leaderIP: IP of the master
+ :param args: arguments to execute in the appliance
+ :param kwargs: tty=bool tells docker whether or not to create a TTY shell for
+ interactive SSHing. The default value is False. Input=string is passed as
+ input to the Popen call.
+ """
+ kwargs['appliance'] = True
+ return cls._coreSSH(leaderIP, *args, **kwargs)
+
+
+ @classmethod
+ def _sshInstance(cls, nodeIP, *args, **kwargs):
+ # returns the output from the command
+ kwargs['collectStdout'] = True
+ return cls._coreSSH(nodeIP, *args, **kwargs)
+
+ @classmethod
+ def _coreSSH(cls, nodeIP, *args, **kwargs):
+ """
+ kwargs: input, tty, appliance, collectStdout, sshOptions
+ """
+ commandTokens = ['ssh', '-o', "StrictHostKeyChecking=no", '-t']
+ sshOptions = kwargs.pop('sshOptions', None)
+ if sshOptions:
+ # add specified options to ssh command
+ assert isinstance(sshOptions, list)
+ commandTokens.extend(sshOptions)
+ # specify host
+ commandTokens.append('core@%s' % nodeIP)
+ appliance = kwargs.pop('appliance', None)
+ if appliance:
+ # run the args in the appliance
+ tty = kwargs.pop('tty', None)
+ ttyFlag = '-t' if tty else ''
+ commandTokens += ['docker', 'exec', '-i', ttyFlag, 'toil_leader']
+ inputString = kwargs.pop('input', None)
+ if inputString is not None:
+ kwargs['stdin'] = subprocess.PIPE
+ collectStdout = kwargs.pop('collectStdout', None)
+ if collectStdout:
+ kwargs['stdout'] = subprocess.PIPE
+ logger.debug('Node %s: %s', nodeIP, ' '.join(args))
+ args = map(pipes.quote, args)
+ commandTokens += args
+ logger.debug('Full command %s', ' '.join(commandTokens))
+ popen = subprocess.Popen(commandTokens, **kwargs)
+ stdout, stderr = popen.communicate(input=inputString)
+ # at this point the process has already exited, no need for a timeout
+ resultValue = popen.wait()
+ if resultValue != 0:
+ raise RuntimeError('Executing the command "%s" on the appliance returned a non-zero '
+ 'exit code %s with stdout %s and stderr %s' % (' '.join(args), resultValue, stdout, stderr))
+ assert stderr is None
+ return stdout
+
+ @classmethod
+ def rsyncLeader(cls, clusterName, args):
+ leader = cls._getLeader(clusterName)
+ cls._rsyncNode(leader.ip_address, args)
+
+ @classmethod
+ def _rsyncNode(cls, ip, args, applianceName='toil_leader'):
+ sshCommand = 'ssh -o "StrictHostKeyChecking=no"' # Skip host key checking
+ remoteRsync = "docker exec -i %s rsync" % applianceName # Access rsync inside appliance
+ parsedArgs = []
+ hostInserted = False
+ # Insert remote host address
+ for i in args:
+ if i.startswith(":") and not hostInserted:
+ i = ("core@%s" % ip) + i
+ hostInserted = True
+ elif i.startswith(":") and hostInserted:
+ raise ValueError("Cannot rsync between two remote hosts")
+ parsedArgs.append(i)
+ if not hostInserted:
+ raise ValueError("No remote host found in argument list")
+ command = ['rsync', '-e', sshCommand, '--rsync-path', remoteRsync]
+ logger.debug("Running %r.", command + parsedArgs)
+
+ return subprocess.check_call(command + parsedArgs)
+
+ @classmethod
+ def _toNameSpace(cls, clusterName):
+ assert isinstance(clusterName, str)
+ if any((char.isupper() for char in clusterName)) or '_' in clusterName:
+ raise RuntimeError("The cluster name must be lowercase and cannot contain the '_' "
+ "character.")
+ namespace = clusterName
+ if not namespace.startswith('/'):
+ namespace = '/'+namespace+'/'
+ return namespace.replace('-','/')
+
+ @classmethod
+ def _getLeader(cls, clusterName, wait=False, zone=None):
+ ctx = cls._buildContext(clusterName=clusterName, zone=zone)
+ instances = cls.__getNodesInCluster(ctx, clusterName, both=True)
+ instances.sort(key=lambda x: x.launch_time)
+ leader = instances[0] # assume leader was launched first
+ if wait:
+ logger.info("Waiting for toil_leader to enter 'running' state...")
+ cls._tagWhenRunning(ctx.ec2, [leader], clusterName)
+ logger.info('... toil_leader is running')
+ cls._waitForNode(leader, 'toil_leader')
+ return leader
+
+ @classmethod
+ def _tagWhenRunning(cls, ec2, instances, tag):
+ wait_instances_running(ec2, instances)
+ for instance in instances:
+ instance.add_tag("Name", tag)
+
+ @classmethod
+ def _waitForNode(cls, instance, role):
+ # returns the node's IP
+ cls._waitForIP(instance)
+ instanceIP = instance.ip_address
+ cls._waitForSSHPort(instanceIP)
+ cls._waitForSSHKeys(instanceIP)
+ # wait here so docker commands can be used reliably afterwards
+ cls._waitForDockerDaemon(instanceIP)
+ cls._waitForAppliance(instanceIP, role=role)
+ return instanceIP
+
+ @classmethod
+ def _waitForSSHKeys(cls, instanceIP):
+ # the propagation of public ssh keys vs. opening the SSH port is racey, so this method blocks until
+ # the keys are propagated and the instance can be SSH into
+ while True:
+ try:
+ logger.info('Attempting to establish SSH connection...')
+ cls._sshInstance(instanceIP, 'ps', sshOptions=['-oBatchMode=yes'])
+ except RuntimeError:
+ logger.info('Connection rejected, waiting for public SSH key to be propagated. Trying again in 10s.')
+ time.sleep(10)
+ else:
+ logger.info('...SSH connection established.')
+ # ssh succeeded
+ return
+
+
+ @classmethod
+ def _waitForDockerDaemon(cls, ip_address):
+ logger.info('Waiting for docker on %s to start...', ip_address)
+ while True:
+ output = cls._sshInstance(ip_address, '/usr/bin/ps', 'aux')
+ time.sleep(5)
+ if 'docker daemon' in output:
+ # docker daemon has started
+ break
+ else:
+ logger.info('... Still waiting...')
+ logger.info('Docker daemon running')
+
+ @classmethod
+ def _waitForAppliance(cls, ip_address, role):
+ logger.info('Waiting for %s Toil appliance to start...', role)
+ while True:
+ output = cls._sshInstance(ip_address, '/usr/bin/docker', 'ps')
+ if role in output:
+ logger.info('...Toil appliance started')
+ break
+ else:
+ logger.info('...Still waiting, trying again in 10sec...')
+ time.sleep(10)
+
+ @classmethod
+ def _waitForIP(cls, instance):
+ """
+ Wait until the instances has a public IP address assigned to it.
+
+ :type instance: boto.ec2.instance.Instance
+ """
+ logger.info('Waiting for ip...')
+ while True:
+ time.sleep(a_short_time)
+ instance.update()
+ if instance.ip_address or instance.public_dns_name:
+ logger.info('...got ip')
+ break
+
+ @classmethod
+ def _waitForSSHPort(cls, ip_address):
+ """
+ Wait until the instance represented by this box is accessible via SSH.
+
+ :return: the number of unsuccessful attempts to connect to the port before a the first
+ success
+ """
+ logger.info('Waiting for ssh port to open...')
+ for i in count():
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ try:
+ s.settimeout(a_short_time)
+ s.connect((ip_address, 22))
+ logger.info('...ssh port open')
+ return i
+ except socket.error:
+ pass
+ finally:
+ s.close()
+
+ @classmethod
+ def launchCluster(cls, instanceType, keyName, clusterName, spotBid=None, zone=None):
+ ctx = cls._buildContext(clusterName=clusterName, zone=zone)
+ profileARN = cls._getProfileARN(ctx)
+ # the security group name is used as the cluster identifier
+ cls._createSecurityGroup(ctx, clusterName)
+ bdm = cls._getBlockDeviceMapping(ec2_instance_types[instanceType])
+ leaderData = dict(role='leader',
+ image=applianceSelf(),
+ entrypoint='mesos-master',
+ sshKey='AAAAB3NzaC1yc2Enoauthorizedkeyneeded',
+ args=leaderArgs.format(name=clusterName))
+ userData = awsUserData.format(**leaderData)
+ kwargs = {'key_name': keyName, 'security_groups': [clusterName],
+ 'instance_type': instanceType,
+ 'user_data': userData, 'block_device_map': bdm,
+ 'instance_profile_arn': profileARN}
+ if not spotBid:
+ logger.info('Launching non-preemptable leader')
+ create_ondemand_instances(ctx.ec2, image_id=cls._discoverAMI(ctx),
+ spec=kwargs, num_instances=1)
+ else:
+ logger.info('Launching preemptable leader')
+ # force generator to evaluate
+ list(create_spot_instances(ec2=ctx.ec2,
+ price=spotBid,
+ image_id=cls._discoverAMI(ctx),
+ tags={'clusterName': clusterName},
+ spec=kwargs,
+ num_instances=1))
+ return cls._getLeader(clusterName=clusterName, wait=True)
+
+ @classmethod
+ def destroyCluster(cls, clusterName, zone=None):
+ def expectedShutdownErrors(e):
+ return e.status == 400 and 'dependent object' in e.body
+
+ ctx = cls._buildContext(clusterName=clusterName, zone=zone)
+ instances = cls.__getNodesInCluster(ctx, clusterName, both=True)
+ spotIDs = cls._getSpotRequestIDs(ctx, clusterName)
+ if spotIDs:
+ ctx.ec2.cancel_spot_instance_requests(request_ids=spotIDs)
+ instancesToTerminate = awsFilterImpairedNodes(instances, ctx.ec2)
+ if instancesToTerminate:
+ cls._deleteIAMProfiles(instances=instancesToTerminate, ctx=ctx)
+ cls._terminateInstances(instances=instancesToTerminate, ctx=ctx)
+ if len(instances) == len(instancesToTerminate):
+ logger.info('Deleting security group...')
+ for attempt in retry(timeout=300, predicate=expectedShutdownErrors):
+ with attempt:
+ try:
+ ctx.ec2.delete_security_group(name=clusterName)
+ except BotoServerError as e:
+ if e.error_code == 'InvalidGroup.NotFound':
+ pass
+ else:
+ raise
+ logger.info('... Succesfully deleted security group')
+ else:
+ assert len(instances) > len(instancesToTerminate)
+ # the security group can't be deleted until all nodes are terminated
+ logger.warning('The TOIL_AWS_NODE_DEBUG environment variable is set and some nodes '
+ 'have failed health checks. As a result, the security group & IAM '
+ 'roles will not be deleted.')
+
+ @classmethod
+ def _terminateInstances(cls, instances, ctx):
+ instanceIDs = [x.id for x in instances]
+ cls._terminateIDs(instanceIDs, ctx)
+
+ @classmethod
+ def _terminateIDs(cls, instanceIDs, ctx):
+ logger.info('Terminating instance(s): %s', instanceIDs)
+ ctx.ec2.terminate_instances(instance_ids=instanceIDs)
+ logger.info('Instance(s) terminated.')
+
+ def _logAndTerminate(self, instanceIDs):
+ self._terminateIDs(instanceIDs, self.ctx)
+
+ @classmethod
+ def _deleteIAMProfiles(cls, instances, ctx):
+ instanceProfiles = [x.instance_profile['arn'] for x in instances]
+ for profile in instanceProfiles:
+ # boto won't look things up by the ARN so we have to parse it to get
+ # the profile name
+ profileName = profile.rsplit('/')[-1]
+ try:
+ profileResult = ctx.iam.get_instance_profile(profileName)
+ except BotoServerError as e:
+ if e.status == 404:
+ return
+ else:
+ raise
+ # wade through EC2 response object to get what we want
+ profileResult = profileResult['get_instance_profile_response']
+ profileResult = profileResult['get_instance_profile_result']
+ profile = profileResult['instance_profile']
+ # this is based off of our 1:1 mapping of profiles to roles
+ role = profile['roles']['member']['role_name']
+ try:
+ ctx.iam.remove_role_from_instance_profile(profileName, role)
+ except BotoServerError as e:
+ if e.status == 404:
+ pass
+ else:
+ raise
+ policyResults = ctx.iam.list_role_policies(role)
+ policyResults = policyResults['list_role_policies_response']
+ policyResults = policyResults['list_role_policies_result']
+ policies = policyResults['policy_names']
+ for policyName in policies:
+ try:
+ ctx.iam.delete_role_policy(role, policyName)
+ except BotoServerError as e:
+ if e.status == 404:
+ pass
+ else:
+ raise
+ try:
+ ctx.iam.delete_role(role)
+ except BotoServerError as e:
+ if e.status == 404:
+ pass
+ else:
+ raise
+ try:
+ ctx.iam.delete_instance_profile(profileName)
+ except BotoServerError as e:
+ if e.status == 404:
+ pass
+ else:
+ raise
+
+ def _addNodes(self, instances, numNodes, preemptable=False):
+ bdm = self._getBlockDeviceMapping(self.instanceType)
+ arn = self._getProfileARN(self.ctx)
+ keyPath = '' if not self.config.sseKey else self.config.sseKey
+ entryPoint = 'mesos-slave' if not self.config.sseKey else "waitForKey.sh"
+ workerData = dict(role='worker',
+ image=applianceSelf(),
+ entrypoint=entryPoint,
+ sshKey=self.masterPublicKey,
+ args=workerArgs.format(ip=self.leaderIP, preemptable=preemptable, keyPath=keyPath))
+ userData = awsUserData.format(**workerData)
+ kwargs = {'key_name': self.keyName,
+ 'security_groups': [self.clusterName],
+ 'instance_type': self.instanceType.name,
+ 'user_data': userData,
+ 'block_device_map': bdm,
+ 'instance_profile_arn': arn}
+
+ instancesLaunched = []
+
+ if not preemptable:
+ logger.info('Launching %s non-preemptable nodes', numNodes)
+ instancesLaunched = create_ondemand_instances(self.ctx.ec2, image_id=self._discoverAMI(self.ctx),
+ spec=kwargs, num_instances=1)
+ else:
+ logger.info('Launching %s preemptable nodes', numNodes)
+ kwargs['placement'] = getSpotZone(self.spotBid, self.instanceType.name, self.ctx)
+ # force generator to evaluate
+ instancesLaunched = list(create_spot_instances(ec2=self.ctx.ec2,
+ price=self.spotBid,
+ image_id=self._discoverAMI(self.ctx),
+ tags={'clusterName': self.clusterName},
+ spec=kwargs,
+ num_instances=numNodes,
+ tentative=True)
+ )
+ # flatten the list
+ instancesLaunched = [item for sublist in instancesLaunched for item in sublist]
+ self._tagWhenRunning(self.ctx.ec2, instancesLaunched, self.clusterName)
+ self._propagateKey(instancesLaunched)
+ logger.info('Launched %s new instance(s)', numNodes)
+ return len(instancesLaunched)
+
+ def _propagateKey(self, instances):
+ if not self.config.sseKey:
+ return
+ for node in instances:
+ # since we're going to be rsyncing into the appliance we need the appliance to be running first
+ ipAddress = self._waitForNode(node, 'toil_worker')
+ self._rsyncNode(ipAddress, [self.config.sseKey, ':' + self.config.sseKey], applianceName='toil_worker')
+
+ @classmethod
+ def _getBlockDeviceMapping(cls, instanceType):
+ # determine number of ephemeral drives via cgcloud-lib
+ bdtKeys = ['', '/dev/xvdb', '/dev/xvdc', '/dev/xvdd']
+ bdm = BlockDeviceMapping()
+ # the first disk is already attached for us so start with 2nd.
+ for disk in xrange(1, instanceType.disks + 1):
+ bdm[bdtKeys[disk]] = BlockDeviceType(
+ ephemeral_name='ephemeral{}'.format(disk - 1)) # ephemeral counts start at 0
+
+ logger.debug('Device mapping: %s', bdm)
+ return bdm
+
+ @classmethod
+ def __getNodesInCluster(cls, ctx, clusterName, preemptable=False, both=False):
+ pendingInstances = ctx.ec2.get_only_instances(filters={'instance.group-name': clusterName,
+ 'instance-state-name': 'pending'})
+ runningInstances = ctx.ec2.get_only_instances(filters={'instance.group-name': clusterName,
+ 'instance-state-name': 'running'})
+ instances = set(pendingInstances)
+ if not preemptable and not both:
+ return [x for x in instances.union(set(runningInstances)) if x.spot_instance_request_id is None]
+ elif preemptable and not both:
+ return [x for x in instances.union(set(runningInstances)) if x.spot_instance_request_id is not None]
+ elif both:
+ return [x for x in instances.union(set(runningInstances))]
+
+ def _getNodesInCluster(self, preeptable=False, both=False):
+ if not both:
+ return self.__getNodesInCluster(self.ctx, self.clusterName, preemptable=preeptable)
+ else:
+ return self.__getNodesInCluster(self.ctx, self.clusterName, both=both)
+
+ def _getWorkersInCluster(self, preemptable):
+ entireCluster = self._getNodesInCluster(both=True)
+ logger.debug('All nodes in cluster %s', entireCluster)
+ workerInstances = [i for i in entireCluster if i.private_ip_address != self.leaderIP and
+ preemptable != (i.spot_instance_request_id is None)]
+ logger.debug('Workers found in cluster %s', workerInstances)
+ workerInstances = awsFilterImpairedNodes(workerInstances, self.ctx.ec2)
+ return workerInstances
+
+ @classmethod
+ def _getSpotRequestIDs(cls, ctx, clusterName):
+ requests = ctx.ec2.get_all_spot_instance_requests()
+ tags = ctx.ec2.get_all_tags({'tag:': {'clusterName': clusterName}})
+ idsToCancel = [tag.id for tag in tags]
+ return [request for request in requests if request.id in idsToCancel]
+
+ @classmethod
+ def _createSecurityGroup(cls, ctx, name):
+ def groupNotFound(e):
+ retry = (e.status == 400 and 'does not exist in default VPC' in e.body)
+ return retry
+
+ # security group create/get. ssh + all ports open within the group
+ try:
+ web = ctx.ec2.create_security_group(name, 'Toil appliance security group')
+ except EC2ResponseError as e:
+ if e.status == 400 and 'already exists' in e.body:
+ pass # group exists- nothing to do
+ else:
+ raise
+ else:
+ for attempt in retry(predicate=groupNotFound, timeout=300):
+ with attempt:
+ # open port 22 for ssh-ing
+ web.authorize(ip_protocol='tcp', from_port=22, to_port=22, cidr_ip='0.0.0.0/0')
+ for attempt in retry(predicate=groupNotFound, timeout=300):
+ with attempt:
+ # the following authorizes all port access within the web security group
+ web.authorize(ip_protocol='tcp', from_port=0, to_port=65535, src_group=web)
+
+ @classmethod
+ def _getProfileARN(cls, ctx):
+ def addRoleErrors(e):
+ return e.status == 404
+ roleName = '-toil'
+ policy = dict(iam_full=iamFullPolicy, ec2_full=ec2FullPolicy,
+ s3_full=s3FullPolicy, sbd_full=sdbFullPolicy)
+ iamRoleName = ctx.setup_iam_ec2_role(role_name=roleName, policies=policy)
+
+ try:
+ profile = ctx.iam.get_instance_profile(iamRoleName)
+ except BotoServerError as e:
+ if e.status == 404:
+ profile = ctx.iam.create_instance_profile(iamRoleName)
+ profile = profile.create_instance_profile_response.create_instance_profile_result
+ else:
+ raise
+ else:
+ profile = profile.get_instance_profile_response.get_instance_profile_result
+ profile = profile.instance_profile
+ profile_arn = profile.arn
+
+ if len(profile.roles) > 1:
+ raise RuntimeError('Did not expect profile to contain more than one role')
+ elif len(profile.roles) == 1:
+ # this should be profile.roles[0].role_name
+ if profile.roles.member.role_name == iamRoleName:
+ return profile_arn
+ else:
+ ctx.iam.remove_role_from_instance_profile(iamRoleName,
+ profile.roles.member.role_name)
+ for attempt in retry(predicate=addRoleErrors):
+ with attempt:
+ ctx.iam.add_role_to_instance_profile(iamRoleName, iamRoleName)
+ return profile_arn
diff --git a/src/toil/provisioners/cgcloud/__init__.py b/src/toil/provisioners/cgcloud/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/toil/provisioners/cgcloud/provisioner.py b/src/toil/provisioners/cgcloud/provisioner.py
new file mode 100644
index 0000000..d1ed24a
--- /dev/null
+++ b/src/toil/provisioners/cgcloud/provisioner.py
@@ -0,0 +1,338 @@
+# Copyright (C) 2015-2016 Regents of the University of California
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import re
+import time
+from collections import Iterable
+
+# Python 3 compatibility imports
+from six import iterkeys, itervalues
+from six.moves.urllib.request import urlopen
+
+import boto.ec2
+from bd2k.util import memoize, parse_iso_utc, less_strict_bool
+from bd2k.util.exceptions import require
+from bd2k.util.throttle import throttle
+from boto.ec2.instance import Instance
+from cgcloud.lib.ec2 import (ec2_instance_types,
+ create_spot_instances,
+ create_ondemand_instances,
+ tag_object_persistently)
+from cgcloud.lib.util import (allocate_cluster_ordinals,
+ thread_pool)
+
+from toil.batchSystems.abstractBatchSystem import (AbstractScalableBatchSystem,
+ AbstractBatchSystem)
+from toil.common import Config
+from toil.provisioners import awsRemainingBillingInterval, awsFilterImpairedNodes
+from toil.provisioners.abstractProvisioner import (AbstractProvisioner,
+ Shape)
+
+log = logging.getLogger(__name__)
+
+# The maximum time to allow for instance creation and nodes joining the cluster. Note that it may
+# take twice that time to provision preemptable instances since the timeout is reset after spot
+# instance creation.
+#
+provisioning_timeout = 10 * 60
+
+
+class CGCloudProvisioner(AbstractProvisioner):
+ """
+ A provisioner that uses CGCloud's toil-box role to boot up worker nodes in EC2. It uses the
+ spot market to provision preemptable instances, but defaults to on-demand instances.
+
+ This provisioner assumes that
+
+ * It is running the leader node, i.e. the EC2 instance that's running the Mesos master
+ process and the Toil leader process.
+
+ * The leader node was recreated from a toil-box image using CGCloud.
+
+ * The version of cgcloud used to create the toil-box image is compatible with the one this
+ provisioner, and therefore Toil depend on.
+
+ * The SSH keypair, security group and user data applied to the leader also apply to the workers
+
+ * An instance type with ephemeral volumes is being used (it asserts that assumption)
+ """
+
+ def __init__(self, config, batchSystem):
+ """
+ :type config: Config
+ :type batchSystem: AbstractBatchSystem
+ """
+ super(CGCloudProvisioner, self).__init__(config, batchSystem)
+ self.batchSystem = batchSystem
+ self.imageId = self._instance.image_id
+ require(config.nodeType, 'Must pass --nodeType when using the cgcloud provisioner')
+ instanceType = self._resolveInstanceType(config.nodeType)
+ self._requireEphemeralDrives(instanceType)
+ if config.preemptableNodeType:
+ try:
+ preemptableInstanceType, spotBid = config.preemptableNodeType.split(':')
+ except ValueError:
+ raise ValueError("Preemptible node type '%s' is not valid for this provisioner. "
+ "Use format INSTANCE_TYPE:SPOT_BID, e.g. m3.large:0.10 instead"
+ % config.preemptableNodeType)
+ preemptableInstanceType = self._resolveInstanceType(preemptableInstanceType)
+ self._requireEphemeralDrives(preemptableInstanceType)
+ try:
+ self.spotBid = float(spotBid)
+ except ValueError:
+ raise ValueError("The spot bid '%s' is not valid. Use a floating point dollar "
+ "amount such as '0.42' instead." % spotBid)
+ else:
+ preemptableInstanceType, self.spotBid = None, None
+ self.instanceType = {False: instanceType, True: preemptableInstanceType}
+
+ def _requireEphemeralDrives(self, workerType):
+ require(workerType.disks > 0,
+ "This provisioner only supports instance types with one or more ephemeral "
+ "volumes. The requested type '%s' does not have any.", workerType.name)
+ leaderType = self._resolveInstanceType(self._instance.instance_type)
+ require(workerType.disks == leaderType.disks,
+ 'The instance type selected for worker nodes (%s) offers %i ephemeral volumes but '
+ 'this type of leader (%s) has %i. The number of drives must match between leader '
+ 'and worker nodes. Please specify a different worker node type or use a different '
+ 'leader.', workerType.name, workerType.disks, leaderType.name, leaderType.disks)
+
+ def _resolveInstanceType(self, instanceType):
+ """
+ :param str instanceType: the instance type as a string, e.g. 'm3.large'
+ :rtype: cgcloud.lib.ec2.InstanceType
+ """
+ try:
+ return ec2_instance_types[instanceType]
+ except KeyError:
+ raise RuntimeError("Invalid or unknown instance type '%s'" % instanceType)
+
+ def _getWorkersInCluster(self, preemptable):
+ instances = list(self._getAllRunningInstances())
+ workerInstances = [i for i in instances
+ if i.id != self._instanceId # exclude leader
+ and preemptable != (i.spot_instance_request_id is None)]
+ instancesToTerminate = awsFilterImpairedNodes(workerInstances, self._ec2)
+ return instancesToTerminate
+
+ @classmethod
+ def launchCluster(cls, instanceType, keyName, clusterName, spotBid=None):
+ raise NotImplementedError
+
+ @classmethod
+ def sshLeader(cls, clusterName, args):
+ raise NotImplementedError
+
+ @classmethod
+ def destroyCluster(cls, clusterName):
+ raise NotImplementedError
+
+ def _remainingBillingInterval(self, instance):
+ return awsRemainingBillingInterval(instance)
+
+ def _addNodes(self, instances, numNodes, preemptable=False):
+ deadline = time.time() + provisioning_timeout
+ spec = dict(key_name=self._keyName,
+ user_data=self._userData(),
+ instance_type=self.instanceType[preemptable].name,
+ instance_profile_arn=self._instanceProfileArn,
+ security_group_ids=self._securityGroupIds,
+ ebs_optimized=self.ebsOptimized,
+ dry_run=False)
+ # Offset the ordinals of the preemptable nodes to be disjunct from the non-preemptable
+ # ones. Without this, the two scaler threads would inevitably allocate colliding ordinals.
+ offset = 1000 if preemptable else 0
+ used_ordinals = {int(i.tags['cluster_ordinal']) - offset for i in instances}
+ # Since leader is absent from the instances iterable, we need to explicitly reserve its
+ # ordinal unless we're allocating offset ordinals reserved for preemptable instances:
+ assert len(used_ordinals) == len(instances) # check for collisions
+ if not preemptable:
+ used_ordinals.add(0)
+ ordinals = (ordinal + offset for ordinal in allocate_cluster_ordinals(num=numNodes,
+ used=used_ordinals))
+
+ def createInstances():
+ """
+ :rtype: Iterable[list[Instance]]
+ """
+ if preemptable:
+ for batch in create_spot_instances(self._ec2, self.spotBid, self.imageId, spec,
+ # Don't insist on spot requests and don't raise
+ # if no requests were fulfilled:
+ tentative=True,
+ num_instances=numNodes,
+ timeout=deadline - time.time()):
+ yield batch
+ else:
+ yield create_ondemand_instances(self._ec2, self.imageId, spec,
+ num_instances=numNodes)
+
+ instancesByAddress = {}
+
+ def handleInstance(instance):
+ log.debug('Tagging instance %s.', instance.id)
+ leader_tags = self._instance.tags
+ name = leader_tags['Name'].replace('toil-leader', 'toil-worker')
+ tag_object_persistently(instance, dict(leader_tags,
+ Name=name,
+ cluster_ordinal=next(ordinals)))
+ assert instance.private_ip_address
+ instancesByAddress[instance.private_ip_address] = instance
+
+ # Each instance gets a different ordinal so we can't tag an entire batch at once but have
+ # to tag each instance individually. It needs to be done quickly because the tags are
+ # crucial for the boot code running inside the instance to join the cluster. Hence we do
+ # it in a thread pool. If the pool is too large, we'll hit the EC2 limit on the number of
+ # of concurrent requests. If it is too small, we won't be able to tag all instances in
+ # time.
+ with thread_pool(min(numNodes, 32)) as pool:
+ for batch in createInstances():
+ log.debug('Got a batch of %i instance(s).', len(batch))
+ for instance in batch:
+ log.debug('Submitting instance %s to thread pool for tagging.', instance.id)
+ pool.apply_async(handleInstance, (instance,))
+ numInstancesAdded = len(instancesByAddress)
+ log.info('Created and tagged %i instance(s).', numInstancesAdded)
+
+ if preemptable:
+ # Reset deadline such that slow spot creation does not take away from instance boot-up
+ deadline = time.time() + provisioning_timeout
+ if isinstance(self.batchSystem, AbstractScalableBatchSystem):
+ while instancesByAddress and time.time() < deadline:
+ with throttle(10):
+ log.debug('Waiting for batch system to report back %i node(s).',
+ len(instancesByAddress))
+ # Get all nodes to be safe, not just the ones whose preemptability matches,
+ # in case there's a problem with a node determining its own preemptability.
+ nodes = self.batchSystem.getNodes()
+ for nodeAddress in iterkeys(nodes):
+ instancesByAddress.pop(nodeAddress, None)
+ if instancesByAddress:
+ log.warn('%i instance(s) out of %i did not join the cluster as worker nodes. They '
+ 'will be terminated.', len(instancesByAddress), numInstancesAdded)
+ instanceIds = [i.id for i in itervalues(instancesByAddress)]
+ self._logAndTerminate(instanceIds)
+ numInstancesAdded -= len(instanceIds)
+ else:
+ log.info('All %i node(s) joined the cluster.', numInstancesAdded)
+ else:
+ log.warn('Batch system is not scalable. Assuming all instances joined the cluster.')
+ return numInstancesAdded
+
+ def _logAndTerminate(self, instanceIds):
+ log.debug('IDs of terminated instances: %r', instanceIds)
+ self._ec2.terminate_instances(instance_ids=instanceIds)
+
+ def getNodeShape(self, preemptable=False):
+ instanceType = self.instanceType[preemptable]
+ return Shape(wallTime=60 * 60,
+ memory=instanceType.memory * 2 ** 30,
+ cores=instanceType.cores,
+ disk=(instanceType.disks * instanceType.disk_capacity * 2 ** 30))
+
+ def _getAllRunningInstances(self):
+ """
+ ... including the leader.
+
+ :rtype: Iterable[Instance]
+ """
+ return self._ec2.get_only_instances(filters={
+ 'tag:leader_instance_id': self._instanceId,
+ 'instance-state-name': 'running'})
+
+ @classmethod
+ def _instanceData(cls, path):
+ return urlopen('http://169.254.169.254/latest/' + path).read()
+
+ @classmethod
+ def _metaData(cls, path):
+ return cls._instanceData('meta-data/' + path)
+
+ @classmethod
+ def _userData(cls):
+ user_data = cls._instanceData('user-data')
+ log.info("User data is '%s'", user_data)
+ return user_data
+
+ @property
+ @memoize
+ def _nodeIP(self):
+ ip = self._metaData('local-ipv4')
+ log.info("Local IP is '%s'", ip)
+ return ip
+
+ @property
+ @memoize
+ def _instanceId(self):
+ instance_id = self._metaData('instance-id')
+ log.info("Instance ID is '%s'", instance_id)
+ return instance_id
+
+ @property
+ @memoize
+ def _availabilityZone(self):
+ zone = self._metaData('placement/availability-zone')
+ log.info("Availability zone is '%s'", zone)
+ return zone
+
+ @property
+ @memoize
+ def _region(self):
+ m = re.match(r'^([a-z]{2}-[a-z]+-[1-9][0-9]*)([a-z])$', self._availabilityZone)
+ assert m
+ region = m.group(1)
+ log.info("Region is '%s'", region)
+ return region
+
+ @property
+ @memoize
+ def _ec2(self):
+ return boto.ec2.connect_to_region(self._region)
+
+ @property
+ @memoize
+ def _keyName(self):
+ return self._instance.key_name
+
+ @property
+ @memoize
+ def _instance(self):
+ return self._getInstance(self._instanceId)
+
+ @property
+ @memoize
+ def _securityGroupIds(self):
+ return [sg.id for sg in self._instance.groups]
+
+ @property
+ @memoize
+ def _instanceProfileArn(self):
+ return self._instance.instance_profile['arn']
+
+ def _getInstance(self, instance_id):
+ """
+ :rtype: Instance
+ """
+ reservations = self._ec2.get_all_reservations(instance_ids=[instance_id])
+ instances = (i for r in reservations for i in r.instances if i.id == instance_id)
+ instance = next(instances)
+ assert next(instances, None) is None
+ return instance
+
+ @property
+ @memoize
+ def ebsOptimized(self):
+ return self._instance.ebs_optimized
diff --git a/src/toil/provisioners/clusterScaler.py b/src/toil/provisioners/clusterScaler.py
new file mode 100644
index 0000000..4facdbb
--- /dev/null
+++ b/src/toil/provisioners/clusterScaler.py
@@ -0,0 +1,423 @@
+# Copyright (C) 2015-2016 Regents of the University of California
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+
+import logging
+from collections import deque
+from threading import Lock
+
+from bd2k.util.exceptions import require
+from bd2k.util.threading import ExceptionalThread
+from bd2k.util.throttle import throttle
+
+from toil.batchSystems.abstractBatchSystem import AbstractScalableBatchSystem
+from toil.common import Config
+from toil.provisioners.abstractProvisioner import AbstractProvisioner, Shape
+
+logger = logging.getLogger(__name__)
+
+# A *deficit* exists when we have more jobs that can run on preemptable nodes than we have
+# preemptable nodes. In order to not block these jobs, we want to increase the number of non-
+# preemptable nodes that we have and need for just non-preemptable jobs. However, we may still
+# prefer waiting for preemptable instances to come available.
+#
+# To accommodate this, we set the delta to the difference between the number of provisioned
+# preemptable nodes and the number of nodes that were requested. when the non-preemptable thread
+# wants to provision nodes, it will multiply this delta times a preference for preemptable vs.
+# non-preemptable nodes.
+
+_preemptableNodeDeficit = 0
+
+class RecentJobShapes(object):
+ """
+ Used to track the 'shapes' of the last N jobs run (see Shape).
+ """
+
+ def __init__(self, config, nodeShape, N=1000):
+ # As a prior we start of with 10 jobs each with the default memory, cores, and disk. To
+ # estimate the running time we use the the default wall time of each node allocation,
+ # so that one job will fill the time per node.
+ self.jobShapes = deque(maxlen=N,
+ iterable=10 * [Shape(wallTime=nodeShape.wallTime,
+ memory=config.defaultMemory,
+ cores=config.defaultCores,
+ disk=config.defaultDisk)])
+ # Calls to add and getLastNJobShapes may be concurrent
+ self.lock = Lock()
+ # Number of jobs to average over
+ self.N = N
+
+ def add(self, jobShape):
+ """
+ Adds a job shape as the last completed job.
+ :param Shape jobShape: The memory, core and disk requirements of the completed job
+ """
+ with self.lock:
+ self.jobShapes.append(jobShape)
+
+ def get(self):
+ """
+ Gets the last N job shapes added.
+ """
+ with self.lock:
+ return list(self.jobShapes)
+
+
+def binPacking(jobShapes, nodeShape):
+ """
+ Use a first fit decreasing (FFD) bin packing like algorithm to calculate an approximate
+ minimum number of nodes that will fit the given list of jobs.
+ :param Shape nodeShape: The properties of an atomic node allocation, in terms of wall-time,
+ memory, cores and local disk.
+ :param list[Shape] jobShapes: A list of shapes, each representing a job.
+ Let a *node reservation* be an interval of time that a node is reserved for, it is defined by
+ an integer number of node-allocations.
+ For a node reservation its *jobs* are the set of jobs that will be run within the node
+ reservation.
+ A minimal node reservation has time equal to one atomic node allocation, or the minimum
+ number node allocations to run the longest running job in its jobs.
+ :rtype: int
+ :returns: The minimum number of minimal node allocations estimated to be required to run all
+ the jobs in jobShapes.
+ """
+ logger.debug('Running bin packing for node shape %s and %s job(s).', nodeShape, len(jobShapes))
+ # Sort in descending order from largest to smallest. The FFD like-strategy will pack the jobs in order from longest
+ # to shortest.
+ jobShapes.sort()
+ jobShapes.reverse()
+ assert len(jobShapes) == 0 or jobShapes[0] >= jobShapes[-1]
+
+ class NodeReservation(object):
+ """
+ Represents a node reservation. To represent the resources available in a reservation a
+ node reservation is represented as a sequence of Shapes, each giving the resources free
+ within the given interval of time
+ """
+
+ def __init__(self, shape):
+ # The wall-time and resource available
+ self.shape = shape
+ # The next portion of the reservation
+ self.nReservation = None
+
+ nodeReservations = [] # The list of node reservations
+
+ for jS in jobShapes:
+ def addToReservation():
+ """
+ Function adds the job, jS, to the first node reservation in which it will fit (this
+ is the bin-packing aspect)
+ """
+
+ def fits(x, y):
+ """
+ Check if a job shape's resource requirements will fit within a given node allocation
+ """
+ return y.memory <= x.memory and y.cores <= x.cores and y.disk <= x.disk
+
+ def subtract(x, y):
+ """
+ Adjust available resources of a node allocation as a job is scheduled within it.
+ """
+ return Shape(x.wallTime, x.memory - y.memory, x.cores - y.cores, x.disk - y.disk)
+
+ def split(x, y, t):
+ """
+ Partition a node allocation into two
+ """
+ return (Shape(t, x.memory - y.memory, x.cores - y.cores, x.disk - y.disk),
+ NodeReservation(Shape(x.wallTime - t, x.memory, x.cores, x.disk)))
+
+ i = 0 # Index of node reservation
+ while True:
+ # Case a new node reservation is required
+ if i == len(nodeReservations):
+ x = NodeReservation(subtract(nodeShape, jS))
+ nodeReservations.append(x)
+ t = nodeShape.wallTime
+ while t < jS.wallTime:
+ y = NodeReservation(x.shape)
+ t += nodeShape.wallTime
+ x.nReservation = y
+ x = y
+ return
+
+ # Attempt to add the job to node reservation i
+ x = nodeReservations[i]
+ y = x
+ t = 0
+
+ while True:
+ if fits(y.shape, jS):
+ t += y.shape.wallTime
+
+ # If the jS fits in the node allocation from x to y
+ if t >= jS.wallTime:
+ t = 0
+ while x != y:
+ x.shape = subtract(x.shape, jS)
+ t += x.shape.wallTime
+ x = x.nReservation
+ assert x == y
+ assert jS.wallTime - t <= x.shape.wallTime
+ if jS.wallTime - t < x.shape.wallTime:
+ x.shape, nS = split(x.shape, jS, jS.wallTime - t)
+ nS.nReservation = x.nReservation
+ x.nReservation = nS
+ else:
+ assert jS.wallTime - t == x.shape.wallTime
+ x.shape = subtract(x.shape, jS)
+ return
+
+ # If the job would fit, but is longer than the total node allocation
+ # extend the node allocation
+ elif y.nReservation == None and x == nodeReservations[i]:
+ # Extend the node reservation to accommodate jS
+ y.nReservation = NodeReservation(nodeShape)
+
+ else: # Does not fit, reset
+ x = y.nReservation
+ t = 0
+
+ y = y.nReservation
+ if y is None:
+ # Reached the end of the reservation without success so stop trying to
+ # add to reservation i
+ break
+ i += 1
+
+ addToReservation()
+ logger.debug("Done running bin packing for node shape %s and %s job(s) resulting in %s node "
+ "reservations.", nodeShape, len(jobShapes), len(nodeReservations))
+ return len(nodeReservations)
+
+
+class ClusterScaler(object):
+ def __init__(self, provisioner, leader, config):
+ """
+ Class manages automatically scaling the number of worker nodes.
+ :param AbstractProvisioner provisioner: Provisioner instance to scale.
+ :param toil.leader.Leader leader:
+ :param Config config: Config object from which to draw parameters.
+ """
+ self.provisioner = provisioner
+ self.leader = leader
+ self.config = config
+ # Indicates that the scaling threads should shutdown
+ self.stop = False
+
+ assert config.maxPreemptableNodes >= 0 and config.maxNodes >= 0
+ require(config.maxPreemptableNodes + config.maxNodes > 0,
+ 'Either --maxNodes or --maxPreemptableNodes must be non-zero.')
+
+ self.preemptableScaler = ScalerThread(self, preemptable=True) if self.config.maxPreemptableNodes > 0 else None
+
+ self.scaler = ScalerThread(self, preemptable=False) if self.config.maxNodes > 0 else None
+
+ def start(self):
+ """
+ Start the cluster scaler thread(s).
+ """
+ if self.preemptableScaler != None:
+ self.preemptableScaler.start()
+
+ if self.scaler != None:
+ self.scaler.start()
+
+ def check(self):
+ """
+ Attempt to join any existing scaler threads that may have died or finished. This insures
+ any exceptions raised in the threads are propagated in a timely fashion.
+ """
+ exception = False
+ for scalerThread in [self.preemptableScaler, self.scaler]:
+ if scalerThread is not None:
+ try:
+ scalerThread.join(timeout=0)
+ except Exception as e:
+ logger.exception(e)
+ exception = True
+ if exception:
+ raise RuntimeError('The cluster scaler has exited due to an exception')
+
+ def shutdown(self):
+ """
+ Shutdown the cluster.
+ """
+ self.stop = True
+ for scaler in self.preemptableScaler, self.scaler:
+ if scaler is not None:
+ scaler.join()
+
+ def addCompletedJob(self, job, wallTime):
+ """
+ Adds the shape of a completed job to the queue, allowing the scalar to use the last N
+ completed jobs in factoring how many nodes are required in the cluster.
+ :param toil.job.JobNode job: The memory, core and disk requirements of the completed job
+ :param int wallTime: The wall-time taken to complete the job in seconds.
+ """
+ s = Shape(wallTime=wallTime, memory=job.memory, cores=job.cores, disk=job.disk)
+ if job.preemptable and self.preemptableScaler is not None:
+ self.preemptableScaler.jobShapes.add(s)
+ else:
+ self.scaler.jobShapes.add(s)
+
+
+class ScalerThread(ExceptionalThread):
+ """
+ A thread that automatically scales the number of either preemptable or non-preemptable worker
+ nodes according to the number of jobs queued and the resource requirements of the last N
+ completed jobs.
+ The scaling calculation is essentially as follows: Use the RecentJobShapes instance to
+ calculate how many nodes, n, can be used to productively compute the last N completed
+ jobs. Let M be the number of jobs issued to the batch system. The number of nodes
+ required is then estimated to be alpha * n * M/N, where alpha is a scaling factor used to
+ adjust the balance between under- and over- provisioning the cluster.
+ At each scaling decision point a comparison between the current, C, and newly estimated
+ number of nodes is made. If the absolute difference is less than beta * C then no change
+ is made, else the size of the cluster is adapted. The beta factor is an inertia parameter
+ that prevents continual fluctuations in the number of nodes.
+ """
+ def __init__(self, scaler, preemptable):
+ """
+ :param ClusterScaler scaler: the parent class
+ """
+ super(ScalerThread, self).__init__(name='preemptable-scaler' if preemptable else 'scaler')
+ self.scaler = scaler
+ self.preemptable = preemptable
+ self.nodeTypeString = ("preemptable" if self.preemptable else "non-preemptable") + " nodes" # Used for logging
+ # Resource requirements and wall-time of an atomic node allocation
+ self.nodeShape = scaler.provisioner.getNodeShape(preemptable=preemptable)
+ # Monitors the requirements of the N most recently completed jobs
+ self.jobShapes = RecentJobShapes(scaler.config, self.nodeShape)
+ # Minimum/maximum number of either preemptable or non-preemptable nodes in the cluster
+ self.minNodes = scaler.config.minPreemptableNodes if preemptable else scaler.config.minNodes
+ self.maxNodes = scaler.config.maxPreemptableNodes if preemptable else scaler.config.maxNodes
+ if isinstance(self.scaler.leader.batchSystem, AbstractScalableBatchSystem):
+ self.totalNodes = len(self.scaler.leader.batchSystem.getNodes(self.preemptable))
+ else:
+ self.totalNodes = 0
+ logger.info('Starting with %s %s(s) in the cluster.', self.totalNodes, self.nodeTypeString)
+
+ if scaler.config.clusterStats:
+ self.scaler.provisioner.startStats(preemptable=preemptable)
+
+ def tryRun(self):
+ global _preemptableNodeDeficit
+
+ while not self.scaler.stop:
+ with throttle(self.scaler.config.scaleInterval):
+ # Estimate the number of nodes to run the issued jobs.
+
+ # Number of jobs issued
+ queueSize = self.scaler.leader.getNumberOfJobsIssued(preemptable=self.preemptable)
+
+ # Job shapes of completed jobs
+ recentJobShapes = self.jobShapes.get()
+ assert len(recentJobShapes) > 0
+
+ # Estimate of number of nodes needed to run recent jobs
+ nodesToRunRecentJobs = binPacking(recentJobShapes, self.nodeShape)
+
+ # Actual calculation of the estimated number of nodes required
+ estimatedNodes = 0 if queueSize == 0 else max(1, int(round(
+ self.scaler.config.alphaPacking
+ * nodesToRunRecentJobs
+ * float(queueSize) / len(recentJobShapes))))
+
+ # Account for case where the average historical runtime of completed jobs is less
+ # than the runtime of currently running jobs. This is important
+ # to avoid a deadlock where the estimated number of nodes to run the jobs
+ # is too small to schedule a set service jobs and their dependent jobs, leading
+ # to service jobs running indefinitely.
+
+ # How many jobs are currently running and their average runtime.
+ numberOfRunningJobs, currentAvgRuntime = self.scaler.leader.getNumberAndAvgRuntimeOfCurrentlyRunningJobs()
+
+ # Average runtime of recently completed jobs
+ historicalAvgRuntime = sum(map(lambda jS : jS.wallTime, recentJobShapes))
+
+ # Ratio of avg. runtime of currently running and completed jobs
+ runtimeCorrection = float(currentAvgRuntime)/historicalAvgRuntime if currentAvgRuntime > historicalAvgRuntime and numberOfRunningJobs >= estimatedNodes else 1.0
+
+ # Make correction, if necessary (only do so if cluster is busy and average runtime is higher than historical
+ # average)
+ if runtimeCorrection != 1.0:
+ logger.warn("Historical avg. runtime (%s) is less than current avg. runtime (%s) and cluster"
+ " is being well utilised (%s running jobs), increasing cluster requirement by: %s" %
+ (historicalAvgRuntime, currentAvgRuntime, numberOfRunningJobs, runtimeCorrection))
+ estimatedNodes *= runtimeCorrection
+
+ # If we're the non-preemptable scaler, we need to see if we have a deficit of
+ # preemptable nodes that we should compensate for.
+ if not self.preemptable:
+ compensation = self.scaler.config.preemptableCompensation
+ assert 0.0 <= compensation <= 1.0
+ # The number of nodes we provision as compensation for missing preemptable
+ # nodes is the product of the deficit (the number of preemptable nodes we did
+ # _not_ allocate) and configuration preference.
+ compensationNodes = int(round(_preemptableNodeDeficit * compensation))
+ logger.info('Adding %d preemptable nodes to compensate for a deficit of %d '
+ 'non-preemptable ones.', compensationNodes, _preemptableNodeDeficit)
+ estimatedNodes += compensationNodes
+
+ fix_my_name = (0 if nodesToRunRecentJobs <= 0
+ else len(recentJobShapes) / float(nodesToRunRecentJobs))
+ logger.info('Estimating that cluster needs %s %s of shape %s, from current '
+ 'size of %s, given a queue size of %s, the number of jobs per node '
+ 'estimated to be %s, an alpha parameter of %s and a run-time length correction of %s.',
+ estimatedNodes, self.nodeTypeString, self.nodeShape,
+ self.totalNodes, queueSize, fix_my_name,
+ self.scaler.config.alphaPacking, runtimeCorrection)
+
+ # Use inertia parameter to stop small fluctuations
+ if estimatedNodes <= self.totalNodes * self.scaler.config.betaInertia <= estimatedNodes:
+ logger.debug('Difference in new (%s) and previous estimates in number of '
+ '%s (%s) required is within beta (%s), making no change.',
+ estimatedNodes, self.nodeTypeString, self.totalNodes, self.scaler.config.betaInertia)
+ estimatedNodes = self.totalNodes
+
+ # Bound number using the max and min node parameters
+ if estimatedNodes > self.maxNodes:
+ logger.info('Limiting the estimated number of necessary %s (%s) to the '
+ 'configured maximum (%s).', self.nodeTypeString, estimatedNodes, self.maxNodes)
+ estimatedNodes = self.maxNodes
+ elif estimatedNodes < self.minNodes:
+ logger.info('Raising the estimated number of necessary %s (%s) to the '
+ 'configured mininimum (%s).', self.nodeTypeString, estimatedNodes, self.minNodes)
+ estimatedNodes = self.minNodes
+
+ if estimatedNodes != self.totalNodes:
+ logger.info('Changing the number of %s from %s to %s.', self.nodeTypeString, self.totalNodes,
+ estimatedNodes)
+ self.totalNodes = self.scaler.provisioner.setNodeCount(numNodes=estimatedNodes,
+ preemptable=self.preemptable)
+
+ # If we were scaling up the number of preemptable nodes and failed to meet
+ # our target, we need to update the slack so that non-preemptable nodes will
+ # be allocated instead and we won't block. If we _did_ meet our target,
+ # we need to reset the slack to 0.
+ if self.preemptable:
+ if self.totalNodes < estimatedNodes:
+ deficit = estimatedNodes - self.totalNodes
+ logger.info('Preemptable scaler detected deficit of %d nodes.', deficit)
+ _preemptableNodeDeficit = deficit
+ else:
+ _preemptableNodeDeficit = 0
+
+ self.scaler.provisioner.checkStats()
+
+ self.scaler.provisioner.shutDown(preemptable=self.preemptable)
+ logger.info('Scaler exited normally.')
diff --git a/src/toil/realtimeLogger.py b/src/toil/realtimeLogger.py
new file mode 100644
index 0000000..6d26e83
--- /dev/null
+++ b/src/toil/realtimeLogger.py
@@ -0,0 +1,246 @@
+# Copyright (C) 2015-2016 Regents of the University of California
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Implements a real-time UDP-based logging system that user scripts can use for debugging.
+"""
+
+from __future__ import absolute_import
+import os
+import os.path
+import json
+import logging
+import logging.handlers
+import socket
+import threading
+
+# Python 3 compatibility imports
+from six.moves import socketserver as SocketServer
+
+import toil.lib.bioio
+
+log = logging.getLogger(__name__)
+
+
+class LoggingDatagramHandler(SocketServer.BaseRequestHandler):
+ """
+ Receive logging messages from the jobs and display them on the leader.
+
+ Uses bare JSON message encoding.
+ """
+
+ def handle(self):
+ """
+ Handle a single message. SocketServer takes care of splitting out the messages.
+
+ Messages are JSON-encoded logging module records.
+ """
+ # Unpack the data from the request
+ data, socket = self.request
+ try:
+ # Parse it as JSON
+ message_attrs = json.loads(data)
+ # Fluff it up into a proper logging record
+ record = logging.makeLogRecord(message_attrs)
+ except:
+ # Complain someone is sending us bad logging data
+ logging.error("Malformed log message from {}".format(self.client_address[0]))
+ else:
+ # Log level filtering should have been done on the remote end. The handle() method
+ # skips it on this end.
+ log.handle(record)
+
+
+class JSONDatagramHandler(logging.handlers.DatagramHandler):
+ """
+ Send logging records over UDP serialized as JSON.
+
+ They have to fit in a single UDP datagram, so don't try to log more than 64kb at once.
+ """
+
+ def makePickle(self, record):
+ """
+ Actually, encode the record as bare JSON instead.
+ """
+ return json.dumps(record.__dict__)
+
+
+class RealtimeLoggerMetaclass(type):
+ """
+ Metaclass for RealtimeLogger that lets you do things like RealtimeLogger.warning(),
+ RealtimeLogger.info(), etc.
+ """
+
+ def __getattr__(self, name):
+ """
+ If a real attribute can't be found, try one of the logging methods on the actual logger
+ object.
+ """
+ return getattr(self.getLogger(), name)
+
+
+class RealtimeLogger(object):
+ """
+ Provides a logger that logs over UDP to the leader. To use in a Toil job, do:
+
+ >>> from toil.realtimeLogger import RealtimeLogger
+ >>> RealtimeLogger.info("This logging message goes straight to the leader")
+
+ That's all a user of Toil would need to do. On the leader, Job.Runner.startToil()
+ automatically starts the UDP server by using an instance of this class as a context manager.
+ """
+ # Enable RealtimeLogger.info() syntactic sugar
+ __metaclass__ = RealtimeLoggerMetaclass
+
+ # The names of all environment variables used by this class are prefixed with this string
+ envPrefix = "TOIL_RT_LOGGING_"
+
+ # Avoid duplicating the default level everywhere
+ defaultLevel = 'INFO'
+
+ # State maintained on server and client
+
+ lock = threading.RLock()
+
+ # Server-side state
+
+ # The leader keeps a server and thread
+ loggingServer = None
+ serverThread = None
+
+ initialized = 0
+
+ # Client-side state
+
+ logger = None
+
+ @classmethod
+ def _startLeader(cls, batchSystem, level=defaultLevel):
+ with cls.lock:
+ if cls.initialized == 0:
+ cls.initialized += 1
+ if level:
+ log.info('Starting real-time logging.')
+ # Start up the logging server
+ cls.loggingServer = SocketServer.ThreadingUDPServer(
+ server_address=('0.0.0.0', 0),
+ RequestHandlerClass=LoggingDatagramHandler)
+
+ # Set up a thread to do all the serving in the background and exit when we do
+ cls.serverThread = threading.Thread(target=cls.loggingServer.serve_forever)
+ cls.serverThread.daemon = True
+ cls.serverThread.start()
+
+ # Set options for logging in the environment so they get sent out to jobs
+ fqdn = socket.getfqdn()
+ try:
+ ip = socket.gethostbyname(fqdn)
+ except socket.gaierror:
+ # FIXME: Does this only happen for me? Should we librarize the work-around?
+ import platform
+ if platform.system() == 'Darwin' and '.' not in fqdn:
+ ip = socket.gethostbyname(fqdn + '.local')
+ else:
+ raise
+ port = cls.loggingServer.server_address[1]
+
+ def _setEnv(name, value):
+ name = cls.envPrefix + name
+ os.environ[name] = value
+ batchSystem.setEnv(name)
+
+ _setEnv('ADDRESS', '%s:%i' % (ip, port))
+ _setEnv('LEVEL', level)
+ else:
+ log.info('Real-time logging disabled')
+ else:
+ if level:
+ log.warn('Ignoring nested request to start real-time logging')
+
+ @classmethod
+ def _stopLeader(cls):
+ """
+ Stop the server on the leader.
+ """
+ with cls.lock:
+ assert cls.initialized > 0
+ cls.initialized -= 1
+ if cls.initialized == 0:
+ if cls.loggingServer:
+ log.info('Stopping real-time logging server.')
+ cls.loggingServer.shutdown()
+ cls.loggingServer = None
+ if cls.serverThread:
+ log.info('Joining real-time logging server thread.')
+ cls.serverThread.join()
+ cls.serverThread = None
+ for k in os.environ.keys():
+ if k.startswith(cls.envPrefix):
+ os.environ.pop(k)
+
+ @classmethod
+ def getLogger(cls):
+ """
+ Get the logger that logs real-time to the leader.
+
+ Note that if the returned logger is used on the leader, you will see the message twice,
+ since it still goes to the normal log handlers, too.
+ """
+ # Only do the setup once, so we don't add a handler every time we log. Use a lock to do
+ # so safely even if we're being called in different threads. Use double-checked locking
+ # to reduce the overhead introduced by the lock.
+ if cls.logger is None:
+ with cls.lock:
+ if cls.logger is None:
+ cls.logger = logging.getLogger('toil-rt')
+ try:
+ level = os.environ[cls.envPrefix + 'LEVEL']
+ except KeyError:
+ # There is no server running on the leader, so suppress most log messages
+ # and skip the UDP stuff.
+ cls.logger.setLevel(logging.CRITICAL)
+ else:
+ # Adopt the logging level set on the leader.
+ toil.lib.bioio.setLogLevel(level, cls.logger)
+ try:
+ address = os.environ[cls.envPrefix + 'ADDRESS']
+ except KeyError:
+ pass
+ else:
+ # We know where to send messages to, so send them.
+ host, port = address.split(':')
+ cls.logger.addHandler(JSONDatagramHandler(host, int(port)))
+ return cls.logger
+
+ def __init__(self, batchSystem, level=defaultLevel):
+ """
+ A context manager that starts up the UDP server.
+
+ Should only be invoked on the leader. Python logging should have already been configured.
+ This method takes an optional log level, as a string level name, from the set supported
+ by bioio. If the level is None, False or the empty string, real-time logging will be
+ disabled, i.e. no UDP server will be started on the leader and log messages will be
+ suppressed on the workers. Note that this is different from passing level='OFF',
+ which is equivalent to level='CRITICAL' and does not disable the server.
+ """
+ super(RealtimeLogger, self).__init__()
+ self.__level = level
+ self.__batchSystem = batchSystem
+
+ def __enter__(self):
+ RealtimeLogger._startLeader(self.__batchSystem, level=self.__level)
+
+ # noinspection PyUnusedLocal
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ RealtimeLogger._stopLeader()
diff --git a/src/toil/resource.py b/src/toil/resource.py
new file mode 100644
index 0000000..3462426
--- /dev/null
+++ b/src/toil/resource.py
@@ -0,0 +1,573 @@
+# Copyright (C) 2015-2016 Regents of the University of California
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+
+import errno
+import hashlib
+import importlib
+import json
+import logging
+import os
+import shutil
+import sys
+from collections import namedtuple
+from contextlib import closing
+from io import BytesIO
+from pydoc import locate
+from tempfile import mkdtemp
+from zipfile import ZipFile, PyZipFile
+
+# Python 3 compatibility imports
+from six.moves.urllib.request import urlopen
+
+from bd2k.util import strict_bool
+from bd2k.util.iterables import concat
+from bd2k.util.exceptions import require
+
+from toil import inVirtualEnv
+
+log = logging.getLogger(__name__)
+
+
+class Resource(namedtuple('Resource', ('name', 'pathHash', 'url', 'contentHash'))):
+ """
+ Represents a file or directory that will be deployed to each node before any jobs in the user
+ script are invoked. Each instance is a namedtuple with the following elements:
+
+ The pathHash element contains the MD5 (in hexdigest form) of the path to the resource on the
+ leader node. The path, and therefore its hash is unique within a job store.
+
+ The url element is a "file:" or "http:" URL at which the resource can be obtained.
+
+ The contentHash element is an MD5 checksum of the resource, allowing for validation and
+ caching of resources.
+
+ If the resource is a regular file, the type attribute will be 'file'.
+
+ If the resource is a directory, the type attribute will be 'dir' and the URL will point at a
+ ZIP archive of that directory.
+ """
+
+ resourceEnvNamePrefix = 'JTRES_'
+
+ rootDirPathEnvName = resourceEnvNamePrefix + 'ROOT'
+
+ @classmethod
+ def create(cls, jobStore, leaderPath):
+ """
+ Saves the content of the file or directory at the given path to the given job store
+ and returns a resource object representing that content for the purpose of obtaining it
+ again at a generic, public URL. This method should be invoked on the leader node.
+
+ :param toil.jobStores.abstractJobStore.AbstractJobStore jobStore:
+
+ :param str leaderPath:
+
+ :rtype: Resource
+ """
+ pathHash = cls._pathHash(leaderPath)
+ contentHash = hashlib.md5()
+ # noinspection PyProtectedMember
+ with cls._load(leaderPath) as src:
+ with jobStore.writeSharedFileStream(sharedFileName=pathHash, isProtected=False) as dst:
+ userScript = src.read()
+ contentHash.update(userScript)
+ dst.write(userScript)
+ return cls(name=os.path.basename(leaderPath),
+ pathHash=pathHash,
+ url=jobStore.getSharedPublicUrl(sharedFileName=pathHash),
+ contentHash=contentHash.hexdigest())
+
+ def refresh(self, jobStore):
+ return type(self)(name=self.name,
+ pathHash=self.pathHash,
+ url=jobStore.getSharedPublicUrl(sharedFileName=self.pathHash),
+ contentHash=self.contentHash)
+
+ @classmethod
+ def prepareSystem(cls):
+ """
+ Prepares this system for the downloading and lookup of resources. This method should only
+ be invoked on a worker node. It is idempotent but not thread-safe.
+ """
+ try:
+ resourceRootDirPath = os.environ[cls.rootDirPathEnvName]
+ except KeyError:
+ # Create directory holding local copies of requested resources ...
+ resourceRootDirPath = mkdtemp()
+ # .. and register its location in an environment variable such that child processes
+ # can find it.
+ os.environ[cls.rootDirPathEnvName] = resourceRootDirPath
+ assert os.path.isdir(resourceRootDirPath)
+
+ @classmethod
+ def cleanSystem(cls):
+ """
+ Removes all downloaded, localized resources
+ """
+ resourceRootDirPath = os.environ[cls.rootDirPathEnvName]
+ os.environ.pop(cls.rootDirPathEnvName)
+ shutil.rmtree(resourceRootDirPath)
+ for k, v in os.environ.items():
+ if k.startswith(cls.resourceEnvNamePrefix):
+ os.environ.pop(k)
+
+ def register(self):
+ """
+ Register this resource for later retrieval via lookup(), possibly in a child process.
+ """
+ os.environ[self.resourceEnvNamePrefix + self.pathHash] = self.pickle()
+
+ @classmethod
+ def lookup(cls, leaderPath):
+ """
+ Returns a resource object representing a resource created from a file or directory at the
+ given path on the leader. This method should be invoked on the worker. The given path
+ does not need to refer to an existing file or directory on the worker, it only identifies
+ the resource within an instance of toil. This method returns None if no resource for the
+ given path exists.
+
+ :rtype: Resource
+ """
+ pathHash = cls._pathHash(leaderPath)
+ try:
+ s = os.environ[cls.resourceEnvNamePrefix + pathHash]
+ except KeyError:
+ log.warn("Can't find resource for leader path '%s'", leaderPath)
+ return None
+ else:
+ self = cls.unpickle(s)
+ assert self.pathHash == pathHash
+ return self
+
+ def download(self, callback=None):
+ """
+ Downloads this resource from its URL to a file on the local system. This method should
+ only be invoked on a worker node after the node was setup for accessing resources via
+ prepareSystem().
+ """
+ dirPath = self.localDirPath
+ if not os.path.exists(dirPath):
+ tempDirPath = mkdtemp(dir=os.path.dirname(dirPath), prefix=self.contentHash + "-")
+ self._save(tempDirPath)
+ if callback is not None:
+ callback(tempDirPath)
+ try:
+ os.rename(tempDirPath, dirPath)
+ except OSError as e:
+ # If dirPath already exists & is non-empty either ENOTEMPTY or EEXIST will be raised
+ if e.errno == errno.ENOTEMPTY or e.errno == errno.EEXIST:
+ # Another process beat us to it.
+ # TODO: This is correct but inefficient since multiple processes download the resource redundantly
+ pass
+ else:
+ raise
+
+ @property
+ def localPath(self):
+ """
+ The path to resource on the worker. The file or directory at the returned path may or may
+ not yet exist. Invoking download() will ensure that it does.
+ """
+ raise NotImplementedError
+
+ @property
+ def localDirPath(self):
+ """
+ The path to the directory containing the resource on the worker.
+ """
+ rootDirPath = os.environ[self.rootDirPathEnvName]
+ return os.path.join(rootDirPath, self.contentHash)
+
+ def pickle(self):
+ return self.__class__.__module__ + "." + self.__class__.__name__ + ':' + json.dumps(self)
+
+ @classmethod
+ def unpickle(cls, s):
+ """
+ :rtype: Resource
+ """
+ className, _json = s.split(':', 1)
+ return locate(className)(*json.loads(_json))
+
+ @classmethod
+ def _pathHash(cls, path):
+ return hashlib.md5(path).hexdigest()
+
+ @classmethod
+ def _load(cls, path):
+ """
+ Returns a readable file-like object for the given path. If the path refers to a regular
+ file, this method returns the result of invoking open() on the given path. If the path
+ refers to a directory, this method returns a ZIP file with all files and subdirectories
+ in the directory at the given path.
+
+ :type path: str
+ :rtype: io.FileIO
+ """
+ raise NotImplementedError()
+
+ def _save(self, dirPath):
+ """
+ Save this resource to the directory at the given parent path.
+
+ :type dirPath: str
+ """
+ raise NotImplementedError()
+
+ def _download(self, dstFile):
+ """
+ Download this resource from its URL to the given file object.
+
+ :type dstFile: io.BytesIO|io.FileIO
+ """
+ with closing(urlopen(self.url)) as content:
+ buf = content.read()
+ contentHash = hashlib.md5(buf)
+ assert contentHash.hexdigest() == self.contentHash
+ dstFile.write(buf)
+
+
+class FileResource(Resource):
+ """
+ A resource read from a file on the leader.
+ """
+
+ @classmethod
+ def _load(cls, path):
+ return open(path)
+
+ def _save(self, dirPath):
+ with open(os.path.join(dirPath, self.name), mode='w') as localFile:
+ self._download(localFile)
+
+ @property
+ def localPath(self):
+ return os.path.join(self.localDirPath, self.name)
+
+
+class DirectoryResource(Resource):
+ """
+ A resource read from a directory on the leader. The URL will point to a ZIP archive of the
+ directory. Only Python script/modules will be included. The directory may be a package but it
+ does not need to be.
+ """
+
+ @classmethod
+ def _load(cls, path):
+ """
+ :type path: str
+ """
+ bytesIO = BytesIO()
+ # PyZipFile compiles .py files on the fly, filters out any non-Python files and
+ # distinguishes between packages and simple directories.
+ with PyZipFile(file=bytesIO, mode='w') as zipFile:
+ zipFile.writepy(path)
+ bytesIO.seek(0)
+ return bytesIO
+
+ def _save(self, dirPath):
+ bytesIO = BytesIO()
+ self._download(bytesIO)
+ bytesIO.seek(0)
+ with ZipFile(file=bytesIO, mode='r') as zipFile:
+ zipFile.extractall(path=dirPath)
+
+ @property
+ def localPath(self):
+ return self.localDirPath
+
+
+class VirtualEnvResource(DirectoryResource):
+ """
+ A resource read from a virtualenv on the leader. All modules and packages found in the
+ virtualenv's site-packages directory will be included. Any .pth or .egg-link files will be
+ ignored.
+ """
+
+ @classmethod
+ def _load(cls, path):
+ sitePackages = path
+ assert os.path.basename(sitePackages) == 'site-packages'
+ bytesIO = BytesIO()
+ with PyZipFile(file=bytesIO, mode='w') as zipFile:
+ # This adds the .py files but omits subdirectories since site-packages is not a package
+ zipFile.writepy(sitePackages)
+ # Now add the missing packages
+ for name in os.listdir(sitePackages):
+ path = os.path.join(sitePackages, name)
+ if os.path.isdir(path) and os.path.isfile(os.path.join(path, '__init__.py')):
+ zipFile.writepy(path)
+ bytesIO.seek(0)
+ return bytesIO
+
+
+class ModuleDescriptor(namedtuple('ModuleDescriptor', ('dirPath', 'name', 'fromVirtualEnv'))):
+ """
+ A path to a Python module decomposed into a namedtuple of three elements, namely
+
+ - dirPath, the path to the directory that should be added to sys.path before importing the
+ module,
+
+ - moduleName, the fully qualified name of the module with leading package names separated by
+ dot and
+
+ >>> import toil.resource
+ >>> ModuleDescriptor.forModule('toil.resource') # doctest: +ELLIPSIS
+ ModuleDescriptor(dirPath='/.../src', name='toil.resource', fromVirtualEnv=False)
+
+ >>> import subprocess, tempfile, os
+ >>> dirPath = tempfile.mkdtemp()
+ >>> path = os.path.join( dirPath, 'foo.py' )
+ >>> with open(path,'w') as f:
+ ... f.write('from toil.resource import ModuleDescriptor\\n'
+ ... 'print ModuleDescriptor.forModule(__name__)')
+ >>> subprocess.check_output([ sys.executable, path ]) # doctest: +ELLIPSIS
+ "ModuleDescriptor(dirPath='...', name='foo', fromVirtualEnv=False)\\n"
+
+ Now test a collision. As funny as it sounds, the robotparser module is included in the Python
+ standard library.
+ >>> dirPath = tempfile.mkdtemp()
+ >>> path = os.path.join( dirPath, 'robotparser.py' )
+ >>> with open(path,'w') as f:
+ ... f.write('from toil.resource import ModuleDescriptor\\n'
+ ... 'ModuleDescriptor.forModule(__name__)')
+
+ This should fail and return exit status 1 due to the collision with the built-in 'test' module:
+ >>> subprocess.call([ sys.executable, path ])
+ 1
+
+ Clean up
+ >>> from shutil import rmtree
+ >>> rmtree( dirPath )
+ """
+
+ @classmethod
+ def forModule(cls, name):
+ """
+ Return an instance of this class representing the module of the given name. If the given
+ module name is "__main__", it will be translated to the actual file name of the top-level
+ script without the .py or .pyc extension. This method assumes that the module with the
+ specified name has already been loaded.
+ """
+ module = sys.modules[name]
+ filePath = os.path.abspath(module.__file__)
+ filePath = filePath.split(os.path.sep)
+ filePath[-1], extension = os.path.splitext(filePath[-1])
+ require(extension in ('.py', '.pyc'),
+ 'The name of a user script/module must end in .py or .pyc.')
+ if name == '__main__':
+ # User script/module was invoked as the main program
+ if module.__package__:
+ # Invoked as a module via python -m foo.bar
+ name = [filePath.pop()]
+ for package in reversed(module.__package__.split('.')):
+ dirPathTail = filePath.pop()
+ assert dirPathTail == package
+ name.append(dirPathTail)
+ name = '.'.join(reversed(name))
+ dirPath = os.path.sep.join(filePath)
+ else:
+ # Invoked as a script via python foo/bar.py
+ name = filePath.pop()
+ dirPath = os.path.sep.join(filePath)
+ cls._check_conflict(dirPath, name)
+ else:
+ # User module was imported. Determine the directory containing the top-level package
+ for package in reversed(name.split('.')):
+ dirPathTail = filePath.pop()
+ assert dirPathTail == package
+ dirPath = os.path.sep.join(filePath)
+ assert os.path.isdir(dirPath)
+ fromVirtualEnv = inVirtualEnv() and dirPath.startswith(sys.prefix)
+ return cls(dirPath=dirPath, name=name, fromVirtualEnv=fromVirtualEnv)
+
+ @classmethod
+ def _check_conflict(cls, dirPath, name):
+ """
+ Check whether the module of the given name conflicts with another module on the sys.path.
+
+ :param dirPath: the directory from which the module was originally loaded
+ :param name: the mpdule name
+ """
+ old_sys_path = sys.path
+ try:
+ sys.path = [d for d in old_sys_path if os.path.realpath(d) != os.path.realpath(dirPath)]
+ try:
+ colliding_module = importlib.import_module(name)
+ except ImportError:
+ pass
+ else:
+ raise ResourceException(
+ "The user module '%s' collides with module '%s from '%s'." % (
+ name, colliding_module.__name__, colliding_module.__file__))
+ finally:
+ sys.path = old_sys_path
+
+ @property
+ def belongsToToil(self):
+ """
+ True if this module is part of the Toil distribution
+ """
+ return self.name.startswith('toil.')
+
+ def saveAsResourceTo(self, jobStore):
+ """
+ Store the file containing this module--or even the Python package directory hierarchy
+ containing that file--as a resource to the given job store and return the
+ corresponding resource object. Should only be called on a leader node.
+
+ :type jobStore: toil.jobStores.abstractJobStore.AbstractJobStore
+ :rtype: toil.resource.Resource
+ """
+ return self._getResourceClass().create(jobStore, self._resourcePath)
+
+ def _getResourceClass(self):
+ """
+ Return the concrete subclass of Resource that's appropriate for hot-deploying this module.
+ """
+ if self.fromVirtualEnv:
+ subcls = VirtualEnvResource
+ elif os.path.isdir(self._resourcePath):
+ subcls = DirectoryResource
+ elif os.path.isfile(self._resourcePath):
+ subcls = FileResource
+ elif os.path.exists(self._resourcePath):
+ raise AssertionError("Neither a file or a directory: '%s'" % self._resourcePath)
+ else:
+ raise AssertionError("No such file or directory: '%s'" % self._resourcePath)
+ return subcls
+
+ def localize(self):
+ """
+ Check if this module was saved as a resource. If it was, return a new module descriptor
+ that points to a local copy of that resource. Should only be called on a worker node. On
+ the leader, this method returns this resource, i.e. self.
+
+ :rtype: toil.resource.Resource
+ """
+ if not self._runningOnWorker():
+ log.warn('The localize() method should only be invoked on a worker.')
+ resource = Resource.lookup(self._resourcePath)
+ if resource is None:
+ log.warn("Can't localize module %r", self)
+ return self
+ else:
+ def stash(tmpDirPath):
+ # Save the original dirPath such that we can restore it in globalize()
+ with open(os.path.join(tmpDirPath, '.stash'), 'w') as f:
+ f.write('1' if self.fromVirtualEnv else '0')
+ f.write(self.dirPath)
+
+ resource.download(callback=stash)
+ return self.__class__(dirPath=resource.localDirPath,
+ name=self.name,
+ fromVirtualEnv=self.fromVirtualEnv)
+
+ def _runningOnWorker(self):
+ try:
+ mainModule = sys.modules['__main__']
+ except KeyError:
+ log.warning('Cannot determine main program module.')
+ return False
+ else:
+ mainModuleFile = os.path.basename(mainModule.__file__)
+ workerModuleFiles = concat(('worker' + ext for ext in self.moduleExtensions),
+ '_toil_worker') # the setuptools entry point
+ return mainModuleFile in workerModuleFiles
+
+ def globalize(self):
+ """
+ Reverse the effect of localize().
+ """
+ try:
+ with open(os.path.join(self.dirPath, '.stash')) as f:
+ fromVirtualEnv = [False, True][int(f.read(1))]
+ dirPath = f.read()
+ except IOError as e:
+ if e.errno == errno.ENOENT:
+ if self._runningOnWorker():
+ log.warn("Can't globalize module %r.", self)
+ return self
+ else:
+ raise
+ else:
+ return self.__class__(dirPath=dirPath,
+ name=self.name,
+ fromVirtualEnv=fromVirtualEnv)
+
+ @property
+ def _resourcePath(self):
+ """
+ The path to the directory that should be used when shipping this module and its siblings
+ around as a resource.
+ """
+ if self.fromVirtualEnv:
+ return self.dirPath
+ elif '.' in self.name:
+ return os.path.join(self.dirPath, self._rootPackage())
+ else:
+ initName = self._initModuleName(self.dirPath)
+ if initName:
+ raise ResourceException(
+ "Toil does not support loading a user script from a package directory. You "
+ "may want to remove %s from %s or invoke the user script as a module via "
+ "'PYTHONPATH=\"%s\" python -m %s.%s'." %
+ tuple(concat(initName, self.dirPath, os.path.split(self.dirPath), self.name)))
+ return self.dirPath
+
+ moduleExtensions = ('.py', '.pyc', '.pyo')
+
+ @classmethod
+ def _initModuleName(cls, dirPath):
+ for extension in cls.moduleExtensions:
+ name = '__init__' + extension
+ if os.path.exists(os.path.join(dirPath, name)):
+ return name
+ return None
+
+ def _rootPackage(self):
+ try:
+ head, tail = self.name.split('.', 1)
+ except ValueError:
+ raise ValueError('%r is stand-alone module.' % self)
+ else:
+ return head
+
+ def toCommand(self):
+ return tuple(map(str, self))
+
+ @classmethod
+ def fromCommand(cls, command):
+ assert len(command) == 3
+ return cls(dirPath=command[0], name=command[1], fromVirtualEnv=strict_bool(command[2]))
+
+ def makeLoadable(self):
+ module = self if self.belongsToToil else self.localize()
+ if module.dirPath not in sys.path:
+ sys.path.append(module.dirPath)
+ return module
+
+ def load(self):
+ module = self.makeLoadable()
+ try:
+ return importlib.import_module(module.name)
+ except ImportError:
+ log.error('Failed to import user module %r from sys.path (%r).', module, sys.path)
+ raise
+
+
+class ResourceException(Exception):
+ pass
diff --git a/src/toil/serviceManager.py b/src/toil/serviceManager.py
new file mode 100644
index 0000000..2176560
--- /dev/null
+++ b/src/toil/serviceManager.py
@@ -0,0 +1,197 @@
+# Copyright (C) 2015-2016 Regents of the University of California
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+
+import logging
+import time
+from threading import Thread, Event
+
+# Python 3 compatibility imports
+from six.moves.queue import Empty, Queue
+
+logger = logging.getLogger( __name__ )
+
+class ServiceManager( object ):
+ """
+ Manages the scheduling of services.
+ """
+ def __init__(self, jobStore, toilState):
+ self.jobStore = jobStore
+
+ self.toilState = toilState
+
+ self.jobGraphsWithServicesBeingStarted = set()
+
+ self._terminate = Event() # This is used to terminate the thread associated
+ # with the service manager
+
+ self._jobGraphsWithServicesToStart = Queue() # This is the input queue of
+ # jobGraphs that have services that need to be started
+
+ self._jobGraphsWithServicesThatHaveStarted = Queue() # This is the output queue
+ # of jobGraphs that have services that are already started
+
+ self._serviceJobGraphsToStart = Queue() # This is the queue of services for the
+ # batch system to start
+
+ self.jobsIssuedToServiceManager = 0 # The number of jobs the service manager
+ # is scheduling
+
+ # Start a thread that starts the services of jobGraphs in the
+ # jobsWithServicesToStart input queue and puts the jobGraphs whose services
+ # are running on the jobGraphssWithServicesThatHaveStarted output queue
+ self._serviceStarter = Thread(target=self._startServices,
+ args=(self._jobGraphsWithServicesToStart,
+ self._jobGraphsWithServicesThatHaveStarted,
+ self._serviceJobGraphsToStart, self._terminate,
+ self.jobStore))
+
+ def start(self):
+ """
+ Start the service scheduling thread.
+ """
+ self._serviceStarter.start()
+
+ def scheduleServices(self, jobGraph):
+ """
+ Schedule the services of a job asynchronously.
+ When the job's services are running the jobGraph for the job will
+ be returned by toil.leader.ServiceManager.getJobGraphsWhoseServicesAreRunning.
+
+ :param toil.jobGraph.JobGraph jobGraph: wrapper of job with services to schedule.
+ """
+ # Add jobGraph to set being processed by the service manager
+ self.jobGraphsWithServicesBeingStarted.add(jobGraph)
+
+ # Add number of jobs managed by ServiceManager
+ self.jobsIssuedToServiceManager += sum(map(len, jobGraph.services)) + 1 # The plus one accounts for the root job
+
+ # Asynchronously schedule the services
+ self._jobGraphsWithServicesToStart.put(jobGraph)
+
+ def getJobGraphWhoseServicesAreRunning(self, maxWait):
+ """
+ :param float maxWait: Time in seconds to wait to get a jobGraph before returning
+ :return: a jobGraph added to scheduleServices whose services are running, or None if
+ no such job is available.
+ :rtype: JobGraph
+ """
+ try:
+ jobGraph = self._jobGraphsWithServicesThatHaveStarted.get(timeout=maxWait)
+ self.jobGraphsWithServicesBeingStarted.remove(jobGraph)
+ assert self.jobsIssuedToServiceManager >= 0
+ self.jobsIssuedToServiceManager -= 1
+ return jobGraph
+ except Empty:
+ return None
+
+ def getServiceJobsToStart(self, maxWait):
+ """
+ :param float maxWait: Time in seconds to wait to get a job before returning.
+ :return: a tuple of (serviceJobStoreID, memory, cores, disk, ..) representing
+ a service job to start.
+ :rtype: toil.job.ServiceJobNode
+ """
+ try:
+ serviceJob = self._serviceJobGraphsToStart.get(timeout=maxWait)
+ assert self.jobsIssuedToServiceManager >= 0
+ self.jobsIssuedToServiceManager -= 1
+ return serviceJob
+ except Empty:
+ return None
+
+ def killServices(self, services, error=False):
+ """
+ :param dict services: Maps service jobStoreIDs to the communication flags for the service
+ """
+ for serviceJobStoreID in services:
+ serviceJob = services[serviceJobStoreID]
+ if error:
+ self.jobStore.deleteFile(serviceJob.errorJobStoreID)
+ self.jobStore.deleteFile(serviceJob.terminateJobStoreID)
+
+ def isActive(self, serviceJobNode):
+ """
+ Returns true is the service job has not been told to terminate.
+ :rtype: boolean
+ """
+ return self.jobStore.fileExists(serviceJobNode.terminateJobStoreID)
+
+ def check(self):
+ """
+ Check on the service manager thread.
+ :raise RuntimeError: If the underlying thread has quit.
+ """
+ if not self._serviceStarter.is_alive():
+ raise RuntimeError("Service manager has quit")
+
+ def shutdown(self):
+ """
+ Cleanly terminate worker threads starting and killing services. Will block
+ until all services are started and blocked.
+ """
+ logger.info('Waiting for service manager thread to finish ...')
+ startTime = time.time()
+ self._terminate.set()
+ self._serviceStarter.join()
+ # Kill any services still running to avoid deadlock
+ for services in self.toilState.servicesIssued.values():
+ self.killServices(services, error=True)
+ logger.info('... finished shutting down the service manager. Took %s seconds', time.time() - startTime)
+
+ @staticmethod
+ def _startServices(jobGraphsWithServicesToStart,
+ jobGraphsWithServicesThatHaveStarted,
+ serviceJobsToStart,
+ terminate, jobStore):
+ """
+ Thread used to schedule services.
+ """
+ while True:
+ try:
+ # Get a jobGraph with services to start, waiting a short period
+ jobGraph = jobGraphsWithServicesToStart.get(timeout=1.0)
+ except:
+ # Check if the thread should quit
+ if terminate.is_set():
+ logger.debug('Received signal to quit starting services.')
+ break
+ continue
+
+ if jobGraph is None: # Nothing was ready, loop again
+ continue
+
+ # Start the service jobs in batches, waiting for each batch
+ # to become established before starting the next batch
+ for serviceJobList in jobGraph.services:
+ for serviceJob in serviceJobList:
+ logger.debug("Service manager is starting service job: %s, start ID: %s", serviceJob, serviceJob.startJobStoreID)
+ assert jobStore.fileExists(serviceJob.startJobStoreID)
+ # At this point the terminateJobStoreID and errorJobStoreID could have been deleted!
+ serviceJobsToStart.put(serviceJob)
+
+ # Wait until all the services of the batch are running
+ for serviceJob in serviceJobList:
+ while jobStore.fileExists(serviceJob.startJobStoreID):
+ # Sleep to avoid thrashing
+ time.sleep(1.0)
+
+ # Check if the thread should quit
+ if terminate.is_set():
+ logger.debug('Received signal to quit starting services.')
+ break
+
+ # Add the jobGraph to the output queue of jobs whose services have been started
+ jobGraphsWithServicesThatHaveStarted.put(jobGraph)
\ No newline at end of file
diff --git a/src/toil/statsAndLogging.py b/src/toil/statsAndLogging.py
new file mode 100644
index 0000000..25190da
--- /dev/null
+++ b/src/toil/statsAndLogging.py
@@ -0,0 +1,154 @@
+# Copyright (C) 2015-2016 Regents of the University of California
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+
+import gzip
+import json
+import logging
+import os
+import time
+from threading import Thread, Event
+
+from bd2k.util.expando import Expando
+from toil.lib.bioio import getTotalCpuTime
+
+logger = logging.getLogger( __name__ )
+
+class StatsAndLogging( object ):
+ """
+ Class manages a thread that aggregates statistics and logging information on a toil run.
+ """
+
+ def __init__(self, jobStore, config):
+ self._stop = Event()
+ self._worker = Thread(target=self.statsAndLoggingAggregator,
+ args=(jobStore, self._stop, config))
+
+ def start(self):
+ """
+ Start the stats and logging thread.
+ """
+ self._worker.start()
+
+ @classmethod
+ def writeLogFiles(cls, jobNames, jobLogList, config):
+ def createName(logPath, jobName, logExtension):
+ logName = jobName.replace('-', '--')
+ logName = logName.replace('/', '-')
+ logName = logName.replace(' ', '_')
+ logName = logName.replace("'", '')
+ logName = logName.replace('"', '')
+ counter = 0
+ while True:
+ suffix = str(counter).zfill(3) + logExtension
+ fullName = os.path.join(logPath, logName + suffix)
+ if not os.path.exists(fullName):
+ return fullName
+ counter += 1
+
+ mainFileName = jobNames[0]
+ extension = '.log'
+
+ assert not (config.writeLogs and config.writeLogsGzip), \
+ "Cannot use both --writeLogs and --writeLogsGzip at the same time."
+
+ if config.writeLogs:
+ path = config.writeLogs
+ writeFn = open
+ elif config.writeLogsGzip:
+ path = config.writeLogsGzip
+ writeFn = gzip.open
+ extension += '.gz'
+ else:
+ # we don't have anywhere to write the logs, return now
+ return
+
+ fullName = createName(path, mainFileName, extension)
+ with writeFn(fullName, 'w') as f:
+ f.writelines(l + '\n' for l in jobLogList)
+ for alternateName in jobNames[1:]:
+ # There are chained jobs in this output - indicate this with a symlink
+ # of the job's name to this file
+ name = createName(path, alternateName, extension)
+ os.symlink(os.path.relpath(fullName, path), name)
+
+ @classmethod
+ def statsAndLoggingAggregator(cls, jobStore, stop, config):
+ """
+ The following function is used for collating stats/reporting log messages from the workers.
+ Works inside of a thread, collates as long as the stop flag is not True.
+ """
+ # Overall timing
+ startTime = time.time()
+ startClock = getTotalCpuTime()
+
+ def callback(fileHandle):
+ stats = json.load(fileHandle, object_hook=Expando)
+ try:
+ logs = stats.workers.logsToMaster
+ except AttributeError:
+ # To be expected if there were no calls to logToMaster()
+ pass
+ else:
+ for message in logs:
+ logger.log(int(message.level),
+ 'Got message from job at time %s: %s',
+ time.strftime('%m-%d-%Y %H:%M:%S'), message.text)
+ try:
+ logs = stats.logs
+ except AttributeError:
+ pass
+ else:
+ def logWithFormatting(jobStoreID, jobLogs):
+ logFormat = '\n%s ' % jobStoreID
+ logger.debug('Received Toil worker log. Disable debug level '
+ 'logging to hide this output\n%s', logFormat.join(jobLogs))
+ # we may have multiple jobs per worker
+ jobNames = logs.names
+ messages = logs.messages
+ logWithFormatting(jobNames[0], messages)
+ cls.writeLogFiles(jobNames, messages, config=config)
+
+ while True:
+ # This is a indirect way of getting a message to the thread to exit
+ if stop.is_set():
+ jobStore.readStatsAndLogging(callback)
+ break
+ if jobStore.readStatsAndLogging(callback) == 0:
+ time.sleep(0.5) # Avoid cycling too fast
+
+ # Finish the stats file
+ text = json.dumps(dict(total_time=str(time.time() - startTime),
+ total_clock=str(getTotalCpuTime() - startClock)))
+ jobStore.writeStatsAndLogging(text)
+
+ def check(self):
+ """
+ Check on the stats and logging aggregator.
+ :raise RuntimeError: If the underlying thread has quit.
+ """
+ if not self._worker.is_alive():
+ raise RuntimeError("Stats and logging thread has quit")
+
+ def shutdown(self):
+ """
+ Finish up the stats/logging aggregation thread
+ """
+ logger.info('Waiting for stats and logging collator thread to finish ...')
+ startTime = time.time()
+ self._stop.set()
+ self._worker.join()
+ logger.info('... finished collating stats and logs. Took %s seconds', time.time() - startTime)
+ # in addition to cleaning on exceptions, onError should clean if there are any failed jobs
\ No newline at end of file
diff --git a/src/toil/test/__init__.py b/src/toil/test/__init__.py
new file mode 100644
index 0000000..31ac2a2
--- /dev/null
+++ b/src/toil/test/__init__.py
@@ -0,0 +1,863 @@
+# Copyright (C) 2015-2016 Regents of the University of California
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+
+import logging
+import multiprocessing
+import os
+import re
+import shutil
+import signal
+import tempfile
+import threading
+import time
+import unittest
+import uuid
+from abc import ABCMeta, abstractmethod
+from contextlib import contextmanager
+from inspect import getsource
+from subprocess import PIPE, Popen, CalledProcessError, check_output
+from textwrap import dedent
+from unittest.util import strclass
+
+# Python 3 compatibility imports
+from six import iteritems, itervalues
+from six.moves.urllib.request import urlopen
+
+from bd2k.util import less_strict_bool, memoize
+from bd2k.util.files import mkdir_p
+from bd2k.util.iterables import concat
+from bd2k.util.processes import which
+from bd2k.util.threading import ExceptionalThread
+
+from toil import toilPackageDirPath, applianceSelf
+from toil.version import distVersion
+
+log = logging.getLogger(__name__)
+
+
+class ToilTest(unittest.TestCase):
+ """
+ A common base class for Toil tests. Please have every test case directly or indirectly
+ inherit this one.
+
+ When running tests you may optionally set the TOIL_TEST_TEMP environment variable to the path
+ of a directory where you want temporary test files be placed. The directory will be created
+ if it doesn't exist. The path may be relative in which case it will be assumed to be relative
+ to the project root. If TOIL_TEST_TEMP is not defined, temporary files and directories will
+ be created in the system's default location for such files and any temporary files or
+ directories left over from tests will be removed automatically removed during tear down.
+ Otherwise, left-over files will not be removed.
+ """
+
+ _tempBaseDir = None
+ _tempDirs = None
+
+ @classmethod
+ def setUpClass(cls):
+ super(ToilTest, cls).setUpClass()
+ cls._tempDirs = []
+ tempBaseDir = os.environ.get('TOIL_TEST_TEMP', None)
+ if tempBaseDir is not None and not os.path.isabs(tempBaseDir):
+ tempBaseDir = os.path.abspath(os.path.join(cls._projectRootPath(), tempBaseDir))
+ mkdir_p(tempBaseDir)
+ cls._tempBaseDir = tempBaseDir
+
+ @classmethod
+ def awsRegion(cls):
+ """
+ Use us-west-2 unless running on EC2, in which case use the region in which
+ the instance is located
+ """
+ if runningOnEC2():
+ return cls._region()
+ else:
+ return 'us-west-2'
+
+ @classmethod
+ def _availabilityZone(cls):
+ """
+ Used only when running on EC2. Query this instance's metadata to determine
+ in which availability zone it is running
+ """
+ return urlopen('http://169.254.169.254/latest/meta-data/placement/availability-zone').read()
+
+ @classmethod
+ @memoize
+ def _region(cls):
+ """
+ Used only when running on EC2. Determines in what region this instance is running.
+ The region will not change over the life of the instance so the result
+ is memoized to avoid unnecessary work.
+ """
+ m = re.match(r'^([a-z]{2}-[a-z]+-[1-9][0-9]*)([a-z])$', cls._availabilityZone())
+ assert m
+ region = m.group(1)
+ return region
+
+ @classmethod
+ def _getUtilScriptPath(cls, script_name):
+ return os.path.join(toilPackageDirPath(), 'utils', script_name + '.py')
+
+ @classmethod
+ def _projectRootPath(cls):
+ """
+ Returns the path to the project root, i.e. the directory that typically contains the .git
+ and src subdirectories. This method has limited utility. It only works if in "develop"
+ mode, since it assumes the existence of a src subdirectory which, in a regular install
+ wouldn't exist. Then again, in that mode project root has no meaning anyways.
+ """
+ assert re.search(r'__init__\.pyc?$', __file__)
+ projectRootPath = os.path.dirname(os.path.abspath(__file__))
+ packageComponents = __name__.split('.')
+ expectedSuffix = os.path.join('src', *packageComponents)
+ assert projectRootPath.endswith(expectedSuffix)
+ projectRootPath = projectRootPath[:-len(expectedSuffix)]
+ return projectRootPath
+
+ @classmethod
+ def tearDownClass(cls):
+ if cls._tempBaseDir is None:
+ while cls._tempDirs:
+ tempDir = cls._tempDirs.pop()
+ if os.path.exists(tempDir):
+ shutil.rmtree(tempDir)
+ else:
+ cls._tempDirs = []
+ super(ToilTest, cls).tearDownClass()
+
+ def setUp(self):
+ log.info("Setting up %s ...", self.id())
+ super(ToilTest, self).setUp()
+
+ def _createTempDir(self, purpose=None):
+ return self._createTempDirEx(self._testMethodName, purpose)
+
+ @classmethod
+ def _createTempDirEx(cls, *names):
+ prefix = ['toil', 'test', strclass(cls)]
+ prefix.extend(filter(None, names))
+ prefix.append('')
+ temp_dir_path = tempfile.mkdtemp(dir=cls._tempBaseDir, prefix='-'.join(prefix))
+ cls._tempDirs.append(temp_dir_path)
+ return temp_dir_path
+
+ def tearDown(self):
+ super(ToilTest, self).tearDown()
+ log.info("Tore down %s", self.id())
+
+ def _getTestJobStorePath(self):
+ path = self._createTempDir(purpose='jobstore')
+ # We only need a unique path, directory shouldn't actually exist. This of course is racy
+ # and insecure because another thread could now allocate the same path as a temporary
+ # directory. However, the built-in tempfile module randomizes the name temp dir suffixes
+ # reasonably well (1 in 63 ^ 6 chance of collision), making this an unlikely scenario.
+ os.rmdir(path)
+ return path
+
+ @classmethod
+ def _getSourceDistribution(cls):
+ """
+ Find the sdist tarball for this project, check whether it is up-to date and return the
+ path to it.
+
+ :rtype: str
+ """
+ sdistPath = os.path.join(cls._projectRootPath(), 'dist', 'toil-%s.tar.gz' % distVersion)
+ assert os.path.isfile(
+ sdistPath), "Can't find Toil source distribution at %s. Run 'make sdist'." % sdistPath
+ excluded = set(cls._run('git', 'ls-files', '--others', '-i', '--exclude-standard',
+ capture=True,
+ cwd=cls._projectRootPath()).splitlines())
+ dirty = cls._run('find', 'src', '-type', 'f', '-newer', sdistPath,
+ capture=True,
+ cwd=cls._projectRootPath()).splitlines()
+ assert all(path.startswith('src') for path in dirty)
+ dirty = set(dirty)
+ dirty.difference_update(excluded)
+ assert not dirty, \
+ "Run 'make clean_sdist sdist'. Files newer than %s: %r" % (sdistPath, list(dirty))
+ return sdistPath
+
+ @classmethod
+ def _run(cls, command, *args, **kwargs):
+ """
+ Run a command. Convenience wrapper for subprocess.check_call and subprocess.check_output.
+
+ :param str command: The command to be run.
+
+ :param str args: Any arguments to be passed to the command.
+
+ :param Any kwargs: keyword arguments for subprocess.Popen constructor. Pass capture=True
+ to have the process' stdout returned. Pass input='some string' to feed input to the
+ process' stdin.
+
+ :rtype: None|str
+
+ :return: The output of the process' stdout if capture=True was passed, None otherwise.
+ """
+ args = list(concat(command, args))
+ log.info('Running %r', args)
+ capture = kwargs.pop('capture', False)
+ _input = kwargs.pop('input', None)
+ if capture:
+ kwargs['stdout'] = PIPE
+ if _input is not None:
+ kwargs['stdin'] = PIPE
+ popen = Popen(args, **kwargs)
+ stdout, stderr = popen.communicate(input=_input)
+ assert stderr is None
+ if popen.returncode != 0:
+ raise CalledProcessError(popen.returncode, args)
+ if capture:
+ return stdout
+
+ def _getScriptSource(self, callable_):
+ """
+ Returns the source code of the body of given callable as a string, dedented. This is a
+ naught but incredibly useful trick that lets you embed user scripts as nested functions
+ and expose them to the syntax checker of your IDE.
+ """
+ return dedent('\n'.join(getsource(callable_).split('\n')[1:]))
+
+
+try:
+ # noinspection PyUnresolvedReferences
+ from _pytest.mark import MarkDecorator
+except ImportError:
+ # noinspection PyUnusedLocal
+ def _mark_test(name, test_item):
+ return test_item
+else:
+ def _mark_test(name, test_item):
+ return MarkDecorator(name)(test_item)
+
+
+def needs_aws(test_item):
+ """
+ Use as a decorator before test classes or methods to only run them if AWS usable.
+ """
+ test_item = _mark_test('aws', test_item)
+ try:
+ # noinspection PyUnresolvedReferences
+ from boto import config
+ except ImportError:
+ return unittest.skip("Install toil with the 'aws' extra to include this test.")(test_item)
+ except:
+ raise
+ else:
+ dot_aws_credentials_path = os.path.expanduser('~/.aws/credentials')
+ boto_credentials = config.get('Credentials', 'aws_access_key_id')
+ if boto_credentials:
+ return test_item
+ if os.path.exists(dot_aws_credentials_path) or runningOnEC2():
+ # Assume that EC2 machines like the Jenkins slave that we run CI on will have IAM roles
+ return test_item
+ else:
+ return unittest.skip("Configure ~/.aws/credentials with AWS credentials to include "
+ "this test.")(test_item)
+
+
+def file_begins_with(path, prefix):
+ with open(path) as f:
+ return f.read(len(prefix)) == prefix
+
+
+def runningOnEC2():
+ hv_uuid_path = '/sys/hypervisor/uuid'
+ return os.path.exists(hv_uuid_path) and file_begins_with(hv_uuid_path, 'ec2')
+
+
+def needs_google(test_item):
+ """
+ Use as a decorator before test classes or methods to only run them if Google Storage usable.
+ """
+ test_item = _mark_test('google', test_item)
+ try:
+ # noinspection PyUnresolvedReferences
+ from boto import config
+ except ImportError:
+ return unittest.skip(
+ "Install Toil with the 'google' extra to include this test.")(test_item)
+ else:
+ boto_credentials = config.get('Credentials', 'gs_access_key_id')
+ if boto_credentials:
+ return test_item
+ else:
+ return unittest.skip(
+ "Configure ~/.boto with Google Cloud credentials to include this test.")(test_item)
+
+
+def needs_azure(test_item):
+ """
+ Use as a decorator before test classes or methods to only run them if Azure is usable.
+ """
+ test_item = _mark_test('azure', test_item)
+ try:
+ # noinspection PyUnresolvedReferences
+ import azure.storage
+ except ImportError:
+ return unittest.skip("Install Toil with the 'azure' extra to include this test.")(test_item)
+ except:
+ raise
+ else:
+ from toil.jobStores.azureJobStore import credential_file_path
+ full_credential_file_path = os.path.expanduser(credential_file_path)
+ if not os.path.exists(full_credential_file_path):
+ return unittest.skip("Configure %s with the access key for the 'toiltest' storage "
+ "account." % credential_file_path)(test_item)
+ return test_item
+
+
+def needs_gridengine(test_item):
+ """
+ Use as a decorator before test classes or methods to only run them if GridEngine is installed.
+ """
+ test_item = _mark_test('gridengine', test_item)
+ if next(which('qsub'), None):
+ return test_item
+ else:
+ return unittest.skip("Install GridEngine to include this test.")(test_item)
+
+
+def needs_mesos(test_item):
+ """
+ Use as a decorator before test classes or methods to only run them if the Mesos is installed
+ and configured.
+ """
+ test_item = _mark_test('mesos', test_item)
+ try:
+ # noinspection PyUnresolvedReferences
+ import mesos.native
+ except ImportError:
+ return unittest.skip(
+ "Install Mesos (and Toil with the 'mesos' extra) to include this test.")(test_item)
+ except:
+ raise
+ else:
+ return test_item
+
+
+def needs_parasol(test_item):
+ """
+ Use as decorator so tests are only run if Parasol is installed.
+ """
+ test_item = _mark_test('parasol', test_item)
+ if next(which('parasol'), None):
+ return test_item
+ else:
+ return unittest.skip("Install Parasol to include this test.")(test_item)
+
+
+def needs_slurm(test_item):
+ """
+ Use as a decorator before test classes or methods to only run them if Slurm is installed.
+ """
+ test_item = _mark_test('slurm', test_item)
+ if next(which('squeue'), None):
+ return test_item
+ else:
+ return unittest.skip("Install Slurm to include this test.")(test_item)
+
+
+def needs_encryption(test_item):
+ """
+ Use as a decorator before test classes or methods to only run them if PyNaCl is installed
+ and configured.
+ """
+ test_item = _mark_test('encryption', test_item)
+ try:
+ # noinspection PyUnresolvedReferences
+ import nacl
+ except ImportError:
+ return unittest.skip(
+ "Install Toil with the 'encryption' extra to include this test.")(test_item)
+ except:
+ raise
+ else:
+ return test_item
+
+
+def needs_cwl(test_item):
+ """
+ Use as a decorator before test classes or methods to only run them if CWLTool is installed
+ and configured.
+ """
+ test_item = _mark_test('cwl', test_item)
+ try:
+ # noinspection PyUnresolvedReferences
+ import cwltool
+ except ImportError:
+ return unittest.skip("Install Toil with the 'cwl' extra to include this test.")(test_item)
+ except:
+ raise
+ else:
+ return test_item
+
+
+def needs_appliance(test_item):
+ import json
+ test_item = _mark_test('appliance', test_item)
+ if next(which('docker'), None):
+ image = applianceSelf()
+ try:
+ images = check_output(['docker', 'inspect', image])
+ except CalledProcessError:
+ images = []
+ else:
+ images = {i['Id'] for i in json.loads(images) if image in i['RepoTags']}
+ if len(images) == 0:
+ return unittest.skip("Cannot find appliance image %s. Be sure to run 'make docker' "
+ "prior to running this test." % image)(test_item)
+ elif len(images) == 1:
+ return test_item
+ else:
+ assert False, 'Expected `docker inspect` to return zero or one image.'
+ else:
+ return unittest.skip('Install Docker to include this test.')(test_item)
+
+
+def experimental(test_item):
+ """
+ Use this to decorate experimental or brittle tests in order to skip them during regular builds.
+ """
+ # We'll pytest.mark_test the test as experimental but we'll also unittest.skip it via an
+ # environment variable.
+ test_item = _mark_test('experimental', test_item)
+ if less_strict_bool(os.getenv('TOIL_TEST_EXPERIMENTAL')):
+ return test_item
+ else:
+ return unittest.skip(
+ 'Set TOIL_TEST_EXPERIMENTAL="True" to include this experimental test.')(test_item)
+
+
+def integrative(test_item):
+ """
+ Use this to decorate integration tests so as to skip them during regular builds. We define
+ integration tests as A) involving other, non-Toil software components that we develop and/or
+ B) having a higher cost (time or money). Note that brittleness does not qualify a test for
+ being integrative. Neither does involvement of external services such as AWS, since that
+ would cover most of Toil's test.
+ """
+ # We'll pytest.mark_test the test as integrative but we'll also unittest.skip it via an
+ # environment variable.
+ test_item = _mark_test('integrative', test_item)
+ if less_strict_bool(os.getenv('TOIL_TEST_INTEGRATIVE')):
+ return test_item
+ else:
+ return unittest.skip(
+ 'Set TOIL_TEST_INTEGRATIVE="True" to include this integration test.')(test_item)
+
+
+methodNamePartRegex = re.compile('^[a-zA-Z_0-9]+$')
+
+
+ at contextmanager
+def timeLimit(seconds):
+ """
+ http://stackoverflow.com/a/601168
+ Use to limit the execution time of a function. Raises an exception if the execution of the
+ function takes more than the specified amount of time.
+
+ :param seconds: maximum allowable time, in seconds
+ >>> import time
+ >>> with timeLimit(5):
+ ... time.sleep(4)
+ >>> import time
+ >>> with timeLimit(5):
+ ... time.sleep(6)
+ Traceback (most recent call last):
+ ...
+ RuntimeError: Timed out
+ """
+
+ # noinspection PyUnusedLocal
+ def signal_handler(signum, frame):
+ raise RuntimeError('Timed out')
+
+ signal.signal(signal.SIGALRM, signal_handler)
+ signal.alarm(seconds)
+ try:
+ yield
+ finally:
+ signal.alarm(0)
+
+
+# FIXME: move to bd2k-python-lib
+
+
+def make_tests(generalMethod, targetClass=None, **kwargs):
+ """
+ This method dynamically generates test methods using the generalMethod as a template. Each
+ generated function is the result of a unique combination of parameters applied to the
+ generalMethod. Each of the parameters has a corresponding string that will be used to name
+ the method. These generated functions are named in the scheme: test_[generalMethodName]___[
+ firstParamaterName]_[someValueName]__[secondParamaterName]_...
+
+ The arguments following the generalMethodName should be a series of one or more dictionaries
+ of the form {str : type, ...} where the key represents the name of the value. The names will
+ be used to represent the permutation of values passed for each parameter in the generalMethod.
+
+ :param generalMethod: A method that will be parametrized with values passed as kwargs. Note
+ that the generalMethod must be a regular method.
+
+ :param targetClass: This represents the class to which the generated test methods will be
+ bound. If no targetClass is specified the class of the generalMethod is assumed the
+ target.
+
+ :param kwargs: a series of dictionaries defining values, and their respective names where
+ each keyword is the name of a parameter in generalMethod.
+
+ >>> class Foo:
+ ... def has(self, num, letter):
+ ... return num, letter
+ ...
+ ... def hasOne(self, num):
+ ... return num
+
+ >>> class Bar(Foo):
+ ... pass
+
+ >>> make_tests(Foo.has, targetClass=Bar, num={'one':1, 'two':2}, letter={'a':'a', 'b':'b'})
+
+ >>> b = Bar()
+
+ >>> assert b.test_has__num_one__letter_a() == b.has(1, 'a')
+
+ >>> assert b.test_has__num_one__letter_b() == b.has(1, 'b')
+
+ >>> assert b.test_has__num_two__letter_a() == b.has(2, 'a')
+
+ >>> assert b.test_has__num_two__letter_b() == b.has(2, 'b')
+
+ >>> f = Foo()
+
+ >>> hasattr(f, 'test_has__num_one__letter_a') # should be false because Foo has no test methods
+ False
+
+ >>> make_tests(Foo.has, num={'one':1, 'two':2}, letter={'a':'a', 'b':'b'})
+
+ >>> hasattr(f, 'test_has__num_one__letter_a')
+ True
+
+ >>> assert f.test_has__num_one__letter_a() == f.has(1, 'a')
+
+ >>> assert f.test_has__num_one__letter_b() == f.has(1, 'b')
+
+ >>> assert f.test_has__num_two__letter_a() == f.has(2, 'a')
+
+ >>> assert f.test_has__num_two__letter_b() == f.has(2, 'b')
+
+ >>> make_tests(Foo.hasOne, num={'one':1, 'two':2})
+
+ >>> assert f.test_hasOne__num_one() == f.hasOne(1)
+
+ >>> assert f.test_hasOne__num_two() == f.hasOne(2)
+
+ """
+
+ def pop(d):
+ """
+ Pops an arbitrary key value pair from a given dict.
+
+ :param d: a dictionary
+
+ :return: the popped key, value tuple
+ """
+ k, v = next(iter(iteritems(kwargs)))
+ d.pop(k)
+ return k, v
+
+ def permuteIntoLeft(left, rParamName, right):
+ """
+ Permutes values in right dictionary into each parameter: value dict pair in the left
+ dictionary. Such that the left dictionary will contain a new set of keys each of which is
+ a combination of one of its original parameter-value names appended with some
+ parameter-value name from the right dictionary. Each original key in the left is deleted
+ from the left dictionary after the permutation of the key and every parameter-value name
+ from the right has been added to the left dictionary.
+
+ For example if left is {'__PrmOne_ValName':{'ValName':Val}} and right is
+ {'rValName1':rVal1, 'rValName2':rVal2} then left will become
+ {'__PrmOne_ValName__rParamName_rValName1':{'ValName':Val. 'rValName1':rVal1},
+ '__PrmOne_ValName__rParamName_rValName2':{'ValName':Val. 'rValName2':rVal2}}
+
+ :param left: A dictionary pairing each paramNameValue to a nested dictionary that
+ contains each ValueName and value pair described in the outer dict's paramNameValue
+ key.
+
+ :param rParamName: The name of the parameter that each value in the right dict represents.
+
+ :param right: A dict that pairs 1 or more valueNames and values for the rParamName
+ parameter.
+ """
+ for prmValName, lDict in left.items():
+ for rValName, rVal in right.items():
+ nextPrmVal = ('__%s_%s' % (rParamName, rValName.lower()))
+ if methodNamePartRegex.match(nextPrmVal) is None:
+ raise RuntimeError("The name '%s' cannot be used in a method name" % pvName)
+ aggDict = dict(lDict)
+ aggDict[rParamName] = rVal
+ left[prmValName + nextPrmVal] = aggDict
+ left.pop(prmValName)
+
+ def insertMethodToClass():
+ """
+ Generates and inserts test methods.
+ """
+
+ def fx(self, prms=prms):
+ if prms is not None:
+ return generalMethod(self, **prms)
+ else:
+ return generalMethod(self)
+
+ setattr(targetClass, 'test_%s%s' % (generalMethod.__name__, prmNames), fx)
+
+ if len(kwargs) > 0:
+ # create first left dict
+ left = {}
+ prmName, vals = pop(kwargs)
+ for valName, val in vals.items():
+ pvName = '__%s_%s' % (prmName, valName.lower())
+ if methodNamePartRegex.match(pvName) is None:
+ raise RuntimeError("The name '%s' cannot be used in a method name" % pvName)
+ left[pvName] = {prmName: val}
+
+ # get cartesian product
+ while len(kwargs) > 0:
+ permuteIntoLeft(left, *pop(kwargs))
+
+ # set class attributes
+ targetClass = targetClass or generalMethod.im_class
+ for prmNames, prms in left.items():
+ insertMethodToClass()
+ else:
+ prms = None
+ prmNames = ""
+ insertMethodToClass()
+
+
+ at contextmanager
+def tempFileContaining(content, suffix=''):
+ fd, path = tempfile.mkstemp(suffix=suffix)
+ try:
+ os.write(fd, content)
+ except:
+ os.close(fd)
+ raise
+ else:
+ os.close(fd)
+ yield path
+ finally:
+ os.unlink(path)
+
+
+class ApplianceTestSupport(ToilTest):
+ """
+ A Toil test that runs a user script on a minimal cluster of appliance containers,
+ i.e. one leader container and one worker container.
+ """
+
+ @contextmanager
+ def _applianceCluster(self, mounts=None, numCores=None):
+ """
+ A context manager for creating and tearing down an appliance cluster.
+
+ :param dict|None mounts: Dictionary mapping host paths to container paths. Both the leader
+ and the worker container will be started with one -v argument per dictionary entry,
+ as in -v KEY:VALUE.
+
+ Beware that if KEY is a path to a directory, its entire content will be deleted
+ when the cluster is torn down.
+
+ :param int numCores: The number of cores to be offered by the Mesos slave process running
+ in the worker container.
+
+ :rtype: (ApplianceTestSupport.Appliance, ApplianceTestSupport.Appliance)
+
+ :return: A tuple of the form `(leader, worker)` containing the Appliance instances
+ representing the respective appliance containers
+ """
+ if numCores is None:
+ numCores = multiprocessing.cpu_count()
+ # The last container to stop (and the first to start) should clean the mounts.
+ with self.LeaderThread(self, mounts, cleanMounts=True) as leader:
+ with self.WorkerThread(self, mounts, numCores) as worker:
+ yield leader, worker
+
+ class Appliance(ExceptionalThread):
+ __metaclass__ = ABCMeta
+
+ @abstractmethod
+ def _getRole(self):
+ return 'leader'
+
+ @abstractmethod
+ def _containerCommand(self):
+ raise NotImplementedError()
+
+ @abstractmethod
+ def _entryPoint(self):
+ raise NotImplementedError()
+
+ # Lock is used because subprocess is NOT thread safe: http://tinyurl.com/pkp5pgq
+ lock = threading.Lock()
+
+ def __init__(self, outer, mounts, cleanMounts=False):
+ """
+ :param ApplianceTestSupport outer:
+ """
+ assert all(' ' not in v for v in itervalues(mounts)), 'No spaces allowed in mounts'
+ super(ApplianceTestSupport.Appliance, self).__init__()
+ self.outer = outer
+ self.mounts = mounts
+ self.cleanMounts = cleanMounts
+ self.containerName = str(uuid.uuid4())
+ self.popen = None
+
+ def __enter__(self):
+ with self.lock:
+ image = applianceSelf()
+ # Omitting --rm, it's unreliable, see https://github.com/docker/docker/issues/16575
+ args = list(concat('docker', 'run',
+ '--entrypoint=' + self._entryPoint(),
+ '--net=host',
+ '-i',
+ '--name=' + self.containerName,
+ ['--volume=%s:%s' % mount for mount in iteritems(self.mounts)],
+ image,
+ self._containerCommand()))
+ log.info('Running %r', args)
+ self.popen = Popen(args)
+ self.start()
+ self.__wait_running()
+ return self
+
+ # noinspection PyUnusedLocal
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ try:
+ try:
+ self.outer._run('docker', 'stop', self.containerName)
+ self.join()
+ finally:
+ if self.cleanMounts:
+ self.__cleanMounts()
+ finally:
+ self.outer._run('docker', 'rm', '-f', self.containerName)
+ return False # don't swallow exception
+
+ def __wait_running(self):
+ log.info("Waiting for %s container process to appear. "
+ "Expect to see 'Error: No such image or container'.", self._getRole())
+ while self.isAlive():
+ try:
+ running = self.outer._run('docker', 'inspect',
+ '--format={{ .State.Running }}',
+ self.containerName,
+ capture=True).strip()
+ except CalledProcessError:
+ pass
+ else:
+ if 'true' == running:
+ break
+ time.sleep(1)
+
+ def __cleanMounts(self):
+ """
+ Deletes all files in every mounted directory. Without this step, we risk leaking
+ files owned by root on the host. To avoid races, this method should be called after
+ the appliance container was stopped, otherwise the running container might still be
+ writing files.
+ """
+ # Delete all files within each mounted directory, but not the directory itself.
+ cmd = 'shopt -s dotglob && rm -rf ' + ' '.join(v + '/*'
+ for k, v in iteritems(self.mounts)
+ if os.path.isdir(k))
+ self.outer._run('docker', 'run',
+ '--rm',
+ '--entrypoint=/bin/bash',
+ applianceSelf(),
+ '-c',
+ cmd)
+
+ def tryRun(self):
+ self.popen.wait()
+ log.info('Exiting %s', self.__class__.__name__)
+
+ def runOnAppliance(self, *args, **kwargs):
+ # Check if thread is still alive. Note that ExceptionalThread.join raises the
+ # exception that occurred in the thread.
+ self.join(timeout=0)
+ # noinspection PyProtectedMember
+ self.outer._run('docker', 'exec', '-i', self.containerName, *args, **kwargs)
+
+ def writeToAppliance(self, path, contents):
+ self.runOnAppliance('tee', path, input=contents)
+
+ def deployScript(self, path, packagePath, script):
+ """
+ Deploy a Python module on the appliance.
+
+ :param path: the path (absolute or relative to the WORDIR of the appliance container)
+ to the root of the package hierarchy where the given module should be placed.
+ The given directory should be on the Python path.
+
+ :param packagePath: the desired fully qualified module name (dotted form) of the module
+
+ :param str|callable script: the contents of the Python module. If a callable is given,
+ its source code will be extracted. This is a convenience that lets you embed
+ user scripts into test code as nested function.
+ """
+ if callable(script):
+ script = self.outer._getScriptSource(script)
+ packagePath = packagePath.split('.')
+ packages, module = packagePath[:-1], packagePath[-1]
+ for package in packages:
+ path += '/' + package
+ self.runOnAppliance('mkdir', '-p', path)
+ self.writeToAppliance(path + '/__init__.py', '')
+ self.writeToAppliance(path + '/' + module + '.py', script)
+
+ class LeaderThread(Appliance):
+ def _getRole(self):
+ return 'leader'
+
+ def _entryPoint(self):
+ return 'mesos-master'
+
+ def _containerCommand(self):
+ return ['--registry=in_memory',
+ '--ip=127.0.0.1',
+ '--port=5050',
+ '--allocation_interval=500ms']
+
+ class WorkerThread(Appliance):
+ def __init__(self, outer, mounts, numCores):
+ self.numCores = numCores
+ super(ApplianceTestSupport.WorkerThread, self).__init__(outer, mounts)
+
+ def _entryPoint(self):
+ return 'mesos-slave'
+
+ def _getRole(self):
+ return 'worker'
+
+ def _containerCommand(self):
+ return ['--work_dir=/var/lib/mesos',
+ '--ip=127.0.0.1',
+ '--master=127.0.0.1:5050',
+ '--attributes=preemptable:False',
+ '--resources=cpus(*):%i' % self.numCores]
diff --git a/src/toil/toilState.py b/src/toil/toilState.py
new file mode 100644
index 0000000..06abbcc
--- /dev/null
+++ b/src/toil/toilState.py
@@ -0,0 +1,174 @@
+# Copyright (C) 2015-2016 Regents of the University of California
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+
+import logging
+
+logger = logging.getLogger( __name__ )
+
+class ToilState( object ):
+ """
+ Represents a snapshot of the jobs in the jobStore. Used by the leader to manage the batch.
+ """
+ def __init__( self, jobStore, rootJob, jobCache=None):
+ """
+ Loads the state from the jobStore, using the rootJob
+ as the source of the job graph.
+
+ The jobCache is a map from jobStoreIDs to jobGraphs or None. Is used to
+ speed up the building of the state.
+
+ :param toil.jobStores.abstractJobStore.AbstractJobStore jobStore
+ :param toil.jobWrapper.JobGraph rootJob
+ """
+ # This is a hash of jobs, referenced by jobStoreID, to their predecessor jobs.
+ self.successorJobStoreIDToPredecessorJobs = { }
+
+ # Hash of jobStoreIDs to counts of numbers of successors issued.
+ # There are no entries for jobs
+ # without successors in this map.
+ self.successorCounts = { }
+
+ # This is a hash of service jobs, referenced by jobStoreID, to their predecessor job
+ self.serviceJobStoreIDToPredecessorJob = { }
+
+ # Hash of jobStoreIDs to maps of services issued for the job
+ # Each for job, the map is a dictionary of service jobStoreIDs
+ # to the flags used to communicate the with service
+ self.servicesIssued = { }
+
+ # Jobs that are ready to be processed
+ self.updatedJobs = set( )
+
+ # The set of totally failed jobs - this needs to be filtered at the
+ # end to remove jobs that were removed by checkpoints
+ self.totalFailedJobs = set()
+
+ # Jobs (as jobStoreIDs) with successors that have totally failed
+ self.hasFailedSuccessors = set()
+
+ # The set of successors of failed jobs as a set of jobStoreIds
+ self.failedSuccessors = set()
+
+ # Set of jobs that have multiple predecessors that have one or more predecessors
+ # finished, but not all of them. This acts as a cache for these jobs.
+ # Stored as hash from jobStoreIDs to job graphs
+ self.jobsToBeScheduledWithMultiplePredecessors = {}
+
+ ##Algorithm to build this information
+ logger.info("(Re)building internal scheduler state")
+ self._buildToilState(rootJob, jobStore, jobCache)
+
+ def _buildToilState(self, jobGraph, jobStore, jobCache=None):
+ """
+ Traverses tree of jobs from the root jobGraph (rootJob) building the
+ ToilState class.
+
+ If jobCache is passed, it must be a dict from job ID to JobGraph
+ object. Jobs will be loaded from the cache (which can be downloaded from
+ the jobStore in a batch) instead of piecemeal when recursed into.
+ """
+
+ def getJob(jobId):
+ if jobCache is not None:
+ try:
+ return jobCache[jobId]
+ except ValueError:
+ return jobStore.load(jobId)
+ else:
+ return jobStore.load(jobId)
+
+ # If the jobGraph has a command, is a checkpoint, has services or is ready to be
+ # deleted it is ready to be processed
+ if (jobGraph.command is not None
+ or jobGraph.checkpoint is not None
+ or len(jobGraph.services) > 0
+ or len(jobGraph.stack) == 0):
+ logger.debug('Found job to run: %s, with command: %s, with checkpoint: %s, '
+ 'with services: %s, with stack: %s', jobGraph.jobStoreID,
+ jobGraph.command is not None, jobGraph.checkpoint is not None,
+ len(jobGraph.services) > 0, len(jobGraph.stack) == 0)
+ self.updatedJobs.add((jobGraph, 0))
+
+ if jobGraph.checkpoint is not None:
+ jobGraph.command = jobGraph.checkpoint
+
+ else: # There exist successors
+ logger.debug("Adding job: %s to the state with %s successors" % (jobGraph.jobStoreID, len(jobGraph.stack[-1])))
+
+ # Record the number of successors
+ self.successorCounts[jobGraph.jobStoreID] = len(jobGraph.stack[-1])
+
+ def processSuccessorWithMultiplePredecessors(successorJobGraph):
+ # If jobGraph is not reported as complete by the successor
+ if jobGraph.jobStoreID not in successorJobGraph.predecessorsFinished:
+
+ # Update the sucessor's status to mark the predecessor complete
+ successorJobGraph.predecessorsFinished.add(jobGraph.jobStoreID)
+
+ # If the successor has no predecessors to finish
+ assert len(successorJobGraph.predecessorsFinished) <= successorJobGraph.predecessorNumber
+ if len(successorJobGraph.predecessorsFinished) == successorJobGraph.predecessorNumber:
+
+ # It is ready to be run, so remove it from the cache
+ self.jobsToBeScheduledWithMultiplePredecessors.pop(successorJobStoreID)
+
+ # Recursively consider the successor
+ self._buildToilState(successorJobGraph, jobStore, jobCache=jobCache)
+
+ # For each successor
+ for successorJobNode in jobGraph.stack[-1]:
+ successorJobStoreID = successorJobNode.jobStoreID
+
+ # If the successor jobGraph does not yet point back at a
+ # predecessor we have not yet considered it
+ if successorJobStoreID not in self.successorJobStoreIDToPredecessorJobs:
+
+ # Add the job as a predecessor
+ self.successorJobStoreIDToPredecessorJobs[successorJobStoreID] = [jobGraph]
+
+ # If predecessor number > 1 then the successor has multiple predecessors
+ if successorJobNode.predecessorNumber > 1:
+
+ # We load the successor job
+ successorJobGraph = getJob(successorJobStoreID)
+
+ # We put the successor job in the cache of successor jobs with multiple predecessors
+ assert successorJobStoreID not in self.jobsToBeScheduledWithMultiplePredecessors
+ self.jobsToBeScheduledWithMultiplePredecessors[successorJobStoreID] = successorJobGraph
+
+ # Process successor
+ processSuccessorWithMultiplePredecessors(successorJobGraph)
+
+ else:
+ # The successor has only the jobGraph as a predecessor so
+ # recursively consider the successor
+ self._buildToilState(getJob(successorJobStoreID), jobStore, jobCache=jobCache)
+
+ else:
+ # We've already seen the successor
+
+ # Add the job as a predecessor
+ assert jobGraph not in self.successorJobStoreIDToPredecessorJobs[successorJobStoreID]
+ self.successorJobStoreIDToPredecessorJobs[successorJobStoreID].append(jobGraph)
+
+ # If the successor has multiple predecessors
+ if successorJobStoreID in self.jobsToBeScheduledWithMultiplePredecessors:
+
+ # Get the successor from cache
+ successorJobGraph = self.jobsToBeScheduledWithMultiplePredecessors[successorJobStoreID]
+
+ # Process successor
+ processSuccessorWithMultiplePredecessors(successorJobGraph)
\ No newline at end of file
diff --git a/src/toil/utils/__init__.py b/src/toil/utils/__init__.py
new file mode 100644
index 0000000..0bc10f7
--- /dev/null
+++ b/src/toil/utils/__init__.py
@@ -0,0 +1,27 @@
+from __future__ import absolute_import
+
+from toil import version
+import logging
+
+from toil.provisioners.aws import getCurrentAWSZone
+
+logger = logging.getLogger( __name__ )
+
+
+def addBasicProvisionerOptions(parser):
+ parser.add_argument("--version", action='version', version=version)
+ parser.add_argument('-p', "--provisioner", dest='provisioner', choices=['aws'], required=True,
+ help="The provisioner for cluster auto-scaling. Only aws is currently "
+ "supported")
+ currentZone = getCurrentAWSZone()
+ zoneString = currentZone if currentZone else 'No zone could be determined'
+ parser.add_argument('-z', '--zone', dest='zone', required=False, default=currentZone,
+ help="The AWS availability zone of the master. This parameter can also be "
+ "set via the TOIL_AWS_ZONE environment variable, or by the ec2_region_name "
+ "parameter in your .boto file, or derived from the instance metadata if "
+ "using this utility on an existing EC2 instance. "
+ "Currently: %s" % zoneString)
+ parser.add_argument("clusterName", help="The name that the cluster will be identifiable by. "
+ "Must be lowercase and may not contain the '_' "
+ "character.")
+ return parser
diff --git a/src/toil/utils/toilClean.py b/src/toil/utils/toilClean.py
new file mode 100644
index 0000000..c785d38
--- /dev/null
+++ b/src/toil/utils/toilClean.py
@@ -0,0 +1,37 @@
+# Copyright (C) 2015-2016 Regents of the University of California
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Delete the job store used by a previous Toil workflow invocation
+"""
+from __future__ import absolute_import
+import logging
+
+from toil.lib.bioio import getBasicOptionParser
+from toil.lib.bioio import parseBasicOptions
+from toil.common import Toil, jobStoreLocatorHelp, Config
+from toil.version import version
+
+logger = logging.getLogger( __name__ )
+
+def main():
+ parser = getBasicOptionParser()
+ parser.add_argument("jobStore", type=str,
+ help="The location of the job store to delete. " + jobStoreLocatorHelp)
+ parser.add_argument("--version", action='version', version=version)
+ config = Config()
+ config.setOptions(parseBasicOptions(parser))
+ logger.info("Attempting to delete the job store")
+ jobStore = Toil.getJobStore(config.jobStore)
+ jobStore.destroy()
+ logger.info("Successfully deleted the job store")
diff --git a/src/toil/utils/toilDestroyCluster.py b/src/toil/utils/toilDestroyCluster.py
new file mode 100644
index 0000000..c03ce30
--- /dev/null
+++ b/src/toil/utils/toilDestroyCluster.py
@@ -0,0 +1,32 @@
+# Copyright (C) 2015 UCSC Computational Genomics Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Terminates the specified cluster and associated resources
+"""
+import logging
+from toil.provisioners import Cluster
+from toil.lib.bioio import parseBasicOptions, setLoggingFromOptions, getBasicOptionParser
+from toil.utils import addBasicProvisionerOptions
+
+
+logger = logging.getLogger( __name__ )
+
+
+def main():
+ parser = getBasicOptionParser()
+ parser = addBasicProvisionerOptions(parser)
+ config = parseBasicOptions(parser)
+ setLoggingFromOptions(config)
+ cluster = Cluster(provisioner=config.provisioner, clusterName=config.clusterName)
+ cluster.destroyCluster()
diff --git a/src/toil/utils/toilKill.py b/src/toil/utils/toilKill.py
new file mode 100644
index 0000000..cadf244
--- /dev/null
+++ b/src/toil/utils/toilKill.py
@@ -0,0 +1,44 @@
+# Copyright (C) 2015-2016 Regents of the University of California
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Kills any running jobs trees in a rogue toil.
+"""
+from __future__ import absolute_import
+import logging
+
+from toil.lib.bioio import getBasicOptionParser
+from toil.lib.bioio import parseBasicOptions
+from toil.common import Toil, jobStoreLocatorHelp, Config
+from toil.version import version
+
+logger = logging.getLogger( __name__ )
+
+def main():
+ parser = getBasicOptionParser()
+
+ parser.add_argument("jobStore", type=str,
+ help="The location of the job store used by the workflow whose jobs should "
+ "be killed." + jobStoreLocatorHelp)
+ parser.add_argument("--version", action='version', version=version)
+ options = parseBasicOptions(parser)
+ config = Config()
+ config.setOptions(options)
+ jobStore = Toil.resumeJobStore(config.jobStore)
+
+ logger.info("Starting routine to kill running jobs in the toil workflow: %s", config.jobStore)
+ ####This behaviour is now broken
+ batchSystem = Toil.createBatchSystem(jobStore.config) #This should automatically kill the existing jobs.. so we're good.
+ for jobID in batchSystem.getIssuedBatchJobIDs(): #Just in case we do it again.
+ batchSystem.killBatchJobs(jobID)
+ logger.info("All jobs SHOULD have been killed")
diff --git a/src/toil/utils/toilLaunchCluster.py b/src/toil/utils/toilLaunchCluster.py
new file mode 100644
index 0000000..e4bae64
--- /dev/null
+++ b/src/toil/utils/toilLaunchCluster.py
@@ -0,0 +1,53 @@
+# Copyright (C) 2015 UCSC Computational Genomics Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Launches a toil leader instance with the specified provisioner
+"""
+import logging
+from toil.lib.bioio import parseBasicOptions, setLoggingFromOptions, getBasicOptionParser
+from toil.utils import addBasicProvisionerOptions
+
+logger = logging.getLogger( __name__ )
+
+
+def main():
+ parser = getBasicOptionParser()
+ parser = addBasicProvisionerOptions(parser)
+ parser.add_argument("--nodeType", dest='nodeType', required=True,
+ help="Node type for {non-|}preemptable nodes. The syntax depends on the "
+ "provisioner used. For the aws provisioner this is the name of an "
+ "EC2 instance type followed by a colon and the price in dollar to "
+ "bid for a spot instance, for example 'c3.8xlarge:0.42'.")
+ parser.add_argument("--keyPairName", dest='keyPairName', required=True,
+ help="The name of the AWS key pair to include on the instance")
+ config = parseBasicOptions(parser)
+ setLoggingFromOptions(config)
+ spotBid = None
+ if config.provisioner == 'aws':
+ logger.info('Using aws provisioner.')
+ try:
+ from toil.provisioners.aws.awsProvisioner import AWSProvisioner
+ except ImportError:
+ raise RuntimeError('The aws extra must be installed to use this provisioner')
+ provisioner = AWSProvisioner
+ parsedBid = config.nodeType.split(':', 1)
+ if len(config.nodeType) != len(parsedBid[0]):
+ # there is a bid
+ spotBid = float(parsedBid[1])
+ config.nodeType = parsedBid[0]
+ else:
+ assert False
+
+ provisioner.launchCluster(instanceType=config.nodeType, clusterName=config.clusterName,
+ keyName=config.keyPairName, spotBid=spotBid)
diff --git a/src/toil/utils/toilMain.py b/src/toil/utils/toilMain.py
new file mode 100755
index 0000000..e41c01c
--- /dev/null
+++ b/src/toil/utils/toilMain.py
@@ -0,0 +1,52 @@
+from __future__ import absolute_import, print_function
+from toil.version import version
+import pkg_resources
+import os
+import sys
+
+# Python 3 compatibility imports
+from six import iteritems, iterkeys
+
+def main():
+ modules = loadModules()
+ try:
+ command = sys.argv[1]
+ except IndexError:
+ printHelp(modules)
+ else:
+ if command == '--help':
+ printHelp(modules)
+ elif command == '--version':
+ try:
+ print(pkg_resources.get_distribution('toil').version)
+ except:
+ print("Version gathered from toil.version: "+version)
+ else:
+ try:
+ module = modules[command]
+ except KeyError:
+ print("Unknown option '%s'. "
+ "Pass --help to display usage information.\n" % command, file=sys.stderr)
+ sys.exit(1)
+ else:
+ del sys.argv[1]
+ module.main()
+
+
+def loadModules():
+ # noinspection PyUnresolvedReferences
+ from toil.utils import toilKill, toilStats, toilStatus, toilClean, toilLaunchCluster, toilDestroyCluster, toilSSHCluster, toilRsyncCluster
+ commandMapping = {name[4:].lower(): module for name, module in iteritems(locals())}
+ commandMapping = {name[:-7]+'-'+name[-7:] if name.endswith('cluster') else name: module for name, module in iteritems(commandMapping)}
+ return commandMapping
+
+def printHelp(modules):
+ usage = ("\n"
+ "Usage: {name} COMMAND ...\n"
+ " {name} --help\n"
+ " {name} COMMAND --help\n\n"
+ "where COMMAND is one of the following:\n\n{descriptions}\n\n")
+ print(usage.format(
+ name=os.path.basename(sys.argv[0]),
+ commands='|'.join(iterkeys(modules)),
+ descriptions='\n'.join("%s - %s" % (n, m.__doc__.strip()) for n, m in iteritems(modules))))
diff --git a/src/toil/utils/toilRsyncCluster.py b/src/toil/utils/toilRsyncCluster.py
new file mode 100644
index 0000000..a0cb2c0
--- /dev/null
+++ b/src/toil/utils/toilRsyncCluster.py
@@ -0,0 +1,40 @@
+# Copyright (C) 2015 UCSC Computational Genomics Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Rsyncs into the toil appliance container running on the leader of the cluster
+"""
+import argparse
+import logging
+
+from toil.lib.bioio import parseBasicOptions, setLoggingFromOptions, getBasicOptionParser
+from toil.provisioners import Cluster
+from toil.utils import addBasicProvisionerOptions
+
+
+logger = logging.getLogger(__name__)
+
+
+def main():
+ parser = getBasicOptionParser()
+ parser = addBasicProvisionerOptions(parser)
+ parser.add_argument("args", nargs=argparse.REMAINDER, help="Arguments to pass to"
+ "`rsync`. Takes any arguments that rsync accepts. Specify the"
+ " remote with a colon. For example, to upload `example.py`,"
+ " specify `toil rsync-cluster -p aws test-cluster example.py :`."
+ "\nOr, to download a file from the remote:, `toil rsync-cluster"
+ " -p aws test-cluster :example.py .`")
+ config = parseBasicOptions(parser)
+ setLoggingFromOptions(config)
+ cluster = Cluster(provisioner=config.provisioner, clusterName=config.clusterName)
+ cluster.rsyncCluster(args=config.args)
diff --git a/src/toil/utils/toilSSHCluster.py b/src/toil/utils/toilSSHCluster.py
new file mode 100644
index 0000000..f98e7c2
--- /dev/null
+++ b/src/toil/utils/toilSSHCluster.py
@@ -0,0 +1,33 @@
+# Copyright (C) 2015 UCSC Computational Genomics Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+SSHs into the toil appliance container running on the leader of the cluster
+"""
+import argparse
+import logging
+from toil.provisioners import Cluster
+from toil.lib.bioio import parseBasicOptions, setLoggingFromOptions, getBasicOptionParser
+from toil.utils import addBasicProvisionerOptions
+
+logger = logging.getLogger( __name__ )
+
+
+def main():
+ parser = getBasicOptionParser()
+ parser = addBasicProvisionerOptions(parser)
+ parser.add_argument('args', nargs=argparse.REMAINDER)
+ config = parseBasicOptions(parser)
+ setLoggingFromOptions(config)
+ cluster = Cluster(provisioner=config.provisioner, clusterName=config.clusterName)
+ cluster.sshCluster(args=config.args)
diff --git a/src/toil/utils/toilStats.py b/src/toil/utils/toilStats.py
new file mode 100644
index 0000000..61aa3c2
--- /dev/null
+++ b/src/toil/utils/toilStats.py
@@ -0,0 +1,605 @@
+# Copyright (C) 2015-2016 Regents of the University of California
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Reports statistical data about a given Toil workflow.
+"""
+
+from __future__ import absolute_import, print_function
+from functools import partial
+import logging
+import json
+from toil.lib.bioio import getBasicOptionParser
+from toil.lib.bioio import parseBasicOptions
+from toil.common import Toil, jobStoreLocatorHelp, Config
+from toil.version import version
+from bd2k.util.expando import Expando
+
+logger = logging.getLogger( __name__ )
+
+
+class ColumnWidths(object):
+ """
+ Convenience object that stores the width of columns for printing. Helps make things pretty.
+ """
+ def __init__(self):
+ self.categories = ["time", "clock", "wait", "memory"]
+ self.fields_count = ["count", "min", "med", "ave", "max", "total"]
+ self.fields = ["min", "med", "ave", "max", "total"]
+ self.data = {}
+ for category in self.categories:
+ for field in self.fields_count:
+ self.setWidth(category, field, 8)
+ def title(self, category):
+ """ Return the total printed length of this category item.
+ """
+ return sum(
+ map(lambda x: self.getWidth(category, x), self.fields))
+ def getWidth(self, category, field):
+ category = category.lower()
+ return self.data["%s_%s" % (category, field)]
+ def setWidth(self, category, field, width):
+ category = category.lower()
+ self.data["%s_%s" % (category, field)] = width
+ def report(self):
+ for c in self.categories:
+ for f in self.fields:
+ print('%s %s %d' % (c, f, self.getWidth(c, f)))
+
+def initializeOptions(parser):
+ parser.add_argument("jobStore", type=str,
+ help="The location of the job store used by the workflow for which "
+ "statistics should be reported. " + jobStoreLocatorHelp)
+ parser.add_argument("--outputFile", dest="outputFile", default=None,
+ help="File in which to write results")
+ parser.add_argument("--raw", action="store_true", default=False,
+ help="output the raw json data.")
+ parser.add_argument("--pretty", "--human", action="store_true", default=False,
+ help=("if not raw, prettify the numbers to be "
+ "human readable."))
+ parser.add_argument("--categories",
+ help=("comma separated list from [time, clock, wait, "
+ "memory]"))
+ parser.add_argument("--sortCategory", default="time",
+ help=("how to sort Job list. may be from [alpha, "
+ "time, clock, wait, memory, count]. "
+ "default=%(default)s"))
+ parser.add_argument("--sortField", default="med",
+ help=("how to sort Job list. may be from [min, "
+ "med, ave, max, total]. "
+ "default=%(default)s"))
+ parser.add_argument("--sortReverse", "--reverseSort", default=False,
+ action="store_true",
+ help="reverse sort order.")
+ parser.add_argument("--version", action='version', version=version)
+
+def checkOptions(options, parser):
+ """ Check options, throw parser.error() if something goes wrong
+ """
+ logger.info("Parsed arguments")
+ logger.info("Checking if we have files for toil")
+ if options.jobStore == None:
+ parser.error("Specify --jobStore")
+ defaultCategories = ["time", "clock", "wait", "memory"]
+ if options.categories is None:
+ options.categories = defaultCategories
+ else:
+ options.categories = map(lambda x: x.lower(),
+ options.categories.split(","))
+ for c in options.categories:
+ if c not in defaultCategories:
+ parser.error("Unknown category %s. Must be from %s"
+ % (c, str(defaultCategories)))
+ extraSort = ["count", "alpha"]
+ if options.sortCategory is not None:
+ if (options.sortCategory not in defaultCategories and
+ options.sortCategory not in extraSort):
+ parser.error("Unknown --sortCategory %s. Must be from %s"
+ % (options.sortCategory,
+ str(defaultCategories + extraSort)))
+ sortFields = ["min", "med", "ave", "max", "total"]
+ if options.sortField is not None:
+ if (options.sortField not in sortFields):
+ parser.error("Unknown --sortField %s. Must be from %s"
+ % (options.sortField, str(sortFields)))
+ logger.info("Checked arguments")
+
+def printJson(elem):
+ """ Return a JSON formatted string
+ """
+ prettyString = json.dumps(elem, indent=4, separators=(',',': '))
+ return prettyString
+
+def padStr(s, field=None):
+ """ Pad the begining of a string with spaces, if necessary.
+ """
+ if field is None:
+ return s
+ else:
+ if len(s) >= field:
+ return s
+ else:
+ return " " * (field - len(s)) + s
+
+def prettyMemory(k, field=None, isBytes=False):
+ """ Given input k as kilobytes, return a nicely formatted string.
+ """
+ if isBytes:
+ k /= 1024
+ if k < 1024:
+ return padStr("%gK" % k, field)
+ if k < (1024 * 1024):
+ return padStr("%.1fM" % (k / 1024.0), field)
+ if k < (1024 * 1024 * 1024):
+ return padStr("%.1fG" % (k / 1024.0 / 1024.0), field)
+ if k < (1024 * 1024 * 1024 * 1024):
+ return padStr("%.1fT" % (k / 1024.0 / 1024.0 / 1024.0), field)
+ if k < (1024 * 1024 * 1024 * 1024 * 1024):
+ return padStr("%.1fP" % (k / 1024.0 / 1024.0 / 1024.0 / 1024.0), field)
+
+def prettyTime(t, field=None):
+ """ Given input t as seconds, return a nicely formatted string.
+ """
+ from math import floor
+ pluralDict = {True: "s", False: ""}
+ if t < 120:
+ return padStr("%ds" % t, field)
+ if t < 120 * 60:
+ m = floor(t / 60.)
+ s = t % 60
+ return padStr("%dm%ds" % (m, s), field)
+ if t < 25 * 60 * 60:
+ h = floor(t / 60. / 60.)
+ m = floor((t - (h * 60. * 60.)) / 60.)
+ s = t % 60
+ return padStr("%dh%gm%ds" % (h, m, s), field)
+ if t < 7 * 24 * 60 * 60:
+ d = floor(t / 24. / 60. / 60.)
+ h = floor((t - (d * 24. * 60. * 60.)) / 60. / 60.)
+ m = floor((t
+ - (d * 24. * 60. * 60.)
+ - (h * 60. * 60.))
+ / 60.)
+ s = t % 60
+ dPlural = pluralDict[d > 1]
+ return padStr("%dday%s%dh%dm%ds" % (d, dPlural, h, m, s), field)
+ w = floor(t / 7. / 24. / 60. / 60.)
+ d = floor((t - (w * 7 * 24 * 60 * 60)) / 24. / 60. / 60.)
+ h = floor((t
+ - (w * 7. * 24. * 60. * 60.)
+ - (d * 24. * 60. * 60.))
+ / 60. / 60.)
+ m = floor((t
+ - (w * 7. * 24. * 60. * 60.)
+ - (d * 24. * 60. * 60.)
+ - (h * 60. * 60.))
+ / 60.)
+ s = t % 60
+ wPlural = pluralDict[w > 1]
+ dPlural = pluralDict[d > 1]
+ return padStr("%dweek%s%dday%s%dh%dm%ds" % (w, wPlural, d,
+ dPlural, h, m, s), field)
+
+def reportTime(t, options, field=None):
+ """ Given t seconds, report back the correct format as string.
+ """
+ if options.pretty:
+ return prettyTime(t, field=field)
+ else:
+ if field is not None:
+ return "%*.2f" % (field, t)
+ else:
+ return "%.2f" % t
+
+def reportMemory(k, options, field=None, isBytes=False):
+ """ Given k kilobytes, report back the correct format as string.
+ """
+ if options.pretty:
+ return prettyMemory(int(k), field=field, isBytes=isBytes)
+ else:
+ if isBytes:
+ k /= 1024.
+ if field is not None:
+ return "%*dK" % (field - 1, k) # -1 for the "K"
+ else:
+ return "%dK" % int(k)
+
+def reportNumber(n, options, field=None):
+ """ Given n an integer, report back the correct format as string.
+ """
+ if field is not None:
+ return "%*g" % (field, n)
+ else:
+ return "%g" % n
+
+def refineData(root, options):
+ """ walk down from the root and gather up the important bits.
+ """
+ worker = root.worker
+ job = root.jobs
+ jobTypesTree = root.job_types
+ jobTypes = []
+ for childName in jobTypesTree:
+ jobTypes.append(jobTypesTree[childName])
+ return root, worker, job, jobTypes
+
+def sprintTag(key, tag, options, columnWidths=None):
+ """ Generate a pretty-print ready string from a JTTag().
+ """
+ if columnWidths is None:
+ columnWidths = ColumnWidths()
+ header = " %7s " % decorateTitle("Count", options)
+ sub_header = " %7s " % "n"
+ tag_str = " %s" % reportNumber(tag.total_number, options, field=7)
+ out_str = ""
+ if key == "job":
+ out_str += " %-12s | %7s%7s%7s%7s\n" % ("Worker Jobs", "min",
+ "med", "ave", "max")
+ worker_str = "%s| " % (" " * 14)
+ for t in [tag.min_number_per_worker, tag.median_number_per_worker,
+ tag.average_number_per_worker, tag.max_number_per_worker]:
+ worker_str += reportNumber(t, options, field=7)
+ out_str += worker_str + "\n"
+ if "time" in options.categories:
+ header += "| %*s " % (columnWidths.title("time"),
+ decorateTitle("Time", options))
+ sub_header += decorateSubHeader("Time", columnWidths, options)
+ tag_str += " | "
+ for t, width in [
+ (tag.min_time, columnWidths.getWidth("time", "min")),
+ (tag.median_time, columnWidths.getWidth("time", "med")),
+ (tag.average_time, columnWidths.getWidth("time", "ave")),
+ (tag.max_time, columnWidths.getWidth("time", "max")),
+ (tag.total_time, columnWidths.getWidth("time", "total")),
+ ]:
+ tag_str += reportTime(t, options, field=width)
+ if "clock" in options.categories:
+ header += "| %*s " % (columnWidths.title("clock"),
+ decorateTitle("Clock", options))
+ sub_header += decorateSubHeader("Clock", columnWidths, options)
+ tag_str += " | "
+ for t, width in [
+ (tag.min_clock, columnWidths.getWidth("clock", "min")),
+ (tag.median_clock, columnWidths.getWidth("clock", "med")),
+ (tag.average_clock, columnWidths.getWidth("clock", "ave")),
+ (tag.max_clock, columnWidths.getWidth("clock", "max")),
+ (tag.total_clock, columnWidths.getWidth("clock", "total")),
+ ]:
+ tag_str += reportTime(t, options, field=width)
+ if "wait" in options.categories:
+ header += "| %*s " % (columnWidths.title("wait"),
+ decorateTitle("Wait", options))
+ sub_header += decorateSubHeader("Wait", columnWidths, options)
+ tag_str += " | "
+ for t, width in [
+ (tag.min_wait, columnWidths.getWidth("wait", "min")),
+ (tag.median_wait, columnWidths.getWidth("wait", "med")),
+ (tag.average_wait, columnWidths.getWidth("wait", "ave")),
+ (tag.max_wait, columnWidths.getWidth("wait", "max")),
+ (tag.total_wait, columnWidths.getWidth("wait", "total")),
+ ]:
+ tag_str += reportTime(t, options, field=width)
+ if "memory" in options.categories:
+ header += "| %*s " % (columnWidths.title("memory"),
+ decorateTitle("Memory", options))
+ sub_header += decorateSubHeader("Memory", columnWidths, options)
+ tag_str += " | "
+ for t, width in [
+ (tag.min_memory, columnWidths.getWidth("memory", "min")),
+ (tag.median_memory, columnWidths.getWidth("memory", "med")),
+ (tag.average_memory, columnWidths.getWidth("memory", "ave")),
+ (tag.max_memory, columnWidths.getWidth("memory", "max")),
+ (tag.total_memory, columnWidths.getWidth("memory", "total")),
+ ]:
+ tag_str += reportMemory(t, options, field=width, isBytes=True)
+ out_str += header + "\n"
+ out_str += sub_header + "\n"
+ out_str += tag_str + "\n"
+ return out_str
+
+def decorateTitle(title, options):
+ """ Add a marker to TITLE if the TITLE is sorted on.
+ """
+ if title.lower() == options.sortCategory:
+ return "%s*" % title
+ else:
+ return title
+
+def decorateSubHeader(title, columnWidths, options):
+ """ Add a marker to the correct field if the TITLE is sorted on.
+ """
+ title = title.lower()
+ if title != options.sortCategory:
+ s = "| %*s%*s%*s%*s%*s " % (
+ columnWidths.getWidth(title, "min"), "min",
+ columnWidths.getWidth(title, "med"), "med",
+ columnWidths.getWidth(title, "ave"), "ave",
+ columnWidths.getWidth(title, "max"), "max",
+ columnWidths.getWidth(title, "total"), "total")
+ return s
+ else:
+ s = "| "
+ for field, width in [("min", columnWidths.getWidth(title, "min")),
+ ("med", columnWidths.getWidth(title, "med")),
+ ("ave", columnWidths.getWidth(title, "ave")),
+ ("max", columnWidths.getWidth(title, "max")),
+ ("total", columnWidths.getWidth(title, "total"))]:
+ if options.sortField == field:
+ s += "%*s*" % (width - 1, field)
+ else:
+ s += "%*s" % (width, field)
+ s += " "
+ return s
+
+def get(tree, name):
+ """ Return a float value attribute NAME from TREE.
+ """
+ if name in tree:
+ value = tree[name]
+ else:
+ return float("nan")
+ try:
+ a = float(value)
+ except ValueError:
+ a = float("nan")
+ return a
+
+def sortJobs(jobTypes, options):
+ """ Return a jobTypes all sorted.
+ """
+ longforms = {"med": "median",
+ "ave": "average",
+ "min": "min",
+ "total": "total",
+ "max": "max",}
+ sortField = longforms[options.sortField]
+ if (options.sortCategory == "time" or
+ options.sortCategory == "clock" or
+ options.sortCategory == "wait" or
+ options.sortCategory == "memory"
+ ):
+ return sorted(
+ jobTypes,
+ key=lambda tag: getattr(tag, "%s_%s"
+ % (sortField, options.sortCategory)),
+ reverse=options.sortReverse)
+ elif options.sortCategory == "alpha":
+ return sorted(
+ jobTypes, key=lambda tag: tag.name,
+ reverse=options.sortReverse)
+ elif options.sortCategory == "count":
+ return sorted(jobTypes, key=lambda tag: tag.total_number,
+ reverse=options.sortReverse)
+
+def reportPrettyData(root, worker, job, job_types, options):
+ """ print the important bits out.
+ """
+ out_str = "Batch System: %s\n" % root.batch_system
+ out_str += ("Default Cores: %s Default Memory: %s\n"
+ "Max Cores: %s Max Threads: %s\n" % (
+ reportNumber(get(root, "default_cores"), options),
+ reportMemory(get(root, "default_memory"), options, isBytes=True),
+ reportNumber(get(root, "max_cores"), options),
+ reportNumber(get(root, "max_threads"), options),
+ ))
+ out_str += ("Total Clock: %s Total Runtime: %s\n" % (
+ reportTime(get(root, "total_clock"), options),
+ reportTime(get(root, "total_run_time"), options),
+ ))
+ job_types = sortJobs(job_types, options)
+ columnWidths = computeColumnWidths(job_types, worker, job, options)
+ out_str += "Worker\n"
+ out_str += sprintTag("worker", worker, options, columnWidths=columnWidths)
+ out_str += "Job\n"
+ out_str += sprintTag("job", job, options, columnWidths=columnWidths)
+ for t in job_types:
+ out_str += " %s\n" % t.name
+ out_str += sprintTag(t.name, t, options, columnWidths=columnWidths)
+ return out_str
+
+def computeColumnWidths(job_types, worker, job, options):
+ """ Return a ColumnWidths() object with the correct max widths.
+ """
+ cw = ColumnWidths()
+ for t in job_types:
+ updateColumnWidths(t, cw, options)
+ updateColumnWidths(worker, cw, options)
+ updateColumnWidths(job, cw, options)
+ return cw
+
+def updateColumnWidths(tag, cw, options):
+ """ Update the column width attributes for this tag's fields.
+ """
+ longforms = {"med": "median",
+ "ave": "average",
+ "min": "min",
+ "total": "total",
+ "max": "max",}
+ for category in ["time", "clock", "wait", "memory"]:
+ if category in options.categories:
+ for field in ["min", "med", "ave", "max", "total"]:
+ t = getattr(tag, "%s_%s" % (longforms[field], category))
+ if category in ["time", "clock", "wait"]:
+ s = reportTime(t, options,
+ field=cw.getWidth(category, field)).strip()
+ else:
+ s = reportMemory(t, options,
+ field=cw.getWidth(category, field), isBytes=True).strip()
+ if len(s) >= cw.getWidth(category, field):
+ # this string is larger than max, width must be increased
+ cw.setWidth(category, field, len(s) + 1)
+
+def buildElement(element, items, itemName):
+ """ Create an element for output.
+ """
+ def assertNonnegative(i,name):
+ if i < 0:
+ raise RuntimeError("Negative value %s reported for %s" %(i,name) )
+ else:
+ return float(i)
+
+ itemTimes = []
+ itemClocks = []
+ itemMemory = []
+ for item in items:
+ itemTimes.append(assertNonnegative(item["time"], "time"))
+ itemClocks.append(assertNonnegative(item["clock"], "clock"))
+ itemMemory.append(assertNonnegative(item["memory"], "memory"))
+ assert len(itemClocks) == len(itemTimes) == len(itemMemory)
+
+ itemWaits=[]
+ for index in range(0,len(itemTimes)):
+ itemWaits.append(itemClocks[index]-itemTimes[index])
+
+ itemWaits.sort()
+ itemTimes.sort()
+ itemClocks.sort()
+ itemMemory.sort()
+
+ if len(itemTimes) == 0:
+ itemTimes.append(0)
+ itemClocks.append(0)
+ itemWaits.append(0)
+ itemMemory.append(0)
+
+ element[itemName]=Expando(
+ total_number=float(len(items)),
+ total_time=float(sum(itemTimes)),
+ median_time=float(itemTimes[len(itemTimes)/2]),
+ average_time=float(sum(itemTimes)/len(itemTimes)),
+ min_time=float(min(itemTimes)),
+ max_time=float(max(itemTimes)),
+ total_clock=float(sum(itemClocks)),
+ median_clock=float(itemClocks[len(itemClocks)/2]),
+ average_clock=float(sum(itemClocks)/len(itemClocks)),
+ min_clock=float(min(itemClocks)),
+ max_clock=float(max(itemClocks)),
+ total_wait=float(sum(itemWaits)),
+ median_wait=float(itemWaits[len(itemWaits)/2]),
+ average_wait=float(sum(itemWaits)/len(itemWaits)),
+ min_wait=float(min(itemWaits)),
+ max_wait=float(max(itemWaits)),
+ total_memory=float(sum(itemMemory)),
+ median_memory=float(itemMemory[len(itemMemory)/2]),
+ average_memory=float(sum(itemMemory)/len(itemMemory)),
+ min_memory=float(min(itemMemory)),
+ max_memory=float(max(itemMemory)),
+ name=itemName
+ )
+ return element[itemName]
+
+def createSummary(element, containingItems, containingItemName, getFn):
+ itemCounts = [len(getFn(containingItem)) for
+ containingItem in containingItems]
+ itemCounts.sort()
+ if len(itemCounts) == 0:
+ itemCounts.append(0)
+ element["median_number_per_%s" % containingItemName] = itemCounts[len(itemCounts) / 2]
+ element["average_number_per_%s" % containingItemName] = float(sum(itemCounts)) / len(itemCounts)
+ element["min_number_per_%s" % containingItemName] = min(itemCounts)
+ element["max_number_per_%s" % containingItemName] = max(itemCounts)
+
+
+def getStats(jobStore):
+ """ Collect and return the stats and config data.
+ """
+ def aggregateStats(fileHandle,aggregateObject):
+ try:
+ stats = json.load(fileHandle, object_hook=Expando)
+ for key in stats.keys():
+ if key in aggregateObject:
+ aggregateObject[key].append(stats[key])
+ else:
+ aggregateObject[key]=[stats[key]]
+ except ValueError:
+ logger.critical("File %s contains corrupted json. Skipping file." % fileHandle)
+ pass # The file is corrupted.
+
+ aggregateObject = Expando()
+ callBack = partial(aggregateStats, aggregateObject=aggregateObject)
+ jobStore.readStatsAndLogging(callBack, readAll=True)
+ return aggregateObject
+
+
+def processData(config, stats):
+ ##########################################
+ # Collate the stats and report
+ ##########################################
+ if stats.get("total_time", None) is None: # Hack to allow unfinished toils.
+ stats.total_time = {"total_time": "0.0", "total_clock": "0.0"}
+ else:
+ stats.total_time = sum([float(number) for number in stats.total_time])
+ stats.total_clock = sum([float(number) for number in stats.total_clock])
+
+ collatedStatsTag = Expando(total_run_time=stats.total_time,
+ total_clock=stats.total_clock,
+ batch_system=config.batchSystem,
+ default_memory=str(config.defaultMemory),
+ default_cores=str(config.defaultCores),
+ max_cores=str(config.maxCores)
+ )
+
+ # Add worker info
+ worker = filter(None, stats.workers)
+ jobs = filter(None, stats.jobs)
+ jobs = [item for sublist in jobs for item in sublist]
+
+ def fn4(job):
+ try:
+ return list(jobs)
+ except TypeError:
+ return []
+
+ buildElement(collatedStatsTag, worker, "worker")
+ createSummary(buildElement(collatedStatsTag, jobs, "jobs"),
+ stats.workers, "worker", fn4)
+ # Get info for each job
+ jobNames = set()
+ for job in jobs:
+ jobNames.add(job.class_name)
+ jobTypesTag = Expando()
+ collatedStatsTag.job_types = jobTypesTag
+ for jobName in jobNames:
+ jobTypes = [ job for job in jobs if job.class_name == jobName ]
+ buildElement(jobTypesTag, jobTypes, jobName)
+ collatedStatsTag.name = "collatedStatsTag"
+ return collatedStatsTag
+
+def reportData(tree, options):
+ # Now dump it all out to file
+ if options.raw:
+ out_str = printJson(tree)
+ else:
+ root, worker, job, job_types = refineData(tree, options)
+ out_str = reportPrettyData(root, worker, job, job_types, options)
+ if options.outputFile is not None:
+ fileHandle = open(options.outputFile, "w")
+ fileHandle.write(out_str)
+ fileHandle.close()
+ # Now dump onto the screen
+ print(out_str)
+
+def main():
+ """ Reports stats on the workflow, use with --stats option to toil.
+ """
+ parser = getBasicOptionParser()
+ initializeOptions(parser)
+ options = parseBasicOptions(parser)
+ checkOptions(options, parser)
+ config = Config()
+ config.setOptions(options)
+ jobStore = Toil.resumeJobStore(config.jobStore)
+ stats = getStats(jobStore)
+ collatedStatsTag = processData(jobStore.config, stats)
+ reportData(collatedStatsTag, options)
diff --git a/src/toil/utils/toilStatus.py b/src/toil/utils/toilStatus.py
new file mode 100644
index 0000000..20f6885
--- /dev/null
+++ b/src/toil/utils/toilStatus.py
@@ -0,0 +1,112 @@
+# Copyright (C) 2015-2016 Regents of the University of California
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Reports the state of a Toil workflow
+"""
+from __future__ import absolute_import
+from __future__ import print_function
+
+import logging
+import sys
+
+from toil.lib.bioio import logStream
+from toil.lib.bioio import getBasicOptionParser
+from toil.lib.bioio import parseBasicOptions
+from toil.common import Toil, jobStoreLocatorHelp, Config
+from toil.leader import ToilState
+from toil.job import JobException
+from toil.version import version
+
+logger = logging.getLogger( __name__ )
+
+def main():
+ """Reports the state of the toil.
+ """
+
+ ##########################################
+ #Construct the arguments.
+ ##########################################
+
+ parser = getBasicOptionParser()
+
+ parser.add_argument("jobStore", type=str,
+ help="The location of a job store that holds the information about the "
+ "workflow whose status is to be reported on." + jobStoreLocatorHelp)
+
+ parser.add_argument("--verbose", dest="verbose", action="store_true",
+ help="Print loads of information, particularly all the log files of \
+ jobs that failed. default=%(default)s",
+ default=False)
+
+ parser.add_argument("--failIfNotComplete", dest="failIfNotComplete", action="store_true",
+ help="Return exit value of 1 if toil jobs not all completed. default=%(default)s",
+ default=False)
+ parser.add_argument("--version", action='version', version=version)
+ options = parseBasicOptions(parser)
+ logger.info("Parsed arguments")
+
+ if len(sys.argv) == 1:
+ parser.print_help()
+ sys.exit(0)
+
+ ##########################################
+ #Do some checks.
+ ##########################################
+
+ logger.info("Checking if we have files for Toil")
+ assert options.jobStore is not None
+ config = Config()
+ config.setOptions(options)
+ ##########################################
+ #Survey the status of the job and report.
+ ##########################################
+
+ jobStore = Toil.resumeJobStore(config.jobStore)
+ try:
+ rootJob = jobStore.loadRootJob()
+ except JobException:
+ print('The root job of the job store is absent, the workflow completed successfully.',
+ file=sys.stderr)
+ sys.exit(0)
+
+ toilState = ToilState(jobStore, rootJob )
+
+ # The first element of the toilState.updatedJobs tuple is the jobGraph we want to inspect
+ totalJobs = set(toilState.successorCounts.keys()) | \
+ {jobTuple[0] for jobTuple in toilState.updatedJobs}
+
+ failedJobs = [ job for job in totalJobs if job.remainingRetryCount == 0 ]
+
+ print('There are %i active jobs, %i parent jobs with children, and %i totally failed jobs '
+ 'currently in %s.' % (len(toilState.updatedJobs), len(toilState.successorCounts),
+ len(failedJobs), config.jobStore), file=sys.stderr)
+
+ if options.verbose: #Verbose currently means outputting the files that have failed.
+ for job in failedJobs:
+ if job.logJobStoreFileID is not None:
+ with job.getLogFileHandle(jobStore) as logFileHandle:
+ logStream(logFileHandle, job.jobStoreID, logger.warn)
+ else:
+ print('Log file for job %s is absent.' % job.jobStoreID, file=sys.stderr)
+ if len(failedJobs) == 0:
+ print('There are no failed jobs to report.', file=sys.stderr)
+
+ if (len(toilState.updatedJobs) + len(toilState.successorCounts)) != 0 and \
+ options.failIfNotComplete:
+ sys.exit(1)
+
+def _test():
+ import doctest
+ return doctest.testmod()
diff --git a/src/toil/version.py b/src/toil/version.py
new file mode 100644
index 0000000..88a2117
--- /dev/null
+++ b/src/toil/version.py
@@ -0,0 +1,13 @@
+dockerShortTag = '3.5.0a1.dev321-6b22036'
+baseVersion = '3.5.0a1'
+dockerTag = '3.5.0a1.dev321-6b22036e1bb4227c6d15f2aeda126dfb5cfab716'
+dockerName = 'toil'
+buildNumber = '321'
+cgcloudVersion = '1.6.0a1.dev393'
+version = '3.5.0a1.dev321-6b22036e1bb4227c6d15f2aeda126dfb5cfab716'
+dirty = False
+shortVersion = '3.5.0a1.dev321-6b22036'
+currentCommit = '6b22036e1bb4227c6d15f2aeda126dfb5cfab716'
+dockerMinimalTag = '3.5.0a1.dev321'
+distVersion = '3.5.0a1.dev321'
+dockerRegistry = 'quay.io/ucsc_cgl'
diff --git a/src/toil/worker.py b/src/toil/worker.py
new file mode 100644
index 0000000..94e1c60
--- /dev/null
+++ b/src/toil/worker.py
@@ -0,0 +1,560 @@
+# Copyright (C) 2015-2016 Regents of the University of California
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import, print_function
+import os
+import sys
+import copy
+import random
+import json
+
+import tempfile
+import traceback
+import time
+import socket
+import logging
+import shutil
+from threading import Thread
+
+# Python 3 compatibility imports
+from six.moves import cPickle
+
+from bd2k.util.expando import Expando, MagicExpando
+from toil.common import Toil
+from toil.fileStore import FileStore
+from toil import logProcessContext
+import signal
+
+logger = logging.getLogger( __name__ )
+
+
+
+
+def nextOpenDescriptor():
+ """Gets the number of the next available file descriptor.
+ """
+ descriptor = os.open("/dev/null", os.O_RDONLY)
+ os.close(descriptor)
+ return descriptor
+
+class AsyncJobStoreWrite:
+ def __init__(self, jobStore):
+ pass
+
+ def writeFile(self, filePath):
+ pass
+
+ def writeFileStream(self):
+ pass
+
+ def blockUntilSync(self):
+ pass
+
+def main():
+ logging.basicConfig()
+
+ ##########################################
+ #Import necessary modules
+ ##########################################
+
+ # This is assuming that worker.py is at a path ending in "/toil/worker.py".
+ sourcePath = os.path.dirname(os.path.dirname(__file__))
+ if sourcePath not in sys.path:
+ sys.path.append(sourcePath)
+
+ #Now we can import all the necessary functions
+ from toil.lib.bioio import setLogLevel
+ from toil.lib.bioio import getTotalCpuTime
+ from toil.lib.bioio import getTotalCpuTimeAndMemoryUsage
+ from toil.job import Job
+ try:
+ import boto
+ except ImportError:
+ pass
+ else:
+ # boto is installed, monkey patch it now
+ from bd2k.util.ec2.credentials import enable_metadata_credential_caching
+ enable_metadata_credential_caching()
+ ##########################################
+ #Input args
+ ##########################################
+
+ jobStoreLocator = sys.argv[1]
+ jobStoreID = sys.argv[2]
+ # we really want a list of job names but the ID will suffice if the job graph can't
+ # be loaded. If we can discover the name, we will replace this initial entry
+ listOfJobs = [jobStoreID]
+
+ ##########################################
+ #Load the jobStore/config file
+ ##########################################
+
+ jobStore = Toil.resumeJobStore(jobStoreLocator)
+ config = jobStore.config
+
+ ##########################################
+ #Create the worker killer, if requested
+ ##########################################
+
+ logFileByteReportLimit = config.maxLogFileSize
+
+ if config.badWorker > 0 and random.random() < config.badWorker:
+ def badWorker():
+ #This will randomly kill the worker process at a random time
+ time.sleep(config.badWorkerFailInterval * random.random())
+ os.kill(os.getpid(), signal.SIGKILL) #signal.SIGINT)
+ #TODO: FIX OCCASIONAL DEADLOCK WITH SIGINT (tested on single machine)
+ t = Thread(target=badWorker)
+ # Ideally this would be a daemon thread but that causes an intermittent (but benign)
+ # exception similar to the one described here:
+ # http://stackoverflow.com/questions/20596918/python-exception-in-thread-thread-1-most-likely-raised-during-interpreter-shutd
+ # Our exception is:
+ # Exception in thread Thread-1 (most likely raised during interpreter shutdown):
+ # <type 'exceptions.AttributeError'>: 'NoneType' object has no attribute 'kill'
+ # This attribute error is caused by the call os.kill() and apparently unavoidable with a
+ # daemon
+ t.start()
+
+ ##########################################
+ #Load the environment for the jobGraph
+ ##########################################
+
+ #First load the environment for the jobGraph.
+ with jobStore.readSharedFileStream("environment.pickle") as fileHandle:
+ environment = cPickle.load(fileHandle)
+ for i in environment:
+ if i not in ("TMPDIR", "TMP", "HOSTNAME", "HOSTTYPE"):
+ os.environ[i] = environment[i]
+ # sys.path is used by __import__ to find modules
+ if "PYTHONPATH" in environment:
+ for e in environment["PYTHONPATH"].split(':'):
+ if e != '':
+ sys.path.append(e)
+
+ setLogLevel(config.logLevel)
+
+ toilWorkflowDir = Toil.getWorkflowDir(config.workflowID, config.workDir)
+
+ ##########################################
+ #Setup the temporary directories.
+ ##########################################
+
+ # Dir to put all this worker's temp files in.
+ localWorkerTempDir = tempfile.mkdtemp(dir=toilWorkflowDir)
+ os.chmod(localWorkerTempDir, 0o755)
+
+ ##########################################
+ #Setup the logging
+ ##########################################
+
+ #This is mildly tricky because we don't just want to
+ #redirect stdout and stderr for this Python process; we want to redirect it
+ #for this process and all children. Consequently, we can't just replace
+ #sys.stdout and sys.stderr; we need to mess with the underlying OS-level
+ #file descriptors. See <http://stackoverflow.com/a/11632982/402891>
+
+ #When we start, standard input is file descriptor 0, standard output is
+ #file descriptor 1, and standard error is file descriptor 2.
+
+ #What file do we want to point FDs 1 and 2 to?
+ tempWorkerLogPath = os.path.join(localWorkerTempDir, "worker_log.txt")
+
+ #Save the original stdout and stderr (by opening new file descriptors to the
+ #same files)
+ origStdOut = os.dup(1)
+ origStdErr = os.dup(2)
+
+ #Open the file to send stdout/stderr to.
+ logFh = os.open(tempWorkerLogPath, os.O_WRONLY | os.O_CREAT | os.O_APPEND)
+
+ #Replace standard output with a descriptor for the log file
+ os.dup2(logFh, 1)
+
+ #Replace standard error with a descriptor for the log file
+ os.dup2(logFh, 2)
+
+ #Since we only opened the file once, all the descriptors duped from the
+ #original will share offset information, and won't clobber each others'
+ #writes. See <http://stackoverflow.com/a/5284108/402891>. This shouldn't
+ #matter, since O_APPEND seeks to the end of the file before every write, but
+ #maybe there's something odd going on...
+
+ #Close the descriptor we used to open the file
+ os.close(logFh)
+
+ debugging = logging.getLogger().isEnabledFor(logging.DEBUG)
+ ##########################################
+ #Worker log file trapped from here on in
+ ##########################################
+
+ workerFailed = False
+ statsDict = MagicExpando()
+ statsDict.jobs = []
+ statsDict.workers.logsToMaster = []
+ blockFn = lambda : True
+ cleanCacheFn = lambda x : True
+ try:
+
+ #Put a message at the top of the log, just to make sure it's working.
+ print("---TOIL WORKER OUTPUT LOG---")
+ sys.stdout.flush()
+
+ #Log the number of open file descriptors so we can tell if we're leaking
+ #them.
+ logger.debug("Next available file descriptor: {}".format(
+ nextOpenDescriptor()))
+
+ logProcessContext(config)
+
+ ##########################################
+ #Load the jobGraph
+ ##########################################
+
+ jobGraph = jobStore.load(jobStoreID)
+ listOfJobs[0] = str(jobGraph)
+ logger.debug("Parsed jobGraph")
+
+ ##########################################
+ #Cleanup from any earlier invocation of the jobGraph
+ ##########################################
+
+ if jobGraph.command == None:
+ # Cleanup jobs already finished
+ f = lambda jobs : filter(lambda x : len(x) > 0, map(lambda x :
+ filter(lambda y : jobStore.exists(y.jobStoreID), x), jobs))
+ jobGraph.stack = f(jobGraph.stack)
+ jobGraph.services = f(jobGraph.services)
+ logger.debug("Cleaned up any references to completed successor jobs")
+
+ #This cleans the old log file which may
+ #have been left if the job is being retried after a job failure.
+ oldLogFile = jobGraph.logJobStoreFileID
+ if oldLogFile != None:
+ jobGraph.logJobStoreFileID = None
+ jobStore.update(jobGraph) #Update first, before deleting any files
+ jobStore.deleteFile(oldLogFile)
+
+ ##########################################
+ # If a checkpoint exists, restart from the checkpoint
+ ##########################################
+
+ # The job is a checkpoint, and is being restarted after previously completing
+ if jobGraph.checkpoint != None:
+ logger.debug("Job is a checkpoint")
+ if len(jobGraph.stack) > 0 or len(jobGraph.services) > 0 or jobGraph.command != None:
+ if jobGraph.command != None:
+ assert jobGraph.command == jobGraph.checkpoint
+ logger.debug("Checkpoint job already has command set to run")
+ else:
+ jobGraph.command = jobGraph.checkpoint
+
+ # Reduce the retry count
+ assert jobGraph.remainingRetryCount >= 0
+ jobGraph.remainingRetryCount = max(0, jobGraph.remainingRetryCount - 1)
+
+ jobStore.update(jobGraph) # Update immediately to ensure that checkpoint
+ # is made before deleting any remaining successors
+
+ if len(jobGraph.stack) > 0 or len(jobGraph.services) > 0:
+ # If the subtree of successors is not complete restart everything
+ logger.debug("Checkpoint job has unfinished successor jobs, deleting the jobs on the stack: %s, services: %s " %
+ (jobGraph.stack, jobGraph.services))
+
+ # Delete everything on the stack, as these represent successors to clean
+ # up as we restart the queue
+ def recursiveDelete(jobGraph2):
+ # Recursive walk the stack to delete all remaining jobs
+ for jobs in jobGraph2.stack + jobGraph2.services:
+ for jobNode in jobs:
+ if jobStore.exists(jobNode.jobStoreID):
+ recursiveDelete(jobStore.load(jobNode.jobStoreID))
+ else:
+ logger.debug("Job %s has already been deleted", jobNode)
+ if jobGraph2 != jobGraph:
+ logger.debug("Checkpoint is deleting old successor job: %s", jobGraph2.jobStoreID)
+ jobStore.delete(jobGraph2.jobStoreID)
+ recursiveDelete(jobGraph)
+
+ jobGraph.stack = [ [], [] ] # Initialise the job to mimic the state of a job
+ # that has been previously serialised but which as yet has no successors
+
+ jobGraph.services = [] # Empty the services
+
+ # Update the jobStore to avoid doing this twice on failure and make this clean.
+ jobStore.update(jobGraph)
+
+ # Otherwise, the job and successors are done, and we can cleanup stuff we couldn't clean
+ # because of the job being a checkpoint
+ else:
+ logger.debug("The checkpoint jobs seems to have completed okay, removing any checkpoint files to delete.")
+ #Delete any remnant files
+ map(jobStore.deleteFile, filter(jobStore.fileExists, jobGraph.checkpointFilesToDelete))
+
+ ##########################################
+ #Setup the stats, if requested
+ ##########################################
+
+ if config.stats:
+ startTime = time.time()
+ startClock = getTotalCpuTime()
+
+ #Make a temporary file directory for the jobGraph
+ #localTempDir = makePublicDir(os.path.join(localWorkerTempDir, "localTempDir"))
+
+ startTime = time.time()
+ while True:
+ ##########################################
+ #Run the jobGraph, if there is one
+ ##########################################
+
+ if jobGraph.command is not None:
+ assert jobGraph.command.startswith( "_toil " )
+ logger.debug("Got a command to run: %s" % jobGraph.command)
+ #Load the job
+ job = Job._loadJob(jobGraph.command, jobStore)
+ # If it is a checkpoint job, save the command
+ if job.checkpoint:
+ jobGraph.checkpoint = jobGraph.command
+
+ # Create a fileStore object for the job
+ fileStore = FileStore.createFileStore(jobStore, jobGraph, localWorkerTempDir, blockFn,
+ caching=not config.disableCaching)
+ with job._executor(jobGraph=jobGraph,
+ stats=statsDict if config.stats else None,
+ fileStore=fileStore):
+ with fileStore.open(job):
+ # Get the next block function and list that will contain any messages
+ blockFn = fileStore._blockFn
+
+ job._runner(jobGraph=jobGraph, jobStore=jobStore, fileStore=fileStore)
+
+ # Accumulate messages from this job & any subsequent chained jobs
+ statsDict.workers.logsToMaster += fileStore.loggingMessages
+
+ else:
+ #The command may be none, in which case
+ #the jobGraph is either a shell ready to be deleted or has
+ #been scheduled after a failure to cleanup
+ break
+
+ if FileStore._terminateEvent.isSet():
+ raise RuntimeError("The termination flag is set")
+
+ ##########################################
+ #Establish if we can run another jobGraph within the worker
+ ##########################################
+
+ #If no more jobs to run or services not finished, quit
+ if len(jobGraph.stack) == 0 or len(jobGraph.services) > 0 or jobGraph.checkpoint != None:
+ logger.debug("Stopping running chain of jobs: length of stack: %s, services: %s, checkpoint: %s",
+ len(jobGraph.stack), len(jobGraph.services), jobGraph.checkpoint != None)
+ break
+
+ #Get the next set of jobs to run
+ jobs = jobGraph.stack[-1]
+ assert len(jobs) > 0
+
+ #If there are 2 or more jobs to run in parallel we quit
+ if len(jobs) >= 2:
+ logger.debug("No more jobs can run in series by this worker,"
+ " it's got %i children", len(jobs)-1)
+ break
+
+ #We check the requirements of the jobGraph to see if we can run it
+ #within the current worker
+ successorJobNode = jobs[0]
+ if successorJobNode.memory > jobGraph.memory:
+ logger.debug("We need more memory for the next job, so finishing")
+ break
+ if successorJobNode.cores > jobGraph.cores:
+ logger.debug("We need more cores for the next job, so finishing")
+ break
+ if successorJobNode.disk > jobGraph.disk:
+ logger.debug("We need more disk for the next job, so finishing")
+ break
+ if successorJobNode.predecessorNumber > 1:
+ logger.debug("The jobGraph has multiple predecessors, we must return to the leader.")
+ break
+
+ # Load the successor jobGraph
+ successorJobGraph = jobStore.load(successorJobNode.jobStoreID)
+
+ # add the successor to the list of jobs run
+ listOfJobs.append(str(successorJobGraph))
+
+ # Somewhat ugly, but check if job is a checkpoint job and quit if
+ # so
+ if successorJobGraph.command.startswith( "_toil " ):
+ #Load the job
+ successorJob = Job._loadJob(successorJobGraph.command, jobStore)
+
+ # Check it is not a checkpoint
+ if successorJob.checkpoint:
+ logger.debug("Next job is checkpoint, so finishing")
+ break
+
+ ##########################################
+ #We have a single successor job that is not a checkpoint job.
+ #We transplant the successor jobGraph command and stack
+ #into the current jobGraph object so that it can be run
+ #as if it were a command that were part of the current jobGraph.
+ #We can then delete the successor jobGraph in the jobStore, as it is
+ #wholly incorporated into the current jobGraph.
+ ##########################################
+
+ #Clone the jobGraph and its stack
+ jobGraph = copy.deepcopy(jobGraph)
+
+ #Remove the successor jobGraph
+ jobGraph.stack.pop()
+
+ #These should all match up
+ assert successorJobGraph.memory == successorJobNode.memory
+ assert successorJobGraph.cores == successorJobNode.cores
+ assert successorJobGraph.predecessorsFinished == set()
+ assert successorJobGraph.predecessorNumber == 1
+ assert successorJobGraph.command is not None
+ assert successorJobGraph.jobStoreID == successorJobNode.jobStoreID
+
+ #Transplant the command and stack to the current jobGraph
+ jobGraph.command = successorJobGraph.command
+ jobGraph.stack += successorJobGraph.stack
+ # include some attributes for better identification of chained jobs in
+ # logging output
+ jobGraph.unitName = successorJobGraph.unitName
+ jobGraph.jobName = successorJobGraph.jobName
+ assert jobGraph.memory >= successorJobGraph.memory
+ assert jobGraph.cores >= successorJobGraph.cores
+
+ #Build a fileStore to update the job
+ fileStore = FileStore.createFileStore(jobStore, jobGraph, localWorkerTempDir, blockFn,
+ caching=not config.disableCaching)
+
+ #Update blockFn
+ blockFn = fileStore._blockFn
+
+ #Add successorJobGraph to those to be deleted
+ fileStore.jobsToDelete.add(successorJobGraph.jobStoreID)
+
+ #This will update the job once the previous job is done
+ fileStore._updateJobWhenDone()
+
+ #Clone the jobGraph and its stack again, so that updates to it do
+ #not interfere with this update
+ jobGraph = copy.deepcopy(jobGraph)
+
+ logger.debug("Starting the next job")
+
+ ##########################################
+ #Finish up the stats
+ ##########################################
+ if config.stats:
+ totalCPUTime, totalMemoryUsage = getTotalCpuTimeAndMemoryUsage()
+ statsDict.workers.time = str(time.time() - startTime)
+ statsDict.workers.clock = str(totalCPUTime - startClock)
+ statsDict.workers.memory = str(totalMemoryUsage)
+
+ # log the worker log path here so that if the file is truncated the path can still be found
+ logger.info("Worker log can be found at %s. Set --cleanWorkDir to retain this log", localWorkerTempDir)
+ logger.info("Finished running the chain of jobs on this node, we ran for a total of %f seconds", time.time() - startTime)
+
+ ##########################################
+ #Trapping where worker goes wrong
+ ##########################################
+ except: #Case that something goes wrong in worker
+ traceback.print_exc()
+ logger.error("Exiting the worker because of a failed job on host %s", socket.gethostname())
+ FileStore._terminateEvent.set()
+
+ ##########################################
+ #Wait for the asynchronous chain of writes/updates to finish
+ ##########################################
+
+ blockFn()
+
+ ##########################################
+ #All the asynchronous worker/update threads must be finished now,
+ #so safe to test if they completed okay
+ ##########################################
+
+ if FileStore._terminateEvent.isSet():
+ jobGraph = jobStore.load(jobStoreID)
+ jobGraph.setupJobAfterFailure(config)
+ workerFailed = True
+
+ ##########################################
+ #Cleanup
+ ##########################################
+
+ #Close the worker logging
+ #Flush at the Python level
+ sys.stdout.flush()
+ sys.stderr.flush()
+ #Flush at the OS level
+ os.fsync(1)
+ os.fsync(2)
+
+ #Close redirected stdout and replace with the original standard output.
+ os.dup2(origStdOut, 1)
+
+ #Close redirected stderr and replace with the original standard error.
+ os.dup2(origStdErr, 2)
+
+ #sys.stdout and sys.stderr don't need to be modified at all. We don't need
+ #to call redirectLoggerStreamHandlers since they still log to sys.stderr
+
+ #Close our extra handles to the original standard output and standard error
+ #streams, so we don't leak file handles.
+ os.close(origStdOut)
+ os.close(origStdErr)
+
+ #Now our file handles are in exactly the state they were in before.
+
+ #Copy back the log file to the global dir, if needed
+ if workerFailed:
+ jobGraph.logJobStoreFileID = jobStore.getEmptyFileStoreID(jobGraph.jobStoreID)
+ jobGraph.chainedJobs = listOfJobs
+ with jobStore.updateFileStream(jobGraph.logJobStoreFileID) as w:
+ with open(tempWorkerLogPath, "r") as f:
+ if os.path.getsize(tempWorkerLogPath) > logFileByteReportLimit !=0:
+ if logFileByteReportLimit > 0:
+ f.seek(-logFileByteReportLimit, 2) # seek to last tooBig bytes of file
+ elif logFileByteReportLimit < 0:
+ f.seek(logFileByteReportLimit, 0) # seek to first tooBig bytes of file
+ w.write(f.read())
+ jobStore.update(jobGraph)
+
+ elif debugging: # write log messages
+ with open(tempWorkerLogPath, 'r') as logFile:
+ if os.path.getsize(tempWorkerLogPath) > logFileByteReportLimit != 0:
+ if logFileByteReportLimit > 0:
+ logFile.seek(-logFileByteReportLimit, 2) # seek to last tooBig bytes of file
+ elif logFileByteReportLimit < 0:
+ logFile.seek(logFileByteReportLimit, 0) # seek to first tooBig bytes of file
+ logMessages = logFile.read().splitlines()
+ statsDict.logs.names = listOfJobs
+ statsDict.logs.messages = logMessages
+
+ if (debugging or config.stats or statsDict.workers.logsToMaster) and not workerFailed: # We have stats/logging to report back
+ jobStore.writeStatsAndLogging(json.dumps(statsDict))
+
+ #Remove the temp dir
+ cleanUp = config.cleanWorkDir
+ if cleanUp == 'always' or (cleanUp == 'onSuccess' and not workerFailed) or (cleanUp == 'onError' and workerFailed):
+ shutil.rmtree(localWorkerTempDir)
+
+ #This must happen after the log file is done with, else there is no place to put the log
+ if (not workerFailed) and jobGraph.command == None and len(jobGraph.stack) == 0 and len(jobGraph.services) == 0:
+ # We can now safely get rid of the jobGraph
+ jobStore.delete(jobGraph.jobStoreID)
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-med/toil.git
More information about the debian-med-commit
mailing list