[med-svn] [Git][med-team/python-treetime][upstream] New upstream version 0.11.1

Ananthu C V (@weepingclown) gitlab at salsa.debian.org
Wed Nov 1 18:42:28 GMT 2023



Ananthu C V pushed to branch upstream at Debian Med / python-treetime


Commits:
bc7ed4de by Ananthu C V at 2023-11-01T01:09:00+05:30
New upstream version 0.11.1
- - - - -


27 changed files:

- README.md
- changelog.md
- docs/source/tutorials/timetree.rst
- setup.py
- test.sh
- test/command_line_tests.sh
- + test/test_sequence_evolution_model.txt
- test/test_treetime.py
- + treetime/CLI_io.py
- treetime/__init__.py
- treetime/__main__.py
- treetime/argument_parser.py
- + treetime/clock_filter_methods.py
- treetime/clock_tree.py
- treetime/distribution.py
- treetime/gtr.py
- treetime/gtr_site_specific.py
- treetime/merger_models.py
- treetime/node_interpolator.py
- treetime/seq_utils.py
- treetime/seqgen.py
- treetime/treeanc.py
- treetime/treeregression.py
- treetime/treetime.py
- treetime/utils.py
- treetime/vcf_utils.py
- treetime/wrappers.py


Changes:

=====================================
README.md
=====================================
@@ -46,7 +46,7 @@ Have a look at our repository with [example data](https://github.com/neherlab/tr
 
 ### Installation and prerequisites
 
-TreeTime is compatible with Python 3.6 upwards and is tested on 3.6, 3.7, and 3.8.  It depends on several Python libraries:
+TreeTime is compatible with Python 3.7 upwards and is tested on 3.7 to 3.10.  It depends on several Python libraries:
 
 * numpy, scipy, pandas: for all kind of mathematical operations as matrix
   operations, numerical integration, interpolation, minimization, etc.


=====================================
changelog.md
=====================================
@@ -1,4 +1,52 @@
-# 0.9.4: bug fix and performance improvments
+# 0.11.1: bug fixes and tweaks to plotting
+- fix division by zero error during GTR inference
+- improve doc strings in parse dates
+- tweaks to background shading in timetree plot function (`plot_vs_years`)
+- allow to specify branches on which date confidence intervals are shown.
+
+# 0.11.0: new clock filter method
+
+Previously, only a crude analysis of whether the divergence of tips roughly follows a linear trend was implemented. Tips that deviated too much from that regression line were flagged as outliers and this threshold was parameterized as number of interquartile distances of the distribution of residuals `n_iqd`.
+This filter is not very sensitive and often misses misdated tips that severely distort the tree but still fall within the distribution of root-to-tip distances at that time.
+To overcome this, we implemented a novel filtering method that fits a simple gaussian model of divergence accumulation.
+Information on outliers is saved in a pandas DataFrame `self.outliers` of `TreeTime` and written to file as a tsv file when running treetime as command line tool.
+
+### Other fixes
+ * error when rate estimate is negative during the rate susceptibility calculation. Give hint in error message to specify the rate and its uncertainty explicitly.
+ * Fix bug [issue #250](https://github.com/neherlab/treetime/issues/250) introduced in 0.10.0 where treetime fails in absence of an alignment when trying to create an auspice json file. [PR #251](https://github.com/neherlab/treetime/pull/251)
+
+# 0.10.1: bug fix release
+
+ * avoid probability loss at the end of the domain of distributions
+ * fix erroneous check for merger model.
+ * raise error when probability is lost.
+ * improve initial guess in  branch length optimizations
+
+# 0.10.0: add auspice.json output, drop python 3.6
+
+ * the output directory now contains a json file that is compatible with auspice.us. Both time scaled phylogenies and ancestral inferences can now be visualized and explored using auspice. Available colorings are "Date", "genotype", "Branch support", and "Excluded". See [PR #232](https://github.com/neherlab/treetime/pull/232) for details.
+ * move most function related to IO of the command line wrappers into a separate file.
+ * make TreeTime own its random number generator and add `--rng-seed` to control state in CLI. Any previous usage of `numpy.random.seed` will now be ignored in favor of `--rng-seed`. See [PR #234](https://github.com/neherlab/treetime/pull/234)
+ * add flag `--greedy-resolve` (currently default) as inverse to `--stochastic-resolve` with the aim of switching the default behavior in the future.
+   Add deprecation warning for `greedy-resolve`.
+ * tighten conditions that trigger approximation of narrow distribution as a delta function in convolution using FFT [PR #235](https://github.com/neherlab/treetime/pull/235).
+ * Drop support for python 3.6.
+ * Don't attempt to show figure when calling `Phylo.draw` to suppress matplotlib back-end warning.
+
+# 0.9.6: bug fixes and new mode of polytomy resolution
+ * in cases when very large polytomies are resolved, the multiplication of the discretized message results in messages/distributions of length 1. This resulted in an error, since interpolation objects require at least two points. This is now caught and a small discrete grid created.
+ * increase recursion limit to 10000 by default. The recursion limit can now also be set via the environment variable `TREETIME_RECURSION_LIMIT`.
+ * removed unused imports, fixed typos
+ * add new way to resolve polytomies. the previous polytomy resolution greedily pulled out pairs of child-clades at a time and merged then into a single clade. This often results in atypical caterpillar like subtrees. This is undesirable since it (i) is very atypical, (ii) causes numerical issues due to repeated convolutions, and (iii) triggers recursion errors during newick export. The new optional way of resolving replaces a multi-furcation by a randomly generated coalescent tree that backwards in time mutates (all mutations are singletons and need to 'go' before coalescence), and merges lineages. Lineages that remain when time reaches the time of the parent remain as children of the parent. This new way of resolving is much faster for large polytomies. This experimental feature can be used via the flag `--stochastic-resolve`. Note that the outcome of this stochastic resolution is stochastic!
+
+# 0.9.5: load custom GTR via CLI
+
+ * fix bug that omitted the inferred state of the root in the nexus export of the migration command
+ * add CLI flag and functionality to load sequence evolution models inferred and saved by TreeTime as human-readable text files. The flag is `--custom-gtr <filename>` and overwrites any arguments passed under the `--gtr` flag.
+ * explicitly specify the optimization method, brackets, bounds, and tolerances in calls of `scipy.optimize.minimize` to suppress scipy warning. Scipy had previously silently ignored bounds when the method wasn't explicitly set to `bounded`.
+
+
+# 0.9.4: bug fix and performance improvements
 
  * avoid negative variance associated with branch lengths in tree regression. This could happen in rare cases when marginal time tree estimation returned short negative branch length and the variance was estimated as being proportional to branch length. Variances in the `TreeRegression` clock model are now always non-negative.
  * downsample the grid during multiplication of distribution objects. This turned out to be an issue for trees with very large polytomies. In these cases, a large number of distributions get multiplied which resulted in grid sizes above 100000 points. Grid sizes are now downsampled to the average grid size.


=====================================
docs/source/tutorials/timetree.rst
=====================================
@@ -182,7 +182,7 @@ The following example with a set of MtB sequences uses a fixed evolutionary rate
 
    treetime --aln data/tb/lee_2015.vcf.gz --vcf-reference data/tb/tb_ref.fasta --tree data/tb/lee_2015.nwk --clock-rate 1e-7 --dates data/tb/lee_2015.metadata.tsv
 
-For many bacterial data set were the temporal signal in the data is weak, it is advisable to fix the rate of the molecular clock explicitly.
+For many bacterial data sets where the temporal signal in the data is weak, it is advisable to fix the rate of the molecular clock explicitly.
 Divergence times, however, will depend on this choice.
 
 


=====================================
setup.py
=====================================
@@ -1,4 +1,3 @@
-import os
 from setuptools import setup
 
 def get_version():
@@ -32,14 +31,12 @@ setup(
             'scipy>=0.16.1'
         ],
         extras_require = {
-            ':python_version < "3.6"':['matplotlib>=2.0, ==2.*'],
             ':python_version >= "3.6"':['matplotlib>=2.0'],
         },
         classifiers=[
             "Development Status :: 5 - Production/Stable",
             "Topic :: Scientific/Engineering :: Bio-Informatics",
             "License :: OSI Approved :: MIT License",
-            "Programming Language :: Python :: 3.6",
             "Programming Language :: Python :: 3.7",
             "Programming Language :: Python :: 3.8",
             "Programming Language :: Python :: 3.9",


=====================================
test.sh
=====================================
@@ -3,10 +3,21 @@
 set -euo pipefail
 
 cd test
+
+# Remove treetime_examples in case it exists to not fail
+rm -rf treetime_examples
 git clone https://github.com/neherlab/treetime_examples.git
+
 bash command_line_tests.sh
 OUT=$?
 if [ "$OUT" != 0 ]; then
   exit 1
 fi
+
 pytest test_treetime.py
+if [ "$OUT" != 0 ]; then
+  exit 1
+fi
+
+# Clean up, the 202* is to remove auto-generated output dirs
+rm -rf treetime_examples __pycache__ 202*


=====================================
test/command_line_tests.sh
=====================================
@@ -54,6 +54,24 @@ else
 	echo "timetree_inference on vcf data failed $retval"
 fi
 
+treetime --tree treetime_examples/data/ebola/ebola.nwk --dates treetime_examples/data/ebola/ebola.metadata.csv --aln treetime_examples/data/ebola/ebola.fasta  --coalescent skyline --gen-per-year 100
+retval="$?"
+if [ "$retval" == 0 ]; then
+	echo "skyline approximation ok"
+else
+	((all_tests++))
+	echo "skyline approximation failed $retval"
+fi
+
+# From https://github.com/neherlab/treetime/issues/250 
+treetime --tree treetime_examples/data/ebola/ebola.nwk --dates treetime_examples/data/ebola/ebola.metadata.csv --sequence-length 1000
+retval="$?"
+if [ "$retval" == 0 ]; then
+	echo "sequence length only ok"
+else
+	((all_tests++))
+	echo "sequence length only failed $retval"
+fi
 
 if [ "$all_tests" == 0 ];then
 	echo "All tests passed"


=====================================
test/test_sequence_evolution_model.txt
=====================================
@@ -0,0 +1,25 @@
+Substitution rate (mu): 1.0
+
+Equilibrium frequencies (pi_i):
+  A: 0.3088
+  C: 0.1897
+  G: 0.2335
+  T: 0.2581
+  -: 0.0099
+
+Symmetrized rates from j->i (W_ij):
+	A	C	G	T	-
+  A	0	0.7003	3.0669	0.2651	0.9742
+  C	0.7003	0	0.3354	3.399	0.999
+  G	3.0669	0.3354	0	0.4258	0.9892
+  T	0.2651	3.399	0.4258	0	0.9848
+  -	0.9742	0.999	0.9892	0.9848	0
+
+Actual rates from j->i (Q_ij):
+	A	C	G	T	-
+  A	0	0.2163	0.9472	0.0819	0.3009
+  C	0.1328	0	0.0636	0.6448	0.1895
+  G	0.716	0.0783	0	0.0994	0.2309
+  T	0.0684	0.8772	0.1099	0	0.2541
+  -	0.0097	0.0099	0.0098	0.0098	0
+


=====================================
test/test_treetime.py
=====================================
@@ -26,15 +26,31 @@ def test_assign_gamma(root_dir=None):
     tt_kwargs = {'clock_rate': 0.0001,
                     'time_marginal':'assign'}
     myTree = TreeTime(gtr='Jukes-Cantor', tree = nwk, use_fft=False,
-                    aln = fasta, verbose = 1, dates = dates, precision=3, debug=True)
+                    aln = fasta, verbose = 1, dates = dates, precision=3, debug=True, rng_seed=1234)
     def assign_gamma(tree):
         return tree
     success = myTree.run(infer_gtr=False, assign_gamma=assign_gamma, max_iter=1, verbose=3, **seq_kwargs, **tt_kwargs)
     assert success
 
-def test_GTR():
+def test_GTR(root_dir=None):
     from treetime import GTR
     import numpy as np
+    import os
+    if root_dir is None:
+        root_dir = os.path.dirname(os.path.realpath(__file__))
+    ##check custom GTR model
+    custom_gtr = root_dir + "/test_sequence_evolution_model.txt"
+    gtr = GTR.from_file(custom_gtr)
+    assert (gtr.Pi.sum() - 1.0)**2<1e-14
+    assert np.allclose(gtr.Pi, np.array([0.3088, 0.1897, 0.2335, 0.2581, 0.0099]))
+    assert np.all(gtr.alphabet == np.array(['A', 'C', 'G', 'T', '-']))
+    assert abs(gtr.mu - 1.0) < 1e-4
+    assert abs(gtr.Q.sum(0)).sum() < 1e-14
+    assert np.allclose(gtr.W, np.array([[0, 0.7003, 3.0669, 0.2651, 0.9742],
+                            [0.7003, 0, 0.3354, 3.399, 0.999],
+                            [3.0669, 0.3354, 0, 0.4258, 0.9892],
+                            [0.2651, 3.399, 0.4258, 0, 0.9848],
+                            [0.9742, 0.999, 0.9892, 0.9848, 0]]), atol=1e-4)
     for model in ['Jukes-Cantor']:
         print('testing GTR, model:',model)
         myGTR = GTR.standard(model, alphabet='nuc')
@@ -48,6 +64,35 @@ def test_GTR():
         assert np.abs(myGTR.v.sum()) > 1e-10 # **and** v is not zero
 
 
+def test_reconstruct_discrete_traits():
+    from Bio import Phylo
+    from treetime.wrappers import reconstruct_discrete_traits
+
+    # Create a minimal tree with traits to reconstruct.
+    tiny_tree = Phylo.read(StringIO("((A:0.60100000009,B:0.3010000009):0.1,C:0.2):0.001;"), 'newick')
+    traits = {
+        "A": "?",
+        "B": "North America",
+        "C": "West Asia",
+    }
+
+    # Reconstruct traits with "?" as missing data.
+    mugration, letter_to_state, reverse_alphabet = reconstruct_discrete_traits(
+        tiny_tree,
+        traits,
+        missing_data="?",
+    )
+
+    # With two known states, the letters "A" and "B" should be in the alphabet
+    # mapping to those states.
+    assert "A" in letter_to_state
+    assert "B" in letter_to_state
+
+    # The letter for missing data should be the next letter in the alphabet,
+    # following the two known state letters.
+    assert letter_to_state["C"] == "?"
+
+
 def test_ancestral(root_dir=None):
     import os
     from Bio import AlignIO
@@ -60,7 +105,7 @@ def test_ancestral(root_dir=None):
 
     for marginal in [True, False]:
         print('loading flu example')
-        t = TreeAnc(gtr='Jukes-Cantor', tree=nwk, aln=fasta)
+        t = TreeAnc(gtr='Jukes-Cantor', tree=nwk, aln=fasta, rng_seed=1234)
         print('ancestral reconstruction' + ("marginal" if marginal else "joint"))
         t.reconstruct_anc(method='ml', marginal=marginal)
         assert t.data.compressed_to_full_sequence(t.tree.root.cseq, as_string=True) == 'ATGAATCCAAATCAAAAGATAATAACGATTGGCTCTGTTTCTCTCACCATTTCCACAATATGCTTCTTCATGCAAATTGCCATCTTGATAACTACTGTAACATTGCATTTCAAGCAATATGAATTCAACTCCCCCCCAAACAACCAAGTGATGCTGTGTGAACCAACAATAATAGAAAGAAACATAACAGAGATAGTGTATCTGACCAACACCACCATAGAGAAGGAAATATGCCCCAAACCAGCAGAATACAGAAATTGGTCAAAACCGCAATGTGGCATTACAGGATTTGCACCTTTCTCTAAGGACAATTCGATTAGGCTTTCCGCTGGTGGGGACATCTGGGTGACAAGAGAACCTTATGTGTCATGCGATCCTGACAAGTGTTATCAATTTGCCCTTGGACAGGGAACAACACTAAACAACGTGCATTCAAATAACACAGTACGTGATAGGACCCCTTATCGGACTCTATTGATGAATGAGTTGGGTGTTCCTTTTCATCTGGGGACCAAGCAAGTGTGCATAGCATGGTCCAGCTCAAGTTGTCACGATGGAAAAGCATGGCTGCATGTTTGTATAACGGGGGATGATAAAAATGCAACTGCTAGCTTCATTTACAATGGGAGGCTTGTAGATAGTGTTGTTTCATGGTCCAAAGAAATTCTCAGGACCCAGGAGTCAGAATGCGTTTGTATCAATGGAACTTGTACAGTAGTAATGACTGATGGAAGTGCTTCAGGAAAAGCTGATACTAAAATACTATTCATTGAGGAGGGGAAAATCGTTCATACTAGCACATTGTCAGGAAGTGCTCAGCATGTCGAAGAGTGCTCTTGCTATCCTCGATATCCTGGTGTCAGATGTGTCTGCAGAGACAACTGGAAAGGCTCCAATCGGCCCATCGTAGATATAAACATAAAGGATCATAGCATTGTTTCCAGTTATGTGTGTTCAGGACTTGTTGGAGACACACCCAGAAAAAACGACAGCTCCAGCAGTAGCCATTGTTTGGATCCTAACAATGAAGAAGGTGGTCATGGAGTGAAAGGCTGGGCCTTTGATGATGGAAATGACGTGTGGATGGGAAGAACAATCAACGAGACGTCACGCTTAGGGTATGAAACCTTCAAAGTCATTGAAGGCTGGTCCAACCCTAAGTCCAAATTGCAGATAAATAGGCAAGTCATAGTTGACAGAGGTGATAGGTCCGGTTATTCTGGTATTTTCTCTGTTGAAGGCAAAAGCTGCATCAATCGGTGCTTTTATGTGGAGTTGATTAGGGGAAGAAAAGAGGAAACTGAAGTCTTGTGGACCTCAAACAGTATTGTTGTGTTTTGTGGCACCTCAGGTACATATGGAACAGGCTCATGGCCTGATGGGGCGGACCTCAATCTCATGCCTATA'
@@ -73,7 +118,7 @@ def test_ancestral(root_dir=None):
                                      ">C\nACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGT\n"), 'fasta')
 
     mygtr = GTR.custom(alphabet = np.array(['A', 'C', 'G', 'T']), pi = np.array([0.9, 0.06, 0.02, 0.02]), W=np.ones((4,4)))
-    t = TreeAnc(gtr=mygtr, tree=tiny_tree, aln=tiny_aln)
+    t = TreeAnc(gtr=mygtr, tree=tiny_tree, aln=tiny_aln, rng_seed=1234)
     t.reconstruct_anc('ml', marginal=True, debug=True)
     lhsum =  np.exp(t.sequence_LH(pos=np.arange(4**3))).sum()
     print (lhsum)
@@ -97,25 +142,19 @@ def test_seq_joint_reconstruction_correct():
     from Bio import Phylo, AlignIO
     import numpy as np
     from collections import defaultdict
-    def exclusion(a, b):
-        """
-        Intersection of two lists
-        """
-        return list(set(a) - set(b))
 
     tiny_tree = Phylo.read(StringIO("((A:.060,B:.01200)C:.020,D:.0050)E:.004;"), 'newick')
     mygtr = GTR.custom(alphabet = np.array(['A', 'C', 'G', 'T']),
                        pi = np.array([0.15, 0.95, 0.05, 0.3]),
                        W=np.ones((4,4)))
-    seq = np.random.choice(mygtr.alphabet, p=mygtr.Pi, size=400)
 
 
-    myTree = TreeAnc(gtr=mygtr, tree=tiny_tree, aln=None, verbose=4)
+    myTree = TreeAnc(gtr=mygtr, tree=tiny_tree, aln=None, verbose=4, rng_seed=1234)
 
     # simulate evolution, set resulting sequence as ref_seq
     tree = myTree.tree
     seq_len = 400
-    tree.root.ref_seq = np.random.choice(mygtr.alphabet, p=mygtr.Pi, size=seq_len)
+    tree.root.ref_seq = myTree.rng.choice(mygtr.alphabet, p=mygtr.Pi, size=seq_len)
     print ("Root sequence: " + ''.join(tree.root.ref_seq.astype('U')))
     mutation_list = defaultdict(list)
     for node in tree.find_clades():
@@ -128,7 +167,7 @@ def test_seq_joint_reconstruction_correct():
         # normalize profile
         p=(p.T/p.sum(axis=1)).T
         # sample mutations randomly
-        ref_seq_idxs = np.array([int(np.random.choice(np.arange(p.shape[1]), p=p[k])) for k in np.arange(p.shape[0])])
+        ref_seq_idxs = np.array([int(myTree.rng.choice(np.arange(p.shape[1]), p=p[k])) for k in np.arange(p.shape[0])])
 
         node.ref_seq = np.array([mygtr.alphabet[k] for k in ref_seq_idxs])
 
@@ -170,8 +209,6 @@ def test_seq_joint_reconstruction_correct():
     print ("Difference between reference and inferred LH:", (LH - LH_p).sum())
     assert ((LH - LH_p).sum())<1e-9
 
-    return myTree
-
 
 def test_seq_joint_lh_is_max():
     """
@@ -212,7 +249,7 @@ def test_seq_joint_lh_is_max():
                                          ">E\nACGTACGTACGTACGT\n"), 'fasta')
 
         myTree = TreeAnc(gtr=mygtr, tree = tiny_tree,
-                         aln =tiny_aln, verbose = 4)
+                         aln =tiny_aln, verbose = 4, rng_seed=1234)
 
         logLH_ref = myTree.ancestral_likelihood()
 
@@ -229,7 +266,7 @@ def test_seq_joint_lh_is_max():
                                            ">D\n"+D_char+"\n"), 'fasta')
 
         myTree_1 = TreeAnc(gtr=mygtr, tree = tiny_tree,
-                            aln=tiny_aln_1, verbose = 4)
+                            aln=tiny_aln_1, verbose = 4, rng_seed=1234)
 
         myTree_1.reconstruct_anc(method='ml', marginal=False, debug=True)
         logLH = myTree_1.tree.sequence_LH
@@ -241,4 +278,3 @@ def test_seq_joint_lh_is_max():
     print(abs(ref.max() - real) )
     # joint chooses the most likely realization of the tree
     assert(abs(ref.max() - real) < 1e-10)
-    return ref, real


=====================================
treetime/CLI_io.py
=====================================
@@ -0,0 +1,303 @@
+import os, sys
+from Bio import AlignIO, Phylo
+from .vcf_utils import read_vcf, write_vcf
+from .seq_utils import alphabets
+from Bio import __version__ as bioversion
+from . import version as treetime_version
+import numpy as np
+
+def get_outdir(params, suffix='_treetime'):
+    if params.outdir:
+        if os.path.exists(params.outdir):
+            if os.path.isdir(params.outdir):
+                return params.outdir.rstrip('/') + '/'
+            else:
+                print("designated output location %s is not a directory"%params.outdir, file=sys.stderr)
+        else:
+            os.makedirs(params.outdir)
+            return params.outdir.rstrip('/') + '/'
+
+    from datetime import datetime
+    outdir_stem = datetime.now().date().isoformat()
+    outdir = outdir_stem + suffix.rstrip('/')+'/'
+    count = 1
+    while os.path.exists(outdir):
+        outdir = outdir_stem + '-%04d'%count + suffix.rstrip('/')+'/'
+        count += 1
+
+    os.makedirs(outdir)
+    return outdir
+
+def get_basename(params, outdir):
+    # if params.aln:
+    #     basename = outdir + '.'.join(params.aln.split('/')[-1].split('.')[:-1])
+    # elif params.tree:
+    #     basename = outdir + '.'.join(params.tree.split('/')[-1].split('.')[:-1])
+    # else:
+    basename = outdir
+    return basename
+
+def read_in_DRMs(drm_file, offset):
+    import pandas as pd
+
+    DRMs = {}
+    drmPositions = []
+
+    df = pd.read_csv(drm_file, sep='\t')
+    for mi, m in df.iterrows():
+        pos = m.GENOMIC_POSITION-1+offset #put in correct numbering
+        drmPositions.append(pos)
+
+        if pos in DRMs:
+            DRMs[pos]['alt_base'][m.ALT_BASE] = m.SUBSTITUTION
+        else:
+            DRMs[pos] = {}
+            DRMs[pos]['drug'] = m.DRUG
+            DRMs[pos]['alt_base'] = {}
+            DRMs[pos]['alt_base'][m.ALT_BASE] = m.SUBSTITUTION
+            DRMs[pos]['gene'] = m.GENE
+
+    drmPositions = np.array(drmPositions)
+    drmPositions = np.unique(drmPositions)
+    drmPositions = np.sort(drmPositions)
+
+    DRM_info = {'DRMs': DRMs,
+            'drmPositions': drmPositions}
+
+    return DRM_info
+
+
+def read_if_vcf(params):
+    """
+    Checks if input is VCF and reads in appropriately if it is
+    """
+    ref = None
+    aln = params.aln
+    fixed_pi = None
+    if hasattr(params, 'aln') and params.aln is not None:
+        if any([params.aln.lower().endswith(x) for x in ['.vcf', '.vcf.gz']]):
+            if not params.vcf_reference:
+                print("ERROR: a reference Fasta is required with VCF-format alignments")
+                return -1
+            compress_seq = read_vcf(params.aln, params.vcf_reference)
+            sequences = compress_seq['sequences']
+            ref = compress_seq['reference']
+            aln = sequences
+
+            if not hasattr(params, 'gtr') or params.gtr=="infer": #if not specified, set it:
+                alpha = alphabets['aa'] if params.aa else alphabets['nuc']
+                fixed_pi = [ref.count(base)/len(ref) for base in alpha]
+                if fixed_pi[-1] == 0:
+                    fixed_pi[-1] = 0.05
+                    fixed_pi = [v-0.01 for v in fixed_pi]
+
+    return aln, ref, fixed_pi
+
+
+def plot_rtt(tt, fname):
+    tt.plot_root_to_tip()
+
+    from matplotlib import pyplot as plt
+    plt.savefig(fname)
+    print("--- root-to-tip plot saved to  \n\t"+fname)
+
+
+def export_sequences_and_tree(tt, basename, is_vcf=False, zero_based=False,
+                              report_ambiguous=False, timetree=False, confidence=False,
+                              reconstruct_tip_states=False, tree_suffix=''):
+    seq_info = is_vcf or tt.aln
+    if is_vcf:
+        outaln_name = basename + f'ancestral_sequences{tree_suffix}.vcf'
+        write_vcf(tt.get_reconstructed_alignment(reconstruct_tip_states=reconstruct_tip_states), outaln_name)
+    elif tt.aln:
+        outaln_name = basename + f'ancestral_sequences{tree_suffix}.fasta'
+        AlignIO.write(tt.get_reconstructed_alignment(reconstruct_tip_states=reconstruct_tip_states), outaln_name, 'fasta')
+    if seq_info:
+        print("\n--- alignment including ancestral nodes saved as  \n\t %s\n"%outaln_name)
+
+    # decorate tree with inferred mutations
+    terminal_count = 0
+    offset = 0 if zero_based else 1
+    if timetree:
+        dates_fname = basename + f'dates{tree_suffix}.tsv'
+        fh_dates = open(dates_fname, 'w', encoding='utf-8')
+        if confidence:
+            fh_dates.write('#Lower and upper bound delineate the 90% max posterior region\n')
+            fh_dates.write('#node\tdate\tnumeric date\tlower bound\tupper bound\n')
+        else:
+            fh_dates.write('#node\tdate\tnumeric date\n')
+
+    mutations_out = open(basename + "branch_mutations.txt", "w")
+    mutations_out.write("node\tstate1\tpos\tstate2\n")
+    for n in tt.tree.find_clades():
+        if timetree:
+            if confidence:
+                if n.bad_branch:
+                    fh_dates.write('%s\t--\t--\t--\t--\n'%(n.name))
+                else:
+                    conf = tt.get_max_posterior_region(n, fraction=0.9)
+                    fh_dates.write('%s\t%s\t%f\t%f\t%f\n'%(n.name, n.date, n.numdate,conf[0], conf[1]))
+            else:
+                if n.bad_branch:
+                    fh_dates.write('%s\t--\t--\n'%(n.name))
+                else:
+                    fh_dates.write('%s\t%s\t%f\n'%(n.name, n.date, n.numdate))
+
+        n.confidence=None
+        # due to a bug in older versions of biopython that truncated filenames in nexus export
+        # we truncate them by hand and make them unique.
+        if n.is_terminal() and len(n.name)>40 and bioversion<"1.69":
+            n.name = n.name[:35]+'_%03d'%terminal_count
+            terminal_count+=1
+        n.comment=''
+        if seq_info and len(n.mutations):
+            if n.mask is None:
+                if report_ambiguous:
+                    n.comment= '&mutations="' + ','.join([a+str(pos + offset)+d for (a,pos, d) in n.mutations])+'"'
+                else:
+                    n.comment= '&mutations="' + ','.join([a+str(pos + offset)+d for (a,pos, d) in n.mutations
+                                                        if tt.gtr.ambiguous not in [a,d]])+'"'
+            else:
+                if report_ambiguous:
+                    n.comment= '&mutations="' + ','.join([a+str(pos + offset)+d for (a,pos, d) in n.mutations if n.mask[pos]>0])+f'",mcc="{n.mcc}"'
+                else:
+                    n.comment= '&mutations="' + ','.join([a+str(pos + offset)+d for (a,pos, d) in n.mutations
+                                                        if tt.gtr.ambiguous not in [a,d] and n.mask[pos]>0])+f'",mcc="{n.mcc}"'
+
+                for (a, pos, d) in n.mutations:
+                    if tt.gtr.ambiguous not in [a,d] or report_ambiguous:
+                        mutations_out.write("%s\t%s\t%s\t%s\n" %(n.name, a, pos + 1, d))
+        if timetree:
+            n.comment+=(',' if n.comment else '&') + 'date=%1.2f'%n.numdate
+    mutations_out.close()
+
+    # write tree to file
+    fmt_bl = "%1.6f" if tt.data.full_length<1e6 else "%1.8e"
+    if timetree:
+        outtree_name = basename + f'timetree{tree_suffix}.nexus'
+        print("--- saved divergence times in \n\t %s\n"%dates_fname)
+        Phylo.write(tt.tree, outtree_name, 'nexus')
+    else:
+        outtree_name = basename + f'annotated_tree{tree_suffix}.nexus'
+        Phylo.write(tt.tree, outtree_name, 'nexus', format_branch_length=fmt_bl)
+    print("--- tree saved in nexus format as  \n\t %s\n"%outtree_name)
+
+    # Only create auspice json if there is sequence information
+    auspice = create_auspice_json(tt, timetree=timetree, confidence=confidence, seq_info=seq_info)
+    outtree_name_json = basename + f'auspice_tree{tree_suffix}.json'
+    with open(outtree_name_json, 'w') as fh:
+        import json
+        json.dump(auspice, fh, indent=0)
+        print("--- tree saved in auspice json format as  \n\t %s\n"%outtree_name_json)
+
+    if timetree:
+        for n in tt.tree.find_clades():
+            n.branch_length = n.mutation_length
+        outtree_name = basename + f'divergence_tree{tree_suffix}.nexus'
+        Phylo.write(tt.tree, outtree_name, 'nexus', format_branch_length=fmt_bl)
+        print("--- divergence tree saved in nexus format as  \n\t %s\n"%outtree_name)
+
+    if hasattr(tt, 'outliers') and tt.outliers is not None:
+        print("--- saved detected outliers as " + basename + 'outliers.tsv')
+        tt.outliers.to_csv(basename + 'outliers.tsv', sep='\t')
+
+def print_save_plot_skyline(tt, n_std=2.0, screen=True, save='', plot='', gen=50):
+    if plot:
+        import matplotlib.pyplot as plt
+
+    skyline, conf = tt.merger_model.skyline_inferred(gen=gen, confidence=n_std)
+    if save: fh = open(save, 'w', encoding='utf-8')
+    header1 = "Skyline assuming "+ str(gen)+" gen/year and approximate confidence bounds (+/- %f standard deviations of the LH)\n"%n_std
+    header2 = "date \tN_e \tlower \tupper"
+    if screen: print('\t'+header1+'\t'+header2)
+    if save: fh.write("#"+ header1+'#'+header2+'\n')
+    for (x,y, y1, y2) in zip(skyline.x, skyline.y, conf[0], conf[1]):
+        if screen: print("\t%1.3f\t%1.3e\t%1.3e\t%1.3e"%(x,y, y1, y2))
+        if save: fh.write("%1.3f\t%1.3e\t%1.3e\t%1.3e\n"%(x,y, y1, y2))
+
+    if save:
+        print("\n --- written skyline to %s\n"%save)
+        fh.close()
+
+    if plot:
+        plt.figure()
+        plt.fill_between(skyline.x, conf[0], conf[1], color=(0.8, 0.8, 0.8))
+        plt.plot(skyline.x, skyline.y, label='maximum likelihood skyline')
+        plt.yscale('log')
+        plt.legend()
+        plt.ticklabel_format(axis='x',useOffset=False)
+        plt.savefig(plot)
+
+
+
+def create_auspice_json(tt, timetree=False, confidence=False, seq_info=False):
+    # mock up meta data for auspice json
+    from datetime import datetime
+    meta = {
+        "title": f"Auspice visualization of TreeTime (v{treetime_version}) analysis",
+        "build_url": "https://github.com/neherlab/treetime",
+        "last_updated": datetime.now().strftime("%Y-%m-%d"),
+        "treetime_version": treetime_version,
+        "genome_annotations": {
+            "nuc":{"start":1, "end":int(tt.data.full_length), "type":"source", "strand":"+:"}
+        },
+        "panels":["tree", "entropy"],
+        "colorings": [
+            {
+                "title": "Date",
+                "type": "continuous",
+                "key": "num_date",
+            },
+            {
+                "title": "Genotype",
+                "type": "categorical",
+                "key": "gt",
+            },
+            {
+                "title": "Excluded",
+                "type": "categorical",
+                "key": "bad_branch"
+            },
+            {
+                "title": "Branch Support",
+                "type": "continuous",
+                "key": "confidence"
+            }
+        ],
+        "display_defaults": {"color_by":"bad_branch"},
+        "filters": ["bad_branch"]
+    }
+
+    def node_to_json(n, pdiv=0.0):
+        j = {"name":n.name, "node_attrs":{}, "branch_attrs":{}}
+        if n.clades:
+            j["children"] = []
+
+        if timetree:
+            j["node_attrs"]["num_date"] = {"value":float(n.numdate)}
+            if confidence:
+                conf = tt.get_max_posterior_region(n, fraction=0.9)
+                j["node_attrs"]["num_date"]["confidence"] = (float(conf[0]), float(conf[1]))
+        j["node_attrs"]["div"] = float(pdiv + n.mutation_length)
+        j["node_attrs"]["bad_branch"] = {"value": "Yes" if n.bad_branch else "No"}
+
+        if seq_info: # only add mutations to the json if run with sequence data (fasta or vcf)
+            j["branch_attrs"]["mutations"] = {"nuc": [f"{a}{pos+1}{d}" for a,pos,d in n.mutations if d in "ACGT-"]}
+            # generate bootstrap confidence substitute via the negative exponential of the number of mutations
+            # this is the bootstrap confidence for iid mutations (only ACGT mutations)
+            j["node_attrs"]["confidence"] = {"value":round(1-np.exp(-len([pos for a,pos,d in n.mutations if d in "ACGT"])),3)
+                                            if not n.is_terminal() else 1.0}
+        return j
+
+    # create the tree data structure from the Biopython tree
+    tree = node_to_json(tt.tree.root, 0.0)
+    # dictionary to look up nodes by name
+    node_lookup = {tt.tree.root.name: tree}
+    for n in tt.tree.get_nonterminals():
+        n_json = node_lookup[n.name]
+        for c in n.clades:
+            # generate node jsons for all children and attach them the to parent
+            n_json["children"].append(node_to_json(c, n_json["node_attrs"]["div"]))
+            node_lookup[c.name] = n_json["children"][-1]
+
+    return {"meta":meta, "tree":tree}


=====================================
treetime/__init__.py
=====================================
@@ -1,4 +1,4 @@
-version="0.9.4"
+version="0.11.1"
 ## Here we define an error class for TreeTime errors, MissingData, UnknownMethod and NotReady errors
 ## are all due to incorrect calling of TreeTime functions or input data that does not fit our base assumptions.
 ## Errors marked as TreeTimeUnknownErrors might be due to data not fulfilling base assumptions or due
@@ -28,6 +28,12 @@ class TreeTimeUnknownError(Exception):
     """TreeTimeUnknownError class raised when TreeTime fails during inference due to an unknown reason. This might be due to data not fulfilling base assumptions or due  to bugs in TreeTime. Please report them to the developers if they persist."""
     pass
 
+import os, sys
+recursion_limit = os.environ.get("TREETIME_RECURSION_LIMIT")
+if recursion_limit:
+    sys.setrecursionlimit(int(recursion_limit))
+else:
+    sys.setrecursionlimit(max(sys.getrecursionlimit(), 10000))
 
 from .treeanc import TreeAnc
 from .treetime import TreeTime, plot_vs_years


=====================================
treetime/__main__.py
=====================================
@@ -3,10 +3,8 @@
 Stub function and module used as a setuptools entry point.
 Based on augur's __main__.py and setup.py
 """
-
-from __future__ import print_function, division, absolute_import
 import sys
-from treetime import version, make_parser
+from treetime import make_parser
 
 
 # Entry point for setuptools-installed script and bin/augur dev wrapper.


=====================================
treetime/argument_parser.py
=====================================
@@ -4,8 +4,6 @@ from .wrappers import ancestral_reconstruction, mugration, scan_homoplasies,\
                       timetree, estimate_clock_model, arg_time_trees
 from . import version
 
-py2 = sys.version_info.major==2
-
 def set_default_subparser(self, name, args=None, positional_args=0):
     """default subparser selection. Call after setup, just before parse_args()
     name: is the name of the subparser to call by default
@@ -30,10 +28,6 @@ def set_default_subparser(self, name, args=None, positional_args=0):
                 args.insert(1, name)
 
 
-if py2:
-    argparse.ArgumentParser.set_default_subparser = set_default_subparser
-
-
 treetime_description = \
     "TreeTime: Maximum Likelihood Phylodynamics\n\n"
 subcommand_description = \
@@ -125,10 +119,14 @@ def add_aln_group(parser, required=True):
 
 
 def add_reroot_group(parser):
-    parser.add_argument('--clock-filter', type=float, default=3,
+    parser.add_argument('--clock-filter', type=float, default=4.0,
                               help="ignore tips that don't follow a loose clock, "
-                                   "'clock-filter=number of interquartile ranges from regression'. "
-                                   "Default=3.0, set to 0 to switch off.")
+                                   "'clock-filter=number of interquartile ranges from regression (method=`residual`)' "
+                                   "or z-score of local clock deviation (method=`local`). "
+                                   "Default=4.0, set to 0 to switch off.")
+    parser.add_argument('--clock-filter-method', choices=['residual', 'local'], default='residual',
+                        help="Use residuals from global clock (`residual`, default) or local clock deviation (`clock`) "
+                             "to filter out tips that don't follow the clock")
     reroot_group = parser.add_mutually_exclusive_group()
     reroot_group.add_argument('--reroot', nargs='+', default='best', help=reroot_description)
     reroot_group.add_argument('--keep-root', required = False, action="store_true", default=False,
@@ -144,6 +142,7 @@ def add_gtr_arguments(parser):
     parser.add_argument('--gtr', default='infer', help=gtr_description)
     parser.add_argument('--gtr-params', nargs='+', help=gtr_params_description)
     parser.add_argument('--aa', action='store_true', help="use aminoacid alphabet")
+    parser.add_argument('--custom-gtr', default = None, type=str, help="filename of pre-defined custom GTR model in standard TreeTime format")
 
 def add_time_arguments(parser):
     parser.add_argument('--dates', type=str, help=dates_description)
@@ -179,6 +178,11 @@ def add_timetree_args(parser):
                              "distribution in the final round.")
     parser.add_argument('--keep-polytomies', default=False, action='store_true',
                         help="Don't resolve polytomies using temporal information.")
+    parser.add_argument('--stochastic-resolve', default=False, action='store_true',
+                        help="Resolve polytomies using a random coalescent tree.")
+    parser.add_argument('--greedy-resolve', action='store_false', dest='stochastic_resolve',
+                        help="Resolve polytomies greedily. Currently default, but will "
+                             "switched to `stochastic-resolve` in future versions.")
     # parser.add_argument('--keep-node-order', default=False, action='store_true',
     #                     help="Don't ladderize the tree.")
     parser.add_argument('--relax',nargs=2, type=float,
@@ -193,6 +197,8 @@ def add_timetree_args(parser):
                           help=coalescent_description)
     parser.add_argument('--n-skyline', default="20", type=int,
                           help="number of grid points in skyline coalescent model")
+    parser.add_argument('--gen-per-year', default="50.0", type=float,
+                          help="number of generations per year - used for estimating N_e in coalescent models")
     parser.add_argument('--n-branches-posterior', default=False, action='store_true',
                           help= "add posterior LH to coalescent model: use the posterior probability distributions of "
                                 "divergence times for estimating the number of branches when calculating the coalescent merger"
@@ -214,11 +220,9 @@ def make_parser():
 
     subparsers = parser.add_subparsers()
 
-    if py2:
-        t_parser = subparsers.add_parser('tt', description=timetree_description)
-    else:
-        t_parser = parser
+    t_parser = parser
     t_parser.add_argument('--tree', type=str, help=tree_description)
+    t_parser.add_argument('--rng-seed', type=int, help="random number generator seed for treetime")
     add_seq_len_aln_group(t_parser)
     add_time_arguments(t_parser)
     add_timetree_args(t_parser)
@@ -242,6 +246,7 @@ def make_parser():
     h_parser = subparsers.add_parser('homoplasy', description=homoplasy_description)
     add_aln_group(h_parser)
     h_parser.add_argument('--tree', type = str,  help=tree_description)
+    h_parser.add_argument('--rng-seed', type=int, help="random number generator seed for treetime")
     h_parser.add_argument('--const', type = int, default=0, help ="number of constant sites not included in alignment")
     h_parser.add_argument('--rescale', type = float, default=1.0, help ="rescale branch lengths")
     h_parser.add_argument('--detailed', required = False, action="store_true",  help ="generate a more detailed report")
@@ -256,6 +261,7 @@ def make_parser():
     a_parser = subparsers.add_parser('ancestral', description=ancestral_description)
     add_aln_group(a_parser)
     a_parser.add_argument('--tree', type=str,  help=tree_description)
+    a_parser.add_argument('--rng-seed', type=int, help="random number generator seed for treetime")
     add_gtr_arguments(a_parser)
     a_parser.add_argument('--marginal', default=False, action="store_true", help ="marginal reconstruction of ancestral sequences")
     add_anc_arguments(a_parser)
@@ -265,6 +271,7 @@ def make_parser():
     ## MUGRATION
     m_parser = subparsers.add_parser('mugration', description=mugration_description)
     m_parser.add_argument('--tree', required = True, type=str, help=tree_description)
+    m_parser.add_argument('--rng-seed', type=int, help="random number generator seed for treetime")
     m_parser.add_argument('--name-column', type=str, help="label of the column to be used as taxon name")
     m_parser.add_argument('--attribute', type=str, help ="attribute to reconstruct, e.g. country")
     m_parser.add_argument('--states', required = True, type=str, help ="csv or tsv file with discrete characters."
@@ -291,6 +298,7 @@ def make_parser():
                         "It will reroot the tree to maximize the clock-like "
                         "signal and recalculate branch length unless run with --keep-root.")
     c_parser.add_argument('--tree', required=True, type=str,  help=tree_description)
+    c_parser.add_argument('--rng-seed', type=int, help="random number generator seed for treetime")
     add_time_arguments(c_parser)
     add_seq_len_aln_group(c_parser)
 
@@ -309,6 +317,7 @@ def make_parser():
             description="Calculates the root-to-tip regression and quantifies the 'clock-i-ness' of the tree. "
                         "It will reroot the tree to maximize the clock-like "
                         "signal and recalculate branch length unless run with --keep_root.")
+    arg_parser.add_argument('--rng-seed', type=int, help="random number generator seed for treetime")
     arg_parser.add_argument('--trees', nargs=2, required=True, type=str)
     arg_parser.add_argument('--alignments', nargs=2, required=True, type=str)
     arg_parser.add_argument('--mccs', required=True, type=str)
@@ -326,8 +335,4 @@ def make_parser():
     v_parser = subparsers.add_parser('version', description='print version')
     v_parser.set_defaults(func=lambda x: print("treetime "+version))
 
-    ## call the relevant function and return
-    if py2:
-        parser.set_default_subparser('tt')
-
     return parser


=====================================
treetime/clock_filter_methods.py
=====================================
@@ -0,0 +1,195 @@
+import numpy as np
+import pandas as pd
+
+def residual_filter(tt, n_iqd):
+    terminals = tt.tree.get_terminals()
+    clock_rate = tt.clock_model['slope']
+    icpt = tt.clock_model['intercept']
+    res = {}
+    for node in terminals:
+        if hasattr(node, 'raw_date_constraint') and  (node.raw_date_constraint is not None):
+            res[node] = node.dist2root - clock_rate*np.mean(node.raw_date_constraint) - icpt
+
+    residuals = np.array(list(res.values()))
+    iqd = np.percentile(residuals,75) - np.percentile(residuals,25)
+    outliers = {}
+    for node,r in res.items():
+        if abs(r)>n_iqd*iqd and node.up.up is not None:
+            node.bad_branch=True
+            outliers[node.name] = {'tau':(node.dist2root - icpt)/clock_rate,  'avg_date': np.mean(node.raw_date_constraint),
+                                'exact_date': node.raw_date_constraint if type(node) is float else None,
+                                'residual': r/iqd}
+        else:
+            node.bad_branch=False
+
+    tt.outliers=None
+    if len(outliers):
+        outlier_df = pd.DataFrame(outliers).T.loc[:,['avg_date', 'tau', 'residual']]\
+                                .rename(columns={'avg_date':'given_date', 'tau':'apparent_date'})
+        tt.logger("Clock_filter.residual_filter marked the following outliers:", 2, warn=True)
+        if tt.verbose>=2:
+            print(outlier_df)
+        tt.outliers = outlier_df
+    return len(outliers)
+
+def local_filter(tt, z_score_threshold):
+    tt.logger(f"TreeTime.ClockFilter: starting local_outlier_detection", 2)
+
+    node_info = collect_node_info(tt)
+
+    node_info, z_scale = calculate_node_timings(tt, node_info)
+    tt.logger(f"TreeTime.ClockFilter: z-scale {z_scale:1.2f}", 2)
+
+    outliers = flag_outliers(tt, node_info, z_score_threshold, z_scale)
+
+    for n in tt.tree.get_terminals():
+        if n.name in outliers:
+            n.bad_branch = True
+
+    tt.outliers=None
+    if len(outliers):
+        outlier_df = pd.DataFrame(outliers).T.loc[:,['avg_date', 'tau', 'z', 'diagnosis']]\
+                                .rename(columns={'avg_date':'given_date', 'tau':'apparent_date', 'z':'z-score'})
+        tt.logger("Clock_filter.local_filter marked the following outliers", 2, warn=True)
+        if tt.verbose>=2:
+            print(outlier_df)
+        tt.outliers = outlier_df
+    return len(outliers)
+
+
+def flag_outliers(tt, node_info, z_score_threshold, z_scale):
+    def add_outlier_info(z, n, n_info, parent_tau, mu):
+        n_info['z'] = z
+        diagnosis=''
+        # muts = n_info["nmuts"] if n.is_terminal() else 0.0
+        # parent_tau = node_info[n.up.name]['tau'] if n.is_terminal() else n_info['tau']
+        if z<0:
+            if np.abs(n_info['avg_date']-parent_tau) > n_info["nmuts"]/mu:
+                diagnosis='date_too_early'
+            else:
+                diagnosis = 'excess_mutations'
+        else:
+            diagnosis = 'date_too_late'
+        n_info['diagnosis'] = diagnosis
+        return n_info
+
+    outliers = {}
+    mu = tt.clock_model['slope']*tt.data.full_length
+    for n in tt.tree.get_terminals():
+        if n.up.up is None:
+            continue # do not label children of the root as bad -- typically a problem with root choice that will be fixed anyway
+        n_info = node_info[n.name]
+        parent_tau = node_info[n.up.name]['tau']
+        if n_info['exact_date']:
+            z = (n_info['avg_date'] - n_info['tau'])/z_scale
+            if np.abs(z) > z_score_threshold:
+                outliers[n.name] = add_outlier_info(z, n, n_info, parent_tau, mu)
+        elif n.raw_date_constraint and len(n.raw_date_constraint):
+            zs = [(n_info['tau'] - x)/z_scale for x in n.raw_date_constraint]
+            if zs[0]*zs[1]>0 and np.min(np.abs(zs))>z_score_threshold:
+                outliers[n.name] = add_outlier_info(zs[0] if np.abs(zs[0])<np.abs(zs[1]) else zs[1],
+                                                    n, n_info, parent_tau, mu)
+
+    return outliers
+
+def calculate_node_timings(tt, node_info, eps=0.2):
+    mu = tt.clock_model['slope']*tt.data.full_length
+    sigma_sq = (3/mu)**2
+    tt.logger(f"Clockfilter.calculate_node_timings: mu={mu:1.3e}/y, sigma={3/mu:1.3e}y", 2)
+    for n in tt.tree.find_clades(order='postorder'):
+        p = node_info[n.name]
+        if not p['exact_date'] or p['skip']:
+            continue
+
+        if n.is_terminal():
+            prefactor = (p["observations"]/sigma_sq + mu**2/(p["nmuts"]+eps))
+            p["a"] = (p["avg_date"]/sigma_sq + mu*p["nmuts"]/(p["nmuts"]+eps))/prefactor
+        else:
+            children = [node_info[c.name] for c in n if (not node_info[c.name]['skip']) and node_info[c.name]['exact_date']]
+            if n==tt.tree.root:
+                tmp_children_1 = mu*np.sum([(mu*c["a"]-c["nmuts"])/(eps+c["nmuts"]) for c in children])
+                tmp_children_2 = mu**2*np.sum([(1-c["b"])/(eps+c["nmuts"]) for c in children])
+                prefactor = (p["observations"]/sigma_sq + tmp_children_2)
+                p["a"] = (p["observations"]*p["avg_date"]/sigma_sq + tmp_children_1)/prefactor
+            else:
+                tmp_children_1 = mu*np.sum([(mu*c["a"]-c["nmuts"])/(eps+c["nmuts"]) for c in children])
+                tmp_children_2 = mu**2*np.sum([(1-c["b"])/(eps+c["nmuts"]) for c in children])
+                prefactor = (p["observations"]/sigma_sq + mu**2/(p["nmuts"]+eps) + tmp_children_2)
+                p["a"] = (p["observations"]*p["avg_date"]/sigma_sq + mu*p["nmuts"]/(p["nmuts"]+eps)+tmp_children_1)/prefactor
+        p["b"] = mu**2/(p["nmuts"]+eps)/prefactor
+
+    node_info[tt.tree.root.name]["tau"] = node_info[tt.tree.root.name]["a"]
+
+    ## need to deal with tips without exact dates below.
+    dev = []
+    for n in tt.tree.get_nonterminals(order='preorder'):
+        p = node_info[n.name]
+        for c in n:
+            c_info = node_info[c.name]
+            if c_info['skip']:
+                c_info['tau']=p['tau']
+            else:
+                if c_info['exact_date']:
+                    c_info["tau"] = c_info["a"] + c_info["b"]*p["tau"]
+                else:
+                    c_info["tau"] = p["tau"] + c_info['nmuts']/mu
+            if c.is_terminal() and c_info['exact_date']:
+                dev.append(c_info['avg_date']-c_info['tau'])
+
+    return node_info, np.std(dev)
+
+
+def collect_node_info(tt, percentile_for_exact_date=90):
+    node_info = {}
+    aln = tt.aln or False
+    if aln and (not tt.sequence_reconstruction):
+        tt.infer_ancestral_sequences(infer_gtr=False)
+    L = tt.data.full_length
+
+    date_uncertainty = [np.abs(n.raw_date_constraint[1]-n.raw_date_constraint[0])
+                            if type(n.raw_date_constraint)!=float else 0.0
+                        for n in tt.tree.get_terminals()
+                            if n.raw_date_constraint is not None]
+    from scipy.stats import scoreatpercentile
+    uncertainty_cutoff = scoreatpercentile(date_uncertainty, percentile_for_exact_date)*1.01
+
+    for n in tt.tree.get_nonterminals(order='postorder'):
+        parent = {"dates": [], "tips": {}, "skip":False}
+        exact_dates = 0
+        for c in n:
+            if c.is_terminal():
+                child = {'skip':False}
+                child["nmuts"] = len([m for m in c.mutations if m[-1] in 'ACGT']) if aln \
+                                      else np.round(c.branch_length*L)
+                if c.raw_date_constraint is None:
+                    child['exact_date'] = False
+                elif type(c.raw_date_constraint)==float:
+                    child['exact_date'] = True
+                else:
+                    child['exact_date'] = np.abs(c.raw_date_constraint[1]-c.raw_date_constraint[0])<=uncertainty_cutoff
+
+                if child['exact_date']:
+                    exact_dates += 1
+                    if child["nmuts"]==0:
+                        child['skip'] = True
+                        parent["tips"][c.name]={'date': np.mean(c.raw_date_constraint),
+                                                'exact_date':child['exact_date']}
+                    else:
+                        child['skip'] = False
+                        child['observations'] = 1
+                if c.raw_date_constraint is not None:
+                    child["avg_date"] = np.mean(c.raw_date_constraint)
+                node_info[c.name] = child
+            else:
+                if node_info[c.name]['exact_date']:
+                    exact_dates += 1
+
+        parent['exact_date'] = exact_dates>0
+
+        parent["nmuts"] = len([m for m in n.mutations if m[-1] in 'ACGT']) if aln else np.round(n.branch_length*L)
+        d = [v['date'] for v in parent['tips'].values() if v['exact_date']]
+        parent["observations"] = len(d)
+        parent["avg_date"] = np.mean(d) if len(d) else 0.0
+        node_info[n.name] = parent
+
+    return node_info


=====================================
treetime/clock_tree.py
=====================================
@@ -1,6 +1,6 @@
 import numpy as np
 from . import config as ttconf
-from . import MissingDataError, UnknownMethodError
+from . import MissingDataError, UnknownMethodError, TreeTimeUnknownError
 from .treeanc import TreeAnc
 from .utils import numeric_date, DateConversion, datestring_from_numeric
 from .distribution import Distribution
@@ -24,7 +24,7 @@ class ClockTree(TreeAnc):
 
     def __init__(self, *args, dates=None, debug=False, real_dates=True, precision_fft = 'auto',
                 precision='auto', precision_branch='auto', branch_length_mode='joint', use_covariation=False,
-                use_fft=True,**kwargs):
+                use_fft=True, **kwargs):
 
         """
         ClockTree constructor
@@ -32,20 +32,20 @@ class ClockTree(TreeAnc):
         Parameters
         ----------
 
-         dates : dict
+        dates : dict
             :code:`{leaf_name:leaf_date}` dictionary
 
-         debug : bool
+        debug : bool
             If True, the debug mode is ON, which means no or less clean-up of
             obsolete parameters to control program execution in intermediate
             states. In debug mode, the python debugger is also allowed to interrupt
             program execution with intercative shell if an error occurs.
 
-         real_dates : bool
+        real_dates : bool
             If True, some additional checks for the input dates sanity will be
             performed.
 
-         precision : int
+        precision : int
             Precision can be 0 (rough), 1 (default), 2 (fine), or 3 (ultra fine).
             This parameter determines the number of grid points that are used
             for the evaluation of the branch length interpolation objects.
@@ -59,11 +59,11 @@ class ClockTree(TreeAnc):
             The number of points desired to span the width of the FWHM of a distribution
             can be specified explicitly by precision_fft (default is 200).
 
-         branch_length_mode : str
+        branch_length_mode : str
             determines whether branch length are calculated using the 'joint' ML,
             'marginal' ML, or branch length of the input tree ('input').
 
-         use_covariation : bool
+        use_covariation : bool
             determines whether root-to-tip regression accounts for covariance
             introduced by shared ancestry.
 
@@ -142,7 +142,7 @@ class ClockTree(TreeAnc):
         if bad_branch_counter>self.tree.count_terminals()-3:
             raise MissingDataError("ERROR: ALMOST NO VALID DATE CONSTRAINTS")
 
-        self.logger("ClockTree._assign_dates: assigned date contraints to {} out of {} tips.".format(self.tree.count_terminals()-bad_branch_counter, self.tree.count_terminals()), 1)
+        self.logger("ClockTree._assign_dates: assigned date constraints to {} out of {} tips.".format(self.tree.count_terminals()-bad_branch_counter, self.tree.count_terminals()), 1)
         return ttconf.SUCCESS
 
 
@@ -272,7 +272,7 @@ class ClockTree(TreeAnc):
         self.date2dist = DateConversion.from_regression(self.clock_model)
 
 
-    def init_date_constraints(self, ancestral_inference=False, clock_rate=None, **kwarks):
+    def init_date_constraints(self, clock_rate=None, **kwarks):
         """
         Get the conversion coefficients between the dates and the branch
         lengths as they are used in ML computations. The conversion formula is
@@ -287,10 +287,6 @@ class ClockTree(TreeAnc):
         Parameters
         ----------
 
-         ancestral_inference: bool
-            If True, reinfer ancestral sequences
-            when ancestral sequences are missing
-
          clock_rate: float
             If specified, timetree optimization will be done assuming a
             fixed clock rate as specified
@@ -426,6 +422,7 @@ class ClockTree(TreeAnc):
                     # Cx.y is the branch length corresponding the optimal subtree
                     bl = node.branch_length_interpolator.x
                     x = bl + node.date_constraint.peak_pos
+                    # if a merger model is defined, add its (log) rate to the propagated distribution
                     if hasattr(self, 'merger_model') and self.merger_model:
                         node.joint_pos_Lx =  Distribution(x, -self.merger_model.integral_merger_rate(node.date_constraint.peak_pos)
                                                 + node.branch_length_interpolator(bl), min_width=self.min_width, is_log=True)
@@ -677,7 +674,6 @@ class ClockTree(TreeAnc):
         self.logger("ClockTree - Marginal reconstruction:  Propagating root -> leaves...", 2)
         from scipy.interpolate import interp1d
         for node in self.tree.find_clades(order='preorder'):
-
             ## If a delta constraint in known no further work required
             if (node.date_constraint is not None) and (not node.bad_branch) and node.date_constraint.is_delta:
                 node.marginal_pos_LH = node.date_constraint
@@ -703,7 +699,16 @@ class ClockTree(TreeAnc):
 
                     if hasattr(self, 'merger_model') and self.merger_model:
                         time_points = parent.marginal_pos_LH.x
-                        # As Lx do not include the node contribution this must be added on
+                        if len(time_points)<5:
+                            time_points = np.unique(np.concatenate([
+                                                        time_points,
+                                                        np.linspace(np.min([x.xmin for x in complementary_msgs]),
+                                                                    np.max([x.xmax for x in complementary_msgs]), 50),
+                                                        np.linspace(np.min([x.effective_support[0] for x in complementary_msgs]),
+                                                                    np.max([x.effective_support[1] for x in complementary_msgs]), 50),
+                                                        ]))
+                        # As Lx (the product of child messages) does not include the node contribution this must
+                        # be added to recover the full distribution of the parent node w/o contribution of the focal node.
                         complementary_msgs.append(self.merger_model.node_contribution(parent, time_points))
 
                         # Removed merger rate must be added back if no msgs from parent (equivalent to root node case)
@@ -738,6 +743,7 @@ class ClockTree(TreeAnc):
                 if node.marginal_pos_Lx is None:
                     node.marginal_pos_LH = node.msg_from_parent
                 else:
+                    #node.subtree_distribution contains merger model contribution of this node
                     node.marginal_pos_LH = NodeInterpolator.multiply((node.msg_from_parent, node.subtree_distribution))
 
                 self.logger('ClockTree._ml_t_root_to_leaves: computed convolution'
@@ -755,7 +761,6 @@ class ClockTree(TreeAnc):
                         plt.plot(msg_parent_to_node.x,msg_parent_to_node.y-msg_parent_to_node.peak_val, '-o')
                         plt.ylim(0,100)
                         plt.xlim(-0.05, 0.05)
-                        #import ipdb; ipdb.set_trace()
 
             # assign positions of nodes and branch length
             # note that marginal reconstruction can result in negative branch lengths
@@ -767,13 +772,26 @@ class ClockTree(TreeAnc):
             # construct the inverse cumulative distribution to evaluate confidence intervals
             if node.marginal_pos_LH.is_delta:
                 node.marginal_inverse_cdf=interp1d([0,1], node.marginal_pos_LH.peak_pos*np.ones(2), kind="linear")
+                node.marginal_cdf = interp1d(node.marginal_pos_LH.peak_pos*np.ones(2), [0,1], kind="linear")
             else:
                 dt = np.diff(node.marginal_pos_LH.x)
                 y = node.marginal_pos_LH.prob_relative(node.marginal_pos_LH.x)
                 int_y = np.concatenate(([0], np.cumsum(dt*(y[1:]+y[:-1])/2.0)))
-                int_y/=int_y[-1]
-                node.marginal_inverse_cdf = interp1d(int_y, node.marginal_pos_LH.x, kind="linear")
-                node.marginal_cdf = interp1d(node.marginal_pos_LH.x, int_y, kind="linear")
+                int_x = node.marginal_pos_LH.x
+                if int_y[-1] == 0:
+                    if len(dt)==0 or node.marginal_pos_LH.fwhm < 100*ttconf.TINY_NUMBER:
+                        ##delta function
+                        peak_idx = node.marginal_pos_LH._peak_idx
+                        int_y = np.concatenate((np.zeros(peak_idx), np.ones(len(node.marginal_pos_LH.x)-peak_idx)))
+                        if peak_idx == 0:
+                            int_y = np.concatenate(([0], int_y))
+                            int_x = np.concatenate(([int_x[0]- ttconf.TINY_NUMBER], int_x))
+                    else:
+                        raise TreeTimeUnknownError("Loss of probability in marginal time tree inference.")
+                else:
+                    int_y/=int_y[-1]
+                node.marginal_inverse_cdf = interp1d(int_y, int_x, kind="linear")
+                node.marginal_cdf = interp1d(int_x, int_y, kind="linear")
 
         if not self.debug:
             _cleanup()
@@ -844,6 +862,9 @@ class ClockTree(TreeAnc):
 
             rate_std = np.sqrt(self.clock_model['cov'][0,0])
 
+        if self.clock_model['slope']<0:
+            raise ValueError("ClockTree.calc_rate_susceptibility: rate estimate is negative. In this case the heuristic treetime uses to account for uncertainty in the rate estimate does not work. Please specify the clock-rate and its standard deviation explicitly via CLI parameters or arguments.")
+
         current_rate = np.abs(self.clock_model['slope'])
         upper_rate = self.clock_model['slope'] + rate_std
         lower_rate = max(0.1*current_rate, self.clock_model['slope'] - rate_std)
@@ -1017,8 +1038,8 @@ class ClockTree(TreeAnc):
                     interval = np.array([left(x), right(x)]).squeeze()
                     return (thres - np.diff(node.marginal_cdf(np.array(interval))))**2
 
-                # minimza and determine success
-                sol = minimize(func, bracket=[0,10], args=(fraction,))
+                # minimze and determine success
+                sol = minimize(func, bracket=[0,10], args=(fraction,), method='brent')
                 if sol['success']:
                     mutation_contribution = self.date2dist.to_numdate(np.array([right(sol['x']), left(sol['x'])]).squeeze())
                 else: # on failure, return standard confidence interval


=====================================
treetime/distribution.py
=====================================
@@ -1,12 +1,11 @@
 import numpy as np
+from . import TreeTimeUnknownError
 from scipy.interpolate import interp1d
 try:
     from collections.abc import Iterable
 except ImportError:
     from collections import Iterable
-from copy import deepcopy as make_copy
-from scipy.ndimage import binary_dilation
-from .config import BIG_NUMBER, MIN_LOG, MIN_INTEGRATION_PEAK, TINY_NUMBER, SUPERTINY_NUMBER
+from .config import BIG_NUMBER, MIN_INTEGRATION_PEAK, TINY_NUMBER
 from .utils import clip
 
 class Distribution(object):
@@ -101,35 +100,36 @@ class Distribution(object):
             new_xmax = np.min([k.xmax for k in dists])
 
             x_vals = np.unique(np.concatenate([k.x for k in dists]))
-            x_vals = x_vals[(x_vals> new_xmin-TINY_NUMBER)&(x_vals< new_xmax+TINY_NUMBER)]
+            x_vals = x_vals[(x_vals > new_xmin - TINY_NUMBER)&(x_vals < new_xmax + TINY_NUMBER)]
             n_dists = len(dists)
+            # for reduce number of points if there are many distributions
             if len(x_vals)>100*n_dists and n_dists>3:
+                # make sure there are at least 3 points per distribution on average
                 n_bins = len(x_vals)//n_dists - 6
                 lower_cut_off = n_dists*3
                 upper_cut_off = n_dists*(n_bins + 3)
+                # use peripheral points from the original array, average the center
                 x_vals = np.concatenate((x_vals[:lower_cut_off],
                                          x_vals[lower_cut_off:upper_cut_off].reshape((-1,n_dists)).mean(axis=1),
                                          x_vals[upper_cut_off:]))
+            # evaluate the function at the consolidated lists of x-values
             y_vals = np.sum([k.__call__(x_vals) for k in dists], axis=0)
             try:
                 peak = y_vals.min()
             except:
-                print("WARNING: Unexpected behavior detected in multiply function,"
-                        "if you see this error \n please let us know by filling an issue at: https://github.com/neherlab/treetime/issues")
-                x_vals = [0,1]
-                y_vals = [BIG_NUMBER,BIG_NUMBER]
-                res = Distribution(x_vals, y_vals, is_log=True,
-                                    min_width=min_width, kind='linear')
-                return res
+                raise TreeTimeUnknownError("Error: Unexpected behavior detected in multiply function"
+                        " when determining peak of function with y-values '"+ str(y_vals) + "'.\n\n"
+                        "If you see this error please let us know by filling an issue at: \n"
+                        "https://github.com/neherlab/treetime/issues")
+
+            # remove data points exp(-1000) less likely than the peak
             ind = (y_vals-peak)<BIG_NUMBER/1000
             n_points = ind.sum()
             if n_points == 0:
-                print("WARNING: Unexpected behavior detected in multipy function,"
-                        "if you see this error \n please let us know by filling an issue at: https://github.com/neherlab/treetime/issues")
-                x_vals = [0,1]
-                y_vals = [BIG_NUMBER,BIG_NUMBER]
-                res = Distribution(x_vals, y_vals, is_log=True,
-                                   min_width=min_width, kind='linear')
+                raise TreeTimeUnknownError("Error: Unexpected behavior detected in multiply function. "
+                        "No valid points left after reducing to plausible region.\n\n"
+                        "If you see this error please let us know by filling an issue at:\n"
+                        "https://github.com/neherlab/treetime/issues")
             elif n_points == 1:
                 res = Distribution.delta_function(x_vals[0])
             else:
@@ -159,7 +159,7 @@ class Distribution(object):
         ind = (y_vals-peak)<BIG_NUMBER/1000
         n_points = ind.sum()
         if n_points == 0:
-            print("WARNING: Unexpected behavior detected in multipy function,"
+            print("WARNING: Unexpected behavior detected in multiply function,"
                     "if you see this error \n please let us know by filling an issue at: https://github.com/neherlab/treetime/issues")
             x_vals = [0,1]
             y_vals = [BIG_NUMBER,BIG_NUMBER]
@@ -271,7 +271,7 @@ class Distribution(object):
     @property
     def y(self):
         if self.is_delta:
-            print("THIS SHOULDN'T BE CALLED ON A DELTA FUNCTION")
+            print("Warning: evaluating log probability of a delta distribution.")
             return [self.weight]
         else:
             return self._peak_val + self._func.y
@@ -313,7 +313,6 @@ class Distribution(object):
         Assess the interval on which the value of self is higher than cutoff
         relative to its peak
         """
-        from scipy.optimize import brentq
         log_cutoff = -np.log(cutoff)
         vals = log_cutoff - self.__call__(self.x) + self.peak_val
         above = vals > 0
@@ -345,7 +344,7 @@ class Distribution(object):
 
     def _adjust_grid(self, rel_tol=0.01, yc=10):
         n_iter=0
-        while len(self.y)>200 and n_iter<5:
+        while len(self.x)>200 and n_iter<5:
             interp_err = 2*self.y[1:-1] - self.y[2:] - self.y[:-2]
             ind = np.ones_like(self.y, dtype=bool)
             dy = self.y-self.peak_val
@@ -434,13 +433,24 @@ class Distribution(object):
 
     def fft(self, T, n=None, inverse_time=True):
         if self.is_delta:
-            import ipdb; ipdb.set_trace()
+            raise TreeTimeUnknownError("attempting Fourier transform of delta function.")
+
         from numpy.fft import rfft
         if n is None:
             n=len(T)
+
+        vals = self.prob_relative(T)
+        if max(vals)<1e-15:
+            # probability is lost due to sampling next to timepoints with
+            # vanishing probability. Since we interpolate logarithms, this
+            # results in loss of probability when we should have a delta-like
+            # peak. Use min log-value to recalibrate and obtain a meaningful peak
+            log_vals = self.__call__(T)
+            vals = np.exp(-(log_vals - log_vals.min()))
+
         if inverse_time:
-            return rfft(self.prob_relative(T), n=n)
+            return rfft(vals, n=n)
         else:
-            return rfft(self.prob_relative(T)[::-1], n=n)
+            return rfft(vals[::-1], n=n)
 
 


=====================================
treetime/gtr.py
=====================================
@@ -172,6 +172,56 @@ class GTR(object):
 
         return eq_freq_str + W_str + Q_str
 
+    @staticmethod
+    def from_file(gtr_fname):
+        """
+        Parse a GTR string and assign the rates accordingly.
+        Note that the input string is expected to be formatted exactly like the output of the `__str__` method.
+
+        Parameters
+        ----------
+
+            gtr_fname : file name
+            String representation of the GTR model
+
+        """
+        try:
+            with open(gtr_fname) as f:
+                alphabet = []
+                pi = []
+                while True:
+                    line = f.readline()
+                    if not line:
+                        break
+                    if line.strip().startswith("Substitution rate (mu):"):
+                        mu = float(line.split(":")[1].strip())
+                    elif line.strip().startswith("Equilibrium frequencies (pi_i):"):
+                        line = f.readline()
+                        while line.strip()!="":
+                            alphabet.append(line.split(":")[0].strip())
+                            pi.append(float(line.split(":")[1].strip()))
+                            line = f.readline()
+                        if not np.any([len(alphabet) == len(a) and np.all(np.array(alphabet) == a) for a in alphabets.values()]):
+                            raise ValueError("GTR: was unable to read custom GTR model in "+str(gtr_fname) +" - Alphabet not recognized")
+                    elif line.strip().startswith("Symmetrized rates from j->i (W_ij):"):
+                        line = f.readline()
+                        line = f.readline()
+                        n = len(pi)
+                        W = np.ones((n,n))
+                        j = 0
+                        while line.strip()!="":
+                            values = line.split()
+                            for i in range(n):
+                                W[j,i] = float(values[i+1])
+                            j +=1
+                            line = f.readline()
+                        if j != n:
+                            raise ValueError("GTR: was unable to read custom GTR model in "+str(gtr_fname) +" - Number of lines in W matrix does not match alphabet length")
+                gtr = GTR.custom(mu, pi, W, alphabet = alphabet)
+                return gtr
+        except:
+            raise MissingDataError('GTR: was unable to read custom GTR model in '+str(gtr_fname))
+
 
     def assign_rates(self, mu=1.0, pi=None, W=None):
         """
@@ -405,7 +455,7 @@ class GTR(object):
 
 
     @classmethod
-    def random(cls, mu=1.0, alphabet='nuc'):
+    def random(cls, mu=1.0, alphabet='nuc', rng=None):
         """
         Creates a random GTR model
 
@@ -420,12 +470,14 @@ class GTR(object):
 
 
         """
+        if rng is None:
+            rng = np.random.default_rng()
 
         alphabet=alphabets[alphabet]
         gtr = cls(alphabet)
         n = gtr.alphabet.shape[0]
-        pi = 1.0*np.random.randint(0,100,size=(n))
-        W = 1.0*np.random.randint(0,100,size=(n,n)) # with gaps
+        pi = 1.0*rng.randint(0,100,size=(n))
+        W = 1.0*rng.randint(0,100,size=(n,n)) # with gaps
 
         gtr.assign_rates(mu=mu, pi=pi, W=W)
         return gtr
@@ -493,14 +545,13 @@ class GTR(object):
             pi = np.copy(fixed_pi)
         pi/=pi.sum()
         W_ij = np.ones_like(nij)
-        mu = nij.sum()/Ti.sum()
+        mu = (nij.sum()+pc)/(Ti.sum()+pc)
         # if pi is fixed, this will immediately converge
         while LA.norm(pi_old-pi) > dp and count < Nit:
             gtr.logger(' '.join(map(str, ['GTR inference iteration',count,'change:',LA.norm(pi_old-pi)])), 3)
             count += 1
             pi_old = np.copy(pi)
-            W_ij = (nij+nij.T+2*pc_mat)/mu/(np.outer(pi,Ti) + np.outer(Ti,pi)
-                                                    + ttconf.TINY_NUMBER + 2*pc_mat)
+            W_ij = (nij+nij.T+2*pc_mat)/mu/(np.outer(pi,Ti) + np.outer(Ti,pi) + ttconf.TINY_NUMBER + 2*pc_mat)
 
             np.fill_diagonal(W_ij, 0)
             scale_factor = avg_transition(W_ij,pi, gap_index=gtr.gap_index)
@@ -509,9 +560,9 @@ class GTR(object):
             if fixed_pi is None:
                 pi = (np.sum(nij+pc_mat,axis=1)+root_state)/(ttconf.TINY_NUMBER + mu*np.dot(W_ij,Ti)+root_state.sum()+np.sum(pc_mat, axis=1))
                 pi /= pi.sum()
-                mu = nij.sum()/(ttconf.TINY_NUMBER + np.sum(pi * (W_ij.dot(Ti))))
+                mu = (nij.sum() + pc)/(np.sum(pi * (W_ij.dot(Ti)))+pc)
             else:
-                mu = nij.sum()/(ttconf.TINY_NUMBER + np.sum(pi * (W_ij.dot(pi)))*Ti.sum())
+                mu = (nij.sum() + pc)/(np.sum(pi * (W_ij.dot(pi)))*Ti.sum() + pc)
 
         if count >= Nit:
             gtr.logger('WARNING: maximum number of iterations has been reached in GTR inference',3, warn=True)
@@ -800,11 +851,12 @@ class GTR(object):
             else:
                 return -1.0*self.prob_t_compressed(seq_pair, multiplicity,t**2, return_log=True)
 
+        hamming_distance = np.sum(multiplicity[seq_pair[:,1]!=seq_pair[:,0]])/np.sum(multiplicity)
         try:
             from scipy.optimize import minimize_scalar
             opt = minimize_scalar(_neg_prob,
-                    bounds=[-np.sqrt(ttconf.MAX_BRANCH_LENGTH),np.sqrt(ttconf.MAX_BRANCH_LENGTH)],
-                    args=(seq_pair, multiplicity), tol=tol)
+                    bracket=[-np.sqrt(ttconf.MAX_BRANCH_LENGTH), np.sqrt(hamming_distance), np.sqrt(ttconf.MAX_BRANCH_LENGTH)],
+                    args=(seq_pair, multiplicity), tol=tol, method='brent')
             new_len = opt["x"]**2
             if 'success' not in opt:
                 opt['success'] = True
@@ -824,7 +876,7 @@ class GTR(object):
 
         if opt["success"] != True:
             # return hamming distance: number of state pairs where state differs/all pairs
-            new_len =  np.sum(multiplicity[seq_pair[:,1]!=seq_pair[:,0]])/np.sum(multiplicity)
+            new_len =  hamming_distance
 
         return new_len
 


=====================================
treetime/gtr_site_specific.py
=====================================
@@ -105,7 +105,7 @@ class GTR_site_specific(GTR):
 
     @classmethod
     def random(cls, L=1, avg_mu=1.0, alphabet='nuc', pi_dirichlet_alpha=1,
-               W_dirichlet_alpha=3.0, mu_gamma_alpha=3.0):
+               W_dirichlet_alpha=3.0, mu_gamma_alpha=3.0, rng=None):
         """
         Creates a random GTR model
 
@@ -129,28 +129,29 @@ class GTR_site_specific(GTR):
         GTR_site_specific
             model with randomly sampled frequencies
         """
+        if rng is None:
+            rng = np.random.default_rng()
 
-        from scipy.stats import gamma
         alphabet=alphabets[alphabet]
         gtr = cls(alphabet=alphabet, seq_len=L)
         n = gtr.alphabet.shape[0]
 
         # Dirichlet distribution == l_1 normalized vector of samples of the Gamma distribution
         if pi_dirichlet_alpha:
-            pi = 1.0*gamma.rvs(pi_dirichlet_alpha, size=(n,L))
+            pi = 1.0*rng.gamma(pi_dirichlet_alpha, size=(n,L))
         else:
             pi = np.ones((n,L))
 
         pi /= pi.sum(axis=0)
         if W_dirichlet_alpha:
-            tmp = 1.0*gamma.rvs(W_dirichlet_alpha, size=(n,n))
+            tmp = 1.0*rng.gamma(W_dirichlet_alpha, size=(n,n))
         else:
             tmp = np.ones((n,n))
         tmp = np.tril(tmp,k=-1)
         W = tmp + tmp.T
 
         if mu_gamma_alpha:
-            mu = gamma.rvs(mu_gamma_alpha, size=(L,))
+            mu = rng.gamma(mu_gamma_alpha, size=(L,))
         else:
             mu = np.ones(L)
 


=====================================
treetime/merger_models.py
=====================================
@@ -255,13 +255,13 @@ class Coalescent(object):
         '''
         from scipy.optimize import minimize_scalar
         initial_Tc = self.Tc
-        def cost(Tc):
-            self.set_Tc(Tc)
+        def cost(logTc):
+            self.set_Tc(np.exp(logTc))
             return -self.total_LH()
 
-        sol = minimize_scalar(cost, bounds=[ttconf.TINY_NUMBER,10.0])
+        sol = minimize_scalar(cost, bracket=[-20.0, 2.0], method='brent')
         if "success" in sol and sol["success"]:
-            self.set_Tc(sol['x'])
+            self.set_Tc(np.exp(sol['x']))
         else:
             self.logger("merger_models:optimize_Tc: optimization of coalescent time scale failed: " + str(sol), 0, warn=True)
             self.set_Tc(initial_Tc.y, T=initial_Tc.x)
@@ -388,3 +388,8 @@ class Coalescent(object):
             return skyline, conf
         else:
             return skyline, None
+
+
+
+
+


=====================================
treetime/node_interpolator.py
=====================================
@@ -1,5 +1,6 @@
 import numpy as np
 from . import config as ttconf
+from . import TreeTimeUnknownError
 from .distribution import Distribution
 from .utils import clip
 from .config import FFT_FWHM_GRID_SIZE
@@ -164,64 +165,89 @@ class NodeInterpolator (Distribution):
         dt = max(branch_interp.one_mutation*0.005, min(node_interp.fwhm, branch_interp.fwhm)/fft_grid_size)
         b_effsupport = branch_interp.effective_support
         n_effsupport = node_interp.effective_support
-
-        tmax = 2*max(b_effsupport[1]-b_effsupport[0], n_effsupport[1]-n_effsupport[0])
-
-        Tb = np.arange(b_effsupport[0], b_effsupport[0] + tmax + dt, dt)
-        if inverse_time:
-            Tn = np.arange(n_effsupport[0], n_effsupport[0] + tmax + dt, dt)
-            Tmin = node_interp.xmin
-            Tmax = ttconf.MAX_BRANCH_LENGTH
-        else:
-            Tn = np.arange(n_effsupport[1] - tmax, n_effsupport[1] + dt, dt)
-            Tmin = -ttconf.MAX_BRANCH_LENGTH
-            Tmax = node_interp.xmax
-
-        raw_len = len(Tb)
-        fft_len = 2*raw_len
-
-        fftb = branch_interp.fft(Tb, n=fft_len)
-        fftn = node_interp.fft(Tn, n=fft_len, inverse_time=inverse_time)
-        if inverse_time:
-            fft_res = np.fft.irfft(fftb*fftn, fft_len)[:raw_len]
-            Tres = Tn + Tb[0]
+        b_support_range = b_effsupport[1]-b_effsupport[0]
+        n_support_range = n_effsupport[1]-n_effsupport[0]
+        # compare the support of the node distribution to the width of the branch length distribution
+        ratio = n_support_range/branch_interp.fwhm
+
+        if ratio < 1.0/fft_grid_size and 4.0*dt > node_interp.fwhm:
+            ## node distribution is much narrower than the branch distribution, proceed as if
+            # node distribution is a delta distribution with the peak 4 full-width-half-maxima
+            # away from the nominal peak to avoid slicing the relevant range to zero
+            log_scale_node_interp = node_interp.integrate(return_log=True, a=node_interp.xmin,b=node_interp.xmax,n=max(100, len(node_interp.x))) #probability of node distribution
+            if inverse_time:
+                x = branch_interp.x + max(n_effsupport[0], node_interp._peak_pos - 4.0*node_interp.fwhm)
+                dist = Distribution(x, branch_interp(x - node_interp._peak_pos) - log_scale_node_interp,
+                                    min_width=max(node_interp.min_width, branch_interp.min_width), is_log=True)
+            else:
+                x = - branch_interp.x + min(n_effsupport[1], node_interp._peak_pos + 4.0*node_interp.fwhm)
+                dist = Distribution(x, branch_interp(branch_interp.x) - log_scale_node_interp,
+                                    min_width=max(node_interp.min_width, branch_interp.min_width), is_log=True)
+            return dist
+        elif ratio > fft_grid_size and 4*dt > branch_interp.fwhm:
+            raise ValueError("ERROR: Unexpected behavior: branch distribution is much narrower than the node distribution.")
         else:
-            fft_res = np.fft.irfft(fftb*fftn, fft_len)[::-1]
-            fft_res = fft_res[raw_len:]
-            Tres = Tn - Tb[0]
-
-        # determine region in which we can trust the FFT convolution and avoid
-        # inaccuracies due to machine precision. 1e-13 seems robust
-        ind = fft_res>fft_res.max()*1e-13
-        res = -np.log(fft_res[ind]) + branch_interp.peak_val + node_interp.peak_val - np.log(dt)
-        Tres_cropped = Tres[ind]
-
-        # extrapolate the tails exponentially: use margin last data points
-        margin = np.minimum(3, Tres_cropped.shape[0]//3)
-        if margin<1 or len(res)==0:
-            import ipdb; ipdb.set_trace()
-        else:
-            left_slope = (res[margin]-res[0])/(Tres_cropped[margin]-Tres_cropped[0])
-            right_slope = (res[-1]-res[-margin-1])/(Tres_cropped[-1]-Tres_cropped[-margin-1])
-
-        # only extrapolate on the left when the slope is negative and we are not on the boundary
-        if Tmin<Tres_cropped[0] and left_slope<0:
-            Tleft = np.linspace(Tmin, Tres_cropped[0],10)[:-1]
-            res_left = res[0] + left_slope*(Tleft - Tres_cropped[0])
-        else:
-            Tleft, res_left = [], []
-
-        # only extrapolate on the right when the slope is positive and we are not on the boundary
-        if Tres_cropped[-1]<Tmax and right_slope>0:
-            Tright = np.linspace(Tres_cropped[-1], Tmax,10)[1:]
-            res_right = res[-1] + right_slope*(Tright - Tres_cropped[-1])
-        else: #otherwise
-            Tright, res_right = [], []
-
-        # instantiate the new interpolation object and return
-        return cls(np.concatenate((Tleft,Tres_cropped,Tright)),
-                   np.concatenate((res_left, res, res_right)),
-                   is_log=True, kind='linear', assume_sorted=True)
+            tmax = 2*max(b_support_range, n_support_range)
+
+            Tb = np.arange(b_effsupport[0], b_effsupport[0] + tmax + dt, dt)
+            if inverse_time:
+                Tn = np.arange(n_effsupport[0], n_effsupport[0] + tmax + dt, dt)
+                Tmin = node_interp.xmin
+                Tmax = ttconf.MAX_BRANCH_LENGTH
+            else:
+                Tn = np.arange(n_effsupport[1] - tmax, n_effsupport[1] + dt, dt)
+                Tmin = -ttconf.MAX_BRANCH_LENGTH
+                Tmax = node_interp.xmax
+
+            raw_len = len(Tb)
+            fft_len = 2*raw_len
+
+            fftb = branch_interp.fft(Tb, n=fft_len)
+            fftn = node_interp.fft(Tn, n=fft_len, inverse_time=inverse_time)
+            if inverse_time:
+                fft_res = np.fft.irfft(fftb*fftn, fft_len)[:raw_len]
+                Tres = Tn + Tb[0]
+            else:
+                fft_res = np.fft.irfft(fftb*fftn, fft_len)[::-1]
+                fft_res = fft_res[raw_len:]
+                Tres = Tn - Tb[0]
+
+            # determine region in which we can trust the FFT convolution and avoid
+            # inaccuracies due to machine precision. 1e-13 seems robust
+            ind = fft_res>fft_res.max()*1e-13
+            res = -np.log(fft_res[ind]) + branch_interp.peak_val + node_interp.peak_val - np.log(dt)
+
+            Tres_cropped = Tres[ind]
+
+            # extrapolate the tails exponentially: use margin last data points
+            margin = np.minimum(3, Tres_cropped.shape[0]//3)
+            if margin<1 or len(res)==0:
+                raise TreeTimeUnknownError("Error: Unexpected behavior detected in FFT function. "
+                                     "No valid points left after reducing to plausible region.\n\n"
+                        "If you see this error please let us know by filling an issue at:\n"
+                        "https://github.com/neherlab/treetime/issues")
+            else:
+                left_slope = (res[margin]-res[0])/(Tres_cropped[margin]-Tres_cropped[0])
+                right_slope = (res[-1]-res[-margin-1])/(Tres_cropped[-1]-Tres_cropped[-margin-1])
+
+            # only extrapolate on the left when the slope is negative and we are not on the boundary
+            if Tmin<Tres_cropped[0] and left_slope<0:
+                Tleft = np.linspace(Tmin, Tres_cropped[0],10)[:-1]
+                res_left = res[0] + left_slope*(Tleft - Tres_cropped[0])
+            else:
+                Tleft, res_left = [], []
+
+            # only extrapolate on the right when the slope is positive and we are not on the boundary
+            if Tres_cropped[-1]<Tmax and right_slope>0:
+                Tright = np.linspace(Tres_cropped[-1], Tmax,10)[1:]
+                res_right = res[-1] + right_slope*(Tright - Tres_cropped[-1])
+            else: #otherwise
+                Tright, res_right = [], []
+
+            # instantiate the new interpolation object and return
+            return cls(np.concatenate((Tleft,Tres_cropped,Tright)),
+                    np.concatenate((res_left, res, res_right)),
+                    is_log=True, kind='linear', assume_sorted=True)
 
     @classmethod
     def convolve(cls, node_interp, branch_interp, max_or_integral='integral',


=====================================
treetime/seq_utils.py
=====================================
@@ -223,7 +223,7 @@ def seq2prof(seq, profile_map):
     return np.array([profile_map[k] for k in seq])
 
 
-def prof2seq(profile, gtr, sample_from_prof=False, normalize=True):
+def prof2seq(profile, gtr, sample_from_prof=False, normalize=True, rng=None):
     """
     Convert profile to sequence and normalize profile across sites.
 
@@ -247,7 +247,8 @@ def prof2seq(profile, gtr, sample_from_prof=False, normalize=True):
      idx : numpy.array
         Indices chosen from profile as array of length L
     """
-
+    if rng is None:
+        rng = np.random.default_rng()
     # normalize profile such that probabilities at each site sum to one
     if normalize:
         tmp_profile, pre=normalize_profile(profile, return_offset=False)
@@ -258,7 +259,7 @@ def prof2seq(profile, gtr, sample_from_prof=False, normalize=True):
     # (sampling from cumulative distribution over the different states)
     if sample_from_prof:
         cumdis = tmp_profile.cumsum(axis=1).T
-        randnum = np.random.random(size=cumdis.shape[1])
+        randnum = rng.random(size=cumdis.shape[1])
         idx = np.argmax(cumdis>=randnum, axis=0)
     else:
         idx = tmp_profile.argmax(axis=1)


=====================================
treetime/seqgen.py
=====================================
@@ -33,7 +33,7 @@ class SeqGen(TreeAnc):
         """
         cum_p = p.cumsum(axis=1).T
 
-        prand = np.random.random(self.seq_len)
+        prand = self.rng.random(self.seq_len)
         seq = self.gtr.alphabet[np.argmax(cum_p>prand, axis=0)]
         return seq
 


=====================================
treetime/treeanc.py
=====================================
@@ -3,7 +3,6 @@ import gc
 import numpy as np
 from Bio import Phylo
 from Bio.Phylo.BaseTree import Clade
-from Bio import AlignIO
 from . import config as ttconf
 from . import MissingDataError,UnknownMethodError
 from .seq_utils import seq2prof, prof2seq, normalize_profile, extend_profile
@@ -55,7 +54,7 @@ class TreeAnc(object):
                 ref=None, verbose = ttconf.VERBOSE, ignore_gaps=True,
                 convert_upper=True, seq_multiplicity=None, log=None,
                 compress=True, seq_len=None, ignore_missing_alns=False,
-                keep_node_order=False, **kwargs):
+                keep_node_order=False, rng_seed=None, **kwargs):
         """
         TreeAnc constructor. It prepares the tree, attaches sequences to the leaf nodes,
         and sets some configuration parameters.
@@ -143,7 +142,7 @@ class TreeAnc(object):
         self.sequence_reconstruction = None
         self.ignore_missing_alns = ignore_missing_alns
         self.keep_node_order = keep_node_order
-
+        self.rng = np.random.default_rng(seed=rng_seed)
         self._tree = None
         self.tree = tree
         if tree is None:
@@ -513,13 +512,6 @@ class TreeAnc(object):
 
         self.logger("TreeAnc.infer_ancestral_sequences with method: %s, %s"%(method, 'marginal' if marginal else 'joint'), 1)
 
-        if not reconstruct_tip_states:
-            self.logger("WARNING: Previous versions of TreeTime (<0.7.0) RECONSTRUCTED sequences"
-                        " of tips at positions with AMBIGUOUS bases. This resulted in"
-                        " unexpected behavior is some cases and is no longer done by default."
-                        " If you want to replace those ambiguous sites with their most likely state,"
-                        " rerun with `reconstruct_tip_states=True` or `--reconstruct-tip-states`.", 0, warn=True, only_once=True)
-
         if method.lower() in ['ml', 'probabilistic']:
             if marginal:
                 _ml_anc = self._ml_anc_marginal
@@ -580,7 +572,7 @@ class TreeAnc(object):
                                     "in the position %d: %s, "
                                     "choosing %s" % (amb, str(self.tree.root.state[amb]),
                                                      self.tree.root.state[amb][0]), 4)
-        self.tree.root._cseq = np.array([k[np.random.randint(len(k)) if len(k)>1 else 0]
+        self.tree.root._cseq = np.array([k[self.rng.randint(len(k)) if len(k)>1 else 0]
                                            for k in self.tree.root.state])
 
 
@@ -802,7 +794,8 @@ class TreeAnc(object):
         self.tree.sequence_marginal_LH = self.tree.total_sequence_LH
         if assign_sequence:
             seq, prof_vals, idxs = prof2seq(self.tree.root.marginal_profile,
-                                        self.gtr, sample_from_prof=sample_from_profile, normalize=False)
+                                        self.gtr, sample_from_prof=sample_from_profile,
+                                        normalize=False, rng=self.rng)
             self.tree.root._cseq = seq
 
 
@@ -871,7 +864,8 @@ class TreeAnc(object):
             # choose sequence based maximal marginal LH.
             if assign_sequence:
                 seq, prof_vals, idxs = prof2seq(node.marginal_profile, self.gtr,
-                                                sample_from_prof=sample_from_profile, normalize=False)
+                                                sample_from_prof=sample_from_profile,
+                                                normalize=False, rng=self.rng)
 
                 if self.sequence_reconstruction:
                     N_diff += (seq!=node.cseq).sum()
@@ -976,8 +970,9 @@ class TreeAnc(object):
         elif isinstance(sample_from_profile, bool):
             root_sample_from_profile = sample_from_profile
 
-        seq, anc_lh_vals, idxs = prof2seq(np.exp(normalized_profile),
-                                    self.gtr, sample_from_prof = root_sample_from_profile)
+        seq, anc_lh_vals, idxs = prof2seq(np.exp(normalized_profile), self.gtr,
+                                          sample_from_prof = root_sample_from_profile,
+                                          rng=self.rng)
 
         # compute the likelihood of the most probable root sequence
         self.tree.sequence_LH = np.choose(idxs, self.tree.root.joint_Lx.T)
@@ -1200,6 +1195,9 @@ class TreeAnc(object):
             return self.one_mutation
 
         if not hasattr(node, 'branch_state'):
+            if node.cseq is None and node.is_terminal():
+                raise MissingDataError("TreeAnc.optimal_branch_length: terminal node alignments required; sequence is missing for leaf: '%s'. "
+                        "Missing terminal sequences can be inferred from sister nodes by rerunning with `reconstruct_tip_states=True` or `--reconstruct-tip-states`" % node.name)
             self.add_branch_state(node)
         return self.gtr.optimal_t_compressed(node.branch_state['pair'],
                                     node.branch_state['multiplicity'])
@@ -1355,7 +1353,7 @@ class TreeAnc(object):
         self.logger("TreeAnc.optimize_tree: sequences...", 1)
         N_diff = self.reconstruct_anc(method=method_anc, infer_gtr=infer_gtr, pc=pc,
                                       marginal=marginal_sequences, **kwargs)
-        self.optimize_branch_lengths_joint(verbose=0, store_old=False, mode=branch_length_mode)
+        self.optimize_branch_lengths_joint(store_old=False)
         n = 0
         while n<max_iter:
             n += 1
@@ -1568,7 +1566,7 @@ class TreeAnc(object):
 
         old_mu = self.gtr.mu
         try:
-            sol = minimize_scalar(cost_func,bracket=[0.01*np.sqrt(old_mu), np.sqrt(old_mu),100*np.sqrt(old_mu)])
+            sol = minimize_scalar(cost_func, bracket=[0.01*np.sqrt(old_mu), np.sqrt(old_mu),100*np.sqrt(old_mu)], method='brent')
         except:
             self.gtr.mu=old_mu
             self.logger('treeanc:optimize_gtr_rate: optimization failed, continuing with previous mu',1,warn=True)


=====================================
treetime/treeregression.py
=====================================
@@ -402,7 +402,7 @@ class TreeRegression(object):
         else:
             ii = np.argmin(chisq_grid)
             bounds = (0 if ii==0 else grid[ii-1], 1.0 if ii==len(grid)-1 else grid[ii+1])
-            sol = minimize_scalar(chisq, bounds=bounds, method="bounded")
+            sol = minimize_scalar(chisq, bounds=bounds, method="bounded", options={'xatol':1e-6})
             if sol["success"]:
                 return sol['x'], sol['fun']
             else:
@@ -611,6 +611,4 @@ if __name__ == '__main__':
     rtt = np.array(rtt)
     plt.plot(ti, rtt)
     plt.plot(ti, reg["slope"]*ti + reg["intercept"])
-    plt.show()
-
     Phylo.draw(T)


=====================================
treetime/treetime.py
=====================================
@@ -73,16 +73,17 @@ class TreeTime(ClockTree):
                 sys.exit(2)
 
 
-    def _run(self, root=None, infer_gtr=True, relaxed_clock=None, n_iqd = None,
-            resolve_polytomies=True, max_iter=0, Tc=None, fixed_clock_rate=None,
+    def _run(self, root=None, infer_gtr=True, relaxed_clock=None, clock_filter_method='residuals',
+             n_iqd = None, resolve_polytomies=True, max_iter=0, Tc=None, fixed_clock_rate=None,
             time_marginal='never', sequence_marginal=False, branch_length_mode='auto',
             vary_rate=False, use_covariation=False, tracelog_file=None,
-            method_anc = 'probabilistic', assign_gamma=None, **kwargs):
+            method_anc = 'probabilistic', assign_gamma=None, stochastic_resolve=False,
+            **kwargs):
 
         """
         Run TreeTime reconstruction. Based on the input parameters, it divides
         the analysis into semi-independent jobs and conquers them one-by-one,
-        gradually optimizing the tree given the temporal constarints and leaf
+        gradually optimizing the tree given the temporal constraints and leaf
         node sequences.
 
         Parameters
@@ -111,6 +112,9 @@ class TreeTime(ClockTree):
         resolve_polytomies : bool
            If True, attempt to resolve multiple mergers
 
+        stochastic_resolve : bool (default False)
+           Resolve multiple mergers via a random coalescent tree (True) or via greedy optimization
+
         max_iter : int
            Maximum number of iterations to optimize the tree
 
@@ -149,7 +153,7 @@ class TreeTime(ClockTree):
 
         use_covariation : bool, optional
             default False, if False, rate estimates will be performed using simple
-            regression ignoring phylogenetic covaration between nodes. If vary_rate is True,
+            regression ignoring phylogenetic covariation between nodes. If vary_rate is True,
             use_covariation is true by default
 
         method_anc: str, optional
@@ -167,7 +171,7 @@ class TreeTime(ClockTree):
 
         Returns
         -------
-        TreeTime error/succces code : str
+        TreeTime error/success code : str
             return value depending on success or error
 
 
@@ -218,7 +222,8 @@ class TreeTime(ClockTree):
             else:
                 plot_rtt=False
             reroot_mechanism = 'least-squares' if root=='clock_filter' else root
-            self.clock_filter(reroot=reroot_mechanism, n_iqd=n_iqd, plot=plot_rtt, fixed_clock_rate=fixed_clock_rate)
+            self.clock_filter(reroot=reroot_mechanism, method=clock_filter_method,
+                              n_iqd=n_iqd, plot=plot_rtt, fixed_clock_rate=fixed_clock_rate)
         elif root is not None:
             self.reroot(root=root, clock_rate=fixed_clock_rate)
 
@@ -279,13 +284,14 @@ class TreeTime(ClockTree):
             n_resolved=0
             if resolve_polytomies:
                 # if polytomies are found, rerun the entire procedure
-                n_resolved = self.resolve_polytomies()
+                n_resolved = self.resolve_polytomies(stochastic_resolve=stochastic_resolve)
                 if n_resolved:
                     seq_kwargs['prune_short']=False
                     self.prepare_tree()
                     if self.branch_length_mode!='input': # otherwise reoptimize branch length while preserving branches without mutations
                         self.optimize_tree(max_iter=0, method_anc = method_anc,**seq_kwargs)
                     need_new_time_tree = True
+
             if assign_gamma and callable(assign_gamma):
                 self.logger("### assigning gamma",1)
                 assign_gamma(self.tree)
@@ -378,7 +384,8 @@ class TreeTime(ClockTree):
             self.branch_length_mode = 'input'
 
 
-    def clock_filter(self, reroot='least-squares', n_iqd=None, plot=False, fixed_clock_rate=None):
+    def clock_filter(self, reroot='least-squares', method='residual',
+                     n_iqd=None, plot=False, fixed_clock_rate=None):
         r'''
         Labels outlier branches that don't seem to follow a molecular clock
         and excludes them from subsequent molecular clock estimation and
@@ -400,34 +407,22 @@ class TreeTime(ClockTree):
             If True, plot the results
 
         '''
+        from .clock_filter_methods import residual_filter, local_filter
         if n_iqd is None:
             n_iqd = ttconf.NIQD
         if type(reroot) is list and len(reroot)==1:
             reroot=str(reroot[0])
 
-        terminals = self.tree.get_terminals()
         if reroot:
-            self.reroot(root='least-squares' if reroot=='best' else reroot, covariation=False, clock_rate=fixed_clock_rate)
+            self.reroot(root='least-squares' if reroot=='best' else reroot,
+                        covariation=False, clock_rate=fixed_clock_rate)
         else:
             self.get_clock_model(covariation=False, slope=fixed_clock_rate)
 
-        clock_rate = self.clock_model['slope']
-        icpt = self.clock_model['intercept']
-        res = {}
-        for node in terminals:
-            if hasattr(node, 'raw_date_constraint') and  (node.raw_date_constraint is not None):
-                res[node] = node.dist2root - clock_rate*np.mean(node.raw_date_constraint) - icpt
-
-        residuals = np.array(list(res.values()))
-        iqd = np.percentile(residuals,75) - np.percentile(residuals,25)
-        bad_branch_count = 0
-        for node,r in res.items():
-            if abs(r)>n_iqd*iqd and node.up.up is not None:
-                self.logger('TreeTime.ClockFilter: marking %s as outlier, residual %f interquartile distances'%(node.name,r/iqd), 3, warn=True)
-                node.bad_branch=True
-                bad_branch_count += 1
-            else:
-                node.bad_branch=False
+        if method=='residual':
+            bad_branch_count = residual_filter(self, n_iqd)
+        elif method=='local':
+            bad_branch_count = local_filter(self, n_iqd)
 
         if bad_branch_count>0.34*self.tree.count_terminals():
             self.logger("TreeTime.clock_filter: More than a third of leaves have been excluded by the clock filter. Please check your input data.", 0, warn=True)
@@ -567,7 +562,8 @@ class TreeTime(ClockTree):
         return new_root
 
 
-    def resolve_polytomies(self, merge_compressed=False, resolution_threshold=0.05):
+    def resolve_polytomies(self, merge_compressed=False, resolution_threshold=0.05,
+                           stochastic_resolve=False):
         """
         Resolve the polytomies on the tree.
 
@@ -581,8 +577,13 @@ class TreeTime(ClockTree):
         Parameters
         ----------
          merge_compressed : bool
-            If True, keep compressed branches as polytomies. If False,
-            return a strictly binary tree.
+            If True, keep compressed branches as polytomies. Applies to greedy resolve
+         resolution_threshold : float
+            minimal delta LH to consider for polytomy resolution. Otherwise, keep parent as polytomy
+         stochastic_resolve : bool
+            generate a stochastic binary coalescent tree with mutation from the children of
+            a polytomy. Doesn't necessarily resolve the node fully. This step is stochastic
+            and different runs will result in different outcomes.
 
         Returns
         --------
@@ -592,14 +593,26 @@ class TreeTime(ClockTree):
         """
         self.logger("TreeTime.resolve_polytomies: resolving multiple mergers...",1)
         poly_found=0
+        if stochastic_resolve is False:
+            self.logger("DEPRECATION WARNING. TreeTime.resolve_polytomies: You are "
+                        "resolving polytomies using the old 'greedy' mode. This is not "
+                        "well suited for large polytomies. Stochastic resolution will "
+                        "become the default in future versions. To switch now, rerun "
+                        "with the flag `--stochastic-resolve`. To keep using the greedy method "
+                        "in the future, run with `--greedy-resolve` ", 0, warn=True, only_once=True)
 
         for n in self.tree.find_clades():
             if len(n.clades) > 2:
                 prior_n_clades = len(n.clades)
-                self._poly(n, merge_compressed, resolution_threshold=resolution_threshold)
+                if stochastic_resolve:
+                    self.generate_subtree(n)
+                else:
+                    self._poly(n, merge_compressed, resolution_threshold=resolution_threshold)
+
                 poly_found+=prior_n_clades - len(n.clades)
 
-        obsolete_nodes = [n for n in self.tree.find_clades() if len(n.clades)==1 and n.up is not None]
+        obsolete_nodes = [n for n in self.tree.find_clades()
+                          if len(n.clades)==1 and n.up is not None]
         for node in obsolete_nodes:
             self.logger('TreeTime.resolve_polytomies: remove obsolete node '+node.name,4)
             if node.up is not None:
@@ -616,7 +629,7 @@ class TreeTime(ClockTree):
 
         """
         Function to resolve polytomies for a given parent node. If the
-        number of the direct decendants is less than three (not a polytomy), does
+        number of the direct descendants is less than three (not a polytomy), does
         nothing. Otherwise, for each pair of nodes, assess the possible LH increase
         which could be gained by merging the two nodes. The increase in the LH is
         basically the tradeoff between the gain of the LH due to the changing the
@@ -632,11 +645,17 @@ class TreeTime(ClockTree):
             """
             cost gain if nodes n1, n2 are joined and their parent is placed at time t
             cost gain = (LH loss now) - (LH loss when placed at time t)
+            NOTE: this cost function ignores the coalescent likelihood. Given the greedy
+            and approximate nature of this calculation, this seems justified. But this
+            entire procedure is not well suited for large polytomies.
             """
-            cg2 = n2.branch_length_interpolator._func(parent.time_before_present - n2.time_before_present) - n2.branch_length_interpolator._func(t - n2.time_before_present)
+            # old - new contributions of child branches
             cg1 = n1.branch_length_interpolator._func(parent.time_before_present - n1.time_before_present) - n1.branch_length_interpolator._func(t - n1.time_before_present)
+            cg2 = n2.branch_length_interpolator._func(parent.time_before_present - n2.time_before_present) - n2.branch_length_interpolator._func(t - n2.time_before_present)
+            # old - new contribution of additional branch (no old contribution)
             cg_new = - zero_branch_slope * (parent.time_before_present - t) # loss in LH due to the new branch
-            return -(cg2+cg1+cg_new)
+
+            return -(cg2 + cg1 + cg_new)
 
         def cost_gain(n1, n2, parent):
             """
@@ -645,7 +664,7 @@ class TreeTime(ClockTree):
             try:
                 cg = sciopt.minimize_scalar(_c_gain,
                     bounds=[max(n1.time_before_present,n2.time_before_present), parent.time_before_present],
-                    method='Bounded',args=(n1,n2, parent))
+                    method='bounded',args=(n1,n2, parent), options={'xatol':1e-4*self.one_mutation})
                 return cg['x'], - cg['fun']
             except:
                 self.logger("TreeTime._poly.cost_gain: optimization of gain failed", 3, warn=True)
@@ -753,6 +772,122 @@ class TreeTime(ClockTree):
         return LH
 
 
+    def generate_subtree(self, parent):
+        from .branch_len_interpolator import BranchLenInterpolator
+        # use the random number generator of TreeTime
+        exp_dis = self.rng.exponential
+
+        L = self.data.full_length
+        mutation_rate = self.gtr.mu*L
+
+        tmax = parent.time_before_present
+        branches_by_time = sorted(parent.clades, key=lambda x:x.time_before_present)
+        # calculate the mutations on branches leading to nodes from the mutation length
+        # this excludes state chances to ambiguous states
+        mutations_per_branch = {b.name:round(b.mutation_length*L) for b in branches_by_time}
+
+        branches_alive=branches_by_time[:1]
+        branches_to_come = branches_by_time[1:]
+        t = branches_alive[-1].time_before_present
+        if t>=tmax:
+            # no time left -- keep everything as individual children.
+            return
+
+        # if there is no coalescent model, assume a rate that would typically coalesce all tips
+        # in the time window between the latest and the parent node.
+        dummy_coalescent_rate = 2.0/(tmax-t)
+        self.logger(f"TreeTime.generate_subtree: node {parent.name} has {len(branches_by_time)} children."
+                    +f" {len([b for b,k in mutations_per_branch.items() if k>0])} have mutations."
+                    +f" The time window for coalescence is {tmax-t:1.4e}",3)
+
+        # loop until time collides with the parent node or all but two branches have been dealt with
+        # the remaining two would be the children of the parent
+        while len(branches_alive)+len(branches_to_come)>2 and t<tmax:
+
+            # branches without mutations are ready to coalesce -- others have to mutate first
+            ready_to_coalesce = [b for b in branches_alive if mutations_per_branch.get(b.name,0)==0]
+            if hasattr(self, 'merger_model') and (self.merger_model is not None):
+                coalescent_rate = self.merger_model.branch_merger_rate(t) + mutation_rate
+            else:
+                coalescent_rate = 0.5*len(ready_to_coalesce)*dummy_coalescent_rate + mutation_rate
+
+            total_mutations = np.sum([mutations_per_branch.get(b.name,0) for b in branches_alive])
+            n_branches_w_mutations = len(branches_alive) - len(ready_to_coalesce)
+            # the probability of a branch without events is the sum of mutation and coalescent rates
+            # branches with mutations can only mutate, the others only coalesce. This is due to the
+            # conditioning for all branches being direct descendants of a polytomy.
+            total_mut_rate = mutation_rate*total_mutations + coalescent_rate*n_branches_w_mutations
+            total_coalescent_rate = max(0,(len(ready_to_coalesce)-1))*(coalescent_rate + mutation_rate)
+            # just a single branch and no mutations --> advance to next branch
+            if (total_mut_rate + total_coalescent_rate)==0 and len(branches_to_come):
+                branches_alive.append(branches_to_come.pop(0))
+                t = branches_alive[-1].time_before_present
+                continue
+
+            # determine the next time step
+            total_rate_inv = 1.0/(total_mut_rate + total_coalescent_rate)
+            dt = exp_dis(total_rate_inv)
+            t+=dt
+            # if the time advanced past the next branch in the branches_to_come list
+            # add this branch to branches alive and re-renter the loop
+            if len(branches_to_come) and t>branches_to_come[0].time_before_present:
+                while len(branches_to_come) and t>branches_to_come[0].time_before_present:
+                    branches_alive.append(branches_to_come.pop(0))
+            # else mutate or coalesce
+            else:
+                # determine whether to mutate or coalesce
+                p = self.rng.random()
+                mut_or_coal = p<total_mut_rate*total_rate_inv
+                if mut_or_coal:
+                    # transform p to be on a scale of 0 to total mutation
+                    p /= total_mut_rate*total_rate_inv
+                    p *= total_mutations
+                    # discount one mutation at a time until p<0, break and remove that mutation
+                    for b in branches_alive:
+                        p -= mutations_per_branch.get(b.name,0)
+                        if p<0: break
+                    mutations_per_branch[b.name] -= 1
+                else:
+                    # pick a pair to coalesce, make a new node.
+                    picks = self.rng.choice(len(ready_to_coalesce), size=2, replace=False)
+                    new_node = Phylo.BaseTree.Clade()
+                    new_node.time_before_present = t
+                    n1, n2 = ready_to_coalesce[picks[0]], ready_to_coalesce[picks[1]]
+                    new_node.clades = [n1, n2]
+                    new_node.mutation_length = 0.0
+                    n1.branch_length = t - n1.time_before_present
+                    n2.branch_length = t - n2.time_before_present
+                    n1.up = new_node
+                    n2.up = new_node
+                    if n1.mask is None or n2.mask is None:
+                        new_node.mask = None
+                        new_node.mcc = None
+                    else:
+                        new_node.mask = n1.mask * n2.mask
+                        new_node.mcc = n1.mcc if n1.mcc==n2.mcc else None
+                        self.logger('TreeTime._poly.merge_nodes: assigning mcc to new node ' + new_node.mcc, 4)
+                    new_node.up = parent
+                    new_node.tt = self
+                    if hasattr(parent, "_cseq"):
+                        new_node._cseq = parent._cseq
+                        self.add_branch_state(new_node)
+                    new_node.branch_length_interpolator = BranchLenInterpolator(new_node, self.gtr,
+                                pattern_multiplicity = self.data.multiplicity(mask=new_node.mask), min_width=self.min_width,
+                                one_mutation=self.one_mutation, branch_length_mode=self.branch_length_mode,
+                                n_grid_points = self.branch_grid_points)
+                    branches_alive = [b for b in branches_alive if b not in [n1,n2]] + [new_node]
+
+        remaining_branches = []
+        for b in branches_alive + branches_to_come:
+            b.branch_length = tmax - b.time_before_present
+            b.up = parent
+            remaining_branches.append(b)
+
+        self.logger(f"TreeTime.generate_subtree: node {parent.name} was resolved from {len(branches_by_time)} to {len(remaining_branches)} children.",3)
+        # assign the remaining branches as new clades to the parent.
+        parent.clades = remaining_branches
+
+
     def print_lh(self, joint=True):
         """
         Print the total likelihood of the tree given the constrained leaves
@@ -949,7 +1084,7 @@ class TreeTime(ClockTree):
         return Treg.optimal_reroot(force_positive=force_positive, slope=slope, keep_node_order=self.keep_node_order)['node']
 
 
-def plot_vs_years(tt, step = None, ax=None, confidence=None, ticks=True, **kwargs):
+def plot_vs_years(tt, step = None, ax=None, confidence=None, ticks=True, selective_confidence=None, **kwargs):
     '''
     Converts branch length to years and plots the time tree on a time axis.
 
@@ -986,7 +1121,7 @@ def plot_vs_years(tt, step = None, ax=None, confidence=None, ticks=True, **kwarg
     # draw tree
     if "label_func" not in kwargs:
         kwargs["label_func"] = lambda x:x.name if (x.is_terminal() and nleafs<30) else ""
-    Phylo.draw(tt.tree, axes=ax, **kwargs)
+    Phylo.draw(tt.tree, axes=ax, do_show=False, **kwargs)
 
     offset = tt.tree.root.numdate - tt.tree.root.branch_length
     date_range = np.max([n.numdate for n in tt.tree.get_terminals()])-offset
@@ -1033,14 +1168,16 @@ def plot_vs_years(tt, step = None, ax=None, confidence=None, ticks=True, **kwarg
             pos = year - offset
             r = Rectangle((pos, ylim[1]-5),
                           step, ylim[0]-ylim[1]+10,
-                          facecolor=[0.7+0.1*(1+yi%2)] * 3,
-                          edgecolor=[1,1,1])
+                          facecolor=[0.88+0.04*(1+yi%2)] * 3,
+                          edgecolor=[0.8,0.8,0.8])
             ax.add_patch(r)
-            if year in tick_vals and pos>=xlim[0] and pos<=xlim[1] and ticks:
-                label_str = "%1.2f"%(step*(year//step)) if step<1 else  str(int(year))
-                ax.text(pos,ylim[0]-0.04*(ylim[1]-ylim[0]), label_str,
-                        horizontalalignment='center')
-        ax.set_axis_off()
+            if step>=1:
+                if year in tick_vals and pos>=xlim[0] and pos<=xlim[1] and ticks:
+                    label_str = "%1.2f"%(step*(year//step)) if step<1 else  str(int(year))
+                    ax.text(pos,ylim[0]-0.04*(ylim[1]-ylim[0]), label_str,
+                            horizontalalignment='center')
+        if step>=1:
+            ax.set_axis_off()
 
     # add confidence intervals to the tree graph -- grey bars
     if confidence:
@@ -1055,7 +1192,7 @@ def plot_vs_years(tt, step = None, ax=None, confidence=None, ticks=True, **kwarg
             raise NotReadyError("confidence needs to be either a float (for max posterior region) or a two numbers specifying lower and upper bounds")
 
         for n in tt.tree.find_clades():
-            if not n.bad_branch:
+            if not n.bad_branch and (selective_confidence is None or selective_confidence(n)):
                 pos = cfunc(n, confidence)
                 ax.plot(pos-offset, np.ones(len(pos))*n.ypos, lw=3, c=(0.5,0.5,0.5))
     return fig, ax


=====================================
treetime/utils.py
=====================================
@@ -220,16 +220,23 @@ def parse_dates(date_file, name_col=None, date_col=None):
     Parameters
     ----------
     date_file : str
-        name of file to parse meta data from
+        name of csv/tsv file to parse meta data from
+    name_col : str, optional
+        name of column containing taxon names. If None, will use
+        first column that contains 'name', 'strain', 'accession'
+    date_col : str, optional
+        name of column containing taxon names. If None, will use 
+        a column that contains the substring 'date'
 
     Returns
     -------
-    dict
-        dictionary linking fields in a column interpreted as taxon name
-        (first column that contains 'name', 'strain', 'accession')
-        to a numerical date inferred from a column that contains 'date'.
-        It will first try to parse the column as float, than via
+    dict[str, float | list[float]]
+        dictionary mapping taxon names to numeric dates (float year)
+        It will first try to parse date column strings as float, then as min/max
+        pair of floats (e.g. '[2018.2:2018.4]'), then as date strings using
         pandas.to_datetime and finally as ambiguous date such as 2018-05-XX
+        Numeric date values are returned as float or a list of floats with
+        2 elements [min, max] if the date is ambiguous.
     """
     print("\nAttempting to parse dates...")
     dates = {}


=====================================
treetime/vcf_utils.py
=====================================
@@ -435,7 +435,10 @@ def write_vcf(tree_dict, file_name):#, compress=False):
             try:
                 pattern2.append(sequences[k][pi+1])
             except KeyError:
-                pattern2.append(ref[pi+1])
+                try:
+                    pattern2.append(ref[pi+1])
+                except IndexError:
+                    pass
 
         pattern = np.array(pattern).astype('U')
         pattern2 = np.array(pattern2).astype('U')


=====================================
treetime/wrappers.py
=====================================
@@ -2,14 +2,13 @@ import os, shutil, sys
 import numpy as np
 import pandas as pd
 from textwrap import fill
-from Bio import Phylo, AlignIO
+from Bio import Phylo
 from Bio import __version__ as bioversion
 from . import TreeAnc, GTR, TreeTime
 from . import utils
-from .vcf_utils import read_vcf, write_vcf
-from .seq_utils import alphabets
 from . import TreeTimeError, MissingDataError, UnknownMethodError
 from .treetime import reduce_time_marginal_argument
+from .CLI_io import *
 
 def assure_tree(params, tmp_dir='treetime_tmp'):
     """
@@ -32,12 +31,24 @@ def assure_tree(params, tmp_dir='treetime_tmp'):
         return 1
     return 0
 
+
 def create_gtr(params):
     """
     parse the arguments referring to the GTR model and return a GTR structure
     """
     model = params.gtr
     gtr_params = params.gtr_params
+    custom_gtr = params.custom_gtr
+    if custom_gtr:
+        if model not in ['custom', 'infer']:
+            print(f'Warning: you specified a GTR model `{model}` and a custom gtr path `{custom_gtr}`. TreeTime will load the custom model and ignore the parameter `--gtr {model}`.')
+        if os.path.isfile(custom_gtr):
+            gtr = GTR.from_file(custom_gtr)
+            params.gtr = 'custom'
+            return gtr
+        else:
+            raise ValueError(f"File with custom GTR model `{custom_gtr}` does not exist!")
+
     if model == 'infer':
         gtr = GTR.standard('jc', alphabet='aa' if params.aa else 'nuc')
     else:
@@ -57,225 +68,12 @@ def create_gtr(params):
                 print ("GTR params are not specified. Creating GTR model with default parameters")
 
             gtr = GTR.standard(model, **kwargs)
-            infer_gtr = False
         except KeyError as e:
             print("\nUNKNOWN SUBSTITUTION MODEL\n")
             raise e
 
     return gtr
 
-def get_outdir(params, suffix='_treetime'):
-    if params.outdir:
-        if os.path.exists(params.outdir):
-            if os.path.isdir(params.outdir):
-                return params.outdir.rstrip('/') + '/'
-            else:
-                print("designated output location %s is not a directory"%params.outdir, file=sys.stderr)
-        else:
-            os.makedirs(params.outdir)
-            return params.outdir.rstrip('/') + '/'
-
-    from datetime import datetime
-    outdir_stem = datetime.now().date().isoformat()
-    outdir = outdir_stem + suffix.rstrip('/')+'/'
-    count = 1
-    while os.path.exists(outdir):
-        outdir = outdir_stem + '-%04d'%count + suffix.rstrip('/')+'/'
-        count += 1
-
-    os.makedirs(outdir)
-    return outdir
-
-def get_basename(params, outdir):
-    # if params.aln:
-    #     basename = outdir + '.'.join(params.aln.split('/')[-1].split('.')[:-1])
-    # elif params.tree:
-    #     basename = outdir + '.'.join(params.tree.split('/')[-1].split('.')[:-1])
-    # else:
-    basename = outdir
-    return basename
-
-def read_in_DRMs(drm_file, offset):
-    import pandas as pd
-
-    DRMs = {}
-    drmPositions = []
-
-    df = pd.read_csv(drm_file, sep='\t')
-    for mi, m in df.iterrows():
-        pos = m.GENOMIC_POSITION-1+offset #put in correct numbering
-        drmPositions.append(pos)
-
-        if pos in DRMs:
-            DRMs[pos]['alt_base'][m.ALT_BASE] = m.SUBSTITUTION
-        else:
-            DRMs[pos] = {}
-            DRMs[pos]['drug'] = m.DRUG
-            DRMs[pos]['alt_base'] = {}
-            DRMs[pos]['alt_base'][m.ALT_BASE] = m.SUBSTITUTION
-            DRMs[pos]['gene'] = m.GENE
-
-    drmPositions = np.array(drmPositions)
-    drmPositions = np.unique(drmPositions)
-    drmPositions = np.sort(drmPositions)
-
-    DRM_info = {'DRMs': DRMs,
-            'drmPositions': drmPositions}
-
-    return DRM_info
-
-
-def read_if_vcf(params):
-    """
-    Checks if input is VCF and reads in appropriately if it is
-    """
-    ref = None
-    aln = params.aln
-    fixed_pi = None
-    if hasattr(params, 'aln') and params.aln is not None:
-        if any([params.aln.lower().endswith(x) for x in ['.vcf', '.vcf.gz']]):
-            if not params.vcf_reference:
-                print("ERROR: a reference Fasta is required with VCF-format alignments")
-                return -1
-            compress_seq = read_vcf(params.aln, params.vcf_reference)
-            sequences = compress_seq['sequences']
-            ref = compress_seq['reference']
-            aln = sequences
-
-            if not hasattr(params, 'gtr') or params.gtr=="infer": #if not specified, set it:
-                alpha = alphabets['aa'] if params.aa else alphabets['nuc']
-                fixed_pi = [ref.count(base)/len(ref) for base in alpha]
-                if fixed_pi[-1] == 0:
-                    fixed_pi[-1] = 0.05
-                    fixed_pi = [v-0.01 for v in fixed_pi]
-
-    return aln, ref, fixed_pi
-
-
-def plot_rtt(tt, fname):
-    tt.plot_root_to_tip()
-
-    from matplotlib import pyplot as plt
-    plt.savefig(fname)
-    print("--- root-to-tip plot saved to  \n\t"+fname)
-
-
-def export_sequences_and_tree(tt, basename, is_vcf=False, zero_based=False,
-                              report_ambiguous=False, timetree=False, confidence=False,
-                              reconstruct_tip_states=False, tree_suffix=''):
-    seq_info = is_vcf or tt.aln
-    if is_vcf:
-        outaln_name = basename + f'ancestral_sequences{tree_suffix}.vcf'
-        write_vcf(tt.get_reconstructed_alignment(reconstruct_tip_states=reconstruct_tip_states), outaln_name)
-    elif tt.aln:
-        outaln_name = basename + f'ancestral_sequences{tree_suffix}.fasta'
-        AlignIO.write(tt.get_reconstructed_alignment(reconstruct_tip_states=reconstruct_tip_states), outaln_name, 'fasta')
-    if seq_info:
-        print("\n--- alignment including ancestral nodes saved as  \n\t %s\n"%outaln_name)
-
-    # decorate tree with inferred mutations
-    terminal_count = 0
-    offset = 0 if zero_based else 1
-    if timetree:
-        dates_fname = basename + f'dates{tree_suffix}.tsv'
-        fh_dates = open(dates_fname, 'w', encoding='utf-8')
-        if confidence:
-            fh_dates.write('#Lower and upper bound delineate the 90% max posterior region\n')
-            fh_dates.write('#node\tdate\tnumeric date\tlower bound\tupper bound\n')
-        else:
-            fh_dates.write('#node\tdate\tnumeric date\n')
-
-    mutations_out = open(basename + "branch_mutations.txt", "w")
-    mutations_out.write("node\tstate1\tpos\tstate2\n")
-    for n in tt.tree.find_clades():
-        if timetree:
-            if confidence:
-                if n.bad_branch:
-                    fh_dates.write('%s\t--\t--\t--\t--\n'%(n.name))
-                else:
-                    conf = tt.get_max_posterior_region(n, fraction=0.9)
-                    fh_dates.write('%s\t%s\t%f\t%f\t%f\n'%(n.name, n.date, n.numdate,conf[0], conf[1]))
-            else:
-                if n.bad_branch:
-                    fh_dates.write('%s\t--\t--\n'%(n.name))
-                else:
-                    fh_dates.write('%s\t%s\t%f\n'%(n.name, n.date, n.numdate))
-
-        n.confidence=None
-        # due to a bug in older versions of biopython that truncated filenames in nexus export
-        # we truncate them by hand and make them unique.
-        if n.is_terminal() and len(n.name)>40 and bioversion<"1.69":
-            n.name = n.name[:35]+'_%03d'%terminal_count
-            terminal_count+=1
-        n.comment=''
-        if seq_info and len(n.mutations):
-            if n.mask is None:
-                if report_ambiguous:
-                    n.comment= '&mutations="' + ','.join([a+str(pos + offset)+d for (a,pos, d) in n.mutations])+'"'
-                else:
-                    n.comment= '&mutations="' + ','.join([a+str(pos + offset)+d for (a,pos, d) in n.mutations
-                                                        if tt.gtr.ambiguous not in [a,d]])+'"'
-            else:
-                if report_ambiguous:
-                    n.comment= '&mutations="' + ','.join([a+str(pos + offset)+d for (a,pos, d) in n.mutations if n.mask[pos]>0])+f'",mcc="{n.mcc}"'
-                else:
-                    n.comment= '&mutations="' + ','.join([a+str(pos + offset)+d for (a,pos, d) in n.mutations
-                                                        if tt.gtr.ambiguous not in [a,d] and n.mask[pos]>0])+f'",mcc="{n.mcc}"'
-
-                for (a, pos, d) in n.mutations:
-                    if tt.gtr.ambiguous not in [a,d] or report_ambiguous:
-                        mutations_out.write("%s\t%s\t%s\t%s\n" %(n.name, a, pos + 1, d))
-        if timetree:
-            n.comment+=(',' if n.comment else '&') + 'date=%1.2f'%n.numdate
-    mutations_out.close()
-
-    # write tree to file
-    fmt_bl = "%1.6f" if tt.data.full_length<1e6 else "%1.8e"
-    if timetree:
-        outtree_name = basename + f'timetree{tree_suffix}.nexus'
-        print("--- saved divergence times in \n\t %s\n"%dates_fname)
-        Phylo.write(tt.tree, outtree_name, 'nexus')
-    else:
-        outtree_name = basename + f'annotated_tree{tree_suffix}.nexus'
-        Phylo.write(tt.tree, outtree_name, 'nexus', format_branch_length=fmt_bl)
-    print("--- tree saved in nexus format as  \n\t %s\n"%outtree_name)
-
-    if timetree:
-        for n in tt.tree.find_clades():
-            n.branch_length = n.mutation_length
-        outtree_name = basename + f'divergence_tree{tree_suffix}.nexus'
-        Phylo.write(tt.tree, outtree_name, 'nexus', format_branch_length=fmt_bl)
-        print("--- divergence tree saved in nexus format as  \n\t %s\n"%outtree_name)
-
-
-def print_save_plot_skyline(tt, n_std=2.0, screen=True, save='', plot=''):
-    if plot:
-        import matplotlib.pyplot as plt
-
-    skyline, conf = tt.merger_model.skyline_inferred(gen=50, confidence=n_std)
-    if save: fh = open(save, 'w', encoding='utf-8')
-    header1 = "Skyline assuming 50 gen/year and approximate confidence bounds (+/- %f standard deviations of the LH)\n"%n_std
-    header2 = "date \tN_e \tlower \tupper"
-    if screen: print('\t'+header1+'\t'+header2)
-    if save: fh.write("#"+ header1+'#'+header2+'\n')
-    for (x,y, y1, y2) in zip(skyline.x, skyline.y, conf[0], conf[1]):
-        if screen: print("\t%1.3f\t%1.3e\t%1.3e\t%1.3e"%(x,y, y1, y2))
-        if save: fh.write("%1.3f\t%1.3e\t%1.3e\t%1.3e\n"%(x,y, y1, y2))
-
-    if save:
-        print("\n --- written skyline to %s\n"%save)
-        fh.close()
-
-    if plot:
-        plt.figure()
-        plt.fill_between(skyline.x, conf[0], conf[1], color=(0.8, 0.8, 0.8))
-        plt.plot(skyline.x, skyline.y, label='maximum likelihood skyline')
-        plt.yscale('log')
-        plt.legend()
-        plt.ticklabel_format(axis='x',useOffset=False)
-        plt.savefig(plot)
-
-
 def scan_homoplasies(params):
     """
     the function implementing treetime homoplasies
@@ -296,7 +94,7 @@ def scan_homoplasies(params):
     ### ANCESTRAL RECONSTRUCTION
     ###########################################################################
     treeanc = TreeAnc(params.tree, aln=aln, ref=ref, gtr=gtr, verbose=1,
-                      fill_overhangs=True)
+                      fill_overhangs=True, rng_seed=params.rng_seed)
     if treeanc.aln is None: # if alignment didn't load, exit
         return 1
 
@@ -369,8 +167,6 @@ def scan_homoplasies(params):
                                       for x in treeanc.tree.find_clades()])
     corrected_terminal_branch_length = np.sum([np.exp(-x.branch_length)*np.sinh(x.branch_length)
                                       for x in treeanc.tree.get_terminals()])
-    expected_mutations = L*corrected_branch_length
-    expected_terminal_mutations = L*corrected_terminal_branch_length
 
     # make histograms and sum mutations in different categories
     multiplicities = np.bincount([len(x) for x in mutations.values()])
@@ -527,10 +323,11 @@ def timetree(params):
     if params.aln is None and params.sequence_length is None:
         print("one of arguments '--aln' and '--sequence-length' is required.", file=sys.stderr)
         return 1
+
     myTree = TreeTime(dates=dates, tree=params.tree, ref=ref,
                       aln=aln, gtr=gtr, seq_len=params.sequence_length,
                       verbose=params.verbose, fill_overhangs=not params.keep_overhangs,
-                      branch_length_mode = params.branch_length_mode)
+                      branch_length_mode = params.branch_length_mode, rng_seed=params.rng_seed)
 
     return run_timetree(myTree, params, outdir)
 
@@ -562,17 +359,24 @@ def run_timetree(myTree, params, outdir, tree_suffix='', prune_short=True, metho
     # coalescent model options
     try:
         coalescent = float(params.coalescent)
-        if coalescent<10*myTree.one_mutation:
-            coalescent = None
     except:
         if params.coalescent in ['opt', 'const', 'skyline']:
             coalescent = params.coalescent
         else:
-            print("unknown coalescent model specification, has to be either "
-                  "a float, 'opt', 'const' or 'skyline' -- exiting")
-            return 1
+            raise TreeTimeError("unknown coalescent model specification, has to be either "
+                                "a float, 'opt', 'const' or 'skyline' -- exiting")
+
+    # coalescent rates faster than the time to one mutation can lead to numerical issues
+    if type(coalescent)==float and coalescent>0 and coalescent<myTree.one_mutation:
+        raise TreeTimeError(f"coalescent time scale is too low, should be at least distance"
+                            f" corresponding to one mutation {myTree.one_mutation:1.3e}")
+
+
     n_branches_posterior = params.n_branches_posterior
 
+    if hasattr(params, 'stochastic_resolve'):
+        stochastic_resolve = params.stochastic_resolve
+    else: stochastic_resolve = False
 
     # determine whether confidence intervals are to be computed and how the
     # uncertainty in the rate estimate should be treated
@@ -603,9 +407,10 @@ def run_timetree(myTree, params, outdir, tree_suffix='', prune_short=True, metho
     try:
         success = myTree.run(root=root, relaxed_clock=relaxed_clock_params,
                resolve_polytomies=(not params.keep_polytomies),
+               stochastic_resolve = stochastic_resolve,
                Tc=coalescent, max_iter=params.max_iter,
                fixed_clock_rate=params.clock_rate,
-               n_iqd=params.clock_filter,
+               n_iqd=params.clock_filter, clock_filter_method=params.clock_filter_method,
                time_marginal="confidence-only" if (calc_confidence and time_marginal=='never') else time_marginal,
                vary_rate = vary_rate,
                branch_length_mode = branch_length_mode,
@@ -638,12 +443,12 @@ def run_timetree(myTree, params, outdir, tree_suffix='', prune_short=True, metho
     if coalescent in ['skyline', 'opt', 'const']:
         print("Inferred coalescent model")
         if coalescent=='skyline':
-            print_save_plot_skyline(myTree, plot=basename+'skyline.pdf', save=basename+'skyline.tsv', screen=True)
+            print_save_plot_skyline(myTree, plot=basename+'skyline.pdf', save=basename+'skyline.tsv', screen=True, gen=params.gen_per_year)
         else:
             Tc = myTree.merger_model.Tc.y[0]
             print(" --T_c: \t %1.2e \toptimized inverse merger rate in units of substitutions"%Tc)
             print(" --T_c: \t %1.2e \toptimized inverse merger rate in years"%(Tc/myTree.date2dist.clock_rate))
-            print(" --N_e: \t %1.2e \tcorresponding 'effective population size' assuming 50 gen/year\n"%(Tc/myTree.date2dist.clock_rate*50))
+            print(" --N_e: \t %1.2e \tcorresponding 'effective population size' assuming %1.2e gen/year\n"%(Tc/myTree.date2dist.clock_rate*params.gen_per_year, params.gen_per_year))
 
     # plot
     ##IMPORTANT: after this point the functions not only plot the tree but also modify the branch length
@@ -702,7 +507,7 @@ def ancestral_reconstruction(params):
     is_vcf = True if ref is not None else False
 
     treeanc = TreeAnc(params.tree, aln=aln, ref=ref, gtr=gtr, verbose=1,
-                      fill_overhangs=not params.keep_overhangs)
+                      fill_overhangs=not params.keep_overhangs, rng_seed=params.rng_seed)
 
     try:
         ndiff = treeanc.infer_ancestral_sequences('ml', infer_gtr=params.gtr=='infer',
@@ -729,7 +534,7 @@ def ancestral_reconstruction(params):
     return 0
 
 def reconstruct_discrete_traits(tree, traits, missing_data='?', pc=1.0, sampling_bias_correction=None,
-                                weights=None, verbose=0, iterations=5):
+                                weights=None, verbose=0, iterations=5, rng_seed=None):
     """take a set of discrete states associated with tips of a tree
     and reconstruct their ancestral states along with a GTR model that
     approximately maximizes the likelihood of the states on the tree.
@@ -747,7 +552,7 @@ def reconstruct_discrete_traits(tree, traits, missing_data='?', pc=1.0, sampling
     sampling_bias_correction : float, optional
         factor to inflate overall switching rate by to counteract sampling bias
     weights : str, optional
-        name of file with equilibirum frequencies
+        name of file with equilibrium frequencies
     verbose : int, optional
         level of verbosity in output
     iterations : int, optional
@@ -768,7 +573,10 @@ def reconstruct_discrete_traits(tree, traits, missing_data='?', pc=1.0, sampling
     ### make a single character alphabet that maps to discrete states
     ###########################################################################
 
-    unique_states = set(traits.values())
+    # Find all unique states to reconstruct, excluding the missing data state.
+    # This missing state will get its own letter assigned after we enumerate the
+    # known states.
+    unique_states = set(traits.values()) - {missing_data}
     n_observed_states = len(unique_states)
 
     # load weights from file and convert to dict if supplied as string
@@ -837,7 +645,7 @@ def reconstruct_discrete_traits(tree, traits, missing_data='?', pc=1.0, sampling
     ### set up treeanc
     ###########################################################################
     treeanc = TreeAnc(tree, gtr=mugration_GTR, verbose=verbose, ref='A',
-                      convert_upper=False, one_mutation=0.001)
+                      convert_upper=False, one_mutation=0.001, rng_seed=rng_seed)
     treeanc.use_mutation_length = False
     pseudo_seqs = {n.name: {0:reverse_alphabet[traits[n.name]] if n.name in traits else missing_char}
                    for n in treeanc.tree.get_terminals()}
@@ -864,11 +672,6 @@ def reconstruct_discrete_traits(tree, traits, missing_data='?', pc=1.0, sampling
                                  marginal=True, normalized_rate=False,
                                  reconstruct_tip_states=True)
 
-    print(fill("NOTE: previous versions (<0.7.0) of this command made a 'short-branch length assumption. "
-          "TreeTime now optimizes the overall rate numerically and thus allows for long branches "
-          "along which multiple changes accumulated. This is expected to affect estimates of the "
-          "overall rate while leaving the relative rates mostly unchanged."))
-
     return treeanc, letter_to_state, reverse_alphabet
 
 
@@ -916,8 +719,10 @@ def mugration(params):
     leaf_to_attr = {x[taxon_name]:str(x[attr]) for xi, x in states.iterrows()
                     if x[attr]!=params.missing_data and x[attr]}
 
-    mug, letter_to_state, reverse_alphabet = reconstruct_discrete_traits(params.tree, leaf_to_attr, missing_data=params.missing_data,
-            pc=params.pc, sampling_bias_correction=params.sampling_bias_correction, verbose=params.verbose, weights=params.weights)
+    mug, letter_to_state, reverse_alphabet = reconstruct_discrete_traits(params.tree, leaf_to_attr,
+                missing_data=params.missing_data, pc=params.pc,
+                sampling_bias_correction=params.sampling_bias_correction,
+                verbose=params.verbose, weights=params.weights, rng_seed=params.rng_seed)
 
     if mug is None:
         print("Mugration inference failed, check error messages above and your input data.")
@@ -941,8 +746,6 @@ def mugration(params):
     terminal_count = 0
     for n in mug.tree.find_clades():
         n.confidence=None
-        if n.up is None:
-            continue
         # due to a bug in older versions of biopython that truncated filenames in nexus export
         # we truncate them by hand and make them unique.
         if n.is_terminal() and len(n.name)>40 and bioversion<"1.69":
@@ -998,7 +801,7 @@ def estimate_clock_model(params):
     try:
         myTree = TreeTime(dates=dates, tree=params.tree, aln=aln, gtr='JC69',
                       verbose=params.verbose, seq_len=params.sequence_length,
-                      ref=ref)
+                      ref=ref, rng_seed=params.rng_seed)
     except TreeTimeError as e:
         print("\nTreeTime setup failed. Please see above for error messages and/or rerun with --verbose 4\n")
         raise e
@@ -1006,7 +809,8 @@ def estimate_clock_model(params):
     myTree.tip_slack=params.tip_slack
     if params.clock_filter:
         n_bad = [n.name for n in myTree.tree.get_terminals() if n.bad_branch]
-        myTree.clock_filter(n_iqd=params.clock_filter, reroot=params.reroot or 'least-squares')
+        myTree.clock_filter(n_iqd=params.clock_filter, reroot=params.reroot or 'least-squares',
+                            method=params.clock_filter_method)
         n_bad_after = [n.name for n in myTree.tree.get_terminals() if n.bad_branch]
         if len(n_bad_after)>len(n_bad):
             print("The following leaves don't follow a loose clock and "



View it on GitLab: https://salsa.debian.org/med-team/python-treetime/-/commit/bc7ed4dea2a0360e7cb30762ccbfaa42a2cc1c46

-- 
View it on GitLab: https://salsa.debian.org/med-team/python-treetime/-/commit/bc7ed4dea2a0360e7cb30762ccbfaa42a2cc1c46
You're receiving this email because of your account on salsa.debian.org.


-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://alioth-lists.debian.net/pipermail/debian-med-commit/attachments/20231101/2541e536/attachment-0001.htm>


More information about the debian-med-commit mailing list