From 8c857700045f801667d9e2989c9b52aaa32c7839 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Tue, 18 May 2021 15:51:44 -0700 Subject: [PATCH 001/196] first attempt at making inference work with JointPartialsProvider --- .../hmc/IntegratedLoadingsGradient.java | 18 +++++--- .../continuous/JointPartialsProvider.java | 41 +++++++++++++++++++ .../hmc/IntegratedLoadingsGradientParser.java | 10 ++++- src/dr/math/matrixAlgebra/WrappedVector.java | 5 +++ 4 files changed, 68 insertions(+), 6 deletions(-) diff --git a/src/dr/evomodel/continuous/hmc/IntegratedLoadingsGradient.java b/src/dr/evomodel/continuous/hmc/IntegratedLoadingsGradient.java index 25b11efe86..d788e66792 100644 --- a/src/dr/evomodel/continuous/hmc/IntegratedLoadingsGradient.java +++ b/src/dr/evomodel/continuous/hmc/IntegratedLoadingsGradient.java @@ -2,10 +2,10 @@ import dr.evolution.tree.Tree; import dr.evolution.tree.TreeTrait; -import dr.evomodel.treedatalikelihood.DataLikelihoodDelegate; import dr.evomodel.treedatalikelihood.TreeDataLikelihood; import dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate; import dr.evomodel.treedatalikelihood.continuous.IntegratedFactorAnalysisLikelihood; +import dr.evomodel.treedatalikelihood.continuous.JointPartialsProvider; import dr.evomodel.treedatalikelihood.continuous.cdi.PrecisionType; import dr.evomodel.treedatalikelihood.preorder.WrappedNormalSufficientStatistics; import dr.evomodel.treedatalikelihood.preorder.WrappedTipFullConditionalDistributionDelegate; @@ -16,7 +16,6 @@ import dr.math.matrixAlgebra.missingData.MissingOps; import dr.util.StopWatch; import dr.util.TaskPool; -import dr.evomodelxml.continuous.hmc.TaskPoolParser; import dr.xml.*; import org.ejml.data.DenseMatrix64F; @@ -44,16 +43,19 @@ public class IntegratedLoadingsGradient implements GradientWrtParameterProvider, private final ThreadUseProvider threadUseProvider; private final RemainderCompProvider remainderCompProvider; private final TaskPool taskPool; + private final JointPartialsProvider jointPartials; public IntegratedLoadingsGradient(TreeDataLikelihood treeDataLikelihood, ContinuousDataLikelihoodDelegate likelihoodDelegate, IntegratedFactorAnalysisLikelihood factorAnalysisLikelihood, + JointPartialsProvider jointPartials, TaskPool taskPool, ThreadUseProvider threadUseProvider, RemainderCompProvider remainderCompProvider) { this.factorAnalysisLikelihood = factorAnalysisLikelihood; + this.jointPartials = jointPartials; String traitName = factorAnalysisLikelihood.getModelName(); @@ -275,9 +277,15 @@ private void computeGradientForOneTaxon(final int index, // for (WrappedNormalSufficientStatistics statistic : statistics) { // TODO Maybe need to re-enable - final ReadableVector meanFactor = statistic.getMean(); - final WrappedMatrix precisionFactor = statistic.getPrecision(); - final WrappedMatrix varianceFactor = statistic.getVariance(); + final WrappedNormalSufficientStatistics extendedStatistic; + if (jointPartials == null) { + extendedStatistic = statistic; + } else { + extendedStatistic = jointPartials.partitionNormalStatistics(statistic, factorAnalysisLikelihood); + } + final ReadableVector meanFactor = extendedStatistic.getMean(); + final WrappedMatrix precisionFactor = extendedStatistic.getPrecision(); + final WrappedMatrix varianceFactor = extendedStatistic.getVariance(); if (DEBUG) { System.err.println("FM" + taxon + " : " + meanFactor); diff --git a/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java b/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java index fa3e2f8739..067d52a585 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java @@ -2,11 +2,14 @@ import dr.evolution.tree.Tree; import dr.evomodel.treedatalikelihood.continuous.cdi.PrecisionType; +import dr.evomodel.treedatalikelihood.preorder.WrappedNormalSufficientStatistics; import dr.inference.model.*; import dr.math.matrixAlgebra.WrappedMatrix; +import dr.math.matrixAlgebra.WrappedVector; import dr.math.matrixAlgebra.missingData.MissingOps; import dr.xml.*; import org.ejml.data.DenseMatrix64F; +import org.ejml.ops.CommonOps; import java.util.List; @@ -265,6 +268,44 @@ public void addTreeAndRateModel(Tree treeModel, ContinuousRateTransformation rat } } + public WrappedNormalSufficientStatistics partitionNormalStatistics(WrappedNormalSufficientStatistics statistic, + ContinuousTraitPartialsProvider provider) { + + int traitOffset = 0; + for (ContinuousTraitPartialsProvider potentialProvider : providers) { + if (provider == potentialProvider) { + break; + } else { + traitOffset += potentialProvider.getTraitDimension(); + } + } + + int traitDim = provider.getTraitDimension(); + + WrappedVector originalMean = statistic.getMean(); + WrappedVector newMean = new WrappedVector.View(originalMean, traitOffset, traitDim); + + int[] varianceIndices = new int[traitDim]; + for (int i = 0; i < traitDim; i++) { + varianceIndices[i] = i + traitOffset; + } + + WrappedMatrix originalVariance = statistic.getVariance(); + DenseMatrix64F newVariance = new DenseMatrix64F(traitDim, traitDim); + + for (int i = 0; i < traitDim; i++) { + for (int j = 0; j < traitDim; j++) { + newVariance.set(i, j, originalVariance.get(varianceIndices[i], varianceIndices[j])); + } + } + + DenseMatrix64F newPrecision = new DenseMatrix64F(traitDim, traitDim); + CommonOps.invert(newVariance, newPrecision); //TODO: cholesky + + return new WrappedNormalSufficientStatistics(newMean, new WrappedMatrix.WrappedDenseMatrix(newPrecision), + new WrappedMatrix.WrappedDenseMatrix(newVariance)); + } + public static final AbstractXMLObjectParser PARSER = new AbstractXMLObjectParser() { private static final String PARSER_NAME = "jointPartialsProvider"; diff --git a/src/dr/evomodelxml/continuous/hmc/IntegratedLoadingsGradientParser.java b/src/dr/evomodelxml/continuous/hmc/IntegratedLoadingsGradientParser.java index 77bf185b8c..b7b15619e6 100644 --- a/src/dr/evomodelxml/continuous/hmc/IntegratedLoadingsGradientParser.java +++ b/src/dr/evomodelxml/continuous/hmc/IntegratedLoadingsGradientParser.java @@ -5,6 +5,7 @@ import dr.evomodel.treedatalikelihood.TreeDataLikelihood; import dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate; import dr.evomodel.treedatalikelihood.continuous.IntegratedFactorAnalysisLikelihood; +import dr.evomodel.treedatalikelihood.continuous.JointPartialsProvider; import dr.util.TaskPool; import dr.xml.*; @@ -68,12 +69,16 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { "=\"" + PARALLEL + "\" or remove the " + TaskPoolParser.TASK_PARSER_NAME + " element."); } + JointPartialsProvider jointPartials = (JointPartialsProvider) xo.getChild(JointPartialsProvider.class); + + // TODO Check dimensions, parameters, etc. return factory( treeDataLikelihood, continuousDataLikelihoodDelegate, factorAnalysis, + jointPartials, taskPool, threadProvider, remainderCompProvider); @@ -83,6 +88,7 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { protected IntegratedLoadingsGradient factory(TreeDataLikelihood treeDataLikelihood, ContinuousDataLikelihoodDelegate likelihoodDelegate, IntegratedFactorAnalysisLikelihood factorAnalysisLikelihood, + JointPartialsProvider jointPartialsProvider, TaskPool taskPool, IntegratedLoadingsGradient.ThreadUseProvider threadUseProvider, IntegratedLoadingsGradient.RemainderCompProvider remainderCompProvider) @@ -92,6 +98,7 @@ protected IntegratedLoadingsGradient factory(TreeDataLikelihood treeDataLikeliho treeDataLikelihood, likelihoodDelegate, factorAnalysisLikelihood, + jointPartialsProvider, taskPool, threadUseProvider, remainderCompProvider); @@ -123,7 +130,8 @@ public String getParserName() { new ElementRule(TreeDataLikelihood.class), new ElementRule(TaskPool.class, true), AttributeRule.newStringRule(THREAD_TYPE, true), - AttributeRule.newStringRule(REMAINDER_COMPUTATION, true) + AttributeRule.newStringRule(REMAINDER_COMPUTATION, true), + new ElementRule(JointPartialsProvider.class, true) }; } diff --git a/src/dr/math/matrixAlgebra/WrappedVector.java b/src/dr/math/matrixAlgebra/WrappedVector.java index 48dc3ced30..4181912b36 100644 --- a/src/dr/math/matrixAlgebra/WrappedVector.java +++ b/src/dr/math/matrixAlgebra/WrappedVector.java @@ -27,6 +27,7 @@ import dr.inference.model.Variable; +import dr.xml.AbstractXMLObjectParser; /** * @author Marc A. Suchard @@ -98,6 +99,10 @@ final class View extends Raw { public View(WrappedVector vector, int offset, int length) { super(vector.getBuffer(), vector.getOffset() + offset, length); + + if (!(vector instanceof WrappedVector.Raw)) { + throw new RuntimeException("This can only extend WrappedVector.Raw"); + } } } From 78cd317e385b80b5983c7cbf4e2a38320f3c686b Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Tue, 18 May 2021 16:04:54 -0700 Subject: [PATCH 002/196] missing data in loadings gradient --- ci/TestXML/testLoadingsGradient.xml | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/ci/TestXML/testLoadingsGradient.xml b/ci/TestXML/testLoadingsGradient.xml index 68e7a2c632..2383225573 100644 --- a/ci/TestXML/testLoadingsGradient.xml +++ b/ci/TestXML/testLoadingsGradient.xml @@ -12,8 +12,7 @@ - 1.056873733132347 -0.8909827170981991 1.1873505039550996 -0.39339506737983243 - 1.066150616149452 + 1.056873733132347 NA NA NA 1.066150616149452 @@ -27,7 +26,7 @@ - 0.04658821120386576 -0.8666572905938715 -0.454549337998158 -0.5001246338040305 + NA -0.8666572905938715 -0.454549337998158 -0.5001246338040305 -0.5397266841036715 @@ -37,7 +36,7 @@ - -0.42898479922720356 -0.3440100523685048 0.6436374468251037 -0.09729958495856225 + -0.42898479922720356 NA 0.6436374468251037 -0.09729958495856225 -1.396413239367357 @@ -47,8 +46,7 @@ - -0.8970245252413375 0.7619824434342454 1.5439602691637375 0.09081342420693704 - 1.6135287964494376 + NA NA NA NA 1.6135287964494376 From d024d55db96b07f93cfd8be7ba4c769b94aff626 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Tue, 18 May 2021 18:59:21 -0700 Subject: [PATCH 003/196] bug fix --- .../continuous/JointPartialsProvider.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java b/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java index 067d52a585..067e210413 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java @@ -179,10 +179,10 @@ public double[] getTipPartial(int taxonIndex, boolean fullyObserved) { double subDet = subPartial[precisionType.getDeterminantOffset(subDim)]; if (!precisionType.isMissingDeterminantValue(subDet)) { - - DenseMatrix64F prec = MissingOps.wrap(subPartial, precisionOffset, subDim, subDim); - DenseMatrix64F var = new DenseMatrix64F(subDim, subDim); - subDet = MissingOps.safeInvert2(prec, var, true).getLogDeterminant(); + //TODO: what was I trying to do here? +// DenseMatrix64F prec = MissingOps.wrap(subPartial, precisionOffset, subDim, subDim); +// DenseMatrix64F var = new DenseMatrix64F(subDim, subDim); +// subDet = MissingOps.safeInvert2(prec, var, true).getLogDeterminant(); } partial[detDim] += subDet; From 6fe95a88b7afcd792b4c97c5e2651230a910c3fb Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Wed, 19 May 2021 10:10:00 -0700 Subject: [PATCH 004/196] IntegratedLoadingsGradient more generalized --- .../hmc/IntegratedLoadingsGradient.java | 44 +++++++++---------- .../ContinuousTraitPartialsProvider.java | 14 +++++- .../continuous/JointPartialsProvider.java | 1 + .../hmc/IntegratedLoadingsGradientParser.java | 8 ++-- 4 files changed, 41 insertions(+), 26 deletions(-) diff --git a/src/dr/evomodel/continuous/hmc/IntegratedLoadingsGradient.java b/src/dr/evomodel/continuous/hmc/IntegratedLoadingsGradient.java index d788e66792..c447cdf942 100644 --- a/src/dr/evomodel/continuous/hmc/IntegratedLoadingsGradient.java +++ b/src/dr/evomodel/continuous/hmc/IntegratedLoadingsGradient.java @@ -4,8 +4,8 @@ import dr.evolution.tree.TreeTrait; import dr.evomodel.treedatalikelihood.TreeDataLikelihood; import dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate; +import dr.evomodel.treedatalikelihood.continuous.ContinuousTraitPartialsProvider; import dr.evomodel.treedatalikelihood.continuous.IntegratedFactorAnalysisLikelihood; -import dr.evomodel.treedatalikelihood.continuous.JointPartialsProvider; import dr.evomodel.treedatalikelihood.continuous.cdi.PrecisionType; import dr.evomodel.treedatalikelihood.preorder.WrappedNormalSufficientStatistics; import dr.evomodel.treedatalikelihood.preorder.WrappedTipFullConditionalDistributionDelegate; @@ -34,8 +34,10 @@ public class IntegratedLoadingsGradient implements GradientWrtParameterProvider, private final TreeTrait> fullConditionalDensity; private final IntegratedFactorAnalysisLikelihood factorAnalysisLikelihood; + private final ContinuousTraitPartialsProvider partialsProvider; protected final int dimTrait; protected final int dimFactors; + protected final int dimPartials; private final Tree tree; private final Likelihood likelihood; private final double[] data; @@ -43,19 +45,18 @@ public class IntegratedLoadingsGradient implements GradientWrtParameterProvider, private final ThreadUseProvider threadUseProvider; private final RemainderCompProvider remainderCompProvider; private final TaskPool taskPool; - private final JointPartialsProvider jointPartials; public IntegratedLoadingsGradient(TreeDataLikelihood treeDataLikelihood, ContinuousDataLikelihoodDelegate likelihoodDelegate, IntegratedFactorAnalysisLikelihood factorAnalysisLikelihood, - JointPartialsProvider jointPartials, + ContinuousTraitPartialsProvider partialsProvider, TaskPool taskPool, ThreadUseProvider threadUseProvider, RemainderCompProvider remainderCompProvider) { this.factorAnalysisLikelihood = factorAnalysisLikelihood; - this.jointPartials = jointPartials; + this.partialsProvider = partialsProvider; String traitName = factorAnalysisLikelihood.getModelName(); @@ -70,6 +71,7 @@ public IntegratedLoadingsGradient(TreeDataLikelihood treeDataLikelihood, this.dimTrait = factorAnalysisLikelihood.getDataDimension(); this.dimFactors = factorAnalysisLikelihood.getNumberOfFactors(); + this.dimPartials = partialsProvider.getTraitDimension(); Parameter dataParameter = factorAnalysisLikelihood.getParameter(); this.data = dataParameter.getParameterValues(); @@ -150,8 +152,8 @@ private ReadableMatrix shiftToSecondMoment(WrappedMatrix variance, ReadableVecto return variance; } - private WrappedNormalSufficientStatistics getWeightedAverage(ReadableVector m1, ReadableMatrix p1, - ReadableVector m2, ReadableMatrix p2) { + private static WrappedNormalSufficientStatistics getWeightedAverage(ReadableVector m1, ReadableMatrix p1, + ReadableVector m2, ReadableMatrix p2) { assert (m1.getDim() == m2.getDim()); assert (p1.getDim() == p2.getDim()); @@ -159,9 +161,11 @@ private WrappedNormalSufficientStatistics getWeightedAverage(ReadableVector m1, assert (m1.getDim() == p1.getMinorDim()); assert (m1.getDim() == p1.getMajorDim()); - final WrappedVector m12 = new WrappedVector.Raw(new double[m1.getDim()], 0, dimFactors); - final DenseMatrix64F p12 = new DenseMatrix64F(dimFactors, dimFactors); - final DenseMatrix64F v12 = new DenseMatrix64F(dimFactors, dimFactors); + int dim = m1.getDim(); + + final WrappedVector m12 = new WrappedVector.Raw(new double[m1.getDim()], 0, dim); + final DenseMatrix64F p12 = new DenseMatrix64F(dim, dim); + final DenseMatrix64F v12 = new DenseMatrix64F(dim, dim); final WrappedMatrix wP12 = new WrappedMatrix.WrappedDenseMatrix(p12); final WrappedMatrix wV12 = new WrappedMatrix.WrappedDenseMatrix(v12); @@ -169,7 +173,7 @@ private WrappedNormalSufficientStatistics getWeightedAverage(ReadableVector m1, MissingOps.add(p1, p2, wP12); safeInvert2(p12, v12, false); - weightedAverage(m1, p1, m2, p2, m12, wV12, dimFactors); + weightedAverage(m1, p1, m2, p2, m12, wV12, dim); return new WrappedNormalSufficientStatistics(m12, wP12, wV12); } @@ -277,15 +281,9 @@ private void computeGradientForOneTaxon(final int index, // for (WrappedNormalSufficientStatistics statistic : statistics) { // TODO Maybe need to re-enable - final WrappedNormalSufficientStatistics extendedStatistic; - if (jointPartials == null) { - extendedStatistic = statistic; - } else { - extendedStatistic = jointPartials.partitionNormalStatistics(statistic, factorAnalysisLikelihood); - } - final ReadableVector meanFactor = extendedStatistic.getMean(); - final WrappedMatrix precisionFactor = extendedStatistic.getPrecision(); - final WrappedMatrix varianceFactor = extendedStatistic.getVariance(); + final ReadableVector meanFactor = statistic.getMean(); + final WrappedMatrix precisionFactor = statistic.getPrecision(); + final WrappedMatrix varianceFactor = statistic.getVariance(); if (DEBUG) { System.err.println("FM" + taxon + " : " + meanFactor); @@ -293,10 +291,12 @@ private void computeGradientForOneTaxon(final int index, System.err.println("FV" + taxon + " : " + varianceFactor); } - final WrappedNormalSufficientStatistics convolution = getWeightedAverage( + WrappedNormalSufficientStatistics convolution = getWeightedAverage( meanFactor, precisionFactor, meanKernel, precisionKernel); + convolution = partialsProvider.partitionNormalStatistics(convolution, factorAnalysisLikelihood); + final ReadableVector mean = convolution.getMean(); // final ReadableMatrix precision = convolution.getPrecision(); final WrappedMatrix variance = convolution.getVariance(); @@ -374,8 +374,8 @@ private static double[] join(double[][] array) { // } private WrappedNormalSufficientStatistics getTipKernel(int taxonIndex) { - double[] buffer = factorAnalysisLikelihood.getTipPartial(taxonIndex, false); - return new WrappedNormalSufficientStatistics(buffer, 0, dimFactors, null, PrecisionType.FULL); + double[] buffer = partialsProvider.getTipPartial(taxonIndex, false); + return new WrappedNormalSufficientStatistics(buffer, 0, dimPartials, null, PrecisionType.FULL); } public enum ThreadUseProvider { diff --git a/src/dr/evomodel/treedatalikelihood/continuous/ContinuousTraitPartialsProvider.java b/src/dr/evomodel/treedatalikelihood/continuous/ContinuousTraitPartialsProvider.java index 8000b1b9ce..5064bd2c27 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/ContinuousTraitPartialsProvider.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/ContinuousTraitPartialsProvider.java @@ -27,6 +27,7 @@ import dr.evolution.tree.Tree; import dr.evomodel.treedatalikelihood.continuous.cdi.PrecisionType; +import dr.evomodel.treedatalikelihood.preorder.WrappedNormalSufficientStatistics; import dr.inference.model.CompoundParameter; import java.util.ArrayList; @@ -77,12 +78,23 @@ default boolean suppliesWishartStatistics() { return true; } - default int[] getPartitionDimensions() { return null;} + default int[] getPartitionDimensions() { + return null; + } default void addTreeAndRateModel(Tree treeModel, ContinuousRateTransformation rateTransformation) { // Do nothing } + default WrappedNormalSufficientStatistics partitionNormalStatistics(WrappedNormalSufficientStatistics statistic, + ContinuousTraitPartialsProvider provider) { + if (this == provider) { + return statistic; + } + throw new RuntimeException("This class does not currently support 'partitionNormalStatistics' with " + + "a provider other than itself."); + } + static boolean[] indicesToIndicator(List indices, int n) { if (indices == null) { diff --git a/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java b/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java index 067e210413..e6d1a1118e 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java @@ -268,6 +268,7 @@ public void addTreeAndRateModel(Tree treeModel, ContinuousRateTransformation rat } } + @Override public WrappedNormalSufficientStatistics partitionNormalStatistics(WrappedNormalSufficientStatistics statistic, ContinuousTraitPartialsProvider provider) { diff --git a/src/dr/evomodelxml/continuous/hmc/IntegratedLoadingsGradientParser.java b/src/dr/evomodelxml/continuous/hmc/IntegratedLoadingsGradientParser.java index b7b15619e6..cc635b80ff 100644 --- a/src/dr/evomodelxml/continuous/hmc/IntegratedLoadingsGradientParser.java +++ b/src/dr/evomodelxml/continuous/hmc/IntegratedLoadingsGradientParser.java @@ -4,6 +4,7 @@ import dr.evomodel.treedatalikelihood.DataLikelihoodDelegate; import dr.evomodel.treedatalikelihood.TreeDataLikelihood; import dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate; +import dr.evomodel.treedatalikelihood.continuous.ContinuousTraitPartialsProvider; import dr.evomodel.treedatalikelihood.continuous.IntegratedFactorAnalysisLikelihood; import dr.evomodel.treedatalikelihood.continuous.JointPartialsProvider; import dr.util.TaskPool; @@ -69,7 +70,8 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { "=\"" + PARALLEL + "\" or remove the " + TaskPoolParser.TASK_PARSER_NAME + " element."); } - JointPartialsProvider jointPartials = (JointPartialsProvider) xo.getChild(JointPartialsProvider.class); + ContinuousTraitPartialsProvider partialsProvider = (JointPartialsProvider) xo.getChild(JointPartialsProvider.class); + if (partialsProvider == null) partialsProvider = factorAnalysis; // TODO Check dimensions, parameters, etc. @@ -78,7 +80,7 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { treeDataLikelihood, continuousDataLikelihoodDelegate, factorAnalysis, - jointPartials, + partialsProvider, taskPool, threadProvider, remainderCompProvider); @@ -88,7 +90,7 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { protected IntegratedLoadingsGradient factory(TreeDataLikelihood treeDataLikelihood, ContinuousDataLikelihoodDelegate likelihoodDelegate, IntegratedFactorAnalysisLikelihood factorAnalysisLikelihood, - JointPartialsProvider jointPartialsProvider, + ContinuousTraitPartialsProvider jointPartialsProvider, TaskPool taskPool, IntegratedLoadingsGradient.ThreadUseProvider threadUseProvider, IntegratedLoadingsGradient.RemainderCompProvider remainderCompProvider) From cd21a56d3a425f985c9562dc5b83fa2961db442a Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Wed, 19 May 2021 10:11:03 -0700 Subject: [PATCH 005/196] updated testJointPartialsProvider.xml --- ci/TestXML/testJointPartialsProvider.xml | 184 ++++++++++++++++++----- 1 file changed, 146 insertions(+), 38 deletions(-) diff --git a/ci/TestXML/testJointPartialsProvider.xml b/ci/TestXML/testJointPartialsProvider.xml index d57d942f02..37d0b17257 100644 --- a/ci/TestXML/testJointPartialsProvider.xml +++ b/ci/TestXML/testJointPartialsProvider.xml @@ -95,10 +95,10 @@ - - - - + + + + @@ -114,9 +114,9 @@ - - - + + + @@ -124,9 +124,9 @@ + value="0.9247905179016628 -0.33921029582132534 0.5584156892793313 0.3131380271405006 -1.3830823737436915"/> + value="-1.032474009627882 -0.5424164961392062 -0.7720274902578971 0.16523232493297793 0.6057788313870089"/> @@ -145,13 +145,13 @@ - + + value="7.6897114774781805 2.447408007560504 5.624132942250213 2.575119059444284 18.984513296012224"/> @@ -166,8 +166,8 @@ - - + + @@ -178,7 +178,7 @@ - + @@ -284,7 +284,7 @@ - -100.22471627319682 + -92.83480728956839 @@ -296,7 +296,7 @@ - -36.73488439033143 + -34.55739075783681 @@ -308,32 +308,140 @@ - -65.00935647098959 + -60.18281026197038 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Check log likelihood of joint model (with correlation): + + + + + + -104.76448987237042 + + + Check log likelihood of residual model only: + + + + + + -35.65645117578014 + + + + + + Check log likelihood of factor model only: + + + + + + -70.13809042542833 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + From 8ad32d783fde18e8962e287234abcf4820f898e4 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Thu, 20 May 2021 16:44:38 -0700 Subject: [PATCH 006/196] testing inference on restricted diffusion covariance --- ci/TestXML/testJointPartialsProvider.xml | 167 +++++++++++++++++++++++ 1 file changed, 167 insertions(+) diff --git a/ci/TestXML/testJointPartialsProvider.xml b/ci/TestXML/testJointPartialsProvider.xml index 37d0b17257..223b99f2ab 100644 --- a/ci/TestXML/testJointPartialsProvider.xml +++ b/ci/TestXML/testJointPartialsProvider.xml @@ -443,5 +443,172 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + From e1c5d7582aaf77993d30352f80d28ccc1d00dff0 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Thu, 20 May 2021 16:52:57 -0700 Subject: [PATCH 007/196] probably not a problem but better to be sure --- ci/TestXML/testJointPartialsProvider.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/TestXML/testJointPartialsProvider.xml b/ci/TestXML/testJointPartialsProvider.xml index 223b99f2ab..c145a58ce1 100644 --- a/ci/TestXML/testJointPartialsProvider.xml +++ b/ci/TestXML/testJointPartialsProvider.xml @@ -458,7 +458,7 @@ - From cf9d584115e57c06344f97b4ca4e1c4e9a16958b Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Wed, 26 May 2021 14:54:37 -0700 Subject: [PATCH 008/196] jointPartialsProvider merges data parameters --- .../continuous/JointPartialsProvider.java | 17 ++++- src/dr/inference/model/CompoundParameter.java | 66 +++++++++++++++++++ 2 files changed, 81 insertions(+), 2 deletions(-) diff --git a/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java b/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java index fa3e2f8739..9a127af1e3 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java @@ -35,6 +35,10 @@ public class JointPartialsProvider extends AbstractModel implements ContinuousTr private String tipTraitName; + private final CompoundParameter jointDataParameter; + + private static final Boolean DEBUG = false; + public JointPartialsProvider(String name, ContinuousTraitPartialsProvider[] providers) { super(name); this.name = name; @@ -61,6 +65,16 @@ public JointPartialsProvider(String name, ContinuousTraitPartialsProvider[] prov addModel((Model) provider); } } + + CompoundParameter[] parameters = new CompoundParameter[providers.length]; + for (int i = 0; i < parameters.length; i++) { + parameters[i] = providers[i].getParameter(); + } + + this.jointDataParameter = CompoundParameter.mergeParameters(parameters); + if (DEBUG) { + CompoundParameter.checkParametersMerged(jointDataParameter, parameters); + } } @@ -207,8 +221,7 @@ public boolean[] getDataMissingIndicators() { @Override public CompoundParameter getParameter() { - System.err.println("Warning: This is broken. (JointPartialsProvider.getParameter())"); - return providers[0].getParameter(); //TODO: This is going to be the real problem, I think + return jointDataParameter; } @Override diff --git a/src/dr/inference/model/CompoundParameter.java b/src/dr/inference/model/CompoundParameter.java index 7a132303e6..6e5ea2481a 100644 --- a/src/dr/inference/model/CompoundParameter.java +++ b/src/dr/inference/model/CompoundParameter.java @@ -290,6 +290,72 @@ public double getParameterValue(int index, int parameter) { return getParameter(parameter).getParameterValue(index); } + public static CompoundParameter mergeParameters(String newName, + CompoundParameter[] parameters, + Boolean enforceSameNames) { + int nParameters = parameters[0].getParameterCount(); + for (int i = 1; i < parameters.length; i++) { + assert (nParameters == parameters[i].getParameterCount()); + } + + CompoundParameter mergedParameter = new CompoundParameter(newName); + for (int i = 0; i < nParameters; i++) { + String parameterName = parameters[0].getParameter(i).getParameterName(); + CompoundParameter rowParameter = new CompoundParameter(parameterName); + for (int j = 0; j < parameters.length; j++) { + Parameter subParameter = parameters[j].getParameter(i); + if (enforceSameNames && subParameter.getParameterName() != parameterName) { + throw new RuntimeException("parameter " + j + " with sub-parameter " + i + " has name " + + subParameter.getParameterName() + ". This does not match parameter 0 with sub-parameter " + + +i + " that has name " + parameterName + "."); + } + rowParameter.addParameter(subParameter); + } + mergedParameter.addParameter(rowParameter); + } + + return mergedParameter; + } + + public static CompoundParameter mergeParameters(CompoundParameter[] parameters) { + String newName = parameters[0].getParameterName(); + for (int i = 1; i < parameters.length; i++) { + newName += "_and_" + parameters[i].getParameterName(); + } + return mergeParameters(newName, parameters, true); + } + + public static void checkParametersMerged(CompoundParameter mergedParameter, CompoundParameter[] parameters) { + int nParams = mergedParameter.getParameterCount(); + for (int i = 0; i < nParams; i++) { + CompoundParameter subMerged = (CompoundParameter) mergedParameter.getParameter(i); + + for (int j = 0; j < parameters.length; j++) { + assert (subMerged.getParameter(j) == parameters[j].getParameter(i)); + } + + //Below is redundant but a good sanity check + int currentParameter = 0; + int currentIndex = 0; + int currentDimension = parameters[currentParameter].getParameter(i).getDimension(); + + for (int j = 0; j < subMerged.getDimension(); j++) { + if (currentIndex == currentDimension) { + currentParameter++; + currentIndex = 0; + currentDimension = parameters[currentParameter].getParameter(i).getDimension(); + } + + double v1 = subMerged.getParameterValue(j); + double v2 = parameters[currentParameter].getParameter(i).getParameterValue(currentIndex); + + assert (v1 == v2); + + currentIndex++; + } + } + } + // **************************************************************** // Private and protected stuff // **************************************************************** From 7c0c0bef55aaad35a086c62bb3fd27a9fc91ffc6 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Wed, 26 May 2021 15:47:23 -0700 Subject: [PATCH 009/196] cleaning code --- .../continuous/JointPartialsProvider.java | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java b/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java index 3c3e48c125..c6de1d0254 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java @@ -110,7 +110,7 @@ public boolean bufferTips() { @Override public int getTraitCount() { - return providers[0].getTraitCount(); //TODO: make sure all have same trait count in parser + return providers[0].getTraitCount(); } @Override @@ -214,12 +214,12 @@ public double[] getTipPartial(int taxonIndex, boolean fullyObserved) { @Override public List getMissingIndices() { - return missingIndices; //TODO: how to merge missing indices + return missingIndices; } @Override public boolean[] getDataMissingIndicators() { - return missingIndicators; //TODO: see above + return missingIndicators; } @Override @@ -336,6 +336,14 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { for (int i = 0; i < providersList.size(); i++) { providers[i] = providersList.get(i); } + + int traitCount = providers[0].getTraitCount(); + for (int i = 1; i < providers.length; i++) { + if (providers[i].getTraitCount() != traitCount) { + throw new XMLParseException("all partials providers must have the same trait count"); + } + + } return new JointPartialsProvider(PARSER_NAME, providers); } From ba4cb5eb4bd9e9b2d285d4d59d8c21aa421de7d8 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Wed, 26 May 2021 15:49:43 -0700 Subject: [PATCH 010/196] more code cleaning --- .../treedatalikelihood/continuous/JointPartialsProvider.java | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java b/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java index c6de1d0254..f91560c678 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java @@ -23,7 +23,6 @@ public class JointPartialsProvider extends AbstractModel implements ContinuousTraitPartialsProvider { - private final String name; private final ContinuousTraitPartialsProvider[] providers; private final int traitDim; private final int dataDim; @@ -44,7 +43,6 @@ public class JointPartialsProvider extends AbstractModel implements ContinuousTr public JointPartialsProvider(String name, ContinuousTraitPartialsProvider[] providers) { super(name); - this.name = name; this.providers = providers; int traitDim = 0; From 8f8aaa10b0137f11b5006678e0a3837081ff2405 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Wed, 2 Jun 2021 10:16:24 -0700 Subject: [PATCH 011/196] more flexible CorrelationMatrixStatistic --- src/dr/inference/model/CorrelationMatrixStatistic.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/dr/inference/model/CorrelationMatrixStatistic.java b/src/dr/inference/model/CorrelationMatrixStatistic.java index bd0c926170..0bb98758eb 100644 --- a/src/dr/inference/model/CorrelationMatrixStatistic.java +++ b/src/dr/inference/model/CorrelationMatrixStatistic.java @@ -38,12 +38,12 @@ public class CorrelationMatrixStatistic extends Statistic.Abstract implements Va private static final String CORRELATION_MATRIX = "correlationMatrix"; - private final MatrixParameter matrix; + private final MatrixParameterInterface matrix; private final double[][] correlation; private boolean corrKnown = false; private final boolean invert; - public CorrelationMatrixStatistic(MatrixParameter matrix, Boolean invert) { + public CorrelationMatrixStatistic(MatrixParameterInterface matrix, Boolean invert) { this.matrix = matrix; this.invert = invert; correlation = new double[matrix.getRowDimension()][matrix.getColumnDimension()]; @@ -100,7 +100,7 @@ public void variableChangedEvent(Variable variable, int index, Variable.ChangeTy @Override public Object parseXMLObject(XMLObject xo) throws XMLParseException { - MatrixParameter matrix = (MatrixParameter) xo.getChild(MatrixParameter.class); + MatrixParameterInterface matrix = (MatrixParameterInterface) xo.getChild(MatrixParameterInterface.class); if (matrix.getColumnDimension() != matrix.getRowDimension()) { throw new XMLParseException("Only square matrices can be converted to correlation matrices"); } @@ -114,7 +114,7 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { @Override public XMLSyntaxRule[] getSyntaxRules() { return new XMLSyntaxRule[]{ - new ElementRule(MatrixParameter.class), + new ElementRule(MatrixParameterInterface.class), AttributeRule.newBooleanRule(INVERT, true) }; } From bf150cad234ba8eed5fc4a61480b403bde200e43 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Wed, 28 Jul 2021 06:27:30 -0700 Subject: [PATCH 012/196] quick dimensionality check in GammaGibbsProvider --- .../operators/repeatedMeasures/GammaGibbsProvider.java | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/dr/inference/operators/repeatedMeasures/GammaGibbsProvider.java b/src/dr/inference/operators/repeatedMeasures/GammaGibbsProvider.java index 6b7f546171..6121270485 100644 --- a/src/dr/inference/operators/repeatedMeasures/GammaGibbsProvider.java +++ b/src/dr/inference/operators/repeatedMeasures/GammaGibbsProvider.java @@ -116,6 +116,7 @@ class NormalExtensionGibbsProvider implements GammaGibbsProvider { private final boolean[] missingVector; private double[] tipValues; + private boolean hasCheckedDimension = false; public NormalExtensionGibbsProvider(ModelExtensionProvider.NormalExtensionProvider dataModel, TreeDataLikelihood treeLikelihood) { @@ -146,6 +147,14 @@ public SufficientStatistics getSufficientStatistics(int dim) { final int taxonCount = treeLikelihood.getTree().getExternalNodeCount(); final int traitDim = dataModel.getDataDimension(); + + if (!hasCheckedDimension) { //TODO: actually check that this works + if (taxonCount * traitDim != tipValues.length) { + throw new RuntimeException("dimensions are incompatible"); + } + hasCheckedDimension = true; + } + int missingCount = 0; double SSE = 0; From 039882acccd41e41d76a165bcc148e8d594801f0 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Wed, 28 Jul 2021 07:47:40 -0700 Subject: [PATCH 013/196] setting up to remove code duplication --- .../inference/model/FastMatrixParameter.java | 18 +-------------- .../model/MatrixParameterInterface.java | 23 +++++++++++++++++++ 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/src/dr/inference/model/FastMatrixParameter.java b/src/dr/inference/model/FastMatrixParameter.java index b131cd9585..532f76bee8 100644 --- a/src/dr/inference/model/FastMatrixParameter.java +++ b/src/dr/inference/model/FastMatrixParameter.java @@ -210,16 +210,6 @@ public int getDimension() { } - private int index(int row, int col) { - // column-major - if(col > getColumnDimension()){ - throw new RuntimeException("Column " + col + " out of bounds: Compared to " + getColumnDimension() + "maximum size."); - } - if(row > getRowDimension()){ - throw new RuntimeException("Row " + row + " out of bounds: Compared to " + getRowDimension() + "maximum size."); - } - return col * rowDimension + row; - } @Override public double getParameterValue(int row, int col) { @@ -281,13 +271,7 @@ public double[] getColumnValues(int col) { @Override public double[][] getParameterAsMatrix() { - double[][] rtn = new double[getRowDimension()][getColumnDimension()]; - for (int j = 0; j < getColumnDimension(); ++j) { - for (int i = 0; i < getRowDimension(); ++i) { - rtn[i][j] = getParameterValue(i, j); - } - } - return rtn; + return MatrixParameterInterface.getParameterAsMatrix(this); } @Override diff --git a/src/dr/inference/model/MatrixParameterInterface.java b/src/dr/inference/model/MatrixParameterInterface.java index d1debcf4d5..144911731b 100644 --- a/src/dr/inference/model/MatrixParameterInterface.java +++ b/src/dr/inference/model/MatrixParameterInterface.java @@ -64,4 +64,27 @@ public interface MatrixParameterInterface extends Parameter { String toSymmetricString(); boolean isConstrainedSymmetric(); + + default int index(int row, int col) { + // column-major + if (col > getColumnDimension()) { + throw new RuntimeException("Column " + col + " out of bounds: Compared to " + getColumnDimension() + "maximum size."); + } + if (row > getRowDimension()) { + throw new RuntimeException("Row " + row + " out of bounds: Compared to " + getRowDimension() + "maximum size."); + } + return col * getRowDimension() + row; + } + + static double[][] getParameterAsMatrix(MatrixParameterInterface parameter) { + int rowDim = parameter.getRowDimension(); + int colDim = parameter.getColumnDimension(); + double[][] rtn = new double[rowDim][colDim]; + for (int j = 0; j < colDim; ++j) { + for (int i = 0; i < rowDim; ++i) { + rtn[i][j] = parameter.getParameterValue(i, j); + } + } + return rtn; + } } From 6faf146894d82971fec4be105dc6fb3161880fe7 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Wed, 28 Jul 2021 07:48:08 -0700 Subject: [PATCH 014/196] reformatting code --- src/dr/inference/model/FastMatrixParameter.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/dr/inference/model/FastMatrixParameter.java b/src/dr/inference/model/FastMatrixParameter.java index 532f76bee8..d51a45e8b1 100644 --- a/src/dr/inference/model/FastMatrixParameter.java +++ b/src/dr/inference/model/FastMatrixParameter.java @@ -186,7 +186,7 @@ public Bounds getBounds() { } @Override - public void fireParameterChangedEvent(int index, ChangeType type){ + public void fireParameterChangedEvent(int index, ChangeType type) { matrix.fireParameterChangedEvent(index + column * getDimension(), type); super.fireParameterChangedEvent(index, ChangeType.VALUE_CHANGED); } @@ -299,11 +299,11 @@ public Parameter getUniqueParameter(int index) { return super.getParameter(0); } - public void addBounds(Bounds boundary){ + public void addBounds(Bounds boundary) { singleParameter.addBounds(boundary); } - public Bounds getBounds(){ + public Bounds getBounds() { return singleParameter.getBounds(); } From d6b347cef448931b5138d672986e8606baecc6d9 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Wed, 28 Jul 2021 07:49:05 -0700 Subject: [PATCH 015/196] matrix-valued transformed parameter --- .../model/MatrixTransformedParameter.java | 100 ++++++++++++++++++ 1 file changed, 100 insertions(+) create mode 100644 src/dr/inference/model/MatrixTransformedParameter.java diff --git a/src/dr/inference/model/MatrixTransformedParameter.java b/src/dr/inference/model/MatrixTransformedParameter.java new file mode 100644 index 0000000000..07453b535f --- /dev/null +++ b/src/dr/inference/model/MatrixTransformedParameter.java @@ -0,0 +1,100 @@ +package dr.inference.model; + +import dr.util.Transform; + + +public class MatrixTransformedParameter extends TransformedParameter implements MatrixParameterInterface { + + private final int rowDim; + private final int colDim; + + public MatrixTransformedParameter(Parameter parameter, Transform transform, Boolean inverse, int rowDim, int colDim) { + super(parameter, transform, inverse); + this.rowDim = rowDim; + this.colDim = colDim; + } + + + public MatrixTransformedParameter(MatrixParameterInterface parameter, Transform transform, Boolean inverse) { + this(parameter, transform, inverse, parameter.getRowDimension(), parameter.getColumnDimension()); + } + + + @Override + public double getParameterValue(int row, int col) { + return getParameterValue(index(row, col)); + } + + @Override + public Parameter getParameter(int column) { + throw new RuntimeException("not yet implemented"); + } + + @Override + public void setParameterValue(int row, int col, double value) { + setParameterValue(index(row, col), value); + } + + @Override + public void setParameterValueQuietly(int row, int col, double value) { + setParameterValueQuietly(index(row, col), value); + } + + @Override + public void setParameterValueNotifyChangedAll(int row, int col, double value) { + setParameterValueNotifyChangedAll(index(row, col), value); + } + + @Override + public double[] getColumnValues(int col) { + throw new RuntimeException("not yet implemented"); + } + + @Override + public double[][] getParameterAsMatrix() { + return MatrixParameterInterface.getParameterAsMatrix(this); + } + + @Override + public int getColumnDimension() { + return colDim; + } + + @Override + public int getRowDimension() { + return rowDim; + } + + @Override + public int getUniqueParameterCount() { + return 1; + } + + @Override + public Parameter getUniqueParameter(int index) { + return parameter; + } + + @Override + public void copyParameterValues(double[] destination, int offset) { + throw new RuntimeException("not yet implemented"); + } + + @Override + public void setAllParameterValuesQuietly(double[] values, int offset) { + throw new RuntimeException("not yet implemented"); + } + + @Override + public String toSymmetricString() { + throw new RuntimeException("not yet implemented"); + } + + @Override + public boolean isConstrainedSymmetric() { + return false; + } + +} + + From dca66f8de0f6fb5ec0efa482f2d6d3d3cd9d2108 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Wed, 28 Jul 2021 07:49:21 -0700 Subject: [PATCH 016/196] parser for MatrixTransformedParameter --- .../model/TransformedParameterParser.java | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/dr/inferencexml/model/TransformedParameterParser.java b/src/dr/inferencexml/model/TransformedParameterParser.java index 0fa924e948..36f4757c09 100644 --- a/src/dr/inferencexml/model/TransformedParameterParser.java +++ b/src/dr/inferencexml/model/TransformedParameterParser.java @@ -25,6 +25,8 @@ package dr.inferencexml.model; +import dr.inference.model.MatrixParameterInterface; +import dr.inference.model.MatrixTransformedParameter; import dr.inference.model.Parameter; import dr.inference.model.TransformedParameter; import dr.util.Transform; @@ -37,6 +39,7 @@ public class TransformedParameterParser extends AbstractXMLObjectParser { public static final String TRANSFORMED_PARAMETER = "transformedParameter"; public static final String INVERSE = "inverse"; + public static final String AS_MATRIX = "asMatrix"; public Object parseXMLObject(XMLObject xo) throws XMLParseException { @@ -44,6 +47,16 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { final Transform.ParsedTransform parsedTransform = (Transform.ParsedTransform) xo.getChild(Transform.ParsedTransform.class); final boolean inverse = xo.getAttribute(INVERSE, false); + final boolean asMatrix = xo.getAttribute(AS_MATRIX, false); + if (asMatrix) { + if (parameter instanceof MatrixParameterInterface) { + return new MatrixTransformedParameter((MatrixParameterInterface) parameter, parsedTransform.transform, inverse); + } else { + throw new XMLParseException("'asMatrix' is 'true' but the supplied parameter is not a matrix. " + + "Not currently implemented."); + } + } + TransformedParameter transformedParameter = new TransformedParameter(parameter, parsedTransform.transform, inverse); return transformedParameter; } @@ -56,7 +69,7 @@ public XMLSyntaxRule[] getSyntaxRules() { new ElementRule(Parameter.class), new ElementRule(Transform.ParsedTransform.class), AttributeRule.newBooleanRule(INVERSE, true), - + AttributeRule.newBooleanRule(AS_MATRIX, true) }; public String getParserDescription() { From fbb4e8e0cdca19e16e243f8888277687569b8bcb Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Wed, 28 Jul 2021 08:47:54 -0700 Subject: [PATCH 017/196] Revert "parser for MatrixTransformedParameter" This reverts commit dca66f8de0f6fb5ec0efa482f2d6d3d3cd9d2108. --- .../model/TransformedParameterParser.java | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/src/dr/inferencexml/model/TransformedParameterParser.java b/src/dr/inferencexml/model/TransformedParameterParser.java index 36f4757c09..0fa924e948 100644 --- a/src/dr/inferencexml/model/TransformedParameterParser.java +++ b/src/dr/inferencexml/model/TransformedParameterParser.java @@ -25,8 +25,6 @@ package dr.inferencexml.model; -import dr.inference.model.MatrixParameterInterface; -import dr.inference.model.MatrixTransformedParameter; import dr.inference.model.Parameter; import dr.inference.model.TransformedParameter; import dr.util.Transform; @@ -39,7 +37,6 @@ public class TransformedParameterParser extends AbstractXMLObjectParser { public static final String TRANSFORMED_PARAMETER = "transformedParameter"; public static final String INVERSE = "inverse"; - public static final String AS_MATRIX = "asMatrix"; public Object parseXMLObject(XMLObject xo) throws XMLParseException { @@ -47,16 +44,6 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { final Transform.ParsedTransform parsedTransform = (Transform.ParsedTransform) xo.getChild(Transform.ParsedTransform.class); final boolean inverse = xo.getAttribute(INVERSE, false); - final boolean asMatrix = xo.getAttribute(AS_MATRIX, false); - if (asMatrix) { - if (parameter instanceof MatrixParameterInterface) { - return new MatrixTransformedParameter((MatrixParameterInterface) parameter, parsedTransform.transform, inverse); - } else { - throw new XMLParseException("'asMatrix' is 'true' but the supplied parameter is not a matrix. " + - "Not currently implemented."); - } - } - TransformedParameter transformedParameter = new TransformedParameter(parameter, parsedTransform.transform, inverse); return transformedParameter; } @@ -69,7 +56,7 @@ public XMLSyntaxRule[] getSyntaxRules() { new ElementRule(Parameter.class), new ElementRule(Transform.ParsedTransform.class), AttributeRule.newBooleanRule(INVERSE, true), - AttributeRule.newBooleanRule(AS_MATRIX, true) + }; public String getParserDescription() { From 8cbe18407576d3b6a90f86662179da5462c6ad28 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Wed, 28 Jul 2021 08:48:44 -0700 Subject: [PATCH 018/196] MatrixTransformedParameter now extends TransformedMultivariateParameter --- .../model/MatrixTransformedParameter.java | 8 ++++--- ...ransformedMultivariateParameterParser.java | 21 ++++++++++++++----- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/src/dr/inference/model/MatrixTransformedParameter.java b/src/dr/inference/model/MatrixTransformedParameter.java index 07453b535f..58d323c94f 100644 --- a/src/dr/inference/model/MatrixTransformedParameter.java +++ b/src/dr/inference/model/MatrixTransformedParameter.java @@ -3,19 +3,21 @@ import dr.util.Transform; -public class MatrixTransformedParameter extends TransformedParameter implements MatrixParameterInterface { +public class MatrixTransformedParameter extends TransformedMultivariateParameter implements MatrixParameterInterface { private final int rowDim; private final int colDim; - public MatrixTransformedParameter(Parameter parameter, Transform transform, Boolean inverse, int rowDim, int colDim) { + public MatrixTransformedParameter(Parameter parameter, Transform.MultivariableTransform transform, + Boolean inverse, int rowDim, int colDim) { super(parameter, transform, inverse); this.rowDim = rowDim; this.colDim = colDim; } - public MatrixTransformedParameter(MatrixParameterInterface parameter, Transform transform, Boolean inverse) { + public MatrixTransformedParameter(MatrixParameterInterface parameter, Transform.MultivariableTransform transform, + Boolean inverse) { this(parameter, transform, inverse, parameter.getRowDimension(), parameter.getColumnDimension()); } diff --git a/src/dr/inferencexml/model/TransformedMultivariateParameterParser.java b/src/dr/inferencexml/model/TransformedMultivariateParameterParser.java index 2b50339c77..8c5dbee146 100644 --- a/src/dr/inferencexml/model/TransformedMultivariateParameterParser.java +++ b/src/dr/inferencexml/model/TransformedMultivariateParameterParser.java @@ -25,9 +25,7 @@ package dr.inferencexml.model; -import dr.inference.model.Bounds; -import dr.inference.model.Parameter; -import dr.inference.model.TransformedMultivariateParameter; +import dr.inference.model.*; import dr.util.Transform; import dr.xml.*; @@ -36,6 +34,7 @@ public class TransformedMultivariateParameterParser extends AbstractXMLObjectPar private static final String TRANSFORMED_MULTIVARIATE_PARAMETER = "transformedMultivariateParameter"; public static final String INVERSE = "inverse"; private static final String BOUNDS = "bounds"; + private static final String AS_MATRIX = "asMatrix"; public Object parseXMLObject(XMLObject xo) throws XMLParseException { @@ -44,8 +43,19 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { xo.getChild(Transform.MultivariableTransform.class); final boolean inverse = xo.getAttribute(INVERSE, false); - TransformedMultivariateParameter transformedParameter - = new TransformedMultivariateParameter(parameter, transform, inverse); + final TransformedMultivariateParameter transformedParameter; + final boolean asMatrix = xo.getAttribute(AS_MATRIX, false); + if (asMatrix) { + if (parameter instanceof MatrixParameterInterface) { + transformedParameter = new MatrixTransformedParameter((MatrixParameterInterface) parameter, transform, inverse); + } else { + throw new XMLParseException("'asMatrix' is 'true' but the supplied parameter is not a matrix. " + + "Not currently implemented."); + } + } else { + transformedParameter = new TransformedMultivariateParameter(parameter, transform, inverse); + } + if (xo.hasChildNamed(BOUNDS)) { Bounds bounds = ((Parameter) xo.getElementFirstChild(BOUNDS)).getBounds(); transformedParameter.addBounds(bounds); @@ -64,6 +74,7 @@ public XMLSyntaxRule[] getSyntaxRules() { new ElementRule(Parameter.class), new ElementRule(Transform.MultivariableTransform.class), AttributeRule.newBooleanRule(INVERSE, true), + AttributeRule.newBooleanRule(AS_MATRIX, true), }; From 514205a3f37b9e0a8d8f24492a08f994cb056dd7 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Wed, 28 Jul 2021 08:49:50 -0700 Subject: [PATCH 019/196] refactoring MatrixTransformedParameter -> TransformedMatrixParameter --- ...formedParameter.java => TransformedMatrixParameter.java} | 6 +++--- .../model/TransformedMultivariateParameterParser.java | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) rename src/dr/inference/model/{MatrixTransformedParameter.java => TransformedMatrixParameter.java} (92%) diff --git a/src/dr/inference/model/MatrixTransformedParameter.java b/src/dr/inference/model/TransformedMatrixParameter.java similarity index 92% rename from src/dr/inference/model/MatrixTransformedParameter.java rename to src/dr/inference/model/TransformedMatrixParameter.java index 58d323c94f..b8cc9daf58 100644 --- a/src/dr/inference/model/MatrixTransformedParameter.java +++ b/src/dr/inference/model/TransformedMatrixParameter.java @@ -3,12 +3,12 @@ import dr.util.Transform; -public class MatrixTransformedParameter extends TransformedMultivariateParameter implements MatrixParameterInterface { +public class TransformedMatrixParameter extends TransformedMultivariateParameter implements MatrixParameterInterface { private final int rowDim; private final int colDim; - public MatrixTransformedParameter(Parameter parameter, Transform.MultivariableTransform transform, + public TransformedMatrixParameter(Parameter parameter, Transform.MultivariableTransform transform, Boolean inverse, int rowDim, int colDim) { super(parameter, transform, inverse); this.rowDim = rowDim; @@ -16,7 +16,7 @@ public MatrixTransformedParameter(Parameter parameter, Transform.MultivariableTr } - public MatrixTransformedParameter(MatrixParameterInterface parameter, Transform.MultivariableTransform transform, + public TransformedMatrixParameter(MatrixParameterInterface parameter, Transform.MultivariableTransform transform, Boolean inverse) { this(parameter, transform, inverse, parameter.getRowDimension(), parameter.getColumnDimension()); } diff --git a/src/dr/inferencexml/model/TransformedMultivariateParameterParser.java b/src/dr/inferencexml/model/TransformedMultivariateParameterParser.java index 8c5dbee146..556f12b4fa 100644 --- a/src/dr/inferencexml/model/TransformedMultivariateParameterParser.java +++ b/src/dr/inferencexml/model/TransformedMultivariateParameterParser.java @@ -47,7 +47,7 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { final boolean asMatrix = xo.getAttribute(AS_MATRIX, false); if (asMatrix) { if (parameter instanceof MatrixParameterInterface) { - transformedParameter = new MatrixTransformedParameter((MatrixParameterInterface) parameter, transform, inverse); + transformedParameter = new TransformedMatrixParameter((MatrixParameterInterface) parameter, transform, inverse); } else { throw new XMLParseException("'asMatrix' is 'true' but the supplied parameter is not a matrix. " + "Not currently implemented."); From 78e0545a318199a73f79fdee0adad44ba2d99fa9 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Wed, 18 Aug 2021 14:37:36 -0700 Subject: [PATCH 020/196] updating geodesicHMC --- ...GeodesicHamiltonianMonteCarloOperator.java | 184 +++++++++++------- ...icHamiltonianMonteCarloOperatorParser.java | 21 +- 2 files changed, 134 insertions(+), 71 deletions(-) diff --git a/src/dr/inference/operators/hmc/GeodesicHamiltonianMonteCarloOperator.java b/src/dr/inference/operators/hmc/GeodesicHamiltonianMonteCarloOperator.java index 5c12fd1604..bd3e3311a8 100644 --- a/src/dr/inference/operators/hmc/GeodesicHamiltonianMonteCarloOperator.java +++ b/src/dr/inference/operators/hmc/GeodesicHamiltonianMonteCarloOperator.java @@ -21,7 +21,10 @@ public class GeodesicHamiltonianMonteCarloOperator extends HamiltonianMonteCarloOperator implements Reportable { - public GeodesicHamiltonianMonteCarloOperator(AdaptationMode mode, double weight, GradientWrtParameterProvider gradientProvider, Parameter parameter, Transform transform, Parameter maskParameter, Options runtimeOptions, MassPreconditioner preconditioner) { + public GeodesicHamiltonianMonteCarloOperator(AdaptationMode mode, double weight, + GradientWrtParameterProvider gradientProvider, Parameter parameter, + Transform transform, Parameter maskParameter, Options runtimeOptions, + MassPreconditioner preconditioner) { super(mode, weight, gradientProvider, parameter, transform, maskParameter, runtimeOptions, preconditioner); this.leapFrogEngine = new GeodesicLeapFrogEngine(parameter, getDefaultInstabilityHandler(), preconditioning, mask); } @@ -79,6 +82,10 @@ public String getReport() { return sb.toString(); } + public void setOrthogonalityStructure(ArrayList oldOrthogonalityStructure) { + ((GeodesicLeapFrogEngine) leapFrogEngine).setOrthogonalityStructure(oldOrthogonalityStructure); + } + public static class GeodesicLeapFrogEngine extends HamiltonianMonteCarloOperator.LeapFrogEngine.Default { @@ -93,6 +100,7 @@ public static class GeodesicLeapFrogEngine extends HamiltonianMonteCarloOperator private final int[] subRows; private final int[] subColumns; + private final ArrayList orthogonalityStructure; GeodesicLeapFrogEngine(Parameter parameter, HamiltonianMonteCarloOperator.InstabilityHandler instabilityHandler, @@ -100,11 +108,13 @@ public static class GeodesicLeapFrogEngine extends HamiltonianMonteCarloOperator super(parameter, instabilityHandler, preconditioning, mask); this.matrixParameter = (MatrixParameterInterface) parameter; - this.subRows = parseSubRowsFromMask(); this.subColumns = parseSubColumnsFromMask(); if (mask != null) checkMask(subRows, subColumns); + this.orthogonalityStructure = new ArrayList<>(); + orthogonalityStructure.add(subRows); + this.nRows = subRows.length; this.nCols = subColumns.length; this.positionMatrix = new DenseMatrix64F(nCols, nRows); @@ -114,6 +124,38 @@ public static class GeodesicLeapFrogEngine extends HamiltonianMonteCarloOperator this.momentumMatrix = new DenseMatrix64F(nCols, nRows); } + public void setOrthogonalityStructure(ArrayList oldOrthogonalityStructure) { + orthogonalityStructure.clear(); + + ArrayList subRowList = new ArrayList<>(); + for (int i : subRows) { + subRowList.add(i); + } + + //check that orthogonalityStructure is consistent with the subRows + ArrayList alreadyOrthogonal = new ArrayList<>(); + + for (int i = 0; i < oldOrthogonalityStructure.size(); i++) { + for (int j = 0; j < oldOrthogonalityStructure.get(i).length; j++) { + if (!subRowList.contains(oldOrthogonalityStructure.get(i)[j])) { //TODO: check that we're doing this by row (or allow to do by row or column) + throw new RuntimeException("Cannot enforce orthogonality structure."); + } + if (alreadyOrthogonal.contains(oldOrthogonalityStructure.get(i)[j])) { + throw new RuntimeException("Orthogonal blocks must be non-overlapping"); + } + alreadyOrthogonal.add(oldOrthogonalityStructure.get(i)[j]); + orthogonalityStructure.add(oldOrthogonalityStructure.get(i)); + } + } + + for (int i = 0; i < subRows.length; i++) { + if (!alreadyOrthogonal.contains(subRows[i])) { + orthogonalityStructure.add(new int[]{subRows[i]}); + } + } + + } + private int[] parseSubColumnsFromMask() { int originalRows = matrixParameter.getRowDimension(); @@ -212,18 +254,19 @@ private void checkMask(int[] rows, int[] cols) { } } - private void setSubMatrix(double[] src, int srcOffset, DenseMatrix64F dest) { + private void setOrthogonalSubMatrix(double[] src, int srcOffset, int block, DenseMatrix64F dest) { int nRowsOriginal = matrixParameter.getRowDimension(); - for (int row = 0; row < subRows.length; row++) { + int[] blockRows = orthogonalityStructure.get(block); + for (int row = 0; row < blockRows.length; row++) { for (int col = 0; col < subColumns.length; col++) { - int ind = nRowsOriginal * subColumns[col] + subRows[row] + srcOffset; + int ind = nRowsOriginal * subColumns[col] + blockRows[row] + srcOffset; dest.set(col, row, src[ind]); } } } - private void setSubMatrix(double[] src, DenseMatrix64F dest) { - setSubMatrix(src, 0, dest); + private void setOrthogonalSubMatrix(double[] src, int block, DenseMatrix64F dest) { + setOrthogonalSubMatrix(src, 0, block, dest); } private void unwrapSubMatrix(DenseMatrix64F src, double[] dest, int destOffset) { @@ -252,81 +295,84 @@ public void updateMomentum(double[] position, double[] momentum, double[] gradie public void updatePosition(double[] position, WrappedVector momentum, double functionalStepSize) throws HamiltonianMonteCarloOperator.NumericInstabilityException { + for (int block = 0; block < orthogonalityStructure.size(); block++) { + // positionMatrix.setData(position); - setSubMatrix(position, positionMatrix); - setSubMatrix(momentum.getBuffer(), momentum.getOffset(), momentumMatrix); + setOrthogonalSubMatrix(position, block, positionMatrix); + setOrthogonalSubMatrix(momentum.getBuffer(), momentum.getOffset(), block, momentumMatrix); // System.arraycopy(momentum.getBuffer(), momentum.getOffset(), momentumMatrix.data, 0, momentum.getDim()); - CommonOps.multTransB(positionMatrix, momentumMatrix, innerProduct); - CommonOps.multTransB(momentumMatrix, momentumMatrix, innerProduct2); + CommonOps.multTransB(positionMatrix, momentumMatrix, innerProduct); + CommonOps.multTransB(momentumMatrix, momentumMatrix, innerProduct2); - double[][] VtV = new double[2 * nCols][2 * nCols]; + double[][] VtV = new double[2 * nCols][2 * nCols]; - for (int i = 0; i < nCols; i++) { - VtV[i + nCols][i] = 1; - for (int j = 0; j < nCols; j++) { - VtV[i][j] = innerProduct.get(i, j); - VtV[i + nCols][j + nCols] = innerProduct.get(i, j); - VtV[i][j + nCols] = -innerProduct2.get(j, i); + for (int i = 0; i < nCols; i++) { + VtV[i + nCols][i] = 1; + for (int j = 0; j < nCols; j++) { + VtV[i][j] = innerProduct.get(i, j); + VtV[i + nCols][j + nCols] = innerProduct.get(i, j); + VtV[i][j + nCols] = -innerProduct2.get(j, i); + } } - } - double[] expBuffer = new double[nCols * nCols]; - CommonOps.scale(-functionalStepSize, innerProduct); - SkewSymmetricMatrixExponential matExp1 = new SkewSymmetricMatrixExponential(nCols); - matExp1.exponentiate(innerProduct.data, expBuffer); + double[] expBuffer = new double[nCols * nCols]; + CommonOps.scale(-functionalStepSize, innerProduct); + SkewSymmetricMatrixExponential matExp1 = new SkewSymmetricMatrixExponential(nCols); + matExp1.exponentiate(innerProduct.data, expBuffer); - double[] expBuffer2 = new double[nCols * nCols * 4]; - SkewSymmetricMatrixExponential matExp2 = new SkewSymmetricMatrixExponential(nCols * 2); //TODO: better matrix exponential - DenseMatrix64F VtVmat = new DenseMatrix64F(VtV); - CommonOps.scale(functionalStepSize, VtVmat); - matExp2.exponentiate(VtVmat.data, expBuffer2); + double[] expBuffer2 = new double[nCols * nCols * 4]; + SkewSymmetricMatrixExponential matExp2 = new SkewSymmetricMatrixExponential(nCols * 2); //TODO: better matrix exponential + DenseMatrix64F VtVmat = new DenseMatrix64F(VtV); + CommonOps.scale(functionalStepSize, VtVmat); + matExp2.exponentiate(VtVmat.data, expBuffer2); - DenseMatrix64F X = new DenseMatrix64F(nCols * 2, nCols * 2); - DenseMatrix64F Y = new DenseMatrix64F(nCols * 2, nCols * 2); + DenseMatrix64F X = new DenseMatrix64F(nCols * 2, nCols * 2); + DenseMatrix64F Y = new DenseMatrix64F(nCols * 2, nCols * 2); - for (int i = 0; i < nCols; i++) { - for (int j = 0; j < nCols; j++) { - X.set(i, j, expBuffer[i * nCols + j]); - X.set(i + nCols, j + nCols, expBuffer[i * nCols + j]); + for (int i = 0; i < nCols; i++) { + for (int j = 0; j < nCols; j++) { + X.set(i, j, expBuffer[i * nCols + j]); + X.set(i + nCols, j + nCols, expBuffer[i * nCols + j]); + } } - } - Y.setData(expBuffer2); + Y.setData(expBuffer2); - DenseMatrix64F Z = new DenseMatrix64F(nCols * 2, nCols * 2); + DenseMatrix64F Z = new DenseMatrix64F(nCols * 2, nCols * 2); - CommonOps.mult(Y, X, Z); + CommonOps.mult(Y, X, Z); - DenseMatrix64F PM = new DenseMatrix64F(nCols * 2, nRows); - for (int i = 0; i < nRows; i++) { - for (int j = 0; j < nCols; j++) { - PM.set(j, i, positionMatrix.get(j, i)); - PM.set(j + nCols, i, momentumMatrix.get(j, i)); + DenseMatrix64F PM = new DenseMatrix64F(nCols * 2, nRows); + for (int i = 0; i < nRows; i++) { + for (int j = 0; j < nCols; j++) { + PM.set(j, i, positionMatrix.get(j, i)); + PM.set(j + nCols, i, momentumMatrix.get(j, i)); + } } - } - DenseMatrix64F W = new DenseMatrix64F(2 * nCols, nRows); - CommonOps.transpose(Z); - CommonOps.mult(Z, PM, W); + DenseMatrix64F W = new DenseMatrix64F(2 * nCols, nRows); + CommonOps.transpose(Z); + CommonOps.mult(Z, PM, W); - for (int i = 0; i < nRows; i++) { - for (int j = 0; j < nCols; j++) { - positionMatrix.set(j, i, W.get(j, i)); - momentumMatrix.set(j, i, W.get(j + nCols, i)); + for (int i = 0; i < nRows; i++) { + for (int j = 0; j < nCols; j++) { + positionMatrix.set(j, i, W.get(j, i)); + momentumMatrix.set(j, i, W.get(j + nCols, i)); + } } - } - //TODO: only run chunk below occasionally - CommonOps.multTransB(positionMatrix, positionMatrix, innerProduct); - CholeskyDecomposition cholesky = DecompositionFactory.chol(nCols, true); - cholesky.decompose(innerProduct); - TriangularSolver.invertLower(innerProduct.data, nCols); - CommonOps.mult(innerProduct, positionMatrix, projection); - System.arraycopy(projection.data, 0, positionMatrix.data, 0, positionMatrix.data.length); + //TODO: only run chunk below occasionally + CommonOps.multTransB(positionMatrix, positionMatrix, innerProduct); + CholeskyDecomposition cholesky = DecompositionFactory.chol(nCols, true); + cholesky.decompose(innerProduct); + TriangularSolver.invertLower(innerProduct.data, nCols); + CommonOps.mult(innerProduct, positionMatrix, projection); + System.arraycopy(projection.data, 0, positionMatrix.data, 0, positionMatrix.data.length); - unwrapSubMatrix(positionMatrix, position); - unwrapSubMatrix(momentumMatrix, momentum.getBuffer(), momentum.getOffset()); + unwrapSubMatrix(positionMatrix, position); + unwrapSubMatrix(momentumMatrix, momentum.getBuffer(), momentum.getOffset()); // System.arraycopy(positionMatrix.data, 0, position, 0, position.length); // System.arraycopy(momentumMatrix.data, 0, momentum.getBuffer(), momentum.getOffset(), momentum.getDim()); + } matrixParameter.setAllParameterValuesQuietly(position, 0); matrixParameter.fireParameterChangedEvent(); @@ -334,18 +380,20 @@ public void updatePosition(double[] position, WrappedVector momentum, @Override public void projectMomentum(double[] momentum, double[] position) { - setSubMatrix(position, positionMatrix); - setSubMatrix(momentum, momentumMatrix); + for (int block = 0; block < orthogonalityStructure.size(); block++) { + setOrthogonalSubMatrix(position, block, positionMatrix); + setOrthogonalSubMatrix(momentum, block, momentumMatrix); // positionMatrix.setData(position); // momentumMatrix.setData(momentum); - CommonOps.multTransB(positionMatrix, momentumMatrix, innerProduct); - EJMLUtils.addWithTransposed(innerProduct); + CommonOps.multTransB(positionMatrix, momentumMatrix, innerProduct); + EJMLUtils.addWithTransposed(innerProduct); - CommonOps.mult(0.5, innerProduct, positionMatrix, projection); - CommonOps.subtractEquals(momentumMatrix, projection); + CommonOps.mult(0.5, innerProduct, positionMatrix, projection); + CommonOps.subtractEquals(momentumMatrix, projection); - unwrapSubMatrix(momentumMatrix, momentum); + unwrapSubMatrix(momentumMatrix, momentum); + } } } } diff --git a/src/dr/inferencexml/operators/hmc/GeodesicHamiltonianMonteCarloOperatorParser.java b/src/dr/inferencexml/operators/hmc/GeodesicHamiltonianMonteCarloOperatorParser.java index bdb824bfb9..0f96689b08 100644 --- a/src/dr/inferencexml/operators/hmc/GeodesicHamiltonianMonteCarloOperatorParser.java +++ b/src/dr/inferencexml/operators/hmc/GeodesicHamiltonianMonteCarloOperatorParser.java @@ -1,7 +1,6 @@ package dr.inferencexml.operators.hmc; import dr.inference.hmc.GradientWrtParameterProvider; -import dr.inference.hmc.ReversibleHMCProvider; import dr.inference.model.Parameter; import dr.inference.operators.AdaptationMode; import dr.inference.operators.hmc.GeodesicHamiltonianMonteCarloOperator; @@ -13,6 +12,8 @@ import dr.xml.XMLParseException; import dr.xml.XMLSyntaxRule; +import java.util.ArrayList; + /** * @author Gabriel Hassler * @author Marc A. Suchard @@ -20,10 +21,24 @@ public class GeodesicHamiltonianMonteCarloOperatorParser extends HamiltonianMonteCarloOperatorParser { public final static String OPERATOR_NAME = "geodesicHamiltonianMonteCarloOperator"; + public final static String ORTHOGONALITY_STRUCTURE = "orthogonalityStructure"; + public final static String ROWS = "rows"; @Override public Object parseXMLObject(XMLObject xo) throws XMLParseException { - return super.parseXMLObject(xo); + GeodesicHamiltonianMonteCarloOperator hmc = (GeodesicHamiltonianMonteCarloOperator) super.parseXMLObject(xo); + if (xo.hasChildNamed(ORTHOGONALITY_STRUCTURE)) { + XMLObject cxo = xo.getChild(ORTHOGONALITY_STRUCTURE); + ArrayList orthogonalityStructure = new ArrayList<>(); + for (int i = 0; i < cxo.getChildCount(); i++) { + XMLObject group = (XMLObject) xo.getChild(i); + orthogonalityStructure.add(xo.getIntegerArrayAttribute(ROWS)); + } + + hmc.setOrthogonalityStructure(orthogonalityStructure); + } + + return hmc; } @Override @@ -39,7 +54,7 @@ protected HamiltonianMonteCarloOperator factory(AdaptationMode adaptationMode, d @Override public XMLSyntaxRule[] getSyntaxRules() { return rules; - } + } //TODO: add orthogonality structure rules @Override From d34fdb33d7c715e343ed529d44213b1d6aef6b26 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Wed, 18 Aug 2021 14:50:58 -0700 Subject: [PATCH 021/196] bug fixes to geodesic hmc --- .../hmc/GeodesicHamiltonianMonteCarloOperator.java | 2 +- .../hmc/GeodesicHamiltonianMonteCarloOperatorParser.java | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/dr/inference/operators/hmc/GeodesicHamiltonianMonteCarloOperator.java b/src/dr/inference/operators/hmc/GeodesicHamiltonianMonteCarloOperator.java index bd3e3311a8..d2c95d9039 100644 --- a/src/dr/inference/operators/hmc/GeodesicHamiltonianMonteCarloOperator.java +++ b/src/dr/inference/operators/hmc/GeodesicHamiltonianMonteCarloOperator.java @@ -144,8 +144,8 @@ public void setOrthogonalityStructure(ArrayList oldOrthogonalityStructure throw new RuntimeException("Orthogonal blocks must be non-overlapping"); } alreadyOrthogonal.add(oldOrthogonalityStructure.get(i)[j]); - orthogonalityStructure.add(oldOrthogonalityStructure.get(i)); } + orthogonalityStructure.add(oldOrthogonalityStructure.get(i)); } for (int i = 0; i < subRows.length; i++) { diff --git a/src/dr/inferencexml/operators/hmc/GeodesicHamiltonianMonteCarloOperatorParser.java b/src/dr/inferencexml/operators/hmc/GeodesicHamiltonianMonteCarloOperatorParser.java index 0f96689b08..f92522100a 100644 --- a/src/dr/inferencexml/operators/hmc/GeodesicHamiltonianMonteCarloOperatorParser.java +++ b/src/dr/inferencexml/operators/hmc/GeodesicHamiltonianMonteCarloOperatorParser.java @@ -31,8 +31,12 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { XMLObject cxo = xo.getChild(ORTHOGONALITY_STRUCTURE); ArrayList orthogonalityStructure = new ArrayList<>(); for (int i = 0; i < cxo.getChildCount(); i++) { - XMLObject group = (XMLObject) xo.getChild(i); - orthogonalityStructure.add(xo.getIntegerArrayAttribute(ROWS)); + XMLObject group = (XMLObject) cxo.getChild(i); + int[] rows = group.getIntegerArrayAttribute(ROWS); + for (int j = 0; j < rows.length; j++) { + rows[j]--; + } + orthogonalityStructure.add(rows); } hmc.setOrthogonalityStructure(orthogonalityStructure); From 06bb0287e3f3ec257ad3198b0a0629bfba68f549 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Wed, 18 Aug 2021 15:47:00 -0700 Subject: [PATCH 022/196] more geodesicHMC bug fixes --- ...GeodesicHamiltonianMonteCarloOperator.java | 72 ++++++++++++------- 1 file changed, 45 insertions(+), 27 deletions(-) diff --git a/src/dr/inference/operators/hmc/GeodesicHamiltonianMonteCarloOperator.java b/src/dr/inference/operators/hmc/GeodesicHamiltonianMonteCarloOperator.java index d2c95d9039..32e38c7d98 100644 --- a/src/dr/inference/operators/hmc/GeodesicHamiltonianMonteCarloOperator.java +++ b/src/dr/inference/operators/hmc/GeodesicHamiltonianMonteCarloOperator.java @@ -90,12 +90,12 @@ public void setOrthogonalityStructure(ArrayList oldOrthogonalityStructure public static class GeodesicLeapFrogEngine extends HamiltonianMonteCarloOperator.LeapFrogEngine.Default { private final MatrixParameterInterface matrixParameter; - private final DenseMatrix64F positionMatrix; - private final DenseMatrix64F innerProduct; - private final DenseMatrix64F innerProduct2; - private final DenseMatrix64F projection; - private final DenseMatrix64F momentumMatrix; - private final int nRows; + // private final DenseMatrix64F positionMatrix; +// private final DenseMatrix64F innerProduct; +// private final DenseMatrix64F innerProduct2; +// private final DenseMatrix64F projection; +// private final DenseMatrix64F momentumMatrix; +// private final int nRows; private final int nCols; private final int[] subRows; @@ -115,13 +115,13 @@ public static class GeodesicLeapFrogEngine extends HamiltonianMonteCarloOperator this.orthogonalityStructure = new ArrayList<>(); orthogonalityStructure.add(subRows); - this.nRows = subRows.length; +// this.nRows = subRows.length; this.nCols = subColumns.length; - this.positionMatrix = new DenseMatrix64F(nCols, nRows); - this.innerProduct = new DenseMatrix64F(nCols, nCols); - this.innerProduct2 = new DenseMatrix64F(nCols, nCols); - this.projection = new DenseMatrix64F(nCols, nRows); - this.momentumMatrix = new DenseMatrix64F(nCols, nRows); +// this.positionMatrix = new DenseMatrix64F(nCols, nRows); +// this.innerProduct = new DenseMatrix64F(nCols, nCols); +// this.innerProduct2 = new DenseMatrix64F(nCols, nCols); +// this.projection = new DenseMatrix64F(nCols, nRows); +// this.momentumMatrix = new DenseMatrix64F(nCols, nRows); } public void setOrthogonalityStructure(ArrayList oldOrthogonalityStructure) { @@ -254,7 +254,9 @@ private void checkMask(int[] rows, int[] cols) { } } - private void setOrthogonalSubMatrix(double[] src, int srcOffset, int block, DenseMatrix64F dest) { + private DenseMatrix64F setOrthogonalSubMatrix(double[] src, int srcOffset, int block) { + DenseMatrix64F dest = new DenseMatrix64F(nCols, orthogonalityStructure.get(block).length); + int nRowsOriginal = matrixParameter.getRowDimension(); int[] blockRows = orthogonalityStructure.get(block); for (int row = 0; row < blockRows.length; row++) { @@ -263,24 +265,27 @@ private void setOrthogonalSubMatrix(double[] src, int srcOffset, int block, Dens dest.set(col, row, src[ind]); } } + + return dest; } - private void setOrthogonalSubMatrix(double[] src, int block, DenseMatrix64F dest) { - setOrthogonalSubMatrix(src, 0, block, dest); + private DenseMatrix64F setOrthogonalSubMatrix(double[] src, int block) { + return setOrthogonalSubMatrix(src, 0, block); } - private void unwrapSubMatrix(DenseMatrix64F src, double[] dest, int destOffset) { + private void unwrapSubMatrix(DenseMatrix64F src, int block, double[] dest, int destOffset) { int nRowsOriginal = matrixParameter.getRowDimension(); - for (int row = 0; row < nRows; row++) { + int[] blockRows = orthogonalityStructure.get(block); + for (int row = 0; row < blockRows.length; row++) { for (int col = 0; col < nCols; col++) { - int ind = nRowsOriginal * subColumns[col] + subRows[row] + destOffset; + int ind = nRowsOriginal * subColumns[col] + blockRows[row] + destOffset; dest[ind] = src.get(col, row); } } } - private void unwrapSubMatrix(DenseMatrix64F src, double[] dest) { - unwrapSubMatrix(src, dest, 0); + private void unwrapSubMatrix(DenseMatrix64F src, int block, double[] dest) { + unwrapSubMatrix(src, block, dest, 0); } @Override @@ -297,10 +302,16 @@ public void updatePosition(double[] position, WrappedVector momentum, for (int block = 0; block < orthogonalityStructure.size(); block++) { + int nRows = orthogonalityStructure.get(block).length; + // positionMatrix.setData(position); - setOrthogonalSubMatrix(position, block, positionMatrix); - setOrthogonalSubMatrix(momentum.getBuffer(), momentum.getOffset(), block, momentumMatrix); + DenseMatrix64F positionMatrix = setOrthogonalSubMatrix(position, block); + DenseMatrix64F momentumMatrix = setOrthogonalSubMatrix(momentum.getBuffer(), momentum.getOffset(), block); // System.arraycopy(momentum.getBuffer(), momentum.getOffset(), momentumMatrix.data, 0, momentum.getDim()); + + DenseMatrix64F innerProduct = new DenseMatrix64F(nCols, nCols); + DenseMatrix64F innerProduct2 = new DenseMatrix64F(nCols, nCols); + CommonOps.multTransB(positionMatrix, momentumMatrix, innerProduct); CommonOps.multTransB(momentumMatrix, momentumMatrix, innerProduct2); @@ -365,11 +376,14 @@ public void updatePosition(double[] position, WrappedVector momentum, CholeskyDecomposition cholesky = DecompositionFactory.chol(nCols, true); cholesky.decompose(innerProduct); TriangularSolver.invertLower(innerProduct.data, nCols); + + DenseMatrix64F projection = new DenseMatrix64F(nCols, nRows); + CommonOps.mult(innerProduct, positionMatrix, projection); System.arraycopy(projection.data, 0, positionMatrix.data, 0, positionMatrix.data.length); - unwrapSubMatrix(positionMatrix, position); - unwrapSubMatrix(momentumMatrix, momentum.getBuffer(), momentum.getOffset()); + unwrapSubMatrix(positionMatrix, block, position); + unwrapSubMatrix(momentumMatrix, block, momentum.getBuffer(), momentum.getOffset()); // System.arraycopy(positionMatrix.data, 0, position, 0, position.length); // System.arraycopy(momentumMatrix.data, 0, momentum.getBuffer(), momentum.getOffset(), momentum.getDim()); } @@ -381,18 +395,22 @@ public void updatePosition(double[] position, WrappedVector momentum, @Override public void projectMomentum(double[] momentum, double[] position) { for (int block = 0; block < orthogonalityStructure.size(); block++) { - setOrthogonalSubMatrix(position, block, positionMatrix); - setOrthogonalSubMatrix(momentum, block, momentumMatrix); + DenseMatrix64F positionMatrix = setOrthogonalSubMatrix(position, block); + DenseMatrix64F momentumMatrix = setOrthogonalSubMatrix(momentum, block); // positionMatrix.setData(position); // momentumMatrix.setData(momentum); + DenseMatrix64F innerProduct = new DenseMatrix64F(nCols, nCols); + CommonOps.multTransB(positionMatrix, momentumMatrix, innerProduct); EJMLUtils.addWithTransposed(innerProduct); + DenseMatrix64F projection = new DenseMatrix64F(nCols, orthogonalityStructure.get(block).length); + CommonOps.mult(0.5, innerProduct, positionMatrix, projection); CommonOps.subtractEquals(momentumMatrix, projection); - unwrapSubMatrix(momentumMatrix, momentum); + unwrapSubMatrix(momentumMatrix, block, momentum); } } } From f68956d21a2e9224d95962ce5ef12bb9d7b526d7 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Thu, 19 Aug 2021 10:46:42 -0700 Subject: [PATCH 023/196] another bug fix in geoHMC --- .../hmc/GeodesicHamiltonianMonteCarloOperator.java | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/dr/inference/operators/hmc/GeodesicHamiltonianMonteCarloOperator.java b/src/dr/inference/operators/hmc/GeodesicHamiltonianMonteCarloOperator.java index 32e38c7d98..1d5d1e8bc6 100644 --- a/src/dr/inference/operators/hmc/GeodesicHamiltonianMonteCarloOperator.java +++ b/src/dr/inference/operators/hmc/GeodesicHamiltonianMonteCarloOperator.java @@ -372,14 +372,15 @@ public void updatePosition(double[] position, WrappedVector momentum, } //TODO: only run chunk below occasionally - CommonOps.multTransB(positionMatrix, positionMatrix, innerProduct); + innerProduct = new DenseMatrix64F(nRows, nRows); + CommonOps.multTransA(positionMatrix, positionMatrix, innerProduct); CholeskyDecomposition cholesky = DecompositionFactory.chol(nCols, true); cholesky.decompose(innerProduct); - TriangularSolver.invertLower(innerProduct.data, nCols); + TriangularSolver.invertLower(innerProduct.data, nRows); DenseMatrix64F projection = new DenseMatrix64F(nCols, nRows); - CommonOps.mult(innerProduct, positionMatrix, projection); + CommonOps.mult(positionMatrix, innerProduct, projection); System.arraycopy(projection.data, 0, positionMatrix.data, 0, positionMatrix.data.length); unwrapSubMatrix(positionMatrix, block, position); From 8dd0d218a4481de4b5027846f45c3d35422fbb45 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Thu, 19 Aug 2021 15:30:02 -0700 Subject: [PATCH 024/196] Revert "another bug fix in geoHMC" This reverts commit f68956d21a2e9224d95962ce5ef12bb9d7b526d7. --- .../hmc/GeodesicHamiltonianMonteCarloOperator.java | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/dr/inference/operators/hmc/GeodesicHamiltonianMonteCarloOperator.java b/src/dr/inference/operators/hmc/GeodesicHamiltonianMonteCarloOperator.java index 1d5d1e8bc6..32e38c7d98 100644 --- a/src/dr/inference/operators/hmc/GeodesicHamiltonianMonteCarloOperator.java +++ b/src/dr/inference/operators/hmc/GeodesicHamiltonianMonteCarloOperator.java @@ -372,15 +372,14 @@ public void updatePosition(double[] position, WrappedVector momentum, } //TODO: only run chunk below occasionally - innerProduct = new DenseMatrix64F(nRows, nRows); - CommonOps.multTransA(positionMatrix, positionMatrix, innerProduct); + CommonOps.multTransB(positionMatrix, positionMatrix, innerProduct); CholeskyDecomposition cholesky = DecompositionFactory.chol(nCols, true); cholesky.decompose(innerProduct); - TriangularSolver.invertLower(innerProduct.data, nRows); + TriangularSolver.invertLower(innerProduct.data, nCols); DenseMatrix64F projection = new DenseMatrix64F(nCols, nRows); - CommonOps.mult(positionMatrix, innerProduct, projection); + CommonOps.mult(innerProduct, positionMatrix, projection); System.arraycopy(projection.data, 0, positionMatrix.data, 0, positionMatrix.data.length); unwrapSubMatrix(positionMatrix, block, position); From 04f6202b3d914bbe86c9385617d0c0a0a1f015b9 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Thu, 19 Aug 2021 15:53:17 -0700 Subject: [PATCH 025/196] mixed up 'rows' and 'columns' --- ...GeodesicHamiltonianMonteCarloOperator.java | 45 ++++++++++--------- 1 file changed, 25 insertions(+), 20 deletions(-) diff --git a/src/dr/inference/operators/hmc/GeodesicHamiltonianMonteCarloOperator.java b/src/dr/inference/operators/hmc/GeodesicHamiltonianMonteCarloOperator.java index 32e38c7d98..09754e82b8 100644 --- a/src/dr/inference/operators/hmc/GeodesicHamiltonianMonteCarloOperator.java +++ b/src/dr/inference/operators/hmc/GeodesicHamiltonianMonteCarloOperator.java @@ -95,8 +95,8 @@ public static class GeodesicLeapFrogEngine extends HamiltonianMonteCarloOperator // private final DenseMatrix64F innerProduct2; // private final DenseMatrix64F projection; // private final DenseMatrix64F momentumMatrix; -// private final int nRows; - private final int nCols; + private final int nRows; +// private final int nCols; private final int[] subRows; private final int[] subColumns; @@ -113,10 +113,10 @@ public static class GeodesicLeapFrogEngine extends HamiltonianMonteCarloOperator if (mask != null) checkMask(subRows, subColumns); this.orthogonalityStructure = new ArrayList<>(); - orthogonalityStructure.add(subRows); + orthogonalityStructure.add(subColumns); -// this.nRows = subRows.length; - this.nCols = subColumns.length; + this.nRows = subRows.length; +// this.nCols = subColumns.length; // this.positionMatrix = new DenseMatrix64F(nCols, nRows); // this.innerProduct = new DenseMatrix64F(nCols, nCols); // this.innerProduct2 = new DenseMatrix64F(nCols, nCols); @@ -127,9 +127,9 @@ public static class GeodesicLeapFrogEngine extends HamiltonianMonteCarloOperator public void setOrthogonalityStructure(ArrayList oldOrthogonalityStructure) { orthogonalityStructure.clear(); - ArrayList subRowList = new ArrayList<>(); - for (int i : subRows) { - subRowList.add(i); + ArrayList subColList = new ArrayList<>(); + for (int i : subColumns) { + subColList.add(i); } //check that orthogonalityStructure is consistent with the subRows @@ -137,7 +137,7 @@ public void setOrthogonalityStructure(ArrayList oldOrthogonalityStructure for (int i = 0; i < oldOrthogonalityStructure.size(); i++) { for (int j = 0; j < oldOrthogonalityStructure.get(i).length; j++) { - if (!subRowList.contains(oldOrthogonalityStructure.get(i)[j])) { //TODO: check that we're doing this by row (or allow to do by row or column) + if (!subColList.contains(oldOrthogonalityStructure.get(i)[j])) { //TODO: check that we're doing this by row (or allow to do by row or column) throw new RuntimeException("Cannot enforce orthogonality structure."); } if (alreadyOrthogonal.contains(oldOrthogonalityStructure.get(i)[j])) { @@ -255,13 +255,16 @@ private void checkMask(int[] rows, int[] cols) { } private DenseMatrix64F setOrthogonalSubMatrix(double[] src, int srcOffset, int block) { - DenseMatrix64F dest = new DenseMatrix64F(nCols, orthogonalityStructure.get(block).length); int nRowsOriginal = matrixParameter.getRowDimension(); - int[] blockRows = orthogonalityStructure.get(block); - for (int row = 0; row < blockRows.length; row++) { - for (int col = 0; col < subColumns.length; col++) { - int ind = nRowsOriginal * subColumns[col] + blockRows[row] + srcOffset; + int[] blockCols = orthogonalityStructure.get(block); + int nCols = blockCols.length; + + DenseMatrix64F dest = new DenseMatrix64F(nCols, nRows); + + for (int row = 0; row < nRows; row++) { + for (int col = 0; col < nCols; col++) { + int ind = nRowsOriginal * blockCols[col] + subRows[row] + srcOffset; dest.set(col, row, src[ind]); } } @@ -275,10 +278,10 @@ private DenseMatrix64F setOrthogonalSubMatrix(double[] src, int block) { private void unwrapSubMatrix(DenseMatrix64F src, int block, double[] dest, int destOffset) { int nRowsOriginal = matrixParameter.getRowDimension(); - int[] blockRows = orthogonalityStructure.get(block); - for (int row = 0; row < blockRows.length; row++) { - for (int col = 0; col < nCols; col++) { - int ind = nRowsOriginal * subColumns[col] + blockRows[row] + destOffset; + int[] blockCols = orthogonalityStructure.get(block); + for (int row = 0; row < nRows; row++) { + for (int col = 0; col < blockCols.length; col++) { + int ind = nRowsOriginal * blockCols[col] + subRows[row] + destOffset; dest[ind] = src.get(col, row); } } @@ -302,7 +305,7 @@ public void updatePosition(double[] position, WrappedVector momentum, for (int block = 0; block < orthogonalityStructure.size(); block++) { - int nRows = orthogonalityStructure.get(block).length; + int nCols = orthogonalityStructure.get(block).length; // positionMatrix.setData(position); DenseMatrix64F positionMatrix = setOrthogonalSubMatrix(position, block); @@ -397,6 +400,8 @@ public void projectMomentum(double[] momentum, double[] position) { for (int block = 0; block < orthogonalityStructure.size(); block++) { DenseMatrix64F positionMatrix = setOrthogonalSubMatrix(position, block); DenseMatrix64F momentumMatrix = setOrthogonalSubMatrix(momentum, block); + + int nCols = orthogonalityStructure.get(block).length; // positionMatrix.setData(position); // momentumMatrix.setData(momentum); @@ -405,7 +410,7 @@ public void projectMomentum(double[] momentum, double[] position) { CommonOps.multTransB(positionMatrix, momentumMatrix, innerProduct); EJMLUtils.addWithTransposed(innerProduct); - DenseMatrix64F projection = new DenseMatrix64F(nCols, orthogonalityStructure.get(block).length); + DenseMatrix64F projection = new DenseMatrix64F(nCols, nRows); CommonOps.mult(0.5, innerProduct, positionMatrix, projection); CommonOps.subtractEquals(momentumMatrix, projection); From 07c13ffff25e398dc30a1f53908e22b0e5ba80ed Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Thu, 19 Aug 2021 15:56:12 -0700 Subject: [PATCH 026/196] matrix inner product transform --- .../app/beast/development_parsers.properties | 2 + src/dr/util/MatrixInnerProductTransform.java | 129 ++++++++++++++++++ src/dr/util/Transform.java | 21 +++ 3 files changed, 152 insertions(+) create mode 100644 src/dr/util/MatrixInnerProductTransform.java diff --git a/src/dr/app/beast/development_parsers.properties b/src/dr/app/beast/development_parsers.properties index 0599d69be2..47158d3291 100644 --- a/src/dr/app/beast/development_parsers.properties +++ b/src/dr/app/beast/development_parsers.properties @@ -307,3 +307,5 @@ dr.inference.model.FactorProportionStatistic # Shrinkage dr.inference.model.MaskFromTree +# Structural Equation Modeling +dr.util.MatrixInnerProductTransform \ No newline at end of file diff --git a/src/dr/util/MatrixInnerProductTransform.java b/src/dr/util/MatrixInnerProductTransform.java new file mode 100644 index 0000000000..5f2454b1f2 --- /dev/null +++ b/src/dr/util/MatrixInnerProductTransform.java @@ -0,0 +1,129 @@ +package dr.util; + +import dr.inference.model.MatrixParameterInterface; +import dr.xml.*; +import org.ejml.data.DenseMatrix64F; +import org.ejml.ops.CommonOps; + +/** + * @author Gabriel Hassler + * @author Marc A. Suchard + */ + +public class MatrixInnerProductTransform extends Transform.MatrixVariateTransform { + + + private static final String MATRIX_INNER_PRODUCT = "matrixInnerProductTransform"; + + + public MatrixInnerProductTransform(int nRows, int nCols) { + super(nRows * nCols, nRows, nCols); + } + + @Override + public double[] inverse(double[] values, int from, int to, double sum) { + throw new RuntimeException(getTransformName() + " is not invertible"); + } + + @Override + public double[] gradient(double[] values, int from, int to) { + throw new RuntimeException("not yet implemented"); + } + + @Override + public double[] gradientInverse(double[] values, int from, int to) { + throw new RuntimeException("not yet implemented"); + } + + @Override + public String getTransformName() { + return MATRIX_INNER_PRODUCT; + } + + + @Override + protected double[] transform(double[] values) { + DenseMatrix64F X = DenseMatrix64F.wrap(rowDimension, columnDimension, values); + DenseMatrix64F XXt = new DenseMatrix64F(rowDimension, rowDimension); + CommonOps.multTransB(X, X, XXt); + + return XXt.getData(); + } + + @Override + protected double[] inverse(double[] values) { + throw new RuntimeException(getTransformName() + " is not invertible"); + } + + @Override + protected double getLogJacobian(double[] values) { + throw new RuntimeException(getTransformName() + " is not invertible"); + } + + @Override + protected double[] getGradientLogJacobianInverse(double[] values) { + throw new RuntimeException(getTransformName() + " is not invertible"); + } + + @Override + public double[][] computeJacobianMatrixInverse(double[] values) { + throw new RuntimeException(getTransformName() + " is not invertible"); + } + + @Override + protected boolean isInInteriorDomain(double[] values) { + return values.length == dim; + } + + public static final AbstractXMLObjectParser PARSER = new AbstractXMLObjectParser() { + + private static final String N_ROWS = "nRows"; + private static final String N_COLS = "nColumns"; + + @Override + public Object parseXMLObject(XMLObject xo) throws XMLParseException { + final int rowDimension; + final int columnDimension; + + if (xo.getChildCount() > 0) { + MatrixParameterInterface parameter = (MatrixParameterInterface) + xo.getChild(MatrixParameterInterface.class); + + rowDimension = parameter.getRowDimension(); + columnDimension = parameter.getColumnDimension(); + } else { + rowDimension = xo.getIntegerAttribute(N_ROWS); + columnDimension = xo.getIntegerAttribute(N_COLS); + } + + return new MatrixInnerProductTransform(rowDimension, columnDimension); + } + + @Override + public XMLSyntaxRule[] getSyntaxRules() { + return new XMLSyntaxRule[]{ +// new XORRule( +// new ElementRule(MatrixParameterInterface.class), +// new AndRule( +// AttributeRule.newIntegerRule(N_ROWS), +// AttributeRule.newIntegerRule(N_COLS)) +// ) + }; + } + + @Override + public String getParserDescription() { + return "Takes the matrix X and transforms it to XXt"; + } + + @Override + public Class getReturnType() { + return MatrixInnerProductTransform.class; + } + + @Override + public String getParserName() { + return MATRIX_INNER_PRODUCT; + } + }; +} diff --git a/src/dr/util/Transform.java b/src/dr/util/Transform.java index 133e1a4539..4453dd4160 100644 --- a/src/dr/util/Transform.java +++ b/src/dr/util/Transform.java @@ -473,6 +473,27 @@ public final boolean isInInteriorDomain(double[] values, int from, int to) { } } + abstract class MatrixVariateTransform extends MultivariateTransform { + + protected final int rowDimension; + protected final int columnDimension; + + public MatrixVariateTransform(int inputDimension, int outputRowDimension, int outputColumnDimension) { + super(inputDimension, outputRowDimension * outputColumnDimension); + this.rowDimension = outputRowDimension; + this.columnDimension = outputColumnDimension; + } + + + public int getRowDimension() { + return rowDimension; + } + + public int getColumnDimension() { + return columnDimension; + } + } + class LogTransform extends UnivariableTransform { public double transform(double value) { From 1428201b3559e18a69c67b408bc6bf5fa2c1942f Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Thu, 19 Aug 2021 16:00:26 -0700 Subject: [PATCH 027/196] full correlation precision gradient --- .../hmc/AbstractPrecisionGradient.java | 10 ++- .../hmc/FullCorrelationPrecisionGradient.java | 72 +++++++++++++++++++ .../hmc/PrecisionGradientParser.java | 11 +++ .../model/CompoundSymmetricMatrix.java | 18 +++++ .../inference/model/TransformedParameter.java | 4 ++ 5 files changed, 112 insertions(+), 3 deletions(-) create mode 100644 src/dr/evomodel/treedatalikelihood/hmc/FullCorrelationPrecisionGradient.java diff --git a/src/dr/evomodel/treedatalikelihood/hmc/AbstractPrecisionGradient.java b/src/dr/evomodel/treedatalikelihood/hmc/AbstractPrecisionGradient.java index c80b27bba0..be4726ba04 100644 --- a/src/dr/evomodel/treedatalikelihood/hmc/AbstractPrecisionGradient.java +++ b/src/dr/evomodel/treedatalikelihood/hmc/AbstractPrecisionGradient.java @@ -40,7 +40,7 @@ public abstract class AbstractPrecisionGradient extends AbstractDiffusionGradien private final GradientWrtPrecisionProvider gradientWrtPrecisionProvider; // final Likelihood likelihood; - final CompoundSymmetricMatrix compoundSymmetricMatrix; + protected final CompoundSymmetricMatrix compoundSymmetricMatrix; private final int dim; private Parametrization parametrization; @@ -146,14 +146,18 @@ public ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient.Deriv return ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient.DerivationParameter.WRT_VARIANCE; } - int getDimensionCorrelation() { + protected int getDimensionCorrelation() { return dim * (dim - 1) / 2; } - int getDimensionDiagonal() { + protected int getDimensionDiagonal() { return dim; } + protected int getDimensionFull() { + return dim * dim; + } + @Override public double[] getGradientLogDensity() { double[] gradient = (gradientWrtPrecisionProvider.getBranchSpecificGradient() == null) ? null : gradientWrtPrecisionProvider.getBranchSpecificGradient().getGradientLogDensity(); // Get gradient wrt variance diff --git a/src/dr/evomodel/treedatalikelihood/hmc/FullCorrelationPrecisionGradient.java b/src/dr/evomodel/treedatalikelihood/hmc/FullCorrelationPrecisionGradient.java new file mode 100644 index 0000000000..4668fd1b4e --- /dev/null +++ b/src/dr/evomodel/treedatalikelihood/hmc/FullCorrelationPrecisionGradient.java @@ -0,0 +1,72 @@ +package dr.evomodel.treedatalikelihood.hmc; + +import dr.inference.model.*; +import dr.util.MatrixInnerProductTransform; +import org.ejml.data.DenseMatrix64F; +import org.ejml.ops.CommonOps; + +/** + * @author Gabriel Hassler + * @author Marc A. Suchard + */ + +public class FullCorrelationPrecisionGradient extends CorrelationPrecisionGradient { + + private final MatrixParameterInterface decomposedMatrix; + private static final RuntimeException PARAMETER_EXCEPTION = new RuntimeException("off-diagonal parameter must be a mask of a inner product transform."); + + public FullCorrelationPrecisionGradient(GradientWrtPrecisionProvider gradientWrtPrecisionProvider, Likelihood likelihood, MatrixParameterInterface parameter) { + super(gradientWrtPrecisionProvider, likelihood, parameter); + + // TODO: this is super messy but get's the job done (for now). Maybe create a class just to wrap these different transforms/masks + Parameter correlationParameter = compoundSymmetricMatrix.getOffDiagonalParameter(); + if (correlationParameter instanceof MaskedParameter) { + MaskedParameter maskedCorrelation = (MaskedParameter) correlationParameter; + if (maskedCorrelation.getUnmaskedParameter() instanceof TransformedMultivariateParameter) { + TransformedMultivariateParameter transformedParameter = + (TransformedMultivariateParameter) maskedCorrelation.getUnmaskedParameter(); + if (transformedParameter.getTransform() instanceof MatrixInnerProductTransform) { //TODO: chain rul in transforms + decomposedMatrix = (MatrixParameterInterface) transformedParameter.getUntransformedParameter(); + } else { + throw PARAMETER_EXCEPTION; + } + } else { + throw PARAMETER_EXCEPTION; + } + } else { + throw PARAMETER_EXCEPTION; + } + + } + + + @Override + public int getDimension() { + return getDimensionFull(); + } + + @Override + public Parameter getParameter() { + return decomposedMatrix; + } + + @Override + double[] getGradientParameter(double[] gradient) { + int dim = decomposedMatrix.getRowDimension(); + + double[] correlationGradient = compoundSymmetricMatrix.updateGradientFullOffDiagonal(gradient); + double[] decomposedGradient = new double[gradient.length]; + + DenseMatrix64F corGradMat = DenseMatrix64F.wrap(dim, dim, correlationGradient); + DenseMatrix64F decompMat = DenseMatrix64F.wrap(dim, dim, decomposedMatrix.getParameterValues()); + DenseMatrix64F decompGradMat = DenseMatrix64F.wrap(dim, dim, decomposedGradient); + + CommonOps.mult(decompMat, corGradMat, decompGradMat); + + CommonOps.scale(2.0, decompGradMat); + + return decomposedGradient; + } + + +} diff --git a/src/dr/evomodelxml/continuous/hmc/PrecisionGradientParser.java b/src/dr/evomodelxml/continuous/hmc/PrecisionGradientParser.java index a54117c4df..c8ec2f95c1 100644 --- a/src/dr/evomodelxml/continuous/hmc/PrecisionGradientParser.java +++ b/src/dr/evomodelxml/continuous/hmc/PrecisionGradientParser.java @@ -56,6 +56,7 @@ public class PrecisionGradientParser extends AbstractXMLObjectParser { private final static String PRECISION_DIAGONAL = "diagonal"; private final static String PRECISION_DIAGONAL_OLD = "precisionDiagonal"; private final static String PRECISION_BOTH = "both"; + private final static String PRECISION_CORRELATION_DECOMPOSED = "decomposedCorrelation"; private static final String TRAIT_NAME = TreeTraitParserUtilities.TRAIT_NAME; @Override @@ -71,6 +72,8 @@ private ParameterMode parseParameterMode(XMLObject xo) throws XMLParseException mode = ParameterMode.WRT_CORRELATION; } else if (parameterString.compareTo(PRECISION_DIAGONAL) == 0 || parameterString.compareToIgnoreCase(PRECISION_DIAGONAL_OLD) == 0) { mode = ParameterMode.WRT_DIAGONAL; + } else if (parameterString.equalsIgnoreCase(PRECISION_CORRELATION_DECOMPOSED)) { + mode = ParameterMode.WRT_CORRELATION_DECOMPOSED; } return mode; } @@ -99,6 +102,14 @@ public AbstractPrecisionGradient factory(GradientWrtPrecisionProvider gradientWr MatrixParameterInterface parameter) { return new DiagonalPrecisionGradient(gradientWrtPrecisionProvider, treeDataLikelihood, parameter); } + }, + WRT_CORRELATION_DECOMPOSED { + @Override + public AbstractPrecisionGradient factory(GradientWrtPrecisionProvider gradientWrtPrecisionProvider, + TreeDataLikelihood treeDataLikelihood, + MatrixParameterInterface parameter) { + return new FullCorrelationPrecisionGradient(gradientWrtPrecisionProvider, treeDataLikelihood, parameter); + } }; abstract AbstractPrecisionGradient factory(GradientWrtPrecisionProvider gradientWrtPrecisionProvider, diff --git a/src/dr/inference/model/CompoundSymmetricMatrix.java b/src/dr/inference/model/CompoundSymmetricMatrix.java index 55770473cb..aaf7afc537 100644 --- a/src/dr/inference/model/CompoundSymmetricMatrix.java +++ b/src/dr/inference/model/CompoundSymmetricMatrix.java @@ -137,6 +137,24 @@ public double[] updateGradientOffDiagonal(double[] vecX) { return updateGradientCorrelation(vechuGradient); } + public double[] updateGradientFullOffDiagonal(double[] gradient) { + assert gradient.length == dim * dim; + + double[] diagQ = diagonalParameter.getParameterValues(); + + double[] offDiagGradient = new double[gradient.length]; + + int k = 0; + for (int i = 0; i < dim; ++i) { + for (int j = 0; j < dim; ++j) { + offDiagGradient[k] = gradient[i * dim + j] * Math.sqrt(diagQ[i] * diagQ[j]); + ++k; + } + } + + return offDiagGradient; + } + public double[] updateGradientCorrelation(double[] gradient) { if (!isCholesky) { return gradient; diff --git a/src/dr/inference/model/TransformedParameter.java b/src/dr/inference/model/TransformedParameter.java index 8dacd02984..3847b442d6 100644 --- a/src/dr/inference/model/TransformedParameter.java +++ b/src/dr/inference/model/TransformedParameter.java @@ -209,6 +209,10 @@ public double diffLogJacobian(double[] oldValues, double[] newValues) { } } + public Transform getTransform() { + return transform; + } + protected final Parameter parameter; protected final Transform transform; protected final boolean inverse; From 0ec80c8743efd8aae3df42e92ae90359d5fe20c7 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Fri, 27 Aug 2021 14:45:16 -0700 Subject: [PATCH 028/196] bug fix in GeodesicHMC --- .../hmc/GeodesicHamiltonianMonteCarloOperator.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/dr/inference/operators/hmc/GeodesicHamiltonianMonteCarloOperator.java b/src/dr/inference/operators/hmc/GeodesicHamiltonianMonteCarloOperator.java index 09754e82b8..407ddcab74 100644 --- a/src/dr/inference/operators/hmc/GeodesicHamiltonianMonteCarloOperator.java +++ b/src/dr/inference/operators/hmc/GeodesicHamiltonianMonteCarloOperator.java @@ -148,9 +148,9 @@ public void setOrthogonalityStructure(ArrayList oldOrthogonalityStructure orthogonalityStructure.add(oldOrthogonalityStructure.get(i)); } - for (int i = 0; i < subRows.length; i++) { - if (!alreadyOrthogonal.contains(subRows[i])) { - orthogonalityStructure.add(new int[]{subRows[i]}); + for (int i = 0; i < subColumns.length; i++) { + if (!alreadyOrthogonal.contains(subColumns[i])) { + orthogonalityStructure.add(new int[]{subColumns[i]}); } } From 26384d8eb24fd18bfbc0ba57f66f5906d8330d11 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Fri, 27 Aug 2021 14:45:31 -0700 Subject: [PATCH 029/196] updating test xml for GeodesicHMC --- ci/TestXML/testGeodesicHMC.xml | 157 ++++++++++++++++++++++++++++++--- 1 file changed, 144 insertions(+), 13 deletions(-) diff --git a/ci/TestXML/testGeodesicHMC.xml b/ci/TestXML/testGeodesicHMC.xml index 45ad93c7ee..b786378296 100644 --- a/ci/TestXML/testGeodesicHMC.xml +++ b/ci/TestXML/testGeodesicHMC.xml @@ -50,6 +50,31 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + @@ -170,6 +214,47 @@ + + + + + + + + + + + check geodesic leapfrog position + + + + + + -0.38482093085762703 0.025532145855234056 -0.47524352258225555 + -0.16143910221998026 0.7325336366554922 -0.18234052725410338 + -0.15833080607533787 -0.010559139462760325 -0.4621442194403536 + 0.04468814415719696 0.3353640677552862 -0.2628674991877096 + 0.23164750963293373 0.37313155882177257 -0.5285731992066846 + 0.09764476238442396 0.04647869140091962 0.08501541243930114 + 0.34796492729763856 -0.05209210160484462 -0.2825403364803621 + 0.48181029964268246 0.2485572959205083 -0.24282884394633733 + 0.48037324907816614 0.12245643738398809 0.1381010939300122 + 0.38937168724293414 -0.35956182072618714 0.11719588758032286 + + + + + + check geodesic leapfrog ratio + + + + + + -102.79492338116779 + + + @@ -190,10 +275,9 @@ - - - + \ No newline at end of file + +X3 = randn(p, k) +X3 = svd(X).U +X03 = copy(X3) +h, h2, X3_new = geo_hmc(X3, copy(M), dist, 5, 0.05, k, p, ortho_structure = [[1, 2], [3]]) +hastings_ratio = h - h2 + +function pretty(x::Vector) + s = join(x, ' ') + clipboard(s) + return s +end + +function pretty(x::Matrix) + n, p = size(x) + rows = [join(x[i, :], '\t') for i = 1:n] + s = join(rows, '\n') + clipboard(s) + return s +end + + --> \ No newline at end of file From b13af732f763b837cad8fe43e5f9cfba9fa9b12e Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Mon, 30 Aug 2021 15:33:59 -0700 Subject: [PATCH 030/196] small changes needed for test xml --- .../AbstractTransformedCompoundMatrix.java | 23 ++++++++++++++++--- .../model/CompoundSymmetricMatrix.java | 5 +++- .../model/CompoundSymmetricMatrixParser.java | 17 ++++++++++++-- 3 files changed, 39 insertions(+), 6 deletions(-) diff --git a/src/dr/inference/model/AbstractTransformedCompoundMatrix.java b/src/dr/inference/model/AbstractTransformedCompoundMatrix.java index 368ad3abd4..15819adefb 100644 --- a/src/dr/inference/model/AbstractTransformedCompoundMatrix.java +++ b/src/dr/inference/model/AbstractTransformedCompoundMatrix.java @@ -37,6 +37,8 @@ abstract public class AbstractTransformedCompoundMatrix extends MatrixParameter final Parameter diagonalParameter; final Parameter offDiagonalParameter; + protected boolean isStrictlyUpperTriangular = true; + final CompoundParameter untransformedCompoundParameter; protected final int dim; @@ -59,7 +61,7 @@ abstract public class AbstractTransformedCompoundMatrix extends MatrixParameter diagonalParameter = diagonals; dim = diagonalParameter.getDimension(); offDiagonalParameter = - (transform == null) ? offDiagonal: new TransformedMultivariateParameter(offDiagonal, transform, inverse); + (transform == null) ? offDiagonal : new TransformedMultivariateParameter(offDiagonal, transform, inverse); addParameter(diagonalParameter); addParameter(offDiagonalParameter); @@ -149,7 +151,7 @@ public Parameter getUntransformedOffDiagonalParameter() { return offDiagonalParameter; } - public CompoundParameter getUntransformedCompoundParameter(){ + public CompoundParameter getUntransformedCompoundParameter() { return untransformedCompoundParameter; } @@ -185,7 +187,6 @@ public String getReport() { } int getUpperTriangularIndex(int i, int j) { - assert i != j; if (i < j) { return upperTriangularTransformation(i, j); } else { @@ -194,6 +195,22 @@ int getUpperTriangularIndex(int i, int j) { } private int upperTriangularTransformation(int i, int j) { + if (isStrictlyUpperTriangular) { + assert i != j; + return strictlyUpperTriangularTransformation(i, j); + } + return weaklyUpperTriangularTransformatino(i, j); + } + + private int strictlyUpperTriangularTransformation(int i, int j) { return i * (2 * dim - i - 1) / 2 + (j - i - 1); } + + private int weaklyUpperTriangularTransformatino(int i, int j) { + return i * (2 * dim - i + 1) / 2 + (j - i); + } + + public void setStrictlyUpperTriangular(boolean b) { + this.isStrictlyUpperTriangular = b; + } } diff --git a/src/dr/inference/model/CompoundSymmetricMatrix.java b/src/dr/inference/model/CompoundSymmetricMatrix.java index aaf7afc537..91d03c9626 100644 --- a/src/dr/inference/model/CompoundSymmetricMatrix.java +++ b/src/dr/inference/model/CompoundSymmetricMatrix.java @@ -74,8 +74,11 @@ public double getParameterValue(int row, int col) { Math.sqrt(diagonalParameter.getParameterValue(row) * diagonalParameter.getParameterValue(col)); } return offDiagonalParameter.getParameterValue(getUpperTriangularIndex(row, col)); + } else if (isStrictlyUpperTriangular) { + return diagonalParameter.getParameterValue(row); } - return diagonalParameter.getParameterValue(row); + return diagonalParameter.getParameterValue(row) * + offDiagonalParameter.getParameterValue(getUpperTriangularIndex(row, row)); } @Override diff --git a/src/dr/inferencexml/model/CompoundSymmetricMatrixParser.java b/src/dr/inferencexml/model/CompoundSymmetricMatrixParser.java index 0f777426f4..fa8a051a8c 100644 --- a/src/dr/inferencexml/model/CompoundSymmetricMatrixParser.java +++ b/src/dr/inferencexml/model/CompoundSymmetricMatrixParser.java @@ -39,6 +39,7 @@ public class CompoundSymmetricMatrixParser extends AbstractXMLObjectParser { public static final String OFF_DIAGONAL = "offDiagonal"; public static final String AS_CORRELATION = "asCorrelation"; public static final String IS_CHOLESKY = "isCholesky"; + public static final String IS_STRICTLY_UPPER = "isStrictlyUpperTriangular"; public String getParserName() { return MATRIX_PARAMETER; @@ -56,7 +57,18 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { boolean isCholesky = xo.getAttribute(IS_CHOLESKY, false); - return new CompoundSymmetricMatrix(diagonalParameter, offDiagonalParameter, asCorrelation, isCholesky); + boolean isStrictlyUpperTriangular = xo.getAttribute(IS_STRICTLY_UPPER, true); + + CompoundSymmetricMatrix compoundSymmetricMatrix = + new CompoundSymmetricMatrix(diagonalParameter, offDiagonalParameter, asCorrelation, isCholesky); + + if (!isStrictlyUpperTriangular) { + System.err.println("Warning: attribute " + IS_STRICTLY_UPPER + " in " + MATRIX_PARAMETER + " should only be set to 'false' " + + "for debugging and testing purposes."); + compoundSymmetricMatrix.setStrictlyUpperTriangular(false); + } + + return compoundSymmetricMatrix; } //************************************************************************ @@ -77,7 +89,8 @@ public XMLSyntaxRule[] getSyntaxRules() { new ElementRule(OFF_DIAGONAL, new XMLSyntaxRule[]{new ElementRule(Parameter.class)}), AttributeRule.newBooleanRule(AS_CORRELATION, true), - AttributeRule.newBooleanRule(IS_CHOLESKY, true) + AttributeRule.newBooleanRule(IS_CHOLESKY, true), + AttributeRule.newBooleanRule(IS_STRICTLY_UPPER, true) }; public Class getReturnType() { From 5f9705b35fac8dc4dcdd23dcd47088099233e553 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Mon, 30 Aug 2021 15:34:21 -0700 Subject: [PATCH 031/196] new test xml --- .../testDecomposedPrecisionGradient.xml | 949 ++++++++++++++++++ 1 file changed, 949 insertions(+) create mode 100644 ci/TestXML/testDecomposedPrecisionGradient.xml diff --git a/ci/TestXML/testDecomposedPrecisionGradient.xml b/ci/TestXML/testDecomposedPrecisionGradient.xml new file mode 100644 index 0000000000..00b3b28359 --- /dev/null +++ b/ci/TestXML/testDecomposedPrecisionGradient.xml @@ -0,0 +1,949 @@ + + + + + 1.4504684181045955 -0.14284447291449504 -1.0119411669208316 -1.8819540799010328 + -0.1384404093049692 + + -0.662952763140711 0.9513614285349006 -2.108932707646549 -2.0961791811470993 + 2.661172775719335 0.5350379298018195 -1.0959241570237364 + + + 0.15352097287184627 -0.09472535562758214 0.27985238903246823 -0.019946152301260983 + 0.3443372277336919 + + 0.08374875118164166 -0.7729120714689584 0.17424374792444564 0.9702674802897109 + 0.09545400167045057 -0.32162404929972754 -0.3387822914562665 + + + 2.200653223968015 -0.0795385792579889 -1.1446401668832948 1.9274697338482123 + 0.41205308749434033 + + -1.0121123585612022 0.9195365779605083 -0.6170544346845236 -1.8292824815753927 + 0.7004598382205527 0.9401836929519916 -0.7813299246039653 + + + -1.2532545176712426 -0.48790521779443685 0.884213522671786 -2.064560843371396 + -0.3158152558935746 + + 1.329453586675561 -1.582729811590773 0.8992507678788044 2.349277968849871 + -0.6601559498192575 -0.882597555414697 -0.0169224422088861 + + + 0.41868109815844307 -0.22186455310436676 -0.022988319788239897 -2.548123167548903 + -0.4622188536745763 + + -0.8531782996559609 1.0382781688865887 -1.077019750629998 -1.6182460032886408 + 0.9826706209863015 1.0842189097074624 -2.1384229043759118 + + + 4.4273171112722505 0.6480101890705183 -2.172700757793593 -0.0035088867329724216 + -0.9708496239216856 + + -0.8329657542809155 1.2584151165835098 -1.6142142131030388 -1.076611831980412 + 2.6670622356123346 0.6861714470945357 -1.175866940626436 + + + 2.3203929162532297 0.3725584132742778 -1.1769389867418278 1.2328721429338592 + -0.7713444353828667 + + -1.892674571708976 -0.16249385788628967 -0.8409999291108403 -1.4471593850810358 + 2.6709406209922237 -0.37884954072281496 -1.0623132825616932 + + + -0.7971839751371184 -0.15946626495406074 0.6818281014696681 -1.4099224016229832 + 0.09386886880711572 + + -0.2021982832276974 1.3049519389247144 -0.4452588514935441 -1.2402005454786986 + 0.19699695588271904 0.6648096594726121 -0.07910296518730908 + + + 0.7872010607307975 -0.5360955364584296 -0.12169838853620368 -0.39452949231308404 + -0.11471852018908674 + + -1.6368817092847547 1.4981731613772484 -0.9092372098564622 -2.4399308027236297 + 1.9643506737350376 0.7516103191917511 0.05616039741333009 + + + 3.6382333366720823 -0.017061858949330128 -1.9008988874383252 -0.348034849919142 + -0.9855946367623509 + + 0.1938389477846087 0.3031614518091112 -0.007510311784659662 -0.8490066503703302 + 2.532926244907736 0.9024319472203699 -1.0253068700652412 + + + -1.184684596412352 -0.3373341216589226 0.9345919035227708 -3.014912786360072 + 0.31001311760138495 + + 1.155068355934342 -0.10355915593333934 0.14687532037602075 0.8643076724641471 + -0.10816185360547724 -0.49153544892722345 0.60141465438905 + + + 4.631277219477512 0.03698746401422809 -2.268133753571182 0.0169209622400682 + -0.910164285136404 + + -0.3716484808992495 1.3068979802763248 -1.3355497680250177 -1.4125140814579382 + 2.681827609434124 0.6739624227871729 -1.153872000322871 + + + 1.3970737780256708 0.6043852118437912 -0.7623866000993361 2.8060917498377744 + -0.5503316735891083 + + -0.7944619108528178 0.1302411814772453 -0.6603824354753844 -0.9929557322017177 + -0.22388480545019415 0.13190298673215253 -0.8570216588115904 + + + -1.616788667295229 -0.806437390067797 0.39814890718108953 -1.994329097799743 + 0.4094240169625553 + + 1.5806278265045601 -1.907292482900509 1.8424407517830224 3.6027466877664915 + -1.2739828077929862 -2.3148080456463718 1.5625391832502078 + + + 0.8004136576866477 -0.533248942850482 -0.14103433675566102 -0.20608027718036043 + 0.4516581972571442 + + 2.2980274785855364 -1.5459918179314465 1.8642453790415525 3.2173428901530112 + -1.2615442212793726 -0.385434747297974 0.8937397902195097 + + + 3.6080068303925885 0.25395120462964105 -2.257122483546119 -0.3153824915228581 + -1.454448222823429 + + -0.34622690546058366 1.836944052124082 -2.1304785222347555 -2.4713047939192747 + 1.7172493840738947 1.6934157138048465 -0.6545207615746411 + + + -0.7105269291410143 -0.564751929090143 -0.06745141274519473 0.5926988328636207 + 0.0659783383091872 + + 0.06862477459277477 -0.7372113082004961 0.9166652418189181 -0.14796100702318604 + 0.1279068214962194 0.31978643343459845 0.4488992673247123 + + + 1.056102298562018 -0.7621157304426829 -0.6066763610384218 -1.2510050456942825 + 0.07589646564509389 + + -1.0218882383113286 0.951716316498584 -0.5036360302357432 -1.4166090977262955 + 2.3470914027194043 0.7003915111304753 -1.271180895506425 + + + 0.22617599455272513 -0.27359061127359 0.19403185683644306 -1.2122114623216327 + -0.9053262460621428 + + -0.3118491154509224 0.18279808408889753 -0.5726253944184869 -1.362561852120372 + 0.43053573793319855 0.46246578360400625 -1.5219897261995319 + + + -0.22130304532816447 0.0644664770620994 -0.19320468485623532 -2.1816453393228317 + 0.06801120918083173 + + -1.3079598555318348 1.0889008626479812 -0.950785435760656 -1.5311546561004854 + 0.6803402301420669 1.5656608082295063 -2.152229683653433 + + + -0.2838907742715124 -0.39409516878042883 0.7555113226363832 0.10440833751119188 + -0.40763339779631863 + + 1.777257501621698 -1.466580404150666 2.557574945310042 2.5656172080905115 + -0.6946788838565369 -2.5767760332793506 1.7516289327873902 + + + 0.6626528498060399 -0.3275965102077757 0.24223610993684527 -0.09096873543368655 + -0.4029789107634051 + + 0.6334697418991561 0.47950254670730963 0.5408008111312147 0.2075177669236997 + 0.16938120051211147 -0.3509934485713843 -0.44624475021927673 + + + 1.6796302144098492 0.32666653115600824 -0.7157290381038556 0.39203091157937114 + -0.4932915811213813 + + -1.9325920730623534 1.467159261579424 -1.3216275641324522 -2.1884346703217656 + 2.8480033681739214 -0.7810847425506641 -0.39916819781153146 + + + 1.1987535606257749 0.44705047091757916 -0.6093372000446172 0.7016476324683573 + -0.19704755205028981 + + -0.4426175008296125 0.9636603375059469 -0.19911720606715827 -1.3807208478081199 + 0.4356585482940649 0.31618292516224733 -1.167361800389819 + + + 2.1183112928125043 0.022795269100566784 -0.706312730223746 -0.4900141301168763 + -1.033334691636246 + + -2.3487414874746237 1.7285797739254058 -0.6349828003218257 -2.2977694507260074 + 2.332038004713757 0.7540213031632179 -1.4616788797096008 + + + 3.6031539008231737 0.7593285623245982 -2.3417410018912825 0.030471520206703576 + -1.0601756050140614 + + -0.0424528433076128 1.492524847428248 -1.5566842834107706 -2.7775711512521744 + 1.6401839091606571 1.85362975077714 -0.6436531688767123 + + + -0.3220924708343631 -0.5557191337459022 0.9113832725772277 -2.029114246058209 + -0.05794683472004196 + + -1.2777149496779292 0.9336746076200828 -0.5172022546593567 -1.6820017406897751 + 0.5731418043257617 -0.0829656049388475 0.40756760062873837 + + + 0.41105237565580754 0.21883586718662515 -0.5959194733208327 -2.6282034032092296 + -0.23753650184357944 + + -0.9238045908838415 1.4752081412702691 -1.570262480506506 -1.999352187312831 + 0.8052658339262135 0.9511051726735815 -1.3314904505124694 + + + 0.9260624288369403 0.20189744941457832 0.09587864300009852 1.5652912151367726 + -0.2884451767339513 + + -1.7729281899337024 0.2995941415263823 0.39217262694695354 -0.7923526828247436 + 0.128124609819587 0.8758488844989119 -1.2847129503207024 + + + -0.722008992386592 -0.34638856073688074 1.4284835305939063 -2.156182664052707 + 0.17356894965076541 + + 2.9221467487803654 -2.4726294191734794 1.630556322037672 3.5534218212665065 + -1.748234399930954 -1.6358758905851478 0.9446381657426747 + + + -0.7112318979285964 -0.06631694773257929 0.37869372378975935 -2.094053360207584 + -0.013800708379046775 + + -0.29680884213874104 -0.017436098340965267 0.3332093579688632 0.39814584645828505 + 1.3215766632077042 -0.7958698423729463 -1.4620042747333908 + + + -0.4311110564848181 -0.2407829274604897 0.38235848792939386 -1.7771324431896427 + -0.1183129189436983 + + -1.115850445146014 1.2514140671569534 -0.6035544285736413 -1.3492750675106686 + 0.12317208887937103 1.2892500294144715 -1.7667504998702162 + + + -2.9982215568260204 -0.17268567918201416 1.520299005713333 -1.7584774491934643 + 0.4155704140876405 + + 0.3580923569923377 0.07063099475589191 0.2152720709842974 0.5374146705098859 + -0.30511657453950136 0.1157800446475814 0.4100636085945001 + + + 0.4814819642466103 -0.4821003787585708 -0.17427200981657862 -0.6809792832395908 + -0.2539000538201176 + + -0.9862476915533331 1.0806570693673345 -0.791410384248968 -1.2708515940778227 + 0.5980468547237873 0.7119850511893989 -1.5122376558065533 + + + 1.516599003711587 -0.009497377912340378 -0.4084076834409431 -0.8136805303153987 + -0.08727914585772706 + + -1.7263811312630672 1.5183046917975478 -1.3313290390664052 -2.5357803711437885 + 2.822604414467734 0.7179394306238576 -0.45572900651220394 + + + 2.9616508482990542 0.13454295303382208 -1.516516849472648 0.5085184440567294 + -1.3264306465570126 + + -1.0854819696685565 1.6360339879662262 -2.089763423606385 -2.84791210471931 + 1.960133669528192 1.4389528809830336 -1.2979580276168088 + + + 2.683315390352514 0.4831213169046982 -1.8499378634646169 0.4004129926274031 + -0.6561020274325278 + + 0.44421987666811846 1.196528538640992 -2.403733352483769 -2.2922211381740425 + 1.7423160301146756 1.2988816932168126 -0.740055861827871 + + + -1.6571545757767276 0.20214315947322248 1.837357954783404 0.35917649047606914 + 0.7124617270865872 + + -0.09029346005745059 -0.073468506969504 0.31608422026570304 -0.08405129632296132 + -0.5615947095925746 0.6767322992201802 0.047908296791057126 + + + 1.7983700637108855 -0.37270671456349563 -1.0874447029580625 -0.06868876244436617 + -0.7303755266699812 + + -1.7331036423543644 1.5440123279979427 -0.6654055132403026 -2.486501149743464 + 1.7066397670722657 1.1884565428696456 -0.4665150080416766 + + + 0.5397446697222499 0.09102077471287501 -0.14962415397300588 -0.5375207519400755 + -0.5603724750172383 + + -1.3960309125189996 0.22787294965844804 1.0503448423484447 0.07897417111214299 + 0.4569341161021331 0.29485091412424436 -1.8427372186697375 + + + -0.5471743771612763 -0.1878369947260156 0.29893379156247163 -1.0356613305328137 + 0.3625124992194364 + + 2.9953785139375486 -2.3379625137870486 2.110188948371469 4.824876207752994 + -1.2681644531552334 -1.7869817352752764 0.6708938067984612 + + + -0.9337979745643583 0.09213492517095584 0.42596573980166225 -1.1985744266740472 + 0.05890294798873634 + + 2.1914223144583946 -2.8276075586318234 2.2286444575132283 4.158384691077882 + -1.387921705068349 -0.8523109540356532 0.8046381789639238 + + + -1.3123630744902304 0.0026785455358834392 0.480120200199058 2.1870411302417656 + 0.14797158235149355 + + -0.9539489628900828 2.069382657189224 -1.275790529588412 -2.1448331489307497 + -0.9413512771516541 1.8193272961312323 0.43682164703932025 + + + -1.9831767903444852 -0.573253775354061 1.5384605992966272 -4.983129457977968 + -0.4975716827525164 + + 3.105852185049976 -1.8302043541581954 0.5792267809555028 2.76896559781388 + -0.9328340873691627 -1.9650419579326028 1.5691790769699159 + + + 1.618809150082869 0.05675800765232991 -0.518410024358127 -1.5165599339777596 + -1.0490971573018375 + + 2.1597460735786607 -1.219628178433101 0.19139914584365494 1.4531559139508947 + 0.4190914359767381 -1.7508716326830631 0.978074024901255 + + + -0.8307173898767378 0.10654444778174302 0.6405845548532689 -0.9676168933810074 + 0.15167639211208472 + + 2.2092973182684155 -1.5749870332113791 1.6519545463345964 3.1305379579554655 + -0.9569962714724581 -1.4148598687219434 1.1148547787499044 + + + -1.6121168001918886 -0.4234330706678282 0.6397322958324675 0.166853400737008 + 0.5635887952475342 + + -0.5114846392291167 0.6728957918526385 -1.5163464035486054 -2.1350332695849237 + -0.7406252956326891 1.5660453415348106 0.2963978994269499 + + + -0.849465410173332 0.03610940061343518 0.798737804283377 1.0040212714524026 + 0.5870743172626081 + + -0.7924600317128573 0.1728196463733907 0.41514392465608213 -0.8287716580614396 + 0.7330044028331248 -0.4977973820150349 -0.5842320380884116 + + + 0.12492157039278184 0.4013102736049088 0.3101452754022568 -1.5567013585490281 + -0.034290186701839115 + + -1.1658230905379925 0.9993830959164887 -0.8249113737391192 -1.7589467666583338 + 1.5494776397636474 0.1367086141537926 0.31731682552318063 + + + 3.080918198030047 0.5977622692966176 -1.9517450309007207 0.46748544287426863 + -0.9225931895587817 + + -1.1094309189668474 1.038988039780252 -0.8560558821786213 -1.3940833489482944 + 2.0167240341166868 1.5766537566556682 -1.3934513291656896 + + + 0.33907766076668744 0.20374651950717282 0.017587159972954652 -1.848833417729504 + -0.24679956897869268 + + -0.42224746987496736 0.8705962934638634 -0.44964713274319346 -0.8961650296445403 + 2.1360725781382035 -0.2089430026097176 -1.4778182144809444 + + + 0.23748721441616025 0.33950511706703024 0.07138897671171872 -0.8479946671359011 + -0.24533441079143165 + + -1.8603565341982247 1.263137684642208 -1.2447696956829375 -1.4163642180723728 + 0.5118578712895288 0.7636481180814552 -0.902489452933578 + + + 1.2448589967721209 0.2597879694188224 -0.562751852762885 1.1752712579011666 + -0.3702222827040246 + + -0.7281406059397849 0.9281419958568942 -0.2086186985887098 -1.0295644302424845 + 0.46001898798505764 0.414522602100573 -1.1066734128117535 + + + 0.7525818314553964 0.4253392801464652 -0.5166956093485332 0.9669757423590708 + -0.06622375241998388 + + -1.9108521457986223 1.5544146043278646 -1.751546108590263 -3.119156550640771 + 1.4033004061808065 1.1975280132479214 -0.5352308550787401 + + + 1.0855153415944376 0.594892914230102 -0.5740417451565234 -1.6350191050983547 + -0.3204245999529538 + + 0.566632044862726 -0.024199556305300338 -0.04503347105177532 -0.43224712664736853 + 1.9583407185381525 -0.6333666734094967 -1.898827121791507 + + + -0.42895391488625734 -0.030196893133901864 0.09517310954394331 -1.0354241246125127 + 0.09183199423266761 + + -0.08859811416080238 0.39588769222624554 -0.1359460295865526 0.3282933562541942 + 0.07597370371040849 0.35217852240131553 -0.5400636631305571 + + + 2.4843847111568533 0.2001695727735501 -1.7872929009884984 0.4983220859883465 + -0.7789030556555453 + + -0.9136672766520304 0.3852030596829894 0.1771818118330906 -0.25734502542158344 + 1.709140955973959 0.7895634141704581 -1.2378614178101772 + + + 3.560564318630743 0.14069904613405163 -1.1327261150185497 0.04277960514154723 + -0.7467245787868527 + + 0.6021199393289647 0.060512536487224475 0.06807770091546825 -0.20726525893378683 + 1.4522160954782413 0.43088046393163626 -1.0332234194917902 + + + 3.2049481306298557 -0.5869473800827703 -1.519783362477308 0.130092851523213 + -0.7897297468507531 + + 0.10539531164799716 1.3578992717111218 -1.7138560390248565 -2.009672010972344 + 1.3651492464419435 1.6843153145902239 -1.4314535686105943 + + + -0.8732135062124844 -0.005511814957031458 0.40991415392606484 -2.35697560461693 + -0.12701922004566 + + -1.343468257713221 0.9534380616825053 -0.49153793137041246 -1.7467145850010932 + 0.2560227682761405 0.4169927148904515 -0.15419723657408055 + + + 2.4471629044219783 0.23757244668459218 -1.673990919008239 2.154138560313645 + -0.576861766222922 + + -1.3019501249557708 1.0081091413798915 -0.5064152377289395 -1.3205508513863478 + 0.8724685670729289 0.4967875740291248 -1.427986703280894 + + + -0.7940899538024865 -0.4306669146999051 1.05172543280212 -1.7865787003244866 + -0.21803797465191893 + + 1.8184244552047455 -1.6986310467521866 0.9683902425684878 3.3413479535828556 + -0.8930780835545424 -2.0201970268159752 1.4531415916508565 + + + 1.133150030989805 -0.3411105978786282 -1.0273462168407452 -0.1987428753501449 + -0.5877224566487076 + + -3.2192977875639013 2.25932125174942 -1.5396057457480405 -3.4050622640584236 + 2.924769178570502 0.8281079822432955 -0.3152319553895514 + + + 1.3141773522603473 0.08439959708439354 -1.0169679721403058 0.6398242878775741 + -0.842193026155539 + + -1.4985622426363725 1.888104494180459 -1.6146390579817658 -2.556566984739852 + 1.5684092099902174 1.1561683404366787 -0.2723094839836289 + + + 2.1753785780352346 0.8086744282524415 -0.4311497863563798 0.2884503912492653 + -0.11117632289584495 + + -1.7719196472620025 -0.23877491735402243 0.43832894495638125 -0.27204167140550456 + 2.841252281390278 0.06623953500076873 -1.3484741909564584 + + + 0.05022788599094996 -0.5203791264789055 0.08334026094680996 -0.07714546572378381 + -0.15129530972223396 + + -0.3273829662998836 -0.3993171163469974 0.3385388971286929 -0.14603928619518797 + 0.5038415175895067 0.7861751547527236 0.23013458857398875 + + + 0.46049384076377137 -0.4318037056311801 0.1984896903616046 -1.7492976943296135 + -0.1633899350949141 + + -0.6337849148210263 0.9436418894839859 -0.8701958077369399 -1.3880806600848157 + 0.8106798167130254 0.40949865749902625 -1.7109936686575904 + + + 0.3729245503013578 -0.07753072956861637 0.1911144813163615 -0.4462014928244675 + -0.5957955311359391 + + -2.265799123774301 2.152249859033429 -1.1408054931352454 -3.118644776124542 + 2.831736198964476 0.8351768788187263 -0.13668759635378702 + + + -0.0037396296758778758 -0.2769455917471196 0.2446696793976476 -2.663668202147154 + -0.5705094609213028 + + 0.406978874755234 0.02491366158676382 -0.11765263917371116 -0.32420824659138753 + -0.015936395943933368 0.02269841326960867 -0.8058188791611416 + + + 1.7390725067638733 -0.0893622647653759 -0.3564400152674386 -0.7992777119521316 + -0.1536859214861031 + + 0.5430419825648486 0.0950505725279309 0.2664451673804889 0.4500093978812432 + 0.5807012885043125 0.1909539569151718 0.07800159104018549 + + + -0.14796067090633921 0.3749731641088246 -0.7428996349572857 -1.430810297625129 + -0.24629717692699526 + + 1.1280882048999579 -1.355744873656693 0.7690075430192467 1.5998165256545325 + -0.1317476417813725 -1.1025165909280112 -0.3021237523242501 + + + 0.34922721327386 0.2544187267683622 -0.23797602153026748 -0.9325255350736625 + -0.5339868885922561 + + -0.34888928049964996 1.5717900352778775 -2.4235230361825377 -3.1936847156588173 + 1.033382006157963 1.4398494242490902 -1.408598692107348 + + + 4.633788329498635 0.025513529303246463 -2.443732777162112 1.2090413461783382 + -0.36235076644434394 + + -1.1175158385969923 0.7977314310652747 -0.41263195815871656 -1.1004180732091222 + 3.1683841433068776 -0.5792665847638991 -2.125789550224141 + + + 0.26199561449329933 -0.7681945069096061 -0.3333522196932032 -0.25251089597917753 + 0.3165621918267967 + + -2.9705259253836673 1.8180462291600987 -1.0024038852443646 -2.37109423669183 + 1.7039495010172414 0.6523385717003207 0.3992284223649003 + + + 3.8391267386181176 0.061335692830798594 -1.9163998620961513 -0.21829371464197167 + -1.8038733790086257 + + -0.512161766349629 0.6259026273768376 -0.9967259396816254 -0.7231226910083073 + 2.767421639199582 1.1583455108314518 -1.810162989938626 + + + 3.1235812313668436 -0.15413956210452046 -1.4572474915287827 -0.42857108953684164 + -0.8626894506789224 + + 0.20249312943869593 0.8586881236281271 -0.9994845963750876 -0.87500670309654 + 2.024879939582193 0.744016671955439 -0.8921186283447579 + + + 2.1776716264871685 0.4406415165915938 -1.6046877620978321 0.6146454271310179 + 0.3837293499874107 + + -0.5718006434394662 0.8883918013878271 -2.6613712194866785 -2.4481672338992686 + 2.4285471984949276 0.28597601677439966 -1.865778125813134 + + + 2.2870754811766894 -0.7037883745194677 -0.7634090007799686 0.5939602913039611 + -0.2804919898120651 + + -1.5953365754027373 1.4965540676528157 -1.5225801335095646 -3.146881507711289 + 1.970166451510055 1.4627229421016916 -0.5308113530396027 + + + 1.1426715463353532 -0.0999925247975047 0.09001180842232431 -0.9646165692256534 + -0.40334603713395084 + + 0.8701310224833447 -0.4169896743447055 0.14995996674615356 1.0436742157260925 + 1.083900380749431 -0.11434650747109375 -1.4015196191528552 + + + 1.273079938899428 0.4369415026913567 -0.9826076574025513 1.1740762605476174 + -0.41131607919062807 + + -0.8692973389413557 0.8407152116192129 -0.6615571078200725 -1.4608449561937624 + 0.4790459539717743 0.7519544172014794 -1.2775143982703772 + + + 2.493228173250216 0.2606437223602246 -0.9850353781834488 -0.342095435890081 + -0.6793122617468146 + + -1.2985761046741562 1.559977662746558 -1.2596828987869233 -2.582986892500161 + 2.191482907421893 1.110941863027046 -1.725131950914876 + + + 2.254921309478415 0.4108481807749138 -0.5759577339468747 -1.0496097575571455 + -0.4445100798539827 + + -0.30258144955033217 0.8754491173217693 -1.5011101755668328 -1.3544003946892313 + 2.267715545638793 0.5838311915321326 -1.5152596802137162 + + + -0.7151321685951372 0.24339185892322013 0.44881575262871953 -1.6343426922853364 + -0.2199007550600263 + + 2.062290040853019 -2.4878895897428164 2.7072976489959806 4.076461735375559 + -1.488600860827909 -0.7932600294096247 0.8611504379816385 + + + 0.4366423954981472 -0.019915088864700572 0.329550833349404 -0.31125597531265486 + 0.1596250699217144 + + 2.7369080038519704 -2.379032142677834 2.247236570931165 3.3752286660372497 + -1.0113520098220417 -1.4375116479085315 0.5961078861248504 + + + -0.19109912505008808 0.06525725755337802 0.6159151354401626 -2.4186903440497023 + -0.2394527913113511 + + -1.165523423326007 1.088516952719756 -0.7584430062830332 -1.8802763952754133 + 2.3421829541423014 0.07667501258553172 -0.026863317935092557 + + + -1.5003587098453073 0.11822780923495975 0.25409142323531975 0.5087007683262486 + -0.39114544196103546 + + -0.874162754332145 1.5907849041565474 -1.9192490320152527 -2.2272724524364547 + -0.9368204230400716 1.6013992901180072 0.2499325316583305 + + + 2.0463179182193723 0.3328225626246842 -1.0724542525659773 3.8426824741007612 + -0.30979748121207074 + + -1.2080903204748026 0.603258481038443 -0.5195797801351668 -1.084276229385134 + -0.26374137608704407 0.34849609187662023 -0.8218628220605041 + + + -1.4014683787564242 0.5867698864134574 0.36958452988610896 -0.3338814488597784 + -0.5142582202419861 + + -1.6347304490012586 0.5577797201657974 0.322759087038926 -1.2557496018953545 + 0.8114445397654532 1.1862147463946175 -0.8193139783386718 + + + 1.6606006332126921 0.03652131124686031 -1.1151345143326865 2.219236874215507 + -0.44505215450985547 + + -1.7536673837444967 0.9301902322074466 -0.7369167258508473 -1.72495814513816 + 0.30784080056051016 0.8091515420915065 -1.299157482204717 + + + 0.5570044203414558 -0.36447101789483005 0.38742583170270145 -2.7275830420569758 + 0.04142138241544252 + + -0.308356462512166 1.0811796144791508 -1.3160587371238175 -1.895189578145343 + 0.7288501869801459 0.989516016650147 -1.3218646761219353 + + + 3.5543406459263402 0.5075642073782366 -1.6562467874621971 -0.17619637173335806 + -1.2443283415848032 + + 0.10286773307164632 1.1052682214988807 -0.820626481796686 -1.1785087158182554 + 2.2924177414794693 1.506326123011939 -1.0142926607594427 + + + 0.2237236812202359 -0.2581101746121156 -0.6872732571127904 -1.336607878535933 + -0.004807005907825568 + + 0.43735632169585703 0.13731237374833052 0.3296402899174436 0.48468009859572314 + 0.9028532017953814 0.3374737124646578 -1.4223650060497173 + + + 1.5277204683946073 0.42043321284334945 -1.2917151012794923 0.35654973937399304 + -0.4928009626273261 + + 1.8263389142354534 -1.373669965900515 0.6708507740960747 2.1153837180916435 + 0.5824064502223352 -1.8819740875750386 -0.17322838711601674 + + + 2.1468019264083074 0.011688879057038415 -1.7381588811244213 0.681975760327472 + -0.40872543169552855 + + -1.0181571328463535 1.0036364468635801 -0.8444765674678829 -2.4421808777536658 + 2.955139085034866 -0.7581692402539767 -0.7262523460777199 + + + 1.87258523646835 -0.3186847068997547 -0.7677437037532908 0.38841506667748377 + -0.8063540421261158 + + -0.6238687683976614 2.080014240796421 -2.0978809082173244 -2.329337208047534 + 0.8957518415425533 1.4452416378580466 -0.1696665159708346 + + + 0.8865423247730155 0.27702669983427575 -0.5216957755253404 0.14724715556893936 + 0.11615717049572227 + + -0.05435186490269686 0.8796315496567562 -0.6588649654516191 -0.7980096342377572 + 0.5263870491434705 0.353443177624992 -1.1295877894312993 + + + -3.5802536410159083 -0.3019997850994086 2.044517919793755 -1.6524934881595539 + 1.3295171297401558 + + -1.2287449739416472 -0.5550252571660728 0.5359106796422701 0.032268148502316574 + -0.006092208824982848 0.02046339551476911 -0.6835189362199018 + + + -0.25167131733104986 0.05678541848200047 -0.04046984427760272 0.24258695110207773 + -0.1694766005702717 + + 1.965111227318497 -1.6371913029182628 1.9176947264052373 3.281015995115687 + -1.009153359748534 -1.2588301454193207 0.8709160258342199 + + + 0.6911008831420365 0.11623994725369718 0.1901466732124502 -1.0076546401990496 + -0.5486313390486429 + + -1.8601090510825014 0.23854995201877077 -0.3444372792101621 -0.598493050445065 + 1.7922664431570154 0.540417558812196 -0.25670070115306787 + + + 1.432075262155753 -0.07300098441294534 -0.30733154613135216 -1.9331798574492014 + -0.35141549371005804 + + 2.459038289692322 -2.3714353688977723 1.0807733881167962 3.3997012360526293 + -0.3551209288997105 -1.3983425741016333 0.6514554250465071 + + + + ((((((taxon_96:0.16291279166880093,(taxon_79:0.0572196284801809,taxon_92:0.0572196284801809):0.10569316318862003):0.23186176626087887,((taxon_51:0.07119012235505112,taxon_55:0.07119012235505112):0.030837938962528226,taxon_31:0.10202806131757935):0.29274649661210045):0.12404478135738557,((taxon_66:0.016379226481527308,taxon_17:0.016379226481527308):0.38422186028869326,((((taxon_38:0.1944538327742514,(taxon_97:0.17538124931356106,taxon_33:0.17538124931356106):0.01907258346069033):0.10733312106565472,taxon_88:0.3017869538399061):0.05168209762273184,taxon_18:0.35346905146263796):0.0413330467192093,((taxon_47:0.10361753835976575,(taxon_43:0.06939572377036309,taxon_86:0.06939572377036309):0.03422181458940267):0.17348881774850958,(taxon_8:0.12022227923025953,taxon_56:0.12022227923025953):0.15688407687801578):0.11769574207357193):0.005798988588373327):0.11821825251684479):0.27803487068959637,(((((taxon_59:0.14241201067687348,(taxon_37:0.13289257407002236,(taxon_26:0.012102542928237791,taxon_16:0.012102542928237791):0.12079003114178456):0.009519436606851114):0.260481530740412,((taxon_85:0.13321410402037606,taxon_99:0.13321410402037606):0.227754896360244,(taxon_95:0.26155039797143054,taxon_52:0.26155039797143054):0.09941860240918954):0.0419245410366654):0.0004594904660083952,(((taxon_39:0.1820446284968688,taxon_76:0.1820446284968688):0.09404145531357054,(taxon_57:0.04746929219772926,taxon_58:0.04746929219772926):0.2286167916127101):0.06905987063171622,((taxon_78:0.08644581555203126,taxon_54:0.08644581555203126):0.029657955856453262,(taxon_64:0.055168585423594145,taxon_9:0.055168585423594145):0.06093518598489039):0.22904218303367102):0.058207077441138316):0.08224254349660347,(taxon_1:0.07663322549298805,(taxon_82:0.06347853745888889,taxon_77:0.06347853745888889):0.013154688034099173):0.4089623498869093):0.148513778036282,(((taxon_72:0.168015463883959,((taxon_20:0.026652361900130928,taxon_5:0.026652361900130928):0.11045182095797572,((taxon_67:0.09659459826764794,((taxon_34:0.027096288296805875,taxon_19:0.027096288296805875):0.024019585999130653,(taxon_28:0.002309752323377199,taxon_90:0.002309752323377199):0.048806121972559326):0.04547872397171141):0.013229756729671536,taxon_32:0.10982435499731948):0.027279827860787177):0.030911281025852347):0.3391103070103022,(taxon_69:0.35563613977581465,(taxon_27:0.0906859359976699,taxon_60:0.0906859359976699):0.2649502037781447):0.15148963111844654):0.11657233288696214,((taxon_74:0.36685090182431235,((taxon_35:0.19671419214590474,taxon_49:0.19671419214590474):0.06845157023282315,(taxon_65:0.16339811555633743,(taxon_63:0.004995253499267516,taxon_68:0.004995253499267516):0.1584028620570699):0.10176764682239046):0.10168513944558445):0.006343194018020243,(taxon_48:0.3575083240602523,(taxon_23:0.09955167432485686,(taxon_94:0.08832463761271805,taxon_7:0.08832463761271805):0.011227036712138804):0.25795664973539545):0.0156857717820803):0.25050400793889077):0.010411249634956004):0.1627448565604824):0.08919946674857195,((((taxon_89:0.09279844397471991,((taxon_87:0.007696445808463369,taxon_13:0.007696445808463369):0.07885330559622813,(taxon_80:0.08028810668980015,(taxon_61:0.07692661240971177,taxon_3:0.07692661240971177):0.0033614942800883796):0.006261644714891361):0.0062486925700284005):0.05544442870326615,((taxon_53:0.026129327563896024,taxon_24:0.026129327563896024):0.1098590907608896,taxon_29:0.13598841832478561):0.01225445435320044):0.20291611459214093,taxon_40:0.35115898727012695):0.05046043275645241,((((taxon_75:0.1676654410835572,(((taxon_50:0.10141022456499973,taxon_10:0.10141022456499973):0.05077625235590244,(taxon_12:0.00018617878901610926,taxon_6:0.00018617878901610926):0.15200029813188606):0.009576409657510136,taxon_91:0.1617628865784123):0.005902554505144874):0.07303061086844011,(taxon_81:0.022654924388436373,taxon_25:0.022654924388436373):0.21804112756356092):0.1101992994590126,taxon_73:0.3508953514110099):0.009666472687346744,taxon_36:0.36056182409835663):0.04105759592822275):0.4844342566986543):0.11394632327476625,((((taxon_70:0.15363208451560414,taxon_11:0.15363208451560414):0.04089397813246364,(taxon_22:0.008085469685390046,taxon_2:0.008085469685390046):0.18644059296267773):0.18439221495135483,((taxon_93:0.09375019133038573,taxon_45:0.09375019133038573):0.16059515129484994,(taxon_4:0.11835388968821214,taxon_71:0.11835388968821214):0.13599145293702353):0.12457293497418694):0.083113612412,(((((taxon_84:0.13413488284455308,taxon_46:0.13413488284455308):0.021538096342384794,taxon_100:0.15567297918693787):0.016068728834798575,(((taxon_83:0.002163399488911959,taxon_42:0.002163399488911959):0.13880675809754586,taxon_15:0.14097015758645784):0.02289337795317344,taxon_98:0.1638635355396313):0.007878172482105164):0.08712454165033161,taxon_44:0.25886624967206806):0.019279221739069336,(((taxon_21:0.0395162349000485,taxon_14:0.0395162349000485):0.010381416470037045,taxon_62:0.04989765137008554):0.12477291581159908,(taxon_41:0.025744912045784222,taxon_30:0.025744912045784222):0.1489256551359004):0.10347490422945278):0.18388641860028523):0.5379681099885774); + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + From dc6281abcac5440d175da69b367105b6625a4c07 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Mon, 30 Aug 2021 15:34:53 -0700 Subject: [PATCH 032/196] proper gradient in FullCorrelationPrecisionGradient --- .../hmc/FullCorrelationPrecisionGradient.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dr/evomodel/treedatalikelihood/hmc/FullCorrelationPrecisionGradient.java b/src/dr/evomodel/treedatalikelihood/hmc/FullCorrelationPrecisionGradient.java index 4668fd1b4e..3faf5896c8 100644 --- a/src/dr/evomodel/treedatalikelihood/hmc/FullCorrelationPrecisionGradient.java +++ b/src/dr/evomodel/treedatalikelihood/hmc/FullCorrelationPrecisionGradient.java @@ -61,7 +61,7 @@ public Parameter getParameter() { DenseMatrix64F decompMat = DenseMatrix64F.wrap(dim, dim, decomposedMatrix.getParameterValues()); DenseMatrix64F decompGradMat = DenseMatrix64F.wrap(dim, dim, decomposedGradient); - CommonOps.mult(decompMat, corGradMat, decompGradMat); + CommonOps.multTransA(corGradMat, decompMat, decompGradMat); CommonOps.scale(2.0, decompGradMat); From eb97c631a26b34cf7c34b7c81d48a14b7f26a797 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Fri, 17 Sep 2021 11:11:36 -0700 Subject: [PATCH 033/196] refactoring GeodesicHMC for more flexible structures --- ...GeodesicHamiltonianMonteCarloOperator.java | 367 ++++++++++++------ ...icHamiltonianMonteCarloOperatorParser.java | 9 +- 2 files changed, 258 insertions(+), 118 deletions(-) diff --git a/src/dr/inference/operators/hmc/GeodesicHamiltonianMonteCarloOperator.java b/src/dr/inference/operators/hmc/GeodesicHamiltonianMonteCarloOperator.java index 407ddcab74..3044996cf6 100644 --- a/src/dr/inference/operators/hmc/GeodesicHamiltonianMonteCarloOperator.java +++ b/src/dr/inference/operators/hmc/GeodesicHamiltonianMonteCarloOperator.java @@ -18,6 +18,7 @@ import org.ejml.ops.CommonOps; import java.util.ArrayList; +import java.util.Collections; public class GeodesicHamiltonianMonteCarloOperator extends HamiltonianMonteCarloOperator implements Reportable { @@ -82,7 +83,7 @@ public String getReport() { return sb.toString(); } - public void setOrthogonalityStructure(ArrayList oldOrthogonalityStructure) { + public void setOrthogonalityStructure(ArrayList> oldOrthogonalityStructure) { ((GeodesicLeapFrogEngine) leapFrogEngine).setOrthogonalityStructure(oldOrthogonalityStructure); } @@ -95,12 +96,13 @@ public static class GeodesicLeapFrogEngine extends HamiltonianMonteCarloOperator // private final DenseMatrix64F innerProduct2; // private final DenseMatrix64F projection; // private final DenseMatrix64F momentumMatrix; - private final int nRows; +// private final int nRows; // private final int nCols; - private final int[] subRows; - private final int[] subColumns; - private final ArrayList orthogonalityStructure; + // private final int[] subRows; +// private final int[] subColumns; + private final ArrayList> orthogonalityStructure; + private final ArrayList> orthogonalityBlockRows; GeodesicLeapFrogEngine(Parameter parameter, HamiltonianMonteCarloOperator.InstabilityHandler instabilityHandler, @@ -108,14 +110,33 @@ public static class GeodesicLeapFrogEngine extends HamiltonianMonteCarloOperator super(parameter, instabilityHandler, preconditioning, mask); this.matrixParameter = (MatrixParameterInterface) parameter; - this.subRows = parseSubRowsFromMask(); - this.subColumns = parseSubColumnsFromMask(); - if (mask != null) checkMask(subRows, subColumns); + +// this.subRows = parseSubRowsFromMask(); +// this.subColumns = parseSubColumnsFromMask(); + this.orthogonalityStructure = new ArrayList<>(); - orthogonalityStructure.add(subColumns); + this.orthogonalityBlockRows = new ArrayList<>(); + + if (mask == null) { + ArrayList rows = new ArrayList<>(); + for (int i = 0; i < matrixParameter.getRowDimension(); i++) { + rows.add(i); + } + ArrayList cols = new ArrayList<>(); + for (int i = 0; i < matrixParameter.getColumnDimension(); i++) { + cols.add(i); + } + orthogonalityStructure.add(cols); + orthogonalityBlockRows.add(rows); + } else { + parseStructureFromMask(mask); + } +// orthogonalityStructure.add(subColumns); - this.nRows = subRows.length; +// this.nRows = subRows.length; + +// if (mask != null) checkMask(subRows, subColumns); // this.nCols = subColumns.length; // this.positionMatrix = new DenseMatrix64F(nCols, nRows); // this.innerProduct = new DenseMatrix64F(nCols, nCols); @@ -124,147 +145,246 @@ public static class GeodesicLeapFrogEngine extends HamiltonianMonteCarloOperator // this.momentumMatrix = new DenseMatrix64F(nCols, nRows); } - public void setOrthogonalityStructure(ArrayList oldOrthogonalityStructure) { - orthogonalityStructure.clear(); - - ArrayList subColList = new ArrayList<>(); - for (int i : subColumns) { - subColList.add(i); - } + private void parseStructureFromMask(double[] mask) { + int nRows = matrixParameter.getRowDimension(); + int nCols = matrixParameter.getColumnDimension(); - //check that orthogonalityStructure is consistent with the subRows - ArrayList alreadyOrthogonal = new ArrayList<>(); + ArrayList colRows = new ArrayList<>(); - for (int i = 0; i < oldOrthogonalityStructure.size(); i++) { - for (int j = 0; j < oldOrthogonalityStructure.get(i).length; j++) { - if (!subColList.contains(oldOrthogonalityStructure.get(i)[j])) { //TODO: check that we're doing this by row (or allow to do by row or column) - throw new RuntimeException("Cannot enforce orthogonality structure."); + for (int i = 0; i < nCols; i++) { + colRows.clear(); + int offset = i * nRows; + for (int j = 0; j < nRows; j++) { + if (mask[offset + j] == 1) { + colRows.add(j); } - if (alreadyOrthogonal.contains(oldOrthogonalityStructure.get(i)[j])) { - throw new RuntimeException("Orthogonal blocks must be non-overlapping"); - } - alreadyOrthogonal.add(oldOrthogonalityStructure.get(i)[j]); } - orthogonalityStructure.add(oldOrthogonalityStructure.get(i)); - } - for (int i = 0; i < subColumns.length; i++) { - if (!alreadyOrthogonal.contains(subColumns[i])) { - orthogonalityStructure.add(new int[]{subColumns[i]}); + if (!colRows.isEmpty()) { + int matchingInd = findMatchingArray(orthogonalityBlockRows, colRows); + if (matchingInd == -1) { + ArrayList newBlock = new ArrayList<>(); + newBlock.add(i); + orthogonalityStructure.add(newBlock); + orthogonalityBlockRows.add(new ArrayList<>(colRows)); + } else { + orthogonalityStructure.get(matchingInd).add(i); + } } } - } - private int[] parseSubColumnsFromMask() { - - int originalRows = matrixParameter.getRowDimension(); - int originalColumns = matrixParameter.getColumnDimension(); - - ArrayList subArray = new ArrayList(); + private int findMatchingArray(ArrayList> listOfLists, ArrayList list) { + int nLists = listOfLists.size(); + for (int i = 0; i < nLists; i++) { + ArrayList subList = listOfLists.get(i); + boolean matching = true; + if (list.size() == subList.size()) { + for (int j = 0; j < list.size(); j++) { + if (list.get(j) != subList.get(j)) { + matching = false; + break; + } + } - for (int col = 0; col < originalColumns; col++) { - int offset = col * originalRows; - for (int row = 0; row < originalRows; row++) { - int ind = offset + row; - if (mask == null || mask[ind] == 1.0) { - subArray.add(col); - break; + if (matching) { + return i; } } } - int[] subColumns = new int[subArray.size()]; - for (int i = 0; i < subColumns.length; i++) { - subColumns[i] = subArray.get(i); - } - - return subColumns; + return -1; } - private int[] parseSubRowsFromMask() { - int originalRows = matrixParameter.getRowDimension(); - int originalColumns = matrixParameter.getColumnDimension(); - - ArrayList subArray = new ArrayList(); + private int findSubArray(ArrayList> listOfLists, ArrayList list, ArrayList remainingList) { //assumes both are sorted + remainingList.clear(); + int nLists = listOfLists.size(); + for (int i = 0; i < nLists; i++) { + ArrayList subList = listOfLists.get(i); + if (list.size() <= subList.size()) { + int currentInd = 0; + for (int j = 0; j < subList.size(); j++) { + + if (currentInd < list.size() && subList.get(j) == list.get(currentInd)) { + currentInd += 1; + } else { + remainingList.add(subList.get(j)); + } + } - for (int row = 0; row < originalRows; row++) { - for (int col = 0; col < originalColumns; col++) { - int ind = col * originalRows + row; - if (mask == null || mask[ind] == 1.0) { - subArray.add(row); - break; + if (currentInd == list.size()) { + return i; } - } - } - int[] subRows = new int[subArray.size()]; - for (int i = 0; i < subRows.length; i++) { - subRows[i] = subArray.get(i); + } } - return subRows; + return -1; } - private void checkMask(int[] rows, int[] cols) { - int originalRows = matrixParameter.getRowDimension(); - int originalColumns = matrixParameter.getColumnDimension(); - int subRowInd = 0; - int subColInd = 0; + public void setOrthogonalityStructure(ArrayList> newOrthogonalColumns) { - Boolean isSubRow; - Boolean isSubCol; - - for (int row = 0; row < originalRows; row++) { - if (row == rows[subRowInd]) { - isSubRow = true; - subRowInd++; - } else { - isSubRow = false; + for (int i = 0; i < newOrthogonalColumns.size(); i++) { + ArrayList remainingList = new ArrayList<>(); + ArrayList cols = newOrthogonalColumns.get(i); + Collections.sort(cols); + int matchingCol = findSubArray(orthogonalityStructure, cols, remainingList); + if (matchingCol == -1) { + throw new RuntimeException("Orthogonality structure incompatible with itself or mask."); } - subColInd = 0; + ArrayList existingCols = orthogonalityStructure.get(matchingCol); - for (int col = 0; col < originalColumns; col++) { - if (col == cols[subColInd]) { - isSubCol = true; - subColInd++; - } else { - isSubCol = false; - } - int ind = originalRows * col + row; - - if (isSubCol && isSubRow) { - if (mask[ind] != 1.0) { - throw new RuntimeException("mask is incompatible with " + - GeodesicHamiltonianMonteCarloOperatorParser.OPERATOR_NAME + - ". All elements in sub-matrix must be set to 1."); - } - } else { - if (mask[ind] != 0.0) { - throw new RuntimeException("mask is incompatible with " + - GeodesicHamiltonianMonteCarloOperatorParser.OPERATOR_NAME + - ". All elements outside of sub-matrix must be set to 0."); - } - } + if (remainingList.size() > 0) { + orthogonalityStructure.set(matchingCol, remainingList); + orthogonalityStructure.add(cols); + orthogonalityBlockRows.add(orthogonalityBlockRows.get(matchingCol)); } + } +// ArrayList subColList = new ArrayList<>(); +// for (int i : subColumns) { +// subColList.add(i); +// } +// +// //check that orthogonalityStructure is consistent with the subRows +// ArrayList alreadyOrthogonal = new ArrayList<>(); +// +// for (int i = 0; i < newOrthogonalColumns.size(); i++) { +// for (int j = 0; j < newOrthogonalColumns.get(i).length; j++) { +// if (!subColList.contains(newOrthogonalColumns.get(i)[j])) { //TODO: check that we're doing this by row (or allow to do by row or column) +// throw new RuntimeException("Cannot enforce orthogonality structure."); +// } +// if (alreadyOrthogonal.contains(newOrthogonalColumns.get(i)[j])) { +// throw new RuntimeException("Orthogonal blocks must be non-overlapping"); +// } +// alreadyOrthogonal.add(newOrthogonalColumns.get(i)[j]); +// } +// orthogonalityStructure.add(newOrthogonalColumns.get(i)); +// } +// +// for (int i = 0; i < subColumns.length; i++) { +// if (!alreadyOrthogonal.contains(subColumns[i])) { +// orthogonalityStructure.add(new int[]{subColumns[i]}); +// } +// } + } +// private int[] parseSubColumnsFromMask() { +// +// int originalRows = matrixParameter.getRowDimension(); +// int originalColumns = matrixParameter.getColumnDimension(); +// +// ArrayList subArray = new ArrayList(); +// +// for (int col = 0; col < originalColumns; col++) { +// int offset = col * originalRows; +// for (int row = 0; row < originalRows; row++) { +// int ind = offset + row; +// if (mask == null || mask[ind] == 1.0) { +// subArray.add(col); +// break; +// } +// } +// } +// +// int[] subColumns = new int[subArray.size()]; +// for (int i = 0; i < subColumns.length; i++) { +// subColumns[i] = subArray.get(i); +// } +// +// return subColumns; +// } +// +// private int[] parseSubRowsFromMask() { +// int originalRows = matrixParameter.getRowDimension(); +// int originalColumns = matrixParameter.getColumnDimension(); +// +// ArrayList subArray = new ArrayList(); +// +// for (int row = 0; row < originalRows; row++) { +// for (int col = 0; col < originalColumns; col++) { +// int ind = col * originalRows + row; +// if (mask == null || mask[ind] == 1.0) { +// subArray.add(row); +// break; +// } +// } +// } +// +// int[] subRows = new int[subArray.size()]; +// for (int i = 0; i < subRows.length; i++) { +// subRows[i] = subArray.get(i); +// } +// +// return subRows; +// } + +// private void checkMask(int[] rows, int[] cols) { +// int originalRows = matrixParameter.getRowDimension(); +// int originalColumns = matrixParameter.getColumnDimension(); +// +// int subRowInd = 0; +// int subColInd = 0; +// +// Boolean isSubRow; +// Boolean isSubCol; +// +// for (int row = 0; row < originalRows; row++) { +// if (subRowInd < rows.length && row == rows[subRowInd]) { +// isSubRow = true; +// subRowInd++; +// } else { +// isSubRow = false; +// } +// +// subColInd = 0; +// +// for (int col = 0; col < originalColumns; col++) { +// if (subColInd < cols.length && col == cols[subColInd]) { +// isSubCol = true; +// subColInd++; +// } else { +// isSubCol = false; +// } +// +// int ind = originalRows * col + row; +// +// if (isSubCol && isSubRow) { +// if (mask[ind] != 1.0) { +// throw new RuntimeException("mask is incompatible with " + +// GeodesicHamiltonianMonteCarloOperatorParser.OPERATOR_NAME + +// ". All elements in sub-matrix must be set to 1."); +// } +// } else { +// if (mask[ind] != 0.0) { +// throw new RuntimeException("mask is incompatible with " + +// GeodesicHamiltonianMonteCarloOperatorParser.OPERATOR_NAME + +// ". All elements outside of sub-matrix must be set to 0."); +// } +// } +// +// } +// } +// } + private DenseMatrix64F setOrthogonalSubMatrix(double[] src, int srcOffset, int block) { int nRowsOriginal = matrixParameter.getRowDimension(); - int[] blockCols = orthogonalityStructure.get(block); - int nCols = blockCols.length; + ArrayList blockCols = orthogonalityStructure.get(block); + ArrayList blockRows = orthogonalityBlockRows.get(block); + int nCols = blockCols.size(); + int nRows = blockRows.size(); DenseMatrix64F dest = new DenseMatrix64F(nCols, nRows); for (int row = 0; row < nRows; row++) { for (int col = 0; col < nCols; col++) { - int ind = nRowsOriginal * blockCols[col] + subRows[row] + srcOffset; + int ind = nRowsOriginal * blockCols.get(col) + blockRows.get(row) + srcOffset; dest.set(col, row, src[ind]); } } @@ -278,10 +398,12 @@ private DenseMatrix64F setOrthogonalSubMatrix(double[] src, int block) { private void unwrapSubMatrix(DenseMatrix64F src, int block, double[] dest, int destOffset) { int nRowsOriginal = matrixParameter.getRowDimension(); - int[] blockCols = orthogonalityStructure.get(block); - for (int row = 0; row < nRows; row++) { - for (int col = 0; col < blockCols.length; col++) { - int ind = nRowsOriginal * blockCols[col] + subRows[row] + destOffset; + ArrayList blockCols = orthogonalityStructure.get(block); + ArrayList blockRows = orthogonalityBlockRows.get(block); + + for (int row = 0; row < blockRows.size(); row++) { + for (int col = 0; col < blockCols.size(); col++) { + int ind = nRowsOriginal * blockCols.get(col) + blockRows.get(row) + destOffset; dest[ind] = src.get(col, row); } } @@ -305,7 +427,8 @@ public void updatePosition(double[] position, WrappedVector momentum, for (int block = 0; block < orthogonalityStructure.size(); block++) { - int nCols = orthogonalityStructure.get(block).length; + int nCols = orthogonalityStructure.get(block).size(); + int nRows = orthogonalityBlockRows.get(block).size(); // positionMatrix.setData(position); DenseMatrix64F positionMatrix = setOrthogonalSubMatrix(position, block); @@ -383,6 +506,19 @@ public void updatePosition(double[] position, WrappedVector momentum, DenseMatrix64F projection = new DenseMatrix64F(nCols, nRows); CommonOps.mult(innerProduct, positionMatrix, projection); + + double sse = 0; + for (int i = 0; i < positionMatrix.data.length; i++) { + double diff = projection.data[i] - positionMatrix.data[i]; + sse += diff * diff; + } + + if (sse > 1e-3) { + System.err.println("unstable"); //TODO: REMOVE + throw new NumericInstabilityException(); + } + + System.arraycopy(projection.data, 0, positionMatrix.data, 0, positionMatrix.data.length); unwrapSubMatrix(positionMatrix, block, position); @@ -401,7 +537,8 @@ public void projectMomentum(double[] momentum, double[] position) { DenseMatrix64F positionMatrix = setOrthogonalSubMatrix(position, block); DenseMatrix64F momentumMatrix = setOrthogonalSubMatrix(momentum, block); - int nCols = orthogonalityStructure.get(block).length; + int nCols = orthogonalityStructure.get(block).size(); + int nRows = orthogonalityBlockRows.get(block).size(); // positionMatrix.setData(position); // momentumMatrix.setData(momentum); diff --git a/src/dr/inferencexml/operators/hmc/GeodesicHamiltonianMonteCarloOperatorParser.java b/src/dr/inferencexml/operators/hmc/GeodesicHamiltonianMonteCarloOperatorParser.java index f92522100a..2f86f685b6 100644 --- a/src/dr/inferencexml/operators/hmc/GeodesicHamiltonianMonteCarloOperatorParser.java +++ b/src/dr/inferencexml/operators/hmc/GeodesicHamiltonianMonteCarloOperatorParser.java @@ -29,14 +29,17 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { GeodesicHamiltonianMonteCarloOperator hmc = (GeodesicHamiltonianMonteCarloOperator) super.parseXMLObject(xo); if (xo.hasChildNamed(ORTHOGONALITY_STRUCTURE)) { XMLObject cxo = xo.getChild(ORTHOGONALITY_STRUCTURE); - ArrayList orthogonalityStructure = new ArrayList<>(); + ArrayList> orthogonalityStructure = new ArrayList<>(); for (int i = 0; i < cxo.getChildCount(); i++) { XMLObject group = (XMLObject) cxo.getChild(i); int[] rows = group.getIntegerArrayAttribute(ROWS); + ArrayList rowList = new ArrayList<>(); + for (int j = 0; j < rows.length; j++) { - rows[j]--; + rowList.add(rows[j] - 1); } - orthogonalityStructure.add(rows); + + orthogonalityStructure.add(rowList); } hmc.setOrthogonalityStructure(orthogonalityStructure); From b9b710b084e445bd62091ba36b2955b7ed370ffd Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Fri, 17 Sep 2021 11:29:58 -0700 Subject: [PATCH 034/196] prior on determinant of norm-constrained matrix --- .../MultivariateDistributionLikelihood.java | 40 +++++++++++ ...nstrainedDeterminantDistributionModel.java | 67 +++++++++++++++++++ 2 files changed, 107 insertions(+) create mode 100644 src/dr/math/distributions/ConstrainedDeterminantDistributionModel.java diff --git a/src/dr/inference/distribution/MultivariateDistributionLikelihood.java b/src/dr/inference/distribution/MultivariateDistributionLikelihood.java index 0e432cfbda..7b1e81da85 100644 --- a/src/dr/inference/distribution/MultivariateDistributionLikelihood.java +++ b/src/dr/inference/distribution/MultivariateDistributionLikelihood.java @@ -76,6 +76,7 @@ public class MultivariateDistributionLikelihood extends AbstractDistributionLike public static final String SPHERICAL_BETA_PRIOR = "sphericalBetaPrior"; public static final String SPHERICAL_BETA_SHAPE = "shapeParameter"; public static final String MV_LOG_NORMAL_PRIOR = "MVlogNormalPrior"; + public static final String DETERMINANT_PRIOR = "determinantPrior"; public static final String DATA = "data"; @@ -758,6 +759,45 @@ public Class getReturnType() { } }; + public static XMLObjectParser DETERMINANT_PRIOR_PARSER = new AbstractXMLObjectParser() { + @Override + public Object parseXMLObject(XMLObject xo) throws XMLParseException { + double shape = xo.getDoubleAttribute(LKJ_SHAPE); + MatrixParameterInterface parameter = (MatrixParameterInterface) xo.getChild(MatrixParameterInterface.class); + int dim = parameter.getRowDimension(); + if (parameter.getColumnDimension() != dim) { + throw new XMLParseException("matrix must be square"); + } + MultivariateDistributionLikelihood likelihood = new MultivariateDistributionLikelihood(new ConstrainedDeterminantDistributionModel(shape, dim)); + likelihood.addData(parameter); + return likelihood; + } + + @Override + public XMLSyntaxRule[] getSyntaxRules() { + return new XMLSyntaxRule[]{ + AttributeRule.newDoubleRule(LKJ_SHAPE, true), + new ElementRule(MatrixParameterInterface.class) + + }; + } + + @Override + public String getParserDescription() { + return "Calculates p(X) ∝ det(X)^a"; + } + + @Override + public Class getReturnType() { + return Likelihood.class; + } + + @Override + public String getParserName() { + return DETERMINANT_PRIOR; + } + }; + public static XMLObjectParser LKJ_PRIOR_PARSER = new AbstractXMLObjectParser() { public String getParserName() { diff --git a/src/dr/math/distributions/ConstrainedDeterminantDistributionModel.java b/src/dr/math/distributions/ConstrainedDeterminantDistributionModel.java new file mode 100644 index 0000000000..b65308ade6 --- /dev/null +++ b/src/dr/math/distributions/ConstrainedDeterminantDistributionModel.java @@ -0,0 +1,67 @@ +package dr.math.distributions; + + +import dr.inference.model.GradientProvider; +import org.ejml.data.DenseMatrix64F; +import org.ejml.ops.CommonOps; + +/** + * @author Gabriel Hassler + * @author Marc A. Suchard + */ + +public class ConstrainedDeterminantDistributionModel implements MultivariateDistribution, GradientProvider { + + private final int dim; + private final double shape; + + public ConstrainedDeterminantDistributionModel(double shape, int dim) { + this.dim = dim; + this.shape = shape; + } + + + @Override + public int getDimension() { + return dim * dim; + } + + @Override + public double[] getGradientLogDensity(Object x) { + return gradLogPdf((double[]) x); + } + + private double[] gradLogPdf(double[] x) { + DenseMatrix64F X = DenseMatrix64F.wrap(dim, dim, x); + + DenseMatrix64F Xinv = new DenseMatrix64F(dim, dim); + CommonOps.invert(X, Xinv); + + CommonOps.scale(shape, Xinv); + CommonOps.transpose(Xinv); + return Xinv.getData(); + } + + @Override + public double logPdf(double[] x) { + DenseMatrix64F X = DenseMatrix64F.wrap(dim, dim, x); + double det = CommonOps.det(X); + + return shape * Math.log(Math.abs(det)); //TODO: normalizing constant + } + + @Override + public double[][] getScaleMatrix() { + throw new RuntimeException("not implemented"); + } + + @Override + public double[] getMean() { + throw new RuntimeException("not implemented"); + } + + @Override + public String getType() { + return "ConstrainedDeterminant"; + } +} From 447661afbabdc1ae785010866b24fccff2028540 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Fri, 17 Sep 2021 11:33:30 -0700 Subject: [PATCH 035/196] removing unnecessary variable in NormalExtensionGibbsProviderParser --- .../operators/NormalExtensionGibbsProviderParser.java | 1 - 1 file changed, 1 deletion(-) diff --git a/src/dr/inferencexml/operators/NormalExtensionGibbsProviderParser.java b/src/dr/inferencexml/operators/NormalExtensionGibbsProviderParser.java index 83cc2b8500..9729abe77b 100644 --- a/src/dr/inferencexml/operators/NormalExtensionGibbsProviderParser.java +++ b/src/dr/inferencexml/operators/NormalExtensionGibbsProviderParser.java @@ -12,7 +12,6 @@ public class NormalExtensionGibbsProviderParser extends AbstractXMLObjectParser { - private static final String TREE_TRAIT_NAME = "treeTraitName"; private static final String NORMAL_EXTENSION = "normalExtension"; From a27200892495015a8f594536cea8852be072fc7d Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Fri, 17 Sep 2021 16:44:40 -0700 Subject: [PATCH 036/196] idiot-proofing & code cleanup --- .../continuous/ContinuousTraitDataModel.java | 2 +- .../continuous/RepeatedMeasuresTraitDataModel.java | 7 ------- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/src/dr/evomodel/treedatalikelihood/continuous/ContinuousTraitDataModel.java b/src/dr/evomodel/treedatalikelihood/continuous/ContinuousTraitDataModel.java index 7b3f06ab5a..ef98db36a3 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/ContinuousTraitDataModel.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/ContinuousTraitDataModel.java @@ -268,7 +268,7 @@ public ContinuousExtensionDelegate getExtensionDelegate(ContinuousDataLikelihood @Override public double[] transformTreeTraits(double[] treeTraits) { - return treeTraits; + return treeTraits.clone(); } /* diff --git a/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java b/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java index 41cde39d06..96c6d5fd29 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java @@ -48,7 +48,6 @@ import org.ejml.ops.CommonOps; import java.util.Arrays; -import java.util.List; /** * @author Marc A. Suchard @@ -60,7 +59,6 @@ public class RepeatedMeasuresTraitDataModel extends ContinuousTraitDataModel imp private final String traitName; private final MatrixParameterInterface samplingPrecisionParameter; private boolean diagonalOnly = false; - // private DenseMatrix64F samplingVariance; private boolean variableChanged = true; private boolean varianceKnown = false; @@ -256,11 +254,6 @@ public MatrixParameterInterface getExtensionPrecision() { return samplingPrecisionParameter; } - @Override - public double[] transformTreeTraits(double[] treeTraits) { - return treeTraits; - } - @Override public int getDataDimension() { return dimTrait; From 15bdff31e1d0b16f2a4780838665beceb6ea34fc Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Fri, 17 Sep 2021 16:46:42 -0700 Subject: [PATCH 037/196] generalizing RepeatedMeasuresWishartStatistics --- .../continuous/RepeatedMeasuresWishartStatistics.java | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresWishartStatistics.java b/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresWishartStatistics.java index ac7627897d..d69c6ed64b 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresWishartStatistics.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresWishartStatistics.java @@ -4,6 +4,7 @@ import dr.evolution.tree.TreeTrait; import dr.evomodel.treedatalikelihood.TreeDataLikelihood; import dr.evomodel.treedatalikelihood.preorder.ContinuousExtensionDelegate; +import dr.evomodel.treedatalikelihood.preorder.ModelExtensionProvider; import dr.inference.model.MatrixParameterInterface; import dr.math.distributions.WishartSufficientStatistics; import dr.math.interfaces.ConjugateWishartStatisticsProvider; @@ -11,8 +12,6 @@ import org.ejml.data.DenseMatrix64F; import org.ejml.ops.CommonOps; -import static dr.evomodel.treedatalikelihood.preorder.AbstractRealizedContinuousTraitDelegate.REALIZED_TIP_TRAIT; - /** * @author Gabriel Hassler * @author Marc A. Suchard @@ -20,10 +19,9 @@ public class RepeatedMeasuresWishartStatistics implements ConjugateWishartStatisticsProvider { - private final RepeatedMeasuresTraitDataModel traitModel; + private final ModelExtensionProvider.NormalExtensionProvider traitModel; private final Tree tree; private final TreeTrait tipTrait; - private final String traitName; private final ContinuousExtensionDelegate extensionDelegate; private final ContinuousDataLikelihoodDelegate likelihoodDelegate; private final double[] outerProduct; @@ -32,11 +30,10 @@ public class RepeatedMeasuresWishartStatistics implements ConjugateWishartStatis private final double[] buffer; private boolean forceResample; - public RepeatedMeasuresWishartStatistics(RepeatedMeasuresTraitDataModel traitModel, + public RepeatedMeasuresWishartStatistics(ModelExtensionProvider.NormalExtensionProvider traitModel, TreeDataLikelihood treeLikelihood, boolean forceResample) { this.traitModel = traitModel; - this.traitName = traitModel.getTraitName(); this.tree = treeLikelihood.getTree(); this.tipTrait = treeLikelihood.getTreeTrait(traitModel.getTipTraitName()); @@ -65,7 +62,7 @@ public WishartSufficientStatistics getWishartStatistics() { if (forceResample) { likelihoodDelegate.fireModelChanged(); } - double[] treeValues = (double[]) tipTrait.getTrait(tree, null); + double[] treeValues = traitModel.transformTreeTraits((double[]) tipTrait.getTrait(tree, null)); double[] dataValues = extensionDelegate.getExtendedValues(treeValues); DenseMatrix64F XminusY = DenseMatrix64F.wrap(nTaxa, dimTrait, buffer); From 0ca9f45fcab99adb2361015a6f66518b56119427 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Fri, 17 Sep 2021 16:47:51 -0700 Subject: [PATCH 038/196] test class to make sure tree traits are partitioned correctly --- .../app/beast/development_parsers.properties | 3 +- .../continuous/TreeTraitProviderTest.java | 125 ++++++++++++++++++ 2 files changed, 127 insertions(+), 1 deletion(-) create mode 100644 src/dr/evomodel/treedatalikelihood/continuous/TreeTraitProviderTest.java diff --git a/src/dr/app/beast/development_parsers.properties b/src/dr/app/beast/development_parsers.properties index 47158d3291..6847796ad7 100644 --- a/src/dr/app/beast/development_parsers.properties +++ b/src/dr/app/beast/development_parsers.properties @@ -308,4 +308,5 @@ dr.inference.model.FactorProportionStatistic dr.inference.model.MaskFromTree # Structural Equation Modeling -dr.util.MatrixInnerProductTransform \ No newline at end of file +dr.util.MatrixInnerProductTransform +dr.evomodel.treedatalikelihood.continuous.TreeTraitProviderTest \ No newline at end of file diff --git a/src/dr/evomodel/treedatalikelihood/continuous/TreeTraitProviderTest.java b/src/dr/evomodel/treedatalikelihood/continuous/TreeTraitProviderTest.java new file mode 100644 index 0000000000..cd27ab1482 --- /dev/null +++ b/src/dr/evomodel/treedatalikelihood/continuous/TreeTraitProviderTest.java @@ -0,0 +1,125 @@ +package dr.evomodel.treedatalikelihood.continuous; + +import dr.evolution.tree.Tree; +import dr.evolution.tree.TreeTrait; +import dr.evomodel.treedatalikelihood.TreeDataLikelihood; +import dr.evomodel.treedatalikelihood.preorder.ModelExtensionProvider; +import dr.math.matrixAlgebra.Matrix; +import dr.xml.*; + +public class TreeTraitProviderTest implements Reportable { + + private final TreeDataLikelihood treeDataLikelihood; + private final ModelExtensionProvider.NormalExtensionProvider modelExtensionProvider; + private final Tree tree; + private final TreeTrait tipTrait; + private static final int REPS = 1000; + + + public TreeTraitProviderTest(TreeDataLikelihood treeDataLikelihood, + ModelExtensionProvider.NormalExtensionProvider modelExtension) { + + this.treeDataLikelihood = treeDataLikelihood; + this.modelExtensionProvider = modelExtension; + this.tree = treeDataLikelihood.getTree(); + this.tipTrait = treeDataLikelihood.getTreeTrait(modelExtension.getTipTraitName()); + + } + + + @Override + public String getReport() { + double[] meanTipTraits = (double[]) tipTrait.getTrait(tree, null); + double[] meanTransformedTraits = modelExtensionProvider.transformTreeTraits(meanTipTraits); + + + for (int i = 1; i < REPS; i++) { //start at 1 because they're already initialized with data + treeDataLikelihood.fireModelChanged(); // force new sample + double[] meanTipTraitsNew = (double[]) tipTrait.getTrait(tree, null); + double[] meanTransformedTraitsNew = modelExtensionProvider.transformTreeTraits(meanTipTraitsNew); + for (int j = 0; j < meanTipTraits.length; j++) { + meanTipTraits[j] += meanTipTraitsNew[j]; + } + for (int j = 0; j < meanTransformedTraits.length; j++) { + meanTransformedTraits[j] += meanTransformedTraitsNew[j]; + } + } + + for (int j = 0; j < meanTipTraits.length; j++) { + meanTipTraits[j] /= REPS; + } + for (int j = 0; j < meanTransformedTraits.length; j++) { + meanTransformedTraits[j] /= REPS; + } + + int traitDim = modelExtensionProvider.getTraitDimension(); + int dataDim = modelExtensionProvider.getDataDimension(); + int taxonCount = tree.getTaxonCount(); + + Matrix matrix = new Matrix(taxonCount, traitDim); + Matrix matrixTransformed = new Matrix(taxonCount, dataDim); + + + for (int taxon = 0; taxon < taxonCount; taxon++) { + for (int trait = 0; trait < traitDim; trait++) { + matrix.set(taxon, trait, meanTipTraits[taxon * traitDim + trait]); + } + for (int factor = 0; factor < dataDim; factor++) { + matrixTransformed.set(taxon, factor, meanTransformedTraits[taxon * dataDim + factor]); + } + } + + StringBuilder sb = new StringBuilder("Normal extension gibbs report for trait " + + modelExtensionProvider.getTipTraitName() + ":\n"); + sb.append("\ttaxon order:"); + for (int i = 0; i < taxonCount; i++) { + sb.append(" " + tree.getTaxonId(i)); + } + sb.append("\n"); + sb.append("\ttree trait values:\n"); + sb.append(matrix.toString(2)); + sb.append("\n"); + sb.append("\ttransformed trait values:\n"); + sb.append(matrixTransformed.toString(2)); + sb.append("\n\n"); + return sb.toString(); + } + + public static final AbstractXMLObjectParser PARSER = new AbstractXMLObjectParser() { + + private static final String TRAIT_PROVIDER_TEST = "treeTraitReporter"; + + + @Override + public Object parseXMLObject(XMLObject xo) throws XMLParseException { + TreeDataLikelihood dataLikelihood = (TreeDataLikelihood) xo.getChild(TreeDataLikelihood.class); + ModelExtensionProvider.NormalExtensionProvider extensionProvider = + (ModelExtensionProvider.NormalExtensionProvider) + xo.getChild(ModelExtensionProvider.NormalExtensionProvider.class); + return new TreeTraitProviderTest(dataLikelihood, extensionProvider); + } + + @Override + public XMLSyntaxRule[] getSyntaxRules() { + return new XMLSyntaxRule[]{ + new ElementRule(TreeDataLikelihood.class), + new ElementRule(ModelExtensionProvider.NormalExtensionProvider.class) + }; + } + + @Override + public String getParserDescription() { + return "Calculates the average tree traits (and transformed tree traits)"; + } + + @Override + public Class getReturnType() { + return TreeTraitProviderTest.class; + } + + @Override + public String getParserName() { + return TRAIT_PROVIDER_TEST; + } + }; +} From 31ecaf657a4bc2ed4fc1012a05cc45fe724e8183 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Fri, 17 Sep 2021 16:48:34 -0700 Subject: [PATCH 039/196] test xml for tree traits under JointPartialsProvider --- .../testJointNormalExtensionProvider.xml | 443 ++++++++++++++++++ 1 file changed, 443 insertions(+) create mode 100644 ci/TestXML/testJointNormalExtensionProvider.xml diff --git a/ci/TestXML/testJointNormalExtensionProvider.xml b/ci/TestXML/testJointNormalExtensionProvider.xml new file mode 100644 index 0000000000..a4a0cf173f --- /dev/null +++ b/ci/TestXML/testJointNormalExtensionProvider.xml @@ -0,0 +1,443 @@ + + + + + 0.009384774731272527 -2.2489794932795935 0.21352181534425757 -0.9271419643920304 + 0.8908861115024072 + + -0.6626058145275399 -1.1376289492948137 0.5516715218486571 -1.3542281106636904 + 0.3425038907927388 1.9765519199971684 + + -2.4815614340693783 1.1715540851082962 -2.39142491538794 -0.27989415293453157 + + + 0.23114357993799609 0.013952561858756735 -0.5715417064973886 -0.4990804301056488 + 0.15801526456903897 + + 4.612640226355059 -2.270688683027667 0.06636421634789397 -2.3625614660756753 + -0.8862427605874439 0.128866505745016 + + 0.5874473809484401 -1.063856568127724 2.337606463609135 -0.07230648327331456 + + + 1.4187251180998413 -0.8170425382025268 0.07892424466259318 -0.3143517014495449 + -0.5472602804707967 + + -0.2728961572270844 2.309402233426405 -0.23416506480699462 0.9735063577970848 + 0.7151577438330944 -0.2347586744790381 + + -1.38029611269419 1.2045730316000665 -2.2938521607202462 -0.509902784415168 + + + -0.28339647175256877 -0.9419386451678324 0.26650579966179655 -0.6755373526537636 + -0.274877664411792 + + 1.1287595908409713 -3.1116614928659736 -0.06491592308086552 1.164874512948666 + -1.5336645088386092 -0.20428682533763876 + + -0.7828258802650758 -0.2182636672376014 0.6182103263813437 -0.9187485631673218 + + + -0.023499282722710884 0.9994051565722422 0.9578990579896408 0.8112526922930889 + -2.6974732014859786 + + 3.143509232145515 -3.366997883122405 0.1885442249025096 1.6584275614560118 + -2.66345546086915 -0.590589928945624 + + -3.421283834293406 1.5709096082117868 -2.4315878382033373 -0.2740496928659121 + + + -0.36284138795968407 1.214692899789406 0.4272574497473184 0.042273981909662195 + 0.4670760468843862 + + 0.6358323758485376 1.0193852357944098 -0.06189221405150272 1.3515630490382535 + -0.84387472620449 -0.14804040300618612 + + 0.1747727600285431 -0.4182716110970156 -0.7131407512876877 0.972963830403999 + + + 0.3988882987788642 -3.5925427617821164 -0.44162891549217226 -0.028944636921210878 + -1.1717821176724494 + + -0.22158862440099492 3.8444687603027385 1.4126292387163153 -2.928646472046473 + 0.04954047312222701 -2.7504755938766206 + + 0.021643414180460174 -0.4321387164642869 1.2959001126066132 -0.42378024950560955 + + + 0.549188719359929 -0.11646287269704525 0.12678272760083498 2.3916555239378905 + 1.0747211613241887 + + 1.0743759617094972 1.9818670093948083 1.348126601620159 -3.314507305497285 + -0.2006444474345116 -0.9088558225552659 + + 0.2630370864114636 -0.8181174721922079 -0.7089288142035404 1.6537379221433017 + + + 1.533349801858807 -0.8042394873096373 -0.01143188542494187 0.27028387196740455 + -0.5971074429239394 + + 0.5863288191090735 0.42370418900515106 -0.26783621859956946 0.04133111427212976 + 0.35876892318386006 -0.1964572683557873 + + 2.1152692357575718 0.276819521004338 0.2564027788580229 -0.40434147564812495 + + + 0.8023272733249729 0.042684825645238955 0.39572262698333827 -0.9915802597590105 + -0.8267974999597728 + + 1.9537056424001604 0.5925477606958199 -0.0497984821438642 -0.9489121452722484 + -0.7291919934293225 -1.7306473188245368 + + 2.1802316612653865 -0.2655432401198189 -0.5372409702155931 -1.1318047301845096 + + + + (taxon9:0.008348686686952765,((((taxon6:0.016572492255427156,taxon8:0.1693267472384557):0.11628259080566952,(taxon2:0.25890343887261164,((taxon3:0.0208741196691072,taxon10:0.017517882286128703):0.01551427131818077,taxon7:0.4845369967448872):0.12609158351609934):0.17835201769192324):0.11681134995255121,taxon4:0.30670476579014777):0.006431542969055358,(taxon1:0.04588745970630809,taxon5:0.15083832568450478):0.09255924641281718):0.08777650912548374); + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Check 1 + + + + + + -0.0112659 -0.0421795 + 0.0423912 0.687461 + -0.0149849 1.01539 + 0.149458 0.229179 + -0.217187 -0.525185 + -0.220403 -0.65122 + -0.423775 -0.904478 + 0.00605764 -0.270155 + -0.036566 0.0716805 + 0.0792544 0.46777 + + + + + + Check 1 + + + + + + 0.021454 -0.0233714 -0.00979186 -0.0147636 0.00250097 + -0.0946056 0.301863 0.256002 0.320374 -0.0921609 + 0.000410885 0.40242 0.431092 0.517016 -0.164365 + -0.275944 0.176384 -0.0070383 0.0303856 0.0185279 + 0.406037 -0.334055 -0.0694146 -0.140392 0.00314756 + 0.415154 -0.386846 -0.119458 -0.2017 0.0217018 + 0.789103 -0.603151 -0.085595 -0.213701 -0.0126798 + -0.00384924 -0.105909 -0.11611 -0.138727 0.0444845 + 0.0641586 0.00853202 0.0546726 0.0565494 -0.0245263 + -0.155416 0.233615 0.139778 0.189526 -0.0443623 + + + + + + + + + + + + + + + Check 2 + + + + + + -0.0135582 -0.0214845 0.00397748 + 0.0757335 -0.560984 0.0827329 + 0.253474 -0.720004 0.481433 + -0.204286 -1.51033 -0.578873 + 0.436107 -0.707736 0.203309 + 0.302863 -0.897516 0.114378 + 0.923331 -0.996767 0.999364 + -0.484098 -0.499524 -0.538314 + -0.370085 -0.0871457 -0.180858 + -0.764123 -0.666975 -0.595591 + + + + + + Check 2 + + + + + + 0.0337014 0.00740647 0.0101907 -0.0137624 -0.026762 -0.00236374 + 0.69037 0.458552 0.22018 -0.396597 -0.290869 -0.430955 + 0.462661 1.65346 0.739242 -1.04027 0.0845515 -1.06344 + 2.79935 -1.09548 -0.379289 0.0785096 -1.83474 -0.00261458 + 0.613637 1.0812 0.384381 -0.686247 0.0498013 -0.949892 + 1.03273 0.822841 0.321661 -0.638018 -0.300214 -0.8367 + 0.0537815 3.55528 1.40849 -1.97783 0.952206 -2.31857 + 1.44965 -1.54838 -0.535011 0.558691 -1.29912 0.717755 + 0.4724 -0.701772 -0.187652 0.258751 -0.579387 0.469799 + 1.87889 -1.84317 -0.556674 0.597472 -1.75148 0.944242 + + + + + + + + + + + + + + + Check 3 + + + + + + -0.0518822 0.013655 -0.0107688 -0.078494 + 0.286926 -0.446739 0.0625993 0.979948 + 0.554231 -0.720789 0.328392 1.61218 + -0.439069 -0.62909 0.483508 -0.118281 + -0.476466 -0.0590634 -0.270876 -0.421286 + -0.722599 -0.054492 -0.290839 -0.760727 + -0.458698 -0.328729 0.403096 -0.431089 + -0.782394 0.0796499 -0.264029 -0.968578 + -0.341454 0.202061 -0.53683 -0.303807 + -0.557268 -0.0127222 -0.439181 -0.31635 + + + + + + Check 3 + + + + + + -0.0518822 0.013655 -0.0107688 -0.078494 + 0.286926 -0.446739 0.0625993 0.979948 + 0.554231 -0.720789 0.328392 1.61218 + -0.439069 -0.62909 0.483508 -0.118281 + -0.476466 -0.0590634 -0.270876 -0.421286 + -0.722599 -0.054492 -0.290839 -0.760727 + -0.458698 -0.328729 0.403096 -0.431089 + -0.782394 0.0796499 -0.264029 -0.968578 + -0.341454 0.202061 -0.53683 -0.303807 + -0.557268 -0.0127222 -0.439181 -0.31635 + + + From c328824883815cf4b8df849d9ce578e8e7181030 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Mon, 20 Sep 2021 10:09:48 -0700 Subject: [PATCH 040/196] testing new gradient in testxml --- ci/TestXML/testDecomposedPrecisionGradient.xml | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/ci/TestXML/testDecomposedPrecisionGradient.xml b/ci/TestXML/testDecomposedPrecisionGradient.xml index 00b3b28359..d118550942 100644 --- a/ci/TestXML/testDecomposedPrecisionGradient.xml +++ b/ci/TestXML/testDecomposedPrecisionGradient.xml @@ -896,9 +896,10 @@ - + + - + @@ -923,7 +924,7 @@ - + @@ -940,7 +941,7 @@ - + From ed0598825d2d272308feb11bd24b4f02333ec325 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Mon, 20 Sep 2021 10:10:37 -0700 Subject: [PATCH 041/196] less strict numerical instability condition (still a ad hoc solution) --- .../operators/hmc/GeodesicHamiltonianMonteCarloOperator.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dr/inference/operators/hmc/GeodesicHamiltonianMonteCarloOperator.java b/src/dr/inference/operators/hmc/GeodesicHamiltonianMonteCarloOperator.java index 3044996cf6..4445865e82 100644 --- a/src/dr/inference/operators/hmc/GeodesicHamiltonianMonteCarloOperator.java +++ b/src/dr/inference/operators/hmc/GeodesicHamiltonianMonteCarloOperator.java @@ -513,7 +513,7 @@ public void updatePosition(double[] position, WrappedVector momentum, sse += diff * diff; } - if (sse > 1e-3) { + if (sse / position.length > 1e-2) { //TODO: actually figure out if I want this System.err.println("unstable"); //TODO: REMOVE throw new NumericInstabilityException(); } From 2f1a7b5db93060e5d96652a02bfd58517132feb0 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Thu, 23 Sep 2021 09:44:51 -0700 Subject: [PATCH 042/196] need to check if matrix is PD in LKJCorrelationDistribution --- .../distributions/LKJCorrelationDistribution.java | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/dr/math/distributions/LKJCorrelationDistribution.java b/src/dr/math/distributions/LKJCorrelationDistribution.java index 6009a260a4..ac40ff5431 100644 --- a/src/dr/math/distributions/LKJCorrelationDistribution.java +++ b/src/dr/math/distributions/LKJCorrelationDistribution.java @@ -51,12 +51,12 @@ public double logPdf(double[] x) { assert (x.length == upperTriangularSize(dim)); - if (shape == 1.0) { // Uniform - return logNormalizationConstant; - } else { - SymmetricMatrix R = compoundCorrelationSymmetricMatrix(x, dim); - return logPdf(R); - } +// if (shape == 1.0) { // Uniform //even when it's uniform, you still want to return -inf if it's not pos. def. +// return logNormalizationConstant; +// } else { + SymmetricMatrix R = compoundCorrelationSymmetricMatrix(x, dim); + return logPdf(R); +// } } private double logPdf(Matrix R) { From 555d7c975cbf679968d023155908081baec2b126 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Thu, 23 Sep 2021 11:48:21 -0700 Subject: [PATCH 043/196] statistic that computes the determinant of a matrix --- .../app/beast/development_parsers.properties | 3 +- .../inference/model/DeterminantStatistic.java | 96 +++++++++++++++++++ 2 files changed, 98 insertions(+), 1 deletion(-) create mode 100644 src/dr/inference/model/DeterminantStatistic.java diff --git a/src/dr/app/beast/development_parsers.properties b/src/dr/app/beast/development_parsers.properties index 6847796ad7..72a867231e 100644 --- a/src/dr/app/beast/development_parsers.properties +++ b/src/dr/app/beast/development_parsers.properties @@ -309,4 +309,5 @@ dr.inference.model.MaskFromTree # Structural Equation Modeling dr.util.MatrixInnerProductTransform -dr.evomodel.treedatalikelihood.continuous.TreeTraitProviderTest \ No newline at end of file +dr.evomodel.treedatalikelihood.continuous.TreeTraitProviderTest +dr.inference.model.DeterminantStatistic \ No newline at end of file diff --git a/src/dr/inference/model/DeterminantStatistic.java b/src/dr/inference/model/DeterminantStatistic.java new file mode 100644 index 0000000000..04f931c066 --- /dev/null +++ b/src/dr/inference/model/DeterminantStatistic.java @@ -0,0 +1,96 @@ +package dr.inference.model; + +import dr.xml.*; +import org.ejml.data.DenseMatrix64F; +import org.ejml.ops.CommonOps; + +/** + * @author Gabriel Hassler + * @author Marc A. Suchard + */ + +public class DeterminantStatistic extends Statistic.Abstract implements VariableListener { + + private final MatrixParameterInterface matrix; + private final int matrixDim; + private boolean detKnown = false; + private double det; + + public DeterminantStatistic(String name, MatrixParameterInterface matrix) { + super(name); + + this.matrix = matrix; + this.matrixDim = matrix.getRowDimension(); + matrix.addParameterListener(this); + + } + + @Override + public int getDimension() { + return 1; + } + + @Override + public double getStatisticValue(int dim) { + if (!detKnown) { + double[] values = matrix.getParameterValues(); + DenseMatrix64F M = DenseMatrix64F.wrap(matrixDim, matrixDim, values); + det = CommonOps.det(M); + detKnown = true; + } + + return det; + } + + @Override + public void variableChangedEvent(Variable variable, int index, Variable.ChangeType type) { + detKnown = false; + } + + public static final AbstractXMLObjectParser PARSER = new AbstractXMLObjectParser() { + + private static final String DETERMINANT_STATISTIC = "determinant"; + + + @Override + public Object parseXMLObject(XMLObject xo) throws XMLParseException { + MatrixParameterInterface matrix = (MatrixParameterInterface) xo.getChild(MatrixParameterInterface.class); + if (matrix.getColumnDimension() != matrix.getRowDimension()) { + throw new XMLParseException("can only calculate determinant for square matrices"); + } + final String name; + if (xo.hasId()) { + name = xo.getId(); + } else if (matrix.getId() != null) { + name = "determinant." + matrix.getId(); + } else { + name = "determinant"; + } + return new DeterminantStatistic(name, matrix); + } + + @Override + public XMLSyntaxRule[] getSyntaxRules() { + return new XMLSyntaxRule[]{ + new ElementRule(MatrixParameterInterface.class) + }; + } + + @Override + public String getParserDescription() { + return "Statistic that computes the determinant of a matrix"; + } + + @Override + public Class getReturnType() { + return DeterminantStatistic.class; + } + + @Override + public String getParserName() { + return DETERMINANT_STATISTIC; + } + }; + + +} From 070a3c7a8e860da069d1c3439017f40eb6ee3595 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Mon, 4 Oct 2021 11:45:45 -0700 Subject: [PATCH 044/196] removing character that messeswith ant compilation --- .../distribution/MultivariateDistributionLikelihood.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dr/inference/distribution/MultivariateDistributionLikelihood.java b/src/dr/inference/distribution/MultivariateDistributionLikelihood.java index 7b1e81da85..200a4e9b3d 100644 --- a/src/dr/inference/distribution/MultivariateDistributionLikelihood.java +++ b/src/dr/inference/distribution/MultivariateDistributionLikelihood.java @@ -784,7 +784,7 @@ public XMLSyntaxRule[] getSyntaxRules() { @Override public String getParserDescription() { - return "Calculates p(X) ∝ det(X)^a"; + return "Calculates p(X) = c * det(X)^a (currently omits normalization constant c)"; } @Override From 4163a624a83ad2dae72f0c12edb4c557f281f947 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Fri, 17 Dec 2021 15:54:33 -0800 Subject: [PATCH 045/196] logger for diagonals of matrix --- .../app/beast/development_parsers.properties | 3 +- .../inference/model/MatrixDiagonalLogger.java | 65 +++++++++++++++++++ 2 files changed, 67 insertions(+), 1 deletion(-) create mode 100644 src/dr/inference/model/MatrixDiagonalLogger.java diff --git a/src/dr/app/beast/development_parsers.properties b/src/dr/app/beast/development_parsers.properties index 72a867231e..19e8fd295d 100644 --- a/src/dr/app/beast/development_parsers.properties +++ b/src/dr/app/beast/development_parsers.properties @@ -310,4 +310,5 @@ dr.inference.model.MaskFromTree # Structural Equation Modeling dr.util.MatrixInnerProductTransform dr.evomodel.treedatalikelihood.continuous.TreeTraitProviderTest -dr.inference.model.DeterminantStatistic \ No newline at end of file +dr.inference.model.DeterminantStatistic +dr.inference.model.MatrixDiagonalLogger \ No newline at end of file diff --git a/src/dr/inference/model/MatrixDiagonalLogger.java b/src/dr/inference/model/MatrixDiagonalLogger.java new file mode 100644 index 0000000000..aad6ba53d5 --- /dev/null +++ b/src/dr/inference/model/MatrixDiagonalLogger.java @@ -0,0 +1,65 @@ +package dr.inference.model; + +import dr.xml.*; + +public class MatrixDiagonalLogger extends Statistic.Abstract { + private final MatrixParameterInterface matrix; + + public MatrixDiagonalLogger(MatrixParameterInterface matrix) { + this.matrix = matrix; + } + + + @Override + public int getDimension() { + return matrix.getColumnDimension(); + } + + @Override + public double getStatisticValue(int dim) { + return matrix.getParameterValue(dim, dim); + } + + @Override + public String getDimensionName(int dim) { + return getStatisticName() + "." + matrix.getDimensionName(dim); + } + + public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { + + private static final String MATRIX_DIAGONAL = "matrixDiagonals"; + + + @Override + public Object parseXMLObject(XMLObject xo) throws XMLParseException { + MatrixParameterInterface matrix = (MatrixParameterInterface) xo.getChild(MatrixParameterInterface.class); + if (matrix.getColumnDimension() != matrix.getRowDimension()) { + throw new XMLParseException("Only square matrices can be converted to correlation matrices"); + } + + return new MatrixDiagonalLogger(matrix); + } + + @Override + public XMLSyntaxRule[] getSyntaxRules() { + return new XMLSyntaxRule[]{ + new ElementRule(MatrixParameterInterface.class) + }; + } + + @Override + public String getParserDescription() { + return "This element returns a statistic that is the diagonals of the associated matrix."; + } + + @Override + public Class getReturnType() { + return MatrixDiagonalLogger.class; + } + + @Override + public String getParserName() { + return MATRIX_DIAGONAL; + } + }; +} From fa7167eb446d05af933af3927172b5f3940e4b6b Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Sun, 19 Dec 2021 09:14:30 -0800 Subject: [PATCH 046/196] tools for getting tree trait from TraitDataLikelihood --- .../ExtendedLatentLiabilityGibbsOperator.java | 18 +++++++++++++----- .../ContinuousTraitPartialsProvider.java | 7 +++++++ .../continuous/JointPartialsProvider.java | 14 ++++++++++++++ .../TreeTraitParserUtilities.java | 8 +++++--- 4 files changed, 39 insertions(+), 8 deletions(-) diff --git a/src/dr/evomodel/operators/ExtendedLatentLiabilityGibbsOperator.java b/src/dr/evomodel/operators/ExtendedLatentLiabilityGibbsOperator.java index 97673dd88b..68b0b7ec88 100644 --- a/src/dr/evomodel/operators/ExtendedLatentLiabilityGibbsOperator.java +++ b/src/dr/evomodel/operators/ExtendedLatentLiabilityGibbsOperator.java @@ -5,8 +5,10 @@ import dr.evomodel.continuous.OrderedLatentLiabilityLikelihood; import dr.evomodel.treedatalikelihood.TreeDataLikelihood; import dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate; +import dr.evomodel.treedatalikelihood.continuous.ContinuousTraitPartialsProvider; import dr.evomodel.treedatalikelihood.preorder.ContinuousExtensionDelegate; import dr.evomodel.treedatalikelihood.preorder.ModelExtensionProvider; +import dr.evomodelxml.treelikelihood.TreeTraitParserUtilities; import dr.inference.model.CompoundParameter; import dr.inference.model.Parameter; import dr.inference.operators.GibbsOperator; @@ -18,7 +20,7 @@ import java.util.ArrayList; import java.util.Arrays; -import static dr.evomodelxml.treelikelihood.TreeTraitParserUtilities.getTreeTraitFromDataLikelihood; +import static dr.evomodelxml.treelikelihood.TreeTraitParserUtilities.*; public class ExtendedLatentLiabilityGibbsOperator extends SimpleMCMCOperator implements GibbsOperator, Reportable { @@ -102,11 +104,16 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { ContinuousDataLikelihoodDelegate delegate = (ContinuousDataLikelihoodDelegate) treeDataLikelihood.getDataLikelihoodDelegate(); - ModelExtensionProvider.NormalExtensionProvider dataModel = - (ModelExtensionProvider.NormalExtensionProvider) delegate.getDataModel(); + String traitName = xo.getAttribute(TRAIT_NAME, getTipTraitNameFromDataLikelihood(treeDataLikelihood)); + + TreeTrait treeTrait = treeDataLikelihood.getTreeTrait(traitName); + + + ContinuousTraitPartialsProvider superDataModel = delegate.getDataModel(); + ModelExtensionProvider.NormalExtensionProvider dataModel = + (ModelExtensionProvider.NormalExtensionProvider) superDataModel.getProviderForTrait(traitName); - TreeTrait treeTrait = getTreeTraitFromDataLikelihood(treeDataLikelihood); Tree tree = treeDataLikelihood.getTree(); ContinuousExtensionDelegate extensionDelegate = dataModel.getExtensionDelegate(delegate, treeTrait, tree); @@ -121,7 +128,8 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { public XMLSyntaxRule[] getSyntaxRules() { return new XMLSyntaxRule[]{ new ElementRule(TreeDataLikelihood.class), - new ElementRule(OrderedLatentLiabilityLikelihood.class) + new ElementRule(OrderedLatentLiabilityLikelihood.class), + AttributeRule.newStringRule(TreeTraitParserUtilities.TRAIT_NAME, true) }; } diff --git a/src/dr/evomodel/treedatalikelihood/continuous/ContinuousTraitPartialsProvider.java b/src/dr/evomodel/treedatalikelihood/continuous/ContinuousTraitPartialsProvider.java index 5064bd2c27..439a12a324 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/ContinuousTraitPartialsProvider.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/ContinuousTraitPartialsProvider.java @@ -95,6 +95,13 @@ default WrappedNormalSufficientStatistics partitionNormalStatistics(WrappedNorma "a provider other than itself."); } + default ContinuousTraitPartialsProvider getProviderForTrait(String trait) { + if (trait.equals(getTipTraitName())) { + return this; + } + throw new RuntimeException("Partials provider does not have trait '" + trait + "'"); + } + static boolean[] indicesToIndicator(List indices, int n) { if (indices == null) { diff --git a/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java b/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java index f91560c678..04779e3eaa 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java @@ -318,6 +318,20 @@ public WrappedNormalSufficientStatistics partitionNormalStatistics(WrappedNormal new WrappedMatrix.WrappedDenseMatrix(newVariance)); } + @Override + public ContinuousTraitPartialsProvider getProviderForTrait(String trait) { + if (trait.equals(getTipTraitName())) { + return this; + } + for (ContinuousTraitPartialsProvider submodel : providers) { + System.out.println(submodel.getTipTraitName()); + if (trait.equals(submodel.getTipTraitName())) { + return submodel; + } + } + throw new RuntimeException("Partials provider does not have trait '" + trait + "', nor did any of its sub-models"); + } + public static final AbstractXMLObjectParser PARSER = new AbstractXMLObjectParser() { private static final String PARSER_NAME = "jointPartialsProvider"; diff --git a/src/dr/evomodelxml/treelikelihood/TreeTraitParserUtilities.java b/src/dr/evomodelxml/treelikelihood/TreeTraitParserUtilities.java index 3ad0351bec..5c6e6fb552 100644 --- a/src/dr/evomodelxml/treelikelihood/TreeTraitParserUtilities.java +++ b/src/dr/evomodelxml/treelikelihood/TreeTraitParserUtilities.java @@ -549,13 +549,15 @@ private Map drawRandomSample(int total, int length) { } public static TreeTrait getTreeTraitFromDataLikelihood(TreeDataLikelihood dataLikelihood) { + return dataLikelihood.getTreeTrait(getTipTraitNameFromDataLikelihood(dataLikelihood)); + } + + public static String getTipTraitNameFromDataLikelihood(TreeDataLikelihood dataLikelihood) { ContinuousDataLikelihoodDelegate delegate = (ContinuousDataLikelihoodDelegate) dataLikelihood.getDataLikelihoodDelegate(); ContinuousTraitPartialsProvider dataModel = delegate.getDataModel(); String traitName = dataModel.getTipTraitName(); -// String realizedTraitName = getTipTraitName(traitName); - - return dataLikelihood.getTreeTrait(traitName); + return traitName; } } From f9dbecbcbafed790c7e03a45b2bb587da4aff594 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Sun, 19 Dec 2021 09:16:16 -0800 Subject: [PATCH 047/196] geodesic hmc bug fix? (i should have committed this a long time ago) --- .../operators/hmc/GeodesicHamiltonianMonteCarloOperator.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/dr/inference/operators/hmc/GeodesicHamiltonianMonteCarloOperator.java b/src/dr/inference/operators/hmc/GeodesicHamiltonianMonteCarloOperator.java index 4445865e82..2d42580a99 100644 --- a/src/dr/inference/operators/hmc/GeodesicHamiltonianMonteCarloOperator.java +++ b/src/dr/inference/operators/hmc/GeodesicHamiltonianMonteCarloOperator.java @@ -197,10 +197,11 @@ private int findMatchingArray(ArrayList> listOfLists, ArrayLi } private int findSubArray(ArrayList> listOfLists, ArrayList list, ArrayList remainingList) { //assumes both are sorted - remainingList.clear(); int nLists = listOfLists.size(); for (int i = 0; i < nLists; i++) { ArrayList subList = listOfLists.get(i); + remainingList.clear(); + if (list.size() <= subList.size()) { int currentInd = 0; for (int j = 0; j < subList.size(); j++) { From 6b3092afcd93c7ab9fa8c9fe412759d06756033d Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Sun, 19 Dec 2021 09:18:40 -0800 Subject: [PATCH 048/196] new accept condition for rejection operator --- .../operators/rejection/AcceptCondition.java | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/dr/inference/operators/rejection/AcceptCondition.java b/src/dr/inference/operators/rejection/AcceptCondition.java index 7b2b6cf704..8909b9d456 100644 --- a/src/dr/inference/operators/rejection/AcceptCondition.java +++ b/src/dr/inference/operators/rejection/AcceptCondition.java @@ -1,5 +1,9 @@ package dr.inference.operators.rejection; +import dr.math.matrixAlgebra.CholeskyDecomposition; +import dr.math.matrixAlgebra.IllegalDimension; +import dr.math.matrixAlgebra.Matrix; + public interface AcceptCondition { boolean satisfiesCondition(double[] values); @@ -41,6 +45,21 @@ public boolean satisfiesCondition(double[] values) { } return true; } + }, + + PositiveDefinite("positiveDefinite") { + @Override + public boolean satisfiesCondition(double[] values) { + int n = (int) Math.sqrt(values.length); + Matrix M = new Matrix(values, n, n); + CholeskyDecomposition chol; + try { + chol = new CholeskyDecomposition(M); + } catch (IllegalDimension illegalDimension) { + throw new RuntimeException("Matrix must be square"); + } + return chol.isSPD(); + } }; private final String name; From be25ee5bec07dcaf6b4ce6576b3730a0ddadae2f Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Sun, 19 Dec 2021 11:47:08 -0800 Subject: [PATCH 049/196] more general way to specify which columns of the loadings get sampled --- .../GeneralizedSampleConstraints.java | 123 +++++++++++++ .../LoadingsSamplerConstraints.java | 161 ++++++++++++++++++ .../NewLoadingsGibbsOperator.java | 146 +--------------- .../LoadingsGibbsOperatorParser.java | 30 ++-- 4 files changed, 309 insertions(+), 151 deletions(-) create mode 100644 src/dr/inference/operators/factorAnalysis/GeneralizedSampleConstraints.java create mode 100644 src/dr/inference/operators/factorAnalysis/LoadingsSamplerConstraints.java diff --git a/src/dr/inference/operators/factorAnalysis/GeneralizedSampleConstraints.java b/src/dr/inference/operators/factorAnalysis/GeneralizedSampleConstraints.java new file mode 100644 index 0000000000..1633c77bc6 --- /dev/null +++ b/src/dr/inference/operators/factorAnalysis/GeneralizedSampleConstraints.java @@ -0,0 +1,123 @@ +package dr.inference.operators.factorAnalysis; + +import dr.xml.*; + +import java.util.ArrayList; + + +/** + * @author Gabriel Hassler + * @author Marc A. Suchard + */ + +public class GeneralizedSampleConstraints implements LoadingsSamplerConstraints { + private final ArrayList columnIndices; + private final ArrayList uniqueLengths; + private final int[] arrayIndices; + + + GeneralizedSampleConstraints(ArrayList columnIndices) { + this.columnIndices = columnIndices; + + int nCols = columnIndices.size(); + this.uniqueLengths = new ArrayList<>(); + this.arrayIndices = new int[nCols]; + for (int i = 0; i < nCols; i++) { + int dim = columnIndices.get(i).length; + int arrayIndex = -1; + for (int j = 0; j < uniqueLengths.size(); j++) { + if (uniqueLengths.get(j) == dim) { + arrayIndex = j; + break; + } + } + + if (arrayIndex == -1) { + arrayIndex = uniqueLengths.size(); + uniqueLengths.add(dim); + } + + arrayIndices[i] = arrayIndex; + } + } + + + @Override + public int getColumnDim(int colIndex, int nRows) { + return columnIndices.get(colIndex).length; //TODO: return int[] of actual indices, not just assume 0:(n- 1) + } + + @Override + public int getArrayIndex(int colIndex, int nRows) { + return arrayIndices[colIndex]; + } + + @Override + public void allocateStorage(ArrayList precisionArray, ArrayList midMeanArray, ArrayList meanArray, int nRows) { + for (int i = 0; i < uniqueLengths.size(); i++) { + int dim = uniqueLengths.get(i); + precisionArray.add(new double[dim][dim]); + midMeanArray.add(new double[dim]); + meanArray.add(new double[dim]); + } + } + + public static final AbstractXMLObjectParser PARSER = new AbstractXMLObjectParser() { + + private static final String SAMPLE_COLUMNS = "sampleColumns"; + private static final String INDICES = "indices"; + private static final String TRAITS = "traits"; + private static final String ROWS = "rows"; + + @Override + public Object parseXMLObject(XMLObject xo) throws XMLParseException { + ArrayList columnIndices = new ArrayList<>(); + int nextTrait = 1; + for (int i = 0; i < xo.getChildCount(); i++) { + XMLObject cxo = (XMLObject) xo.getChild(i); + int[] traits = cxo.getIntegerArrayAttribute(TRAITS); + int[] rows = cxo.getIntegerArrayAttribute(ROWS); + + for (int trait : traits) { + if (trait != nextTrait) { + throw new XMLParseException("Currently only implemented for sequential '" + TRAITS + + "' values."); + } + + for (int j = 0; j < rows.length; j++) { + rows[j] -= 1; + } + columnIndices.add(rows); + nextTrait++; + } + + } + return new GeneralizedSampleConstraints(columnIndices); + } + + @Override + public XMLSyntaxRule[] getSyntaxRules() { + return new XMLSyntaxRule[]{ + new ElementRule(INDICES, new XMLSyntaxRule[]{ + AttributeRule.newIntegerArrayRule(TRAITS, false), + AttributeRule.newIntegerArrayRule(ROWS, false) + }, 1, Integer.MAX_VALUE) + }; + } + + @Override + public String getParserDescription() { + return "Sample from only certain elements of the loadings matrix"; + } + + @Override + public Class getReturnType() { + return LoadingsSamplerConstraints.class; + } + + @Override + public String getParserName() { + return SAMPLE_COLUMNS; + } + }; +} diff --git a/src/dr/inference/operators/factorAnalysis/LoadingsSamplerConstraints.java b/src/dr/inference/operators/factorAnalysis/LoadingsSamplerConstraints.java new file mode 100644 index 0000000000..17cf84ced7 --- /dev/null +++ b/src/dr/inference/operators/factorAnalysis/LoadingsSamplerConstraints.java @@ -0,0 +1,161 @@ +package dr.inference.operators.factorAnalysis; + + +import java.util.ArrayList; + +/** + * @author Gabriel Hassler + * @author Marc A. Suchard + */ + +public interface LoadingsSamplerConstraints { + + int getColumnDim(int colIndex, int nRows); + + int getArrayIndex(int colIndex, int nRows); + + void allocateStorage(ArrayList precisionArray, ArrayList midMeanArray, + ArrayList meanArray, int nRows); + + + enum ColumnDimProvider implements LoadingsSamplerConstraints { + + NONE("none") { + @Override + public int getColumnDim(int colIndex, int nRows) { + return nRows; + } + + @Override + public int getArrayIndex(int colIndex, int nRows) { + return 0; + } + + @Override + public void allocateStorage(ArrayList precisionArray, ArrayList midMeanArray, + ArrayList meanArray, int nRows) { + + precisionArray.add(new double[nRows][nRows]); + midMeanArray.add(new double[nRows]); + meanArray.add(new double[nRows]); + + } + }, + + UPPER_TRIANGULAR("upperTriangular") { + @Override + public int getColumnDim(int colIndex, int nRows) { + return Math.min(colIndex + 1, nRows); + } + + @Override + public int getArrayIndex(int colIndex, int nRows) { + return Math.min(colIndex, nRows - 1); + } + + @Override + public void allocateStorage(ArrayList precisionArray, ArrayList midMeanArray, + ArrayList meanArray, int nRows) { + + for (int i = 1; i <= nRows; i++) { + precisionArray.add(new double[i][i]); + midMeanArray.add(new double[i]); + meanArray.add(new double[i]); + } + + } + }, + + FIRST_ROW("firstRow") { + @Override + public int getColumnDim(int colIndex, int nRows) { + return 1; + } + + @Override + public int getArrayIndex(int colIndex, int nRows) { + return 0; + } + + @Override + public void allocateStorage(ArrayList precisionArray, ArrayList midMeanArray, + ArrayList meanArray, int nRows) { + + precisionArray.add(new double[1][1]); + midMeanArray.add(new double[1]); + meanArray.add(new double[1]); + + } + }, + + HYBRID("hybrid") { + @Override + public int getColumnDim(int colIndex, int nRows) { + + if (colIndex == 0) { + return 1; + } + return nRows; + } + + @Override + public int getArrayIndex(int colIndex, int nRows) { + if (colIndex == 0) { + return 0; + } + return 1; + } + + @Override + public void allocateStorage(ArrayList precisionArray, ArrayList midMeanArray, ArrayList meanArray, int nRows) { + + // first column + precisionArray.add(new double[1][1]); + midMeanArray.add(new double[1]); + meanArray.add(new double[1]); + + + // remaining columns + precisionArray.add(new double[nRows][nRows]); + midMeanArray.add(new double[nRows]); + meanArray.add(new double[nRows]); + + } + }; + + + private static int[] convertToIndices(int i) { + int[] indices = new int[i]; + for (int j = 0; j < i; j++) { + indices[j] = j; + } + return indices; + } + + public int[] getColumnIndices(int colIndex, int nRows) { + return convertToIndices(getColumnDim(colIndex, nRows)); + } + + + private String name; + + ColumnDimProvider(String name) { + this.name = name; + } + + public String getName() { + return name; + } + + public static ColumnDimProvider parse(String name) { + name = name.toLowerCase(); + for (ColumnDimProvider dimProvider : ColumnDimProvider.values()) { + if (name.compareTo(dimProvider.getName().toLowerCase()) == 0) { + return dimProvider; + } + } + throw new IllegalArgumentException("Unknown dimension provider type"); + } + + } +} diff --git a/src/dr/inference/operators/factorAnalysis/NewLoadingsGibbsOperator.java b/src/dr/inference/operators/factorAnalysis/NewLoadingsGibbsOperator.java index 6f2291d3d5..e440170df9 100644 --- a/src/dr/inference/operators/factorAnalysis/NewLoadingsGibbsOperator.java +++ b/src/dr/inference/operators/factorAnalysis/NewLoadingsGibbsOperator.java @@ -65,14 +65,14 @@ public class NewLoadingsGibbsOperator extends SimpleMCMCOperator implements Gibb private final FactorAnalysisOperatorAdaptor adaptor; private final ConstrainedSampler constrainedSampler; - private final ColumnDimProvider columnDimProvider; + private final LoadingsSamplerConstraints columnDimProvider; public NewLoadingsGibbsOperator(FactorAnalysisStatisticsProvider statisticsProvider, NormalStatisticsProvider prior, double weight, boolean randomScan, DistributionLikelihood workingPrior, boolean multiThreaded, int numThreads, ConstrainedSampler constrainedSampler, - ColumnDimProvider columnDimProvider) { + LoadingsSamplerConstraints columnDimProvider) { setWeight(weight); @@ -209,8 +209,10 @@ private void drawI(int i, double[][] precision, double[] midMean, double[] mean) } private void drawI(int i) { - int arrayInd = columnDimProvider.getArrayIndex(i, adaptor.getNumberOfFactors()); - drawI(i, precisionArray.get(arrayInd), meanMidArray.get(arrayInd), meanArray.get(arrayInd)); + if (columnDimProvider.getColumnDim(i, adaptor.getNumberOfFactors()) > 0) { + int arrayInd = columnDimProvider.getArrayIndex(i, adaptor.getNumberOfFactors()); + drawI(i, precisionArray.get(arrayInd), meanMidArray.get(arrayInd), meanArray.get(arrayInd)); + } } // @Override @@ -370,142 +372,6 @@ public static ConstrainedSampler parse(String name) { abstract void applyConstraint(FactorAnalysisOperatorAdaptor adaptor); } - public enum ColumnDimProvider {//TODO: don't hard code constraints, make more generalizable - - - NONE("none") { - @Override - int getColumnDim(int colIndex, int nRows) { - return nRows; - } - - @Override - int getArrayIndex(int colIndex, int nRows) { - return 0; - } - - @Override - void allocateStorage(ArrayList precisionArray, ArrayList midMeanArray, - ArrayList meanArray, int nRows) { - - precisionArray.add(new double[nRows][nRows]); - midMeanArray.add(new double[nRows]); - meanArray.add(new double[nRows]); - - } - }, - - UPPER_TRIANGULAR("upperTriangular") { - @Override - int getColumnDim(int colIndex, int nRows) { - return Math.min(colIndex + 1, nRows); - } - - @Override - int getArrayIndex(int colIndex, int nRows) { - return Math.min(colIndex, nRows - 1); - } - - @Override - void allocateStorage(ArrayList precisionArray, ArrayList midMeanArray, - ArrayList meanArray, int nRows) { - - for (int i = 1; i <= nRows; i++) { - precisionArray.add(new double[i][i]); - midMeanArray.add(new double[i]); - meanArray.add(new double[i]); - } - - } - }, - - FIRST_ROW("firstRow") { - @Override - int getColumnDim(int colIndex, int nRows) { - return 1; - } - - @Override - int getArrayIndex(int colIndex, int nRows) { - return 0; - } - - @Override - void allocateStorage(ArrayList precisionArray, ArrayList midMeanArray, - ArrayList meanArray, int nRows) { - - precisionArray.add(new double[1][1]); - midMeanArray.add(new double[1]); - meanArray.add(new double[1]); - - } - }, - - HYBRID("hybrid") { - @Override - int getColumnDim(int colIndex, int nRows) { - - if (colIndex == 0) { - return 1; - } - return nRows; - } - - @Override - int getArrayIndex(int colIndex, int nRows) { - if (colIndex == 0) { - return 0; - } - return 1; - } - - @Override - void allocateStorage(ArrayList precisionArray, ArrayList midMeanArray, ArrayList meanArray, int nRows) { - - // first column - precisionArray.add(new double[1][1]); - midMeanArray.add(new double[1]); - meanArray.add(new double[1]); - - - // remaining columns - precisionArray.add(new double[nRows][nRows]); - midMeanArray.add(new double[nRows]); - meanArray.add(new double[nRows]); - - } - }; - - abstract int getColumnDim(int colIndex, int nRows); - - abstract int getArrayIndex(int colIndex, int nRows); - - abstract void allocateStorage(ArrayList precisionArray, ArrayList midMeanArray, - ArrayList meanArray, int nRows); - - - private String name; - - ColumnDimProvider(String name) { - this.name = name; - } - - public String getName() { - return name; - } - - public static ColumnDimProvider parse(String name) { - name = name.toLowerCase(); - for (ColumnDimProvider dimProvider : ColumnDimProvider.values()) { - if (name.compareTo(dimProvider.getName().toLowerCase()) == 0) { - return dimProvider; - } - } - throw new IllegalArgumentException("Unknown dimension provider type"); - } - - } - @Override public String getReport() { int repeats = 20000; diff --git a/src/dr/inferencexml/operators/factorAnalysis/LoadingsGibbsOperatorParser.java b/src/dr/inferencexml/operators/factorAnalysis/LoadingsGibbsOperatorParser.java index a633a73722..f7ebf999a3 100644 --- a/src/dr/inferencexml/operators/factorAnalysis/LoadingsGibbsOperatorParser.java +++ b/src/dr/inferencexml/operators/factorAnalysis/LoadingsGibbsOperatorParser.java @@ -25,19 +25,13 @@ package dr.inferencexml.operators.factorAnalysis; -import dr.evomodel.treedatalikelihood.TreeDataLikelihood; -import dr.evomodel.treedatalikelihood.continuous.IntegratedFactorAnalysisLikelihood; import dr.inference.distribution.DistributionLikelihood; import dr.inference.distribution.MomentDistributionModel; -import dr.inference.distribution.NormalDistributionModel; import dr.inference.distribution.NormalStatisticsProvider; import dr.inference.model.LatentFactorModel; import dr.inference.model.MatrixParameterInterface; -import dr.inference.model.Parameter; import dr.inference.operators.factorAnalysis.*; import dr.math.distributions.Distribution; -import dr.math.distributions.NormalDistribution; -import dr.util.Attribute; import dr.xml.*; /** @@ -117,13 +111,26 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { xo.getAttribute(CONSTRAINT, NewLoadingsGibbsOperator.ConstrainedSampler.NONE.getName()) ); - NewLoadingsGibbsOperator.ColumnDimProvider dimProvider = - NewLoadingsGibbsOperator.ColumnDimProvider.parse(xo.getAttribute(SPARSITY_CONSTRAINT, - NewLoadingsGibbsOperator.ColumnDimProvider.UPPER_TRIANGULAR.getName()) - ); + LoadingsSamplerConstraints sparsityConstraints = + (GeneralizedSampleConstraints) + xo.getChild(GeneralizedSampleConstraints.class); + + if (sparsityConstraints != null && xo.hasAttribute(SPARSITY_CONSTRAINT)) { + throw new XMLParseException("Cannot provide both a '" + SPARSITY_CONSTRAINT + "' attribute and '" + + GeneralizedSampleConstraints.PARSER.getParserName() + + "' element."); + } + + if (sparsityConstraints == null) { + sparsityConstraints = + LoadingsSamplerConstraints.ColumnDimProvider.parse(xo.getAttribute(SPARSITY_CONSTRAINT, + LoadingsSamplerConstraints.ColumnDimProvider.UPPER_TRIANGULAR.getName()) + ); + } + return new NewLoadingsGibbsOperator(statisticsProvider, prior, weight, randomScan, WorkingPrior, - multiThreaded, numThreads, sampler, dimProvider); + multiThreaded, numThreads, sampler, sparsityConstraints); } else { // return new LoadingsGibbsOperator(LFM, prior, weight, randomScan, WorkingPrior, multiThreaded, numThreads); return null; @@ -158,6 +165,7 @@ public XMLSyntaxRule[] getSyntaxRules() { new ElementRule(WORKING_PRIOR, new XMLSyntaxRule[]{ new ElementRule(DistributionLikelihood.class) }, true), + new ElementRule(GeneralizedSampleConstraints.class, true) }; @Override From a852fb128b9d7989a3da857d317e73bc9d858275 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Sun, 19 Dec 2021 11:56:01 -0800 Subject: [PATCH 050/196] need to find the new parser --- src/dr/app/beast/development_parsers.properties | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/dr/app/beast/development_parsers.properties b/src/dr/app/beast/development_parsers.properties index 19e8fd295d..4f97a35b97 100644 --- a/src/dr/app/beast/development_parsers.properties +++ b/src/dr/app/beast/development_parsers.properties @@ -311,4 +311,5 @@ dr.inference.model.MaskFromTree dr.util.MatrixInnerProductTransform dr.evomodel.treedatalikelihood.continuous.TreeTraitProviderTest dr.inference.model.DeterminantStatistic -dr.inference.model.MatrixDiagonalLogger \ No newline at end of file +dr.inference.model.MatrixDiagonalLogger +dr.inference.operators.factorAnalysis.GeneralizedSampleConstraints \ No newline at end of file From 8a9b44ca61dda9277512dc2681b6a94e7011d320 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Wed, 26 Jan 2022 14:44:32 -0800 Subject: [PATCH 051/196] checking that all precision types are the same --- .../continuous/JointPartialsProvider.java | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java b/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java index 04779e3eaa..b4617071a9 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java @@ -33,7 +33,7 @@ public class JointPartialsProvider extends AbstractModel implements ContinuousTr private final boolean defaultAllowSingular; private final Boolean computeDeterminant; // TODO: Maybe pass as argument? - private static final PrecisionType precisionType = PrecisionType.FULL; //TODO: base on child precisionTypes (make sure they're all the same) + private final PrecisionType precisionType; //TODO: base on child precisionTypes (make sure they're all the same) private String tipTraitName; @@ -41,9 +41,10 @@ public class JointPartialsProvider extends AbstractModel implements ContinuousTr private static final Boolean DEBUG = false; - public JointPartialsProvider(String name, ContinuousTraitPartialsProvider[] providers) { + public JointPartialsProvider(String name, ContinuousTraitPartialsProvider[] providers, PrecisionType precisionType) { super(name); this.providers = providers; + this.precisionType = precisionType; int traitDim = 0; int dataDim = 0; @@ -356,7 +357,19 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { } } - return new JointPartialsProvider(PARSER_NAME, providers); + + PrecisionType precisionType = providers[0].getPrecisionType(); + for (int i = 1; i < providers.length; i++) { + if (providers[i].getPrecisionType() != precisionType) { + throw new XMLParseException("all partials providers must have the same precision type. " + + "Provider for model " + providers[0].getModelName() + " has precision type '" + precisionType + + "', while provider for model " + providers[i].getModelName() + " has precision type '" + + providers[i].getPrecisionType() + "'."); + } + } + + + return new JointPartialsProvider(PARSER_NAME, providers, precisionType); } @Override From d77b685bf181f72b6a7a12752040c77362342c01 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Wed, 26 Jan 2022 15:25:52 -0800 Subject: [PATCH 052/196] cleaning code --- .../continuous/RepeatedMeasuresTraitDataModel.java | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java b/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java index 96c6d5fd29..64a16a3723 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java @@ -306,16 +306,6 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { String traitName = returnValue.traitName; - //TODO diffusionModel was only used for the dimension. - // But this should be the same as the samplingPrecision dimension ? -// MultivariateDiffusionModel diffusionModel = (MultivariateDiffusionModel) -// xo.getChild(MultivariateDiffusionModel.class); - - //TODO: This was never used. -// final boolean[] missingIndicators = new boolean[returnValue.traitParameter.getDimension()]; -// for (int i : missingIndices) { -// missingIndicators[i] = true; -// } boolean scaleByTipHeight = xo.getAttribute(SCALE_BY_TIP_HEIGHT, false); From 84007da75a97d7caef021d3a3f45dc6709b03496 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Wed, 26 Jan 2022 17:21:06 -0800 Subject: [PATCH 053/196] optionally force full precision in RepeatedMeasures --- .../ExtendedLatentLiabilityGibbsOperator.java | 2 +- .../RepeatedMeasuresTraitDataModel.java | 26 ++++++++++++++----- ...eScaledRepeatedMeasuresTraitDataModel.java | 6 +++-- .../ContinuousDataLikelihoodParser.java | 2 +- .../continuous/RepeatedMeasureFactorTest.java | 7 +++-- .../hmc/DiffusionGradientTest.java | 6 +++-- 6 files changed, 34 insertions(+), 15 deletions(-) diff --git a/src/dr/evomodel/operators/ExtendedLatentLiabilityGibbsOperator.java b/src/dr/evomodel/operators/ExtendedLatentLiabilityGibbsOperator.java index 68b0b7ec88..060a17d698 100644 --- a/src/dr/evomodel/operators/ExtendedLatentLiabilityGibbsOperator.java +++ b/src/dr/evomodel/operators/ExtendedLatentLiabilityGibbsOperator.java @@ -39,7 +39,7 @@ public class ExtendedLatentLiabilityGibbsOperator extends SimpleMCMCOperator imp this.latentLiabilityLikelihood = latentLiabilityLikelihood; this.dataModel = dataModel; - if (!dataModel.diagonalVariance()) { + if (!dataModel.diagonalVariance() && dataModel.getDataDimension() > 1) { throw new RuntimeException(EXTENDED_LATENT_GIBBS + " is only valid for extended models with diagonal variance."); } diff --git a/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java b/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java index 64a16a3723..34329cf7b2 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java @@ -33,6 +33,7 @@ import dr.evomodel.treedatalikelihood.continuous.cdi.PrecisionType; import dr.evomodel.treedatalikelihood.preorder.ContinuousExtensionDelegate; import dr.evomodel.treedatalikelihood.preorder.ModelExtensionProvider; +import dr.evomodelxml.treedatalikelihood.ContinuousDataLikelihoodParser; import dr.evomodelxml.treelikelihood.TreeTraitParserUtilities; import dr.inference.model.CompoundParameter; import dr.inference.model.MatrixParameterInterface; @@ -77,10 +78,10 @@ public RepeatedMeasuresTraitDataModel(String name, boolean[] missindIndicators, boolean useMissingIndices, final int dimTrait, - MatrixParameterInterface samplingPrecision) { + MatrixParameterInterface samplingPrecision, + PrecisionType precisionType) { - super(name, parameter, missindIndicators, useMissingIndices, dimTrait, - dimTrait == 1 ? PrecisionType.SCALAR : PrecisionType.FULL); //TODO: Not sure this is the best way to do this. + super(name, parameter, missindIndicators, useMissingIndices, dimTrait, precisionType); this.traitName = name; this.samplingPrecisionParameter = samplingPrecision; @@ -309,6 +310,15 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { boolean scaleByTipHeight = xo.getAttribute(SCALE_BY_TIP_HEIGHT, false); + int dimTrait = samplingPrecision.getColumnDimension(); + final PrecisionType precisionType; + if (xo.getAttribute(ContinuousDataLikelihoodParser.FORCE_FULL_PRECISION, false) || + dimTrait > 1) { + precisionType = PrecisionType.FULL; + } else { + precisionType = PrecisionType.SCALAR; + } + if (!scaleByTipHeight) { return new RepeatedMeasuresTraitDataModel( traitName, @@ -316,9 +326,10 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { missingIndicators, // missingIndicators, true, - samplingPrecision.getColumnDimension(), + dimTrait, // diffusionModel.getPrecisionParameter().getRowDimension(), - samplingPrecision + samplingPrecision, + precisionType ); } else { return new TreeScaledRepeatedMeasuresTraitDataModel( @@ -326,8 +337,9 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { traitParameter, missingIndicators, true, - samplingPrecision.getColumnDimension(), - samplingPrecision + dimTrait, + samplingPrecision, + precisionType ); } } diff --git a/src/dr/evomodel/treedatalikelihood/continuous/TreeScaledRepeatedMeasuresTraitDataModel.java b/src/dr/evomodel/treedatalikelihood/continuous/TreeScaledRepeatedMeasuresTraitDataModel.java index 5bad29d359..5268e7fe4c 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/TreeScaledRepeatedMeasuresTraitDataModel.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/TreeScaledRepeatedMeasuresTraitDataModel.java @@ -27,6 +27,7 @@ import dr.evolution.tree.NodeRef; import dr.evolution.tree.Tree; +import dr.evomodel.treedatalikelihood.continuous.cdi.PrecisionType; import dr.inference.model.CompoundParameter; import dr.inference.model.MatrixParameterInterface; import org.ejml.data.DenseMatrix64F; @@ -48,8 +49,9 @@ public TreeScaledRepeatedMeasuresTraitDataModel(String name, boolean[] missingIndicators, boolean useMissingIndices, final int dimTrait, - MatrixParameterInterface samplingPrecision) { - super(name, parameter, missingIndicators, useMissingIndices, dimTrait, samplingPrecision); + MatrixParameterInterface samplingPrecision, + PrecisionType precisionType) { + super(name, parameter, missingIndicators, useMissingIndices, dimTrait, samplingPrecision, precisionType); } @Override diff --git a/src/dr/evomodelxml/treedatalikelihood/ContinuousDataLikelihoodParser.java b/src/dr/evomodelxml/treedatalikelihood/ContinuousDataLikelihoodParser.java index 7c6eef48c3..f7bf151bf7 100644 --- a/src/dr/evomodelxml/treedatalikelihood/ContinuousDataLikelihoodParser.java +++ b/src/dr/evomodelxml/treedatalikelihood/ContinuousDataLikelihoodParser.java @@ -64,7 +64,7 @@ public class ContinuousDataLikelihoodParser extends AbstractXMLObjectParser { private static final String RECONSTRUCT_TRAITS = "reconstructTraits"; private static final String FORCE_COMPLETELY_MISSING = "forceCompletelyMissing"; private static final String ALLOW_SINGULAR = "allowSingular"; - private static final String FORCE_FULL_PRECISION = "forceFullPrecision"; + public static final String FORCE_FULL_PRECISION = "forceFullPrecision"; private static final String FORCE_DRIFT = "forceDrift"; private static final String FORCE_OU = "forceOU"; diff --git a/src/test/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasureFactorTest.java b/src/test/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasureFactorTest.java index b3dedf36e6..76a5ad3422 100644 --- a/src/test/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasureFactorTest.java +++ b/src/test/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasureFactorTest.java @@ -31,6 +31,7 @@ import dr.evomodel.continuous.MultivariateElasticModel; import dr.evomodel.treedatalikelihood.TreeDataLikelihood; import dr.evomodel.treedatalikelihood.continuous.*; +import dr.evomodel.treedatalikelihood.continuous.cdi.PrecisionType; import dr.inference.model.*; import dr.math.MathUtils; import dr.math.matrixAlgebra.Vector; @@ -147,14 +148,16 @@ public void setUp() throws Exception { // new boolean[3], true, dimTrait, - samplingPrecisionParameter); + samplingPrecisionParameter, + PrecisionType.FULL); dataModelRepeatedMeasuresFull = new RepeatedMeasuresTraitDataModel("dataModelRepeatedMeasures", traitParameter, missingIndicators, true, dimTrait, - samplingPrecisionParameterFull); + samplingPrecisionParameterFull, + PrecisionType.FULL); } diff --git a/src/test/dr/evomodel/treedatalikelihood/hmc/DiffusionGradientTest.java b/src/test/dr/evomodel/treedatalikelihood/hmc/DiffusionGradientTest.java index 8fb04919bd..4a8ae8dbdd 100644 --- a/src/test/dr/evomodel/treedatalikelihood/hmc/DiffusionGradientTest.java +++ b/src/test/dr/evomodel/treedatalikelihood/hmc/DiffusionGradientTest.java @@ -207,14 +207,16 @@ public void setUp() throws Exception { missingIndicators, true, dimTrait, - samplingPrecision); + samplingPrecision, + PrecisionType.FULL); dataModelRepeatedMeasuresInv = new RepeatedMeasuresTraitDataModel("dataModelRepeatedMeasuresInv", traitParameter, missingIndicators, true, dimTrait, - samplingPrecisionInv); + samplingPrecisionInv, + PrecisionType.FULL); } From a68ff416d00a278e1d775c1e8a424b7f4d2e819f Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Tue, 5 Apr 2022 17:17:34 -0700 Subject: [PATCH 054/196] parser for CorrelationToCholesky --- .../app/beast/development_parsers.properties | 1 + src/dr/util/CorrelationToCholesky.java | 37 +++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/src/dr/app/beast/development_parsers.properties b/src/dr/app/beast/development_parsers.properties index ade22d8349..243cddfd96 100644 --- a/src/dr/app/beast/development_parsers.properties +++ b/src/dr/app/beast/development_parsers.properties @@ -322,6 +322,7 @@ dr.evomodel.treedatalikelihood.continuous.TreeTraitProviderTest dr.inference.model.DeterminantStatistic dr.inference.model.MatrixDiagonalLogger dr.inference.operators.factorAnalysis.GeneralizedSampleConstraints +dr.util.CorrelationToCholesky # SMOOTH SKYGRID dr.evomodelxml.coalescent.smooth.SmoothSkygridLikelihoodParser diff --git a/src/dr/util/CorrelationToCholesky.java b/src/dr/util/CorrelationToCholesky.java index 77d5c74cf8..1ca60668a5 100644 --- a/src/dr/util/CorrelationToCholesky.java +++ b/src/dr/util/CorrelationToCholesky.java @@ -29,6 +29,7 @@ import dr.math.matrixAlgebra.IllegalDimension; import dr.math.matrixAlgebra.SymmetricMatrix; import dr.math.matrixAlgebra.WrappedMatrix; +import dr.xml.*; import static dr.math.matrixAlgebra.SymmetricMatrix.compoundCorrelationSymmetricMatrix; import static dr.math.matrixAlgebra.SymmetricMatrix.extractUpperTriangular; @@ -153,5 +154,41 @@ private int posStrict(int i, int j) { return i * (2 * dimVector - i - 1) / 2 + (j - i - 1); } + + public static final AbstractXMLObjectParser PARSER = new AbstractXMLObjectParser() { + private static final String DIMENSION = "dimension"; + private static final String CORRELATION_TO_CHOLESKY = "correlationToCholeskyTransform"; + + + @Override + public Object parseXMLObject(XMLObject xo) throws XMLParseException { + int dim = xo.getIntegerAttribute(DIMENSION); + return new CorrelationToCholesky(dim); + } + + @Override + public XMLSyntaxRule[] getSyntaxRules() { + return new XMLSyntaxRule[]{ + AttributeRule.newIntegerRule(DIMENSION) + }; + } + + @Override + public String getParserDescription() { + return "transforms the off-diagonal elements of a correlation to the off-diagonal elements of its" + + " Cholesky decomposition"; + } + + @Override + public Class getReturnType() { + return CorrelationToCholesky.class; + } + + @Override + public String getParserName() { + return CORRELATION_TO_CHOLESKY; + } + }; + } From 64d0ca1892a5def91e1eaf6ecd0452f521c31596 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Tue, 24 May 2022 12:02:57 -0700 Subject: [PATCH 055/196] bug fix --- src/dr/inference/model/TransformedParameter.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dr/inference/model/TransformedParameter.java b/src/dr/inference/model/TransformedParameter.java index 3847b442d6..588858c3ca 100644 --- a/src/dr/inference/model/TransformedParameter.java +++ b/src/dr/inference/model/TransformedParameter.java @@ -201,8 +201,8 @@ public void variableChangedEvent(Variable variable, int index, ChangeType type) public double diffLogJacobian(double[] oldValues, double[] newValues) { // Takes **untransformed** values if (inverse) { - return -transform.getLogJacobian(oldValues, 0, oldValues.length) - + transform.getLogJacobian(newValues, 0, newValues.length); + return -transform.getLogJacobian(transform(oldValues), 0, oldValues.length) + + transform.getLogJacobian(transform(newValues), 0, newValues.length); } else { return transform.getLogJacobian(oldValues, 0, oldValues.length) - transform.getLogJacobian(newValues, 0, newValues.length); From ac55f114da4c3ffcbe8d0bce705f6bbd08567830 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Tue, 24 May 2022 13:58:35 -0700 Subject: [PATCH 056/196] hacky way to check bounds in transformed random walk operator --- .../app/beast/development_parsers.properties | 1 + .../model/GeneralParameterBounds.java | 56 +++++++++++++++++++ ...ransformedParameterRandomWalkOperator.java | 18 +++++- .../CorrelationParameterBoundsParser.java | 38 +++++++++++++ ...rmedParameterRandomWalkOperatorParser.java | 7 ++- 5 files changed, 117 insertions(+), 3 deletions(-) create mode 100644 src/dr/inference/model/GeneralParameterBounds.java create mode 100644 src/dr/inferencexml/model/CorrelationParameterBoundsParser.java diff --git a/src/dr/app/beast/development_parsers.properties b/src/dr/app/beast/development_parsers.properties index 243cddfd96..88a399767a 100644 --- a/src/dr/app/beast/development_parsers.properties +++ b/src/dr/app/beast/development_parsers.properties @@ -323,6 +323,7 @@ dr.inference.model.DeterminantStatistic dr.inference.model.MatrixDiagonalLogger dr.inference.operators.factorAnalysis.GeneralizedSampleConstraints dr.util.CorrelationToCholesky +dr.inferencexml.model.CorrelationParameterBoundsParser # SMOOTH SKYGRID dr.evomodelxml.coalescent.smooth.SmoothSkygridLikelihoodParser diff --git a/src/dr/inference/model/GeneralParameterBounds.java b/src/dr/inference/model/GeneralParameterBounds.java new file mode 100644 index 0000000000..1e4af7a920 --- /dev/null +++ b/src/dr/inference/model/GeneralParameterBounds.java @@ -0,0 +1,56 @@ +package dr.inference.model; + +import org.ejml.data.DenseMatrix64F; +import org.ejml.ops.CommonOps; + +public interface GeneralParameterBounds { + + boolean satisfiesBounds(Parameter parameter); + + + class CorrelationParameterBounds implements GeneralParameterBounds { + + private final int dim; + + public CorrelationParameterBounds(int dim) { + this.dim = dim; + } + + + @Override + public boolean satisfiesBounds(Parameter parameter) { + + DenseMatrix64F C; + double[] c = parameter.getParameterValues(); + + if (c.length == dim * dim) { + C = DenseMatrix64F.wrap(dim, dim, parameter.getParameterValues()); + for (int i = 0; i < dim; i++) { + if (C.get(i, i) != 1.0) { + return false; + } + } + + } else if (c.length == dim * (dim - 1) / 2) { + int ind = 0; + C = new DenseMatrix64F(dim, dim); + for (int i = 0; i < dim; i++) { + C.set(i, i, 1.0); + for (int j = i + 1; j < dim; j++) { + C.set(i, j, c[ind]); + C.set(j, i, c[ind]); + ind++; + } + } + } else { + throw new RuntimeException("incompatible dimensions"); + } + + + double det = CommonOps.det(C); + return det >= 0; // already checked if diagonals were 1 + } + + } + +} diff --git a/src/dr/inference/operators/TransformedParameterRandomWalkOperator.java b/src/dr/inference/operators/TransformedParameterRandomWalkOperator.java index 9d4028b643..41c89e69f1 100644 --- a/src/dr/inference/operators/TransformedParameterRandomWalkOperator.java +++ b/src/dr/inference/operators/TransformedParameterRandomWalkOperator.java @@ -25,6 +25,7 @@ package dr.inference.operators; +import dr.inference.model.GeneralParameterBounds; import dr.inference.model.TransformedParameter; import dr.math.matrixAlgebra.Matrix; @@ -35,10 +36,14 @@ public class TransformedParameterRandomWalkOperator extends RandomWalkOperator { private static boolean DEBUG = false; + private static boolean checkValid = true; + + private final GeneralParameterBounds generalBounds; public TransformedParameterRandomWalkOperator(TransformedParameter parameter, double windowSize, BoundaryCondition bc, double weight, AdaptationMode mode) { super(parameter, windowSize, bc, weight, mode); + this.generalBounds = null; //TODO: implement if needed } public TransformedParameterRandomWalkOperator(TransformedParameter parameter, RandomWalkOperator randomWalkOperator) { @@ -47,14 +52,16 @@ public TransformedParameterRandomWalkOperator(TransformedParameter parameter, Ra randomWalkOperator.getBoundaryCondition(), randomWalkOperator.getWeight(), randomWalkOperator.getMode()); + this.generalBounds = null; //TODO: implement if needed } - public TransformedParameterRandomWalkOperator(RandomWalkOperator randomWalkOperator) { + public TransformedParameterRandomWalkOperator(RandomWalkOperator randomWalkOperator, GeneralParameterBounds bounds) { super((TransformedParameter) randomWalkOperator.getParameter(), randomWalkOperator.getWindowSize(), randomWalkOperator.getBoundaryCondition(), randomWalkOperator.getWeight(), randomWalkOperator.getMode()); + this.generalBounds = bounds; } @Override @@ -73,6 +80,15 @@ public double doOperation() { System.err.println("newValues: " + new Matrix(newValues, newValues.length, 1)); System.err.println("newValuesTrans: " + new Matrix(parameter.getParameterValues(), newValues.length, 1)); } + + if (checkValid) { // GH: below is sloppy, but best I could do without refactoring how Parameter handles bounds + if (generalBounds == null && !parameter.isWithinBounds()) { + return Double.NEGATIVE_INFINITY; + } else if (!generalBounds.satisfiesBounds(parameter)) { + return Double.NEGATIVE_INFINITY; + } + } + // Compute Jacobians ratio += ((TransformedParameter) parameter).diffLogJacobian(oldValues, newValues); if (DEBUG) { diff --git a/src/dr/inferencexml/model/CorrelationParameterBoundsParser.java b/src/dr/inferencexml/model/CorrelationParameterBoundsParser.java new file mode 100644 index 0000000000..64d2068d69 --- /dev/null +++ b/src/dr/inferencexml/model/CorrelationParameterBoundsParser.java @@ -0,0 +1,38 @@ +package dr.inferencexml.model; + +import dr.inference.model.GeneralParameterBounds; +import dr.xml.*; + +public class CorrelationParameterBoundsParser extends AbstractXMLObjectParser { + private static final String CORRELATION_BOUNDS = "correlationBounds"; + private static final String DIMENSION = "dimension"; + + + @Override + public Object parseXMLObject(XMLObject xo) throws XMLParseException { + int dim = xo.getIntegerAttribute(DIMENSION); + return new GeneralParameterBounds.CorrelationParameterBounds(dim); + } + + @Override + public XMLSyntaxRule[] getSyntaxRules() { + return new XMLSyntaxRule[]{ + AttributeRule.newIntegerRule(DIMENSION) + }; + } + + @Override + public String getParserDescription() { + return "Indicates whether a parameter is a valid correlation matrix or not"; + } + + @Override + public Class getReturnType() { + return GeneralParameterBounds.CorrelationParameterBounds.class; + } + + @Override + public String getParserName() { + return CORRELATION_BOUNDS; + } +} diff --git a/src/dr/inferencexml/operators/TransformedParameterRandomWalkOperatorParser.java b/src/dr/inferencexml/operators/TransformedParameterRandomWalkOperatorParser.java index f2d8a5f782..802e4f0e5f 100644 --- a/src/dr/inferencexml/operators/TransformedParameterRandomWalkOperatorParser.java +++ b/src/dr/inferencexml/operators/TransformedParameterRandomWalkOperatorParser.java @@ -25,6 +25,7 @@ package dr.inferencexml.operators; +import dr.inference.model.GeneralParameterBounds; import dr.inference.model.TransformedParameter; import dr.inference.operators.AdaptableMCMCOperator; import dr.inference.operators.MCMCOperator; @@ -47,7 +48,8 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { } catch (XMLParseException e) { throw new XMLParseException("RandomWalkOperatorParser failled in TraansformedParameterRandomWalkOperator."); } - return new TransformedParameterRandomWalkOperator((RandomWalkOperator) randomWalk); + GeneralParameterBounds bounds = (GeneralParameterBounds) xo.getChild(GeneralParameterBounds.class); + return new TransformedParameterRandomWalkOperator((RandomWalkOperator) randomWalk, bounds); } @@ -76,6 +78,7 @@ public XMLSyntaxRule[] getSyntaxRules() { new ElementRule(TransformedParameter.class), }, true), new StringAttributeRule(BOUNDARY_CONDITION, null, RandomWalkOperator.BoundaryCondition.values(), true), - new ElementRule(TransformedParameter.class) + new ElementRule(TransformedParameter.class), + new ElementRule(GeneralParameterBounds.class, true) }; } From 2f5d27d0659667370f7a5ceaf7c7a14f4158f05c Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Tue, 31 May 2022 11:21:23 -0700 Subject: [PATCH 057/196] RandomWalkOperator constructor w/ updateMap directly --- .../operators/RandomWalkOperator.java | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/src/dr/inference/operators/RandomWalkOperator.java b/src/dr/inference/operators/RandomWalkOperator.java index d2135b4430..49fe7d39b7 100644 --- a/src/dr/inference/operators/RandomWalkOperator.java +++ b/src/dr/inference/operators/RandomWalkOperator.java @@ -57,20 +57,34 @@ public RandomWalkOperator(Parameter parameter, double windowSize, BoundaryCondit public RandomWalkOperator(Parameter parameter, Parameter updateIndex, double windowSize, BoundaryCondition boundaryCondition, double weight, AdaptationMode mode) { + + this(parameter, windowSize, boundaryCondition, weight, mode, makeUpdateMap(updateIndex)); + } + + public RandomWalkOperator(Parameter parameter, double windowSize, BoundaryCondition boundaryCondition, + double weight, AdaptationMode mode, List updateMap) { super(mode); + + setWeight(weight); this.parameter = parameter; this.windowSize = windowSize; this.boundaryCondition = boundaryCondition; + this.updateMap = updateMap; + if (updateMap != null) { + updateMapSize = updateMap.size(); + } + } - setWeight(weight); + private static ArrayList makeUpdateMap(Parameter updateIndex) { + ArrayList updateMap = null; if (updateIndex != null) { updateMap = new ArrayList(); for (int i = 0; i < updateIndex.getDimension(); i++) { if (updateIndex.getParameterValue(i) == 1.0) updateMap.add(i); } - updateMapSize=updateMap.size(); } + return updateMap; } /** @@ -265,7 +279,7 @@ public String toString() { protected Parameter parameter = null; private double windowSize = 0.01; - private List updateMap = null; + private List updateMap; private int updateMapSize; private final BoundaryCondition boundaryCondition; } From 1d9e1fb9984b113c2deb2a6267afd2758c8aa656 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Tue, 31 May 2022 11:28:44 -0700 Subject: [PATCH 058/196] updateIndex now passed to TransformedParameterRandomWalkOperator --- src/dr/inference/operators/RandomWalkOperator.java | 4 ++++ .../operators/TransformedParameterRandomWalkOperator.java | 6 ++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/dr/inference/operators/RandomWalkOperator.java b/src/dr/inference/operators/RandomWalkOperator.java index 49fe7d39b7..43ac45b321 100644 --- a/src/dr/inference/operators/RandomWalkOperator.java +++ b/src/dr/inference/operators/RandomWalkOperator.java @@ -102,6 +102,10 @@ public final BoundaryCondition getBoundaryCondition() { return boundaryCondition; } + public final List getUpdateMap() { + return updateMap; + } + /** * change the parameter and return the hastings ratio. */ diff --git a/src/dr/inference/operators/TransformedParameterRandomWalkOperator.java b/src/dr/inference/operators/TransformedParameterRandomWalkOperator.java index 41c89e69f1..0b363bbe38 100644 --- a/src/dr/inference/operators/TransformedParameterRandomWalkOperator.java +++ b/src/dr/inference/operators/TransformedParameterRandomWalkOperator.java @@ -51,7 +51,8 @@ public TransformedParameterRandomWalkOperator(TransformedParameter parameter, Ra randomWalkOperator.getWindowSize(), randomWalkOperator.getBoundaryCondition(), randomWalkOperator.getWeight(), - randomWalkOperator.getMode()); + randomWalkOperator.getMode(), + randomWalkOperator.getUpdateMap()); this.generalBounds = null; //TODO: implement if needed } @@ -60,7 +61,8 @@ public TransformedParameterRandomWalkOperator(RandomWalkOperator randomWalkOpera randomWalkOperator.getWindowSize(), randomWalkOperator.getBoundaryCondition(), randomWalkOperator.getWeight(), - randomWalkOperator.getMode()); + randomWalkOperator.getMode(), + randomWalkOperator.getUpdateMap()); this.generalBounds = bounds; } From 5762b458a86682cfba526ef805a3ca4615801ef9 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Thu, 2 Jun 2022 16:16:22 -0700 Subject: [PATCH 059/196] more generalized framework for operations on TransformedParameters --- .../app/beast/development_parsers.properties | 1 + .../operators/SimpleMCMCOperator.java | 5 ++ .../TransformedParameterOperator.java | 87 +++++++++++++++++++ .../TransformedParameterOperatorParser.java | 41 +++++++++ 4 files changed, 134 insertions(+) create mode 100644 src/dr/inference/operators/TransformedParameterOperator.java create mode 100644 src/dr/inferencexml/operators/TransformedParameterOperatorParser.java diff --git a/src/dr/app/beast/development_parsers.properties b/src/dr/app/beast/development_parsers.properties index 88a399767a..4812144e33 100644 --- a/src/dr/app/beast/development_parsers.properties +++ b/src/dr/app/beast/development_parsers.properties @@ -324,6 +324,7 @@ dr.inference.model.MatrixDiagonalLogger dr.inference.operators.factorAnalysis.GeneralizedSampleConstraints dr.util.CorrelationToCholesky dr.inferencexml.model.CorrelationParameterBoundsParser +dr.inferencexml.operators.TransformedParameterOperatorParser # SMOOTH SKYGRID dr.evomodelxml.coalescent.smooth.SmoothSkygridLikelihoodParser diff --git a/src/dr/inference/operators/SimpleMCMCOperator.java b/src/dr/inference/operators/SimpleMCMCOperator.java index 94a49559c0..7baba1f7ba 100644 --- a/src/dr/inference/operators/SimpleMCMCOperator.java +++ b/src/dr/inference/operators/SimpleMCMCOperator.java @@ -26,6 +26,7 @@ package dr.inference.operators; import dr.inference.model.Likelihood; +import dr.inference.model.Parameter; import java.util.ArrayDeque; import java.util.Deque; @@ -228,6 +229,10 @@ public long getTotalCalculationCount() { */ public abstract double doOperation(); + public Parameter getParameter() { + throw new RuntimeException("not implemented for operator of class " + getOperatorName()); + } + private double weight = 1.0; private long acceptCount = 0; private long rejectCount = 0; diff --git a/src/dr/inference/operators/TransformedParameterOperator.java b/src/dr/inference/operators/TransformedParameterOperator.java new file mode 100644 index 0000000000..766a7770c4 --- /dev/null +++ b/src/dr/inference/operators/TransformedParameterOperator.java @@ -0,0 +1,87 @@ +package dr.inference.operators; + +import dr.inference.model.GeneralParameterBounds; +import dr.inference.model.Parameter; +import dr.inference.model.TransformedParameter; + +public class TransformedParameterOperator extends AbstractAdaptableOperator { + private boolean isAdaptable; + private SimpleMCMCOperator subOperator; + private TransformedParameter parameter; + private boolean checkValid; + private GeneralParameterBounds generalBounds; + + public TransformedParameterOperator(SimpleMCMCOperator operator, GeneralParameterBounds generalBounds) { + + this.subOperator = operator; + setWeight(operator.getWeight()); + this.isAdaptable = operator instanceof AbstractAdaptableOperator; + this.parameter = (TransformedParameter) operator.getParameter(); + + this.generalBounds = generalBounds; + this.checkValid = generalBounds != null; + } + + + @Override + protected void setAdaptableParameterValue(double value) { + if (isAdaptable) { + ((AbstractAdaptableOperator) subOperator).setAdaptableParameterValue(value); + } + } + + @Override + protected double getAdaptableParameterValue() { + if (isAdaptable) { + return ((AbstractAdaptableOperator) subOperator).getAdaptableParameterValue(); + } + return 0; + } + + @Override + public double getRawParameter() { + if (isAdaptable) { + return ((AbstractAdaptableOperator) subOperator).getRawParameter(); + } + throw new RuntimeException("not actually adaptable parameter"); + } + + @Override + public String getAdaptableParameterName() { + if (isAdaptable) { + return ((AbstractAdaptableOperator) subOperator).getAdaptableParameterName(); + } + throw new RuntimeException("not actually adaptable parameter"); + } + + @Override + public String getOperatorName() { + return "transformedParameterOperator." + subOperator.getOperatorName(); + } + + @Override + public double doOperation() { + double[] oldValues = parameter.getParameterUntransformedValues(); + double ratio = subOperator.doOperation(); + double[] newValues = parameter.getParameterUntransformedValues(); + + + if (checkValid) { // GH: below is sloppy, but best I could do without refactoring how Parameter handles bounds + if (generalBounds == null && !parameter.isWithinBounds()) { + return Double.NEGATIVE_INFINITY; + } else if (!generalBounds.satisfiesBounds(parameter)) { + return Double.NEGATIVE_INFINITY; + } + } + + // Compute Jacobians + ratio += parameter.diffLogJacobian(oldValues, newValues); + + return ratio; + } + + @Override + public Parameter getParameter() { + return subOperator.getParameter(); + } +} diff --git a/src/dr/inferencexml/operators/TransformedParameterOperatorParser.java b/src/dr/inferencexml/operators/TransformedParameterOperatorParser.java new file mode 100644 index 0000000000..2717751497 --- /dev/null +++ b/src/dr/inferencexml/operators/TransformedParameterOperatorParser.java @@ -0,0 +1,41 @@ +package dr.inferencexml.operators; + +import dr.inference.model.GeneralParameterBounds; +import dr.inference.operators.SimpleMCMCOperator; +import dr.inference.operators.TransformedParameterOperator; +import dr.xml.*; + +public class TransformedParameterOperatorParser extends AbstractXMLObjectParser { + + private static final String TRANSFORMED_OPERATOR = "transformedParameterOperator"; + + @Override + public Object parseXMLObject(XMLObject xo) throws XMLParseException { + SimpleMCMCOperator operator = (SimpleMCMCOperator) xo.getChild(SimpleMCMCOperator.class); + GeneralParameterBounds bounds = (GeneralParameterBounds) xo.getChild(GeneralParameterBounds.class); + return new TransformedParameterOperator(operator, bounds); + } + + @Override + public XMLSyntaxRule[] getSyntaxRules() { + return new XMLSyntaxRule[]{ + new ElementRule(SimpleMCMCOperator.class), + new ElementRule(GeneralParameterBounds.class, true) + }; + } + + @Override + public String getParserDescription() { + return "operator that corrects the hastings ratio with appropriate Jacobian term due to parameter transform"; + } + + @Override + public Class getReturnType() { + return TransformedParameterOperator.class; + } + + @Override + public String getParserName() { + return TRANSFORMED_OPERATOR; + } +} From 3ded408ce170e8b3733d3ca4446ec36982de9b61 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Thu, 2 Jun 2022 16:19:43 -0700 Subject: [PATCH 060/196] slightly nicer code --- .../operators/TransformedParameterOperator.java | 13 +++++++------ .../TransformedParameterOperatorParser.java | 3 ++- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/dr/inference/operators/TransformedParameterOperator.java b/src/dr/inference/operators/TransformedParameterOperator.java index 766a7770c4..7be3eaff96 100644 --- a/src/dr/inference/operators/TransformedParameterOperator.java +++ b/src/dr/inference/operators/TransformedParameterOperator.java @@ -5,11 +5,12 @@ import dr.inference.model.TransformedParameter; public class TransformedParameterOperator extends AbstractAdaptableOperator { - private boolean isAdaptable; - private SimpleMCMCOperator subOperator; - private TransformedParameter parameter; - private boolean checkValid; - private GeneralParameterBounds generalBounds; + private final boolean isAdaptable; + private final SimpleMCMCOperator subOperator; + private final TransformedParameter parameter; + private final boolean checkValid; + private final GeneralParameterBounds generalBounds; + public static final String TRANSFORMED_OPERATOR = "transformedParameterOperator"; public TransformedParameterOperator(SimpleMCMCOperator operator, GeneralParameterBounds generalBounds) { @@ -56,7 +57,7 @@ public String getAdaptableParameterName() { @Override public String getOperatorName() { - return "transformedParameterOperator." + subOperator.getOperatorName(); + return TRANSFORMED_OPERATOR + "." + subOperator.getOperatorName(); } @Override diff --git a/src/dr/inferencexml/operators/TransformedParameterOperatorParser.java b/src/dr/inferencexml/operators/TransformedParameterOperatorParser.java index 2717751497..244f049c28 100644 --- a/src/dr/inferencexml/operators/TransformedParameterOperatorParser.java +++ b/src/dr/inferencexml/operators/TransformedParameterOperatorParser.java @@ -5,9 +5,10 @@ import dr.inference.operators.TransformedParameterOperator; import dr.xml.*; +import static dr.inference.operators.TransformedParameterOperator.TRANSFORMED_OPERATOR; + public class TransformedParameterOperatorParser extends AbstractXMLObjectParser { - private static final String TRANSFORMED_OPERATOR = "transformedParameterOperator"; @Override public Object parseXMLObject(XMLObject xo) throws XMLParseException { From d190b5e25dd7a7624272eb070d685a9edfc4b960 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Thu, 2 Jun 2022 17:14:07 -0700 Subject: [PATCH 061/196] 2-part operator for sampling from convex parameter spaces --- .../app/beast/development_parsers.properties | 1 + src/dr/inference/model/Parameter.java | 9 +++ .../TransformedMultivariateParameter.java | 16 +++++ .../ConvexSpaceRandomWalkOperator.java | 68 +++++++++++++++++++ .../ConvexSpaceRandomWalkOperatorParser.java | 52 ++++++++++++++ 5 files changed, 146 insertions(+) create mode 100644 src/dr/inference/operators/ConvexSpaceRandomWalkOperator.java create mode 100644 src/dr/inferencexml/operators/ConvexSpaceRandomWalkOperatorParser.java diff --git a/src/dr/app/beast/development_parsers.properties b/src/dr/app/beast/development_parsers.properties index 4812144e33..88091fd8b6 100644 --- a/src/dr/app/beast/development_parsers.properties +++ b/src/dr/app/beast/development_parsers.properties @@ -325,6 +325,7 @@ dr.inference.operators.factorAnalysis.GeneralizedSampleConstraints dr.util.CorrelationToCholesky dr.inferencexml.model.CorrelationParameterBoundsParser dr.inferencexml.operators.TransformedParameterOperatorParser +dr.inferencexml.operators.ConvexSpaceRandomWalkOperatorParser # SMOOTH SKYGRID dr.evomodelxml.coalescent.smooth.SmoothSkygridLikelihoodParser diff --git a/src/dr/inference/model/Parameter.java b/src/dr/inference/model/Parameter.java index b4e85af93b..7ef1a3fc5b 100644 --- a/src/dr/inference/model/Parameter.java +++ b/src/dr/inference/model/Parameter.java @@ -177,6 +177,15 @@ public interface Parameter extends Statistic, Variable { void setParameterUntransformedValue(int dim, double a); + default void setAllParameterValuesQuietly(double[] values) { + if (values.length != getDimension()) { + throw new IllegalArgumentException("supplied values must be of same dimension as parameter"); + } + for (int i = 0; i < this.getDimension(); i++) { + setParameterValueQuietly(i, values[i]); + } + } + boolean isImmutable(); Set FULL_PARAMETER_SET = new LinkedHashSet(); diff --git a/src/dr/inference/model/TransformedMultivariateParameter.java b/src/dr/inference/model/TransformedMultivariateParameter.java index c9eade96bb..218de2020e 100644 --- a/src/dr/inference/model/TransformedMultivariateParameter.java +++ b/src/dr/inference/model/TransformedMultivariateParameter.java @@ -67,6 +67,22 @@ public void setParameterValue(int dim, double value) { public void setParameterValueQuietly(int dim, double value) { update(); transformedValues[dim] = value; + updateParameterQuietlyFromTransformedValues(); + } + + @Override + public void setAllParameterValuesQuietly(double[] values) { + if (values.length != transformedValues.length) { + throw new IllegalArgumentException("supplied values must be of same dimension as transformed parameter"); + } + + for (int i = 0; i < transformedValues.length; i++) { + transformedValues[i] = values[i]; + } + updateParameterQuietlyFromTransformedValues(); + } + + private void updateParameterQuietlyFromTransformedValues() { unTransformedValues = inverse(transformedValues); // Need to update all values for (int i = 0; i < parameter.getDimension(); i++) { diff --git a/src/dr/inference/operators/ConvexSpaceRandomWalkOperator.java b/src/dr/inference/operators/ConvexSpaceRandomWalkOperator.java new file mode 100644 index 0000000000..d4510b5a68 --- /dev/null +++ b/src/dr/inference/operators/ConvexSpaceRandomWalkOperator.java @@ -0,0 +1,68 @@ +package dr.inference.operators; + +import dr.inference.model.Parameter; +import dr.math.distributions.RandomGenerator; +import jebl.math.Random; + +public class ConvexSpaceRandomWalkOperator extends AbstractAdaptableOperator { + private double window; + private final RandomGenerator generator; + private final Parameter parameter; + + public static final String CONVEX_RW = "convexSpaceRandomWalkOperator"; + public static final String WINDOW_SIZE = "relativeWindowSize"; + + public ConvexSpaceRandomWalkOperator(Parameter parameter, RandomGenerator generator, double window, double weight) { + setWeight(weight); + + this.parameter = parameter; + this.generator = generator; + this.window = window; + } + + + @Override + public double doOperation() { + double[] sample = (double[]) generator.nextRandom(); + double[] values = parameter.getParameterValues(); + double t = window * Random.nextDouble(); + double oneMinus = 1.0 - t; + + for (int i = 0; i < values.length; i++) { + sample[i] = values[i] * t + sample[i] * oneMinus; + } + + parameter.setAllParameterValuesQuietly(sample); + parameter.fireParameterChangedEvent(); + return 0.0; //TODO: need to check that the prior is uniform + } + + + @Override + protected void setAdaptableParameterValue(double value) { + if (value > 0) value = 0; + window = Math.exp(value); + } + + @Override + protected double getAdaptableParameterValue() { + return Math.log(window); + } + + @Override + public double getRawParameter() { + return window; + } + + @Override + public String getAdaptableParameterName() { + return WINDOW_SIZE; + } + + @Override + public String getOperatorName() { + return CONVEX_RW; + } + + +} diff --git a/src/dr/inferencexml/operators/ConvexSpaceRandomWalkOperatorParser.java b/src/dr/inferencexml/operators/ConvexSpaceRandomWalkOperatorParser.java new file mode 100644 index 0000000000..3c3d8a87c5 --- /dev/null +++ b/src/dr/inferencexml/operators/ConvexSpaceRandomWalkOperatorParser.java @@ -0,0 +1,52 @@ +package dr.inferencexml.operators; + +import dr.inference.model.Parameter; +import dr.inference.operators.ConvexSpaceRandomWalkOperator; +import dr.inference.operators.MCMCOperator; +import dr.math.distributions.RandomGenerator; +import dr.xml.*; + +public class ConvexSpaceRandomWalkOperatorParser extends AbstractXMLObjectParser { + + + @Override + public Object parseXMLObject(XMLObject xo) throws XMLParseException { + Parameter parameter = (Parameter) xo.getChild(Parameter.class); + RandomGenerator generator = (RandomGenerator) xo.getChild(RandomGenerator.class); + + double weight = xo.getDoubleAttribute(MCMCOperator.WEIGHT); + + double windowSize = xo.getAttribute(ConvexSpaceRandomWalkOperator.WINDOW_SIZE, 1.0); + if (windowSize > 1.0) { + throw new XMLParseException(ConvexSpaceRandomWalkOperator.WINDOW_SIZE + " must be between 0 and 1"); + } + + return new ConvexSpaceRandomWalkOperator(parameter, generator, windowSize, weight); + } + + @Override + public XMLSyntaxRule[] getSyntaxRules() { + return new XMLSyntaxRule[]{ + new ElementRule(Parameter.class), + new ElementRule(RandomGenerator.class), + AttributeRule.newDoubleRule(MCMCOperator.WEIGHT), + AttributeRule.newDoubleRule(ConvexSpaceRandomWalkOperator.WINDOW_SIZE, true) + }; + } + + @Override + public String getParserDescription() { + return "operator that first samples uniformly from some space then updates the parameter to a point along" + + " the line from its current value to the sampled one"; + } + + @Override + public Class getReturnType() { + return ConvexSpaceRandomWalkOperator.class; + } + + @Override + public String getParserName() { + return ConvexSpaceRandomWalkOperator.CONVEX_RW; + } +} From c8fa6f8bf062b684b03e75ebe8148d105d33cb21 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Fri, 3 Jun 2022 10:43:34 -0700 Subject: [PATCH 062/196] LKJ correlation distribution with structural zeros (partial) --- .../app/beast/development_parsers.properties | 1 + ...WithStructuralZerosDistributionParser.java | 57 +++++++++++++++++ ...lationWithStructuralZerosDistribution.java | 64 +++++++++++++++++++ src/dr/xml/XMLObject.java | 11 +++- 4 files changed, 132 insertions(+), 1 deletion(-) create mode 100644 src/dr/inferencexml/distribution/LKJCorrelationWithStructuralZerosDistributionParser.java create mode 100644 src/dr/math/distributions/LKJCorrelationWithStructuralZerosDistribution.java diff --git a/src/dr/app/beast/development_parsers.properties b/src/dr/app/beast/development_parsers.properties index 88091fd8b6..7e791c8605 100644 --- a/src/dr/app/beast/development_parsers.properties +++ b/src/dr/app/beast/development_parsers.properties @@ -326,6 +326,7 @@ dr.util.CorrelationToCholesky dr.inferencexml.model.CorrelationParameterBoundsParser dr.inferencexml.operators.TransformedParameterOperatorParser dr.inferencexml.operators.ConvexSpaceRandomWalkOperatorParser +dr.inferencexml.distribution.LKJCorrelationWithStructuralZerosDistributionParser # SMOOTH SKYGRID dr.evomodelxml.coalescent.smooth.SmoothSkygridLikelihoodParser diff --git a/src/dr/inferencexml/distribution/LKJCorrelationWithStructuralZerosDistributionParser.java b/src/dr/inferencexml/distribution/LKJCorrelationWithStructuralZerosDistributionParser.java new file mode 100644 index 0000000000..70bf1c17d4 --- /dev/null +++ b/src/dr/inferencexml/distribution/LKJCorrelationWithStructuralZerosDistributionParser.java @@ -0,0 +1,57 @@ +package dr.inferencexml.distribution; + +import dr.inference.model.ParameterParser; +import dr.math.distributions.LKJCorrelationWithStructuralZerosDistribution; +import dr.xml.*; + +import java.util.ArrayList; + +public class LKJCorrelationWithStructuralZerosDistributionParser extends AbstractXMLObjectParser { + + private static final String BLOCKS = "blocks"; + private static final String BLOCK = "block"; //name doesn't actually matter + + @Override + public Object parseXMLObject(XMLObject xo) throws XMLParseException { + XMLObject bxo = xo.getChild(BLOCKS); + ArrayList blocks = new ArrayList<>(); + + for (int i = 0; i < bxo.getChildCount(); i++) { + XMLObject bcxo = (XMLObject) xo.getChild(i); + blocks.add(bcxo.getIntegerArrayChild(0)); + } + + int dim = xo.getIntegerAttribute(ParameterParser.DIMENSION); + double shape = xo.getDoubleAttribute(PriorParsers.SHAPE); + + return new LKJCorrelationWithStructuralZerosDistribution(dim, shape, blocks); + + } + + @Override + public XMLSyntaxRule[] getSyntaxRules() { + return new XMLSyntaxRule[]{ + AttributeRule.newDoubleRule(PriorParsers.SHAPE), + AttributeRule.newIntegerRule(ParameterParser.DIMENSION), + new ElementRule(BLOCKS, new XMLSyntaxRule[]{ + new ElementRule(BLOCK, new XMLSyntaxRule[0], "", 0, Integer.MAX_VALUE) + }) + + }; + } + + @Override + public String getParserDescription() { + return "LKJ correlation distribution with some diagonal blocks fixed at the identity matrix"; + } + + @Override + public Class getReturnType() { + return LKJCorrelationWithStructuralZerosDistribution.class; + } + + @Override + public String getParserName() { + return LKJCorrelationWithStructuralZerosDistribution.LKJ_WITH_ZEROS; + } +} diff --git a/src/dr/math/distributions/LKJCorrelationWithStructuralZerosDistribution.java b/src/dr/math/distributions/LKJCorrelationWithStructuralZerosDistribution.java new file mode 100644 index 0000000000..bf977035d2 --- /dev/null +++ b/src/dr/math/distributions/LKJCorrelationWithStructuralZerosDistribution.java @@ -0,0 +1,64 @@ +package dr.math.distributions; + +import dr.math.MathUtils; + +import java.util.ArrayList; + +public class LKJCorrelationWithStructuralZerosDistribution extends LKJCorrelationDistribution implements RandomGenerator { + + private final int[] blockAssignments; + + public static final String LKJ_WITH_ZEROS = "LKJCorrelationWithZerosDistribution"; + + + public LKJCorrelationWithStructuralZerosDistribution(int dim, double shape, ArrayList zeroBlocks) { + super(dim, shape); + + this.blockAssignments = new int[upperTriangularSize(dim)]; + for (int i = 0; i < zeroBlocks.size(); i++) { + for (int j = 0; j < zeroBlocks.get(i).length; j++) { + blockAssignments[zeroBlocks.get(i)[j]] = i + 1; //want to save 0 for indices that aren't in any block + } + } + + System.err.println("Warning: LKJCorrelationDistribution with structural zeros does not have proper normalization constant"); //TODO + } + + @Override + public double[] nextRandom() { + int ind = 0; + int n = upperTriangularSize(dim); + double[] partialDraw = new double[n]; + + for (int row = 0; row < dim; row++) { + for (int col = row + 1; col < dim; col++) { + + if (blockAssignments[row] == 0 || blockAssignments[row] != blockAssignments[col]) { + int diag = row - col; + double alpha = shape + 0.5 * (dim - 1 - diag); + double beta = MathUtils.nextBeta(alpha, alpha); + beta *= 2; + beta -= 1; + + partialDraw[ind] = beta; + } + + ind++; + } + } + + double[] correlation = new double[n]; + + //convert partials to correlation matrix + for (int diag = 1; diag < dim; diag++) { + //TODO + } + + return correlation; + } + + @Override + public double logPdf(Object x) { + return logPdf((double[]) x); + } +} diff --git a/src/dr/xml/XMLObject.java b/src/dr/xml/XMLObject.java index 571791ea13..d4ceccbbb3 100644 --- a/src/dr/xml/XMLObject.java +++ b/src/dr/xml/XMLObject.java @@ -329,12 +329,21 @@ public double[] getDoubleArrayAttribute(String name) throws XMLParseException { } /** - * @return the named attribute as a double[]. + * @return the named attribute as a int[]. */ public int[] getIntegerArrayAttribute(String name) throws XMLParseException { return getIntegerArray(getAndTest(name)); } + /** + * @param i the index of the child to return + * @return the ith child as a int[]. + * @throws XMLParseException if getChild(i) would + */ + public int[] getIntegerArrayChild(int i) throws XMLParseException { + return getIntegerArray(getChild(i)); + } + /** * @return the named attribute as an integer. */ From c2941d20899ec1f10d9ecf1dbb963ba2c0d81aec Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Fri, 3 Jun 2022 11:06:10 -0700 Subject: [PATCH 063/196] starting to refactor parsing for ContinuousTraitDataModel --- .../ContinuousTraitDataModelParser.java | 105 ++++++++++++++++++ .../ContinuousDataLikelihoodParser.java | 4 +- 2 files changed, 107 insertions(+), 2 deletions(-) create mode 100644 src/dr/evomodelxml/continuous/ContinuousTraitDataModelParser.java diff --git a/src/dr/evomodelxml/continuous/ContinuousTraitDataModelParser.java b/src/dr/evomodelxml/continuous/ContinuousTraitDataModelParser.java new file mode 100644 index 0000000000..ee0a84ec46 --- /dev/null +++ b/src/dr/evomodelxml/continuous/ContinuousTraitDataModelParser.java @@ -0,0 +1,105 @@ +package dr.evomodelxml.continuous; + +import dr.evolution.tree.Tree; +import dr.evomodel.tree.TreeModel; +import dr.evomodel.treedatalikelihood.continuous.ContinuousTraitDataModel; +import dr.evomodel.treedatalikelihood.continuous.IntegratedProcessTraitDataModel; +import dr.evomodel.treedatalikelihood.continuous.cdi.PrecisionType; +import dr.evomodelxml.treelikelihood.TreeTraitParserUtilities; +import dr.inference.model.CompoundParameter; +import dr.inference.model.Parameter; +import dr.xml.*; + +import static dr.evomodelxml.treedatalikelihood.ContinuousDataLikelihoodParser.FORCE_FULL_PRECISION; + +public class ContinuousTraitDataModelParser extends AbstractXMLObjectParser { + + private static String CONTINUOUS_TRAITS = "continuousTraitDataModel"; + + public static final String INTEGRATED_PROCESS = "integratedProcess"; + public static final String FORCE_COMPLETELY_MISSING = "forceCompletelyMissing"; + + private static final String NUM_TRAITS = "numTraits"; + + + @Override + public Object parseXMLObject(XMLObject xo) throws XMLParseException { + + Tree treeModel = (Tree) xo.getChild(Tree.class); + boolean[] missingIndicators; + final String traitName; + + boolean useMissingIndices = true; + boolean integratedProcess = xo.getAttribute(INTEGRATED_PROCESS, false); + + + TreeTraitParserUtilities utilities = new TreeTraitParserUtilities(); + + TreeTraitParserUtilities.TraitsAndMissingIndices returnValue = + utilities.parseTraitsFromTaxonAttributes(xo, treeModel, true); + CompoundParameter traitParameter = returnValue.traitParameter; + + int dimAll = traitParameter.getParameter(0).getDimension(); + int numTraits = xo.getAttribute(NUM_TRAITS, 1); + int dim = dimAll / numTraits; //TODO: check that dimAll is a factor of numTraits, also TODO: maybe pass numTraits directly? + + + missingIndicators = returnValue.getMissingIndicators(); +// sampleMissingParameter = returnValue.sampleMissingParameter; + traitName = returnValue.traitName; + useMissingIndices = returnValue.useMissingIndices; + + PrecisionType precisionType = PrecisionType.SCALAR; + + if (xo.getAttribute(FORCE_FULL_PRECISION, false) || + (useMissingIndices && !xo.getAttribute(FORCE_COMPLETELY_MISSING, false))) { + precisionType = PrecisionType.FULL; + } + + if (xo.hasChildNamed(TreeTraitParserUtilities.JITTER)) { + utilities.jitter(xo, dim, missingIndicators); + } + +// System.err.println("Using precisionType == " + precisionType + " for data model."); + + if (integratedProcess) { + return new IntegratedProcessTraitDataModel(traitName, + traitParameter, + missingIndicators, useMissingIndices, + dim, precisionType); + } + + return new ContinuousTraitDataModel(traitName, + traitParameter, + missingIndicators, useMissingIndices, + dim, precisionType); + } + + @Override + public XMLSyntaxRule[] getSyntaxRules() { + return new XMLSyntaxRule[]{ + new ElementRule(Tree.class), + new ElementRule(TreeTraitParserUtilities.TRAIT_PARAMETER, new XMLSyntaxRule[]{ + new ElementRule(Parameter.class) + }), + AttributeRule.newBooleanRule(INTEGRATED_PROCESS, true), + AttributeRule.newIntegerRule(NUM_TRAITS, true), + AttributeRule.newBooleanRule(FORCE_COMPLETELY_MISSING, true) + }; + } + + @Override + public String getParserDescription() { + return null; + } + + @Override + public Class getReturnType() { + return ContinuousTraitDataModel.class; + } + + @Override + public String getParserName() { + return CONTINUOUS_TRAITS; + } +} diff --git a/src/dr/evomodelxml/treedatalikelihood/ContinuousDataLikelihoodParser.java b/src/dr/evomodelxml/treedatalikelihood/ContinuousDataLikelihoodParser.java index f7bf151bf7..4a001f0d5e 100644 --- a/src/dr/evomodelxml/treedatalikelihood/ContinuousDataLikelihoodParser.java +++ b/src/dr/evomodelxml/treedatalikelihood/ContinuousDataLikelihoodParser.java @@ -46,6 +46,8 @@ import java.util.List; import static dr.evomodel.treedatalikelihood.preorder.AbstractRealizedContinuousTraitDelegate.getTipTraitName; +import static dr.evomodelxml.continuous.ContinuousTraitDataModelParser.FORCE_COMPLETELY_MISSING; +import static dr.evomodelxml.continuous.ContinuousTraitDataModelParser.INTEGRATED_PROCESS; /** * @author Andrew Rambaut @@ -62,7 +64,6 @@ public class ContinuousDataLikelihoodParser extends AbstractXMLObjectParser { private static final String OPTIMAL_TRAITS = AbstractMultivariateTraitLikelihood.OPTIMAL_TRAITS; private static final String RECONSTRUCT_TRAITS = "reconstructTraits"; - private static final String FORCE_COMPLETELY_MISSING = "forceCompletelyMissing"; private static final String ALLOW_SINGULAR = "allowSingular"; public static final String FORCE_FULL_PRECISION = "forceFullPrecision"; private static final String FORCE_DRIFT = "forceDrift"; @@ -70,7 +71,6 @@ public class ContinuousDataLikelihoodParser extends AbstractXMLObjectParser { private static final String STRENGTH_OF_SELECTION_MATRIX = "strengthOfSelectionMatrix"; - private static final String INTEGRATED_PROCESS = "integratedProcess"; public static final String CONTINUOUS_DATA_LIKELIHOOD = "traitDataLikelihood"; From bd2fd40ab5f8bc72755d03e440e68115160d638e Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Fri, 3 Jun 2022 11:35:13 -0700 Subject: [PATCH 064/196] remove code duplication --- .../ContinuousTraitDataModelParser.java | 28 ++++++----- .../ContinuousDataLikelihoodParser.java | 47 +++---------------- 2 files changed, 23 insertions(+), 52 deletions(-) diff --git a/src/dr/evomodelxml/continuous/ContinuousTraitDataModelParser.java b/src/dr/evomodelxml/continuous/ContinuousTraitDataModelParser.java index ee0a84ec46..4cddbe5417 100644 --- a/src/dr/evomodelxml/continuous/ContinuousTraitDataModelParser.java +++ b/src/dr/evomodelxml/continuous/ContinuousTraitDataModelParser.java @@ -24,12 +24,15 @@ public class ContinuousTraitDataModelParser extends AbstractXMLObjectParser { @Override public Object parseXMLObject(XMLObject xo) throws XMLParseException { + return parseContinuousTraitDataModel(xo); + } + public static ContinuousTraitDataModel parseContinuousTraitDataModel(XMLObject xo) throws XMLParseException { Tree treeModel = (Tree) xo.getChild(Tree.class); boolean[] missingIndicators; final String traitName; - boolean useMissingIndices = true; + boolean useMissingIndices; boolean integratedProcess = xo.getAttribute(INTEGRATED_PROCESS, false); @@ -75,22 +78,25 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { dim, precisionType); } + public static final XMLSyntaxRule[] rules = new XMLSyntaxRule[]{ + new ElementRule(Tree.class), + new ElementRule(TreeTraitParserUtilities.TRAIT_PARAMETER, new XMLSyntaxRule[]{ + new ElementRule(Parameter.class) + }), + AttributeRule.newBooleanRule(INTEGRATED_PROCESS, true), + AttributeRule.newIntegerRule(NUM_TRAITS, true), + AttributeRule.newBooleanRule(FORCE_COMPLETELY_MISSING, true), + AttributeRule.newStringRule(TreeTraitParserUtilities.TRAIT_NAME, true) + }; + @Override public XMLSyntaxRule[] getSyntaxRules() { - return new XMLSyntaxRule[]{ - new ElementRule(Tree.class), - new ElementRule(TreeTraitParserUtilities.TRAIT_PARAMETER, new XMLSyntaxRule[]{ - new ElementRule(Parameter.class) - }), - AttributeRule.newBooleanRule(INTEGRATED_PROCESS, true), - AttributeRule.newIntegerRule(NUM_TRAITS, true), - AttributeRule.newBooleanRule(FORCE_COMPLETELY_MISSING, true) - }; + return rules; } @Override public String getParserDescription() { - return null; + return "parses continuous traits from a tree"; } @Override diff --git a/src/dr/evomodelxml/treedatalikelihood/ContinuousDataLikelihoodParser.java b/src/dr/evomodelxml/treedatalikelihood/ContinuousDataLikelihoodParser.java index 4a001f0d5e..c51a31a2fa 100644 --- a/src/dr/evomodelxml/treedatalikelihood/ContinuousDataLikelihoodParser.java +++ b/src/dr/evomodelxml/treedatalikelihood/ContinuousDataLikelihoodParser.java @@ -37,10 +37,9 @@ import dr.evomodel.treedatalikelihood.continuous.*; import dr.evomodel.treedatalikelihood.continuous.cdi.PrecisionType; import dr.evomodel.treedatalikelihood.preorder.*; +import dr.evomodelxml.continuous.ContinuousTraitDataModelParser; import dr.evomodelxml.treelikelihood.TreeTraitParserUtilities; -import dr.inference.model.CompoundParameter; import dr.inference.model.MatrixParameterInterface; -import dr.inference.model.Parameter; import dr.xml.*; import java.util.List; @@ -111,40 +110,8 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { boolean integratedProcess = xo.getAttribute(INTEGRATED_PROCESS, false); if (xo.hasChildNamed(TreeTraitParserUtilities.TRAIT_PARAMETER)) { - TreeTraitParserUtilities utilities = new TreeTraitParserUtilities(); - - TreeTraitParserUtilities.TraitsAndMissingIndices returnValue = - utilities.parseTraitsFromTaxonAttributes(xo, treeModel, true); - CompoundParameter traitParameter = returnValue.traitParameter; - missingIndicators = returnValue.getMissingIndicators(); -// sampleMissingParameter = returnValue.sampleMissingParameter; - traitName = returnValue.traitName; - useMissingIndices = returnValue.useMissingIndices; - - PrecisionType precisionType = PrecisionType.SCALAR; - - if (xo.getAttribute(FORCE_FULL_PRECISION, false) || - (useMissingIndices && !xo.getAttribute(FORCE_COMPLETELY_MISSING, false))) { - precisionType = PrecisionType.FULL; - } - - if (xo.hasChildNamed(TreeTraitParserUtilities.JITTER)) { - utilities.jitter(xo, diffusionModel.getPrecisionmatrix().length, missingIndicators); - } - -// System.err.println("Using precisionType == " + precisionType + " for data model."); - - if (!integratedProcess) { - dataModel = new ContinuousTraitDataModel(traitName, - traitParameter, - missingIndicators, useMissingIndices, - dim, precisionType); - } else { - dataModel = new IntegratedProcessTraitDataModel(traitName, - traitParameter, - missingIndicators, useMissingIndices, - dim, precisionType); - } + dataModel = ContinuousTraitDataModelParser.parseContinuousTraitDataModel(xo); + traitName = xo.getStringAttribute(TreeTraitParserUtilities.TRAIT_NAME); } else { // Has ContinuousTraitPartialsProvider dataModel = (ContinuousTraitPartialsProvider) xo.getChild(ContinuousTraitPartialsProvider.class); traitName = xo.getAttribute(TreeTraitParserUtilities.TRAIT_NAME, TreeTraitParserUtilities.DEFAULT_TRAIT_NAME); @@ -276,12 +243,10 @@ public Class getReturnType() { new ElementRule(MultivariateDiffusionModel.class), new ElementRule(BranchRateModel.class, true), new ElementRule(CONJUGATE_ROOT_PRIOR, ConjugateRootTraitPrior.rules), - new XORRule( - new ElementRule(TreeTraitParserUtilities.TRAIT_PARAMETER, new XMLSyntaxRule[]{ - new ElementRule(Parameter.class), - }), + new XORRule(new XMLSyntaxRule[]{ + new AndRule(ContinuousTraitDataModelParser.rules), new ElementRule(ContinuousTraitPartialsProvider.class) - ), + }), new ElementRule(DRIFT_MODELS, new XMLSyntaxRule[]{ new ElementRule(BranchRateModel.class, 1, Integer.MAX_VALUE), }, true), From 108e41f12f553f31e598f00ee82eef07953e8404 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Fri, 3 Jun 2022 11:37:33 -0700 Subject: [PATCH 065/196] new parser --- src/dr/app/beast/development_parsers.properties | 1 + 1 file changed, 1 insertion(+) diff --git a/src/dr/app/beast/development_parsers.properties b/src/dr/app/beast/development_parsers.properties index 7e791c8605..abad96a3fd 100644 --- a/src/dr/app/beast/development_parsers.properties +++ b/src/dr/app/beast/development_parsers.properties @@ -327,6 +327,7 @@ dr.inferencexml.model.CorrelationParameterBoundsParser dr.inferencexml.operators.TransformedParameterOperatorParser dr.inferencexml.operators.ConvexSpaceRandomWalkOperatorParser dr.inferencexml.distribution.LKJCorrelationWithStructuralZerosDistributionParser +dr.evomodelxml.continuous.ContinuousTraitDataModelParser # SMOOTH SKYGRID dr.evomodelxml.coalescent.smooth.SmoothSkygridLikelihoodParser From 74c644702d5e33742afde7dd00766ef3573f0964 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Fri, 3 Jun 2022 11:40:12 -0700 Subject: [PATCH 066/196] forgot about default --- .../treedatalikelihood/ContinuousDataLikelihoodParser.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/dr/evomodelxml/treedatalikelihood/ContinuousDataLikelihoodParser.java b/src/dr/evomodelxml/treedatalikelihood/ContinuousDataLikelihoodParser.java index c51a31a2fa..141348dd02 100644 --- a/src/dr/evomodelxml/treedatalikelihood/ContinuousDataLikelihoodParser.java +++ b/src/dr/evomodelxml/treedatalikelihood/ContinuousDataLikelihoodParser.java @@ -111,12 +111,11 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { if (xo.hasChildNamed(TreeTraitParserUtilities.TRAIT_PARAMETER)) { dataModel = ContinuousTraitDataModelParser.parseContinuousTraitDataModel(xo); - traitName = xo.getStringAttribute(TreeTraitParserUtilities.TRAIT_NAME); } else { // Has ContinuousTraitPartialsProvider dataModel = (ContinuousTraitPartialsProvider) xo.getChild(ContinuousTraitPartialsProvider.class); - traitName = xo.getAttribute(TreeTraitParserUtilities.TRAIT_NAME, TreeTraitParserUtilities.DEFAULT_TRAIT_NAME); } + traitName = xo.getAttribute(TreeTraitParserUtilities.TRAIT_NAME, TreeTraitParserUtilities.DEFAULT_TRAIT_NAME); dataModel.setTipTraitName(getTipTraitName(traitName)); // TODO: not an ideal solution as the trait name could be set differently later ConjugateRootTraitPrior rootPrior = ConjugateRootTraitPrior.parseConjugateRootTraitPrior(xo, dataModel.getTraitDimension()); From 4d9b1e2704e447765c42ee7f8d57b6471d19dabc Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Fri, 3 Jun 2022 12:08:50 -0700 Subject: [PATCH 067/196] making sure that ContinuousDataLikelihoodParser pulls in the correct values from the data model --- .../continuous/ContinuousTraitDataModel.java | 7 +++++++ .../continuous/ContinuousTraitPartialsProvider.java | 2 ++ .../continuous/ElementaryVectorDataModel.java | 5 +++++ .../continuous/EmptyTraitDataModel.java | 5 +++++ .../continuous/IntegratedFactorAnalysisLikelihood.java | 5 +++++ .../continuous/JointPartialsProvider.java | 9 +++++++++ .../ContinuousDataLikelihoodParser.java | 10 +++------- 7 files changed, 36 insertions(+), 7 deletions(-) diff --git a/src/dr/evomodel/treedatalikelihood/continuous/ContinuousTraitDataModel.java b/src/dr/evomodel/treedatalikelihood/continuous/ContinuousTraitDataModel.java index ef98db36a3..451686bced 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/ContinuousTraitDataModel.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/ContinuousTraitDataModel.java @@ -48,6 +48,7 @@ public class ContinuousTraitDataModel extends AbstractModel implements Continuou final PrecisionType precisionType; private final boolean[] missingIndicators; + private boolean useMissingIndices; private String tipTraitName = null; @@ -61,6 +62,7 @@ public ContinuousTraitDataModel(String name, addVariable(parameter); this.originalMissingIndicators = missingIndicators; + this.useMissingIndices = true; this.missingIndicators = (useMissingIndices ? missingIndicators : new boolean[missingIndicators.length]); this.dimTrait = dimTrait; @@ -108,6 +110,11 @@ public CompoundParameter getParameter() { return parameter; } + @Override + public boolean usesMissingIndices() { + return useMissingIndices; + } + @Override public List getMissingIndices() { return ContinuousTraitPartialsProvider.indicatorToIndices(missingIndicators); // TODO: finish deprecating diff --git a/src/dr/evomodel/treedatalikelihood/continuous/ContinuousTraitPartialsProvider.java b/src/dr/evomodel/treedatalikelihood/continuous/ContinuousTraitPartialsProvider.java index 439a12a324..6b7205e61a 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/ContinuousTraitPartialsProvider.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/ContinuousTraitPartialsProvider.java @@ -70,6 +70,8 @@ default boolean[] getTraitMissingIndicators() { // returns null for no missing t String getModelName(); + boolean usesMissingIndices(); + default boolean getDefaultAllowSingular() { return false; } diff --git a/src/dr/evomodel/treedatalikelihood/continuous/ElementaryVectorDataModel.java b/src/dr/evomodel/treedatalikelihood/continuous/ElementaryVectorDataModel.java index ed3354875a..503b72bb91 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/ElementaryVectorDataModel.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/ElementaryVectorDataModel.java @@ -102,6 +102,11 @@ public CompoundParameter getParameter() { return traitParameter; } + @Override + public boolean usesMissingIndices() { + return false; + } + public void setTipTraitDimParameters(int tip, int trait, int dim) { tipIndicator.setParameterValue(trait, tip); diff --git a/src/dr/evomodel/treedatalikelihood/continuous/EmptyTraitDataModel.java b/src/dr/evomodel/treedatalikelihood/continuous/EmptyTraitDataModel.java index 5a0861b9c5..c75dedd4b3 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/EmptyTraitDataModel.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/EmptyTraitDataModel.java @@ -93,6 +93,11 @@ public String getModelName() { return name; } + @Override + public boolean usesMissingIndices() { + return false; + } + @Override public List getMissingIndices() { return null; diff --git a/src/dr/evomodel/treedatalikelihood/continuous/IntegratedFactorAnalysisLikelihood.java b/src/dr/evomodel/treedatalikelihood/continuous/IntegratedFactorAnalysisLikelihood.java index 64890543fa..a782fb6b84 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/IntegratedFactorAnalysisLikelihood.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/IntegratedFactorAnalysisLikelihood.java @@ -227,6 +227,11 @@ public CompoundParameter getParameter() { return traitParameter; } + @Override + public boolean usesMissingIndices() { + return true; + } + @Override public boolean getDefaultAllowSingular() { return true; diff --git a/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java b/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java index b4617071a9..b7e76e8faa 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java @@ -226,6 +226,15 @@ public CompoundParameter getParameter() { return jointDataParameter; } + @Override + public boolean usesMissingIndices() { + boolean useMissingIndices = false; + for (ContinuousTraitPartialsProvider provider : providers) { + useMissingIndices = useMissingIndices || provider.usesMissingIndices(); + } + return useMissingIndices; + } + @Override protected void handleModelChangedEvent(Model model, Object object, int index) { fireModelChanged(); // sub-providers should handle everything else diff --git a/src/dr/evomodelxml/treedatalikelihood/ContinuousDataLikelihoodParser.java b/src/dr/evomodelxml/treedatalikelihood/ContinuousDataLikelihoodParser.java index 141348dd02..c0eb6d55ec 100644 --- a/src/dr/evomodelxml/treedatalikelihood/ContinuousDataLikelihoodParser.java +++ b/src/dr/evomodelxml/treedatalikelihood/ContinuousDataLikelihoodParser.java @@ -100,14 +100,9 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { ContinuousRateTransformation rateTransformation = new ContinuousRateTransformation.Default( treeModel, scaleByTime, useTreeLength); - final int dim = diffusionModel.getPrecisionmatrix().length; - final String traitName; - boolean[] missingIndicators; -// Parameter sampleMissingParameter = null; + ContinuousTraitPartialsProvider dataModel; - boolean useMissingIndices = true; - boolean integratedProcess = xo.getAttribute(INTEGRATED_PROCESS, false); if (xo.hasChildNamed(TreeTraitParserUtilities.TRAIT_PARAMETER)) { dataModel = ContinuousTraitDataModelParser.parseContinuousTraitDataModel(xo); @@ -145,6 +140,7 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { } DiffusionProcessDelegate diffusionProcessDelegate; + boolean integratedProcess = dataModel instanceof IntegratedProcessTraitDataModel; //TODO: can add to interface if that would be better if ((optimalTraitsModels != null && elasticModel != null) || xo.getAttribute(FORCE_OU, false)) { if (!integratedProcess) { diffusionProcessDelegate = new OUDiffusionModelDelegate(treeModel, diffusionModel, optimalTraitsModels, elasticModel); @@ -172,7 +168,7 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { if (reconstructTraits) { // if (missingIndices != null && missingIndices.size() == 0) { - if (!useMissingIndices) { + if (!dataModel.usesMissingIndices()) { ProcessSimulationDelegate simulationDelegate = delegate.getPrecisionType() == PrecisionType.SCALAR ? From c1aeb0cff91187c61b5a5b34c9bcb9c23c060d7c Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Fri, 3 Jun 2022 12:13:20 -0700 Subject: [PATCH 068/196] a little code cleaning --- .../continuous/ContinuousTraitDataModelParser.java | 4 ++-- .../treedatalikelihood/ContinuousDataLikelihoodParser.java | 6 ------ 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/src/dr/evomodelxml/continuous/ContinuousTraitDataModelParser.java b/src/dr/evomodelxml/continuous/ContinuousTraitDataModelParser.java index 4cddbe5417..91ca462c77 100644 --- a/src/dr/evomodelxml/continuous/ContinuousTraitDataModelParser.java +++ b/src/dr/evomodelxml/continuous/ContinuousTraitDataModelParser.java @@ -1,7 +1,6 @@ package dr.evomodelxml.continuous; import dr.evolution.tree.Tree; -import dr.evomodel.tree.TreeModel; import dr.evomodel.treedatalikelihood.continuous.ContinuousTraitDataModel; import dr.evomodel.treedatalikelihood.continuous.IntegratedProcessTraitDataModel; import dr.evomodel.treedatalikelihood.continuous.cdi.PrecisionType; @@ -10,7 +9,6 @@ import dr.inference.model.Parameter; import dr.xml.*; -import static dr.evomodelxml.treedatalikelihood.ContinuousDataLikelihoodParser.FORCE_FULL_PRECISION; public class ContinuousTraitDataModelParser extends AbstractXMLObjectParser { @@ -18,6 +16,8 @@ public class ContinuousTraitDataModelParser extends AbstractXMLObjectParser { public static final String INTEGRATED_PROCESS = "integratedProcess"; public static final String FORCE_COMPLETELY_MISSING = "forceCompletelyMissing"; + public static final String FORCE_FULL_PRECISION = "forceFullPrecision"; + private static final String NUM_TRAITS = "numTraits"; diff --git a/src/dr/evomodelxml/treedatalikelihood/ContinuousDataLikelihoodParser.java b/src/dr/evomodelxml/treedatalikelihood/ContinuousDataLikelihoodParser.java index c0eb6d55ec..8a096112e4 100644 --- a/src/dr/evomodelxml/treedatalikelihood/ContinuousDataLikelihoodParser.java +++ b/src/dr/evomodelxml/treedatalikelihood/ContinuousDataLikelihoodParser.java @@ -45,8 +45,6 @@ import java.util.List; import static dr.evomodel.treedatalikelihood.preorder.AbstractRealizedContinuousTraitDelegate.getTipTraitName; -import static dr.evomodelxml.continuous.ContinuousTraitDataModelParser.FORCE_COMPLETELY_MISSING; -import static dr.evomodelxml.continuous.ContinuousTraitDataModelParser.INTEGRATED_PROCESS; /** * @author Andrew Rambaut @@ -64,7 +62,6 @@ public class ContinuousDataLikelihoodParser extends AbstractXMLObjectParser { private static final String RECONSTRUCT_TRAITS = "reconstructTraits"; private static final String ALLOW_SINGULAR = "allowSingular"; - public static final String FORCE_FULL_PRECISION = "forceFullPrecision"; private static final String FORCE_DRIFT = "forceDrift"; private static final String FORCE_OU = "forceOU"; @@ -254,12 +251,9 @@ public Class getReturnType() { AttributeRule.newBooleanRule(USE_TREE_LENGTH, true), AttributeRule.newBooleanRule(RECIPROCAL_RATES, true), AttributeRule.newBooleanRule(RECONSTRUCT_TRAITS, true), - AttributeRule.newBooleanRule(FORCE_COMPLETELY_MISSING, true), AttributeRule.newBooleanRule(ALLOW_SINGULAR, true), - AttributeRule.newBooleanRule(FORCE_FULL_PRECISION, true), AttributeRule.newBooleanRule(FORCE_DRIFT, true), AttributeRule.newBooleanRule(FORCE_OU, true), - AttributeRule.newBooleanRule(INTEGRATED_PROCESS, true), AttributeRule.newStringRule(TreeTraitParserUtilities.TRAIT_NAME, true), TreeTraitParserUtilities.jitterRules(true), }; From fee412945d98b963e1f4aee4129460b7f5f75f55 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Fri, 3 Jun 2022 12:18:35 -0700 Subject: [PATCH 069/196] oops --- .../continuous/RepeatedMeasuresTraitDataModel.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java b/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java index 34329cf7b2..058a93b980 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java @@ -33,7 +33,7 @@ import dr.evomodel.treedatalikelihood.continuous.cdi.PrecisionType; import dr.evomodel.treedatalikelihood.preorder.ContinuousExtensionDelegate; import dr.evomodel.treedatalikelihood.preorder.ModelExtensionProvider; -import dr.evomodelxml.treedatalikelihood.ContinuousDataLikelihoodParser; +import dr.evomodelxml.continuous.ContinuousTraitDataModelParser; import dr.evomodelxml.treelikelihood.TreeTraitParserUtilities; import dr.inference.model.CompoundParameter; import dr.inference.model.MatrixParameterInterface; @@ -312,7 +312,7 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { int dimTrait = samplingPrecision.getColumnDimension(); final PrecisionType precisionType; - if (xo.getAttribute(ContinuousDataLikelihoodParser.FORCE_FULL_PRECISION, false) || + if (xo.getAttribute(ContinuousTraitDataModelParser.FORCE_FULL_PRECISION, false) || dimTrait > 1) { precisionType = PrecisionType.FULL; } else { From aa5fc95a5c960ca1b2f213e091adbe6b436d37fe Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Sun, 5 Jun 2022 16:06:39 -0400 Subject: [PATCH 070/196] starting to refactor RepeatedMeasuresTraitDataModel (probably need to do more) --- .../RepeatedMeasuresTraitDataModel.java | 41 +++++++++++-------- ...eScaledRepeatedMeasuresTraitDataModel.java | 3 +- .../continuous/RepeatedMeasureFactorTest.java | 10 +++++ .../hmc/DiffusionGradientTest.java | 2 + 4 files changed, 39 insertions(+), 17 deletions(-) diff --git a/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java b/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java index 058a93b980..f7ab41d8b7 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java @@ -72,8 +72,11 @@ public class RepeatedMeasuresTraitDataModel extends ContinuousTraitDataModel imp private boolean[] missingTraitIndicators = null; + private ContinuousTraitPartialsProvider childModel; + public RepeatedMeasuresTraitDataModel(String name, + ContinuousTraitPartialsProvider childModel, CompoundParameter parameter, boolean[] missindIndicators, boolean useMissingIndices, @@ -83,6 +86,7 @@ public RepeatedMeasuresTraitDataModel(String name, super(name, parameter, missindIndicators, useMissingIndices, dimTrait, precisionType); + this.childModel = childModel; this.traitName = name; this.samplingPrecisionParameter = samplingPrecision; addVariable(samplingPrecision); @@ -110,7 +114,7 @@ public double[] getTipPartial(int taxonIndex, boolean fullyObserved) { throw new RuntimeException("Incompatible with this model."); } - double[] partial = super.getTipPartial(taxonIndex, fullyObserved); + double[] partial = childModel.getTipPartial(taxonIndex, fullyObserved); if (precisionType == precisionType.SCALAR) { return partial; //TODO: I don't think this is right, especially given constructor above. } @@ -283,12 +287,14 @@ public void chainRuleWrtVariance(double[] gradient, NodeRef node) { public Object parseXMLObject(XMLObject xo) throws XMLParseException { MutableTreeModel treeModel = (MutableTreeModel) xo.getChild(TreeModel.class); - TreeTraitParserUtilities utilities = new TreeTraitParserUtilities(); + final ContinuousTraitPartialsProvider subModel; + + if (xo.hasChildNamed(TreeTraitParserUtilities.TRAIT_PARAMETER)) { + subModel = ContinuousTraitDataModelParser.parseContinuousTraitDataModel(xo); + } else { + subModel = (ContinuousTraitPartialsProvider) xo.getChild(ContinuousTraitPartialsProvider.class); + } - TreeTraitParserUtilities.TraitsAndMissingIndices returnValue = - utilities.parseTraitsFromTaxonAttributes(xo, treeModel, true); - CompoundParameter traitParameter = returnValue.traitParameter; - boolean[] missingIndicators = returnValue.getMissingIndicators(); XMLObject cxo = xo.getChild(PRECISION); MatrixParameterInterface samplingPrecision = (MatrixParameterInterface) @@ -306,7 +312,7 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { } - String traitName = returnValue.traitName; + String modelName = subModel.getModelName(); boolean scaleByTipHeight = xo.getAttribute(SCALE_BY_TIP_HEIGHT, false); @@ -321,9 +327,10 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { if (!scaleByTipHeight) { return new RepeatedMeasuresTraitDataModel( - traitName, - traitParameter, - missingIndicators, + modelName, + subModel, + subModel.getParameter(), + subModel.getDataMissingIndicators(), // missingIndicators, true, dimTrait, @@ -333,9 +340,10 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { ); } else { return new TreeScaledRepeatedMeasuresTraitDataModel( - traitName, - traitParameter, - missingIndicators, + modelName, + subModel, + subModel.getParameter(), + subModel.getDataMissingIndicators(), true, dimTrait, samplingPrecision, @@ -372,9 +380,10 @@ public String getParserName() { // Tree trait parser new ElementRule(MutableTreeModel.class), AttributeRule.newStringRule(TreeTraitParserUtilities.TRAIT_NAME), - new ElementRule(TreeTraitParserUtilities.TRAIT_PARAMETER, new XMLSyntaxRule[]{ - new ElementRule(Parameter.class) - }), + new XORRule( + new ElementRule(ContinuousTraitPartialsProvider.class), + new AndRule(ContinuousTraitDataModelParser.rules) + ), new ElementRule(TreeTraitParserUtilities.MISSING, new XMLSyntaxRule[]{ new ElementRule(Parameter.class) }, true), diff --git a/src/dr/evomodel/treedatalikelihood/continuous/TreeScaledRepeatedMeasuresTraitDataModel.java b/src/dr/evomodel/treedatalikelihood/continuous/TreeScaledRepeatedMeasuresTraitDataModel.java index 5268e7fe4c..c24ad64626 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/TreeScaledRepeatedMeasuresTraitDataModel.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/TreeScaledRepeatedMeasuresTraitDataModel.java @@ -45,13 +45,14 @@ public class TreeScaledRepeatedMeasuresTraitDataModel extends RepeatedMeasuresTr private ContinuousRateTransformation rateTransformation; public TreeScaledRepeatedMeasuresTraitDataModel(String name, + ContinuousTraitPartialsProvider childModel, CompoundParameter parameter, boolean[] missingIndicators, boolean useMissingIndices, final int dimTrait, MatrixParameterInterface samplingPrecision, PrecisionType precisionType) { - super(name, parameter, missingIndicators, useMissingIndices, dimTrait, samplingPrecision, precisionType); + super(name, childModel, parameter, missingIndicators, useMissingIndices, dimTrait, samplingPrecision, precisionType); } @Override diff --git a/src/test/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasureFactorTest.java b/src/test/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasureFactorTest.java index 76a5ad3422..0e9db57f05 100644 --- a/src/test/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasureFactorTest.java +++ b/src/test/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasureFactorTest.java @@ -133,6 +133,14 @@ public void setUp() throws Exception { loadingsParameters[5] = new Parameter.Default(new double[]{0.0, 0.0, 0.0, 0.0, 0.0, 1.0}); MatrixParameterInterface loadingsMatrixParameters = new MatrixParameter("loadings", loadingsParameters); + dataModel = new ContinuousTraitDataModel("dataModel", + traitParameter, + missingIndicators, + true, + 6, + PrecisionType.FULL + ); + dataModelFactor = new IntegratedFactorAnalysisLikelihood("dataModelFactors", traitParameter, missingIndicators, @@ -143,6 +151,7 @@ public void setUp() throws Exception { //// Repeated Measures Model //// ****************************************************************************** dataModelRepeatedMeasures = new RepeatedMeasuresTraitDataModel("dataModelRepeatedMeasures", + dataModel, traitParameter, missingIndicators, // new boolean[3], @@ -152,6 +161,7 @@ public void setUp() throws Exception { PrecisionType.FULL); dataModelRepeatedMeasuresFull = new RepeatedMeasuresTraitDataModel("dataModelRepeatedMeasures", + dataModel, traitParameter, missingIndicators, true, diff --git a/src/test/dr/evomodel/treedatalikelihood/hmc/DiffusionGradientTest.java b/src/test/dr/evomodel/treedatalikelihood/hmc/DiffusionGradientTest.java index 4a8ae8dbdd..5b5870cc36 100644 --- a/src/test/dr/evomodel/treedatalikelihood/hmc/DiffusionGradientTest.java +++ b/src/test/dr/evomodel/treedatalikelihood/hmc/DiffusionGradientTest.java @@ -203,6 +203,7 @@ public void setUp() throws Exception { new CompoundSymmetricMatrix(diagonalVarSampling, offDiagonalSampling, true, false)); dataModelRepeatedMeasures = new RepeatedMeasuresTraitDataModel("dataModelRepeatedMeasures", + dataModel, traitParameter, missingIndicators, true, @@ -211,6 +212,7 @@ public void setUp() throws Exception { PrecisionType.FULL); dataModelRepeatedMeasuresInv = new RepeatedMeasuresTraitDataModel("dataModelRepeatedMeasuresInv", + dataModel, traitParameter, missingIndicators, true, From ded81ffb0aec94ad58a8f600302997d54b9688a6 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Thu, 9 Jun 2022 14:31:00 -0400 Subject: [PATCH 071/196] actually drawing random LKJCorrelation w/ zeros --- ...WithStructuralZerosDistributionParser.java | 8 +- ...lationWithStructuralZerosDistribution.java | 82 ++++++++++++++++--- 2 files changed, 79 insertions(+), 11 deletions(-) diff --git a/src/dr/inferencexml/distribution/LKJCorrelationWithStructuralZerosDistributionParser.java b/src/dr/inferencexml/distribution/LKJCorrelationWithStructuralZerosDistributionParser.java index 70bf1c17d4..bac0ceb9c6 100644 --- a/src/dr/inferencexml/distribution/LKJCorrelationWithStructuralZerosDistributionParser.java +++ b/src/dr/inferencexml/distribution/LKJCorrelationWithStructuralZerosDistributionParser.java @@ -17,10 +17,16 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { ArrayList blocks = new ArrayList<>(); for (int i = 0; i < bxo.getChildCount(); i++) { - XMLObject bcxo = (XMLObject) xo.getChild(i); + XMLObject bcxo = (XMLObject) bxo.getChild(i); blocks.add(bcxo.getIntegerArrayChild(0)); } + for (int[] block : blocks) { + for (int i = 0; i < block.length; i++) { + block[i] = block[i] - 1; + } + } + int dim = xo.getIntegerAttribute(ParameterParser.DIMENSION); double shape = xo.getDoubleAttribute(PriorParsers.SHAPE); diff --git a/src/dr/math/distributions/LKJCorrelationWithStructuralZerosDistribution.java b/src/dr/math/distributions/LKJCorrelationWithStructuralZerosDistribution.java index bf977035d2..0f62207580 100644 --- a/src/dr/math/distributions/LKJCorrelationWithStructuralZerosDistribution.java +++ b/src/dr/math/distributions/LKJCorrelationWithStructuralZerosDistribution.java @@ -1,6 +1,8 @@ package dr.math.distributions; import dr.math.MathUtils; +import org.ejml.data.DenseMatrix64F; +import org.ejml.ops.CommonOps; import java.util.ArrayList; @@ -10,11 +12,13 @@ public class LKJCorrelationWithStructuralZerosDistribution extends LKJCorrelatio public static final String LKJ_WITH_ZEROS = "LKJCorrelationWithZerosDistribution"; + public static boolean DEBUG = false; + public LKJCorrelationWithStructuralZerosDistribution(int dim, double shape, ArrayList zeroBlocks) { super(dim, shape); - this.blockAssignments = new int[upperTriangularSize(dim)]; + this.blockAssignments = new int[dim]; for (int i = 0; i < zeroBlocks.size(); i++) { for (int j = 0; j < zeroBlocks.get(i).length; j++) { blockAssignments[zeroBlocks.get(i)[j]] = i + 1; //want to save 0 for indices that aren't in any block @@ -26,9 +30,8 @@ public LKJCorrelationWithStructuralZerosDistribution(int dim, double shape, Arra @Override public double[] nextRandom() { - int ind = 0; - int n = upperTriangularSize(dim); - double[] partialDraw = new double[n]; + + DenseMatrix64F partial = new DenseMatrix64F(dim, dim); for (int row = 0; row < dim; row++) { for (int col = row + 1; col < dim; col++) { @@ -39,19 +42,78 @@ public double[] nextRandom() { double beta = MathUtils.nextBeta(alpha, alpha); beta *= 2; beta -= 1; - - partialDraw[ind] = beta; + partial.set(row, col, beta); + partial.set(col, row, beta); } - ind++; } } - double[] correlation = new double[n]; + if (DEBUG) { + System.out.println(partial); + } + + + for (int i = 0; i < dim; i++) { + partial.set(i, i, 1); + } //convert partials to correlation matrix - for (int diag = 1; diag < dim; diag++) { - //TODO + for (int diag = 2; diag < dim; diag++) { + int dimSub = diag - 1; + DenseMatrix64F Rinv = new DenseMatrix64F(dimSub, dimSub); + DenseMatrix64F R = new DenseMatrix64F(dimSub, dimSub); + + DenseMatrix64F r1 = new DenseMatrix64F(dimSub, 1); + DenseMatrix64F r2 = new DenseMatrix64F(dimSub, 1); + + DenseMatrix64F RInvr1 = new DenseMatrix64F(dimSub, 1); + DenseMatrix64F RInvr2 = new DenseMatrix64F(dimSub, 1); + + for (int row = 0; row < dim - diag; row++) { + int col = row + diag; + + for (int i = 0; i < dimSub; i++) { + int rowi = i + row + 1; + r1.set(i, 0, partial.get(rowi, row)); + r2.set(i, 0, partial.get(rowi, col)); + for (int j = 0; j < dimSub; j++) { + R.set(i, j, partial.get(rowi, j + row + 1)); + } + } + + CommonOps.invert(R, Rinv); + CommonOps.mult(Rinv, r1, RInvr1); + CommonOps.mult(Rinv, r2, RInvr2); + + double r1tRinvr1 = 0; + double r1tRinvr2 = 0; + double r2tRinvr2 = 0; + + for (int i = 0; i < dimSub; i++) { + r1tRinvr1 += RInvr1.get(i, 0) * r1.get(i, 0); + r1tRinvr2 += RInvr2.get(i, 0) * r1.get(i, 0); + r2tRinvr2 += RInvr2.get(i, 0) * r2.get(i, 0); + } + + double d = (1 - r1tRinvr1) * (1 - r2tRinvr2); + double c = r1tRinvr2 + partial.get(row, col) * d; + partial.set(row, col, c); + partial.set(col, row, c); + } + } + + double[] correlation = new double[upperTriangularSize(dim)]; + int ind = 0; + for (int i = 0; i < dim; i++) { + for (int j = (i + 1); j < dim; j++) { + correlation[ind] = partial.get(i, j); + ind++; + } + } + + if (DEBUG) { + System.out.println(partial); } return correlation; From 16a54d1401265d3a5b3a7ac5bacf8dcebfe182fd Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Thu, 9 Jun 2022 15:03:40 -0400 Subject: [PATCH 072/196] fireParameterChanged() now works on TransformedParameter --- .../TransformedMultivariateParameter.java | 7 +++++++ .../inference/model/TransformedParameter.java | 20 ++++++++++++++++--- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/src/dr/inference/model/TransformedMultivariateParameter.java b/src/dr/inference/model/TransformedMultivariateParameter.java index 218de2020e..c208c4a5ac 100644 --- a/src/dr/inference/model/TransformedMultivariateParameter.java +++ b/src/dr/inference/model/TransformedMultivariateParameter.java @@ -121,4 +121,11 @@ private boolean hasChanged() { } return false; } + + @Override + public void variableChangedEvent(Variable variable, int index, Parameter.ChangeType type) { + if (!doNotPropagateChangeUp) { + fireParameterChangedEvent(-1, ChangeType.ALL_VALUES_CHANGED); //if one dimension of the untransformed parameter changes, it is very likely that many dimensions of the transformed parameter change + } + } } diff --git a/src/dr/inference/model/TransformedParameter.java b/src/dr/inference/model/TransformedParameter.java index 588858c3ca..8d4a26e5e4 100644 --- a/src/dr/inference/model/TransformedParameter.java +++ b/src/dr/inference/model/TransformedParameter.java @@ -193,9 +193,21 @@ public double removeDimension(int index) { throw new RuntimeException("Not yet implemented."); } - public void variableChangedEvent(Variable variable, int index, ChangeType type) { - // Propogate change up model graph - fireParameterChangedEvent(index, type); + @Override + public void fireParameterChangedEvent() { + + doNotPropagateChangeUp = true; + parameter.fireParameterChangedEvent(); + doNotPropagateChangeUp = false; + + fireParameterChangedEvent(-1, ChangeType.ALL_VALUES_CHANGED); + } + + @Override + public void variableChangedEvent(Variable variable, int index, Parameter.ChangeType type) { + if (!doNotPropagateChangeUp) { + fireParameterChangedEvent(index, type); + } } public double diffLogJacobian(double[] oldValues, double[] newValues) { @@ -217,4 +229,6 @@ public Transform getTransform() { protected final Transform transform; protected final boolean inverse; protected Bounds transformedBounds; + + protected boolean doNotPropagateChangeUp = false; } \ No newline at end of file From 9ef2ea63cc681ad778254706d9606541dcf10a5a Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Thu, 9 Jun 2022 15:05:05 -0400 Subject: [PATCH 073/196] don't need to search through sub-parameters if doNotPropogateChangeUp == true --- src/dr/inference/model/CompoundParameter.java | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/dr/inference/model/CompoundParameter.java b/src/dr/inference/model/CompoundParameter.java index 6e5ea2481a..94c422f11f 100644 --- a/src/dr/inference/model/CompoundParameter.java +++ b/src/dr/inference/model/CompoundParameter.java @@ -274,15 +274,15 @@ protected String toStringCompoundParameter(int dim) { public void variableChangedEvent(Variable variable, int index, Parameter.ChangeType type) { int dim = 0; - for (Parameter parameter1 : uniqueParameters) { - if (variable == parameter1) { - if (!doNotPropagateChangeUp) { + if (!doNotPropagateChangeUp) { + for (Parameter parameter1 : uniqueParameters) { + if (variable == parameter1) { int subparameterIndex = (index == -1) ? -1 : dim + index; fireParameterChangedEvent(subparameterIndex, type); + break; } - break; + dim += parameter1.getDimension(); } - dim += parameter1.getDimension(); } } @@ -375,7 +375,7 @@ public int getBoundsDimension() { } } - protected ArrayList getParameters(){ + protected ArrayList getParameters() { return parameters; } From b097d62d4475b9995045d6eb78b7619d3610453d Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Thu, 9 Jun 2022 15:05:57 -0400 Subject: [PATCH 074/196] want to stay close to to current position, not random one --- .../inference/operators/ConvexSpaceRandomWalkOperator.java | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/dr/inference/operators/ConvexSpaceRandomWalkOperator.java b/src/dr/inference/operators/ConvexSpaceRandomWalkOperator.java index d4510b5a68..3e4fcb46c0 100644 --- a/src/dr/inference/operators/ConvexSpaceRandomWalkOperator.java +++ b/src/dr/inference/operators/ConvexSpaceRandomWalkOperator.java @@ -29,7 +29,7 @@ public double doOperation() { double oneMinus = 1.0 - t; for (int i = 0; i < values.length; i++) { - sample[i] = values[i] * t + sample[i] * oneMinus; + sample[i] = values[i] * oneMinus + sample[i] * t; } parameter.setAllParameterValuesQuietly(sample); @@ -64,5 +64,10 @@ public String getOperatorName() { return CONVEX_RW; } + @Override + public Parameter getParameter() { + return parameter; + } + } From f945259f0f97a339ed8d41af54947a200f81bd0d Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Fri, 10 Jun 2022 10:52:51 -0400 Subject: [PATCH 075/196] a couple extra methods + fixing return types in SymmetricMatrix --- src/dr/math/matrixAlgebra/SymmetricMatrix.java | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/dr/math/matrixAlgebra/SymmetricMatrix.java b/src/dr/math/matrixAlgebra/SymmetricMatrix.java index 27eeca4424..a8f533fb45 100644 --- a/src/dr/math/matrixAlgebra/SymmetricMatrix.java +++ b/src/dr/math/matrixAlgebra/SymmetricMatrix.java @@ -107,9 +107,13 @@ public static SymmetricMatrix compoundSymmetricMatrix(double[] diag, double[] of * if n <= 0 */ public static SymmetricMatrix compoundCorrelationSymmetricMatrix(double[] offdiag, int n) { - if (n <= 0) - throw new NegativeArraySizeException( - "Requested matrix size: " + n); + return compoundSymmetricMatrix(1.0, offdiag, n); + } + + public static SymmetricMatrix compoundSymmetricMatrix(double diagonal, double[] offdiag, int n) { + if (n <= 0) + throw new NegativeArraySizeException( + "Requested matrix size: " + n); assert n * (n - 1) / 2 == offdiag.length : "Requested matrix size: " + n + " doesn't match off diagonal array size: " + offdiag.length; @@ -117,7 +121,7 @@ public static SymmetricMatrix compoundCorrelationSymmetricMatrix(double[] offdia double[][] a = new double[n][n]; int k = 0; for (int i = 0; i < n; i++) { - a[i][i] = 1.0; + a[i][i] = diagonal; for (int j = i + 1; j < n; j++) { a[i][j] = a[j][i] = offdiag[k]; k++; @@ -221,11 +225,11 @@ public static SymmetricMatrix identityMatrix(int n) { } /** - * @return Matrix inverse of the receiver. + * @return SymmetricMatrix inverse of the receiver. * @throws java.lang.ArithmeticException if the receiver is * a singular matrix. */ - public Matrix inverse() throws ArithmeticException { + public SymmetricMatrix inverse() throws ArithmeticException { return rows() < lupCRLCriticalDimension ? new SymmetricMatrix( (new LUPDecomposition(this)).inverseMatrixComponents()) @@ -315,7 +319,7 @@ public Matrix product(double a) { * the receivers are not equal to the number of rows * of the supplied matrix. */ - public SymmetricMatrix product(SymmetricMatrix a) throws IllegalDimension { + public Matrix product(SymmetricMatrix a) throws IllegalDimension { return new SymmetricMatrix(productComponents(a)); } From 451670dc72b7aeb3979d2c22e012f11fdf0f3064 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Fri, 10 Jun 2022 10:54:46 -0400 Subject: [PATCH 076/196] proper hastings ratio for ConvexSpaceRnadomWalkOperator --- .../ConvexSpaceRandomWalkOperator.java | 25 ++++++++++++++++--- .../ConvexSpaceRandomWalkOperatorParser.java | 11 +++++--- .../ConvexSpaceRandomGenerator.java | 21 ++++++++++++++++ 3 files changed, 50 insertions(+), 7 deletions(-) create mode 100644 src/dr/math/distributions/ConvexSpaceRandomGenerator.java diff --git a/src/dr/inference/operators/ConvexSpaceRandomWalkOperator.java b/src/dr/inference/operators/ConvexSpaceRandomWalkOperator.java index 3e4fcb46c0..1851036f2d 100644 --- a/src/dr/inference/operators/ConvexSpaceRandomWalkOperator.java +++ b/src/dr/inference/operators/ConvexSpaceRandomWalkOperator.java @@ -1,18 +1,19 @@ package dr.inference.operators; import dr.inference.model.Parameter; -import dr.math.distributions.RandomGenerator; +import dr.math.distributions.ConvexSpaceRandomGenerator; import jebl.math.Random; public class ConvexSpaceRandomWalkOperator extends AbstractAdaptableOperator { private double window; - private final RandomGenerator generator; + private final ConvexSpaceRandomGenerator generator; private final Parameter parameter; public static final String CONVEX_RW = "convexSpaceRandomWalkOperator"; public static final String WINDOW_SIZE = "relativeWindowSize"; - public ConvexSpaceRandomWalkOperator(Parameter parameter, RandomGenerator generator, double window, double weight) { + public ConvexSpaceRandomWalkOperator(Parameter parameter, ConvexSpaceRandomGenerator generator, + double window, double weight) { setWeight(weight); this.parameter = parameter; @@ -25,6 +26,9 @@ public ConvexSpaceRandomWalkOperator(Parameter parameter, RandomGenerator genera public double doOperation() { double[] sample = (double[]) generator.nextRandom(); double[] values = parameter.getParameterValues(); + + ConvexSpaceRandomGenerator.LineThroughPoints distances = generator.distanceToEdge(values, sample); + double t = window * Random.nextDouble(); double oneMinus = 1.0 - t; @@ -34,7 +38,20 @@ public double doOperation() { parameter.setAllParameterValuesQuietly(sample); parameter.fireParameterChangedEvent(); - return 0.0; //TODO: need to check that the prior is uniform + + double tForward = t / (distances.forwardDistance * window); + double forwardLogDensity = uniformProductLogPdf(tForward); + + double backWardDistance = distances.backwardDistance + t; + double tBackward = t / (backWardDistance * window); + double backwardLogDensity = uniformProductLogPdf(tBackward); + + return backwardLogDensity - forwardLogDensity; + } + + private double uniformProductLogPdf(double t) { + double density = -Math.log(t); + return Math.log(density); } diff --git a/src/dr/inferencexml/operators/ConvexSpaceRandomWalkOperatorParser.java b/src/dr/inferencexml/operators/ConvexSpaceRandomWalkOperatorParser.java index 3c3d8a87c5..72c7c40b8a 100644 --- a/src/dr/inferencexml/operators/ConvexSpaceRandomWalkOperatorParser.java +++ b/src/dr/inferencexml/operators/ConvexSpaceRandomWalkOperatorParser.java @@ -3,7 +3,7 @@ import dr.inference.model.Parameter; import dr.inference.operators.ConvexSpaceRandomWalkOperator; import dr.inference.operators.MCMCOperator; -import dr.math.distributions.RandomGenerator; +import dr.math.distributions.ConvexSpaceRandomGenerator; import dr.xml.*; public class ConvexSpaceRandomWalkOperatorParser extends AbstractXMLObjectParser { @@ -12,7 +12,12 @@ public class ConvexSpaceRandomWalkOperatorParser extends AbstractXMLObjectParser @Override public Object parseXMLObject(XMLObject xo) throws XMLParseException { Parameter parameter = (Parameter) xo.getChild(Parameter.class); - RandomGenerator generator = (RandomGenerator) xo.getChild(RandomGenerator.class); + ConvexSpaceRandomGenerator generator = + (ConvexSpaceRandomGenerator) xo.getChild(ConvexSpaceRandomGenerator.class); + + if (!generator.isUniform()) { + throw new XMLParseException("sample distribution must be uniform over its support"); + } double weight = xo.getDoubleAttribute(MCMCOperator.WEIGHT); @@ -28,7 +33,7 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { public XMLSyntaxRule[] getSyntaxRules() { return new XMLSyntaxRule[]{ new ElementRule(Parameter.class), - new ElementRule(RandomGenerator.class), + new ElementRule(ConvexSpaceRandomGenerator.class), AttributeRule.newDoubleRule(MCMCOperator.WEIGHT), AttributeRule.newDoubleRule(ConvexSpaceRandomWalkOperator.WINDOW_SIZE, true) }; diff --git a/src/dr/math/distributions/ConvexSpaceRandomGenerator.java b/src/dr/math/distributions/ConvexSpaceRandomGenerator.java new file mode 100644 index 0000000000..51b947586c --- /dev/null +++ b/src/dr/math/distributions/ConvexSpaceRandomGenerator.java @@ -0,0 +1,21 @@ +package dr.math.distributions; + +public interface ConvexSpaceRandomGenerator extends RandomGenerator { + + LineThroughPoints distanceToEdge(double[] origin, double[] draw); + + boolean isUniform(); + + class LineThroughPoints { + public final double forwardDistance; + public final double backwardDistance; + public final double totalDistance; + + public LineThroughPoints(double forwardDistance, double backwardDistance) { + this.forwardDistance = forwardDistance; + this.backwardDistance = backwardDistance; + this.totalDistance = forwardDistance + backwardDistance; + } + } + +} From 75934e39794fbe551228adb3b84e0e9b2331bfca Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Fri, 10 Jun 2022 11:21:12 -0400 Subject: [PATCH 077/196] 'length' of line from point to edge of space --- ...lationWithStructuralZerosDistribution.java | 52 ++++++++++++++++++- 1 file changed, 51 insertions(+), 1 deletion(-) diff --git a/src/dr/math/distributions/LKJCorrelationWithStructuralZerosDistribution.java b/src/dr/math/distributions/LKJCorrelationWithStructuralZerosDistribution.java index 0f62207580..be54d0c9be 100644 --- a/src/dr/math/distributions/LKJCorrelationWithStructuralZerosDistribution.java +++ b/src/dr/math/distributions/LKJCorrelationWithStructuralZerosDistribution.java @@ -1,12 +1,20 @@ package dr.math.distributions; +import dr.evomodel.substmodel.ColtEigenSystem; +import dr.evomodel.substmodel.EigenDecomposition; import dr.math.MathUtils; +import dr.math.matrixAlgebra.IllegalDimension; +import dr.math.matrixAlgebra.Matrix; +import dr.math.matrixAlgebra.SymmetricMatrix; import org.ejml.data.DenseMatrix64F; import org.ejml.ops.CommonOps; import java.util.ArrayList; -public class LKJCorrelationWithStructuralZerosDistribution extends LKJCorrelationDistribution implements RandomGenerator { +import static dr.math.matrixAlgebra.SymmetricMatrix.compoundCorrelationSymmetricMatrix; +import static dr.math.matrixAlgebra.SymmetricMatrix.compoundSymmetricMatrix; + +public class LKJCorrelationWithStructuralZerosDistribution extends LKJCorrelationDistribution implements ConvexSpaceRandomGenerator { private final int[] blockAssignments; @@ -123,4 +131,46 @@ public double[] nextRandom() { public double logPdf(Object x) { return logPdf((double[]) x); } + + @Override + public LineThroughPoints distanceToEdge(double[] origin, double[] draw) { + double[] x = new double[origin.length]; + for (int i = 0; i < origin.length; i++) { + x[i] = origin[i] - draw[i]; + } + + SymmetricMatrix Y = compoundCorrelationSymmetricMatrix(origin, dim); + SymmetricMatrix X = compoundSymmetricMatrix(0.0, x, dim); + + SymmetricMatrix Xinv = X.inverse(); + final Matrix Z; + + try { + Z = Y.product(Xinv); + } catch (IllegalDimension illegalDimension) { + throw new RuntimeException("illegal dimensions"); + } + + ColtEigenSystem eigenSystem = new ColtEigenSystem(dim); + EigenDecomposition decomposition = eigenSystem.decomposeMatrix(Z.toComponents()); //TODO: only need smallest magnitude eigenvalues + double[] values = decomposition.getEigenValues(); + + double minNegative = Double.NEGATIVE_INFINITY; + double minPositive = Double.POSITIVE_INFINITY; + for (int i = 0; i < values.length; i++) { + double value = values[i]; + if (value < 0 && value > minNegative) { + minNegative = value; + } else if (value >= 0 & value < minPositive) { + minPositive = value; + } + } + + return new ConvexSpaceRandomGenerator.LineThroughPoints(minPositive, -minNegative); + } + + @Override + public boolean isUniform() { + return shape == 1; + } } From ca0b17c6bfc1c1218a062dedd1c8458553e7f13d Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Fri, 10 Jun 2022 17:05:38 -0400 Subject: [PATCH 078/196] bug fix --- .../LKJCorrelationWithStructuralZerosDistribution.java | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/dr/math/distributions/LKJCorrelationWithStructuralZerosDistribution.java b/src/dr/math/distributions/LKJCorrelationWithStructuralZerosDistribution.java index be54d0c9be..cfa79e5f14 100644 --- a/src/dr/math/distributions/LKJCorrelationWithStructuralZerosDistribution.java +++ b/src/dr/math/distributions/LKJCorrelationWithStructuralZerosDistribution.java @@ -1,5 +1,6 @@ package dr.math.distributions; +import dr.app.bss.Utils; import dr.evomodel.substmodel.ColtEigenSystem; import dr.evomodel.substmodel.EigenDecomposition; import dr.math.MathUtils; @@ -45,7 +46,7 @@ public double[] nextRandom() { for (int col = row + 1; col < dim; col++) { if (blockAssignments[row] == 0 || blockAssignments[row] != blockAssignments[col]) { - int diag = row - col; + int diag = col - row; double alpha = shape + 0.5 * (dim - 1 - diag); double beta = MathUtils.nextBeta(alpha, alpha); beta *= 2; @@ -105,7 +106,7 @@ public double[] nextRandom() { } double d = (1 - r1tRinvr1) * (1 - r2tRinvr2); - double c = r1tRinvr2 + partial.get(row, col) * d; + double c = r1tRinvr2 + partial.get(row, col) * Math.sqrt(d); partial.set(row, col, c); partial.set(col, row, c); } From 31b5505f1f23e2e7ec6e5d5cf2c3955dc7b7df25 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Mon, 13 Jun 2022 15:37:18 -0700 Subject: [PATCH 079/196] messy attempt at slightly modified sampler --- .../ConvexSpaceRandomWalkOperator.java | 50 ++++++++++++----- .../ConvexSpaceRandomWalkOperatorParser.java | 17 +++++- ...lationWithStructuralZerosDistribution.java | 53 +++++++++++++++---- 3 files changed, 95 insertions(+), 25 deletions(-) diff --git a/src/dr/inference/operators/ConvexSpaceRandomWalkOperator.java b/src/dr/inference/operators/ConvexSpaceRandomWalkOperator.java index 1851036f2d..080e8441b1 100644 --- a/src/dr/inference/operators/ConvexSpaceRandomWalkOperator.java +++ b/src/dr/inference/operators/ConvexSpaceRandomWalkOperator.java @@ -1,6 +1,7 @@ package dr.inference.operators; import dr.inference.model.Parameter; +import dr.math.MathUtils; import dr.math.distributions.ConvexSpaceRandomGenerator; import jebl.math.Random; @@ -8,14 +9,17 @@ public class ConvexSpaceRandomWalkOperator extends AbstractAdaptableOperator { private double window; private final ConvexSpaceRandomGenerator generator; private final Parameter parameter; + private final Parameter updateIndex; public static final String CONVEX_RW = "convexSpaceRandomWalkOperator"; public static final String WINDOW_SIZE = "relativeWindowSize"; public ConvexSpaceRandomWalkOperator(Parameter parameter, ConvexSpaceRandomGenerator generator, + Parameter updateIndex, double window, double weight) { setWeight(weight); + this.updateIndex = updateIndex; this.parameter = parameter; this.generator = generator; this.window = window; @@ -24,35 +28,53 @@ public ConvexSpaceRandomWalkOperator(Parameter parameter, ConvexSpaceRandomGener @Override public double doOperation() { - double[] sample = (double[]) generator.nextRandom(); +// double[] sample = (double[]) generator.nextRandom(); double[] values = parameter.getParameterValues(); + double[] sample = new double[values.length]; + double sum = 0; + for (int i = 0; i < values.length; i++) { + if (updateIndex == null || updateIndex.getParameterValue(i) == 1) { + sample[i] = MathUtils.nextGaussian(); + sum += sample[i] * sample[i]; + } + } + double norm = Math.sqrt(sum); + for (int i = 0; i < values.length; i++) { + sample[i] = sample[i] / norm; + } + ConvexSpaceRandomGenerator.LineThroughPoints distances = generator.distanceToEdge(values, sample); +// double u1 = Random.nextDouble() * distances.forwardDistance; +// for (int i = 0; i < values.length; i++) { +// sample[i] = values[i] + (sample[i] - values[i]) * u1; +// } + double t = window * Random.nextDouble(); - double oneMinus = 1.0 - t; + t = RandomWalkOperator.reflectValue(t, -distances.backwardDistance, distances.forwardDistance); for (int i = 0; i < values.length; i++) { - sample[i] = values[i] * oneMinus + sample[i] * t; + sample[i] = values[i] - sample[i] * t; } parameter.setAllParameterValuesQuietly(sample); parameter.fireParameterChangedEvent(); - double tForward = t / (distances.forwardDistance * window); - double forwardLogDensity = uniformProductLogPdf(tForward); +// double tForward = t / (distances.forwardDistance * window); +// double forwardLogDensity = uniformProductLogPdf(tForward); +// +// double backWardDistance = distances.backwardDistance + t; +// double tBackward = t / (backWardDistance * window); +// double backwardLogDensity = uniformProductLogPdf(tBackward); - double backWardDistance = distances.backwardDistance + t; - double tBackward = t / (backWardDistance * window); - double backwardLogDensity = uniformProductLogPdf(tBackward); - - return backwardLogDensity - forwardLogDensity; + return 0; } - private double uniformProductLogPdf(double t) { - double density = -Math.log(t); - return Math.log(density); - } +// private double uniformProductLogPdf(double t) { +// double density = -Math.log(t); +// return Math.log(density); +// } @Override diff --git a/src/dr/inferencexml/operators/ConvexSpaceRandomWalkOperatorParser.java b/src/dr/inferencexml/operators/ConvexSpaceRandomWalkOperatorParser.java index 72c7c40b8a..42a178eca6 100644 --- a/src/dr/inferencexml/operators/ConvexSpaceRandomWalkOperatorParser.java +++ b/src/dr/inferencexml/operators/ConvexSpaceRandomWalkOperatorParser.java @@ -3,6 +3,7 @@ import dr.inference.model.Parameter; import dr.inference.operators.ConvexSpaceRandomWalkOperator; import dr.inference.operators.MCMCOperator; +import dr.inference.operators.RandomWalkOperator; import dr.math.distributions.ConvexSpaceRandomGenerator; import dr.xml.*; @@ -26,7 +27,16 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { throw new XMLParseException(ConvexSpaceRandomWalkOperator.WINDOW_SIZE + " must be between 0 and 1"); } - return new ConvexSpaceRandomWalkOperator(parameter, generator, windowSize, weight); + final Parameter updateIndex; + + if (xo.hasChildNamed(RandomWalkOperatorParser.UPDATE_INDEX)) { + XMLObject cxo = xo.getChild(RandomWalkOperatorParser.UPDATE_INDEX); + updateIndex = (Parameter) cxo.getChild(Parameter.class); + } else { + updateIndex = null; + } + + return new ConvexSpaceRandomWalkOperator(parameter, generator, updateIndex, windowSize, weight); } @Override @@ -34,6 +44,11 @@ public XMLSyntaxRule[] getSyntaxRules() { return new XMLSyntaxRule[]{ new ElementRule(Parameter.class), new ElementRule(ConvexSpaceRandomGenerator.class), + new ElementRule(RandomWalkOperatorParser.UPDATE_INDEX, + new XMLSyntaxRule[]{ + new ElementRule(Parameter.class) + }, true + ), AttributeRule.newDoubleRule(MCMCOperator.WEIGHT), AttributeRule.newDoubleRule(ConvexSpaceRandomWalkOperator.WINDOW_SIZE, true) }; diff --git a/src/dr/math/distributions/LKJCorrelationWithStructuralZerosDistribution.java b/src/dr/math/distributions/LKJCorrelationWithStructuralZerosDistribution.java index cfa79e5f14..375968aaa4 100644 --- a/src/dr/math/distributions/LKJCorrelationWithStructuralZerosDistribution.java +++ b/src/dr/math/distributions/LKJCorrelationWithStructuralZerosDistribution.java @@ -137,17 +137,18 @@ public double logPdf(Object x) { public LineThroughPoints distanceToEdge(double[] origin, double[] draw) { double[] x = new double[origin.length]; for (int i = 0; i < origin.length; i++) { - x[i] = origin[i] - draw[i]; + x[i] = draw[i]; } SymmetricMatrix Y = compoundCorrelationSymmetricMatrix(origin, dim); SymmetricMatrix X = compoundSymmetricMatrix(0.0, x, dim); - SymmetricMatrix Xinv = X.inverse(); +// SymmetricMatrix Xinv = X.inverse(); + SymmetricMatrix Yinv = Y.inverse(); final Matrix Z; try { - Z = Y.product(Xinv); + Z = Yinv.product(X); } catch (IllegalDimension illegalDimension) { throw new RuntimeException("illegal dimensions"); } @@ -156,18 +157,50 @@ public LineThroughPoints distanceToEdge(double[] origin, double[] draw) { EigenDecomposition decomposition = eigenSystem.decomposeMatrix(Z.toComponents()); //TODO: only need smallest magnitude eigenvalues double[] values = decomposition.getEigenValues(); - double minNegative = Double.NEGATIVE_INFINITY; - double minPositive = Double.POSITIVE_INFINITY; + double maxNegative = 0; + double maxPositive = 0; for (int i = 0; i < values.length; i++) { double value = values[i]; - if (value < 0 && value > minNegative) { - minNegative = value; - } else if (value >= 0 & value < minPositive) { - minPositive = value; + if (value < 0 && value < maxNegative) { + maxNegative = value; + } else if (value >= 0 & value > maxPositive) { + maxPositive = value; } } - return new ConvexSpaceRandomGenerator.LineThroughPoints(minPositive, -minNegative); + if (DEBUG) { + System.out.print("Eigenvalues: "); + Utils.printArray(values); + + Matrix S = new SymmetricMatrix(dim, dim); + Matrix T = new SymmetricMatrix(dim, dim); + for (int i = 0; i < dim; i++) { + S.set(i, i, 1); + T.set(i, i, 1); + for (int j = (i + 1); j < dim; j++) { + double y = Y.toComponents()[i][j]; + double z = X.toComponents()[i][j]; + double valueS = y - z / maxNegative; + double valueT = y - z / maxPositive; + S.set(i, j, valueS); + S.set(j, i, valueS); + T.set(i, j, valueT); + T.set(j, i, valueT); + } + } + try { + System.out.println("neg: \n\tt = " + maxNegative); + System.out.println("\tdet = " + S.determinant()); + System.out.println(S); + System.out.println("pos: \n\tt = " + maxPositive); + System.out.println("\tdet = " + T.determinant()); + System.out.println(T); + } catch (IllegalDimension illegalDimension) { + illegalDimension.printStackTrace(); + } + } + + return new ConvexSpaceRandomGenerator.LineThroughPoints(1 / maxPositive, -1 / maxNegative); } @Override From a860dda536563d55c92b97c6e5be6bc78a960e7a Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Mon, 13 Jun 2022 15:59:55 -0700 Subject: [PATCH 080/196] informative error message --- src/dr/math/matrixAlgebra/WrappedMatrix.java | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/dr/math/matrixAlgebra/WrappedMatrix.java b/src/dr/math/matrixAlgebra/WrappedMatrix.java index 367594188a..644142ee2b 100644 --- a/src/dr/math/matrixAlgebra/WrappedMatrix.java +++ b/src/dr/math/matrixAlgebra/WrappedMatrix.java @@ -492,7 +492,10 @@ public static WrappedMatrix.WrappedUpperTriangularMatrix fillDiagonal(double[] x sum += temp * temp; } if (sum > 1.0) { - assert (Math.abs(sum - 1.0) < 1E-6); + if (Math.abs(sum - 1.0) > 1E-6) { + throw new RuntimeException("Values are not consistent with the cholesky decomposition of " + + "a correlation matrix. Sum of squared values must be less than 1 (got " + sum + ")"); + } sum = 1.0; } W.set(j, j, Math.sqrt(1 - sum)); From 4bb0858f791a99b83b6c0ddd07fc24b221b045c0 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Tue, 14 Jun 2022 16:48:28 -0700 Subject: [PATCH 081/196] really badly coded (and buggy) adaptive metropolis attempt. promist to fix, but need to do other things so wanted to commit --- .../ConvexSpaceRandomWalkOperator.java | 164 +++++++++++++++++- 1 file changed, 155 insertions(+), 9 deletions(-) diff --git a/src/dr/inference/operators/ConvexSpaceRandomWalkOperator.java b/src/dr/inference/operators/ConvexSpaceRandomWalkOperator.java index 080e8441b1..d8363aee64 100644 --- a/src/dr/inference/operators/ConvexSpaceRandomWalkOperator.java +++ b/src/dr/inference/operators/ConvexSpaceRandomWalkOperator.java @@ -1,15 +1,37 @@ package dr.inference.operators; import dr.inference.model.Parameter; -import dr.math.MathUtils; import dr.math.distributions.ConvexSpaceRandomGenerator; +import dr.math.distributions.MultivariateNormalDistribution; +import dr.math.matrixAlgebra.CholeskyDecomposition; import jebl.math.Random; +import org.ejml.data.DenseMatrix64F; + +import java.util.ArrayList; public class ConvexSpaceRandomWalkOperator extends AbstractAdaptableOperator { + + private static final boolean DEBUG = true; + private static final double corDeflator = 0.01; + private final ArrayList sampleList = new ArrayList<>(); + + private double window; private final ConvexSpaceRandomGenerator generator; private final Parameter parameter; private final Parameter updateIndex; + private final boolean ADAPTIVE_COVARIANCE = true; + private final int burnin = 50; + private double[] mean; + private double[] oldMean; + private final DenseMatrix64F cov; + private final int dim; + private final int varDim; + private int iterations = 0; + private int updates = 0; + private int every = 10; + private final ArrayList varInds = new ArrayList<>(); + private double[][] cholesky; public static final String CONVEX_RW = "convexSpaceRandomWalkOperator"; public static final String WINDOW_SIZE = "relativeWindowSize"; @@ -23,25 +45,149 @@ public ConvexSpaceRandomWalkOperator(Parameter parameter, ConvexSpaceRandomGener this.parameter = parameter; this.generator = generator; this.window = window; + + this.dim = parameter.getDimension(); + for (int i = 0; i < dim; i++) { + if (updateIndex == null || updateIndex.getParameterValue(i) == 1) { + varInds.add(i); + } + } + this.varDim = varInds.size(); + + this.cov = new DenseMatrix64F(varDim, varDim); + for (int i = 0; i < varDim; i++) { + cov.set(i, i, 1); + } + cholesky = CholeskyDecomposition.execute(cov.getData(), 0, varDim); + } @Override public double doOperation() { -// double[] sample = (double[]) generator.nextRandom(); + iterations++; double[] values = parameter.getParameterValues(); + double[] varValues = new double[varDim]; + for (int i = 0; i < varDim; i++) { + varValues[i] = values[varInds.get(i)]; + } + + if (ADAPTIVE_COVARIANCE) { + + + if (iterations == burnin) { + sampleList.add(varValues); + mean = varValues; + updates++; + } else if (iterations > burnin && iterations % every == 0) { + sampleList.add(varValues); + updates++; + oldMean = mean; + + for (int i = 0; i < varDim; i++) { + mean[i] = ((updates - 1) * oldMean[i] + varValues[i]) / updates; + } + + for (int i = 0; i < varDim; i++) { + for (int j = i; j < varDim; j++) { + double value = (updates - 1) * (cov.get(i, j) + oldMean[i] * oldMean[j]); + value += varValues[i] * varValues[j] - mean[i] * mean[j]; + value /= updates; + cov.set(i, j, value); + cov.set(j, i, value); + } + } + + DenseMatrix64F cor = new DenseMatrix64F(varDim, varDim); + for (int i = 0; i < varDim; i++) { + for (int j = 0; j < varDim; j++) { + cor.set(i, j, cov.get(i, j) / Math.sqrt(cov.get(i, i) * cov.get(j, j))); + if (i == j) { + cor.set(i, j, cor.get(i, j) + corDeflator); + } + } + } + + if (DEBUG) { + + DenseMatrix64F testCov = new DenseMatrix64F(varDim, varDim); + double[] testMean = new double[varDim]; + + for (int i = 0; i < updates; i++) { + double[] valuesi = sampleList.get(i); + for (int j = 0; j < varDim; j++) { + testMean[j] += valuesi[j]; + testCov.set(j, j, testCov.get(j, j) + valuesi[j] * valuesi[j]); + for (int k = (j + 1); k < varDim; k++) { + testCov.set(j, k, testCov.get(j, k) + valuesi[j] * valuesi[k]); + } + } + } + +// System.out.print("Sum squares:"); +// System.out.print(testCov); + + for (int i = 0; i < varDim; i++) { + testMean[i] /= updates; + } + +// System.out.print("Mean: "); +// Utils.printArray(testMean); + + for (int i = 0; i < varDim; i++) { + testCov.set(i, i, testCov.get(i, i) / updates - testMean[i] * testMean[i]); + for (int j = (i + 1); j < varDim; j++) { + testCov.set(i, j, testCov.get(i, j) / updates - testMean[i] * testMean[j]); + testCov.set(j, i, testCov.get(i, j)); + } + } + + + DenseMatrix64F testCor = new DenseMatrix64F(varDim, varDim); + for (int i = 0; i < varDim; i++) { + for (int j = 0; j < varDim; j++) { + testCor.set(i, j, testCov.get(i, j) / Math.sqrt(testCov.get(i, i) * testCov.get(j, j))); + if (i == j) { + testCor.set(i, j, testCor.get(i, j) + corDeflator); + } + } + } + + +// System.out.println("Cov:"); +// System.out.println(testCov); +// System.out.println(cov); +// +// System.out.println("Cor"); +// System.out.print(testCor); +// System.out.println(cor); +// System.out.println(); + + System.arraycopy(testCor.data, 0, cor.data, 0, cor.data.length); //TODO: remove after fixing + } + + + cholesky = CholeskyDecomposition.execute(cor.getData(), 0, varDim); + } + } + +// double[] sample = (double[]) generator.nextRandom(); double[] sample = new double[values.length]; + double[] varSample = MultivariateNormalDistribution.nextMultivariateNormalCholesky(new double[varDim], cholesky); + double sum = 0; - for (int i = 0; i < values.length; i++) { - if (updateIndex == null || updateIndex.getParameterValue(i) == 1) { - sample[i] = MathUtils.nextGaussian(); - sum += sample[i] * sample[i]; - } + for (int i = 0; i < varDim; i++) { + sum += varSample[i] * varSample[i]; } + double norm = Math.sqrt(sum); - for (int i = 0; i < values.length; i++) { - sample[i] = sample[i] / norm; + for (int i = 0; i < varDim; i++) { + varSample[i] = varSample[i] / norm; + } + + for (int i = 0; i < varDim; i++) { + sample[varInds.get(i)] = varSample[i]; } ConvexSpaceRandomGenerator.LineThroughPoints distances = generator.distanceToEdge(values, sample); From 4cbe1348f42ce77c352197d1ceb598bd493cd4e8 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Tue, 14 Jun 2022 18:55:09 -0700 Subject: [PATCH 082/196] untested Blomberg's K --- .../inference/model/BlombergKStatistic.java | 155 ++++++++++++++++++ 1 file changed, 155 insertions(+) create mode 100644 src/dr/inference/model/BlombergKStatistic.java diff --git a/src/dr/inference/model/BlombergKStatistic.java b/src/dr/inference/model/BlombergKStatistic.java new file mode 100644 index 0000000000..2bed0db640 --- /dev/null +++ b/src/dr/inference/model/BlombergKStatistic.java @@ -0,0 +1,155 @@ +package dr.inference.model; + +import dr.evolution.tree.TreeTrait; +import dr.evomodel.tree.TreeModel; +import dr.evomodel.treedatalikelihood.TreeDataLikelihood; +import dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate; +import dr.evomodel.treedatalikelihood.continuous.MultivariateTraitDebugUtilities; +import dr.evomodel.treedatalikelihood.continuous.cdi.ContinuousDiffusionIntegrator; +import dr.evomodel.treedatalikelihood.continuous.cdi.PrecisionType; +import dr.math.matrixAlgebra.*; + +import java.util.Arrays; + + +public class BlombergKStatistic extends Statistic.Abstract implements ModelListener { + private final TreeDataLikelihood traitLikelihood; + private final TreeModel tree; + private final TreeTrait treeTrait; + private boolean needToUpdateTree = true; + private final int traitDim; + private Matrix Linv; + private final int treeDim; + private double mseExpected; + private final ContinuousDiffusionIntegrator integrator; + private final ContinuousDataLikelihoodDelegate delegate; + private final double[] k; + + public BlombergKStatistic(TreeDataLikelihood traitLikelihood, String traitName) { + this.traitLikelihood = traitLikelihood; + this.tree = (TreeModel) traitLikelihood.getTree(); + tree.addModelListener(this); + + this.treeTrait = traitLikelihood.getTreeTrait(traitName); + this.traitDim = traitLikelihood.getDataLikelihoodDelegate().getTraitDim(); + this.treeDim = tree.getTaxonCount(); + this.delegate = (ContinuousDataLikelihoodDelegate) traitLikelihood.getDataLikelihoodDelegate(); + this.integrator = delegate.getIntegrator(); + this.k = new double[traitDim]; + } + + + @Override + public int getDimension() { + return traitDim; + } + + @Override + public double getStatisticValue(int dim) { + if (dim == 0) { + computeStatistics(); + } + return k[dim]; + } + + + public double computeStatistics() { + if (needToUpdateTree) { + double[][] treeStructure = MultivariateTraitDebugUtilities.getTreeVariance(tree, + traitLikelihood.getBranchRateModel(), + 1.0, Double.POSITIVE_INFINITY); //TODO: make sure order is right + + + SymmetricMatrix V = new SymmetricMatrix(treeStructure); + Matrix L; + try { + CholeskyDecomposition chol = new CholeskyDecomposition(V); + L = new Matrix(chol.getL()); + } catch (IllegalDimension illegalDimension) { + illegalDimension.printStackTrace(); + throw new RuntimeException(); + } + + Linv = L.inverse(); //TODO: need some triangular-matix specific inverse + double[] ones = new double[traitDim]; + Arrays.fill(ones, 1); + + Vector l; + + try { + l = Linv.product(new Vector(ones)); + } catch (IllegalDimension illegalDimension) { + illegalDimension.printStackTrace(); + throw new RuntimeException(); + } + + double sumInverse = 0; + for (int i = 0; i < traitDim; i++) { + sumInverse += l.component(i) * l.component(i); + } + + double trace = 0; + for (int i = 0; i < treeDim; i++) { + trace += treeStructure[i][i]; + } + + mseExpected = (trace - treeDim / sumInverse) / (treeDim - 1); + needToUpdateTree = false; + } + + + double[] treeTraits = (double[]) treeTrait.getTrait(tree, tree.getRoot()); + + PrecisionType type = delegate.getDataModel().getPrecisionType(); + + double[] partial = new double[type.getPartialsDimension(traitDim)]; + + integrator.getPostOrderPartial(delegate.getActiveNodeIndex(tree.getRoot().getNumber()), partial); + double mean[] = new double[traitDim]; + System.arraycopy(partial, type.getMeanOffset(traitDim), mean, 0, traitDim); + + + double[] thisTrait = new double[treeDim]; + + for (int trait = 0; trait < traitDim; trait++) { + for (int taxon = 0; taxon < treeDim; taxon++) { + thisTrait[taxon] = treeTraits[taxon * traitDim + trait]; + } + + Vector contrasts; + + try { + contrasts = Linv.product(new Vector(thisTrait)); + } catch (IllegalDimension illegalDimension) { + illegalDimension.printStackTrace(); + throw new RuntimeException(); + } + + double ssTrait = sumSquareDiff(thisTrait, mean[trait]); + double ssContrasts = sumSquareDiff(contrasts.toComponents(), mean[trait]); + + k[trait] = (ssTrait / ssContrasts) / mseExpected; + } + + return 0; + } + + private double sumSquareDiff(double[] x, double a) { + double ss = 0; + for (int i = 0; i < x.length; i++) { + double diff = x[i] - a; + ss += diff * diff; + } + return ss; + } + + @Override + public void modelChangedEvent(Model model, Object object, int index) { + needToUpdateTree = true; + } + + @Override + public void modelRestored(Model model) { + needToUpdateTree = true; + } +} From d73c4e07cd4fd99a2a606f761f568fde1dfae32e Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Wed, 15 Jun 2022 17:47:31 -0700 Subject: [PATCH 083/196] working Blomberg's K statistic (could definitely be made more efficient) --- ci/TestXML/testBlombergKStatistic.xml | 105 ++++++++++++++++++ .../app/beast/development_parsers.properties | 1 + .../inference/model/BlombergKStatistic.java | 30 ++--- .../model/BlombergKStatisticParser.java | 39 +++++++ 4 files changed, 162 insertions(+), 13 deletions(-) create mode 100644 ci/TestXML/testBlombergKStatistic.xml create mode 100644 src/dr/inferencexml/model/BlombergKStatisticParser.java diff --git a/ci/TestXML/testBlombergKStatistic.xml b/ci/TestXML/testBlombergKStatistic.xml new file mode 100644 index 0000000000..a67f8a655a --- /dev/null +++ b/ci/TestXML/testBlombergKStatistic.xml @@ -0,0 +1,105 @@ + + + + + 1.60411326 -0.09262507 + + + 0.02369208 1.06519200 + + + -1.4149805 -0.6334667 + + + -1.1346685 -0.2332645 + + + -1.3531916 0.8481068 + + + + + (C:0.3213803857,((B:0.1327716981,E:0.9665925966):0.7209950339,(A:0.5044368715,D:0.7394543779):0.9151846841):0.7891438203); + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Check Blomberg's K statistic (1) + + + + + + 0.478642 + + + + + + Check Blomberg's K statistic (2) + + + + + + 1.216268 + + + + + diff --git a/src/dr/app/beast/development_parsers.properties b/src/dr/app/beast/development_parsers.properties index 7e791c8605..a211a42adf 100644 --- a/src/dr/app/beast/development_parsers.properties +++ b/src/dr/app/beast/development_parsers.properties @@ -306,6 +306,7 @@ dr.inference.operators.rejection.RejectionOperator dr.inference.operators.rejection.DescendingAndSpacedCondition dr.evomodel.operators.ExtendedLatentLiabilityGibbsOperator dr.inference.model.FactorProportionStatistic +dr.inferencexml.model.BlombergKStatisticParser # Shrinkage dr.inference.model.MaskFromTree diff --git a/src/dr/inference/model/BlombergKStatistic.java b/src/dr/inference/model/BlombergKStatistic.java index 2bed0db640..e5cbf49a3d 100644 --- a/src/dr/inference/model/BlombergKStatistic.java +++ b/src/dr/inference/model/BlombergKStatistic.java @@ -13,6 +13,9 @@ public class BlombergKStatistic extends Statistic.Abstract implements ModelListener { + + public static final String BLOMBERGS_K = "blombergsK"; + private final TreeDataLikelihood traitLikelihood; private final TreeModel tree; private final TreeTrait treeTrait; @@ -20,7 +23,7 @@ public class BlombergKStatistic extends Statistic.Abstract implements ModelListe private final int traitDim; private Matrix Linv; private final int treeDim; - private double mseExpected; + private double expectedRatio; private final ContinuousDiffusionIntegrator integrator; private final ContinuousDataLikelihoodDelegate delegate; private final double[] k; @@ -58,6 +61,7 @@ public double computeStatistics() { double[][] treeStructure = MultivariateTraitDebugUtilities.getTreeVariance(tree, traitLikelihood.getBranchRateModel(), 1.0, Double.POSITIVE_INFINITY); //TODO: make sure order is right + //TODO: don't actually need to construct or invert this. can use Ho & Ane 2014 for all calculations SymmetricMatrix V = new SymmetricMatrix(treeStructure); @@ -70,8 +74,8 @@ public double computeStatistics() { throw new RuntimeException(); } - Linv = L.inverse(); //TODO: need some triangular-matix specific inverse - double[] ones = new double[traitDim]; + Linv = L.inverse().transpose(); + double[] ones = new double[treeDim]; Arrays.fill(ones, 1); Vector l; @@ -84,7 +88,7 @@ public double computeStatistics() { } double sumInverse = 0; - for (int i = 0; i < traitDim; i++) { + for (int i = 0; i < treeDim; i++) { sumInverse += l.component(i) * l.component(i); } @@ -93,12 +97,12 @@ public double computeStatistics() { trace += treeStructure[i][i]; } - mseExpected = (trace - treeDim / sumInverse) / (treeDim - 1); + expectedRatio = (trace - treeDim / sumInverse) / (treeDim - 1); needToUpdateTree = false; } - double[] treeTraits = (double[]) treeTrait.getTrait(tree, tree.getRoot()); + double[] treeTraits = (double[]) treeTrait.getTrait(tree, null); PrecisionType type = delegate.getDataModel().getPrecisionType(); @@ -113,9 +117,10 @@ public double computeStatistics() { for (int trait = 0; trait < traitDim; trait++) { for (int taxon = 0; taxon < treeDim; taxon++) { - thisTrait[taxon] = treeTraits[taxon * traitDim + trait]; + thisTrait[taxon] = treeTraits[taxon * traitDim + trait] - mean[trait]; } + Vector contrasts; try { @@ -125,20 +130,19 @@ public double computeStatistics() { throw new RuntimeException(); } - double ssTrait = sumSquareDiff(thisTrait, mean[trait]); - double ssContrasts = sumSquareDiff(contrasts.toComponents(), mean[trait]); + double mse0 = sumSquares(thisTrait); + double mse = sumSquares(contrasts.toComponents()); - k[trait] = (ssTrait / ssContrasts) / mseExpected; + k[trait] = (mse0 / mse) / expectedRatio; } return 0; } - private double sumSquareDiff(double[] x, double a) { + private double sumSquares(double[] x) { double ss = 0; for (int i = 0; i < x.length; i++) { - double diff = x[i] - a; - ss += diff * diff; + ss += x[i] * x[i]; } return ss; } diff --git a/src/dr/inferencexml/model/BlombergKStatisticParser.java b/src/dr/inferencexml/model/BlombergKStatisticParser.java new file mode 100644 index 0000000000..af26d96957 --- /dev/null +++ b/src/dr/inferencexml/model/BlombergKStatisticParser.java @@ -0,0 +1,39 @@ +package dr.inferencexml.model; + +import dr.evomodel.treedatalikelihood.TreeDataLikelihood; +import dr.inference.model.BlombergKStatistic; +import dr.xml.*; + +import static dr.evomodelxml.treelikelihood.TreeTraitParserUtilities.TRAIT_NAME; + +public class BlombergKStatisticParser extends AbstractXMLObjectParser { + @Override + public Object parseXMLObject(XMLObject xo) throws XMLParseException { + TreeDataLikelihood treeDataLikelihood = (TreeDataLikelihood) xo.getChild(TreeDataLikelihood.class); + String traitName = xo.getStringAttribute(TRAIT_NAME); + return new BlombergKStatistic(treeDataLikelihood, traitName); + } + + @Override + public XMLSyntaxRule[] getSyntaxRules() { + return new XMLSyntaxRule[]{ + new ElementRule(TreeDataLikelihood.class), + AttributeRule.newStringRule(TRAIT_NAME) + }; + } + + @Override + public String getParserDescription() { + return "Blomberg's K statistic of phylogenetic signal"; + } + + @Override + public Class getReturnType() { + return BlombergKStatistic.class; + } + + @Override + public String getParserName() { + return BlombergKStatistic.BLOMBERGS_K; + } +} From 162b1768415ff70f118ccf07bbdc64fded0d296b Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Thu, 16 Jun 2022 15:53:03 -0700 Subject: [PATCH 084/196] starting to refactor bounded space class for use w/ HMC --- .../app/beast/development_parsers.properties | 2 +- src/dr/inference/model/BoundedSpace.java | 71 +++++++++++++++++++ .../model/GeneralParameterBounds.java | 56 --------------- .../TransformedParameterOperator.java | 8 +-- ...ransformedParameterRandomWalkOperator.java | 8 +-- ...ava => BoundedSpaceCorrelationParser.java} | 8 +-- .../TransformedParameterOperatorParser.java | 6 +- ...rmedParameterRandomWalkOperatorParser.java | 6 +- 8 files changed, 90 insertions(+), 75 deletions(-) create mode 100644 src/dr/inference/model/BoundedSpace.java delete mode 100644 src/dr/inference/model/GeneralParameterBounds.java rename src/dr/inferencexml/model/{CorrelationParameterBoundsParser.java => BoundedSpaceCorrelationParser.java} (74%) diff --git a/src/dr/app/beast/development_parsers.properties b/src/dr/app/beast/development_parsers.properties index a211a42adf..9a07c2106b 100644 --- a/src/dr/app/beast/development_parsers.properties +++ b/src/dr/app/beast/development_parsers.properties @@ -324,7 +324,7 @@ dr.inference.model.DeterminantStatistic dr.inference.model.MatrixDiagonalLogger dr.inference.operators.factorAnalysis.GeneralizedSampleConstraints dr.util.CorrelationToCholesky -dr.inferencexml.model.CorrelationParameterBoundsParser +dr.inferencexml.model.BoundedSpaceCorrelationParser dr.inferencexml.operators.TransformedParameterOperatorParser dr.inferencexml.operators.ConvexSpaceRandomWalkOperatorParser dr.inferencexml.distribution.LKJCorrelationWithStructuralZerosDistributionParser diff --git a/src/dr/inference/model/BoundedSpace.java b/src/dr/inference/model/BoundedSpace.java new file mode 100644 index 0000000000..84ac24ac28 --- /dev/null +++ b/src/dr/inference/model/BoundedSpace.java @@ -0,0 +1,71 @@ +package dr.inference.model; + +import org.ejml.alg.dense.decomposition.chol.CholeskyDecompositionCommon_D64; +import org.ejml.data.DenseMatrix64F; +import org.ejml.factory.DecompositionFactory; +import org.ejml.interfaces.decomposition.CholeskyDecomposition; +import org.ejml.ops.CommonOps; + +public interface BoundedSpace { + + boolean isWithinBounds(double[] values); + + + class Correlation implements BoundedSpace { + + private final int dim; + + public Correlation(int dim) { + this.dim = dim; + } + + + @Override + public boolean isWithinBounds(double[] x) { + + DenseMatrix64F C; + double[] values = new double[x.length]; + System.arraycopy(x, 0, values, 0, x.length); + + if (values.length == dim * dim) { + C = DenseMatrix64F.wrap(dim, dim, values); + for (int i = 0; i < dim; i++) { + if (C.get(i, i) != 1.0) { + return false; + } + } + + } else if (values.length == dim * (dim - 1) / 2) { + int ind = 0; + C = new DenseMatrix64F(dim, dim); + for (int i = 0; i < dim; i++) { + C.set(i, i, 1.0); + for (int j = i + 1; j < dim; j++) { + C.set(i, j, values[ind]); + C.set(j, i, values[ind]); + ind++; + } + } + } else { + throw new RuntimeException("incompatible dimensions"); + } + + + CholeskyDecomposition chol = DecompositionFactory.chol(dim, true); + boolean isDecomposable = chol.decompose(C); // in place decomposition + if (!isDecomposable) { + return false; + } + + for (int i = 0; i < dim; i++) { + if (C.get(i, i) <= 0) { + return false; + } + } + + return true; + } + + } + +} diff --git a/src/dr/inference/model/GeneralParameterBounds.java b/src/dr/inference/model/GeneralParameterBounds.java deleted file mode 100644 index 1e4af7a920..0000000000 --- a/src/dr/inference/model/GeneralParameterBounds.java +++ /dev/null @@ -1,56 +0,0 @@ -package dr.inference.model; - -import org.ejml.data.DenseMatrix64F; -import org.ejml.ops.CommonOps; - -public interface GeneralParameterBounds { - - boolean satisfiesBounds(Parameter parameter); - - - class CorrelationParameterBounds implements GeneralParameterBounds { - - private final int dim; - - public CorrelationParameterBounds(int dim) { - this.dim = dim; - } - - - @Override - public boolean satisfiesBounds(Parameter parameter) { - - DenseMatrix64F C; - double[] c = parameter.getParameterValues(); - - if (c.length == dim * dim) { - C = DenseMatrix64F.wrap(dim, dim, parameter.getParameterValues()); - for (int i = 0; i < dim; i++) { - if (C.get(i, i) != 1.0) { - return false; - } - } - - } else if (c.length == dim * (dim - 1) / 2) { - int ind = 0; - C = new DenseMatrix64F(dim, dim); - for (int i = 0; i < dim; i++) { - C.set(i, i, 1.0); - for (int j = i + 1; j < dim; j++) { - C.set(i, j, c[ind]); - C.set(j, i, c[ind]); - ind++; - } - } - } else { - throw new RuntimeException("incompatible dimensions"); - } - - - double det = CommonOps.det(C); - return det >= 0; // already checked if diagonals were 1 - } - - } - -} diff --git a/src/dr/inference/operators/TransformedParameterOperator.java b/src/dr/inference/operators/TransformedParameterOperator.java index 7be3eaff96..cd0003994f 100644 --- a/src/dr/inference/operators/TransformedParameterOperator.java +++ b/src/dr/inference/operators/TransformedParameterOperator.java @@ -1,6 +1,6 @@ package dr.inference.operators; -import dr.inference.model.GeneralParameterBounds; +import dr.inference.model.BoundedSpace; import dr.inference.model.Parameter; import dr.inference.model.TransformedParameter; @@ -9,10 +9,10 @@ public class TransformedParameterOperator extends AbstractAdaptableOperator { private final SimpleMCMCOperator subOperator; private final TransformedParameter parameter; private final boolean checkValid; - private final GeneralParameterBounds generalBounds; + private final BoundedSpace generalBounds; public static final String TRANSFORMED_OPERATOR = "transformedParameterOperator"; - public TransformedParameterOperator(SimpleMCMCOperator operator, GeneralParameterBounds generalBounds) { + public TransformedParameterOperator(SimpleMCMCOperator operator, BoundedSpace generalBounds) { this.subOperator = operator; setWeight(operator.getWeight()); @@ -70,7 +70,7 @@ public double doOperation() { if (checkValid) { // GH: below is sloppy, but best I could do without refactoring how Parameter handles bounds if (generalBounds == null && !parameter.isWithinBounds()) { return Double.NEGATIVE_INFINITY; - } else if (!generalBounds.satisfiesBounds(parameter)) { + } else if (!generalBounds.isWithinBounds(parameter.getParameterValues())) { return Double.NEGATIVE_INFINITY; } } diff --git a/src/dr/inference/operators/TransformedParameterRandomWalkOperator.java b/src/dr/inference/operators/TransformedParameterRandomWalkOperator.java index 0b363bbe38..f5e807e48c 100644 --- a/src/dr/inference/operators/TransformedParameterRandomWalkOperator.java +++ b/src/dr/inference/operators/TransformedParameterRandomWalkOperator.java @@ -25,7 +25,7 @@ package dr.inference.operators; -import dr.inference.model.GeneralParameterBounds; +import dr.inference.model.BoundedSpace; import dr.inference.model.TransformedParameter; import dr.math.matrixAlgebra.Matrix; @@ -38,7 +38,7 @@ public class TransformedParameterRandomWalkOperator extends RandomWalkOperator { private static boolean DEBUG = false; private static boolean checkValid = true; - private final GeneralParameterBounds generalBounds; + private final BoundedSpace generalBounds; public TransformedParameterRandomWalkOperator(TransformedParameter parameter, double windowSize, BoundaryCondition bc, double weight, AdaptationMode mode) { @@ -56,7 +56,7 @@ public TransformedParameterRandomWalkOperator(TransformedParameter parameter, Ra this.generalBounds = null; //TODO: implement if needed } - public TransformedParameterRandomWalkOperator(RandomWalkOperator randomWalkOperator, GeneralParameterBounds bounds) { + public TransformedParameterRandomWalkOperator(RandomWalkOperator randomWalkOperator, BoundedSpace bounds) { super((TransformedParameter) randomWalkOperator.getParameter(), randomWalkOperator.getWindowSize(), randomWalkOperator.getBoundaryCondition(), @@ -86,7 +86,7 @@ public double doOperation() { if (checkValid) { // GH: below is sloppy, but best I could do without refactoring how Parameter handles bounds if (generalBounds == null && !parameter.isWithinBounds()) { return Double.NEGATIVE_INFINITY; - } else if (!generalBounds.satisfiesBounds(parameter)) { + } else if (!generalBounds.isWithinBounds(parameter.getParameterValues())) { return Double.NEGATIVE_INFINITY; } } diff --git a/src/dr/inferencexml/model/CorrelationParameterBoundsParser.java b/src/dr/inferencexml/model/BoundedSpaceCorrelationParser.java similarity index 74% rename from src/dr/inferencexml/model/CorrelationParameterBoundsParser.java rename to src/dr/inferencexml/model/BoundedSpaceCorrelationParser.java index 64d2068d69..cdec19191e 100644 --- a/src/dr/inferencexml/model/CorrelationParameterBoundsParser.java +++ b/src/dr/inferencexml/model/BoundedSpaceCorrelationParser.java @@ -1,9 +1,9 @@ package dr.inferencexml.model; -import dr.inference.model.GeneralParameterBounds; +import dr.inference.model.BoundedSpace; import dr.xml.*; -public class CorrelationParameterBoundsParser extends AbstractXMLObjectParser { +public class BoundedSpaceCorrelationParser extends AbstractXMLObjectParser { private static final String CORRELATION_BOUNDS = "correlationBounds"; private static final String DIMENSION = "dimension"; @@ -11,7 +11,7 @@ public class CorrelationParameterBoundsParser extends AbstractXMLObjectParser { @Override public Object parseXMLObject(XMLObject xo) throws XMLParseException { int dim = xo.getIntegerAttribute(DIMENSION); - return new GeneralParameterBounds.CorrelationParameterBounds(dim); + return new BoundedSpace.Correlation(dim); } @Override @@ -28,7 +28,7 @@ public String getParserDescription() { @Override public Class getReturnType() { - return GeneralParameterBounds.CorrelationParameterBounds.class; + return BoundedSpace.Correlation.class; } @Override diff --git a/src/dr/inferencexml/operators/TransformedParameterOperatorParser.java b/src/dr/inferencexml/operators/TransformedParameterOperatorParser.java index 244f049c28..fa3ed338ab 100644 --- a/src/dr/inferencexml/operators/TransformedParameterOperatorParser.java +++ b/src/dr/inferencexml/operators/TransformedParameterOperatorParser.java @@ -1,6 +1,6 @@ package dr.inferencexml.operators; -import dr.inference.model.GeneralParameterBounds; +import dr.inference.model.BoundedSpace; import dr.inference.operators.SimpleMCMCOperator; import dr.inference.operators.TransformedParameterOperator; import dr.xml.*; @@ -13,7 +13,7 @@ public class TransformedParameterOperatorParser extends AbstractXMLObjectParser @Override public Object parseXMLObject(XMLObject xo) throws XMLParseException { SimpleMCMCOperator operator = (SimpleMCMCOperator) xo.getChild(SimpleMCMCOperator.class); - GeneralParameterBounds bounds = (GeneralParameterBounds) xo.getChild(GeneralParameterBounds.class); + BoundedSpace bounds = (BoundedSpace) xo.getChild(BoundedSpace.class); return new TransformedParameterOperator(operator, bounds); } @@ -21,7 +21,7 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { public XMLSyntaxRule[] getSyntaxRules() { return new XMLSyntaxRule[]{ new ElementRule(SimpleMCMCOperator.class), - new ElementRule(GeneralParameterBounds.class, true) + new ElementRule(BoundedSpace.class, true) }; } diff --git a/src/dr/inferencexml/operators/TransformedParameterRandomWalkOperatorParser.java b/src/dr/inferencexml/operators/TransformedParameterRandomWalkOperatorParser.java index 802e4f0e5f..023e0097d2 100644 --- a/src/dr/inferencexml/operators/TransformedParameterRandomWalkOperatorParser.java +++ b/src/dr/inferencexml/operators/TransformedParameterRandomWalkOperatorParser.java @@ -25,7 +25,7 @@ package dr.inferencexml.operators; -import dr.inference.model.GeneralParameterBounds; +import dr.inference.model.BoundedSpace; import dr.inference.model.TransformedParameter; import dr.inference.operators.AdaptableMCMCOperator; import dr.inference.operators.MCMCOperator; @@ -48,7 +48,7 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { } catch (XMLParseException e) { throw new XMLParseException("RandomWalkOperatorParser failled in TraansformedParameterRandomWalkOperator."); } - GeneralParameterBounds bounds = (GeneralParameterBounds) xo.getChild(GeneralParameterBounds.class); + BoundedSpace bounds = (BoundedSpace) xo.getChild(BoundedSpace.class); return new TransformedParameterRandomWalkOperator((RandomWalkOperator) randomWalk, bounds); } @@ -79,6 +79,6 @@ public XMLSyntaxRule[] getSyntaxRules() { }, true), new StringAttributeRule(BOUNDARY_CONDITION, null, RandomWalkOperator.BoundaryCondition.values(), true), new ElementRule(TransformedParameter.class), - new ElementRule(GeneralParameterBounds.class, true) + new ElementRule(BoundedSpace.class, true) }; } From b95c6b6a561eda55fe4becf405df47c7ab5d1826 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Thu, 16 Jun 2022 16:10:57 -0700 Subject: [PATCH 085/196] deleting unnecessary interface + moving some functionalty to 'BoundedSpace' --- src/dr/inference/model/BoundedSpace.java | 98 ++++++++++++++++++- .../ConvexSpaceRandomWalkOperator.java | 10 +- .../ConvexSpaceRandomWalkOperatorParser.java | 14 +-- .../ConvexSpaceRandomGenerator.java | 21 ---- ...lationWithStructuralZerosDistribution.java | 84 +--------------- 5 files changed, 107 insertions(+), 120 deletions(-) delete mode 100644 src/dr/math/distributions/ConvexSpaceRandomGenerator.java diff --git a/src/dr/inference/model/BoundedSpace.java b/src/dr/inference/model/BoundedSpace.java index 84ac24ac28..8f957cf592 100644 --- a/src/dr/inference/model/BoundedSpace.java +++ b/src/dr/inference/model/BoundedSpace.java @@ -1,18 +1,44 @@ package dr.inference.model; -import org.ejml.alg.dense.decomposition.chol.CholeskyDecompositionCommon_D64; +import dr.app.bss.Utils; +import dr.evomodel.substmodel.ColtEigenSystem; +import dr.evomodel.substmodel.EigenDecomposition; +import dr.math.matrixAlgebra.IllegalDimension; +import dr.math.matrixAlgebra.Matrix; +import dr.math.matrixAlgebra.SymmetricMatrix; import org.ejml.data.DenseMatrix64F; import org.ejml.factory.DecompositionFactory; import org.ejml.interfaces.decomposition.CholeskyDecomposition; -import org.ejml.ops.CommonOps; + +import static dr.math.matrixAlgebra.SymmetricMatrix.compoundCorrelationSymmetricMatrix; +import static dr.math.matrixAlgebra.SymmetricMatrix.compoundSymmetricMatrix; public interface BoundedSpace { boolean isWithinBounds(double[] values); + IntersectionDistances distancesToBoundary(double[] origin, double[] direction); + + default double forwardDistanceToBoundary(double[] origin, double[] direction) { + return distancesToBoundary(origin, direction).forwardDistance; + } + + class IntersectionDistances { + public final double forwardDistance; + public final double backwardDistance; + public final double totalDistance; + + public IntersectionDistances(double forwardDistance, double backwardDistance) { + this.forwardDistance = forwardDistance; + this.backwardDistance = backwardDistance; + this.totalDistance = forwardDistance + backwardDistance; + } + } + class Correlation implements BoundedSpace { + private static final boolean DEBUG = false; private final int dim; public Correlation(int dim) { @@ -66,6 +92,74 @@ public boolean isWithinBounds(double[] x) { return true; } + @Override + public IntersectionDistances distancesToBoundary(double[] origin, double[] direction) { + double[] x = new double[origin.length]; + System.arraycopy(direction, 0, x, 0, x.length); + + SymmetricMatrix Y = compoundCorrelationSymmetricMatrix(origin, dim); + SymmetricMatrix X = compoundSymmetricMatrix(0.0, x, dim); + +// SymmetricMatrix Xinv = X.inverse(); + SymmetricMatrix Yinv = Y.inverse(); + final Matrix Z; + + try { + Z = Yinv.product(X); + } catch (IllegalDimension illegalDimension) { + throw new RuntimeException("illegal dimensions"); + } + + ColtEigenSystem eigenSystem = new ColtEigenSystem(dim); + EigenDecomposition decomposition = eigenSystem.decomposeMatrix(Z.toComponents()); //TODO: only need smallest magnitude eigenvalues + double[] values = decomposition.getEigenValues(); + + double maxNegative = 0; + double maxPositive = 0; + for (int i = 0; i < values.length; i++) { + double value = values[i]; + if (value < 0 && value < maxNegative) { + maxNegative = value; + } else if (value >= 0 & value > maxPositive) { + maxPositive = value; + } + } + + if (DEBUG) { + System.out.print("Eigenvalues: "); + Utils.printArray(values); + + Matrix S = new SymmetricMatrix(dim, dim); + Matrix T = new SymmetricMatrix(dim, dim); + for (int i = 0; i < dim; i++) { + S.set(i, i, 1); + T.set(i, i, 1); + for (int j = (i + 1); j < dim; j++) { + double y = Y.toComponents()[i][j]; + double z = X.toComponents()[i][j]; + double valueS = y - z / maxNegative; + double valueT = y - z / maxPositive; + S.set(i, j, valueS); + S.set(j, i, valueS); + T.set(i, j, valueT); + T.set(j, i, valueT); + } + } + try { + System.out.println("neg: \n\tt = " + maxNegative); + System.out.println("\tdet = " + S.determinant()); + System.out.println(S); + System.out.println("pos: \n\tt = " + maxPositive); + System.out.println("\tdet = " + T.determinant()); + System.out.println(T); + } catch (IllegalDimension illegalDimension) { + illegalDimension.printStackTrace(); + } + } + + return new IntersectionDistances(1 / maxPositive, -1 / maxNegative); + } + } } diff --git a/src/dr/inference/operators/ConvexSpaceRandomWalkOperator.java b/src/dr/inference/operators/ConvexSpaceRandomWalkOperator.java index d8363aee64..29209ddc4f 100644 --- a/src/dr/inference/operators/ConvexSpaceRandomWalkOperator.java +++ b/src/dr/inference/operators/ConvexSpaceRandomWalkOperator.java @@ -1,7 +1,7 @@ package dr.inference.operators; +import dr.inference.model.BoundedSpace; import dr.inference.model.Parameter; -import dr.math.distributions.ConvexSpaceRandomGenerator; import dr.math.distributions.MultivariateNormalDistribution; import dr.math.matrixAlgebra.CholeskyDecomposition; import jebl.math.Random; @@ -17,7 +17,7 @@ public class ConvexSpaceRandomWalkOperator extends AbstractAdaptableOperator { private double window; - private final ConvexSpaceRandomGenerator generator; + private final BoundedSpace space; private final Parameter parameter; private final Parameter updateIndex; private final boolean ADAPTIVE_COVARIANCE = true; @@ -36,14 +36,14 @@ public class ConvexSpaceRandomWalkOperator extends AbstractAdaptableOperator { public static final String CONVEX_RW = "convexSpaceRandomWalkOperator"; public static final String WINDOW_SIZE = "relativeWindowSize"; - public ConvexSpaceRandomWalkOperator(Parameter parameter, ConvexSpaceRandomGenerator generator, + public ConvexSpaceRandomWalkOperator(Parameter parameter, BoundedSpace space, Parameter updateIndex, double window, double weight) { setWeight(weight); this.updateIndex = updateIndex; this.parameter = parameter; - this.generator = generator; + this.space = space; this.window = window; this.dim = parameter.getDimension(); @@ -190,7 +190,7 @@ public double doOperation() { sample[varInds.get(i)] = varSample[i]; } - ConvexSpaceRandomGenerator.LineThroughPoints distances = generator.distanceToEdge(values, sample); + BoundedSpace.IntersectionDistances distances = space.distancesToBoundary(values, sample); // double u1 = Random.nextDouble() * distances.forwardDistance; // for (int i = 0; i < values.length; i++) { // sample[i] = values[i] + (sample[i] - values[i]) * u1; diff --git a/src/dr/inferencexml/operators/ConvexSpaceRandomWalkOperatorParser.java b/src/dr/inferencexml/operators/ConvexSpaceRandomWalkOperatorParser.java index 42a178eca6..94734b6b91 100644 --- a/src/dr/inferencexml/operators/ConvexSpaceRandomWalkOperatorParser.java +++ b/src/dr/inferencexml/operators/ConvexSpaceRandomWalkOperatorParser.java @@ -1,10 +1,9 @@ package dr.inferencexml.operators; +import dr.inference.model.BoundedSpace; import dr.inference.model.Parameter; import dr.inference.operators.ConvexSpaceRandomWalkOperator; import dr.inference.operators.MCMCOperator; -import dr.inference.operators.RandomWalkOperator; -import dr.math.distributions.ConvexSpaceRandomGenerator; import dr.xml.*; public class ConvexSpaceRandomWalkOperatorParser extends AbstractXMLObjectParser { @@ -13,12 +12,9 @@ public class ConvexSpaceRandomWalkOperatorParser extends AbstractXMLObjectParser @Override public Object parseXMLObject(XMLObject xo) throws XMLParseException { Parameter parameter = (Parameter) xo.getChild(Parameter.class); - ConvexSpaceRandomGenerator generator = - (ConvexSpaceRandomGenerator) xo.getChild(ConvexSpaceRandomGenerator.class); + BoundedSpace space = + (BoundedSpace) xo.getChild(BoundedSpace.class); - if (!generator.isUniform()) { - throw new XMLParseException("sample distribution must be uniform over its support"); - } double weight = xo.getDoubleAttribute(MCMCOperator.WEIGHT); @@ -36,14 +32,14 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { updateIndex = null; } - return new ConvexSpaceRandomWalkOperator(parameter, generator, updateIndex, windowSize, weight); + return new ConvexSpaceRandomWalkOperator(parameter, space, updateIndex, windowSize, weight); } @Override public XMLSyntaxRule[] getSyntaxRules() { return new XMLSyntaxRule[]{ new ElementRule(Parameter.class), - new ElementRule(ConvexSpaceRandomGenerator.class), + new ElementRule(BoundedSpace.class), new ElementRule(RandomWalkOperatorParser.UPDATE_INDEX, new XMLSyntaxRule[]{ new ElementRule(Parameter.class) diff --git a/src/dr/math/distributions/ConvexSpaceRandomGenerator.java b/src/dr/math/distributions/ConvexSpaceRandomGenerator.java deleted file mode 100644 index 51b947586c..0000000000 --- a/src/dr/math/distributions/ConvexSpaceRandomGenerator.java +++ /dev/null @@ -1,21 +0,0 @@ -package dr.math.distributions; - -public interface ConvexSpaceRandomGenerator extends RandomGenerator { - - LineThroughPoints distanceToEdge(double[] origin, double[] draw); - - boolean isUniform(); - - class LineThroughPoints { - public final double forwardDistance; - public final double backwardDistance; - public final double totalDistance; - - public LineThroughPoints(double forwardDistance, double backwardDistance) { - this.forwardDistance = forwardDistance; - this.backwardDistance = backwardDistance; - this.totalDistance = forwardDistance + backwardDistance; - } - } - -} diff --git a/src/dr/math/distributions/LKJCorrelationWithStructuralZerosDistribution.java b/src/dr/math/distributions/LKJCorrelationWithStructuralZerosDistribution.java index 375968aaa4..2bdf271aee 100644 --- a/src/dr/math/distributions/LKJCorrelationWithStructuralZerosDistribution.java +++ b/src/dr/math/distributions/LKJCorrelationWithStructuralZerosDistribution.java @@ -1,21 +1,13 @@ package dr.math.distributions; -import dr.app.bss.Utils; -import dr.evomodel.substmodel.ColtEigenSystem; -import dr.evomodel.substmodel.EigenDecomposition; import dr.math.MathUtils; -import dr.math.matrixAlgebra.IllegalDimension; -import dr.math.matrixAlgebra.Matrix; -import dr.math.matrixAlgebra.SymmetricMatrix; import org.ejml.data.DenseMatrix64F; import org.ejml.ops.CommonOps; import java.util.ArrayList; -import static dr.math.matrixAlgebra.SymmetricMatrix.compoundCorrelationSymmetricMatrix; -import static dr.math.matrixAlgebra.SymmetricMatrix.compoundSymmetricMatrix; -public class LKJCorrelationWithStructuralZerosDistribution extends LKJCorrelationDistribution implements ConvexSpaceRandomGenerator { +public class LKJCorrelationWithStructuralZerosDistribution extends LKJCorrelationDistribution implements RandomGenerator { private final int[] blockAssignments; @@ -133,78 +125,4 @@ public double logPdf(Object x) { return logPdf((double[]) x); } - @Override - public LineThroughPoints distanceToEdge(double[] origin, double[] draw) { - double[] x = new double[origin.length]; - for (int i = 0; i < origin.length; i++) { - x[i] = draw[i]; - } - - SymmetricMatrix Y = compoundCorrelationSymmetricMatrix(origin, dim); - SymmetricMatrix X = compoundSymmetricMatrix(0.0, x, dim); - -// SymmetricMatrix Xinv = X.inverse(); - SymmetricMatrix Yinv = Y.inverse(); - final Matrix Z; - - try { - Z = Yinv.product(X); - } catch (IllegalDimension illegalDimension) { - throw new RuntimeException("illegal dimensions"); - } - - ColtEigenSystem eigenSystem = new ColtEigenSystem(dim); - EigenDecomposition decomposition = eigenSystem.decomposeMatrix(Z.toComponents()); //TODO: only need smallest magnitude eigenvalues - double[] values = decomposition.getEigenValues(); - - double maxNegative = 0; - double maxPositive = 0; - for (int i = 0; i < values.length; i++) { - double value = values[i]; - if (value < 0 && value < maxNegative) { - maxNegative = value; - } else if (value >= 0 & value > maxPositive) { - maxPositive = value; - } - } - - if (DEBUG) { - System.out.print("Eigenvalues: "); - Utils.printArray(values); - - Matrix S = new SymmetricMatrix(dim, dim); - Matrix T = new SymmetricMatrix(dim, dim); - for (int i = 0; i < dim; i++) { - S.set(i, i, 1); - T.set(i, i, 1); - for (int j = (i + 1); j < dim; j++) { - double y = Y.toComponents()[i][j]; - double z = X.toComponents()[i][j]; - double valueS = y - z / maxNegative; - double valueT = y - z / maxPositive; - S.set(i, j, valueS); - S.set(j, i, valueS); - T.set(i, j, valueT); - T.set(j, i, valueT); - } - } - try { - System.out.println("neg: \n\tt = " + maxNegative); - System.out.println("\tdet = " + S.determinant()); - System.out.println(S); - System.out.println("pos: \n\tt = " + maxPositive); - System.out.println("\tdet = " + T.determinant()); - System.out.println(T); - } catch (IllegalDimension illegalDimension) { - illegalDimension.printStackTrace(); - } - } - - return new ConvexSpaceRandomGenerator.LineThroughPoints(1 / maxPositive, -1 / maxNegative); - } - - @Override - public boolean isUniform() { - return shape == 1; - } } From 03c23ea3ca4b0fc9cfe98c705c116870c7e86fd8 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Fri, 17 Jun 2022 17:58:11 -0700 Subject: [PATCH 086/196] first attempt at HMC that bounces off curved boundaries in the posterior (currently numerically unstable) --- src/dr/inference/model/BoundedSpace.java | 32 +++- .../model/GeneralBoundsProvider.java | 4 + .../model/GraphicalParameterBound.java | 2 +- .../operators/hmc/MassPreconditioner.java | 8 + ...flectiveHamiltonianMonteCarloOperator.java | 167 ++++++++++++++---- ...veHamiltonianMonteCarloOperatorParser.java | 10 +- src/dr/math/matrixAlgebra/EJMLUtils.java | 73 ++++++++ 7 files changed, 256 insertions(+), 40 deletions(-) create mode 100644 src/dr/inference/model/GeneralBoundsProvider.java diff --git a/src/dr/inference/model/BoundedSpace.java b/src/dr/inference/model/BoundedSpace.java index 8f957cf592..312ca03024 100644 --- a/src/dr/inference/model/BoundedSpace.java +++ b/src/dr/inference/model/BoundedSpace.java @@ -3,6 +3,7 @@ import dr.app.bss.Utils; import dr.evomodel.substmodel.ColtEigenSystem; import dr.evomodel.substmodel.EigenDecomposition; +import dr.math.matrixAlgebra.EJMLUtils; import dr.math.matrixAlgebra.IllegalDimension; import dr.math.matrixAlgebra.Matrix; import dr.math.matrixAlgebra.SymmetricMatrix; @@ -13,12 +14,14 @@ import static dr.math.matrixAlgebra.SymmetricMatrix.compoundCorrelationSymmetricMatrix; import static dr.math.matrixAlgebra.SymmetricMatrix.compoundSymmetricMatrix; -public interface BoundedSpace { +public interface BoundedSpace extends GeneralBoundsProvider { boolean isWithinBounds(double[] values); IntersectionDistances distancesToBoundary(double[] origin, double[] direction); + double[] getNormalVectorAtBoundary(double[] position); + default double forwardDistanceToBoundary(double[] origin, double[] direction) { return distancesToBoundary(origin, direction).forwardDistance; } @@ -111,7 +114,7 @@ public IntersectionDistances distancesToBoundary(double[] origin, double[] direc } ColtEigenSystem eigenSystem = new ColtEigenSystem(dim); - EigenDecomposition decomposition = eigenSystem.decomposeMatrix(Z.toComponents()); //TODO: only need smallest magnitude eigenvalues + EigenDecomposition decomposition = eigenSystem.decomposeMatrix(Z.toComponents()); //TODO: only need largest magnitude eigenvalues double[] values = decomposition.getEigenValues(); double maxNegative = 0; @@ -160,6 +163,31 @@ public IntersectionDistances distancesToBoundary(double[] origin, double[] direc return new IntersectionDistances(1 / maxPositive, -1 / maxNegative); } + @Override + public double[] getNormalVectorAtBoundary(double[] position) { + Utils.printArray(position); + double[] c = compoundCorrelationSymmetricMatrix(position, dim).toArrayComponents(); + DenseMatrix64F C = DenseMatrix64F.wrap(dim, dim, c); + DenseMatrix64F A; + try { + A = EJMLUtils.computeRobustAdjugate(C); + } catch (IllegalDimension illegalDimension) { + illegalDimension.printStackTrace(); + throw new RuntimeException(); + } + + double[] normalVector = new double[position.length]; + int ind = 0; + for (int i = 0; i < dim; i++) { + for (int j = (i + 1); j < dim; j++) { + normalVector[ind] = A.get(i, j); + ind++; + } + } + + return normalVector; + } + } } diff --git a/src/dr/inference/model/GeneralBoundsProvider.java b/src/dr/inference/model/GeneralBoundsProvider.java new file mode 100644 index 0000000000..23982002b2 --- /dev/null +++ b/src/dr/inference/model/GeneralBoundsProvider.java @@ -0,0 +1,4 @@ +package dr.inference.model; + +public interface GeneralBoundsProvider { +} diff --git a/src/dr/inference/model/GraphicalParameterBound.java b/src/dr/inference/model/GraphicalParameterBound.java index ce38762446..6b1def4f5d 100644 --- a/src/dr/inference/model/GraphicalParameterBound.java +++ b/src/dr/inference/model/GraphicalParameterBound.java @@ -29,7 +29,7 @@ * @author Xiang Ji * @author Marc A. Suchard */ -public interface GraphicalParameterBound { +public interface GraphicalParameterBound extends GeneralBoundsProvider { Parameter getParameter(); diff --git a/src/dr/inference/operators/hmc/MassPreconditioner.java b/src/dr/inference/operators/hmc/MassPreconditioner.java index b46f70055c..eb4a9c068e 100644 --- a/src/dr/inference/operators/hmc/MassPreconditioner.java +++ b/src/dr/inference/operators/hmc/MassPreconditioner.java @@ -42,6 +42,14 @@ public interface MassPreconditioner { int getDimension(); + default double[] getVelocity(ReadableVector momentum) { + double[] velocity = new double[momentum.getDim()]; + for (int i = 0; i < velocity.length; i++) { + velocity[i] = getVelocity(i, momentum); + } + return velocity; + } + enum Type { NONE("none") { diff --git a/src/dr/inference/operators/hmc/ReflectiveHamiltonianMonteCarloOperator.java b/src/dr/inference/operators/hmc/ReflectiveHamiltonianMonteCarloOperator.java index 564262cf29..0966ba13d4 100644 --- a/src/dr/inference/operators/hmc/ReflectiveHamiltonianMonteCarloOperator.java +++ b/src/dr/inference/operators/hmc/ReflectiveHamiltonianMonteCarloOperator.java @@ -27,8 +27,7 @@ package dr.inference.operators.hmc; import dr.inference.hmc.GradientWrtParameterProvider; -import dr.inference.model.GraphicalParameterBound; -import dr.inference.model.Parameter; +import dr.inference.model.*; import dr.inference.operators.AdaptationMode; import dr.inferencexml.operators.hmc.ReflectiveHamiltonianMonteCarloOperatorParser; import dr.math.matrixAlgebra.ReadableVector; @@ -43,7 +42,8 @@ */ public class ReflectiveHamiltonianMonteCarloOperator extends HamiltonianMonteCarloOperator implements Reportable { - private final GraphicalParameterBound treeParameterBound; + private final GeneralBoundsProvider parameterBound; + private static final boolean DEBUG = true; public ReflectiveHamiltonianMonteCarloOperator(AdaptationMode mode, @@ -54,17 +54,26 @@ public ReflectiveHamiltonianMonteCarloOperator(AdaptationMode mode, Parameter maskParameter, Options runtimeOptions, MassPreconditioner preconditioner, - GraphicalParameterBound graphicalParameterBound) { + GeneralBoundsProvider bounds) { super(mode, weight, gradientProvider, parameter, transform, maskParameter, runtimeOptions, preconditioner); - this.treeParameterBound = graphicalParameterBound; + this.parameterBound = bounds; this.leapFrogEngine = constructLeapFrogEngine(transform); } @Override protected LeapFrogEngine constructLeapFrogEngine(Transform transform) { - return new WithGraphBounds(parameter, getDefaultInstabilityHandler(), preconditioning, mask, treeParameterBound); + if (transform != null) { + throw new RuntimeException("not yet implemented"); + } + + if (parameterBound instanceof GraphicalParameterBound) { //TODO: don't use 'instanceof' to deal with this. + return new WithGraphBounds(parameter, getDefaultInstabilityHandler(), preconditioning, mask, + (GraphicalParameterBound) parameterBound); + } + return new WithMultivariateCurvedBounds(parameter, getDefaultInstabilityHandler(), preconditioning, mask, + (BoundedSpace) parameterBound); } @Override @@ -78,21 +87,17 @@ public String getOperatorName() { } - class WithGraphBounds extends HamiltonianMonteCarloOperator.LeapFrogEngine.Default { - - final private GraphicalParameterBound graphicalParameterBound; - - protected WithGraphBounds(Parameter parameter, - InstabilityHandler instabilityHandler, - MassPreconditioner preconditioning, - double[] mask, - GraphicalParameterBound graphicalParameterBound) { + abstract class WithBounds extends HamiltonianMonteCarloOperator.LeapFrogEngine.Default { + WithBounds(Parameter parameter, + InstabilityHandler instabilityHandler, + MassPreconditioner preconditioning, + double[] mask) { super(parameter, instabilityHandler, preconditioning, mask); + } - this.graphicalParameterBound = graphicalParameterBound; + protected abstract ReflectionEvent nextEvent(double[] position, WrappedVector momentum, double intervalLength); - } @Override public void updatePosition(double[] position, WrappedVector momentum, @@ -107,7 +112,68 @@ public void updatePosition(double[] position, WrappedVector momentum, setParameter(position); } - private ReflectionEvent nextEvent(double[] position, WrappedVector momentum, double intervalLength) { + + } + + class WithMultivariateCurvedBounds extends WithBounds { + + private final BoundedSpace space; + + WithMultivariateCurvedBounds(Parameter parameter, + InstabilityHandler instabilityHandler, + MassPreconditioner preconditioning, + double[] mask, + BoundedSpace space) { + super(parameter, instabilityHandler, preconditioning, mask); + this.space = space; + } + + + @Override + protected ReflectionEvent nextEvent(double[] position, WrappedVector momentum, double intervalLength) { + double[] velocity = preconditioning.getVelocity(momentum); + double timeToReflection = space.forwardDistanceToBoundary(position, velocity); + + if (DEBUG) { + System.out.println("Time to reflection: " + timeToReflection); + } + + + + if (timeToReflection < intervalLength) { + return new ReflectionEvent(ReflectionType.None, Double.NaN, intervalLength, new int[0]); + } else { + double[] boundaryPosition = new double[position.length]; + for (int i = 0; i < position.length; i++) { + boundaryPosition[i] = position[i] + timeToReflection * velocity[i]; + } + double[] normalVector = space.getNormalVectorAtBoundary(boundaryPosition); + return new ReflectionEvent(ReflectionType.MultivariateReflection, timeToReflection, + boundaryPosition, normalVector); + } + } + } + + + class WithGraphBounds extends WithBounds { + + final private GraphicalParameterBound graphicalParameterBound; + + protected WithGraphBounds(Parameter parameter, + InstabilityHandler instabilityHandler, + MassPreconditioner preconditioning, + double[] mask, + GraphicalParameterBound graphicalParameterBound) { + + super(parameter, instabilityHandler, preconditioning, mask); + + this.graphicalParameterBound = graphicalParameterBound; + + } + + + @Override + protected ReflectionEvent nextEvent(double[] position, WrappedVector momentum, double intervalLength) { ReflectionEvent reflectionEventAtFixedBound = firstReflectionAtFixedBounds(position, momentum, intervalLength); ReflectionEvent collisionEvent = firstCollision(position, momentum, intervalLength); return (reflectionEventAtFixedBound.getEventTime() < collisionEvent.getEventTime()) ? reflectionEventAtFixedBound : collisionEvent; @@ -164,7 +230,7 @@ private ReflectionEvent firstCollision(double[] position, ReadableVector momentu } } } - return new ReflectionEvent(type, firstCollisionTime, collisionLocation, intervalLength, new int[]{index1, index2}); + return new ReflectionEvent(type, firstCollisionTime, collisionLocation, new int[]{index1, index2}); } private double[] getIntendedPosition(double[] position, ReadableVector momentum, double intervalLength) { @@ -210,7 +276,7 @@ private ReflectionEvent firstReflectionAtFixedBounds(double[] position, Readable } } } - return new ReflectionEvent(type, firstReflection, reflectionLocation, intervalLength, new int[]{reflectionIndex}); + return new ReflectionEvent(type, firstReflection, reflectionLocation, new int[]{reflectionIndex}); } @@ -219,20 +285,34 @@ private ReflectionEvent firstReflectionAtFixedBounds(double[] position, Readable class ReflectionEvent { private final ReflectionType type; private final double eventTime; - private final double eventLocation; - private final double intervalLength; + private final double[] eventLocation; private final int[] indices; + private final double[] normalVector; ReflectionEvent(ReflectionType type, double eventTime, - double eventLocation, - double intervalLength, + double[] eventLocation, + double[] normalVector, int[] indices) { this.type = type; this.eventTime = eventTime; - this.intervalLength = intervalLength; this.indices = indices; this.eventLocation = eventLocation; + this.normalVector = normalVector; + } + + ReflectionEvent(ReflectionType type, + double eventTime, + double eventLocation, + int[] indices) { + this(type, eventTime, new double[]{eventLocation}, null, indices); + } + + ReflectionEvent(ReflectionType type, + double eventTime, + double[] eventLocation, + double[] normalVector) { + this(type, eventTime, eventLocation, normalVector, new int[0]); } public double getEventTime() { @@ -244,31 +324,54 @@ public ReflectionType getType() { } public void doReflection(double[] position, WrappedVector momentum) { - type.doReflection(position, preconditioning, momentum, eventLocation, indices, eventTime); + type.doReflection(position, preconditioning, momentum, eventLocation, indices, normalVector, eventTime); } } enum ReflectionType { + + MultivariateReflection { + @Override + void doReflection(double[] position, MassPreconditioner preconditioning, WrappedVector momentum, + double eventLocation[], int[] indices, double[] normalVector, double time) { + + updatePosition(position, preconditioning, momentum, time); + double vn = 0; + double nn = 0; + + for (int i : indices) { + vn += momentum.get(i) * normalVector[i]; + nn += normalVector[i] * normalVector[i]; + } + + double c = 2 * vn / nn; + + for (int i : indices) { + momentum.set(i, momentum.get(i) - c * normalVector[i]); + position[i] = eventLocation[i]; + } + } + }, Reflection { @Override void doReflection(double[] position, MassPreconditioner preconditioning, WrappedVector momentum, - double eventLocation, int[] indices, double time) { + double eventLocation[], int[] indices, double[] normalVector, double time) { updatePosition(position, preconditioning, momentum, time); momentum.set(indices[0], -momentum.get(indices[0])); - position[indices[0]] = eventLocation; + position[indices[0]] = eventLocation[0]; } }, Collision { @Override void doReflection(double[] position, MassPreconditioner preconditioning, WrappedVector momentum, - double eventLocation, int[] indices, double time) { + double eventLocation[], int[] indices, double[] normalVector, double time) { updatePosition(position, preconditioning, momentum, time); ReadableVector updatedMomentum = preconditioning.doCollision(indices, momentum); for (int index : indices) { momentum.set(index, updatedMomentum.get(index)); - position[index] = eventLocation; + position[index] = eventLocation[0]; } } @@ -276,7 +379,7 @@ void doReflection(double[] position, MassPreconditioner preconditioning, Wrapped None { @Override void doReflection(double[] position, MassPreconditioner preconditioning, WrappedVector momentum, - double eventLocation, int[] indices, double time) { + double eventLocation[], int[] indices, double[] normalVector, double time) { updatePosition(position, preconditioning, momentum, time); } }; @@ -289,7 +392,7 @@ void updatePosition(double[] position, MassPreconditioner preconditioning, Wrapp } abstract void doReflection(double[] position, MassPreconditioner preconditioning, WrappedVector momentum, - double eventLocation, int[] indices, double time); + double eventLocation[], int[] indices, double[] normalVector, double time); } } diff --git a/src/dr/inferencexml/operators/hmc/ReflectiveHamiltonianMonteCarloOperatorParser.java b/src/dr/inferencexml/operators/hmc/ReflectiveHamiltonianMonteCarloOperatorParser.java index 1a9ef551fe..4bf91c0c81 100644 --- a/src/dr/inferencexml/operators/hmc/ReflectiveHamiltonianMonteCarloOperatorParser.java +++ b/src/dr/inferencexml/operators/hmc/ReflectiveHamiltonianMonteCarloOperatorParser.java @@ -26,7 +26,7 @@ package dr.inferencexml.operators.hmc; import dr.inference.hmc.GradientWrtParameterProvider; -import dr.inference.hmc.ReversibleHMCProvider; +import dr.inference.model.GeneralBoundsProvider; import dr.inference.model.GraphicalParameterBound; import dr.inference.model.Parameter; import dr.inference.operators.AdaptationMode; @@ -47,11 +47,11 @@ public class ReflectiveHamiltonianMonteCarloOperatorParser extends HamiltonianMonteCarloOperatorParser { public final static String OPERATOR_NAME = "reflectiveHamiltonianMonteCarloOperator"; - private GraphicalParameterBound graphicalParameterBound; + private GeneralBoundsProvider bounds; @Override public Object parseXMLObject(XMLObject xo) throws XMLParseException { - this.graphicalParameterBound = (GraphicalParameterBound) xo.getChild(GraphicalParameterBound.class); + this.bounds = (GeneralBoundsProvider) xo.getChild(GeneralBoundsProvider.class); return super.parseXMLObject(xo); } @@ -63,13 +63,13 @@ protected HamiltonianMonteCarloOperator factory(AdaptationMode adaptationMode, d return new ReflectiveHamiltonianMonteCarloOperator(adaptationMode, weight, derivative, parameter, transform, mask, - runtimeOptions, preconditioner, graphicalParameterBound); + runtimeOptions, preconditioner, bounds); } @Override public XMLSyntaxRule[] getSyntaxRules() { XMLSyntaxRule[] extendedRules = new XMLSyntaxRule[rules.length + 1]; - extendedRules[0] = new ElementRule(GraphicalParameterBound.class); + extendedRules[0] = new ElementRule(GeneralBoundsProvider.class); for (int i = 0; i < rules.length; i++) { extendedRules[i + 1] = rules[i]; } diff --git a/src/dr/math/matrixAlgebra/EJMLUtils.java b/src/dr/math/matrixAlgebra/EJMLUtils.java index 589ebf589a..4f12aef7ab 100644 --- a/src/dr/math/matrixAlgebra/EJMLUtils.java +++ b/src/dr/math/matrixAlgebra/EJMLUtils.java @@ -1,11 +1,16 @@ package dr.math.matrixAlgebra; import org.ejml.data.DenseMatrix64F; +import org.ejml.factory.DecompositionFactory; +import org.ejml.interfaces.decomposition.SingularValueDecomposition; +import org.ejml.ops.CommonOps; import java.util.Arrays; public class EJMLUtils { + private static final Boolean DEBUG = true; + public static void addWithTransposed(DenseMatrix64F X) { checkSquare(X); @@ -42,4 +47,72 @@ private static void checkSquare(DenseMatrix64F X) { throw new IllegalArgumentException("matrix must be square."); } } + + public static DenseMatrix64F computeRobustAdjugate(DenseMatrix64F X) throws IllegalDimension { + // algorithm taken from: Stewart, G. W. "On the adjugate matrix." Linear Algebra and its Applications 283.1-3 (1998): 151-164. + + if (DEBUG) { + System.out.println(X); + } + + + checkSquare(X); + + int dim = X.numRows; + + SingularValueDecomposition svd = DecompositionFactory.svd(dim, dim, + true, true, true); + + svd.decompose(X); + DenseMatrix64F Ut = new DenseMatrix64F(dim, dim); + DenseMatrix64F V = new DenseMatrix64F(dim, dim); + + svd.getU(Ut, true); + svd.getV(V, false); + double[] svs = svd.getSingularValues(); + + double du = CommonOps.det(Ut); + double dv = CommonOps.det(V); //TODO: should be same as du if matrix is symmetric + + // gamma_i is just the product of the singular values excluding sv_i + double[] gamma = new double[dim]; + double[] forwardProds = new double[dim]; + double[] backwardProds = new double[dim]; + double fp = 1; + double bp = 1; + for (int i = 0; i < dim; i++) { + fp *= svs[i]; + forwardProds[i] = fp; + + int backInd = dim - i - 1; + bp *= svs[backInd]; + backwardProds[backInd] = bp; + } + + gamma[0] = backwardProds[1]; + gamma[dim - 1] = forwardProds[dim - 2]; + + for (int i = 1; i < dim - 1; i++) { + gamma[i] = forwardProds[i - 1] * backwardProds[i + 1]; + } + + + // Ut = Diag(gamma) * Ut + for (int i = 0; i < dim; i++) { + for (int j = 0; j < dim; j++) { + Ut.set(i, j, Ut.get(i, j) * gamma[i]); + } + } + + DenseMatrix64F adjugate = new DenseMatrix64F(dim, dim); + CommonOps.mult(V, Ut, adjugate); + CommonOps.scale(du * dv, adjugate); + + if (DEBUG) { + System.out.println(adjugate); + } + + return adjugate; + } + } From be7f56302a5df8c551190b06a84769d316953b4c Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Tue, 21 Jun 2022 17:49:37 -0700 Subject: [PATCH 087/196] moving method to ContinuousDataLikelihoodDelegate where it might be more useful --- .../ContinuousDataLikelihoodDelegate.java | 12 ++++++++++++ src/dr/inference/model/BlombergKStatistic.java | 16 ++-------------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/dr/evomodel/treedatalikelihood/continuous/ContinuousDataLikelihoodDelegate.java b/src/dr/evomodel/treedatalikelihood/continuous/ContinuousDataLikelihoodDelegate.java index 7bbe6f0ab8..2885255c0e 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/ContinuousDataLikelihoodDelegate.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/ContinuousDataLikelihoodDelegate.java @@ -1116,4 +1116,16 @@ static ContinuousDataLikelihoodDelegate createWithMissingData(ContinuousDataLike false, likelihoodDelegate.allowSingular); } + + public double[] getPostOrderRootMean() { + PrecisionType type = getDataModel().getPrecisionType(); + + double[] partial = new double[type.getPartialsDimension(dimProcess)]; + + getIntegrator().getPostOrderPartial(getActiveNodeIndex(tree.getRoot().getNumber()), partial); + double mean[] = new double[dimProcess]; + System.arraycopy(partial, type.getMeanOffset(dimProcess), mean, 0, dimProcess); + + return mean; + } } diff --git a/src/dr/inference/model/BlombergKStatistic.java b/src/dr/inference/model/BlombergKStatistic.java index e5cbf49a3d..60990ba3e2 100644 --- a/src/dr/inference/model/BlombergKStatistic.java +++ b/src/dr/inference/model/BlombergKStatistic.java @@ -5,8 +5,6 @@ import dr.evomodel.treedatalikelihood.TreeDataLikelihood; import dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate; import dr.evomodel.treedatalikelihood.continuous.MultivariateTraitDebugUtilities; -import dr.evomodel.treedatalikelihood.continuous.cdi.ContinuousDiffusionIntegrator; -import dr.evomodel.treedatalikelihood.continuous.cdi.PrecisionType; import dr.math.matrixAlgebra.*; import java.util.Arrays; @@ -24,7 +22,6 @@ public class BlombergKStatistic extends Statistic.Abstract implements ModelListe private Matrix Linv; private final int treeDim; private double expectedRatio; - private final ContinuousDiffusionIntegrator integrator; private final ContinuousDataLikelihoodDelegate delegate; private final double[] k; @@ -37,7 +34,6 @@ public BlombergKStatistic(TreeDataLikelihood traitLikelihood, String traitName) this.traitDim = traitLikelihood.getDataLikelihoodDelegate().getTraitDim(); this.treeDim = tree.getTaxonCount(); this.delegate = (ContinuousDataLikelihoodDelegate) traitLikelihood.getDataLikelihoodDelegate(); - this.integrator = delegate.getIntegrator(); this.k = new double[traitDim]; } @@ -56,7 +52,7 @@ public double getStatisticValue(int dim) { } - public double computeStatistics() { + public void computeStatistics() { if (needToUpdateTree) { double[][] treeStructure = MultivariateTraitDebugUtilities.getTreeVariance(tree, traitLikelihood.getBranchRateModel(), @@ -104,14 +100,8 @@ public double computeStatistics() { double[] treeTraits = (double[]) treeTrait.getTrait(tree, null); - PrecisionType type = delegate.getDataModel().getPrecisionType(); - - double[] partial = new double[type.getPartialsDimension(traitDim)]; - - integrator.getPostOrderPartial(delegate.getActiveNodeIndex(tree.getRoot().getNumber()), partial); - double mean[] = new double[traitDim]; - System.arraycopy(partial, type.getMeanOffset(traitDim), mean, 0, traitDim); + double[] mean = delegate.getPostOrderRootMean(); double[] thisTrait = new double[treeDim]; @@ -135,8 +125,6 @@ public double computeStatistics() { k[trait] = (mse0 / mse) / expectedRatio; } - - return 0; } private double sumSquares(double[] x) { From fb9beb4701318df6eaeeb735d57c3859d88bf258 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Tue, 21 Jun 2022 19:20:36 -0700 Subject: [PATCH 088/196] bug fixes + tweak to increase numeric stability (still not working, but moving in the right direction?) --- src/dr/inference/model/BoundedSpace.java | 75 +++++++++++++++---- ...flectiveHamiltonianMonteCarloOperator.java | 30 +++++--- src/dr/math/matrixAlgebra/EJMLUtils.java | 2 +- 3 files changed, 80 insertions(+), 27 deletions(-) diff --git a/src/dr/inference/model/BoundedSpace.java b/src/dr/inference/model/BoundedSpace.java index 312ca03024..128cb9a786 100644 --- a/src/dr/inference/model/BoundedSpace.java +++ b/src/dr/inference/model/BoundedSpace.java @@ -3,6 +3,7 @@ import dr.app.bss.Utils; import dr.evomodel.substmodel.ColtEigenSystem; import dr.evomodel.substmodel.EigenDecomposition; +import dr.math.MathUtils; import dr.math.matrixAlgebra.EJMLUtils; import dr.math.matrixAlgebra.IllegalDimension; import dr.math.matrixAlgebra.Matrix; @@ -41,7 +42,8 @@ public IntersectionDistances(double forwardDistance, double backwardDistance) { class Correlation implements BoundedSpace { - private static final boolean DEBUG = false; + private static final boolean DEBUG = true; + private static final double TOL = 0; private final int dim; public Correlation(int dim) { @@ -95,10 +97,23 @@ public boolean isWithinBounds(double[] x) { return true; } - @Override - public IntersectionDistances distancesToBoundary(double[] origin, double[] direction) { + private double[] robustTrajectoryEigenValues(double[] origin, double[] direction) { + double t = MathUtils.nextDouble(); + double[] newOrigin = new double[origin.length]; + for (int i = 0; i < origin.length; i++) { + newOrigin[i] = origin[i] + t * direction[i]; + } + + double[] shiftedValues = trajectoryEigenvalues(newOrigin, direction); + for (int i = 0; i < shiftedValues.length; i++) { + shiftedValues[i] -= t; + } + return shiftedValues; + } + + private double[] trajectoryEigenvalues(double[] origin, double[] direction) { double[] x = new double[origin.length]; - System.arraycopy(direction, 0, x, 0, x.length); + System.arraycopy(direction, 0, x, 0, x.length); //TODO: is this necessary? SymmetricMatrix Y = compoundCorrelationSymmetricMatrix(origin, dim); SymmetricMatrix X = compoundSymmetricMatrix(0.0, x, dim); @@ -116,32 +131,51 @@ public IntersectionDistances distancesToBoundary(double[] origin, double[] direc ColtEigenSystem eigenSystem = new ColtEigenSystem(dim); EigenDecomposition decomposition = eigenSystem.decomposeMatrix(Z.toComponents()); //TODO: only need largest magnitude eigenvalues double[] values = decomposition.getEigenValues(); + for (int i = 0; i < values.length; i++) { + values[i] = 1 / values[i]; + } + + return values; - double maxNegative = 0; - double maxPositive = 0; + } + + @Override + public IntersectionDistances distancesToBoundary(double[] origin, double[] direction) { + + + double values[] = robustTrajectoryEigenValues(origin, direction); + + double minNegative = Double.NEGATIVE_INFINITY; + double minPositive = Double.POSITIVE_INFINITY; for (int i = 0; i < values.length; i++) { double value = values[i]; - if (value < 0 && value < maxNegative) { - maxNegative = value; - } else if (value >= 0 & value > maxPositive) { - maxPositive = value; + if (value < -TOL && value > minNegative) { + minNegative = value; + } else if (value >= TOL & value < minPositive) { + minPositive = value; } } if (DEBUG) { + SymmetricMatrix Y = compoundCorrelationSymmetricMatrix(origin, dim); + SymmetricMatrix X = compoundSymmetricMatrix(0.0, direction, dim); System.out.print("Eigenvalues: "); Utils.printArray(values); Matrix S = new SymmetricMatrix(dim, dim); Matrix T = new SymmetricMatrix(dim, dim); + double absMax = 0.0; for (int i = 0; i < dim; i++) { S.set(i, i, 1); T.set(i, i, 1); for (int j = (i + 1); j < dim; j++) { double y = Y.toComponents()[i][j]; double z = X.toComponents()[i][j]; - double valueS = y - z / maxNegative; - double valueT = y - z / maxPositive; + double valueS = y - z * minNegative; + double valueT = y - z * minPositive; + if (Math.abs(valueS) > absMax) absMax = Math.abs(valueS); + if (Math.abs(valueT) > absMax) absMax = Math.abs(valueT); + S.set(i, j, valueS); S.set(j, i, valueS); T.set(i, j, valueT); @@ -149,23 +183,32 @@ public IntersectionDistances distancesToBoundary(double[] origin, double[] direc } } try { - System.out.println("neg: \n\tt = " + maxNegative); + System.out.println("starting position: "); + System.out.println("\tdet = " + Y.determinant()); + System.out.println(Y); + System.out.println("direction:"); + System.out.println(X); + System.out.println(); + System.out.println("neg: \n\tt = " + minNegative); System.out.println("\tdet = " + S.determinant()); System.out.println(S); - System.out.println("pos: \n\tt = " + maxPositive); + System.out.println("pos: \n\tt = " + minPositive); System.out.println("\tdet = " + T.determinant()); System.out.println(T); } catch (IllegalDimension illegalDimension) { illegalDimension.printStackTrace(); } + + if (absMax > 1.0) { + throw new RuntimeException("Cannot exceed 1"); + } } - return new IntersectionDistances(1 / maxPositive, -1 / maxNegative); + return new IntersectionDistances(minPositive, minNegative); } @Override public double[] getNormalVectorAtBoundary(double[] position) { - Utils.printArray(position); double[] c = compoundCorrelationSymmetricMatrix(position, dim).toArrayComponents(); DenseMatrix64F C = DenseMatrix64F.wrap(dim, dim, c); DenseMatrix64F A; diff --git a/src/dr/inference/operators/hmc/ReflectiveHamiltonianMonteCarloOperator.java b/src/dr/inference/operators/hmc/ReflectiveHamiltonianMonteCarloOperator.java index 0966ba13d4..db35ad68a7 100644 --- a/src/dr/inference/operators/hmc/ReflectiveHamiltonianMonteCarloOperator.java +++ b/src/dr/inference/operators/hmc/ReflectiveHamiltonianMonteCarloOperator.java @@ -35,6 +35,8 @@ import dr.util.Transform; import dr.xml.Reportable; +import java.util.ArrayList; + /** * @author Xiang Ji @@ -118,6 +120,7 @@ public void updatePosition(double[] position, WrappedVector momentum, class WithMultivariateCurvedBounds extends WithBounds { private final BoundedSpace space; + public final int[] defaultIndices; WithMultivariateCurvedBounds(Parameter parameter, InstabilityHandler instabilityHandler, @@ -126,6 +129,16 @@ class WithMultivariateCurvedBounds extends WithBounds { BoundedSpace space) { super(parameter, instabilityHandler, preconditioning, mask); this.space = space; + ArrayList inds = new ArrayList<>(); + for (int i = 0; i < mask.length; i++) { + if (mask[i] == 1) { + inds.add(i); + } + } + this.defaultIndices = new int[inds.size()]; + for (int i = 0; i < inds.size(); i++) { + defaultIndices[i] = inds.get(i); + } } @@ -136,20 +149,23 @@ protected ReflectionEvent nextEvent(double[] position, WrappedVector momentum, d if (DEBUG) { System.out.println("Time to reflection: " + timeToReflection); + System.out.println("Interval length: " + intervalLength); } - - if (timeToReflection < intervalLength) { - return new ReflectionEvent(ReflectionType.None, Double.NaN, intervalLength, new int[0]); + if (timeToReflection > intervalLength) { + return new ReflectionEvent(ReflectionType.None, intervalLength, Double.NaN, new int[0]); } else { + if (DEBUG) { + System.out.println("!!!!!!!!!!!!!!!REFLECTION!!!!!!!!!!!!!!!"); + } double[] boundaryPosition = new double[position.length]; for (int i = 0; i < position.length; i++) { boundaryPosition[i] = position[i] + timeToReflection * velocity[i]; } double[] normalVector = space.getNormalVectorAtBoundary(boundaryPosition); return new ReflectionEvent(ReflectionType.MultivariateReflection, timeToReflection, - boundaryPosition, normalVector); + boundaryPosition, normalVector, defaultIndices); } } } @@ -308,12 +324,6 @@ class ReflectionEvent { this(type, eventTime, new double[]{eventLocation}, null, indices); } - ReflectionEvent(ReflectionType type, - double eventTime, - double[] eventLocation, - double[] normalVector) { - this(type, eventTime, eventLocation, normalVector, new int[0]); - } public double getEventTime() { return eventTime; diff --git a/src/dr/math/matrixAlgebra/EJMLUtils.java b/src/dr/math/matrixAlgebra/EJMLUtils.java index 4f12aef7ab..994e328e37 100644 --- a/src/dr/math/matrixAlgebra/EJMLUtils.java +++ b/src/dr/math/matrixAlgebra/EJMLUtils.java @@ -9,7 +9,7 @@ public class EJMLUtils { - private static final Boolean DEBUG = true; + private static final Boolean DEBUG = false; public static void addWithTransposed(DenseMatrix64F X) { From 8ea0ed303a22b2bc2e95eb4429decec495a2f4c2 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Wed, 22 Jun 2022 15:02:09 -0700 Subject: [PATCH 089/196] working reflective HMC (very messy code but wanted to commit before I accidentally broke anything) --- src/dr/inference/model/BoundedSpace.java | 38 ++++++-- ...flectiveHamiltonianMonteCarloOperator.java | 89 +++++++++++++++++-- 2 files changed, 111 insertions(+), 16 deletions(-) diff --git a/src/dr/inference/model/BoundedSpace.java b/src/dr/inference/model/BoundedSpace.java index 128cb9a786..1cb097b555 100644 --- a/src/dr/inference/model/BoundedSpace.java +++ b/src/dr/inference/model/BoundedSpace.java @@ -42,8 +42,8 @@ public IntersectionDistances(double forwardDistance, double backwardDistance) { class Correlation implements BoundedSpace { - private static final boolean DEBUG = true; - private static final double TOL = 0; + private static final boolean DEBUG = false; + private static final double TOL = 1e-10; private final int dim; public Correlation(int dim) { @@ -142,6 +142,18 @@ private double[] trajectoryEigenvalues(double[] origin, double[] direction) { @Override public IntersectionDistances distancesToBoundary(double[] origin, double[] direction) { + if (!isWithinBounds(origin)) { //TODO: make this optional? + SymmetricMatrix C = compoundCorrelationSymmetricMatrix(origin, dim); + System.out.println(C); + try { + System.out.println(C.determinant()); + } catch (IllegalDimension illegalDimension) { + illegalDimension.printStackTrace(); + } + + throw new RuntimeException("Starting position is outside of bounds"); + } + double values[] = robustTrajectoryEigenValues(origin, direction); @@ -156,6 +168,9 @@ public IntersectionDistances distancesToBoundary(double[] origin, double[] direc } } + minPositive = -minPositive; + minNegative = -minNegative; + if (DEBUG) { SymmetricMatrix Y = compoundCorrelationSymmetricMatrix(origin, dim); SymmetricMatrix X = compoundSymmetricMatrix(0.0, direction, dim); @@ -171,10 +186,10 @@ public IntersectionDistances distancesToBoundary(double[] origin, double[] direc for (int j = (i + 1); j < dim; j++) { double y = Y.toComponents()[i][j]; double z = X.toComponents()[i][j]; - double valueS = y - z * minNegative; - double valueT = y - z * minPositive; + double valueS = y + z * minNegative; + double valueT = y + z * minPositive; if (Math.abs(valueS) > absMax) absMax = Math.abs(valueS); - if (Math.abs(valueT) > absMax) absMax = Math.abs(valueT); +// if (Math.abs(valueT) > absMax) absMax = Math.abs(valueT); S.set(i, j, valueS); S.set(j, i, valueS); @@ -182,9 +197,11 @@ public IntersectionDistances distancesToBoundary(double[] origin, double[] direc T.set(j, i, valueT); } } + double detY; try { System.out.println("starting position: "); - System.out.println("\tdet = " + Y.determinant()); + detY = Y.determinant(); + System.out.println("\tdet = " + detY); System.out.println(Y); System.out.println("direction:"); System.out.println(X); @@ -197,14 +214,19 @@ public IntersectionDistances distancesToBoundary(double[] origin, double[] direc System.out.println(T); } catch (IllegalDimension illegalDimension) { illegalDimension.printStackTrace(); + throw new RuntimeException(); + } + + if (detY < -TOL || detY > 1) { + throw new RuntimeException("invalid starting position"); } if (absMax > 1.0) { - throw new RuntimeException("Cannot exceed 1"); + throw new RuntimeException("Invalid ending position"); } } - return new IntersectionDistances(minPositive, minNegative); + return new IntersectionDistances(minNegative, minPositive); } @Override diff --git a/src/dr/inference/operators/hmc/ReflectiveHamiltonianMonteCarloOperator.java b/src/dr/inference/operators/hmc/ReflectiveHamiltonianMonteCarloOperator.java index db35ad68a7..accf12e1b2 100644 --- a/src/dr/inference/operators/hmc/ReflectiveHamiltonianMonteCarloOperator.java +++ b/src/dr/inference/operators/hmc/ReflectiveHamiltonianMonteCarloOperator.java @@ -26,11 +26,14 @@ package dr.inference.operators.hmc; +import dr.app.bss.Utils; import dr.inference.hmc.GradientWrtParameterProvider; import dr.inference.model.*; import dr.inference.operators.AdaptationMode; import dr.inferencexml.operators.hmc.ReflectiveHamiltonianMonteCarloOperatorParser; +import dr.math.matrixAlgebra.IllegalDimension; import dr.math.matrixAlgebra.ReadableVector; +import dr.math.matrixAlgebra.SymmetricMatrix; import dr.math.matrixAlgebra.WrappedVector; import dr.util.Transform; import dr.xml.Reportable; @@ -45,7 +48,7 @@ public class ReflectiveHamiltonianMonteCarloOperator extends HamiltonianMonteCarloOperator implements Reportable { private final GeneralBoundsProvider parameterBound; - private static final boolean DEBUG = true; + private static final boolean DEBUG = false; public ReflectiveHamiltonianMonteCarloOperator(AdaptationMode mode, @@ -108,7 +111,30 @@ public void updatePosition(double[] position, WrappedVector momentum, double collapsedTime = 0.0; while (collapsedTime < functionalStepSize) { ReflectionEvent event = nextEvent(position, momentum, functionalStepSize - collapsedTime); + + if (DEBUG) { + SymmetricMatrix C = SymmetricMatrix.compoundCorrelationSymmetricMatrix(position, 6); //TODO: remove + try { + System.out.println("starting det: " + C.determinant()); + } catch (IllegalDimension illegalDimension) { + illegalDimension.printStackTrace(); + throw new RuntimeException(); + } + } + event.doReflection(position, momentum); + + if (DEBUG) { + System.out.println("event: " + event.getType()); + SymmetricMatrix C = SymmetricMatrix.compoundCorrelationSymmetricMatrix(position, 6); //TODO: remove + try { + System.out.println("ending det: " + C.determinant()); + } catch (IllegalDimension illegalDimension) { + illegalDimension.printStackTrace(); + throw new RuntimeException(); + } + } + collapsedTime += event.getEventTime(); } setParameter(position); @@ -165,6 +191,7 @@ protected ReflectionEvent nextEvent(double[] position, WrappedVector momentum, d } double[] normalVector = space.getNormalVectorAtBoundary(boundaryPosition); return new ReflectionEvent(ReflectionType.MultivariateReflection, timeToReflection, + intervalLength - timeToReflection, boundaryPosition, normalVector, defaultIndices); } } @@ -304,9 +331,11 @@ class ReflectionEvent { private final double[] eventLocation; private final int[] indices; private final double[] normalVector; + private final double remainingTime; ReflectionEvent(ReflectionType type, double eventTime, + double remainingTime, double[] eventLocation, double[] normalVector, int[] indices) { @@ -315,13 +344,14 @@ class ReflectionEvent { this.indices = indices; this.eventLocation = eventLocation; this.normalVector = normalVector; + this.remainingTime = remainingTime; } ReflectionEvent(ReflectionType type, double eventTime, double eventLocation, int[] indices) { - this(type, eventTime, new double[]{eventLocation}, null, indices); + this(type, eventTime, Double.NaN, new double[]{eventLocation}, null, indices); } @@ -334,7 +364,7 @@ public ReflectionType getType() { } public void doReflection(double[] position, WrappedVector momentum) { - type.doReflection(position, preconditioning, momentum, eventLocation, indices, normalVector, eventTime); + type.doReflection(position, preconditioning, momentum, eventLocation, indices, normalVector, eventTime, remainingTime); } } @@ -344,7 +374,16 @@ enum ReflectionType { MultivariateReflection { @Override void doReflection(double[] position, MassPreconditioner preconditioning, WrappedVector momentum, - double eventLocation[], int[] indices, double[] normalVector, double time) { + double eventLocation[], int[] indices, double[] normalVector, double time, + double remainingTime) { + + + if (DEBUG) { + System.out.println("time: " + time); + System.out.print("start: "); + Utils.printArray(position); + System.out.println(momentum); + } updatePosition(position, preconditioning, momentum, time); double vn = 0; @@ -361,12 +400,31 @@ void doReflection(double[] position, MassPreconditioner preconditioning, Wrapped momentum.set(i, momentum.get(i) - c * normalVector[i]); position[i] = eventLocation[i]; } + + if (DEBUG) { + System.out.print("end: "); + Utils.printArray(position); + System.out.println(momentum); + } + + if (BOUNCE) { + double t = Math.min(remainingTime, 1e-10); //TODO: need to make sure I'm not leaving the space again, also need to update time later + System.out.println("bounce time: " + t); + updatePosition(position, preconditioning, momentum, t); + } + + if (DEBUG) { + System.out.print("bounce: "); + Utils.printArray(position); + System.out.println(momentum); + } } }, Reflection { @Override void doReflection(double[] position, MassPreconditioner preconditioning, WrappedVector momentum, - double eventLocation[], int[] indices, double[] normalVector, double time) { + double eventLocation[], int[] indices, double[] normalVector, double time, + double remainingTime) { updatePosition(position, preconditioning, momentum, time); momentum.set(indices[0], -momentum.get(indices[0])); position[indices[0]] = eventLocation[0]; @@ -375,7 +433,8 @@ void doReflection(double[] position, MassPreconditioner preconditioning, Wrapped Collision { @Override void doReflection(double[] position, MassPreconditioner preconditioning, WrappedVector momentum, - double eventLocation[], int[] indices, double[] normalVector, double time) { + double eventLocation[], int[] indices, double[] normalVector, double time, + double remainingTime) { updatePosition(position, preconditioning, momentum, time); ReadableVector updatedMomentum = preconditioning.doCollision(indices, momentum); @@ -389,8 +448,19 @@ void doReflection(double[] position, MassPreconditioner preconditioning, Wrapped None { @Override void doReflection(double[] position, MassPreconditioner preconditioning, WrappedVector momentum, - double eventLocation[], int[] indices, double[] normalVector, double time) { + double eventLocation[], int[] indices, double[] normalVector, double time, + double remainginTime) { + + if (DEBUG) { + System.out.println("time: " + time); + System.out.print("start: "); + Utils.printArray(position); + } updatePosition(position, preconditioning, momentum, time); + if (DEBUG) { + System.out.print("end: "); + Utils.printArray(position); + } } }; @@ -402,7 +472,10 @@ void updatePosition(double[] position, MassPreconditioner preconditioning, Wrapp } abstract void doReflection(double[] position, MassPreconditioner preconditioning, WrappedVector momentum, - double eventLocation[], int[] indices, double[] normalVector, double time); + double eventLocation[], int[] indices, double[] normalVector, double time, double remainingTime); + + private static final boolean BOUNCE = true; + } } From 6d8c0354c2c29054e144b63bcce363dc6f05aea5 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Fri, 24 Jun 2022 11:43:18 -0700 Subject: [PATCH 090/196] attempted fix of reflective HMC error --- .../hmc/ReflectiveHamiltonianMonteCarloOperator.java | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/dr/inference/operators/hmc/ReflectiveHamiltonianMonteCarloOperator.java b/src/dr/inference/operators/hmc/ReflectiveHamiltonianMonteCarloOperator.java index accf12e1b2..53058ce314 100644 --- a/src/dr/inference/operators/hmc/ReflectiveHamiltonianMonteCarloOperator.java +++ b/src/dr/inference/operators/hmc/ReflectiveHamiltonianMonteCarloOperator.java @@ -69,6 +69,11 @@ public ReflectiveHamiltonianMonteCarloOperator(AdaptationMode mode, @Override protected LeapFrogEngine constructLeapFrogEngine(Transform transform) { + if (parameterBound == null) { + return null; //will get called again + } + + if (transform != null) { throw new RuntimeException("not yet implemented"); } @@ -156,8 +161,8 @@ class WithMultivariateCurvedBounds extends WithBounds { super(parameter, instabilityHandler, preconditioning, mask); this.space = space; ArrayList inds = new ArrayList<>(); - for (int i = 0; i < mask.length; i++) { - if (mask[i] == 1) { + for (int i = 0; i < parameter.getDimension(); i++) { + if (mask == null || mask[i] == 1) { inds.add(i); } } From e5d54b47010e472bc1ba26ef5088499c3ec4368a Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Fri, 24 Jun 2022 11:47:11 -0700 Subject: [PATCH 091/196] LatentFactorModel never used rowPrecision element (and allowing it to take a MatrixParameterInterface as the data --- src/dr/inference/model/LatentFactorModel.java | 5 +---- .../model/LatentFactorModelParser.java | 14 +++++++------- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/src/dr/inference/model/LatentFactorModel.java b/src/dr/inference/model/LatentFactorModel.java index 1f1d131e95..193fb8e45c 100644 --- a/src/dr/inference/model/LatentFactorModel.java +++ b/src/dr/inference/model/LatentFactorModel.java @@ -48,7 +48,6 @@ public class LatentFactorModel extends AbstractModelLikelihood implements Citabl private final MatrixParameterInterface factors; private final MatrixParameterInterface loadings; private MatrixParameterInterface sData; - private final DiagonalMatrix rowPrecision; private final DiagonalMatrix colPrecision; private final Parameter continuous; @@ -105,7 +104,7 @@ public class LatentFactorModel extends AbstractModelLikelihood implements Citabl private final int nmeasurements; public LatentFactorModel(MatrixParameterInterface data, MatrixParameterInterface factors, MatrixParameterInterface loadings, - DiagonalMatrix rowPrecision, DiagonalMatrix colPrecision, + DiagonalMatrix colPrecision, Parameter missingIndicator, boolean scaleData, Parameter continuous, boolean newModel, boolean recomputeResiduals, boolean recomputeFactors, boolean recomputeLoadings @@ -182,13 +181,11 @@ public LatentFactorModel(MatrixParameterInterface data, MatrixParameterInterface } } - this.rowPrecision = rowPrecision; this.colPrecision = colPrecision; addVariable(data); addVariable(factors); addVariable(loadings); - addVariable(rowPrecision); addVariable(colPrecision); dimFactors = factors.getRowDimension(); diff --git a/src/dr/inferencexml/model/LatentFactorModelParser.java b/src/dr/inferencexml/model/LatentFactorModelParser.java index 585bba33ef..5d0b5d8b18 100644 --- a/src/dr/inferencexml/model/LatentFactorModelParser.java +++ b/src/dr/inferencexml/model/LatentFactorModelParser.java @@ -41,7 +41,7 @@ public class LatentFactorModelParser extends AbstractXMLObjectParser { public final static String FACTORS = "factors"; public final static String DATA = "data"; public final static String LOADINGS = "loadings"; - public static final String ROW_PRECISION = "rowPrecision"; + // public static final String ROW_PRECISION = "rowPrecision"; public static final String COLUMN_PRECISION = "columnPrecision"; public static final String SCALE_DATA = "scaleData"; public static final String CONTINUOUS = "continuous"; @@ -59,7 +59,7 @@ public String getParserName() { public Object parseXMLObject(XMLObject xo) throws XMLParseException { MatrixParameterInterface factors; - if (xo.getChild(FACTORS).getChild(FastMatrixParameter.class) == null) { + if (xo.getChild(FACTORS).getChild(MatrixParameterInterface.class) == null) { CompoundParameter factorsTemp = (CompoundParameter) xo.getChild(FACTORS).getChild(CompoundParameter.class); factors = MatrixParameter.recast(factorsTemp.getParameterName(), factorsTemp); } else { @@ -103,7 +103,7 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { } } - DiagonalMatrix rowPrecision = (DiagonalMatrix) xo.getChild(ROW_PRECISION).getChild(MatrixParameter.class); +// DiagonalMatrix rowPrecision = (DiagonalMatrix) xo.getChild(ROW_PRECISION).getChild(MatrixParameter.class); DiagonalMatrix colPrecision = (DiagonalMatrix) xo.getChild(COLUMN_PRECISION).getChild(MatrixParameter.class); boolean newModel = xo.getAttribute(COMPUTE_RESIDUALS_FOR_DISCRETE, true); boolean computeResiduals = xo.getAttribute(RECOMPUTE_RESIDUALS, true); @@ -126,7 +126,7 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { // } - return new LatentFactorModel(dataParameter, factors, loadings, rowPrecision, colPrecision, missingIndicator, scaleData, continuous, newModel, computeResiduals, computeFactors, computeLoadings); + return new LatentFactorModel(dataParameter, factors, loadings, colPrecision, missingIndicator, scaleData, continuous, newModel, computeResiduals, computeFactors, computeLoadings); } private static final XMLSyntaxRule[] rules = { @@ -147,9 +147,9 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { new ElementRule(LOADINGS, new XMLSyntaxRule[]{ new ElementRule(MatrixParameterInterface.class) }), - new ElementRule(ROW_PRECISION, new XMLSyntaxRule[]{ - new ElementRule(DiagonalMatrix.class) - }), +// new ElementRule(ROW_PRECISION, new XMLSyntaxRule[]{ +// new ElementRule(DiagonalMatrix.class) +// }), new ElementRule(COLUMN_PRECISION, new XMLSyntaxRule[]{ new ElementRule(DiagonalMatrix.class) }), From 65da1cb97636363065ed762ce746caba39f3b95b Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Fri, 24 Jun 2022 12:15:10 -0700 Subject: [PATCH 092/196] less strict test requirements for multithreading (will hopefully stop errors) --- ci/TestXML/testNewLoadingsGibbsOperator.xml | 53 +++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/ci/TestXML/testNewLoadingsGibbsOperator.xml b/ci/TestXML/testNewLoadingsGibbsOperator.xml index 9168aa2408..23f095f8eb 100644 --- a/ci/TestXML/testNewLoadingsGibbsOperator.xml +++ b/ci/TestXML/testNewLoadingsGibbsOperator.xml @@ -392,6 +392,20 @@ + + + -0.3721665664840462 0.48001788993686695 -0.34886569662988637 0.2915078879991332 -0.10091700602930198 + -0.31415199927621096 -2.3884331810214676 1.4849101564580578 -1.7038907684395337 -1.6692350718773161 + -0.09114521475759652 1.0644914784924318 -1.4629205281009046 -0.10761097309593033 -0.09163219006876919 + -0.10718108191461742 -0.49226000993882457 -0.4356048068181444 + + + + + + Check unconstrained loadings mean (multithreaded) + + @@ -444,6 +458,20 @@ + + + -1.401754437093048 0.46714122549955267 -0.34886569662988637 0.2915078879991332 -0.10091700602930198 + -0.31415199927621096 0.0 1.4871332789573564 -1.7038907684395337 -1.6692350718773161 -0.09114521475759652 + 1.0644914784924318 0.0 0.0 -0.09163219006876919 -0.10718108191461742 -0.49226000993882457 + -0.4356048068181444 + + + + + + Check upper triangular loadings mean (multithreaded) + + @@ -588,6 +616,20 @@ + + + -0.5624212060691973 0.6686957665052548 -0.12883363489801392 -0.028483503802251994 -0.20245004154707416 + 0.5727800946633681 -0.1177869727015475 -0.03557128301763026 -0.3402971236017539 0.30078836060143443 + -0.020476521372923114 0.134604155411238 -0.03372109032875503 -0.16741994314436884 -0.032031678146118236 + 0.34725416185170455 0.049927738575763006 -0.07642808814316454 + + + + + + Check unconstrained loadings mean (multithreaded) + + @@ -636,6 +678,17 @@ + + + -0.47980541485547296 0.6422210311912988 -0.12883363489801392 -0.028483503802251994 -0.20245004154707416 0.5727800946633681 0.0 -0.02101776994119088 -0.3402971236017539 0.30078836060143443 -0.020476521372923114 0.134604155411238 0.0 0.0 -0.032031678146118236 0.34725416185170455 0.049927738575763006 -0.07642808814316454 + + + + + + Check upper triangular loadings mean (multithreaded) + + From 877e31f9a5c9bba91b9fc712c6c7bfecc75075e8 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Fri, 24 Jun 2022 12:32:07 -0700 Subject: [PATCH 093/196] adapting MaskedMatrixParameter to be able to handle a FastMtrixParameter --- .../model/MaskedMatrixParameter.java | 132 ++++++++++++++++++ .../model/MaskedMatrixParameterParser.java | 12 +- 2 files changed, 138 insertions(+), 6 deletions(-) create mode 100644 src/dr/inference/model/MaskedMatrixParameter.java diff --git a/src/dr/inference/model/MaskedMatrixParameter.java b/src/dr/inference/model/MaskedMatrixParameter.java new file mode 100644 index 0000000000..751d035bb2 --- /dev/null +++ b/src/dr/inference/model/MaskedMatrixParameter.java @@ -0,0 +1,132 @@ +package dr.inference.model; + +import java.util.ArrayList; + +public class MaskedMatrixParameter extends CompoundParameter implements MatrixParameterInterface, VariableListener { + + private final MatrixParameterInterface matrix; + private final Parameter mask; + private ArrayList rows = new ArrayList<>(); + + + public MaskedMatrixParameter(MatrixParameterInterface matrix, Parameter mask) { + super(matrix.getParameterName() + ".mask"); + this.matrix = matrix; + this.mask = mask; + addParameter(matrix); + addParameter(mask); + this.rows = makeRowsFromMask(); + } + + @Override + public void variableChangedEvent(Variable variable, int index, Parameter.ChangeType type) { + if (variable == mask) { + ArrayList oldRows = rows; + this.rows = makeRowsFromMask(); + int ni = rows.size(); + int oi = oldRows.size(); + if (ni == oi) { + type = ChangeType.ALL_VALUES_CHANGED; + } else if (ni < oi) { + type = ChangeType.REMOVED; + } else { + type = ChangeType.ADDED; + } + index = -1; + } + super.variableChangedEvent(variable, index, type); + } + + private ArrayList makeRowsFromMask() { + ArrayList newRows = new ArrayList<>(); + for (int i = 0; i < mask.getDimension(); i++) { + if (mask.getParameterValue(i) == 1) { + newRows.add(i); + } + } + return newRows; + } + + + @Override + public double getParameterValue(int row, int col) { + return matrix.getParameterValue(rows.get(row), col); + } + + + @Override + public void setParameterValue(int row, int col, double value) { + matrix.setParameterValue(rows.get(row), col, value); + } + + @Override + public void setParameterValueQuietly(int row, int col, double value) { + matrix.setParameterValueQuietly(rows.get(row), col, value); + } + + @Override + public void setParameterValueNotifyChangedAll(int row, int col, double value) { + matrix.setParameterValueNotifyChangedAll(rows.get(row), col, value); + } + + @Override + public double[] getColumnValues(int col) { + double[] maskedValues = new double[rows.size()]; + for (int i = 0; i < rows.size(); i++) { + maskedValues[i] = matrix.getParameterValue(rows.get(i), col); + } + return maskedValues; + } + + @Override + public double[][] getParameterAsMatrix() { + double[][] values = new double[matrix.getColumnDimension()][rows.size()]; + for (int i = 0; i < matrix.getColumnDimension(); i++) { + for (int j = 0; j < rows.size(); j++) { + values[i][j] = getParameterValue(i, j); + } + } + return values; + } + + @Override + public int getColumnDimension() { + return matrix.getColumnDimension(); + } + + @Override + public int getRowDimension() { + return rows.size(); + } + + @Override + public int getUniqueParameterCount() { + return matrix.getUniqueParameterCount(); + } + + @Override + public Parameter getUniqueParameter(int index) { + return matrix.getUniqueParameter(index); + } + + @Override + public void copyParameterValues(double[] destination, int offset) { + throw new RuntimeException("not yet implemented"); + } + + @Override + public void setAllParameterValuesQuietly(double[] values, int offset) { + throw new RuntimeException("not yet implemented"); + } + + @Override + public String toSymmetricString() { + throw new RuntimeException("not yet implemented"); + } + + @Override + public boolean isConstrainedSymmetric() { + return false; + } + +} diff --git a/src/dr/inferencexml/model/MaskedMatrixParameterParser.java b/src/dr/inferencexml/model/MaskedMatrixParameterParser.java index 349aee41db..c99dac200f 100644 --- a/src/dr/inferencexml/model/MaskedMatrixParameterParser.java +++ b/src/dr/inferencexml/model/MaskedMatrixParameterParser.java @@ -25,11 +25,11 @@ package dr.inferencexml.model; -import dr.inference.model.MaskedParameter; -import dr.inference.model.MatrixParameter; -import dr.inference.model.Parameter; +import dr.inference.model.*; import dr.xml.*; +import java.util.ArrayList; + /** * @author Marc A. Suchard */ @@ -44,7 +44,7 @@ public class MaskedMatrixParameterParser extends AbstractXMLObjectParser { public Object parseXMLObject(XMLObject xo) throws XMLParseException { - MatrixParameter matrix = (MatrixParameter) xo.getChild(MatrixParameter.class); + MatrixParameterInterface matrix = (MatrixParameterInterface) xo.getChild(MatrixParameterInterface.class); // System.err.println("colDim " + matrix.getColumnDimension()); @@ -86,7 +86,7 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { maskedParameters[col].addMask(mask, ones); } - MatrixParameter maskedMatrix = new MatrixParameter(matrix.getId() + ".masked", maskedParameters); + MaskedMatrixParameter maskedMatrix = new MaskedMatrixParameter(matrix, mask); // for (int col = 0; col < matrix.getColumnDimension(); ++col) { // maskedMatrix.addParameter(matrix.getParameter(col)); @@ -100,7 +100,7 @@ public XMLSyntaxRule[] getSyntaxRules() { } private final XMLSyntaxRule[] rules = { - new ElementRule(MatrixParameter.class), + new ElementRule(MatrixParameterInterface.class), new ElementRule(MASKING, new XMLSyntaxRule[]{ new ElementRule(Parameter.class) From c753b9d34432d39b5951ba8b818ccf7bf75e523a Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Fri, 24 Jun 2022 12:32:52 -0700 Subject: [PATCH 094/196] testing adjugate calculations for singular matrices --- .../missingData/RobustAdjugateTest.java | 86 +++++++++++++++++++ 1 file changed, 86 insertions(+) create mode 100644 src/test/dr/math/matrixAlgebra/missingData/RobustAdjugateTest.java diff --git a/src/test/dr/math/matrixAlgebra/missingData/RobustAdjugateTest.java b/src/test/dr/math/matrixAlgebra/missingData/RobustAdjugateTest.java new file mode 100644 index 0000000000..6ae8ad9969 --- /dev/null +++ b/src/test/dr/math/matrixAlgebra/missingData/RobustAdjugateTest.java @@ -0,0 +1,86 @@ +package test.dr.math.matrixAlgebra.missingData; + +import dr.math.matrixAlgebra.EJMLUtils; +import org.ejml.data.DenseMatrix64F; + +import static org.junit.Assert.assertEquals; + +import org.junit.Test; + +public class RobustAdjugateTest { + + interface Instance { + DenseMatrix64F getMatrix(); + + double[] getTrueAdjugate(); + + abstract class Basic implements Instance { + + } + } + + Instance test0 = new Instance.Basic() { + + @Override + public DenseMatrix64F getMatrix() { + int dim = 4; + return DenseMatrix64F.wrap( + dim, dim, + new double[]{0.9866964953013155, 0.15545690836486598, -1.1513262141217357, 0.4725542031241179, + -0.19421901601084784, 0.5978801523926415, -0.45827856969399533, -0.09859711476482771, + 0.8510272497600995, 0.46551635730932206, -0.24277738841539084, 0.5870842918793533, + 1.3516144337818743, 0.3483227354656835, 0.3519407851667098, -0.26191847550729874} + ); + } + + @Override + public double[] getTrueAdjugate() { + return new double[]{-0.25959911412615355, 0.2806715010354877, 0.05816932938829323, -0.44364091890937585, + 0.49990493839334016, -1.0147493713240305, -0.6250795751738031, -0.1171771471159761, + 0.7038993830503742, 0.3213216171925368, -0.5630922852614357, -0.11313981508126116, + 0.2710051185801451, 0.5306443430372886, -1.28773681446716, 0.689220622236884}; + } + }; + + Instance test1 = new Instance.Basic() { + + @Override + public DenseMatrix64F getMatrix() { + int dim = 4; + return DenseMatrix64F.wrap(dim, dim, + new double[]{1, 0.0, -0.4564755436471619, 0.20065293735042314, + 0.0, 1, -0.8800002808267111, 0.041910304168150225, + -0.4564755436471619, -0.8800002808267111, 1, 0.0, + 0.20065293735042314, 0.041910304168150225, 0.0, 1.0}); + } + + @Override + public double[] getTrueAdjugate() { + return new double[]{0.22384303214944315, 0.4101080322366171, 0.46307405332955914, -0.06210251427904914, + 0.4101080322366171, 0.7513684767846751, 0.8484087575441298, -0.11377946270368472, + 0.46307405332955914, 0.8484087575441298, 0.9579819251371803, -0.12847423809893943, + -0.06210251427904914, -0.11377946270368472, -0.1284742380989394, 0.017229583796937953}; + } + }; + + private static final double TOL = 1e-12; + + + @Test + public void adjugate() throws Exception { + + Instance[] tests = new Instance[]{test0, test1}; + for (Instance test : tests) { + DenseMatrix64F X = test.getMatrix(); + + DenseMatrix64F A = EJMLUtils.computeRobustAdjugate(X); + + double[] a = test.getTrueAdjugate(); + + for (int i = 0; i < a.length; i++) { + assertEquals(A.get(i), a[i], TOL); + } + + } + } +} From b5817ff1a0151d81aa135b7466d18ae47c52d44d Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Fri, 24 Jun 2022 12:36:14 -0700 Subject: [PATCH 095/196] Gibbs sampler for using integrated trait likelihood to draw from full conditional of (non integrated) trait tips --- .../app/beast/development_parsers.properties | 1 + .../GaussianTreeTraitGibbsOperator.java | 74 +++++++++++++++++++ 2 files changed, 75 insertions(+) create mode 100644 src/dr/inference/operators/factorAnalysis/GaussianTreeTraitGibbsOperator.java diff --git a/src/dr/app/beast/development_parsers.properties b/src/dr/app/beast/development_parsers.properties index 9a07c2106b..1ecf8b7e73 100644 --- a/src/dr/app/beast/development_parsers.properties +++ b/src/dr/app/beast/development_parsers.properties @@ -307,6 +307,7 @@ dr.inference.operators.rejection.DescendingAndSpacedCondition dr.evomodel.operators.ExtendedLatentLiabilityGibbsOperator dr.inference.model.FactorProportionStatistic dr.inferencexml.model.BlombergKStatisticParser +dr.inference.operators.factorAnalysis.GaussianTreeTraitGibbsOperator # Shrinkage dr.inference.model.MaskFromTree diff --git a/src/dr/inference/operators/factorAnalysis/GaussianTreeTraitGibbsOperator.java b/src/dr/inference/operators/factorAnalysis/GaussianTreeTraitGibbsOperator.java new file mode 100644 index 0000000000..6631a5ac55 --- /dev/null +++ b/src/dr/inference/operators/factorAnalysis/GaussianTreeTraitGibbsOperator.java @@ -0,0 +1,74 @@ +package dr.inference.operators.factorAnalysis; + +import dr.evolution.tree.TreeTrait; +import dr.evomodel.treedatalikelihood.TreeDataLikelihood; +import dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate; +import dr.evomodelxml.treelikelihood.TreeTraitParserUtilities; +import dr.inference.model.Parameter; +import dr.inference.operators.GibbsOperator; +import dr.inference.operators.SimpleMCMCOperator; +import dr.xml.AbstractXMLObjectParser; +import dr.xml.XMLObject; +import dr.xml.XMLParseException; +import dr.xml.XMLSyntaxRule; + +public class GaussianTreeTraitGibbsOperator extends SimpleMCMCOperator implements GibbsOperator { + + public static final String TRAIT_GIBBS = "gaussianTreeTraitGibbsOperator"; + private final TreeTrait treeTrait; + private final Parameter traitParameter; + private final TreeDataLikelihood treeDataLikelihood; + + public GaussianTreeTraitGibbsOperator(TreeDataLikelihood treeDataLikelihood, Parameter parameter, String traitName) { + this.traitParameter = parameter; + this.treeDataLikelihood = treeDataLikelihood; + ContinuousDataLikelihoodDelegate delegate = (ContinuousDataLikelihoodDelegate) treeDataLikelihood.getDataLikelihoodDelegate(); + this.treeTrait = treeDataLikelihood.getTreeTrait(delegate.getDataModel().getTipTraitName()); + } + + + @Override + public String getOperatorName() { + return TRAIT_GIBBS; + } + + @Override + public double doOperation() { + treeDataLikelihood.fireModelChanged(); + double[] traits = treeTrait.getTrait(treeDataLikelihood.getTree(), null); + traitParameter.setAllParameterValuesQuietly(traits); + traitParameter.fireParameterChangedEvent(); + return Double.POSITIVE_INFINITY; + } + + public static final AbstractXMLObjectParser PARSER = new AbstractXMLObjectParser() { + @Override + public Object parseXMLObject(XMLObject xo) throws XMLParseException { + String traitName = xo.getStringAttribute(TreeTraitParserUtilities.TRAIT_NAME); + TreeDataLikelihood likelihood = (TreeDataLikelihood) xo.getChild(TreeDataLikelihood.class); + Parameter parameter = (Parameter) xo.getChild(Parameter.class); + return new GaussianTreeTraitGibbsOperator(likelihood, parameter, traitName); + } + + @Override + public XMLSyntaxRule[] getSyntaxRules() { + return new XMLSyntaxRule[0]; + } + + @Override + public String getParserDescription() { + return null; + } + + @Override + public Class getReturnType() { + return GaussianTreeTraitGibbsOperator.class; + } + + @Override + public String getParserName() { + return TRAIT_GIBBS; + } + }; + +} From 127a5bbb792a720d4451bb76e3955c39ba9f2051 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Wed, 3 Aug 2022 11:51:19 -0700 Subject: [PATCH 096/196] IndependentNormalDistributionModel also provides gradients --- ...testIndependentNormalDistributionModel.xml | 129 ++++++++++++++++++ .../IndependentNormalDistributionModel.java | 51 ++++++- ...ependentNormalDistributionModelParser.java | 7 +- 3 files changed, 177 insertions(+), 10 deletions(-) create mode 100644 ci/TestXML/testIndependentNormalDistributionModel.xml diff --git a/ci/TestXML/testIndependentNormalDistributionModel.xml b/ci/TestXML/testIndependentNormalDistributionModel.xml new file mode 100644 index 0000000000..0d5607569b --- /dev/null +++ b/ci/TestXML/testIndependentNormalDistributionModel.xml @@ -0,0 +1,129 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Check log likelihood (precision parameterized) + + + + + + -19.93664004895008 + + + + + + Check gradient (precision parameterized) + + + + + + -0.6216961022466229 2.5487182843174216 -0.19499099542972953 0.7222583134782101 -10.07301892746935 + + + + + + + + + + + + + + + Check log likelihood (variance parameterized) + + + + + + -19.93664004895008 + + + + + + Check gradient (variance parameterized) + + + + + + -0.6216961022466229 2.5487182843174216 -0.19499099542972953 0.7222583134782101 -10.07301892746935 + + + + + + + + \ No newline at end of file diff --git a/src/dr/inference/distribution/IndependentNormalDistributionModel.java b/src/dr/inference/distribution/IndependentNormalDistributionModel.java index a1c0bd6f73..d1cc59b707 100644 --- a/src/dr/inference/distribution/IndependentNormalDistributionModel.java +++ b/src/dr/inference/distribution/IndependentNormalDistributionModel.java @@ -4,6 +4,7 @@ import dr.inference.model.*; import dr.inference.operators.repeatedMeasures.MultiplicativeGammaGibbsHelper; import dr.math.distributions.NormalDistribution; +import dr.xml.Reportable; /** * @author Max Tolkoff @@ -12,13 +13,16 @@ */ public class IndependentNormalDistributionModel extends AbstractModelLikelihood implements NormalStatisticsProvider, - MultiplicativeGammaGibbsHelper { + MultiplicativeGammaGibbsHelper, GradientWrtParameterProvider, Reportable { Parameter mean; Parameter variance; Parameter precision; Parameter data; boolean usePrecision; + public static String INDEPENDENT_NORMAL_DISTRIBUTION_MODEL = "independentNormalDistributionModel"; + + public IndependentNormalDistributionModel(String id, Parameter mean, Parameter variance, Parameter precision, Parameter data) { super(id); addVariable(mean); @@ -71,12 +75,7 @@ public Model getModel() { public double getLogLikelihood() { double sum = 0; for (int i = 0; i < data.getDimension(); i++) { - double sd; - if (usePrecision) { - sd = Math.sqrt(1 / precision.getParameterValue(i)); - } else { - sd = Math.sqrt(variance.getParameterValue(i)); - } + double sd = getNormalSD(i); sum += NormalDistribution.logPdf(data.getParameterValue(i), mean.getParameterValue(i), sd); @@ -138,4 +137,42 @@ public int getRowDimension() { public int getColumnDimension() { return data.getDimension(); } + + @Override + public Likelihood getLikelihood() { + return this; + } + + @Override + public Parameter getParameter() { + return data; + } + + @Override + public int getDimension() { + return data.getDimension(); + } + + @Override + public double[] getGradientLogDensity() { + double[] grad = new double[getDimension()]; + for (int i = 0; i < getDimension(); i++) { + double sd = getNormalSD(i); + grad[i] = NormalDistribution.gradLogPdf(data.getParameterValue(i), mean.getParameterValue(i), sd); + } + return grad; + } + + @Override + public String getReport() { + StringBuilder sb = new StringBuilder(INDEPENDENT_NORMAL_DISTRIBUTION_MODEL + " report:\n"); + sb.append("\tlogLikelihood: " + getLogLikelihood() + "\n"); + sb.append("\tgradient: "); + double[] grad = getGradientLogDensity(); + for (int i = 0; i < grad.length; i++) { + sb.append(grad[i] + " "); + } + sb.append("\n\n"); + return sb.toString(); + } } diff --git a/src/dr/inferencexml/distribution/IndependentNormalDistributionModelParser.java b/src/dr/inferencexml/distribution/IndependentNormalDistributionModelParser.java index 97fd6f9028..3846f87f34 100644 --- a/src/dr/inferencexml/distribution/IndependentNormalDistributionModelParser.java +++ b/src/dr/inferencexml/distribution/IndependentNormalDistributionModelParser.java @@ -4,8 +4,9 @@ import dr.inference.model.Parameter; import dr.xml.*; +import static dr.inference.distribution.IndependentNormalDistributionModel.INDEPENDENT_NORMAL_DISTRIBUTION_MODEL; + public class IndependentNormalDistributionModelParser extends AbstractXMLObjectParser { - public static String INDEPENDENT_NORMAL_DISTRIBUTION_MODEL = "independentNormalDistributionModel"; public static String MEAN = "mean"; public static String VARIANCE = "variance"; public static String PRECISION = "precision"; @@ -17,11 +18,11 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { String id = xo.getStringAttribute(ID); Parameter mean = (Parameter) xo.getChild(MEAN).getChild(Parameter.class); Parameter precision = null; - if(xo.getChild(PRECISION) != null){ + if (xo.getChild(PRECISION) != null) { precision = (Parameter) xo.getChild(PRECISION).getChild(Parameter.class); } Parameter variance = null; - if(xo.getChild(VARIANCE) != null){ + if (xo.getChild(VARIANCE) != null) { variance = (Parameter) xo.getChild(VARIANCE).getChild(Parameter.class); } Parameter data = (Parameter) xo.getChild(DATA).getChild(Parameter.class); From af328e9cf034e58ec97e8f8f0fa3d8af8c4dad41 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Wed, 3 Aug 2022 12:02:44 -0700 Subject: [PATCH 097/196] finishing GaussianTreeTraitGibbsOperator parser --- .../GaussianTreeTraitGibbsOperator.java | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/dr/inference/operators/factorAnalysis/GaussianTreeTraitGibbsOperator.java b/src/dr/inference/operators/factorAnalysis/GaussianTreeTraitGibbsOperator.java index 6631a5ac55..f2d7651a33 100644 --- a/src/dr/inference/operators/factorAnalysis/GaussianTreeTraitGibbsOperator.java +++ b/src/dr/inference/operators/factorAnalysis/GaussianTreeTraitGibbsOperator.java @@ -3,14 +3,10 @@ import dr.evolution.tree.TreeTrait; import dr.evomodel.treedatalikelihood.TreeDataLikelihood; import dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate; -import dr.evomodelxml.treelikelihood.TreeTraitParserUtilities; import dr.inference.model.Parameter; import dr.inference.operators.GibbsOperator; import dr.inference.operators.SimpleMCMCOperator; -import dr.xml.AbstractXMLObjectParser; -import dr.xml.XMLObject; -import dr.xml.XMLParseException; -import dr.xml.XMLSyntaxRule; +import dr.xml.*; public class GaussianTreeTraitGibbsOperator extends SimpleMCMCOperator implements GibbsOperator { @@ -19,7 +15,8 @@ public class GaussianTreeTraitGibbsOperator extends SimpleMCMCOperator implement private final Parameter traitParameter; private final TreeDataLikelihood treeDataLikelihood; - public GaussianTreeTraitGibbsOperator(TreeDataLikelihood treeDataLikelihood, Parameter parameter, String traitName) { + public GaussianTreeTraitGibbsOperator(TreeDataLikelihood treeDataLikelihood, Parameter parameter, double weight) { + setWeight(weight); this.traitParameter = parameter; this.treeDataLikelihood = treeDataLikelihood; ContinuousDataLikelihoodDelegate delegate = (ContinuousDataLikelihoodDelegate) treeDataLikelihood.getDataLikelihoodDelegate(); @@ -44,20 +41,24 @@ public double doOperation() { public static final AbstractXMLObjectParser PARSER = new AbstractXMLObjectParser() { @Override public Object parseXMLObject(XMLObject xo) throws XMLParseException { - String traitName = xo.getStringAttribute(TreeTraitParserUtilities.TRAIT_NAME); TreeDataLikelihood likelihood = (TreeDataLikelihood) xo.getChild(TreeDataLikelihood.class); Parameter parameter = (Parameter) xo.getChild(Parameter.class); - return new GaussianTreeTraitGibbsOperator(likelihood, parameter, traitName); + double weight = xo.getDoubleAttribute(WEIGHT); + return new GaussianTreeTraitGibbsOperator(likelihood, parameter, weight); } @Override public XMLSyntaxRule[] getSyntaxRules() { - return new XMLSyntaxRule[0]; + return new XMLSyntaxRule[]{ + new ElementRule(TreeDataLikelihood.class), + new ElementRule(Parameter.class), + AttributeRule.newDoubleRule(WEIGHT) + }; } @Override public String getParserDescription() { - return null; + return "samples traits at the tips of the tree from their full conditional distribution"; } @Override From d22a9f5044e106da0c8b81fc89b212cf743888d5 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Wed, 3 Aug 2022 14:49:08 -0700 Subject: [PATCH 098/196] very messy code to deal w/ numerical instability related to HMC w/ boundaries --- src/dr/inference/model/BoundedSpace.java | 130 ++++++++++++++---- .../ConvexSpaceRandomWalkOperator.java | 8 +- ...flectiveHamiltonianMonteCarloOperator.java | 44 ++++-- 3 files changed, 141 insertions(+), 41 deletions(-) diff --git a/src/dr/inference/model/BoundedSpace.java b/src/dr/inference/model/BoundedSpace.java index 1cb097b555..4cf1e12655 100644 --- a/src/dr/inference/model/BoundedSpace.java +++ b/src/dr/inference/model/BoundedSpace.java @@ -1,13 +1,13 @@ package dr.inference.model; import dr.app.bss.Utils; -import dr.evomodel.substmodel.ColtEigenSystem; -import dr.evomodel.substmodel.EigenDecomposition; +import dr.inference.operators.hmc.HamiltonianMonteCarloOperator; import dr.math.MathUtils; import dr.math.matrixAlgebra.EJMLUtils; import dr.math.matrixAlgebra.IllegalDimension; import dr.math.matrixAlgebra.Matrix; import dr.math.matrixAlgebra.SymmetricMatrix; +import org.ejml.data.Complex64F; import org.ejml.data.DenseMatrix64F; import org.ejml.factory.DecompositionFactory; import org.ejml.interfaces.decomposition.CholeskyDecomposition; @@ -19,12 +19,12 @@ public interface BoundedSpace extends GeneralBoundsProvider { boolean isWithinBounds(double[] values); - IntersectionDistances distancesToBoundary(double[] origin, double[] direction); + IntersectionDistances distancesToBoundary(double[] origin, double[] direction, boolean isAtBoundary) throws HamiltonianMonteCarloOperator.NumericInstabilityException; double[] getNormalVectorAtBoundary(double[] position); - default double forwardDistanceToBoundary(double[] origin, double[] direction) { - return distancesToBoundary(origin, direction).forwardDistance; + default double forwardDistanceToBoundary(double[] origin, double[] direction, boolean isAtBoundary) throws HamiltonianMonteCarloOperator.NumericInstabilityException { + return distancesToBoundary(origin, direction, isAtBoundary).forwardDistance; } class IntersectionDistances { @@ -43,7 +43,8 @@ public IntersectionDistances(double forwardDistance, double backwardDistance) { class Correlation implements BoundedSpace { private static final boolean DEBUG = false; - private static final double TOL = 1e-10; + private static final double TOL = 0; + private static final double BOUNDARY_TOL = 1e-6; private final int dim; public Correlation(int dim) { @@ -128,9 +129,32 @@ private double[] trajectoryEigenvalues(double[] origin, double[] direction) { throw new RuntimeException("illegal dimensions"); } - ColtEigenSystem eigenSystem = new ColtEigenSystem(dim); - EigenDecomposition decomposition = eigenSystem.decomposeMatrix(Z.toComponents()); //TODO: only need largest magnitude eigenvalues - double[] values = decomposition.getEigenValues(); + double[] z = Z.toArrayComponents(); //TODO: CLEAN THIS UP!!!!! + DenseMatrix64F A = DenseMatrix64F.wrap(dim, dim, z); + org.ejml.interfaces.decomposition.EigenDecomposition factory = new DecompositionFactory().eig(dim, false, false); + if (!factory.decompose(A)) throw new RuntimeException("Eigen decomposition failed."); + double[] allValues = new double[dim]; + int nReal = 0; + for (int i = 0; i < dim; i++) { + Complex64F ev = factory.getEigenvalue(i); + if (ev.isReal()) { + allValues[nReal] = ev.real; + nReal++; + } + } + + double[] values = new double[nReal]; + System.arraycopy(allValues, 0, values, 0, nReal); + +// ColtEigenSystem eigenSystem = new ColtEigenSystem(dim); +// EigenDecomposition decomposition = eigenSystem.decomposeMatrix(Z.toComponents()); //TODO: only need largest magnitude eigenvalues +// double[] values = decomposition.getEigenValues(); + if (DEBUG) { + System.out.println("Raw matrix to decompose: "); + System.out.println(Z); + System.out.print("Raw eigenvalues: "); + Utils.printArray(values); + } for (int i = 0; i < values.length; i++) { values[i] = 1 / values[i]; } @@ -140,31 +164,87 @@ private double[] trajectoryEigenvalues(double[] origin, double[] direction) { } @Override - public IntersectionDistances distancesToBoundary(double[] origin, double[] direction) { + public IntersectionDistances distancesToBoundary(double[] origin, double[] direction, boolean isAtBoundary) throws HamiltonianMonteCarloOperator.NumericInstabilityException { + + if (!isWithinBounds(origin)) { + if (isAtBoundary) { + //TODO: remove below + SymmetricMatrix C = compoundCorrelationSymmetricMatrix(origin, dim); + double det = 0; + try { + det = C.determinant(); + } catch (IllegalDimension illegalDimension) { + illegalDimension.printStackTrace(); + } - if (!isWithinBounds(origin)) { //TODO: make this optional? - SymmetricMatrix C = compoundCorrelationSymmetricMatrix(origin, dim); - System.out.println(C); - try { - System.out.println(C.determinant()); - } catch (IllegalDimension illegalDimension) { - illegalDimension.printStackTrace(); + if (Math.abs(det) > BOUNDARY_TOL) { + System.out.println(det); + throw new HamiltonianMonteCarloOperator.NumericInstabilityException(); + } + } else { + SymmetricMatrix C = compoundCorrelationSymmetricMatrix(origin, dim); + System.out.println(C); + try { + System.out.println(C.determinant()); + } catch (IllegalDimension illegalDimension) { + illegalDimension.printStackTrace(); + } + + throw new HamiltonianMonteCarloOperator.NumericInstabilityException(); } - throw new RuntimeException("Starting position is outside of bounds"); } double values[] = robustTrajectoryEigenValues(origin, direction); double minNegative = Double.NEGATIVE_INFINITY; + double minNegative2 = Double.NEGATIVE_INFINITY; double minPositive = Double.POSITIVE_INFINITY; + double minPositive2 = Double.POSITIVE_INFINITY; + for (int i = 0; i < values.length; i++) { double value = values[i]; if (value < -TOL && value > minNegative) { + minNegative2 = minNegative; minNegative = value; + } else if (value < -TOL && value > minNegative2) { + minNegative2 = value; } else if (value >= TOL & value < minPositive) { + minPositive2 = minPositive; minPositive = value; + } else if (value >= TOL && value < minPositive2) { + minPositive2 = value; + } + } + + if (isAtBoundary) { + if (DEBUG) { + System.out.println("minNegative: " + minNegative); + System.out.println("minNegative2: " + minNegative2); + System.out.println("minPositive: " + minPositive); + System.out.println("minPositive2: " + minPositive2); + + } + if (Math.abs(minNegative) < minPositive) { + if (Math.abs(minNegative) < BOUNDARY_TOL) { + minNegative = minNegative2; + } else { + throw new RuntimeException("isAtBoundary = true but does not appear to be near boundary"); + } + } else { + if (minPositive < BOUNDARY_TOL) { + minPositive = minPositive2; + } else { + throw new RuntimeException("isAtBoundary = true but does not appear to be near boundary"); + } + } + if (DEBUG) { + System.out.println("minNegative: " + minNegative); + System.out.println("minNegative2: " + minNegative2); + System.out.println("minPositive: " + minPositive); + System.out.println("minPositive2: " + minPositive2); + } } @@ -217,13 +297,13 @@ public IntersectionDistances distancesToBoundary(double[] origin, double[] direc throw new RuntimeException(); } - if (detY < -TOL || detY > 1) { - throw new RuntimeException("invalid starting position"); - } - - if (absMax > 1.0) { - throw new RuntimeException("Invalid ending position"); - } +// if (detY < -TOL || detY > 1) { +// throw new RuntimeException("invalid starting position"); +// } +// +// if (absMax > 1.0) { +// throw new RuntimeException("Invalid ending position"); +// } } return new IntersectionDistances(minNegative, minPositive); diff --git a/src/dr/inference/operators/ConvexSpaceRandomWalkOperator.java b/src/dr/inference/operators/ConvexSpaceRandomWalkOperator.java index 29209ddc4f..58ce4b8a5c 100644 --- a/src/dr/inference/operators/ConvexSpaceRandomWalkOperator.java +++ b/src/dr/inference/operators/ConvexSpaceRandomWalkOperator.java @@ -2,6 +2,7 @@ import dr.inference.model.BoundedSpace; import dr.inference.model.Parameter; +import dr.inference.operators.hmc.HamiltonianMonteCarloOperator; import dr.math.distributions.MultivariateNormalDistribution; import dr.math.matrixAlgebra.CholeskyDecomposition; import jebl.math.Random; @@ -190,7 +191,12 @@ public double doOperation() { sample[varInds.get(i)] = varSample[i]; } - BoundedSpace.IntersectionDistances distances = space.distancesToBoundary(values, sample); + BoundedSpace.IntersectionDistances distances; + try { + distances = space.distancesToBoundary(values, sample, false); + } catch (HamiltonianMonteCarloOperator.NumericInstabilityException e) { + throw new RuntimeException("position outside of bounded space at beginning of operator move"); + } // double u1 = Random.nextDouble() * distances.forwardDistance; // for (int i = 0; i < values.length; i++) { // sample[i] = values[i] + (sample[i] - values[i]) * u1; diff --git a/src/dr/inference/operators/hmc/ReflectiveHamiltonianMonteCarloOperator.java b/src/dr/inference/operators/hmc/ReflectiveHamiltonianMonteCarloOperator.java index 53058ce314..184199c286 100644 --- a/src/dr/inference/operators/hmc/ReflectiveHamiltonianMonteCarloOperator.java +++ b/src/dr/inference/operators/hmc/ReflectiveHamiltonianMonteCarloOperator.java @@ -48,6 +48,7 @@ public class ReflectiveHamiltonianMonteCarloOperator extends HamiltonianMonteCarloOperator implements Reportable { private final GeneralBoundsProvider parameterBound; + private boolean isAtBoundary = false; private static final boolean DEBUG = false; @@ -106,19 +107,20 @@ abstract class WithBounds extends HamiltonianMonteCarloOperator.LeapFrogEngine.D super(parameter, instabilityHandler, preconditioning, mask); } - protected abstract ReflectionEvent nextEvent(double[] position, WrappedVector momentum, double intervalLength); + protected abstract ReflectionEvent nextEvent(double[] position, WrappedVector momentum, double intervalLength, + boolean isAtBoundary) throws NumericInstabilityException; @Override public void updatePosition(double[] position, WrappedVector momentum, - double functionalStepSize) { + double functionalStepSize) throws NumericInstabilityException { double collapsedTime = 0.0; while (collapsedTime < functionalStepSize) { - ReflectionEvent event = nextEvent(position, momentum, functionalStepSize - collapsedTime); + ReflectionEvent event = nextEvent(position, momentum, functionalStepSize - collapsedTime, isAtBoundary); if (DEBUG) { - SymmetricMatrix C = SymmetricMatrix.compoundCorrelationSymmetricMatrix(position, 6); //TODO: remove + SymmetricMatrix C = SymmetricMatrix.compoundCorrelationSymmetricMatrix(position, 11); //TODO: remove try { System.out.println("starting det: " + C.determinant()); } catch (IllegalDimension illegalDimension) { @@ -127,11 +129,11 @@ public void updatePosition(double[] position, WrappedVector momentum, } } - event.doReflection(position, momentum); + isAtBoundary = event.doReflection(position, momentum); if (DEBUG) { System.out.println("event: " + event.getType()); - SymmetricMatrix C = SymmetricMatrix.compoundCorrelationSymmetricMatrix(position, 6); //TODO: remove + SymmetricMatrix C = SymmetricMatrix.compoundCorrelationSymmetricMatrix(position, 11); //TODO: remove try { System.out.println("ending det: " + C.determinant()); } catch (IllegalDimension illegalDimension) { @@ -174,9 +176,10 @@ class WithMultivariateCurvedBounds extends WithBounds { @Override - protected ReflectionEvent nextEvent(double[] position, WrappedVector momentum, double intervalLength) { + protected ReflectionEvent nextEvent(double[] position, WrappedVector momentum, double intervalLength, + boolean isAtBoundary) throws NumericInstabilityException { double[] velocity = preconditioning.getVelocity(momentum); - double timeToReflection = space.forwardDistanceToBoundary(position, velocity); + double timeToReflection = space.forwardDistanceToBoundary(position, velocity, isAtBoundary); if (DEBUG) { System.out.println("Time to reflection: " + timeToReflection); @@ -221,7 +224,8 @@ protected WithGraphBounds(Parameter parameter, @Override - protected ReflectionEvent nextEvent(double[] position, WrappedVector momentum, double intervalLength) { + protected ReflectionEvent nextEvent(double[] position, WrappedVector momentum, double intervalLength, + boolean isAtBoundary) { ReflectionEvent reflectionEventAtFixedBound = firstReflectionAtFixedBounds(position, momentum, intervalLength); ReflectionEvent collisionEvent = firstCollision(position, momentum, intervalLength); return (reflectionEventAtFixedBound.getEventTime() < collisionEvent.getEventTime()) ? reflectionEventAtFixedBound : collisionEvent; @@ -368,8 +372,9 @@ public ReflectionType getType() { return type; } - public void doReflection(double[] position, WrappedVector momentum) { + public boolean doReflection(double[] position, WrappedVector momentum) { type.doReflection(position, preconditioning, momentum, eventLocation, indices, normalVector, eventTime, remainingTime); + return type.isAtBoundary(); } } @@ -412,11 +417,11 @@ void doReflection(double[] position, MassPreconditioner preconditioning, Wrapped System.out.println(momentum); } - if (BOUNCE) { - double t = Math.min(remainingTime, 1e-10); //TODO: need to make sure I'm not leaving the space again, also need to update time later - System.out.println("bounce time: " + t); - updatePosition(position, preconditioning, momentum, t); - } +// if (BOUNCE) { +// double t = Math.min(remainingTime, 1e-11); //TODO: need to make sure I'm not leaving the space again, also need to update time later +//// System.out.println("bounce time: " + t); +// updatePosition(position, preconditioning, momentum, t); +// } if (DEBUG) { System.out.print("bounce: "); @@ -467,6 +472,11 @@ void doReflection(double[] position, MassPreconditioner preconditioning, Wrapped Utils.printArray(position); } } + + @Override + public boolean isAtBoundary() { + return false; + } }; void updatePosition(double[] position, MassPreconditioner preconditioning, WrappedVector momentum, double time) { @@ -481,6 +491,10 @@ abstract void doReflection(double[] position, MassPreconditioner preconditioning private static final boolean BOUNCE = true; + public boolean isAtBoundary() { + return true; + } + } } From f3ba435e659dd989dcfff6c86166d59425c02e5a Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Wed, 3 Aug 2022 14:57:07 -0700 Subject: [PATCH 099/196] cleaning code --- src/dr/inference/model/BoundedSpace.java | 22 ------ ...flectiveHamiltonianMonteCarloOperator.java | 76 ++++--------------- 2 files changed, 16 insertions(+), 82 deletions(-) diff --git a/src/dr/inference/model/BoundedSpace.java b/src/dr/inference/model/BoundedSpace.java index 4cf1e12655..a6471385e5 100644 --- a/src/dr/inference/model/BoundedSpace.java +++ b/src/dr/inference/model/BoundedSpace.java @@ -146,9 +146,6 @@ private double[] trajectoryEigenvalues(double[] origin, double[] direction) { double[] values = new double[nReal]; System.arraycopy(allValues, 0, values, 0, nReal); -// ColtEigenSystem eigenSystem = new ColtEigenSystem(dim); -// EigenDecomposition decomposition = eigenSystem.decomposeMatrix(Z.toComponents()); //TODO: only need largest magnitude eigenvalues -// double[] values = decomposition.getEigenValues(); if (DEBUG) { System.out.println("Raw matrix to decompose: "); System.out.println(Z); @@ -219,13 +216,7 @@ public IntersectionDistances distancesToBoundary(double[] origin, double[] direc } if (isAtBoundary) { - if (DEBUG) { - System.out.println("minNegative: " + minNegative); - System.out.println("minNegative2: " + minNegative2); - System.out.println("minPositive: " + minPositive); - System.out.println("minPositive2: " + minPositive2); - } if (Math.abs(minNegative) < minPositive) { if (Math.abs(minNegative) < BOUNDARY_TOL) { minNegative = minNegative2; @@ -239,13 +230,7 @@ public IntersectionDistances distancesToBoundary(double[] origin, double[] direc throw new RuntimeException("isAtBoundary = true but does not appear to be near boundary"); } } - if (DEBUG) { - System.out.println("minNegative: " + minNegative); - System.out.println("minNegative2: " + minNegative2); - System.out.println("minPositive: " + minPositive); - System.out.println("minPositive2: " + minPositive2); - } } minPositive = -minPositive; @@ -297,13 +282,6 @@ public IntersectionDistances distancesToBoundary(double[] origin, double[] direc throw new RuntimeException(); } -// if (detY < -TOL || detY > 1) { -// throw new RuntimeException("invalid starting position"); -// } -// -// if (absMax > 1.0) { -// throw new RuntimeException("Invalid ending position"); -// } } return new IntersectionDistances(minNegative, minPositive); diff --git a/src/dr/inference/operators/hmc/ReflectiveHamiltonianMonteCarloOperator.java b/src/dr/inference/operators/hmc/ReflectiveHamiltonianMonteCarloOperator.java index 184199c286..752342548d 100644 --- a/src/dr/inference/operators/hmc/ReflectiveHamiltonianMonteCarloOperator.java +++ b/src/dr/inference/operators/hmc/ReflectiveHamiltonianMonteCarloOperator.java @@ -31,9 +31,7 @@ import dr.inference.model.*; import dr.inference.operators.AdaptationMode; import dr.inferencexml.operators.hmc.ReflectiveHamiltonianMonteCarloOperatorParser; -import dr.math.matrixAlgebra.IllegalDimension; import dr.math.matrixAlgebra.ReadableVector; -import dr.math.matrixAlgebra.SymmetricMatrix; import dr.math.matrixAlgebra.WrappedVector; import dr.util.Transform; import dr.xml.Reportable; @@ -119,29 +117,7 @@ public void updatePosition(double[] position, WrappedVector momentum, while (collapsedTime < functionalStepSize) { ReflectionEvent event = nextEvent(position, momentum, functionalStepSize - collapsedTime, isAtBoundary); - if (DEBUG) { - SymmetricMatrix C = SymmetricMatrix.compoundCorrelationSymmetricMatrix(position, 11); //TODO: remove - try { - System.out.println("starting det: " + C.determinant()); - } catch (IllegalDimension illegalDimension) { - illegalDimension.printStackTrace(); - throw new RuntimeException(); - } - } - isAtBoundary = event.doReflection(position, momentum); - - if (DEBUG) { - System.out.println("event: " + event.getType()); - SymmetricMatrix C = SymmetricMatrix.compoundCorrelationSymmetricMatrix(position, 11); //TODO: remove - try { - System.out.println("ending det: " + C.determinant()); - } catch (IllegalDimension illegalDimension) { - illegalDimension.printStackTrace(); - throw new RuntimeException(); - } - } - collapsedTime += event.getEventTime(); } setParameter(position); @@ -190,9 +166,6 @@ protected ReflectionEvent nextEvent(double[] position, WrappedVector momentum, d if (timeToReflection > intervalLength) { return new ReflectionEvent(ReflectionType.None, intervalLength, Double.NaN, new int[0]); } else { - if (DEBUG) { - System.out.println("!!!!!!!!!!!!!!!REFLECTION!!!!!!!!!!!!!!!"); - } double[] boundaryPosition = new double[position.length]; for (int i = 0; i < position.length; i++) { boundaryPosition[i] = position[i] + timeToReflection * velocity[i]; @@ -373,7 +346,22 @@ public ReflectionType getType() { } public boolean doReflection(double[] position, WrappedVector momentum) { + + if (DEBUG) { + System.out.println("time: " + eventTime); + System.out.print("start: "); + Utils.printArray(position); + System.out.println(momentum); + } + type.doReflection(position, preconditioning, momentum, eventLocation, indices, normalVector, eventTime, remainingTime); + + if (DEBUG) { + System.out.print("end: "); + Utils.printArray(position); + System.out.println(momentum); + } + return type.isAtBoundary(); } @@ -388,13 +376,6 @@ void doReflection(double[] position, MassPreconditioner preconditioning, Wrapped double remainingTime) { - if (DEBUG) { - System.out.println("time: " + time); - System.out.print("start: "); - Utils.printArray(position); - System.out.println(momentum); - } - updatePosition(position, preconditioning, momentum, time); double vn = 0; double nn = 0; @@ -411,23 +392,7 @@ void doReflection(double[] position, MassPreconditioner preconditioning, Wrapped position[i] = eventLocation[i]; } - if (DEBUG) { - System.out.print("end: "); - Utils.printArray(position); - System.out.println(momentum); - } - -// if (BOUNCE) { -// double t = Math.min(remainingTime, 1e-11); //TODO: need to make sure I'm not leaving the space again, also need to update time later -//// System.out.println("bounce time: " + t); -// updatePosition(position, preconditioning, momentum, t); -// } - if (DEBUG) { - System.out.print("bounce: "); - Utils.printArray(position); - System.out.println(momentum); - } } }, Reflection { @@ -461,16 +426,8 @@ void doReflection(double[] position, MassPreconditioner preconditioning, Wrapped double eventLocation[], int[] indices, double[] normalVector, double time, double remainginTime) { - if (DEBUG) { - System.out.println("time: " + time); - System.out.print("start: "); - Utils.printArray(position); - } updatePosition(position, preconditioning, momentum, time); - if (DEBUG) { - System.out.print("end: "); - Utils.printArray(position); - } + } @Override @@ -489,7 +446,6 @@ void updatePosition(double[] position, MassPreconditioner preconditioning, Wrapp abstract void doReflection(double[] position, MassPreconditioner preconditioning, WrappedVector momentum, double eventLocation[], int[] indices, double[] normalVector, double time, double remainingTime); - private static final boolean BOUNCE = true; public boolean isAtBoundary() { return true; From eec8dbf5ee9268f0b4848fd500f00dfd9d6584e9 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Wed, 3 Aug 2022 17:25:52 -0700 Subject: [PATCH 100/196] a bit more code cleaning (not relying on two different matrix implementations) + a unit test --- src/dr/inference/model/BoundedSpace.java | 55 +++++++++----- .../dr/inference/model/BoundedSpaceTest.java | 76 +++++++++++++++++++ 2 files changed, 111 insertions(+), 20 deletions(-) create mode 100644 src/test/dr/inference/model/BoundedSpaceTest.java diff --git a/src/dr/inference/model/BoundedSpace.java b/src/dr/inference/model/BoundedSpace.java index a6471385e5..d11ce83574 100644 --- a/src/dr/inference/model/BoundedSpace.java +++ b/src/dr/inference/model/BoundedSpace.java @@ -11,6 +11,7 @@ import org.ejml.data.DenseMatrix64F; import org.ejml.factory.DecompositionFactory; import org.ejml.interfaces.decomposition.CholeskyDecomposition; +import org.ejml.ops.CommonOps; import static dr.math.matrixAlgebra.SymmetricMatrix.compoundCorrelationSymmetricMatrix; import static dr.math.matrixAlgebra.SymmetricMatrix.compoundSymmetricMatrix; @@ -43,12 +44,25 @@ public IntersectionDistances(double forwardDistance, double backwardDistance) { class Correlation implements BoundedSpace { private static final boolean DEBUG = false; + private static final boolean CHECK_WITHIN_BOUNDS = false; private static final double TOL = 0; - private static final double BOUNDARY_TOL = 1e-6; + private static final double BOUNDARY_TOL = 1e-9; private final int dim; + private final DenseMatrix64F C; + private final DenseMatrix64F V; + private final DenseMatrix64F Cinv; + private final DenseMatrix64F CinvV; + public Correlation(int dim) { this.dim = dim; + this.V = new DenseMatrix64F(dim, dim); + this.C = new DenseMatrix64F(dim, dim); + for (int i = 0; i < dim; i++) { + C.set(i, i, 1); + } + this.Cinv = new DenseMatrix64F(dim, dim); + this.CinvV = new DenseMatrix64F(dim, dim); } @@ -116,23 +130,24 @@ private double[] trajectoryEigenvalues(double[] origin, double[] direction) { double[] x = new double[origin.length]; System.arraycopy(direction, 0, x, 0, x.length); //TODO: is this necessary? - SymmetricMatrix Y = compoundCorrelationSymmetricMatrix(origin, dim); - SymmetricMatrix X = compoundSymmetricMatrix(0.0, x, dim); - -// SymmetricMatrix Xinv = X.inverse(); - SymmetricMatrix Yinv = Y.inverse(); - final Matrix Z; - - try { - Z = Yinv.product(X); - } catch (IllegalDimension illegalDimension) { - throw new RuntimeException("illegal dimensions"); + int ind = 0; + for (int i = 0; i < dim; i++) { + for (int j = (i + 1); j < dim; j++) { + C.set(i, j, origin[ind]); + C.set(j, i, origin[ind]); + V.set(i, j, direction[ind]); + V.set(j, i, direction[ind]); + ind++; + } } - double[] z = Z.toArrayComponents(); //TODO: CLEAN THIS UP!!!!! - DenseMatrix64F A = DenseMatrix64F.wrap(dim, dim, z); + CommonOps.invert(C, Cinv); + CommonOps.mult(Cinv, V, CinvV); + org.ejml.interfaces.decomposition.EigenDecomposition factory = new DecompositionFactory().eig(dim, false, false); - if (!factory.decompose(A)) throw new RuntimeException("Eigen decomposition failed."); + + if (!factory.decompose(CinvV)) throw new RuntimeException("Eigen decomposition failed."); + double[] allValues = new double[dim]; int nReal = 0; for (int i = 0; i < dim; i++) { @@ -148,7 +163,7 @@ private double[] trajectoryEigenvalues(double[] origin, double[] direction) { if (DEBUG) { System.out.println("Raw matrix to decompose: "); - System.out.println(Z); + System.out.println(CinvV); System.out.print("Raw eigenvalues: "); Utils.printArray(values); } @@ -163,15 +178,15 @@ private double[] trajectoryEigenvalues(double[] origin, double[] direction) { @Override public IntersectionDistances distancesToBoundary(double[] origin, double[] direction, boolean isAtBoundary) throws HamiltonianMonteCarloOperator.NumericInstabilityException { - if (!isWithinBounds(origin)) { + if (CHECK_WITHIN_BOUNDS && !isWithinBounds(origin)) { // don't automatically check that it's inside the boundary if (isAtBoundary) { - //TODO: remove below SymmetricMatrix C = compoundCorrelationSymmetricMatrix(origin, dim); - double det = 0; + double det; try { det = C.determinant(); } catch (IllegalDimension illegalDimension) { illegalDimension.printStackTrace(); + throw new RuntimeException(); } if (Math.abs(det) > BOUNDARY_TOL) { @@ -185,11 +200,11 @@ public IntersectionDistances distancesToBoundary(double[] origin, double[] direc System.out.println(C.determinant()); } catch (IllegalDimension illegalDimension) { illegalDimension.printStackTrace(); + throw new RuntimeException(); } throw new HamiltonianMonteCarloOperator.NumericInstabilityException(); } - } diff --git a/src/test/dr/inference/model/BoundedSpaceTest.java b/src/test/dr/inference/model/BoundedSpaceTest.java new file mode 100644 index 0000000000..a1e0ed3cf4 --- /dev/null +++ b/src/test/dr/inference/model/BoundedSpaceTest.java @@ -0,0 +1,76 @@ +package test.dr.inference.model; + + +import dr.inference.model.BoundedSpace; +import dr.inference.operators.hmc.HamiltonianMonteCarloOperator; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + + +public class BoundedSpaceTest { + + private class Instance { + public final double[] position; + public final double[] velocity; + public final boolean isAtBoundary; + public final double actual; + + Instance(double[] position, double[] velocity, boolean isAtBoundary, double actual) { + this.position = position; + this.velocity = velocity; + this.isAtBoundary = isAtBoundary; + this.actual = actual; + } + } + + Instance[] instances = new Instance[]{ + new Instance( // dense velocity, not at boundary + new double[]{0.2822855517829705, 0.40605243926469287, 0.02849762638750524, -0.564844016766597, 0.10245003047216325, 0.0965515090688095, 0.36620919740224617, -0.01787268001498262, 0.6166518433752171, 0.33255922064276344, 0.3486639561922815, 0.20009002909341947, 0.03369031306626756, 0.3913488688811744, 0.6333813479719003, -0.6524997126334955, -0.23103935804969739, 0.5721530659041922, -0.09121186775518811, -0.33094967030159816, 0.10373856188477253, 0.2358517748200947, 0.28136454978838454, 0.38180511299157194, 0.03857632199153963, -0.10247319649770201, -0.4515249611583288, 0.04069735163371401, 0.2739134475778366, 0.20394511869828952, + 0.16987513814725316, 0.0498047833476856, -0.23155751470047312, -0.34084092179183206, -0.5435689844314044, -0.22997407806145242, -0.524907999369761, -0.26982552685916034, 0.013582732323656897, 0.6283381537641279, -0.5262437447150251, -0.24568348204054502, -0.48506559078208455, -0.006540034446863485, 0.487435821074331}, + new double[]{2.0438614777153097, -0.7395282076530152, -0.04671949232593737, -0.05568195314160399, + -1.7503124002123156, 1.438912952418718, -1.2979060424575182, -0.12967355482400317, -0.5877605903799071, 0.4663973722472238, 0.6501194485236997, -0.27427888048772114, 0.7773041621717932, -1.2529452491475837, 0.1264625176370857, -2.184869901041784, -2.3210809198970628, 0.4536691205497068, 0.0040154775744028525, -0.02022797360190504, -0.4122624930974949, -0.8144835256337633, 1.4565197109479235, 1.240849853686729, 0.985253115077047, 0.4962656279813027, -0.21246925243198708, -0.5021739010419594, 1.4565629697944975, -1.1261205391371558, 0.8147966861488231, -0.08483633566338881, -0.7918752798447827, 0.15361236989858948, 0.14789511915893988, -1.6506804194613995, 0.20917813332349894, -1.020547878554704, 0.6933439963746123, -0.8804450733870868, -1.0296658859958183, 0.2777824369980236, 0.429102310041008, -0.1696873347531642, 0.43901426447036324}, + false, + 0.01595563335319449), + new Instance( // dense velocity, at boundary + new double[]{0.11647159249911301, -0.08506425456560686, 0.06052011411111528, 0.6280642539160219, 0.1858869704724907, 0.2707344946140536, -0.20593108085359318, -0.23371769475449147, -0.3535532648113636, 0.2584924019161337, 0.49145781749224526, -0.08545927119819614, 0.0019223081229396148, -0.17576719543866787, 0.5669613122986198, -0.3811300793200186, 0.048178306235236744, -0.24174056602815053, 0.10075016419804625, -0.04283060683464314, 0.24322613111387503, 0.5120677150420093, -0.026830637068174238, 0.004388348867356102, 0.021323360340519286, 0.0675870251067074, -0.673678610283607, 0.20730194234977975, -0.25512501925923997, 0.6381849075861147, -0.07200433671028439, 0.1623547683774016, -0.27416192799573713, -0.10673348908034844, -0.12657024780739462, -0.2944885708858614, -0.22545327796218337, -0.3349602644633396, -0.12418255126124608, -0.0781009754221288, 0.00985739557424605, -0.7906189231671669, 0.042233590443951856, 0.3912512164379495, 0.27717395861287647}, + new double[]{-1.0346921156530475, -0.7834466818276314, -0.6036925610549285, -1.6488512182861115, -0.8023765324230373, -0.36518794172088387, -0.5445949672312552, 0.17745174291760565, 0.41790594874801706, 0.09060918246766107, 0.354260412636501, 1.9062401157318938, 1.140465086875608, -0.13439202235311207, 1.2385557416800101, -0.08380905464237058, 0.599994532149525, -0.6963958583478322, 0.23334316441818698, -1.052992644030231, -0.2593191511965254, 1.3702363112988352, -0.06139381721406518, -1.4380644122199022, -1.7585887877096682, 1.2412102713846938, 0.4758368659296449, -1.3717994142343595, 1.43323783535285, -2.091234335971886, 0.9082181607882182, 0.08267774926196575, -1.3011183073338028, 0.9517052987954928, -2.7855629538249462, -1.5071558646552672, 1.6031446100634232, -0.23005292084728482, 0.5750574858677685, -0.09568198613982926, 0.16991262670991295, -1.194864379072432, 0.24365560223571764, -0.343107465013622, -0.7663578593237961}, + true, + 0.01223608297387474), + new Instance( // sparse velocity, not at boundary + new double[]{0.0670476094576967, 0.4367961067430731, 0.48664883054844366, -0.19274573625080862, -0.2744063095062948, -0.15306145989259248, 0.30721493611428496, 0.16776289055002544, 0.08104118685651174, -0.6227814065271465, 0.0731979822708254, 0.7857019348678728, 0.2677449731175475, -0.5502088508980884, 0.03771846114331344, 0.3190168387554593, -0.22047212328810817, -0.06540149474172938, -0.687608567476592, -0.34528076191263624, 0.22370110435756238, 0.040256367672321235, -0.22740372706234033, 0.2366142305038848, -0.03725095935153343, 0.2510891496166299, -0.03355890231101856, -0.007285698073503861, 0.6427633993305558, -0.21383544251461112, -0.0016195571510596834, -0.46644562664579325, -0.11856508641792099, 0.29565621479439363, -0.30720691413587015, 0.15803683808011923, 0.2828987312846409, 0.4414784551297542, -0.5221583157958694, 0.40676508546537504, -0.4944084288272014, 0.24916361029681994, -0.2094230343411282, -0.0976994550220947, -0.6795147439292425}, + new double[]{0.0, 0.0, -1.0683641929041463, 1.4074769699960024, 2.4076197351280726, -0.7071752529563077, 0.0902219088734521, -1.1701540514148658, -0.8850141813046155, 0.0, 0.43200829972893906, -0.7761314710437377, 0.7440087608644848, -0.024508316794491402, -1.423953745919625, -0.3052756083657273, 1.1379728532070383, 1.812377176684857, -1.5317894306075623, -0.009775591333715548, 1.1282507432924553, 0.7115553368277749, 1.5015660134490751, -0.5569539859571256, 0.0, 0.0, 0.0, 0.0, 0.0, -1.4341355880302382, 0.0, 0.0, 0.0, 0.0, 0.6427033706412396, + 0.0, 0.0, 0.0, -0.5146126961480285, 0.0, 0.0, -0.6411849536801717, 0.0, 0.38494833959549835, -0.4133283160935234}, + false, + 0.024629529545630423), + new Instance( // sparse velocity, at boundary + new double[]{-0.26757589087078193, -0.02938569555892525, 0.18513857054513194, 0.45426488679410915, -0.01763137023312626, -0.12490569292229133, -0.02248584573483657, 0.3105550081285866, 0.22388294813341902, -0.04819405233443179, 0.31075193194590905, 0.4883275878373804, 0.48203706437713173, -0.10747538571923325, 0.08967928867045576, -0.3804060916333354, -0.22723677965430755, -0.05680381914344709, -0.2714994265633495, + 0.6064673183295967, 0.11161324462019362, 0.02432671422275462, 0.3439120417020678, -0.028657224962422713, 0.2664948230582488, 0.40488926421869764, -0.27352248738407703, 0.6467077532653467, 0.40021128203106154, 0.5023778368436919, 0.318781718004652, -0.20214876545585567, -0.32225743555541214, -0.19859068384300757, -0.18995352580103042, 0.11905713166260283, 0.0011462916709671846, 0.09599237129692728, -0.1260666035857841, -0.14407227967969546, 0.09281799565543995, -0.21075798426389972, 0.16880530157106863, 0.38074163001542666, 0.7018785056040518}, + new double[]{-0.0, -0.0, -0.8822912283623365, -0.07850848303667071, 1.7168396136971413, 0.33187232269944494, -0.12134146134082555, -0.35206153660614065, 0.884644612485139, -0.0, -1.9123537150968595, -0.5569287897049415, 0.354481275240167, -1.1684131108537485, 0.4189200083984581, 0.5675955066325348, 1.446092480296222, -0.15266121755447873, -0.32814205291827453, -0.3222217609135461, -0.7644477043670591, -1.9493869595983575, -0.616075655009724, -1.374788715704303, -0.0, -0.0, -0.0, -0.0, -0.0, -1.1129531496787755, -0.0, -0.0, -0.0, -0.0, 0.6765725834261671, -0.0, -0.0, -0.0, 1.2525347271504157, -0.0, -0.0, -0.7057402390019548, -0.0, -0.40903841628897586, -1.2628021481812326}, + true, + 0.014146061076949323) + }; + + + @Test + public void CorrelationTest() { + int ind = 0; + for (Instance instance : instances) { + ind++; + int dim = (1 + (int) Math.round(Math.sqrt(1 + 8 * instance.position.length))) / 2; + BoundedSpace.Correlation correlationBound = new BoundedSpace.Correlation(dim); + + double t; + try { + t = correlationBound.forwardDistanceToBoundary(instance.position, instance.velocity, instance.isAtBoundary); + } catch (HamiltonianMonteCarloOperator.NumericInstabilityException e) { + e.printStackTrace(); + throw new RuntimeException(); + } + + assertEquals("correlation matrix " + ind, instance.actual, t, 1e-10); + + } + + } +} From fec8171f8650b7be2903fefa504b3598f2e0a5e3 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Wed, 3 Aug 2022 17:26:34 -0700 Subject: [PATCH 101/196] nit-picking typo in test name --- src/test/dr/inference/model/CompoundEigenMatrixTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/test/dr/inference/model/CompoundEigenMatrixTest.java b/src/test/dr/inference/model/CompoundEigenMatrixTest.java index 1d0b456fcd..d19cf51af1 100644 --- a/src/test/dr/inference/model/CompoundEigenMatrixTest.java +++ b/src/test/dr/inference/model/CompoundEigenMatrixTest.java @@ -154,7 +154,7 @@ public double[] getSelectionStrength() { Instance[] all = {test0, test1, test2, test3, test4}; @Test - public void CompundEigenMatrix() throws Exception { + public void CompoundEigenMatrix() throws Exception { for (Instance test : all) { Parameter alphaEig = test.getEigenValuesStrengthOfSelection(); From 22d5b62cce9cd2b24f74feb504553fc754b4bf82 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Wed, 3 Aug 2022 17:27:10 -0700 Subject: [PATCH 102/196] taking credit (blame?) --- src/test/dr/inference/model/BoundedSpaceTest.java | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/test/dr/inference/model/BoundedSpaceTest.java b/src/test/dr/inference/model/BoundedSpaceTest.java index a1e0ed3cf4..57226cf0fd 100644 --- a/src/test/dr/inference/model/BoundedSpaceTest.java +++ b/src/test/dr/inference/model/BoundedSpaceTest.java @@ -7,6 +7,10 @@ import static org.junit.Assert.assertEquals; +/** + * @author Gabriel Hassler + */ + public class BoundedSpaceTest { From 3cff4257c0e6df5321ad511bbb52b73df83c4201 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Thu, 4 Aug 2022 13:31:30 -0700 Subject: [PATCH 103/196] fixing precisionType bug --- .../RepeatedMeasuresTraitDataModel.java | 21 ++++++++++++------- .../ContinuousTraitDataModelParser.java | 15 +++++++++---- 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java b/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java index f7ab41d8b7..af693c7a93 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java @@ -289,12 +289,6 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { MutableTreeModel treeModel = (MutableTreeModel) xo.getChild(TreeModel.class); final ContinuousTraitPartialsProvider subModel; - if (xo.hasChildNamed(TreeTraitParserUtilities.TRAIT_PARAMETER)) { - subModel = ContinuousTraitDataModelParser.parseContinuousTraitDataModel(xo); - } else { - subModel = (ContinuousTraitPartialsProvider) xo.getChild(ContinuousTraitPartialsProvider.class); - } - XMLObject cxo = xo.getChild(PRECISION); MatrixParameterInterface samplingPrecision = (MatrixParameterInterface) @@ -312,8 +306,6 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { } - String modelName = subModel.getModelName(); - boolean scaleByTipHeight = xo.getAttribute(SCALE_BY_TIP_HEIGHT, false); int dimTrait = samplingPrecision.getColumnDimension(); @@ -325,6 +317,19 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { precisionType = PrecisionType.SCALAR; } + if (xo.hasChildNamed(TreeTraitParserUtilities.TRAIT_PARAMETER)) { + subModel = ContinuousTraitDataModelParser.parseContinuousTraitDataModel(xo, precisionType); + } else { + subModel = (ContinuousTraitPartialsProvider) xo.getChild(ContinuousTraitPartialsProvider.class); + if (subModel.getPrecisionType() != precisionType) { + throw new XMLParseException("Precision type of " + REPEATED_MEASURES_MODEL + " is " + + precisionType.getClass() + ", but the precision type of the child model " + + subModel.getModelName() + " is " + subModel.getPrecisionType().getClass()); + } + } + String modelName = subModel.getModelName(); + + if (!scaleByTipHeight) { return new RepeatedMeasuresTraitDataModel( modelName, diff --git a/src/dr/evomodelxml/continuous/ContinuousTraitDataModelParser.java b/src/dr/evomodelxml/continuous/ContinuousTraitDataModelParser.java index 91ca462c77..76f8e56ead 100644 --- a/src/dr/evomodelxml/continuous/ContinuousTraitDataModelParser.java +++ b/src/dr/evomodelxml/continuous/ContinuousTraitDataModelParser.java @@ -28,6 +28,10 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { } public static ContinuousTraitDataModel parseContinuousTraitDataModel(XMLObject xo) throws XMLParseException { + return parseContinuousTraitDataModel(xo, null); + } + + public static ContinuousTraitDataModel parseContinuousTraitDataModel(XMLObject xo, PrecisionType precisionType) throws XMLParseException { Tree treeModel = (Tree) xo.getChild(Tree.class); boolean[] missingIndicators; final String traitName; @@ -52,13 +56,16 @@ public static ContinuousTraitDataModel parseContinuousTraitDataModel(XMLObject x traitName = returnValue.traitName; useMissingIndices = returnValue.useMissingIndices; - PrecisionType precisionType = PrecisionType.SCALAR; + if (precisionType == null) { + precisionType = PrecisionType.SCALAR; - if (xo.getAttribute(FORCE_FULL_PRECISION, false) || - (useMissingIndices && !xo.getAttribute(FORCE_COMPLETELY_MISSING, false))) { - precisionType = PrecisionType.FULL; + if (xo.getAttribute(FORCE_FULL_PRECISION, false) || + (useMissingIndices && !xo.getAttribute(FORCE_COMPLETELY_MISSING, false))) { + precisionType = PrecisionType.FULL; + } } + if (xo.hasChildNamed(TreeTraitParserUtilities.JITTER)) { utilities.jitter(xo, dim, missingIndicators); } From d0803103ef66b09325f408720a92dce1b6459740 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Fri, 5 Aug 2022 11:31:58 -0700 Subject: [PATCH 104/196] starting test xml for composable continuous models --- ci/TestXML/testComposableContinuousModel.xml | 345 +++++++++++++++++++ 1 file changed, 345 insertions(+) create mode 100644 ci/TestXML/testComposableContinuousModel.xml diff --git a/ci/TestXML/testComposableContinuousModel.xml b/ci/TestXML/testComposableContinuousModel.xml new file mode 100644 index 0000000000..c6b747b795 --- /dev/null +++ b/ci/TestXML/testComposableContinuousModel.xml @@ -0,0 +1,345 @@ + + + + + 0.6161109857603313 -3.2602569254736204 + 0.42993066951497727 -2.9486654536360777 -4.057033299418535 + 3.4678011726657467 0.8863191402500052 -2.6600383868277264 -2.2183318084380526 + 1.8189980404545938 -2.29133552807135 + + + + 0.14235084027566275 -3.91591655904778 + -0.8637975323369211 -4.259385675424384 -5.678966719263365 + 3.173375670176406 1.2813599910134097 -0.6881012738308594 -3.0867336997208668 + 2.815737176293176 -4.915486206498403 + + + + 2.3449041383664317 0.9722340059227484 + 9.256880465255662 -2.487396604751713 -2.0680874725706184 + 1.283643704876766 2.8513096233737603 0.9708425448487273 -2.441471802594642 + 1.173180678839264 -1.7603750280019597 + + + + 1.813714920924903 0.39674460647288123 + 2.8127444097741146 -1.7841457523163833 0.5750599739693499 + 2.4926684326346047 3.141533052089204 2.1694254223223037 -1.3919262439366116 + -0.28897700463936593 -6.217061635911254 + + + + 1.2805936511229543 1.3638556864102687 + 9.162902712021545 -1.827052256870099 -3.114941272898032 + 0.310485244451206 2.019670079835947 -0.2328560668224443 -1.1680301897828358 + -0.169794880275567 -4.51512265786732 + + + + 0.8142903555974006 -4.817316126456728 + -4.901391200213806 -2.751126928637177 -2.4254554332105376 + 1.345687457452557 -1.213562335693784 -0.8837634132548726 -1.11674616410668 + 2.3567759275445193 -0.269174568289571 + + + + 0.5537342826706193 -1.9601398886629764 + -1.9003929413559715 -1.180721888259613 -0.4348486999592236 + 2.7126597161955437 -0.6289199169529813 -0.09557967938854506 0.20869076878417675 + -0.9389908132916613 -1.1337383511630141 + + + + 1.9283974425711239 -0.5873562093054292 + 3.588560410152809 -2.19448080201238 -2.7158658300416465 + -1.2924283665094718 1.6943907950946782 2.19407882497502 -1.42228247112324 + -0.4484624547540542 1.1444926622987146 + + + + 1.7199674427711251 -4.026480323777455 + 1.9442524369925174 -3.917564256498128 -4.961084588733641 + 3.9004288986948668 2.7387837406049877 -4.615424402780565 -4.971638687502833 + 5.028472260449842 -7.691835666782411 + + + + 0.40940700074267505 -2.2495354378495644 + -0.16784519819182986 -1.1337506486494466 -2.863705683591497 + -3.112332086507164 -2.4616107952948543 0.04796951247778125 2.4409311713200568 + -1.252912212199225 0.6769242891580838 + + + + + (((taxon10:0.0104828508210571,taxon7:0.06945686994340126):0.05472171377912479,(taxon6:0.09666701733874086,(taxon1:0.025715465624764816,(taxon9:0.14154200426014169,taxon2:0.013832276521980338):0.08913272474958198):0.1193630006138079):0.453503452970523):0.1964588174059454,(((taxon3:0.004017979531503662,taxon5:0.0496801467418641):0.01371646355496644,taxon4:0.07279286149575269):0.15588410980923004,taxon8:0.006692284455141036):0.0833467339290923); + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Check log likelihood of unextended model + + + + + + -36.27071651148219 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Check log likelihood of residual variance model + + + + + + -69.89723893416877 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Check log likelihood of latent factor model + + + + + + -122.30489986705037 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Check log likelihood of joint model + + + + + + -219.13465006512465 + + + + From 841716d4797079b8ee2b9a9c6c07b1598e4dc6aa Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Fri, 5 Aug 2022 15:38:02 -0700 Subject: [PATCH 105/196] fixing numTrait bug --- ci/TestXML/testComposableContinuousModel.xml | 52 +++++++++++++++++++ .../continuous/ContinuousTraitDataModel.java | 14 ++++- .../RepeatedMeasuresTraitDataModel.java | 11 +++- ...eScaledRepeatedMeasuresTraitDataModel.java | 4 +- .../continuous/RepeatedMeasureFactorTest.java | 2 + .../hmc/DiffusionGradientTest.java | 25 ++++----- 6 files changed, 93 insertions(+), 15 deletions(-) diff --git a/ci/TestXML/testComposableContinuousModel.xml b/ci/TestXML/testComposableContinuousModel.xml index c6b747b795..0c45cd1e3a 100644 --- a/ci/TestXML/testComposableContinuousModel.xml +++ b/ci/TestXML/testComposableContinuousModel.xml @@ -342,4 +342,56 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Check log likelihood of joint model + + + + + + -219.43720322982512 + + diff --git a/src/dr/evomodel/treedatalikelihood/continuous/ContinuousTraitDataModel.java b/src/dr/evomodel/treedatalikelihood/continuous/ContinuousTraitDataModel.java index 451686bced..0857e21d3d 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/ContinuousTraitDataModel.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/ContinuousTraitDataModel.java @@ -57,6 +57,18 @@ public ContinuousTraitDataModel(String name, boolean[] missingIndicators, boolean useMissingIndices, final int dimTrait, PrecisionType precisionType) { + + this(name, parameter, missingIndicators, useMissingIndices, dimTrait, + parameter.getParameter(0).getDimension() / dimTrait, + precisionType); + } + + public ContinuousTraitDataModel(String name, + CompoundParameter parameter, + boolean[] missingIndicators, + boolean useMissingIndices, + final int dimTrait, final int numTraits, + PrecisionType precisionType) { super(name); this.parameter = parameter; addVariable(parameter); @@ -66,7 +78,7 @@ public ContinuousTraitDataModel(String name, this.missingIndicators = (useMissingIndices ? missingIndicators : new boolean[missingIndicators.length]); this.dimTrait = dimTrait; - this.numTraits = getParameter().getParameter(0).getDimension() / dimTrait; + this.numTraits = numTraits; this.precisionType = precisionType; diff --git a/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java b/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java index af693c7a93..71f0b47b78 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java @@ -81,10 +81,11 @@ public RepeatedMeasuresTraitDataModel(String name, boolean[] missindIndicators, boolean useMissingIndices, final int dimTrait, + final int numTraits, MatrixParameterInterface samplingPrecision, PrecisionType precisionType) { - super(name, parameter, missindIndicators, useMissingIndices, dimTrait, precisionType); + super(name, parameter, missindIndicators, useMissingIndices, dimTrait, numTraits, precisionType); this.childModel = childModel; this.traitName = name; @@ -329,6 +330,12 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { } String modelName = subModel.getModelName(); + if (subModel.getTraitDimension() != dimTrait) { + throw new XMLParseException("sub-model has trait dimension " + subModel.getTraitDimension() + + ", but sampling precision has dimension " + dimTrait); + } + + int numTraits = subModel.getTraitCount(); if (!scaleByTipHeight) { return new RepeatedMeasuresTraitDataModel( @@ -339,6 +346,7 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { // missingIndicators, true, dimTrait, + numTraits, // diffusionModel.getPrecisionParameter().getRowDimension(), samplingPrecision, precisionType @@ -351,6 +359,7 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { subModel.getDataMissingIndicators(), true, dimTrait, + subModel.getTraitCount(), samplingPrecision, precisionType ); diff --git a/src/dr/evomodel/treedatalikelihood/continuous/TreeScaledRepeatedMeasuresTraitDataModel.java b/src/dr/evomodel/treedatalikelihood/continuous/TreeScaledRepeatedMeasuresTraitDataModel.java index c24ad64626..7ac427d328 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/TreeScaledRepeatedMeasuresTraitDataModel.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/TreeScaledRepeatedMeasuresTraitDataModel.java @@ -50,9 +50,11 @@ public TreeScaledRepeatedMeasuresTraitDataModel(String name, boolean[] missingIndicators, boolean useMissingIndices, final int dimTrait, + final int numTraits, MatrixParameterInterface samplingPrecision, PrecisionType precisionType) { - super(name, childModel, parameter, missingIndicators, useMissingIndices, dimTrait, samplingPrecision, precisionType); + super(name, childModel, parameter, missingIndicators, useMissingIndices, dimTrait, numTraits, + samplingPrecision, precisionType); } @Override diff --git a/src/test/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasureFactorTest.java b/src/test/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasureFactorTest.java index 0e9db57f05..e9ce234bdd 100644 --- a/src/test/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasureFactorTest.java +++ b/src/test/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasureFactorTest.java @@ -157,6 +157,7 @@ public void setUp() throws Exception { // new boolean[3], true, dimTrait, + 1, samplingPrecisionParameter, PrecisionType.FULL); @@ -166,6 +167,7 @@ public void setUp() throws Exception { missingIndicators, true, dimTrait, + 1, samplingPrecisionParameterFull, PrecisionType.FULL); diff --git a/src/test/dr/evomodel/treedatalikelihood/hmc/DiffusionGradientTest.java b/src/test/dr/evomodel/treedatalikelihood/hmc/DiffusionGradientTest.java index 5b5870cc36..e1fb63da59 100644 --- a/src/test/dr/evomodel/treedatalikelihood/hmc/DiffusionGradientTest.java +++ b/src/test/dr/evomodel/treedatalikelihood/hmc/DiffusionGradientTest.java @@ -58,7 +58,7 @@ public class DiffusionGradientTest extends ContinuousTraitTest { private CompoundSymmetricMatrix precisionMatrix; private CachedMatrixInverse precisionMatrixInv; -// protected List missingIndices = new ArrayList(); + // protected List missingIndices = new ArrayList(); protected boolean[] missingIndicators; private MultivariateDiffusionModel diffusionModelVar; @@ -131,7 +131,6 @@ public void setUp() throws Exception { missingIndicators[29] = true; - // Tree createAlignment(PRIMATES_TAXON_SEQUENCE, Nucleotides.INSTANCE); treeModel = createPrimateTreeModel(); @@ -208,6 +207,7 @@ public void setUp() throws Exception { missingIndicators, true, dimTrait, + 1, samplingPrecision, PrecisionType.FULL); @@ -217,6 +217,7 @@ public void setUp() throws Exception { missingIndicators, true, dimTrait, + 1, samplingPrecisionInv, PrecisionType.FULL); @@ -248,8 +249,8 @@ public void testGradientBMWithMissing() { // Repeated Measures Model System.out.println("\nTest gradient precision repeated measures."); - testGradient(diffusionModel, diffusionProcessDelegate, dataModelRepeatedMeasures, rootPrior, meanRoot, precisionMatrix, false, null,null, samplingPrecision); - testGradient(diffusionModel, diffusionProcessDelegate, dataModelRepeatedMeasuresInv, rootPrior, meanRoot, precisionMatrix, false, null,null, samplingPrecisionInv); + testGradient(diffusionModel, diffusionProcessDelegate, dataModelRepeatedMeasures, rootPrior, meanRoot, precisionMatrix, false, null, null, samplingPrecision); + testGradient(diffusionModel, diffusionProcessDelegate, dataModelRepeatedMeasuresInv, rootPrior, meanRoot, precisionMatrix, false, null, null, samplingPrecisionInv); } public void testGradientDriftWithMissing() { @@ -292,8 +293,8 @@ public void testGradientDriftWithMissing() { // Repeated Measures Model System.out.println("\nTest gradient precision repeated measures."); - testGradient(diffusionModel, diffusionProcessDelegate, dataModelRepeatedMeasures, rootPrior, meanRoot, precisionMatrix, false, null,null, samplingPrecision); - testGradient(diffusionModel, diffusionProcessDelegate, dataModelRepeatedMeasuresInv, rootPrior, meanRoot, precisionMatrix, false, null,null, samplingPrecisionInv); + testGradient(diffusionModel, diffusionProcessDelegate, dataModelRepeatedMeasures, rootPrior, meanRoot, precisionMatrix, false, null, null, samplingPrecision); + testGradient(diffusionModel, diffusionProcessDelegate, dataModelRepeatedMeasuresInv, rootPrior, meanRoot, precisionMatrix, false, null, null, samplingPrecisionInv); } public void testGradientSingleDriftWithMissing() { @@ -355,8 +356,8 @@ public void testGradientSingleDriftSameMeanWithMissing() { // Repeated Measures Model System.out.println("\nTest gradient precision repeated measures."); - testGradient(diffusionModel, diffusionProcessDelegate, dataModelRepeatedMeasures, rootPrior, meanRoot, precisionMatrix, false, null,meanRoot, samplingPrecision); - testGradient(diffusionModel, diffusionProcessDelegate, dataModelRepeatedMeasuresInv, rootPrior, meanRoot, precisionMatrix, false, null,meanRoot, samplingPrecisionInv); + testGradient(diffusionModel, diffusionProcessDelegate, dataModelRepeatedMeasures, rootPrior, meanRoot, precisionMatrix, false, null, meanRoot, samplingPrecision); + testGradient(diffusionModel, diffusionProcessDelegate, dataModelRepeatedMeasuresInv, rootPrior, meanRoot, precisionMatrix, false, null, meanRoot, samplingPrecisionInv); } public void testGradientOUWithMissing() { @@ -418,8 +419,8 @@ public void testGradientOUWithMissing() { // Repeated Measures Model System.out.println("\nTest gradient precision repeated measures."); - testGradient(diffusionModel, diffusionProcessDelegate, dataModelRepeatedMeasures, rootPrior, meanRoot, precisionMatrix, false, null,null, samplingPrecision); - testGradient(diffusionModel, diffusionProcessDelegate, dataModelRepeatedMeasuresInv, rootPrior, meanRoot, precisionMatrix, false, null,null, samplingPrecisionInv); + testGradient(diffusionModel, diffusionProcessDelegate, dataModelRepeatedMeasures, rootPrior, meanRoot, precisionMatrix, false, null, null, samplingPrecision); + testGradient(diffusionModel, diffusionProcessDelegate, dataModelRepeatedMeasuresInv, rootPrior, meanRoot, precisionMatrix, false, null, null, samplingPrecisionInv); //************// // Single opt @@ -537,8 +538,8 @@ public void testGradientDiagonalOUWithMissing() { // Repeated Measures Model System.out.println("\nTest gradient precision repeated measures."); - testGradient(diffusionModel, diffusionProcessDelegate, dataModelRepeatedMeasures, rootPrior, meanRoot, precisionMatrix, false, strengthOfSelectionMatrixParam,null, samplingPrecision); - testGradient(diffusionModel, diffusionProcessDelegate, dataModelRepeatedMeasuresInv, rootPrior, meanRoot, precisionMatrix, false, strengthOfSelectionMatrixParam,null, samplingPrecisionInv); + testGradient(diffusionModel, diffusionProcessDelegate, dataModelRepeatedMeasures, rootPrior, meanRoot, precisionMatrix, false, strengthOfSelectionMatrixParam, null, samplingPrecision); + testGradient(diffusionModel, diffusionProcessDelegate, dataModelRepeatedMeasuresInv, rootPrior, meanRoot, precisionMatrix, false, strengthOfSelectionMatrixParam, null, samplingPrecisionInv); } private void testGradient(MultivariateDiffusionModel diffusionModel, From 5594d5f193c898e885c070243d43a1ac5b3be9b2 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Fri, 19 Aug 2022 16:29:44 -0700 Subject: [PATCH 106/196] removing code duplication --- .../ContinuousDataLikelihoodParser.java | 30 ++++++------------- 1 file changed, 9 insertions(+), 21 deletions(-) diff --git a/src/dr/evomodelxml/treedatalikelihood/ContinuousDataLikelihoodParser.java b/src/dr/evomodelxml/treedatalikelihood/ContinuousDataLikelihoodParser.java index 8a096112e4..47c434619f 100644 --- a/src/dr/evomodelxml/treedatalikelihood/ContinuousDataLikelihoodParser.java +++ b/src/dr/evomodelxml/treedatalikelihood/ContinuousDataLikelihoodParser.java @@ -165,31 +165,19 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { if (reconstructTraits) { // if (missingIndices != null && missingIndices.size() == 0) { - if (!dataModel.usesMissingIndices()) { - ProcessSimulationDelegate simulationDelegate = - delegate.getPrecisionType() == PrecisionType.SCALAR ? - new ConditionalOnTipsRealizedDelegate(traitName, treeModel, - diffusionModel, dataModel, rootPrior, rateTransformation, delegate) : - new MultivariateConditionalOnTipsRealizedDelegate(traitName, treeModel, - diffusionModel, dataModel, rootPrior, rateTransformation, delegate); + ProcessSimulationDelegate simulationDelegate = + delegate.getPrecisionType() == PrecisionType.SCALAR ? + new ConditionalOnTipsRealizedDelegate(traitName, treeModel, + diffusionModel, dataModel, rootPrior, rateTransformation, delegate) : + new MultivariateConditionalOnTipsRealizedDelegate(traitName, treeModel, + diffusionModel, dataModel, rootPrior, rateTransformation, delegate); - TreeTraitProvider traitProvider = new ProcessSimulation(treeDataLikelihood, simulationDelegate); + TreeTraitProvider traitProvider = new ProcessSimulation(treeDataLikelihood, simulationDelegate); - treeDataLikelihood.addTraits(traitProvider.getTreeTraits()); + treeDataLikelihood.addTraits(traitProvider.getTreeTraits()); - } else { - - ProcessSimulationDelegate simulationDelegate = - delegate.getPrecisionType() == PrecisionType.SCALAR ? - new ConditionalOnTipsRealizedDelegate(traitName, treeModel, - diffusionModel, dataModel, rootPrior, rateTransformation, delegate) : - new MultivariateConditionalOnTipsRealizedDelegate(traitName, treeModel, - diffusionModel, dataModel, rootPrior, rateTransformation, delegate); - - TreeTraitProvider traitProvider = new ProcessSimulation(treeDataLikelihood, simulationDelegate); - - treeDataLikelihood.addTraits(traitProvider.getTreeTraits()); + if (dataModel.usesMissingIndices()) { ProcessSimulationDelegate fullConditionalDelegate = new TipRealizedValuesViaFullConditionalDelegate( traitName, treeModel, diffusionModel, dataModel, rootPrior, rateTransformation, delegate); From 9ce5367a3494169fb50b2a3d06fd79566331accf Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Fri, 19 Aug 2022 16:31:40 -0700 Subject: [PATCH 107/196] this shouldn't be inside the 'if' statement --- .../ContinuousDataLikelihoodParser.java | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/dr/evomodelxml/treedatalikelihood/ContinuousDataLikelihoodParser.java b/src/dr/evomodelxml/treedatalikelihood/ContinuousDataLikelihoodParser.java index 47c434619f..e1445c8168 100644 --- a/src/dr/evomodelxml/treedatalikelihood/ContinuousDataLikelihoodParser.java +++ b/src/dr/evomodelxml/treedatalikelihood/ContinuousDataLikelihoodParser.java @@ -183,13 +183,6 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { traitName, treeModel, diffusionModel, dataModel, rootPrior, rateTransformation, delegate); treeDataLikelihood.addTraits(new ProcessSimulation(treeDataLikelihood, fullConditionalDelegate).getTreeTraits()); - int[] partitionDimensions = dataModel.getPartitionDimensions(); - if (partitionDimensions != null) { - PartitionedTreeTraitProvider partitionedProvider = - new PartitionedTreeTraitProvider(treeDataLikelihood.getTreeTraits(), partitionDimensions); - treeDataLikelihood.addTraits(partitionedProvider.getTreeTraits()); - } - // String partialTraitName = getPartiallyMissingTraitName(traitName); // @@ -201,6 +194,14 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { // // treeDataLikelihood.addTraits(partialTraitProvider.getTreeTraits()); } + + int[] partitionDimensions = dataModel.getPartitionDimensions(); + if (partitionDimensions != null) { + PartitionedTreeTraitProvider partitionedProvider = + new PartitionedTreeTraitProvider(treeDataLikelihood.getTreeTraits(), partitionDimensions); + treeDataLikelihood.addTraits(partitionedProvider.getTreeTraits()); + } + } return treeDataLikelihood; From f1be42293ae1874539afb2e1bbdcf47d3ccbc582 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Tue, 23 Aug 2022 11:53:40 -0700 Subject: [PATCH 108/196] new framework that will (hopefully) let operators sample from the traits above a model even if not at the tips of a tree without having to use any special xml elements --- .../ConditionalTraitSimulationHelper.java | 96 +++++++++++++++++++ .../continuous/ContinuousTraitDataModel.java | 5 + .../ContinuousTraitPartialsProvider.java | 8 +- .../continuous/ElementaryVectorDataModel.java | 5 + .../continuous/EmptyTraitDataModel.java | 5 + .../IntegratedFactorAnalysisLikelihood.java | 5 + .../continuous/JointPartialsProvider.java | 79 ++++++++++++--- .../RepeatedMeasuresTraitDataModel.java | 59 +++++++++++- .../ContinuousDataLikelihoodParser.java | 3 +- .../FactorAnalysisOperatorAdaptor.java | 18 ++-- 10 files changed, 255 insertions(+), 28 deletions(-) create mode 100644 src/dr/evomodel/treedatalikelihood/continuous/ConditionalTraitSimulationHelper.java diff --git a/src/dr/evomodel/treedatalikelihood/continuous/ConditionalTraitSimulationHelper.java b/src/dr/evomodel/treedatalikelihood/continuous/ConditionalTraitSimulationHelper.java new file mode 100644 index 0000000000..9cd0bddec0 --- /dev/null +++ b/src/dr/evomodel/treedatalikelihood/continuous/ConditionalTraitSimulationHelper.java @@ -0,0 +1,96 @@ +package dr.evomodel.treedatalikelihood.continuous; + +import dr.evolution.tree.TreeTrait; +import dr.evomodel.treedatalikelihood.TreeDataLikelihood; + +import java.util.HashMap; + +/** + * @author Gabriel Hassler + * @author Marc Suchard + */ + + +public class ConditionalTraitSimulationHelper { + + private final TreeDataLikelihood treeLikelihood; + private final TreeTrait treeTrait; + private final ContinuousTraitPartialsProvider topDataModel; + private final HashMap parentMap; + + + public ConditionalTraitSimulationHelper(TreeDataLikelihood treeLikelihood) { + this.treeLikelihood = treeLikelihood; + ContinuousDataLikelihoodDelegate delegate = (ContinuousDataLikelihoodDelegate) treeLikelihood.getDataLikelihoodDelegate(); + this.topDataModel = delegate.getDataModel(); + + this.parentMap = new HashMap<>(); + makeParentMap(topDataModel, parentMap); + + this.treeTrait = treeLikelihood.getTreeTrait(topDataModel.getTipTraitName()); + } + + public TreeTrait getTreeTrait() { + return treeTrait; + } + + private class ParentMapHelper { + public final int traitOffset; + public final int traitDimension; + public final ContinuousTraitPartialsProvider parent; + + private ParentMapHelper(ContinuousTraitPartialsProvider parent, int traitOffset, int traitDimension) { + this.traitOffset = traitOffset; + this.traitDimension = traitDimension; + this.parent = parent; + } + } + + private void makeParentMap(ContinuousTraitPartialsProvider model, + HashMap map) { + + int offset = 0; + for (ContinuousTraitPartialsProvider child : model.getChildModels()) { + + map.put(child, new ParentMapHelper(model, offset, child.getTraitDimension())); + makeParentMap(child, map); + offset += child.getTraitDimension(); + } + } + + public double[] drawTraitsAbove(ContinuousTraitPartialsProvider model) { + int dimTrait = model.getTraitDimension(); + + if (model == topDataModel) { + return (double[]) treeTrait.getTrait(treeLikelihood.getTree(), null); + } + + ParentMapHelper helper = parentMap.get(model); + + double[] fullTraitsAbove = drawTraitsBelow(helper.parent); + if (helper.traitOffset == 0 && helper.traitDimension == helper.parent.getDataDimension()) { + return fullTraitsAbove; + } + + int nTaxa = treeLikelihood.getTree().getExternalNodeCount(); + double[] traitsAbove = new double[nTaxa * dimTrait]; + + int fullOffset = helper.traitOffset; + int thisOffset = 0; + int dimAbove = helper.parent.getDataDimension(); + for (int i = 0; i < nTaxa; i++) { + System.arraycopy(fullTraitsAbove, fullOffset, traitsAbove, thisOffset, helper.traitDimension); + fullOffset += dimAbove; + thisOffset += dimTrait; + } + + return traitsAbove; + } + + public double[] drawTraitsBelow(ContinuousTraitPartialsProvider model) { + double[] aboveTraits = drawTraitsAbove(model); + return model.drawTraitsBelowConditionalOnDataAndTraitsAbove(aboveTraits); + } + + +} diff --git a/src/dr/evomodel/treedatalikelihood/continuous/ContinuousTraitDataModel.java b/src/dr/evomodel/treedatalikelihood/continuous/ContinuousTraitDataModel.java index 0857e21d3d..3a86655f7f 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/ContinuousTraitDataModel.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/ContinuousTraitDataModel.java @@ -127,6 +127,11 @@ public boolean usesMissingIndices() { return useMissingIndices; } + @Override + public ContinuousTraitPartialsProvider[] getChildModels() { + return new ContinuousTraitPartialsProvider[0]; + } + @Override public List getMissingIndices() { return ContinuousTraitPartialsProvider.indicatorToIndices(missingIndicators); // TODO: finish deprecating diff --git a/src/dr/evomodel/treedatalikelihood/continuous/ContinuousTraitPartialsProvider.java b/src/dr/evomodel/treedatalikelihood/continuous/ContinuousTraitPartialsProvider.java index 6b7205e61a..8ff63ab2bd 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/ContinuousTraitPartialsProvider.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/ContinuousTraitPartialsProvider.java @@ -72,6 +72,12 @@ default boolean[] getTraitMissingIndicators() { // returns null for no missing t boolean usesMissingIndices(); + ContinuousTraitPartialsProvider[] getChildModels(); + + default double[] drawTraitsBelowConditionalOnDataAndTraitsAbove(double[] aboveTraits) { + throw new RuntimeException("Conditional sampling not yet implemented for " + this.getClass()); + } + default boolean getDefaultAllowSingular() { return false; } @@ -81,7 +87,7 @@ default boolean suppliesWishartStatistics() { } default int[] getPartitionDimensions() { - return null; + return new int[]{getTraitDimension()}; } default void addTreeAndRateModel(Tree treeModel, ContinuousRateTransformation rateTransformation) { diff --git a/src/dr/evomodel/treedatalikelihood/continuous/ElementaryVectorDataModel.java b/src/dr/evomodel/treedatalikelihood/continuous/ElementaryVectorDataModel.java index 503b72bb91..0d344a56ba 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/ElementaryVectorDataModel.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/ElementaryVectorDataModel.java @@ -107,6 +107,11 @@ public boolean usesMissingIndices() { return false; } + @Override + public ContinuousTraitPartialsProvider[] getChildModels() { + return new ContinuousTraitPartialsProvider[0]; + } + public void setTipTraitDimParameters(int tip, int trait, int dim) { tipIndicator.setParameterValue(trait, tip); diff --git a/src/dr/evomodel/treedatalikelihood/continuous/EmptyTraitDataModel.java b/src/dr/evomodel/treedatalikelihood/continuous/EmptyTraitDataModel.java index c75dedd4b3..6edcdfbe2d 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/EmptyTraitDataModel.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/EmptyTraitDataModel.java @@ -98,6 +98,11 @@ public boolean usesMissingIndices() { return false; } + @Override + public ContinuousTraitPartialsProvider[] getChildModels() { + return new ContinuousTraitPartialsProvider[0]; + } + @Override public List getMissingIndices() { return null; diff --git a/src/dr/evomodel/treedatalikelihood/continuous/IntegratedFactorAnalysisLikelihood.java b/src/dr/evomodel/treedatalikelihood/continuous/IntegratedFactorAnalysisLikelihood.java index a782fb6b84..e1385bc8c4 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/IntegratedFactorAnalysisLikelihood.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/IntegratedFactorAnalysisLikelihood.java @@ -232,6 +232,11 @@ public boolean usesMissingIndices() { return true; } + @Override + public ContinuousTraitPartialsProvider[] getChildModels() { + return new ContinuousTraitPartialsProvider[0]; // LFM is not currently extendible + } + @Override public boolean getDefaultAllowSingular() { return true; diff --git a/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java b/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java index b7e76e8faa..3ffc65742f 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java @@ -6,7 +6,6 @@ import dr.inference.model.*; import dr.math.matrixAlgebra.WrappedMatrix; import dr.math.matrixAlgebra.WrappedVector; -import dr.math.matrixAlgebra.missingData.MissingOps; import dr.xml.*; import org.ejml.data.DenseMatrix64F; import org.ejml.ops.CommonOps; @@ -28,7 +27,8 @@ public class JointPartialsProvider extends AbstractModel implements ContinuousTr private final int dataDim; private final List missingIndices; - private final boolean[] missingIndicators; + private final boolean[] missingDataIndicators; + private final boolean[] missingTraitIndicators; private final boolean defaultAllowSingular; private final Boolean computeDeterminant; // TODO: Maybe pass as argument? @@ -56,8 +56,26 @@ public JointPartialsProvider(String name, ContinuousTraitPartialsProvider[] prov this.traitDim = traitDim; this.dataDim = dataDim; - this.missingIndicators = setupMissingIndicators(); - this.missingIndices = ContinuousTraitPartialsProvider.indicatorToIndices(missingIndicators); + + boolean[][] subTraitMissingInds = new boolean[providers.length][0]; + boolean[][] subDataMissingInds = new boolean[providers.length][0]; + int[] traitDims = new int[providers.length]; + int[] dataDims = new int[providers.length]; + + for (int i = 0; i < providers.length; i++) { + subTraitMissingInds[i] = providers[i].getTraitMissingIndicators(); + subDataMissingInds[i] = providers[i].getDataMissingIndicators(); + traitDims[i] = providers[i].getDataDimension(); + traitDims[i] = providers[i].getTraitDimension(); + } + + int nTaxa = providers[0].getParameter().getParameterCount(); + + + this.missingDataIndicators = mergeIndicators(subDataMissingInds, dataDims, nTaxa, dataDim); + this.missingTraitIndicators = mergeIndicators(subTraitMissingInds, traitDims, nTaxa, traitDim); + + this.missingIndices = ContinuousTraitPartialsProvider.indicatorToIndices(missingDataIndicators); this.defaultAllowSingular = setDefaultAllowSingular(); this.computeDeterminant = defaultAllowSingular; // TODO: not perfect behavior, should be based on actual value of `allowSingular` @@ -80,18 +98,36 @@ public JointPartialsProvider(String name, ContinuousTraitPartialsProvider[] prov } - private boolean[] setupMissingIndicators() { - int nTaxa = providers[0].getParameter().getParameterCount(); - boolean[] indicators = new boolean[dataDim * nTaxa]; - boolean[][] subIndicators = new boolean[providers.length][0]; - for (int i = 0; i < providers.length; i++) { - subIndicators[i] = providers[i].getDataMissingIndicators(); - } +// private boolean[] setupMissingIndicators() { +// int nTaxa = providers[0].getParameter().getParameterCount(); +// boolean[] indicators = new boolean[dataDim * nTaxa]; +// boolean[][] subIndicators = new boolean[providers.length][0]; +// for (int i = 0; i < providers.length; i++) { +// subIndicators[i] = providers[i].getDataMissingIndicators(); +// } +// for (int taxonI = 0; taxonI < nTaxa; taxonI++) { +// int offset = taxonI * dataDim; +// +// for (int providerI = 0; providerI < providers.length; providerI++) { +// int srcDim = providers[providerI].getDataDimension(); +// int srcOffset = taxonI * srcDim; +// System.arraycopy(subIndicators[providerI], srcOffset, indicators, offset, srcDim); +// offset += srcDim; +// } +// } +// +// return indicators; +// } + + private boolean[] mergeIndicators(boolean[][] subIndicators, int[] dims, int nTaxa, int dim) { + + boolean[] indicators = new boolean[dim * nTaxa]; + for (int taxonI = 0; taxonI < nTaxa; taxonI++) { - int offset = taxonI * dataDim; + int offset = taxonI * dim; for (int providerI = 0; providerI < providers.length; providerI++) { - int srcDim = providers[providerI].getDataDimension(); + int srcDim = dims[providerI]; int srcOffset = taxonI * srcDim; System.arraycopy(subIndicators[providerI], srcOffset, indicators, offset, srcDim); offset += srcDim; @@ -101,6 +137,11 @@ private boolean[] setupMissingIndicators() { return indicators; } + @Override + public boolean[] getTraitMissingIndicators() { + return missingTraitIndicators; + } + @Override public boolean bufferTips() { @@ -144,6 +185,11 @@ public int[] getPartitionDimensions() { return dims; } + @Override + public double[] drawTraitsBelowConditionalOnDataAndTraitsAbove(double[] aboveTraits) { + return aboveTraits; + } + @Override public PrecisionType getPrecisionType() { return precisionType; @@ -218,7 +264,7 @@ public List getMissingIndices() { @Override public boolean[] getDataMissingIndicators() { - return missingIndicators; + return missingDataIndicators; } @Override @@ -235,6 +281,11 @@ public boolean usesMissingIndices() { return useMissingIndices; } + @Override + public ContinuousTraitPartialsProvider[] getChildModels() { + return providers; + } + @Override protected void handleModelChangedEvent(Model model, Object object, int index) { fireModelChanged(); // sub-providers should handle everything else diff --git a/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java b/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java index 71f0b47b78..b195bc0443 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java @@ -39,10 +39,7 @@ import dr.inference.model.MatrixParameterInterface; import dr.inference.model.Parameter; import dr.inference.model.Variable; -import dr.math.matrixAlgebra.CholeskyDecomposition; -import dr.math.matrixAlgebra.IllegalDimension; -import dr.math.matrixAlgebra.Matrix; -import dr.math.matrixAlgebra.WrappedVector; +import dr.math.matrixAlgebra.*; import dr.math.matrixAlgebra.missingData.MissingOps; import dr.xml.*; import org.ejml.data.DenseMatrix64F; @@ -275,6 +272,60 @@ public void chainRuleWrtVariance(double[] gradient, NodeRef node) { // Do nothing } + @Override + public ContinuousTraitPartialsProvider[] getChildModels() { + return new ContinuousTraitPartialsProvider[]{childModel}; + } + + @Override + public double[] drawTraitsBelowConditionalOnDataAndTraitsAbove(double[] aboveTraits) { + if (numTraits > 1) { + throw new RuntimeException("not yet implemented"); + } + + double[] belowTraits = new double[aboveTraits.length]; + int nTaxa = belowTraits.length / dimTrait; + + DenseMatrix64F P = DenseMatrix64F.wrap(dimTrait, dimTrait, samplingPrecisionParameter.getParameterValues()); + DenseMatrix64F Q = new DenseMatrix64F(dimTrait, dimTrait); + DenseMatrix64F V = new DenseMatrix64F(dimTrait, dimTrait); + + double[] p0 = new double[dimTrait * dimTrait]; + DenseMatrix64F P0 = DenseMatrix64F.wrap(dimTrait, dimTrait, p0); + + int[] wrappedIndices = new int[dimTrait]; + for (int i = 0; i < dimTrait; i++) { + wrappedIndices[i] = i; + } + + WrappedVector n = new WrappedVector.Raw(new double[dimTrait]); + + int offset = 0; + for (int i = 0; i < nTaxa; i++) { + double[] partial = childModel.getTipPartial(i, false); + System.arraycopy(partial, precisionType.getPrecisionOffset(dimTrait), p0, 0, + precisionType.getPrecisionLength(dimTrait)); + + WrappedVector.Indexed m0 = new WrappedVector.Indexed(partial, precisionType.getMeanOffset(dimTrait), wrappedIndices, dimTrait); + WrappedVector.Indexed x = new WrappedVector.Indexed(aboveTraits, offset, wrappedIndices, dimTrait); + + + CommonOps.add(P0, P, Q); + MissingOps.safeInvert2(Q, V, false); + + MissingOps.weightedAverage(m0, P0, x, P, n, V, dimTrait); + + for (int j = 0; j < dimTrait; j++) { + belowTraits[offset + j] = n.get(j); + } + + offset += dimTrait; + + } + + return belowTraits; + } + private static final boolean DEBUG = false; diff --git a/src/dr/evomodelxml/treedatalikelihood/ContinuousDataLikelihoodParser.java b/src/dr/evomodelxml/treedatalikelihood/ContinuousDataLikelihoodParser.java index e1445c8168..8c7c28bfcc 100644 --- a/src/dr/evomodelxml/treedatalikelihood/ContinuousDataLikelihoodParser.java +++ b/src/dr/evomodelxml/treedatalikelihood/ContinuousDataLikelihoodParser.java @@ -195,8 +195,9 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { // treeDataLikelihood.addTraits(partialTraitProvider.getTreeTraits()); } + //TODO: remove below (should let ConditionalTraitSimulationHelper figure everything out) int[] partitionDimensions = dataModel.getPartitionDimensions(); - if (partitionDimensions != null) { + if (partitionDimensions.length > 1) { PartitionedTreeTraitProvider partitionedProvider = new PartitionedTreeTraitProvider(treeDataLikelihood.getTreeTraits(), partitionDimensions); treeDataLikelihood.addTraits(partitionedProvider.getTreeTraits()); diff --git a/src/dr/inference/operators/factorAnalysis/FactorAnalysisOperatorAdaptor.java b/src/dr/inference/operators/factorAnalysis/FactorAnalysisOperatorAdaptor.java index 1b059a81c6..f2b2bed99c 100644 --- a/src/dr/inference/operators/factorAnalysis/FactorAnalysisOperatorAdaptor.java +++ b/src/dr/inference/operators/factorAnalysis/FactorAnalysisOperatorAdaptor.java @@ -1,7 +1,7 @@ package dr.inference.operators.factorAnalysis; -import dr.evolution.tree.TreeTrait; import dr.evomodel.treedatalikelihood.TreeDataLikelihood; +import dr.evomodel.treedatalikelihood.continuous.ConditionalTraitSimulationHelper; import dr.evomodel.treedatalikelihood.continuous.IntegratedFactorAnalysisLikelihood; import dr.inference.model.*; import dr.math.matrixAlgebra.Matrix; @@ -10,9 +10,6 @@ import java.util.ArrayList; -import static dr.evomodel.treedatalikelihood.preorder.AbstractRealizedContinuousTraitDelegate.REALIZED_TIP_TRAIT; -import static dr.evomodelxml.treedatalikelihood.ContinuousDataLikelihoodParser.FACTOR_NAME; - /** * @author Marc A. Suchard * @author Gabriel Hassler @@ -237,11 +234,11 @@ class IntegratedFactors extends Abstract { private final IntegratedFactorAnalysisLikelihood factorLikelihood; private final TreeDataLikelihood treeLikelihood; + private final ConditionalTraitSimulationHelper factorSimulationHelper; private final Parameter precision; private final CompoundParameter data; - private final TreeTrait factorTrait; private double[] factors; public IntegratedFactors(IntegratedFactorAnalysisLikelihood factorLikelihood, @@ -252,10 +249,15 @@ public IntegratedFactors(IntegratedFactorAnalysisLikelihood factorLikelihood, this.precision = factorLikelihood.getPrecision(); this.data = factorLikelihood.getParameter(); + this.factorSimulationHelper = new ConditionalTraitSimulationHelper(treeLikelihood); - factorTrait = treeLikelihood.getTreeTrait(factorLikelihood.getTipTraitName()); + //TODO: (below) +// if (factorSimulationHelper.getTreeTrait().getTraitName() != factorLikelihood.getTipTraitName()) { +// throw new RuntimeException("Tip trait names must match: '" + +// factorSimulationHelper.getTreeTrait().getTraitName() + "' != '" + +// factorLikelihood.getTipTraitName()); +// } - assert (factorTrait != null); } @Override @@ -290,7 +292,7 @@ public double getColumnPrecision(int index) { @Override public void drawFactors() { - factors = (double[]) factorTrait.getTrait(treeLikelihood.getTree(), null); + factors = factorSimulationHelper.drawTraitsAbove(factorLikelihood); if (DEBUG) { System.err.println("factors: " + new Vector(factors)); From 1dfc69160129c884b00665d0d114f7462a25952e Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Tue, 23 Aug 2022 17:21:31 -0700 Subject: [PATCH 109/196] need to actually sample a value, not just use the mean --- .../RepeatedMeasuresTraitDataModel.java | 6 +- src/dr/math/matrixAlgebra/ReadableMatrix.java | 10 ++++ src/dr/math/matrixAlgebra/ReadableVector.java | 9 +++ .../matrixAlgebra/missingData/MissingOps.java | 55 +++++++++++++++++++ 4 files changed, 77 insertions(+), 3 deletions(-) diff --git a/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java b/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java index b195bc0443..24af4da5bf 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java @@ -315,9 +315,9 @@ public double[] drawTraitsBelowConditionalOnDataAndTraitsAbove(double[] aboveTra MissingOps.weightedAverage(m0, P0, x, P, n, V, dimTrait); - for (int j = 0; j < dimTrait; j++) { - belowTraits[offset + j] = n.get(j); - } + double[] sample = MissingOps.nextPossiblyDegenerateNormal(n, V); + + System.arraycopy(sample, 0, belowTraits, offset, dimTrait); offset += dimTrait; diff --git a/src/dr/math/matrixAlgebra/ReadableMatrix.java b/src/dr/math/matrixAlgebra/ReadableMatrix.java index 32fb9e4945..8d4f90fd3e 100644 --- a/src/dr/math/matrixAlgebra/ReadableMatrix.java +++ b/src/dr/math/matrixAlgebra/ReadableMatrix.java @@ -41,6 +41,16 @@ public interface ReadableMatrix extends ReadableVector { class Utils { + public static double[][] toMatrixArray(ReadableMatrix Y) { // defeats the purpose of wrapping but useful helper function sometimes + double[][] X = new double[Y.getMajorDim()][Y.getMinorDim()]; + for (int i = 0; i < Y.getMajorDim(); i++) { + for (int j = 0; j < Y.getMinorDim(); j++) { + X[i][j] = Y.get(i, j); + } + } + return X; + } + public static double[] toArray(ReadableMatrix matrix) { double[] array = new double[matrix.getDim()]; int offset = 0; diff --git a/src/dr/math/matrixAlgebra/ReadableVector.java b/src/dr/math/matrixAlgebra/ReadableVector.java index 675544e68f..6829952090 100644 --- a/src/dr/math/matrixAlgebra/ReadableVector.java +++ b/src/dr/math/matrixAlgebra/ReadableVector.java @@ -205,5 +205,14 @@ public static double norm(ReadableVector vector) { return Math.sqrt(innerProduct(vector, vector)); } + + public static double[] toArray(ReadableVector v) { + int dim = v.getDim(); + double[] x = new double[dim]; + for (int i = 0; i < dim; i++) { + x[i] = v.get(i); + } + return x; + } } } diff --git a/src/dr/math/matrixAlgebra/missingData/MissingOps.java b/src/dr/math/matrixAlgebra/missingData/MissingOps.java index d401943844..3dffaa9f48 100644 --- a/src/dr/math/matrixAlgebra/missingData/MissingOps.java +++ b/src/dr/math/matrixAlgebra/missingData/MissingOps.java @@ -1,6 +1,7 @@ package dr.math.matrixAlgebra.missingData; import dr.inference.model.MatrixParameterInterface; +import dr.math.distributions.MultivariateNormalDistribution; import dr.math.matrixAlgebra.*; import org.ejml.alg.dense.decomposition.lu.LUDecompositionAlt_D64; import org.ejml.alg.dense.linsol.lu.LinearSolverLu_D64; @@ -1068,6 +1069,60 @@ public static void addTransEquals(DenseMatrix64F M) { } } } + + public static double[] nextPossiblyDegenerateNormal(ReadableVector mean, DenseMatrix64F variance) { + int dim = mean.getDim(); + + if (variance.numCols != dim || variance.numRows != dim) { + throw new RuntimeException("Variance is a " + variance.numRows + "x" + variance.numCols + + " matrix but mean has dimension " + dim); + } + + int zeroCount = countZeroDiagonals(variance); + int nonZeroCount = countFiniteNonZeroDiagonals(variance); + if (zeroCount + nonZeroCount != dim) { + throw new RuntimeException("At least one diagonal element of the variance is infinity. " + + "Cannot sample from distribution with infinite variance"); + } + + + double[] buffer = ReadableVector.Utils.toArray(mean); + + + if (nonZeroCount == dim) { + double[][] cholesky = CholeskyDecomposition.execute(variance.data, 0, dim); + return MultivariateNormalDistribution.nextMultivariateNormalCholesky(buffer, cholesky); + } + + int[] latentIndices = new int[nonZeroCount]; + + int latI = 0; + for (int i = 0; i < dim; i++) { + if (variance.get(i, i) > 0) { + latentIndices[latI] = i; + latI++; + } + } + + WrappedMatrix.Indexed subVar = new WrappedMatrix.Indexed(variance.data, 0, + latentIndices, latentIndices, + dim, dim); + + + WrappedVector.Indexed subMean = new WrappedVector.Indexed(buffer, 0, latentIndices, dim); + + + double[] latentDraw = MultivariateNormalDistribution.nextMultivariateNormalVariance( + ReadableVector.Utils.toArray(subMean), ReadableMatrix.Utils.toMatrixArray(subVar)); + + for (int i = 0; i < latentIndices.length; i++) { + buffer[latentIndices[i]] = latentDraw[i]; + } + + return buffer; + } + + } // public static void safeSolveSymmPosDef(DenseMatrix64F A, From 81a5a301c73a45f5b7c937752de6a08516ec39c5 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Tue, 23 Aug 2022 17:22:23 -0700 Subject: [PATCH 110/196] printing new report to make sure factors are getting sample correctly --- .../FactorAnalysisOperatorAdaptor.java | 53 ++++++++++++++++++- .../NewLoadingsGibbsOperator.java | 6 ++- 2 files changed, 57 insertions(+), 2 deletions(-) diff --git a/src/dr/inference/operators/factorAnalysis/FactorAnalysisOperatorAdaptor.java b/src/dr/inference/operators/factorAnalysis/FactorAnalysisOperatorAdaptor.java index f2b2bed99c..c6078a263b 100644 --- a/src/dr/inference/operators/factorAnalysis/FactorAnalysisOperatorAdaptor.java +++ b/src/dr/inference/operators/factorAnalysis/FactorAnalysisOperatorAdaptor.java @@ -14,7 +14,7 @@ * @author Marc A. Suchard * @author Gabriel Hassler */ -public interface FactorAnalysisOperatorAdaptor { +public interface FactorAnalysisOperatorAdaptor extends Reportable { int getNumberOfTaxa(); @@ -327,6 +327,57 @@ public ArrayList getLikelihoods() { return likelihoods; } + @Override + public String getReport() { + int repeats = 10000; + + int nTaxa = treeLikelihood.getTree().getExternalNodeCount(); + int nFactors = factorLikelihood.getNumberOfFactors(); + int dim = nFactors * nTaxa; + + double[] mean = new double[dim]; + double[][] cov = new double[dim][dim]; + + for (int i = 0; i < repeats; i++) { + drawFactors(); + for (int j = 0; j < dim; j++) { + mean[j] += factors[j]; + cov[j][j] += factors[j] * factors[j]; + + for (int k = (j + 1); k < dim; k++) { + cov[j][k] += factors[j] * factors[k]; + cov[k][j] = cov[j][k]; + } + } + } + + for (int i = 0; i < dim; i++) { + mean[i] /= repeats; + for (int j = 0; j < dim; j++) { + cov[i][j] /= repeats; + } + } + + for (int i = 0; i < dim; i++) { + for (int j = 0; j < dim; j++) { + cov[i][j] -= mean[i] * mean[j]; + } + } + + StringBuilder sb = new StringBuilder(this.getClass() + " report:\n"); + sb.append("Factor mean:\n"); + sb.append(new Vector(mean)); + sb.append("\n\n"); + sb.append("Factor covariance:\n"); + sb.append(new Matrix(cov)); + sb.append("\n\nTaxon order:"); + for (int i = 0; i < nTaxa; i++) { + sb.append(" " + treeLikelihood.getTree().getTaxonId(i)); + } + + return sb.toString(); + } + private static final boolean DEBUG = false; } } diff --git a/src/dr/inference/operators/factorAnalysis/NewLoadingsGibbsOperator.java b/src/dr/inference/operators/factorAnalysis/NewLoadingsGibbsOperator.java index e440170df9..b016269e0e 100644 --- a/src/dr/inference/operators/factorAnalysis/NewLoadingsGibbsOperator.java +++ b/src/dr/inference/operators/factorAnalysis/NewLoadingsGibbsOperator.java @@ -374,6 +374,11 @@ public static ConstrainedSampler parse(String name) { @Override public String getReport() { + StringBuilder sb = new StringBuilder(); + sb.append(adaptor.getReport()); + sb.append("\n\n"); + + int repeats = 20000; int nFac = adaptor.getNumberOfFactors(); @@ -419,7 +424,6 @@ public String getReport() { } } - StringBuilder sb = new StringBuilder(); sb.append(getOperatorName() + "Report:\n"); sb.append("Loadings mean:\n"); sb.append(new Vector(loadMean)); From aa100c639bd614c9c5cf97c6f3fd5b56862d5c9e Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Thu, 1 Sep 2022 11:09:20 -0700 Subject: [PATCH 111/196] MultivariateConditionalOnTipsRealizedDelegate now samples tips appropriately when some traits are observed with infinite precision and others are observed with non-zero, finite precision --- ...iateConditionalOnTipsRealizedDelegate.java | 56 ++++++++++++++----- 1 file changed, 41 insertions(+), 15 deletions(-) diff --git a/src/dr/evomodel/treedatalikelihood/preorder/MultivariateConditionalOnTipsRealizedDelegate.java b/src/dr/evomodel/treedatalikelihood/preorder/MultivariateConditionalOnTipsRealizedDelegate.java index 575526f5be..4bce0f7745 100644 --- a/src/dr/evomodel/treedatalikelihood/preorder/MultivariateConditionalOnTipsRealizedDelegate.java +++ b/src/dr/evomodel/treedatalikelihood/preorder/MultivariateConditionalOnTipsRealizedDelegate.java @@ -8,6 +8,7 @@ import dr.math.matrixAlgebra.WrappedMatrix; import dr.math.matrixAlgebra.WrappedVector; import dr.math.matrixAlgebra.missingData.MissingOps; +import mpi.Comm; import org.ejml.data.DenseMatrix64F; import org.ejml.ops.CommonOps; @@ -63,7 +64,7 @@ protected void simulateTraitForRoot(final int offsetSample, final int offsetPart new WrappedVector.Raw(mean, 0, dimTrait), totalVar, dimTrait - ); + ); final DenseMatrix64F cholesky = getCholeskyOfVariance(totalVar, dimTrait); @@ -193,13 +194,38 @@ private void simulateTraitForExternalNode(final int nodeIndex, final int[] observed = indices.getComplement(); final int[] missing = indices.getArray(); - final DenseMatrix64F V1 = getVarianceBranch(branchPrecision); + for (int i : observed) { + P0.set(i, i, 0.0); + } + + + //TODO: code below likely has some duplication with other classes + final DenseMatrix64F P1 = getPrecisionBranch(branchPrecision); + final DenseMatrix64F P = new DenseMatrix64F(dimTrait, dimTrait); + + + CommonOps.add(P0, P1, P); + + final DenseMatrix64F V = new DenseMatrix64F(dimTrait, dimTrait); + CommonOps.invert(P, V); + + DenseMatrix64F traitSample = wrap(sample, offsetParent, dimTrait, 1); + DenseMatrix64F tipMean = wrap(partialNodeBuffer, offsetPartial, dimTrait, 1); + + DenseMatrix64F P0x = new DenseMatrix64F(dimTrait, 1); + DenseMatrix64F P1x = new DenseMatrix64F(dimTrait, 1); + + CommonOps.mult(P0, tipMean, P0x); + CommonOps.mult(P1, traitSample, P1x); + CommonOps.addEquals(P1x, P0x); + CommonOps.mult(V, P1x, P0x); + // final DenseMatrix64F V1 = new DenseMatrix64F(dimTrait, dimTrait); // CommonOps.scale(1.0 / branchPrecision, Vd, V1); ConditionalVarianceAndTransform2 transform = new ConditionalVarianceAndTransform2( - V1, missing, observed + V, missing, observed ); // TODO Cache (via delegated function) final DenseMatrix64F cP0 = new DenseMatrix64F(missing.length, missing.length); @@ -207,15 +233,15 @@ private void simulateTraitForExternalNode(final int nodeIndex, final WrappedVector cM2 = transform.getConditionalMean( partialNodeBuffer, offsetPartial, // Tip value - sample, offsetParent); // Parent value + P0x.data, 0); // Parent value - final DenseMatrix64F cP1 = transform.getConditionalPrecision(); + final DenseMatrix64F cV2 = transform.getConditionalVariance(); - final DenseMatrix64F cP2 = new DenseMatrix64F(missing.length, missing.length); - final DenseMatrix64F cV2 = new DenseMatrix64F(missing.length, missing.length); - CommonOps.add(cP0, cP1, cP2); //TODO: Shouldn't P0 = 0 always in this situation ? +// final DenseMatrix64F cP2 = new DenseMatrix64F(missing.length, missing.length); +// final DenseMatrix64F cV2 = new DenseMatrix64F(missing.length, missing.length); +// CommonOps.add(cP0, cP1, cP2); //TODO: Shouldn't P0 = 0 always in this situation ? - safeInvert2(cP2, cV2, false); +// safeInvert2(cP2, cV2, false); // TODO Drift? // assert (!likelihoodDelegate.getDiffusionProcessDelegate().hasDrift()); @@ -243,7 +269,7 @@ private void simulateTraitForExternalNode(final int nodeIndex, final WrappedVector M0 = new WrappedVector.Raw(partialNodeBuffer, offsetPartial, dimTrait); final WrappedVector M1 = new WrappedVector.Raw(sample, offsetParent, dimTrait); - final DenseMatrix64F P1 = new DenseMatrix64F(dimTrait, dimTrait); +// final DenseMatrix64F P1 = new DenseMatrix64F(dimTrait, dimTrait); CommonOps.scale(branchPrecision, Pd, P1); final WrappedVector newSample = new WrappedVector.Raw(sample, offsetSample, dimTrait); @@ -257,8 +283,8 @@ private void simulateTraitForExternalNode(final int nodeIndex, System.err.println(""); System.err.println("cP0: " + cP0); System.err.println("cM2: " + cM2); - System.err.println("cP1: " + cP1); - System.err.println("cP2: " + cP2); +// System.err.println("cP1: " + cP1); +// System.err.println("cP2: " + cP2); System.err.println("cV2: " + cV2); // System.err.println("cC2: " + new Matrix(cC2)); System.err.println("SS: " + newSample); @@ -366,7 +392,7 @@ private void simulateTraitForInternalNode(final int offsetSample, System.err.println("SS: " + new WrappedVector.Raw(sample, offsetSample, dimTrait)); System.err.println(""); - if (!check(M2)) { + if (!check(M2)) { System.exit(-1); } } @@ -391,7 +417,7 @@ private boolean check(ReadableVector m2) { return true; } - DenseMatrix64F getPrecisionBranch(double branchPrecision){ + DenseMatrix64F getPrecisionBranch(double branchPrecision) { if (!hasDrift) { DenseMatrix64F P1 = new DenseMatrix64F(dimTrait, dimTrait); CommonOps.scale(branchPrecision, Pd, P1); @@ -401,7 +427,7 @@ DenseMatrix64F getPrecisionBranch(double branchPrecision){ } } - DenseMatrix64F getVarianceBranch(double branchPrecision){ + DenseMatrix64F getVarianceBranch(double branchPrecision) { if (!hasDrift) { final DenseMatrix64F V1 = new DenseMatrix64F(dimTrait, dimTrait); CommonOps.scale(1.0 / branchPrecision, Vd, V1); From 9fa71df168b1da90c55970bf2d248aad16b8c7f3 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Thu, 1 Sep 2022 11:12:55 -0700 Subject: [PATCH 112/196] report + bug fix in ConditionalTraitSimulationHelper --- .../ConditionalTraitSimulationHelper.java | 31 +++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/src/dr/evomodel/treedatalikelihood/continuous/ConditionalTraitSimulationHelper.java b/src/dr/evomodel/treedatalikelihood/continuous/ConditionalTraitSimulationHelper.java index 9cd0bddec0..f544211bb9 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/ConditionalTraitSimulationHelper.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/ConditionalTraitSimulationHelper.java @@ -2,6 +2,8 @@ import dr.evolution.tree.TreeTrait; import dr.evomodel.treedatalikelihood.TreeDataLikelihood; +import dr.math.matrixAlgebra.Vector; +import dr.xml.Reportable; import java.util.HashMap; @@ -11,7 +13,7 @@ */ -public class ConditionalTraitSimulationHelper { +public class ConditionalTraitSimulationHelper implements Reportable { private final TreeDataLikelihood treeLikelihood; private final TreeTrait treeTrait; @@ -62,6 +64,7 @@ public double[] drawTraitsAbove(ContinuousTraitPartialsProvider model) { int dimTrait = model.getTraitDimension(); if (model == topDataModel) { + treeLikelihood.fireModelChanged(); return (double[]) treeTrait.getTrait(treeLikelihood.getTree(), null); } @@ -77,7 +80,7 @@ public double[] drawTraitsAbove(ContinuousTraitPartialsProvider model) { int fullOffset = helper.traitOffset; int thisOffset = 0; - int dimAbove = helper.parent.getDataDimension(); + int dimAbove = helper.parent.getTraitDimension(); for (int i = 0; i < nTaxa; i++) { System.arraycopy(fullTraitsAbove, fullOffset, traitsAbove, thisOffset, helper.traitDimension); fullOffset += dimAbove; @@ -93,4 +96,28 @@ public double[] drawTraitsBelow(ContinuousTraitPartialsProvider model) { } + @Override + public String getReport() { + int repeats = 10000; + + double[] mean = drawTraitsAbove(topDataModel); + for (int i = 1; i < repeats; i++) { + double[] draw = drawTraitsAbove(topDataModel); + for (int j = 0; j < draw.length; j++) { + mean[j] += draw[j]; + } + } + + for (int i = 0; i < mean.length; i++) { + mean[i] /= repeats; + } + + StringBuilder sb = new StringBuilder("Trait simulation report:\n\ttree trait mean: "); + sb.append(new Vector(mean)); + sb.append("\n"); + + return sb.toString(); + } + + } From 3b04a195cfa0104c3c766ad52abd261ed8ae9018 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Thu, 1 Sep 2022 11:13:38 -0700 Subject: [PATCH 113/196] bug fix in JointPartialsProvider --- .../treedatalikelihood/continuous/JointPartialsProvider.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java b/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java index 3ffc65742f..dde03a299c 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java @@ -65,7 +65,7 @@ public JointPartialsProvider(String name, ContinuousTraitPartialsProvider[] prov for (int i = 0; i < providers.length; i++) { subTraitMissingInds[i] = providers[i].getTraitMissingIndicators(); subDataMissingInds[i] = providers[i].getDataMissingIndicators(); - traitDims[i] = providers[i].getDataDimension(); + dataDims[i] = providers[i].getDataDimension(); traitDims[i] = providers[i].getTraitDimension(); } From 8709cd42443a78050f8125857c2c3fbf9f5fdf19 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Thu, 1 Sep 2022 11:14:49 -0700 Subject: [PATCH 114/196] small report edits --- .../ContinuousDataLikelihoodDelegate.java | 2 +- .../FactorAnalysisOperatorAdaptor.java | 44 ++++++++++--------- 2 files changed, 25 insertions(+), 21 deletions(-) diff --git a/src/dr/evomodel/treedatalikelihood/continuous/ContinuousDataLikelihoodDelegate.java b/src/dr/evomodel/treedatalikelihood/continuous/ContinuousDataLikelihoodDelegate.java index 7bbe6f0ab8..c0f23223f5 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/ContinuousDataLikelihoodDelegate.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/ContinuousDataLikelihoodDelegate.java @@ -602,7 +602,7 @@ public String getReport() { Matrix cVar = cVariance.getConditionalVariance(); sb.append("cMean #").append(tip).append(" ").append(new dr.math.matrixAlgebra.Vector(cMean)) - .append(" cVar [").append(cVar).append("]\n"); + .append("\ncVar [").append(cVar).append("]\n\n"); } } diff --git a/src/dr/inference/operators/factorAnalysis/FactorAnalysisOperatorAdaptor.java b/src/dr/inference/operators/factorAnalysis/FactorAnalysisOperatorAdaptor.java index c6078a263b..868ead2762 100644 --- a/src/dr/inference/operators/factorAnalysis/FactorAnalysisOperatorAdaptor.java +++ b/src/dr/inference/operators/factorAnalysis/FactorAnalysisOperatorAdaptor.java @@ -329,6 +329,10 @@ public ArrayList getLikelihoods() { @Override public String getReport() { + + StringBuilder sb = new StringBuilder(factorSimulationHelper.getReport()); + sb.append("\n\n"); + int repeats = 10000; int nTaxa = treeLikelihood.getTree().getExternalNodeCount(); @@ -336,40 +340,40 @@ public String getReport() { int dim = nFactors * nTaxa; double[] mean = new double[dim]; - double[][] cov = new double[dim][dim]; +// double[][] cov = new double[dim][dim]; for (int i = 0; i < repeats; i++) { drawFactors(); for (int j = 0; j < dim; j++) { mean[j] += factors[j]; - cov[j][j] += factors[j] * factors[j]; - - for (int k = (j + 1); k < dim; k++) { - cov[j][k] += factors[j] * factors[k]; - cov[k][j] = cov[j][k]; - } +// cov[j][j] += factors[j] * factors[j]; +// +// for (int k = (j + 1); k < dim; k++) { +// cov[j][k] += factors[j] * factors[k]; +// cov[k][j] = cov[j][k]; +// } } } for (int i = 0; i < dim; i++) { mean[i] /= repeats; - for (int j = 0; j < dim; j++) { - cov[i][j] /= repeats; - } +// for (int j = 0; j < dim; j++) { +// cov[i][j] /= repeats; +// } } - for (int i = 0; i < dim; i++) { - for (int j = 0; j < dim; j++) { - cov[i][j] -= mean[i] * mean[j]; - } - } +// for (int i = 0; i < dim; i++) { +// for (int j = 0; j < dim; j++) { +// cov[i][j] -= mean[i] * mean[j]; +// } +// } - StringBuilder sb = new StringBuilder(this.getClass() + " report:\n"); - sb.append("Factor mean:\n"); + sb.append(this.getClass() + " report:\n"); + sb.append("\tfactor mean: "); sb.append(new Vector(mean)); - sb.append("\n\n"); - sb.append("Factor covariance:\n"); - sb.append(new Matrix(cov)); +// sb.append("\n\n"); +// sb.append("\tFactor covariance:\n"); +// sb.append(new Matrix(cov)); sb.append("\n\nTaxon order:"); for (int i = 0; i < nTaxa; i++) { sb.append(" " + treeLikelihood.getTree().getTaxonId(i)); From d75ee03ae1237df6dd00e7818c583e0e3f9c09c0 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Thu, 1 Sep 2022 11:15:55 -0700 Subject: [PATCH 115/196] more tests in joint models --- ci/TestXML/testComposableContinuousModel.xml | 166 ++++++++++++++++--- 1 file changed, 142 insertions(+), 24 deletions(-) diff --git a/ci/TestXML/testComposableContinuousModel.xml b/ci/TestXML/testComposableContinuousModel.xml index 0c45cd1e3a..1d4062dab8 100644 --- a/ci/TestXML/testComposableContinuousModel.xml +++ b/ci/TestXML/testComposableContinuousModel.xml @@ -72,6 +72,7 @@ + (((taxon10:0.0104828508210571,taxon7:0.06945686994340126):0.05472171377912479,(taxon6:0.09666701733874086,(taxon1:0.025715465624764816,(taxon9:0.14154200426014169,taxon2:0.013832276521980338):0.08913272474958198):0.1193630006138079):0.453503452970523):0.1964588174059454,(((taxon3:0.004017979531503662,taxon5:0.0496801467418641):0.01371646355496644,taxon4:0.07279286149575269):0.15588410980923004,taxon8:0.006692284455141036):0.0833467339290923); @@ -101,7 +102,7 @@ - + @@ -150,12 +151,7 @@ - - - - - - + @@ -221,9 +217,7 @@ - - - + @@ -238,8 +232,10 @@ - - + + @@ -247,7 +243,9 @@ - + @@ -286,25 +284,32 @@ - + - - - - - - - + + + + + + + + - @@ -343,7 +348,7 @@ - + @@ -357,7 +362,6 @@ - @@ -394,4 +398,118 @@ -219.43720322982512 + + + + + + + + + + + + + + + + + + + + + + Check tree traits of joint model + + + + + + 0.4094070007299706 -2.24953543781794 0.48055688839804134 -1.2450733457121714 -3.2469439563030846 + 0.04928935034422466 0.22661236170041593 0.5537342826605709 -1.96013988863524 1.0399035619529968 + -1.5077368382590066 -2.871273657193001 -0.15273913720193377 0.40122987143877253 0.8142903555954035 + -4.817316126423066 -1.8698289223184474 -3.114350566098892 -4.835016486088989 -0.6659232857946336 + -2.7635628060625095 0.6161109857603151 -3.26025692543908 0.057446483772309875 -2.8555356817882966 + -3.9554763023260193 -1.3399296398856677 -2.074714013598168 1.7199674427752143 -4.026480323740543 + 0.39830231666701366 -3.646974806747039 -4.073446946916874 -2.6010057516632514 -3.692093879588291 + 0.14235084027268385 -3.915916559009929 -1.3290168738908505 -4.363954789791933 -4.668275958335926 + -1.323223546578447 -2.7321226935628147 2.3449041383646545 0.972234005959308 6.199575018956239 + -1.9123076901469176 -0.46811846986497585 -2.4041753233595955 1.9848993152781986 1.2805936511058462 + 1.3638556864384555 6.036626207796871 -1.5105776917480398 -0.3996418639328567 -2.059326195034373 + 2.0011912748568648 1.813714920917846 0.39674460651667687 5.219790433608523 -2.1343417424236577 + -1.0921154829916304 -2.0660124489186273 1.646180446420658 1.928397442565256 -0.5873562092792781 + 3.0817066766585413 -2.1874097633690326 -2.099117291701873 -1.1988262214827046 1.7053062655149915 + + + + + + Check extended traits of joint model + + + + + + 1.0020035507295688 0.47666956608333655 -0.27842820595651574 0.5312273751569592 -0.170931044867757 + -2.571386657932635 -1.1714343741350604 -2.0624271010428856 -3.2932663317787956 -4.346789729563105 + -2.0307293191999634 -2.4245227236415303 -2.1588074254268577 1.447402369415613 -2.144887085395567 + 1.6037674333356335 -2.7958635434342796 1.9993932831866914 -0.7148946458337377 2.046801494311694 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + From 4a5acde7d368c6ee3fb6355606a9bb472e627d71 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Thu, 1 Sep 2022 12:34:19 -0700 Subject: [PATCH 116/196] removing cod duplication, consistent ordering between two test xml, fewer iterations in report to decrease run time --- ci/TestXML/testComposableContinuousModel.xml | 2 +- ci/TestXML/testIntegratedFactors.xml | 137 ++++++++++++++---- .../FactorAnalysisOperatorAdaptor.java | 53 +------ 3 files changed, 116 insertions(+), 76 deletions(-) diff --git a/ci/TestXML/testComposableContinuousModel.xml b/ci/TestXML/testComposableContinuousModel.xml index 1d4062dab8..f0b79871b6 100644 --- a/ci/TestXML/testComposableContinuousModel.xml +++ b/ci/TestXML/testComposableContinuousModel.xml @@ -446,7 +446,7 @@ Check extended traits of joint model - + diff --git a/ci/TestXML/testIntegratedFactors.xml b/ci/TestXML/testIntegratedFactors.xml index b467d6c96f..01f19463df 100644 --- a/ci/TestXML/testIntegratedFactors.xml +++ b/ci/TestXML/testIntegratedFactors.xml @@ -112,7 +112,7 @@ - + Check factor mean @@ -120,12 +120,12 @@ - -0.3947245651739801 0.1985702448066101 -0.777982500792344 -0.41804381033712273 -1.014750914187345 - -0.9192096089079205 -0.48414198124749513 -0.7396077600665577 -0.9261028229528564 -0.4575182320243145 + -0.3947245651737603 -0.9192096089078916 0.1985702448067741 -0.4841419812475749 -0.7779825007922909 + -0.7396077600664626 -0.4180438103370534 -0.926102822952771 -1.0147509141874191 -0.4575182320241688 - + Check factor covariance @@ -133,16 +133,26 @@ - 1.2972036001808647 0.19015817719503048 0.3487316076738125 0.2858103753935666 0.17589493286629931 -0.003931068145064925 0.08383442865198204 -0.06429482382503635 -0.03983414016969422 0.15631096492075208 - 0.19015817719503048 0.3343286695497909 0.10664056032828739 0.08941973812500237 0.08740991375793783 0.0876863389617858 0.17659157367275852 0.036564579742915854 0.05143388834937501 0.12857285548407787 - 0.3487316076738125 0.10664056032828739 1.1076707304614501 0.33156914017149575 0.11780777675141962 -0.08971919095980539 -0.01389611400147473 -0.325481582692576 -0.16672658561923384 0.02128571831581275 - 0.2858103753935666 0.08941973812500237 0.33156914017149575 0.8409916340671089 0.10193659338381167 -0.06000449361990779 -0.0021171943879493846 -0.13465404059388714 -0.21805747714952525 0.03290918091321065 - 0.17589493286629931 0.08740991375793783 0.11780777675141962 0.10193659338381167 0.5430539640740335 0.1798710707481486 0.147297909491424 0.1281458985254797 0.15610496804384696 0.5826228771676469 - -0.003931068145064925 0.0876863389617858 -0.08971919095980539 -0.06000449361990779 0.1798710707481486 1.2023691485002246 0.3830034818964805 0.39301310229132014 0.4361970693373678 0.5006241090696903 - 0.08383442865198204 0.17659157367275852 -0.01389611400147473 -0.0021171943879493846 0.147297909491424 0.3830034818964805 0.7145656540824348 0.2606072303159408 0.293545600810603 0.36447870216954925 - -0.06429482382503635 0.036564579742915854 -0.325481582692576 -0.13465404059388714 0.1281458985254797 0.39301310229132014 0.2606072303159408 1.0369990420276736 0.46448001347243917 0.426464176683794 - -0.03983414016969422 0.05143388834937501 -0.16672658561923384 -0.21805747714952525 0.15610496804384696 0.4361970693373678 0.293545600810603 0.46448001347243917 1.187877199346758 0.4845201385827522 - 0.15631096492075208 0.12857285548407787 0.02128571831581275 0.03290918091321065 0.5826228771676469 0.5006241090696903 0.36447870216954925 0.426464176683794 0.4845201385827522 1.532781923611651 + 1.2972036002614686 -0.003931068249860936 0.19015817731099105 0.08383442855821996 0.34873160779625323 + -0.06429482391932775 0.28581037551464306 -0.03983414026802958 0.1758949328560675 0.15631096484987364 + -0.003931068127824541 1.2023691485002246 0.08768633893445184 0.38300348187874533 -0.08971919096179816 + 0.39301310227051545 -0.06000449367408245 0.43619706930758184 0.179871070805078 0.5006241090202366 + 0.19015817727563444 0.0876863388568921 0.33432866966597885 0.1765915735789609 0.1066405604508418 + 0.03656457964849655 0.08941973824607885 0.05143388825100945 0.08740991374770601 0.12857285541327582 + 0.08383442866923682 0.38300348189670785 0.17659157364551534 0.714565654064927 -0.013896114003468919 + 0.2606072302951361 -0.0021171944421178885 0.29354560078070335 0.14729790954835695 0.3644787021200955 + 0.3487316077541891 -0.08971919106456598 0.10664056044436165 -0.013896114095174576 1.1076707305837772 + -0.32548158278685163 0.33156914029257223 -0.16672658571750046 0.11780777674118781 0.021285718245004945 + -0.0642948238078255 0.3930131022914338 0.03656457971560386 0.26060723029820565 -0.32548158269454686 + 1.0369990420068689 -0.13465404064804098 0.46448001344242584 0.12814589858233236 0.4264641766344539 + 0.28581037547405685 -0.06000449372469634 0.08941973824096294 -0.0021171944816269746 0.33156914029405016 + -0.134654040688206 0.8409916341880717 -0.21805747724772798 0.10193659337357985 0.03290918084230424 + -0.0398341401526047 0.43619706933759517 0.051433888322007425 0.29354560079286784 -0.16672658562125886 + 0.4644800134515208 -0.21805747720369717 1.1878771993167447 0.1561049681005929 0.48452013853341214 + 0.1758949329467896 0.17987107064339897 0.08740991387401209 0.1472979093976951 0.11780777687397403 + 0.12814589843122082 0.10193659350488815 0.15610496794560452 0.543053964063688 0.5826228770968216 + 0.15631096493788027 0.5006241090699177 0.1285728554567449 0.3644787021519278 0.02128571831372105 + 0.426464176663103 0.032909180858936044 0.4845201385528526 0.5826228772243653 1.5327819235620836 @@ -180,7 +190,7 @@ - + Check factor mean @@ -188,11 +198,12 @@ - 0.5233207543393005 -0.538249994667213 -0.4247255851472689 0.3086433100521155 -0.5988289029481765 -1.1242975431518119 -0.6170405358071159 -1.4121915793691633 -1.2474809620748961 -1.3179237818525087 + 0.5233207543401477 -1.1242975431519824 -0.5382499946666712 -0.6170405358072628 -0.4247255851464047 + -1.4121915793695337 0.30864331005284384 -1.247480962075059 -0.5988289029476606 -1.3179237818527425 - + Check factor covariance @@ -200,16 +211,86 @@ - 0.2313753599643178 0.02036904751514612 0.012781121018292652 -0.00847539497476646 0.041642182957730256 0.18141157005509603 0.09525849914876114 0.08222019385847586 0.07633649746079686 0.10024747950167873 - 0.02036904751514612 0.19722129507636055 0.010734117255537967 0.010817376445970694 0.04580057517887326 0.016315505819980897 0.017023654188099556 0.006106516666946906 -0.004169117082629558 0.004111393662544254 - 0.012781121018292652 0.010734117255537967 0.20778311312778897 0.015934108432702487 0.050591093359003025 0.015085269638653915 0.009893827850611731 0.0289167165787882 -0.004697547545279489 0.00736801020587002 - -0.00847539497476646 0.010817376445970694 0.015934108432702487 0.3130285365958798 0.09274722444433792 -0.29918766112374706 -0.24684404981937946 -0.26492270867851114 -0.47625564618851135 -0.32143158682861483 - 0.041642182957730256 0.04580057517887326 0.050591093359003025 0.09274722444433792 1.1244501132567848 -0.12543378515183917 -0.11282602229871691 -0.1177920117121793 -0.17864341055358782 -0.23020208477243148 - 0.18141157005509603 0.016315505819980897 0.015085269638653915 -0.29918766112374706 -0.12543378515183917 2.283853666395089 1.2516293021919864 1.111955365298627 1.1910189558418551 1.4095150799435032 - 0.09525849914876114 0.017023654188099556 0.009893827850611731 -0.24684404981937946 -0.11282602229871691 1.2516293021919864 1.6288447037411515 0.9069373907785803 0.973846873952084 1.1504554017409419 - 0.08222019385847586 0.006106516666946906 0.0289167165787882 -0.26492270867851114 -0.1177920117121793 1.111955365298627 0.9069373907785803 1.655153892022554 1.0454996644940593 1.1583373788746485 - 0.07633649746079686 -0.004169117082629558 -0.004697547545279489 -0.47625564618851135 -0.17864341055358782 1.1910189558418551 0.973846873952084 1.0454996644940593 1.8596663578088055 1.2490912602875142 - 0.10024747950167873 0.004111393662544254 0.00736801020587002 -0.32143158682861483 -0.23020208477243148 1.4095150799435032 1.1504554017409419 1.1583373788746485 1.2490912602875142 2.502584859999388 + 0.23137535954606392 0.18141156998244035 0.020369047277085883 0.09525849908504917 0.012781121026137043 + 0.0822201938226649 -0.008475395022401244 0.07633649741308023 0.041642182791292726 0.10024747948488635 + 0.18141157000091523 2.2838536663604145 0.016315505777469843 1.2516293021578804 0.015085269642444112 + 1.1119553652793002 -0.2991876611452641 1.1910189558228694 -0.12543378517530276 1.4095150799303156 + 0.020369047097346993 0.016315505747438886 0.19722129483864137 0.01702365412433076 0.010734117263837106 + 0.006106516631022265 0.01081737639867697 -0.004169117130289343 0.04580057501300416 0.0041113936458655725 + 0.09525849909482942 1.251629302157312 0.01702365414568482 1.6288447037069318 0.009893827854550787 + 0.9069373907593672 -0.24684404984058567 0.9738468739333257 -0.11282602232218072 1.1504554017278679 + 0.012781120600152462 0.015085269566181605 0.010734117017705103 0.009893827786890567 0.20778311313563336 + 0.02891671654305208 0.01593410838518139 -0.0046975475928412525 0.050591093192906555 0.007368010189190781 + 0.08222019380442591 1.1119553652639524 0.006106516624373085 0.9069373907444742 0.02891671658262379 + 1.6551538920033408 -0.2649227086999372 1.0454996644751873 -0.11779201173568685 1.1583373788615745 + -0.008475395392906648 -0.29918766119631734 0.010817376207910456 -0.24684404988300818 0.015934108440433192 + -0.2649227087142939 0.3130285365484724 -0.4762556462361409 0.09274722427835513 -0.32143158684536977 + 0.07633649740678161 1.191018955807067 -0.0041691171250287954 0.9738468739178643 -0.0046975475413340555 + 1.0454996644748462 -0.4762556462098519 1.8596663577901609 -0.17864341057690955 1.2490912602744402 + 0.04164218253959007 -0.1254337852244051 0.04580057494115408 -0.11282602236261154 0.050591093366961104 + -0.11779201174809127 0.09274722439704419 -0.17864341060123753 1.124450113090802 -0.23020208478901008 + 0.10024747944757395 1.4095150799089424 0.004111393620069575 1.1504554017068358 0.007368010209730752 + 1.1583373788553217 -0.32143158685012363 1.2490912602687558 -0.23020208479577428 2.5025848599860865 + - + + diff --git a/src/dr/inference/operators/factorAnalysis/FactorAnalysisOperatorAdaptor.java b/src/dr/inference/operators/factorAnalysis/FactorAnalysisOperatorAdaptor.java index 868ead2762..836a1256e2 100644 --- a/src/dr/inference/operators/factorAnalysis/FactorAnalysisOperatorAdaptor.java +++ b/src/dr/inference/operators/factorAnalysis/FactorAnalysisOperatorAdaptor.java @@ -91,7 +91,7 @@ public double getLoadingsValue(int dim) { @Override public String getReport() { - int repeats = 1000000; + int repeats = 10000; int nFac = getNumberOfFactors(); int nTaxa = getNumberOfTaxa(); int dim = nFac * nTaxa; @@ -107,12 +107,12 @@ public String getReport() { for (int j = 0; j < nTaxa; j++) { for (int k = 0; k < nFac; k++) { double x = getFactorValue(k, j); - sums[k * nTaxa + j] += x; + sums[k + j * nFac] += x; for (int l = 0; l < nTaxa; l++) { for (int m = 0; m < nFac; m++) { double y = getFactorValue(m, l); - sumSquares[k * nTaxa + j][m * nTaxa + l] += x * y; + sumSquares[k + j * nFac][m + l * nFac] += x * y; } } } @@ -332,50 +332,9 @@ public String getReport() { StringBuilder sb = new StringBuilder(factorSimulationHelper.getReport()); sb.append("\n\n"); - - int repeats = 10000; - - int nTaxa = treeLikelihood.getTree().getExternalNodeCount(); - int nFactors = factorLikelihood.getNumberOfFactors(); - int dim = nFactors * nTaxa; - - double[] mean = new double[dim]; -// double[][] cov = new double[dim][dim]; - - for (int i = 0; i < repeats; i++) { - drawFactors(); - for (int j = 0; j < dim; j++) { - mean[j] += factors[j]; -// cov[j][j] += factors[j] * factors[j]; -// -// for (int k = (j + 1); k < dim; k++) { -// cov[j][k] += factors[j] * factors[k]; -// cov[k][j] = cov[j][k]; -// } - } - } - - for (int i = 0; i < dim; i++) { - mean[i] /= repeats; -// for (int j = 0; j < dim; j++) { -// cov[i][j] /= repeats; -// } - } - -// for (int i = 0; i < dim; i++) { -// for (int j = 0; j < dim; j++) { -// cov[i][j] -= mean[i] * mean[j]; -// } -// } - - sb.append(this.getClass() + " report:\n"); - sb.append("\tfactor mean: "); - sb.append(new Vector(mean)); -// sb.append("\n\n"); -// sb.append("\tFactor covariance:\n"); -// sb.append(new Matrix(cov)); - sb.append("\n\nTaxon order:"); - for (int i = 0; i < nTaxa; i++) { + sb.append(super.getReport()); + sb.append("Taxon order:"); + for (int i = 0; i < getNumberOfTaxa(); i++) { sb.append(" " + treeLikelihood.getTree().getTaxonId(i)); } From 390536716eba4d63e7e8ac8288c0a01dc78db363 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Thu, 1 Sep 2022 15:33:41 -0700 Subject: [PATCH 117/196] forgot to switch the seed back before testing. this should work (but still be a lot faster than before) --- .../operators/factorAnalysis/FactorAnalysisOperatorAdaptor.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dr/inference/operators/factorAnalysis/FactorAnalysisOperatorAdaptor.java b/src/dr/inference/operators/factorAnalysis/FactorAnalysisOperatorAdaptor.java index 836a1256e2..e29130952a 100644 --- a/src/dr/inference/operators/factorAnalysis/FactorAnalysisOperatorAdaptor.java +++ b/src/dr/inference/operators/factorAnalysis/FactorAnalysisOperatorAdaptor.java @@ -91,7 +91,7 @@ public double getLoadingsValue(int dim) { @Override public String getReport() { - int repeats = 10000; + int repeats = 20000; int nFac = getNumberOfFactors(); int nTaxa = getNumberOfTaxa(); int dim = nFac * nTaxa; From 30c1a0d91036c3650ead1c9407d0274140b74f64 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Thu, 1 Sep 2022 16:34:14 -0700 Subject: [PATCH 118/196] delegate now stores model extension helper --- .../continuous/ContinuousDataLikelihoodDelegate.java | 9 +++++++++ .../ContinuousDataLikelihoodParser.java | 2 ++ .../factorAnalysis/FactorAnalysisOperatorAdaptor.java | 4 +++- 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/dr/evomodel/treedatalikelihood/continuous/ContinuousDataLikelihoodDelegate.java b/src/dr/evomodel/treedatalikelihood/continuous/ContinuousDataLikelihoodDelegate.java index c0f23223f5..31d5ad588b 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/ContinuousDataLikelihoodDelegate.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/ContinuousDataLikelihoodDelegate.java @@ -83,6 +83,7 @@ public class ContinuousDataLikelihoodDelegate extends AbstractModel implements D private boolean allowSingular = false; private TreeDataLikelihood callbackLikelihood = null; + private ConditionalTraitSimulationHelper extensionHelper = null; public ContinuousDataLikelihoodDelegate(Tree tree, DiffusionProcessDelegate diffusionProcessDelegate, @@ -305,6 +306,10 @@ public TreeDataLikelihood getCallbackLikelihood() { return callbackLikelihood; } + public ConditionalTraitSimulationHelper getExtensionHelper() { + return extensionHelper; + } + public PrecisionType getPrecisionType() { return precisionType; } @@ -655,6 +660,10 @@ public void setCallback(TreeDataLikelihood treeDataLikelihood) { this.callbackLikelihood = treeDataLikelihood; } + public void setExtensionHelper() { + this.extensionHelper = new ConditionalTraitSimulationHelper(callbackLikelihood); + } + @Override public void setComputePostOrderStatisticsOnly(boolean computePostOrderStatistic) { this.computeRemainders = !computePostOrderStatistic; diff --git a/src/dr/evomodelxml/treedatalikelihood/ContinuousDataLikelihoodParser.java b/src/dr/evomodelxml/treedatalikelihood/ContinuousDataLikelihoodParser.java index 8c7c28bfcc..eb39510a03 100644 --- a/src/dr/evomodelxml/treedatalikelihood/ContinuousDataLikelihoodParser.java +++ b/src/dr/evomodelxml/treedatalikelihood/ContinuousDataLikelihoodParser.java @@ -205,6 +205,8 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { } + delegate.setExtensionHelper(); + return treeDataLikelihood; } diff --git a/src/dr/inference/operators/factorAnalysis/FactorAnalysisOperatorAdaptor.java b/src/dr/inference/operators/factorAnalysis/FactorAnalysisOperatorAdaptor.java index e29130952a..f73984451b 100644 --- a/src/dr/inference/operators/factorAnalysis/FactorAnalysisOperatorAdaptor.java +++ b/src/dr/inference/operators/factorAnalysis/FactorAnalysisOperatorAdaptor.java @@ -2,6 +2,7 @@ import dr.evomodel.treedatalikelihood.TreeDataLikelihood; import dr.evomodel.treedatalikelihood.continuous.ConditionalTraitSimulationHelper; +import dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate; import dr.evomodel.treedatalikelihood.continuous.IntegratedFactorAnalysisLikelihood; import dr.inference.model.*; import dr.math.matrixAlgebra.Matrix; @@ -249,7 +250,8 @@ public IntegratedFactors(IntegratedFactorAnalysisLikelihood factorLikelihood, this.precision = factorLikelihood.getPrecision(); this.data = factorLikelihood.getParameter(); - this.factorSimulationHelper = new ConditionalTraitSimulationHelper(treeLikelihood); + this.factorSimulationHelper = + ((ContinuousDataLikelihoodDelegate) treeLikelihood.getDataLikelihoodDelegate()).getExtensionHelper(); //TODO: (below) // if (factorSimulationHelper.getTreeTrait().getTraitName() != factorLikelihood.getTipTraitName()) { From b9ed778351576c598cc079e6b0bca9eced564417 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Fri, 2 Sep 2022 17:01:30 -0700 Subject: [PATCH 119/196] bug fix in WrappedMatrix.Indexed + bug fixes in RepeatedMeasures --- .../continuous/RepeatedMeasuresTraitDataModel.java | 11 ++++++++--- src/dr/math/matrixAlgebra/WrappedMatrix.java | 10 +++++++--- src/dr/math/matrixAlgebra/missingData/MissingOps.java | 2 +- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java b/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java index 24af4da5bf..798bfa3ae7 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java @@ -51,7 +51,7 @@ * @author Marc A. Suchard * @author Gabriel Hassler */ -public class RepeatedMeasuresTraitDataModel extends ContinuousTraitDataModel implements ContinuousTraitPartialsProvider, +public class RepeatedMeasuresTraitDataModel extends ContinuousTraitDataModel implements FullPrecisionContinuousTraitPartialsProvider, ModelExtensionProvider.NormalExtensionProvider { private final String traitName; @@ -247,12 +247,17 @@ public DenseMatrix64F getExtensionVariance(NodeRef node) { return getExtensionVariance(); } + @Override + public MatrixParameterInterface getExtensionPrecision() { + return getExtensionPrecisionParameter(); //TODO: deprecate + } + public void getMeanTipVariances(DenseMatrix64F samplingVariance, DenseMatrix64F samplingComponent) { CommonOps.scale(1.0, samplingVariance, samplingComponent); } @Override - public MatrixParameterInterface getExtensionPrecision() { + public MatrixParameterInterface getExtensionPrecisionParameter() { checkVariableChanged(); return samplingPrecisionParameter; } @@ -313,7 +318,7 @@ public double[] drawTraitsBelowConditionalOnDataAndTraitsAbove(double[] aboveTra CommonOps.add(P0, P, Q); MissingOps.safeInvert2(Q, V, false); - MissingOps.weightedAverage(m0, P0, x, P, n, V, dimTrait); + MissingOps.safeWeightedAverage(m0, P0, x, P, n, V, dimTrait); double[] sample = MissingOps.nextPossiblyDegenerateNormal(n, V); diff --git a/src/dr/math/matrixAlgebra/WrappedMatrix.java b/src/dr/math/matrixAlgebra/WrappedMatrix.java index 367594188a..206e74b36c 100644 --- a/src/dr/math/matrixAlgebra/WrappedMatrix.java +++ b/src/dr/math/matrixAlgebra/WrappedMatrix.java @@ -301,11 +301,15 @@ final class Indexed extends Abstract { final private int[] indicesMajor; final private int[] indicesMinor; + final int dimMajorFull; + final int dimMinorFull; - public Indexed(double[] buffer, int offset, int[] indicesMajor, int[] indicesMinor, int dimMajor, int dimMinor) { - super(buffer, offset, dimMajor, dimMinor); + public Indexed(double[] buffer, int offset, int[] indicesMajor, int[] indicesMinor, int dimMajorFull, int dimMinorFull) { + super(buffer, offset, indicesMajor.length, indicesMinor.length); this.indicesMajor = indicesMajor; this.indicesMinor = indicesMinor; + this.dimMajorFull = dimMajorFull; + this.dimMinorFull = dimMinorFull; } @Override @@ -329,7 +333,7 @@ final public void set(int i, double x) { } private int getIndex(final int i, final int j) { - return offset + indicesMajor[i] * dimMajor + indicesMinor[j]; + return offset + indicesMajor[i] * dimMajorFull + indicesMinor[j]; } } diff --git a/src/dr/math/matrixAlgebra/missingData/MissingOps.java b/src/dr/math/matrixAlgebra/missingData/MissingOps.java index 3dffaa9f48..77cd0e2f0d 100644 --- a/src/dr/math/matrixAlgebra/missingData/MissingOps.java +++ b/src/dr/math/matrixAlgebra/missingData/MissingOps.java @@ -1109,7 +1109,7 @@ public static double[] nextPossiblyDegenerateNormal(ReadableVector mean, DenseMa dim, dim); - WrappedVector.Indexed subMean = new WrappedVector.Indexed(buffer, 0, latentIndices, dim); + WrappedVector.Indexed subMean = new WrappedVector.Indexed(buffer, 0, latentIndices, nonZeroCount); double[] latentDraw = MultivariateNormalDistribution.nextMultivariateNormalVariance( From 082906ff080a408556f501961c6c3b4ff5ed9e27 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Fri, 2 Sep 2022 17:06:19 -0700 Subject: [PATCH 120/196] removing unnecessary argument from WrappedVector --- .../continuous/RepeatedMeasuresTraitDataModel.java | 4 ++-- .../preorder/AbstractValuesViaFullConditionalDelegate.java | 2 +- .../MultivariateConditionalOnTipsRealizedDelegate.java | 4 ++-- src/dr/math/matrixAlgebra/WrappedVector.java | 4 ++-- src/dr/math/matrixAlgebra/missingData/MissingOps.java | 2 +- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java b/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java index 798bfa3ae7..cfc73b2661 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java @@ -311,8 +311,8 @@ public double[] drawTraitsBelowConditionalOnDataAndTraitsAbove(double[] aboveTra System.arraycopy(partial, precisionType.getPrecisionOffset(dimTrait), p0, 0, precisionType.getPrecisionLength(dimTrait)); - WrappedVector.Indexed m0 = new WrappedVector.Indexed(partial, precisionType.getMeanOffset(dimTrait), wrappedIndices, dimTrait); - WrappedVector.Indexed x = new WrappedVector.Indexed(aboveTraits, offset, wrappedIndices, dimTrait); + WrappedVector.Indexed m0 = new WrappedVector.Indexed(partial, precisionType.getMeanOffset(dimTrait), wrappedIndices); + WrappedVector.Indexed x = new WrappedVector.Indexed(aboveTraits, offset, wrappedIndices); CommonOps.add(P0, P, Q); diff --git a/src/dr/evomodel/treedatalikelihood/preorder/AbstractValuesViaFullConditionalDelegate.java b/src/dr/evomodel/treedatalikelihood/preorder/AbstractValuesViaFullConditionalDelegate.java index e344810c56..30437a7c81 100644 --- a/src/dr/evomodel/treedatalikelihood/preorder/AbstractValuesViaFullConditionalDelegate.java +++ b/src/dr/evomodel/treedatalikelihood/preorder/AbstractValuesViaFullConditionalDelegate.java @@ -90,7 +90,7 @@ protected double[] getTraitForNode(NodeRef node) { computeValueWithMissing(cM, // input mean transform.getConditionalCholesky(), // input variance, - new WrappedVector.Indexed(sample, sampleOffset, missing, missing.length), // output sample + new WrappedVector.Indexed(sample, sampleOffset, missing), // output sample transform.getTemporaryStorage()); System.err.println("cM: " + cM); diff --git a/src/dr/evomodel/treedatalikelihood/preorder/MultivariateConditionalOnTipsRealizedDelegate.java b/src/dr/evomodel/treedatalikelihood/preorder/MultivariateConditionalOnTipsRealizedDelegate.java index 4bce0f7745..bf8a5fb85d 100644 --- a/src/dr/evomodel/treedatalikelihood/preorder/MultivariateConditionalOnTipsRealizedDelegate.java +++ b/src/dr/evomodel/treedatalikelihood/preorder/MultivariateConditionalOnTipsRealizedDelegate.java @@ -252,7 +252,7 @@ private void simulateTraitForExternalNode(final int nodeIndex, MultivariateNormalDistribution.nextMultivariateNormalCholesky( cM2, // input mean new WrappedMatrix.WrappedDenseMatrix(cC2), 1.0, // input variance - new WrappedVector.Indexed(sample, offsetSample, missing, missing.length), // output sample + new WrappedVector.Indexed(sample, offsetSample, missing), // output sample tmpEpsilon); } else { double[][] cC2 = getCholeskyOfVariance(cV2.getData(), missing.length); @@ -261,7 +261,7 @@ private void simulateTraitForExternalNode(final int nodeIndex, MultivariateNormalDistribution.nextMultivariateNormalCholesky( cM2, // input mean new WrappedMatrix.ArrayOfArray(cC2), 1.0, // input variance - new WrappedVector.Indexed(sample, offsetSample, missing, missing.length), // output sample + new WrappedVector.Indexed(sample, offsetSample, missing), // output sample tmpEpsilon); } diff --git a/src/dr/math/matrixAlgebra/WrappedVector.java b/src/dr/math/matrixAlgebra/WrappedVector.java index 4181912b36..8701e6e6fa 100644 --- a/src/dr/math/matrixAlgebra/WrappedVector.java +++ b/src/dr/math/matrixAlgebra/WrappedVector.java @@ -130,8 +130,8 @@ final class Indexed extends Abstract { final private int[] indices; - public Indexed(double[] buffer, int offset, int[] indices, int dim) { - super(buffer, offset, dim); + public Indexed(double[] buffer, int offset, int[] indices) { + super(buffer, offset, indices.length); this.indices = indices; } diff --git a/src/dr/math/matrixAlgebra/missingData/MissingOps.java b/src/dr/math/matrixAlgebra/missingData/MissingOps.java index 77cd0e2f0d..fb14251e55 100644 --- a/src/dr/math/matrixAlgebra/missingData/MissingOps.java +++ b/src/dr/math/matrixAlgebra/missingData/MissingOps.java @@ -1109,7 +1109,7 @@ public static double[] nextPossiblyDegenerateNormal(ReadableVector mean, DenseMa dim, dim); - WrappedVector.Indexed subMean = new WrappedVector.Indexed(buffer, 0, latentIndices, nonZeroCount); + WrappedVector.Indexed subMean = new WrappedVector.Indexed(buffer, 0, latentIndices); double[] latentDraw = MultivariateNormalDistribution.nextMultivariateNormalVariance( From 1bdd732b0a774d9efd77e6a48316a560db248e6f Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Fri, 2 Sep 2022 17:26:46 -0700 Subject: [PATCH 121/196] RepeatedMeasuresWishartStatistics now uses new trait sampling framework --- .../ConditionalTraitSimulationHelper.java | 25 ++++++++++++++++++ .../RepeatedMeasuresWishartStatistics.java | 26 +++++++++---------- 2 files changed, 37 insertions(+), 14 deletions(-) diff --git a/src/dr/evomodel/treedatalikelihood/continuous/ConditionalTraitSimulationHelper.java b/src/dr/evomodel/treedatalikelihood/continuous/ConditionalTraitSimulationHelper.java index f544211bb9..d3bfff112c 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/ConditionalTraitSimulationHelper.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/ConditionalTraitSimulationHelper.java @@ -95,6 +95,31 @@ public double[] drawTraitsBelow(ContinuousTraitPartialsProvider model) { return model.drawTraitsBelowConditionalOnDataAndTraitsAbove(aboveTraits); } + public class JointSamples { + + private final double[] traitsAbove; + private final double[] traitsBelow; + + public JointSamples(double[] traitsAbove, double[] traitsBelow) { + this.traitsAbove = traitsAbove; + this.traitsBelow = traitsBelow; + } + + public double[] getTraitsAbove() { + return traitsAbove; + } + + public double[] getTraitsBelow() { + return traitsBelow; + } + } + + public JointSamples drawTraitsAboveAndBelow(ContinuousTraitPartialsProvider model) { + double[] aboveTraits = drawTraitsAbove(model); + double[] belowTraits = model.drawTraitsBelowConditionalOnDataAndTraitsAbove(aboveTraits); + return new JointSamples(aboveTraits, belowTraits); + } + @Override public String getReport() { diff --git a/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresWishartStatistics.java b/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresWishartStatistics.java index d69c6ed64b..9b61898d14 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresWishartStatistics.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresWishartStatistics.java @@ -1,10 +1,7 @@ package dr.evomodel.treedatalikelihood.continuous; import dr.evolution.tree.Tree; -import dr.evolution.tree.TreeTrait; import dr.evomodel.treedatalikelihood.TreeDataLikelihood; -import dr.evomodel.treedatalikelihood.preorder.ContinuousExtensionDelegate; -import dr.evomodel.treedatalikelihood.preorder.ModelExtensionProvider; import dr.inference.model.MatrixParameterInterface; import dr.math.distributions.WishartSufficientStatistics; import dr.math.interfaces.ConjugateWishartStatisticsProvider; @@ -19,10 +16,9 @@ public class RepeatedMeasuresWishartStatistics implements ConjugateWishartStatisticsProvider { - private final ModelExtensionProvider.NormalExtensionProvider traitModel; + private final FullPrecisionContinuousTraitPartialsProvider traitModel; private final Tree tree; - private final TreeTrait tipTrait; - private final ContinuousExtensionDelegate extensionDelegate; + private final ConditionalTraitSimulationHelper extensionHelper; private final ContinuousDataLikelihoodDelegate likelihoodDelegate; private final double[] outerProduct; private final int dimTrait; @@ -30,15 +26,14 @@ public class RepeatedMeasuresWishartStatistics implements ConjugateWishartStatis private final double[] buffer; private boolean forceResample; - public RepeatedMeasuresWishartStatistics(ModelExtensionProvider.NormalExtensionProvider traitModel, + public RepeatedMeasuresWishartStatistics(FullPrecisionContinuousTraitPartialsProvider traitModel, TreeDataLikelihood treeLikelihood, boolean forceResample) { this.traitModel = traitModel; this.tree = treeLikelihood.getTree(); - this.tipTrait = treeLikelihood.getTreeTrait(traitModel.getTipTraitName()); this.likelihoodDelegate = (ContinuousDataLikelihoodDelegate) treeLikelihood.getDataLikelihoodDelegate(); - this.extensionDelegate = traitModel.getExtensionDelegate(likelihoodDelegate, tipTrait, tree); + this.extensionHelper = likelihoodDelegate.getExtensionHelper(); this.dimTrait = traitModel.getTraitDimension(); this.nTaxa = tree.getExternalNodeCount(); @@ -53,7 +48,7 @@ public RepeatedMeasuresWishartStatistics(ModelExtensionProvider.NormalExtensionP @Override public MatrixParameterInterface getPrecisionParameter() { - return traitModel.getExtensionPrecision(); + return traitModel.getExtensionPrecisionParameter(); } @Override @@ -62,12 +57,15 @@ public WishartSufficientStatistics getWishartStatistics() { if (forceResample) { likelihoodDelegate.fireModelChanged(); } - double[] treeValues = traitModel.transformTreeTraits((double[]) tipTrait.getTrait(tree, null)); - double[] dataValues = extensionDelegate.getExtendedValues(treeValues); + + ConditionalTraitSimulationHelper.JointSamples traits = extensionHelper.drawTraitsAboveAndBelow(traitModel); + + double[] valuesAbove = traits.getTraitsAbove(); + double[] valuesBelow = traits.getTraitsBelow(); DenseMatrix64F XminusY = DenseMatrix64F.wrap(nTaxa, dimTrait, buffer); - DenseMatrix64F X = DenseMatrix64F.wrap(nTaxa, dimTrait, treeValues); - DenseMatrix64F Y = DenseMatrix64F.wrap(nTaxa, dimTrait, dataValues); + DenseMatrix64F X = DenseMatrix64F.wrap(nTaxa, dimTrait, valuesAbove); + DenseMatrix64F Y = DenseMatrix64F.wrap(nTaxa, dimTrait, valuesBelow); CommonOps.subtract(X, Y, XminusY); From 6f5263c41bb4d663cdf22a933ad5b2741d5be64d Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Tue, 6 Sep 2022 11:27:06 -0700 Subject: [PATCH 122/196] forgot to commit new interface --- .../FullPrecisionContinuousTraitPartialsProvider.java | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 src/dr/evomodel/treedatalikelihood/continuous/FullPrecisionContinuousTraitPartialsProvider.java diff --git a/src/dr/evomodel/treedatalikelihood/continuous/FullPrecisionContinuousTraitPartialsProvider.java b/src/dr/evomodel/treedatalikelihood/continuous/FullPrecisionContinuousTraitPartialsProvider.java new file mode 100644 index 0000000000..d47998c02e --- /dev/null +++ b/src/dr/evomodel/treedatalikelihood/continuous/FullPrecisionContinuousTraitPartialsProvider.java @@ -0,0 +1,9 @@ +package dr.evomodel.treedatalikelihood.continuous; + +import dr.inference.model.MatrixParameterInterface; + +public interface FullPrecisionContinuousTraitPartialsProvider extends ContinuousTraitPartialsProvider { + + MatrixParameterInterface getExtensionPrecisionParameter(); + +} From 02fd5c1aa3ea7beb327b7124c15246858671e6fb Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Mon, 12 Sep 2022 10:52:12 -0700 Subject: [PATCH 123/196] gamma-precision gibbs now works with new extendable interface --- .../repeatedMeasures/GammaGibbsProvider.java | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/dr/inference/operators/repeatedMeasures/GammaGibbsProvider.java b/src/dr/inference/operators/repeatedMeasures/GammaGibbsProvider.java index 6121270485..ba897d3b57 100644 --- a/src/dr/inference/operators/repeatedMeasures/GammaGibbsProvider.java +++ b/src/dr/inference/operators/repeatedMeasures/GammaGibbsProvider.java @@ -1,8 +1,8 @@ package dr.inference.operators.repeatedMeasures; -import dr.evolution.tree.TreeTrait; -import dr.evomodel.continuous.MatrixShrinkageLikelihood; import dr.evomodel.treedatalikelihood.TreeDataLikelihood; +import dr.evomodel.treedatalikelihood.continuous.ConditionalTraitSimulationHelper; +import dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate; import dr.evomodel.treedatalikelihood.preorder.ModelExtensionProvider; import dr.inference.distribution.DistributionLikelihood; import dr.inference.distribution.LogNormalDistributionModel; @@ -17,8 +17,6 @@ import java.util.List; -import static dr.evomodel.treedatalikelihood.preorder.AbstractRealizedContinuousTraitDelegate.REALIZED_TIP_TRAIT; - /** * @author Marc A. Suchard * @author Gabriel Hassler @@ -110,9 +108,9 @@ class NormalExtensionGibbsProvider implements GammaGibbsProvider { private final ModelExtensionProvider.NormalExtensionProvider dataModel; private final TreeDataLikelihood treeLikelihood; + private final ConditionalTraitSimulationHelper traitProvider; private final CompoundParameter traitParameter; private final Parameter precisionParameter; - private final TreeTrait tipTrait; private final boolean[] missingVector; private double[] tipValues; @@ -123,8 +121,10 @@ public NormalExtensionGibbsProvider(ModelExtensionProvider.NormalExtensionProvid this.dataModel = dataModel; this.treeLikelihood = treeLikelihood; this.traitParameter = dataModel.getParameter(); - this.tipTrait = treeLikelihood.getTreeTrait(dataModel.getTipTraitName()); this.missingVector = dataModel.getDataMissingIndicators(); + this.traitProvider = ((ContinuousDataLikelihoodDelegate) + treeLikelihood.getDataLikelihoodDelegate()).getExtensionHelper(); + MatrixParameterInterface matrixParameter = dataModel.getExtensionPrecision(); @@ -182,7 +182,7 @@ public Parameter getPrecisionParameter() { @Override public void drawValues() { - double[] tipTraits = (double[]) tipTrait.getTrait(treeLikelihood.getTree(), null); + double[] tipTraits = traitProvider.drawTraitsAbove(dataModel); tipValues = dataModel.transformTreeTraits(tipTraits); if (DEBUG) { System.err.println("tipValues: " + new WrappedVector.Raw(tipValues)); From 92811c5c9812ce4b78ae59302175eb4bf27223b0 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Mon, 12 Sep 2022 14:12:53 -0700 Subject: [PATCH 124/196] can supply trait name to NormalExtensionProvider (mostly only useful for report, but also preserves some (probably bad) backward compatibility --- .../repeatedMeasures/GammaGibbsProvider.java | 21 +++++++++++++++---- .../NormalExtensionGibbsProviderParser.java | 11 ++++++++-- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/src/dr/inference/operators/repeatedMeasures/GammaGibbsProvider.java b/src/dr/inference/operators/repeatedMeasures/GammaGibbsProvider.java index ba897d3b57..c979e7c215 100644 --- a/src/dr/inference/operators/repeatedMeasures/GammaGibbsProvider.java +++ b/src/dr/inference/operators/repeatedMeasures/GammaGibbsProvider.java @@ -1,5 +1,6 @@ package dr.inference.operators.repeatedMeasures; +import dr.evolution.tree.TreeTrait; import dr.evomodel.treedatalikelihood.TreeDataLikelihood; import dr.evomodel.treedatalikelihood.continuous.ConditionalTraitSimulationHelper; import dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate; @@ -17,6 +18,8 @@ import java.util.List; +import static dr.evomodel.treedatalikelihood.preorder.AbstractRealizedContinuousTraitDelegate.getTipTraitName; + /** * @author Marc A. Suchard * @author Gabriel Hassler @@ -109,6 +112,7 @@ class NormalExtensionGibbsProvider implements GammaGibbsProvider { private final ModelExtensionProvider.NormalExtensionProvider dataModel; private final TreeDataLikelihood treeLikelihood; private final ConditionalTraitSimulationHelper traitProvider; + private final TreeTrait tipTrait; private final CompoundParameter traitParameter; private final Parameter precisionParameter; private final boolean[] missingVector; @@ -117,13 +121,21 @@ class NormalExtensionGibbsProvider implements GammaGibbsProvider { private boolean hasCheckedDimension = false; public NormalExtensionGibbsProvider(ModelExtensionProvider.NormalExtensionProvider dataModel, - TreeDataLikelihood treeLikelihood) { + TreeDataLikelihood treeLikelihood, + String traitName) { this.dataModel = dataModel; this.treeLikelihood = treeLikelihood; this.traitParameter = dataModel.getParameter(); this.missingVector = dataModel.getDataMissingIndicators(); - this.traitProvider = ((ContinuousDataLikelihoodDelegate) - treeLikelihood.getDataLikelihoodDelegate()).getExtensionHelper(); + + if (traitName == null) { + this.tipTrait = null; + this.traitProvider = ((ContinuousDataLikelihoodDelegate) + treeLikelihood.getDataLikelihoodDelegate()).getExtensionHelper(); + } else { + this.tipTrait = treeLikelihood.getTreeTrait(getTipTraitName(traitName)); + this.traitProvider = null; + } MatrixParameterInterface matrixParameter = dataModel.getExtensionPrecision(); @@ -182,7 +194,8 @@ public Parameter getPrecisionParameter() { @Override public void drawValues() { - double[] tipTraits = traitProvider.drawTraitsAbove(dataModel); + double[] tipTraits = tipTrait == null ? traitProvider.drawTraitsAbove(dataModel) : + (double[]) tipTrait.getTrait(treeLikelihood.getTree(), null); tipValues = dataModel.transformTreeTraits(tipTraits); if (DEBUG) { System.err.println("tipValues: " + new WrappedVector.Raw(tipValues)); diff --git a/src/dr/inferencexml/operators/NormalExtensionGibbsProviderParser.java b/src/dr/inferencexml/operators/NormalExtensionGibbsProviderParser.java index 9729abe77b..555852d41e 100644 --- a/src/dr/inferencexml/operators/NormalExtensionGibbsProviderParser.java +++ b/src/dr/inferencexml/operators/NormalExtensionGibbsProviderParser.java @@ -13,6 +13,7 @@ public class NormalExtensionGibbsProviderParser extends AbstractXMLObjectParser { private static final String NORMAL_EXTENSION = "normalExtension"; + private static final String TREE_TRAIT = "treeTraitName"; @Override @@ -22,14 +23,20 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { TreeDataLikelihood likelihood = (TreeDataLikelihood) xo.getChild(TreeDataLikelihood.class); - return new GammaGibbsProvider.NormalExtensionGibbsProvider(dataModel, likelihood); + String traitName = null; + if (xo.hasAttribute(TREE_TRAIT)) { + traitName = xo.getStringAttribute(TREE_TRAIT); + } + + return new GammaGibbsProvider.NormalExtensionGibbsProvider(dataModel, likelihood, traitName); } @Override public XMLSyntaxRule[] getSyntaxRules() { return new XMLSyntaxRule[]{ new ElementRule(ModelExtensionProvider.NormalExtensionProvider.class), - new ElementRule(TreeDataLikelihood.class) + new ElementRule(TreeDataLikelihood.class), + AttributeRule.newStringRule(TREE_TRAIT, true) }; } From bb926e1dc2ceebd5f10219bbb0eb76dd4db6c2eb Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Wed, 14 Sep 2022 14:33:48 -0700 Subject: [PATCH 125/196] original report used alternative method for calculating the likelihood --- ci/TestXML/testRepeatedMeasures.xml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ci/TestXML/testRepeatedMeasures.xml b/ci/TestXML/testRepeatedMeasures.xml index f1df667660..0472ee834f 100644 --- a/ci/TestXML/testRepeatedMeasures.xml +++ b/ci/TestXML/testRepeatedMeasures.xml @@ -156,14 +156,14 @@ - - Check log likelihood of observed data - - + + + + -52.895139448197 From 51b4627818c98b653ee75e024ca4c05c28990514 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Wed, 14 Sep 2022 14:37:11 -0700 Subject: [PATCH 126/196] starting to deal with actual repeated measures in RepeatedMeasuresTraitDataModel --- .../RepeatedMeasuresTraitDataModel.java | 224 ++++++++++++++---- ...eScaledRepeatedMeasuresTraitDataModel.java | 7 +- .../ContinuousTraitDataModelParser.java | 2 +- 3 files changed, 185 insertions(+), 48 deletions(-) diff --git a/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java b/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java index cfc73b2661..c3b5a37a0f 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java @@ -25,28 +25,26 @@ package dr.evomodel.treedatalikelihood.continuous; -import dr.evolution.tree.MutableTreeModel; import dr.evolution.tree.NodeRef; import dr.evolution.tree.Tree; import dr.evolution.tree.TreeTrait; -import dr.evomodel.tree.TreeModel; import dr.evomodel.treedatalikelihood.continuous.cdi.PrecisionType; import dr.evomodel.treedatalikelihood.preorder.ContinuousExtensionDelegate; import dr.evomodel.treedatalikelihood.preorder.ModelExtensionProvider; import dr.evomodelxml.continuous.ContinuousTraitDataModelParser; import dr.evomodelxml.treelikelihood.TreeTraitParserUtilities; -import dr.inference.model.CompoundParameter; -import dr.inference.model.MatrixParameterInterface; -import dr.inference.model.Parameter; -import dr.inference.model.Variable; +import dr.inference.model.*; import dr.math.matrixAlgebra.*; import dr.math.matrixAlgebra.missingData.MissingOps; import dr.xml.*; import org.ejml.data.DenseMatrix64F; import org.ejml.ops.CommonOps; +import java.util.ArrayList; import java.util.Arrays; +import static dr.evomodelxml.continuous.ContinuousTraitDataModelParser.NUM_TRAITS; + /** * @author Marc A. Suchard * @author Gabriel Hassler @@ -71,6 +69,10 @@ public class RepeatedMeasuresTraitDataModel extends ContinuousTraitDataModel imp private ContinuousTraitPartialsProvider childModel; + private final int nRepeats; + private ArrayList[] relevantRepeats; + private final int nObservedTips; + public RepeatedMeasuresTraitDataModel(String name, ContinuousTraitPartialsProvider childModel, @@ -83,10 +85,14 @@ public RepeatedMeasuresTraitDataModel(String name, PrecisionType precisionType) { super(name, parameter, missindIndicators, useMissingIndices, dimTrait, numTraits, precisionType); + if (numTraits > 1) { + throw new RuntimeException("not currently implemented"); + } this.childModel = childModel; this.traitName = name; this.samplingPrecisionParameter = samplingPrecision; + this.nRepeats = childModel.getTraitCount() / numTraits; addVariable(samplingPrecision); calculatePrecisionInfo(); @@ -98,6 +104,40 @@ public RepeatedMeasuresTraitDataModel(String name, samplingPrecisionParameter.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0, samplingPrecision.getDimension())); + + int offsetInc = precisionType.getPartialsDimension(this.dimTrait); + + int nTaxa = getParameter().getParameterCount(); + int nObservedTips = 0; + relevantRepeats = new ArrayList[nTaxa]; + + for (int i = 0; i < nTaxa; i++) { + int precisionOffset = precisionType.getPrecisionOffset(this.dimTrait); + + relevantRepeats[i] = new ArrayList<>(); + double[] partial = childModel.getTipPartial(i, false); + for (int r = 0; r < nRepeats; r++) { + boolean isAtLeastPartiallyObserved = false; + DenseMatrix64F P = MissingOps.wrap(partial, precisionOffset, this.dimTrait, this.dimTrait); + + for (int j = 0; j < this.dimTrait; j++) { + if (P.get(j, j) > 0) { + isAtLeastPartiallyObserved = true; + break; + } + } + + if (isAtLeastPartiallyObserved) { + relevantRepeats[i].add(r); + nObservedTips++; + } + + precisionOffset += offsetInc; + + } + } + + this.nObservedTips = nObservedTips; } @Override @@ -113,40 +153,96 @@ public double[] getTipPartial(int taxonIndex, boolean fullyObserved) { } double[] partial = childModel.getTipPartial(taxonIndex, fullyObserved); - if (precisionType == precisionType.SCALAR) { - return partial; //TODO: I don't think this is right, especially given constructor above. - } - DenseMatrix64F V = MissingOps.wrap(partial, dimTrait + dimTrait * dimTrait, dimTrait, dimTrait); - //TODO: remove diagonalOnly part - if (diagonalOnly) { - for (int index = 0; index < dimTrait; index++) { - V.set(index, index, V.get(index, index) + 1 / samplingPrecision.component(index, index)); + if (nRepeats == 1) { + if (precisionType == precisionType.SCALAR) { + return partial; //TODO: I don't think this is right, especially given constructor above. } - } else { - for (int i = 0; i < dimTrait; i++) { - for (int j = 0; j < dimTrait; j++) { - V.set(i, j, V.get(i, j) + samplingVariance.component(i, j)); + DenseMatrix64F V = MissingOps.wrap(partial, dimTrait + dimTrait * dimTrait, dimTrait, dimTrait); + + //TODO: remove diagonalOnly part + if (diagonalOnly) { + for (int index = 0; index < dimTrait; index++) { + V.set(index, index, V.get(index, index) + 1 / samplingPrecision.component(index, index)); } + } else { + for (int i = 0; i < dimTrait; i++) { + for (int j = 0; j < dimTrait; j++) { + V.set(i, j, V.get(i, j) + samplingVariance.component(i, j)); + } + } + } + + + DenseMatrix64F P = new DenseMatrix64F(dimTrait, dimTrait); + MissingOps.safeInvert2(V, P, false); //TODO this isn't necessary when this is fully observed + + MissingOps.unwrap(P, partial, dimTrait); + MissingOps.unwrap(V, partial, dimTrait + dimTrait * dimTrait); + + if (DEBUG) { + System.err.println("taxon " + taxonIndex); + System.err.println("\tprecision: " + P); + System.err.println("\tmean: " + new WrappedVector.Raw(partial, 0, dimTrait)); } + + return partial; } + int offsetInc = precisionType.getPartialsDimension(dimTrait); + int varOffset = precisionType.getVarianceOffset(dimTrait); + int meanOffset = precisionType.getMeanOffset(dimTrait); + int varDim = precisionType.getVarianceLength(dimTrait); + DenseMatrix64F Pi = new DenseMatrix64F(dimTrait, dimTrait); + DenseMatrix64F Vi = new DenseMatrix64F(dimTrait, dimTrait); DenseMatrix64F P = new DenseMatrix64F(dimTrait, dimTrait); - MissingOps.safeInvert2(V, P, false); //TODO this isn't necessary when this is fully observed + DenseMatrix64F V = new DenseMatrix64F(dimTrait, dimTrait); + DenseMatrix64F Pm = new DenseMatrix64F(dimTrait, 1); + DenseMatrix64F m = new DenseMatrix64F(dimTrait, 1); + + + for (int i : relevantRepeats[taxonIndex]) { + + System.arraycopy(partial, offsetInc * i + varOffset, Vi.data, 0, varDim); + for (int row = 0; row < dimTrait; row++) { + if (Vi.get(row, row) < Double.POSITIVE_INFINITY) { + Vi.set(row, row, Vi.get(row, row) + samplingVariance.component(row, row)); + for (int col = 0; col < row; col++) { + if (Vi.get(col, col) < Double.POSITIVE_INFINITY) { + Vi.set(row, col, Vi.get(row, col) + samplingVariance.component(row, col)); + Vi.set(col, row, Vi.get(row, col)); + } + } + } + } - MissingOps.unwrap(P, partial, dimTrait); - MissingOps.unwrap(V, partial, dimTrait + dimTrait * dimTrait); + MissingOps.safeInvert2(Vi, Pi, false); + CommonOps.addEquals(P, Pi); - if (DEBUG) { - System.err.println("taxon " + taxonIndex); - System.err.println("\tprecision: " + P); - System.err.println("\tmean: " + new WrappedVector.Raw(partial, 0, dimTrait)); +// System.arraycopy(partial, meanOffset, mi.data, 0, dimTrait); + for (int row = 0; row < dimTrait; row++) { + double value = 0; + for (int col = 0; col < dimTrait; col++) { + value += Pi.get(row, col) * partial[offsetInc * i + meanOffset + col]; + } + Pm.add(row, 0, value); + } } + MissingOps.safeSolve(P, Pm, m, false); + MissingOps.safeInvert2(P, V, false); //TODO: don't invert twice + + partial = new double[offsetInc]; + + System.arraycopy(m.data, 0, partial, precisionType.getMeanOffset(dimTrait), dimTrait); + System.arraycopy(P.data, 0, partial, precisionType.getPrecisionOffset(dimTrait), varDim); + System.arraycopy(V.data, 0, partial, precisionType.getVarianceOffset(dimTrait), varDim); + return partial; } + @Override public boolean[] getTraitMissingIndicators() { if (getDataMissingIndicators() == null) { @@ -288,15 +384,14 @@ public double[] drawTraitsBelowConditionalOnDataAndTraitsAbove(double[] aboveTra throw new RuntimeException("not yet implemented"); } - double[] belowTraits = new double[aboveTraits.length]; - int nTaxa = belowTraits.length / dimTrait; + double[] belowTraits = new double[nObservedTips * dimTrait]; + int nTaxa = getParameter().getParameterCount(); DenseMatrix64F P = DenseMatrix64F.wrap(dimTrait, dimTrait, samplingPrecisionParameter.getParameterValues()); DenseMatrix64F Q = new DenseMatrix64F(dimTrait, dimTrait); DenseMatrix64F V = new DenseMatrix64F(dimTrait, dimTrait); - double[] p0 = new double[dimTrait * dimTrait]; - DenseMatrix64F P0 = DenseMatrix64F.wrap(dimTrait, dimTrait, p0); + DenseMatrix64F P0 = new DenseMatrix64F(dimTrait, dimTrait); int[] wrappedIndices = new int[dimTrait]; for (int i = 0; i < dimTrait; i++) { @@ -305,32 +400,72 @@ public double[] drawTraitsBelowConditionalOnDataAndTraitsAbove(double[] aboveTra WrappedVector n = new WrappedVector.Raw(new double[dimTrait]); - int offset = 0; + int precisionOffset = precisionType.getPrecisionOffset(dimTrait); + int meanOffset = precisionType.getMeanOffset(dimTrait); + int repOffset = precisionType.getPartialsDimension(dimTrait); + int dimPrecision = precisionType.getPrecisionLength(dimTrait); + + int aboveOffset = 0; + int belowOffset = 0; for (int i = 0; i < nTaxa; i++) { double[] partial = childModel.getTipPartial(i, false); - System.arraycopy(partial, precisionType.getPrecisionOffset(dimTrait), p0, 0, - precisionType.getPrecisionLength(dimTrait)); + WrappedVector.Indexed x = new WrappedVector.Indexed(aboveTraits, aboveOffset, wrappedIndices); + + for (int j : relevantRepeats[i]) { + System.arraycopy(partial, j * repOffset + precisionOffset, P0.data, 0, dimPrecision); + WrappedVector.Indexed m0 = new WrappedVector.Indexed(partial, j * repOffset + meanOffset, wrappedIndices); - WrappedVector.Indexed m0 = new WrappedVector.Indexed(partial, precisionType.getMeanOffset(dimTrait), wrappedIndices); - WrappedVector.Indexed x = new WrappedVector.Indexed(aboveTraits, offset, wrappedIndices); + boolean completelyObserved = true; + for (int k = 0; k < dimTrait; k++) { + if (P0.get(k, k) < Double.POSITIVE_INFINITY) { + completelyObserved = false; + break; + } + } + + if (completelyObserved) { + for (int k = 0; k < dimTrait; k++) { + belowTraits[belowOffset + k] = m0.get(k); + } + } else { + CommonOps.add(P0, P, Q); + MissingOps.safeInvert2(Q, V, false); - CommonOps.add(P0, P, Q); - MissingOps.safeInvert2(Q, V, false); + MissingOps.safeWeightedAverage(m0, P0, x, P, n, V, dimTrait); - MissingOps.safeWeightedAverage(m0, P0, x, P, n, V, dimTrait); + double[] sample = MissingOps.nextPossiblyDegenerateNormal(n, V); - double[] sample = MissingOps.nextPossiblyDegenerateNormal(n, V); + System.arraycopy(sample, 0, belowTraits, belowOffset, dimTrait); + } - System.arraycopy(sample, 0, belowTraits, offset, dimTrait); + belowOffset += dimTrait; + + } + + aboveOffset += dimTrait; - offset += dimTrait; } return belowTraits; } + @Override + public double[] transformTreeTraits(double[] treeTraits) { + double[] repeatedTraits = new double[dimTrait * nObservedTips]; + int originalOffset = 0; + int expandedOffset = 0; + for (ArrayList repeats : relevantRepeats) { + for (int i : repeats) { + System.arraycopy(treeTraits, originalOffset, repeatedTraits, expandedOffset, dimTrait); + expandedOffset += dimTrait; + } + originalOffset += dimTrait; + } + + return repeatedTraits; + } private static final boolean DEBUG = false; @@ -343,7 +478,6 @@ public double[] drawTraitsBelowConditionalOnDataAndTraitsAbove(double[] aboveTra @Override public Object parseXMLObject(XMLObject xo) throws XMLParseException { - MutableTreeModel treeModel = (MutableTreeModel) xo.getChild(TreeModel.class); final ContinuousTraitPartialsProvider subModel; @@ -386,13 +520,13 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { } String modelName = subModel.getModelName(); + int numTraits = xo.getAttribute(NUM_TRAITS, subModel.getTraitCount()); + if (subModel.getTraitDimension() != dimTrait) { throw new XMLParseException("sub-model has trait dimension " + subModel.getTraitDimension() + ", but sampling precision has dimension " + dimTrait); } - int numTraits = subModel.getTraitCount(); - if (!scaleByTipHeight) { return new RepeatedMeasuresTraitDataModel( modelName, @@ -448,8 +582,8 @@ public String getParserName() { new ElementRule(Parameter.class), }), // Tree trait parser - new ElementRule(MutableTreeModel.class), - AttributeRule.newStringRule(TreeTraitParserUtilities.TRAIT_NAME), +// new ElementRule(MutableTreeModel.class), +// AttributeRule.newStringRule(TreeTraitParserUtilities.TRAIT_NAME), new XORRule( new ElementRule(ContinuousTraitPartialsProvider.class), new AndRule(ContinuousTraitDataModelParser.rules) diff --git a/src/dr/evomodel/treedatalikelihood/continuous/TreeScaledRepeatedMeasuresTraitDataModel.java b/src/dr/evomodel/treedatalikelihood/continuous/TreeScaledRepeatedMeasuresTraitDataModel.java index 7ac427d328..eaab1cd276 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/TreeScaledRepeatedMeasuresTraitDataModel.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/TreeScaledRepeatedMeasuresTraitDataModel.java @@ -33,8 +33,6 @@ import org.ejml.data.DenseMatrix64F; import org.ejml.ops.CommonOps; -import java.util.List; - /** * @author Marc A. Suchard * @author Paul Bastide @@ -55,6 +53,11 @@ public TreeScaledRepeatedMeasuresTraitDataModel(String name, PrecisionType precisionType) { super(name, childModel, parameter, missingIndicators, useMissingIndices, dimTrait, numTraits, samplingPrecision, precisionType); + + if (!(childModel instanceof ContinuousTraitDataModel)) { + throw new RuntimeException("not yet implemented for alternative child models. " + + "(can't just scale the partial in super.getTipPartial)"); + } } @Override diff --git a/src/dr/evomodelxml/continuous/ContinuousTraitDataModelParser.java b/src/dr/evomodelxml/continuous/ContinuousTraitDataModelParser.java index 76f8e56ead..581b1061c2 100644 --- a/src/dr/evomodelxml/continuous/ContinuousTraitDataModelParser.java +++ b/src/dr/evomodelxml/continuous/ContinuousTraitDataModelParser.java @@ -19,7 +19,7 @@ public class ContinuousTraitDataModelParser extends AbstractXMLObjectParser { public static final String FORCE_FULL_PRECISION = "forceFullPrecision"; - private static final String NUM_TRAITS = "numTraits"; + public static final String NUM_TRAITS = "numTraits"; @Override From 183e8de4463270c115123d3e069fdb5c4b355f74 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Wed, 14 Sep 2022 14:42:47 -0700 Subject: [PATCH 127/196] continuing deprecation of ModelExtensionProvider + small change to how drawing traits can work --- .../continuous/ConditionalTraitSimulationHelper.java | 7 +++++++ .../continuous/ContinuousTraitPartialsProvider.java | 4 ++++ .../preorder/ModelExtensionProvider.java | 2 -- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/dr/evomodel/treedatalikelihood/continuous/ConditionalTraitSimulationHelper.java b/src/dr/evomodel/treedatalikelihood/continuous/ConditionalTraitSimulationHelper.java index d3bfff112c..17f25f14e8 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/ConditionalTraitSimulationHelper.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/ConditionalTraitSimulationHelper.java @@ -115,8 +115,15 @@ public double[] getTraitsBelow() { } public JointSamples drawTraitsAboveAndBelow(ContinuousTraitPartialsProvider model) { + return drawTraitsAboveAndBelow(model, false); + } + + public JointSamples drawTraitsAboveAndBelow(ContinuousTraitPartialsProvider model, boolean transformAbove) { double[] aboveTraits = drawTraitsAbove(model); double[] belowTraits = model.drawTraitsBelowConditionalOnDataAndTraitsAbove(aboveTraits); + if (transformAbove) { + aboveTraits = model.transformTreeTraits(aboveTraits); //TODO: this is probably done twice for something like a latent factor model. can be more efficient + } return new JointSamples(aboveTraits, belowTraits); } diff --git a/src/dr/evomodel/treedatalikelihood/continuous/ContinuousTraitPartialsProvider.java b/src/dr/evomodel/treedatalikelihood/continuous/ContinuousTraitPartialsProvider.java index 8ff63ab2bd..8bcfa31571 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/ContinuousTraitPartialsProvider.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/ContinuousTraitPartialsProvider.java @@ -78,6 +78,10 @@ default double[] drawTraitsBelowConditionalOnDataAndTraitsAbove(double[] aboveTr throw new RuntimeException("Conditional sampling not yet implemented for " + this.getClass()); } + default double[] transformTreeTraits(double[] traits) { + return traits; + } + default boolean getDefaultAllowSingular() { return false; } diff --git a/src/dr/evomodel/treedatalikelihood/preorder/ModelExtensionProvider.java b/src/dr/evomodel/treedatalikelihood/preorder/ModelExtensionProvider.java index 078f79d06f..58cfba52fb 100644 --- a/src/dr/evomodel/treedatalikelihood/preorder/ModelExtensionProvider.java +++ b/src/dr/evomodel/treedatalikelihood/preorder/ModelExtensionProvider.java @@ -14,8 +14,6 @@ ContinuousExtensionDelegate getExtensionDelegate(ContinuousDataLikelihoodDelegat TreeTrait treeTrait, Tree tree); - double[] transformTreeTraits(double[] treeTraits); - interface NormalExtensionProvider extends ModelExtensionProvider { From b22af39f5adbaf5dc3a902e671ed0a2ec856a11a Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Wed, 14 Sep 2022 14:46:31 -0700 Subject: [PATCH 128/196] wishart statistics can now handle repeated measurements --- .../RepeatedMeasuresWishartStatistics.java | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresWishartStatistics.java b/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresWishartStatistics.java index 9b61898d14..8f7447b40a 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresWishartStatistics.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresWishartStatistics.java @@ -1,6 +1,5 @@ package dr.evomodel.treedatalikelihood.continuous; -import dr.evolution.tree.Tree; import dr.evomodel.treedatalikelihood.TreeDataLikelihood; import dr.inference.model.MatrixParameterInterface; import dr.math.distributions.WishartSufficientStatistics; @@ -17,29 +16,24 @@ public class RepeatedMeasuresWishartStatistics implements ConjugateWishartStatis private final FullPrecisionContinuousTraitPartialsProvider traitModel; - private final Tree tree; private final ConditionalTraitSimulationHelper extensionHelper; private final ContinuousDataLikelihoodDelegate likelihoodDelegate; private final double[] outerProduct; private final int dimTrait; - private final int nTaxa; - private final double[] buffer; + private double[] buffer; private boolean forceResample; public RepeatedMeasuresWishartStatistics(FullPrecisionContinuousTraitPartialsProvider traitModel, TreeDataLikelihood treeLikelihood, boolean forceResample) { this.traitModel = traitModel; - this.tree = treeLikelihood.getTree(); this.likelihoodDelegate = (ContinuousDataLikelihoodDelegate) treeLikelihood.getDataLikelihoodDelegate(); this.extensionHelper = likelihoodDelegate.getExtensionHelper(); this.dimTrait = traitModel.getTraitDimension(); - this.nTaxa = tree.getExternalNodeCount(); this.outerProduct = new double[dimTrait * dimTrait]; - this.buffer = new double[nTaxa * dimTrait]; this.forceResample = forceResample; @@ -58,14 +52,21 @@ public WishartSufficientStatistics getWishartStatistics() { likelihoodDelegate.fireModelChanged(); } - ConditionalTraitSimulationHelper.JointSamples traits = extensionHelper.drawTraitsAboveAndBelow(traitModel); + ConditionalTraitSimulationHelper.JointSamples traits = extensionHelper.drawTraitsAboveAndBelow(traitModel, true); double[] valuesAbove = traits.getTraitsAbove(); double[] valuesBelow = traits.getTraitsBelow(); - DenseMatrix64F XminusY = DenseMatrix64F.wrap(nTaxa, dimTrait, buffer); - DenseMatrix64F X = DenseMatrix64F.wrap(nTaxa, dimTrait, valuesAbove); - DenseMatrix64F Y = DenseMatrix64F.wrap(nTaxa, dimTrait, valuesBelow); + int nTipsTotal = valuesAbove.length / dimTrait; + + if (buffer == null) { + buffer = new double[dimTrait * nTipsTotal]; + } + + + DenseMatrix64F XminusY = DenseMatrix64F.wrap(nTipsTotal, dimTrait, buffer); + DenseMatrix64F X = DenseMatrix64F.wrap(nTipsTotal, dimTrait, valuesAbove); + DenseMatrix64F Y = DenseMatrix64F.wrap(nTipsTotal, dimTrait, valuesBelow); CommonOps.subtract(X, Y, XminusY); @@ -74,7 +75,7 @@ public WishartSufficientStatistics getWishartStatistics() { CommonOps.multTransA(XminusY, XminusY, outerProductMat); - return new WishartSufficientStatistics(nTaxa, outerProduct); + return new WishartSufficientStatistics(nTipsTotal, outerProduct); } public void setForceResample(Boolean b) { From 3f611ec54d1ccdfb76c4f19dfa5c32650d026d72 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Wed, 14 Sep 2022 17:33:34 -0700 Subject: [PATCH 129/196] getting rid of use of 'getMatrixDimension()' outside of PrecisionType --- .../continuous/ContinuousTraitDataModel.java | 4 ++-- .../IntegratedProcessTraitDataModel.java | 2 +- .../continuous/RootProcessDelegate.java | 2 +- .../cdi/ContinuousDiffusionIntegrator.java | 22 +++++++++---------- .../cdi/MultivariateIntegrator.java | 8 +++---- .../continuous/cdi/PrecisionType.java | 22 +++++++++---------- ...ctFullConditionalDistributionDelegate.java | 2 +- .../ConditionalOnTipsRealizedDelegate.java | 2 +- .../preorder/NormalSufficientStatistics.java | 2 +- .../WrappedNormalSufficientStatistics.java | 2 +- .../ContinuousDataLikelihoodDelegateTest.java | 4 ++-- 11 files changed, 35 insertions(+), 37 deletions(-) diff --git a/src/dr/evomodel/treedatalikelihood/continuous/ContinuousTraitDataModel.java b/src/dr/evomodel/treedatalikelihood/continuous/ContinuousTraitDataModel.java index 3a86655f7f..ad80c9f161 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/ContinuousTraitDataModel.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/ContinuousTraitDataModel.java @@ -217,7 +217,7 @@ public double[] getTipPartial(int taxonIndex, boolean fullyObserved) { if (fullyObserved) { final PrecisionType precisionType = PrecisionType.SCALAR; - final int offsetInc = dimTrait + precisionType.getMatrixLength(dimTrait); + final int offsetInc = precisionType.getPartialsDimension(dimTrait); final double precision = PrecisionType.getObservedPrecisionValue(false); double[] tipPartial = getTipPartial(taxonIndex, precisionType); @@ -272,7 +272,7 @@ private double[] getTipPartial(int taxonIndex, final PrecisionType precisionType } double[] getTipObservation(int taxonIndex, final PrecisionType precisionType) { - final int offsetInc = dimTrait + precisionType.getMatrixLength(dimTrait); + final int offsetInc = precisionType.getPartialsDimension(dimTrait); final double[] partial = getTipPartial(taxonIndex, precisionType); final double[] data = new double[numTraits * dimTrait]; diff --git a/src/dr/evomodel/treedatalikelihood/continuous/IntegratedProcessTraitDataModel.java b/src/dr/evomodel/treedatalikelihood/continuous/IntegratedProcessTraitDataModel.java index da6db1c2b9..4621839405 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/IntegratedProcessTraitDataModel.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/IntegratedProcessTraitDataModel.java @@ -69,7 +69,7 @@ public double[] getTipPartial(int taxonIndex, boolean fullyObserved) { double[] partial = super.getTipPartial(taxonIndex, fullyObserved); int dimTraitDouble = 2 * dimTrait; - int dimPartialDouble = dimTraitDouble + precisionType.getMatrixLength(dimTraitDouble); + int dimPartialDouble = precisionType.getPartialsDimension(dimTraitDouble); double[] partialDouble = new double[dimPartialDouble]; // Traits [0, traitsPosition] diff --git a/src/dr/evomodel/treedatalikelihood/continuous/RootProcessDelegate.java b/src/dr/evomodel/treedatalikelihood/continuous/RootProcessDelegate.java index 9836c18ba7..3e52938f4a 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/RootProcessDelegate.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/RootProcessDelegate.java @@ -112,7 +112,7 @@ private void setRootPartial(ContinuousDiffusionIntegrator cdi) { double[] mean = prior.getMean(); final int dimTrait = mean.length; - final int length = dimTrait + precisionType.getMatrixLength(dimTrait); + final int length = precisionType.getPartialsDimension(dimTrait); double[] partial = new double[length * numTraits]; int offset = 0; diff --git a/src/dr/evomodel/treedatalikelihood/continuous/cdi/ContinuousDiffusionIntegrator.java b/src/dr/evomodel/treedatalikelihood/continuous/cdi/ContinuousDiffusionIntegrator.java index 28753d4301..9d36fd4727 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/cdi/ContinuousDiffusionIntegrator.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/cdi/ContinuousDiffusionIntegrator.java @@ -148,7 +148,7 @@ class Basic implements ContinuousDiffusionIntegrator { final int bufferCount; final int diffusionCount; - final int dimMatrix; +// final int dimMatrix; final int dimPartialForTrait; final int dimPartial; @@ -189,15 +189,14 @@ public Basic( this.bufferCount = bufferCount; this.diffusionCount = diffusionCount; - this.dimMatrix = precisionType.getMatrixLength(dimTrait); - this.dimPartialForTrait = dimTrait + dimMatrix; + this.dimPartialForTrait = precisionType.getPartialsDimension(dimTrait); this.dimPartial = numTraits * dimPartialForTrait; if (DEBUG) { System.err.println("numTraits: " + numTraits); System.err.println("dimTrait: " + dimTrait); System.err.println("dimProcess: " + dimProcess); - System.err.println("dimMatrix: " + dimMatrix); +// System.err.println("dimMatrix: " + dimMatrix); System.err.println("dimPartialForTrait: " + dimPartialForTrait); System.err.println("dimPartial: " + dimPartial); } @@ -242,7 +241,7 @@ public void getPostOrderPartial(int bufferIndex, final double[] partial) { @Override public double getBranchLength(int bufferIndex) { - return branchLengths[bufferIndex * dimMatrix]; + return branchLengths[bufferIndex]; } @Override @@ -598,8 +597,7 @@ private void updateBranchLengthsAndDet(int precisionIndex, final int[] probabili System.err.println("\t" + probabilityIndices[up] + " <- " + edgeLengths[up]); } - // TODO Currently only written for SCALAR model - branchLengths[dimMatrix * probabilityIndices[up]] = edgeLengths[up]; // TODO Remove dimMatrix + branchLengths[probabilityIndices[up]] = edgeLengths[up]; } updatePrecisionOffsetAndDeterminant(precisionIndex); @@ -667,8 +665,8 @@ public void updatePreOrderPartial( int jbo = dimPartial * jBuffer; // Determine matrix offsets - final int imo = dimMatrix * iMatrix; - final int jmo = dimMatrix * jMatrix; + final int imo = iMatrix; //TODO: not sure why we need iMatrix & jMatrix to begin with? + final int jmo = jMatrix; // Read variance increments along descendant branches of k final double vi = branchLengths[imo]; @@ -809,8 +807,8 @@ protected void updatePartial( int jbo = dimPartial * jBuffer; // Determine matrix offsets - final int imo = dimMatrix * iMatrix; - final int jmo = dimMatrix * jMatrix; + final int imo = iMatrix; //TODO: just use iMatrix * jMatrix? (also, why do we need these?) + final int jmo = jMatrix; // Read variance increments along descendant branches of k final double vi = branchLengths[imo]; @@ -996,7 +994,7 @@ private static void updateMean(final double[] partials, private void allocateStorage() { partials = new double[dimPartial * bufferCount]; - branchLengths = new double[dimMatrix * bufferCount]; // TODO Should be just bufferCount + branchLengths = new double[bufferCount]; // variances = new double[dimMatrix * bufferCount]; // TODO Should be dimTrait * dimTrait remainders = new double[numTraits * bufferCount]; diff --git a/src/dr/evomodel/treedatalikelihood/continuous/cdi/MultivariateIntegrator.java b/src/dr/evomodel/treedatalikelihood/continuous/cdi/MultivariateIntegrator.java index af93ec1e3a..65c4d002ee 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/cdi/MultivariateIntegrator.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/cdi/MultivariateIntegrator.java @@ -152,8 +152,8 @@ public void updatePreOrderPartial( int jbo = dimPartial * jBuffer; // Determine matrix offsets - final int imo = dimMatrix * iMatrix; - final int jmo = dimMatrix * jMatrix; + final int imo = iMatrix; + final int jmo = jMatrix; // Read variance increments along descendant branches of k final double vi = branchLengths[imo]; @@ -284,8 +284,8 @@ protected void updatePartial( int jbo = dimPartial * jBuffer; // Determine matrix offsets - final int imo = dimMatrix * iMatrix; - final int jmo = dimMatrix * jMatrix; + final int imo = iMatrix; + final int jmo = jMatrix; // Read variance increments along descendant branches of k final double vi = branchLengths[imo]; diff --git a/src/dr/evomodel/treedatalikelihood/continuous/cdi/PrecisionType.java b/src/dr/evomodel/treedatalikelihood/continuous/cdi/PrecisionType.java index a9c303e9d6..cdc24106af 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/cdi/PrecisionType.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/cdi/PrecisionType.java @@ -159,7 +159,7 @@ public double[] getScaledPrecision(double[] partial, int offset, double[] diffus FULL("full precision matrix per branch", "full", 2) { // partial structure: - // [mean (p), precision (p^2), variance (p^2), fullPrecision (1), effective dimension (1), determinant (1)] + // [mean (p), precision (p^2), variance (p^2), fullPrecision (1), effective dimension (1), determinant (1), remainder (1)] @Override public void fillPrecisionInPartials(double[] partial, int offset, int index, double precision, @@ -195,19 +195,14 @@ public void copyObservation(double[] partial, int pOffset, double[] data, int dO } } - @Override - public int getMatrixLength(int dimTrait) { - return 2 * super.getMatrixLength(dimTrait) + 3; - } - @Override public int getPrecisionLength(int dimTrait) { - return super.getMatrixLength(dimTrait); + return super.getSingleMatrixLength(dimTrait); } @Override public int getVarianceLength(int dimTrait) { - return super.getMatrixLength(dimTrait); + return super.getSingleMatrixLength(dimTrait); } @Override @@ -250,6 +245,11 @@ public boolean hasEffectiveDimension() { public boolean hasDeterminant() { return true; } + + @Override + public int getPartialsDimension(int dimTrait) { + return dimTrait + 2 * getSingleMatrixLength(dimTrait) + 4; + } }; private final int power; @@ -270,7 +270,7 @@ public int getPower() { return power; } - public int getMatrixLength(int dimTrait) { + public int getSingleMatrixLength(int dimTrait) { int length = 1; final int pow = getPower(); for (int i = 0; i < pow; ++i) { @@ -280,7 +280,7 @@ public int getMatrixLength(int dimTrait) { } public int getPrecisionLength(int dimTrait) { - return getMatrixLength(dimTrait); + return getSingleMatrixLength(dimTrait); } public int getVarianceLength(int dimTrait) { @@ -334,7 +334,7 @@ public int getDeterminantOffset(int dimTrait) { abstract public double[] getScaledPrecision(double[] partial, int offset, double[] diffusionPrecision, int dimTrait); public int getPartialsDimension(int dimTrait) { - return this.getMatrixLength(dimTrait) + dimTrait; + return this.getSingleMatrixLength(dimTrait) + dimTrait; } public boolean hasEffectiveDimension() { diff --git a/src/dr/evomodel/treedatalikelihood/preorder/AbstractFullConditionalDistributionDelegate.java b/src/dr/evomodel/treedatalikelihood/preorder/AbstractFullConditionalDistributionDelegate.java index ed1aeb1b89..da983cb513 100644 --- a/src/dr/evomodel/treedatalikelihood/preorder/AbstractFullConditionalDistributionDelegate.java +++ b/src/dr/evomodel/treedatalikelihood/preorder/AbstractFullConditionalDistributionDelegate.java @@ -27,7 +27,7 @@ public abstract class AbstractFullConditionalDistributionDelegate super(name, tree, diffusionModel, dataModel, rootPrior, rateTransformation, likelihoodDelegate); this.likelihoodDelegate = likelihoodDelegate; this.cdi = likelihoodDelegate.getIntegrator(); - this.dimPartial = dimTrait + likelihoodDelegate.getPrecisionType().getMatrixLength(dimTrait); + this.dimPartial = likelihoodDelegate.getPrecisionType().getPartialsDimension(dimTrait); this.partialNodeBuffer = new double[numTraits * dimPartial]; this.partialRootBuffer = new double[numTraits * dimPartial]; } diff --git a/src/dr/evomodel/treedatalikelihood/preorder/ConditionalOnTipsRealizedDelegate.java b/src/dr/evomodel/treedatalikelihood/preorder/ConditionalOnTipsRealizedDelegate.java index c62dc1eb66..92a30d5c0c 100644 --- a/src/dr/evomodel/treedatalikelihood/preorder/ConditionalOnTipsRealizedDelegate.java +++ b/src/dr/evomodel/treedatalikelihood/preorder/ConditionalOnTipsRealizedDelegate.java @@ -36,7 +36,7 @@ public ConditionalOnTipsRealizedDelegate(String name, this.likelihoodDelegate = likelihoodDelegate; this.cdi = likelihoodDelegate.getIntegrator(); - this.dimPartial = dimTrait + likelihoodDelegate.getPrecisionType().getMatrixLength(dimTrait); + this.dimPartial = likelihoodDelegate.getPrecisionType().getPartialsDimension(dimTrait); partialNodeBuffer = new double[numTraits * dimPartial]; partialPriorBuffer = new double[numTraits * dimPartial]; diff --git a/src/dr/evomodel/treedatalikelihood/preorder/NormalSufficientStatistics.java b/src/dr/evomodel/treedatalikelihood/preorder/NormalSufficientStatistics.java index c470a85ef7..f75223c27b 100644 --- a/src/dr/evomodel/treedatalikelihood/preorder/NormalSufficientStatistics.java +++ b/src/dr/evomodel/treedatalikelihood/preorder/NormalSufficientStatistics.java @@ -22,7 +22,7 @@ public class NormalSufficientStatistics { DenseMatrix64F Pd, PrecisionType precisionType) { - int partialOffset = (dim + precisionType.getMatrixLength(dim)) * index; + int partialOffset = (precisionType.getPartialsDimension(dim)) * index; this.mean = MissingOps.wrap(buffer, partialOffset, dim, 1); this.precision = DenseMatrix64F.wrap(dim, dim, precisionType.getScaledPrecision(buffer, partialOffset, Pd.data, dim)); diff --git a/src/dr/evomodel/treedatalikelihood/preorder/WrappedNormalSufficientStatistics.java b/src/dr/evomodel/treedatalikelihood/preorder/WrappedNormalSufficientStatistics.java index b9303db9bb..797af3d162 100644 --- a/src/dr/evomodel/treedatalikelihood/preorder/WrappedNormalSufficientStatistics.java +++ b/src/dr/evomodel/treedatalikelihood/preorder/WrappedNormalSufficientStatistics.java @@ -32,7 +32,7 @@ public WrappedNormalSufficientStatistics(double[] buffer, DenseMatrix64F Pd, PrecisionType precisionType) { - int partialOffset = (dim + precisionType.getMatrixLength(dim)) * index; + int partialOffset = (precisionType.getPartialsDimension(dim)) * index; this.mean = new WrappedVector.Raw(buffer, partialOffset, dim); if (precisionType == PrecisionType.SCALAR) { this.precision = new WrappedMatrix.Raw(Pd.getData(), 0, dim, dim); diff --git a/src/test/dr/evomodel/treedatalikelihood/continuous/ContinuousDataLikelihoodDelegateTest.java b/src/test/dr/evomodel/treedatalikelihood/continuous/ContinuousDataLikelihoodDelegateTest.java index 9fab5081b9..582bb9ba46 100644 --- a/src/test/dr/evomodel/treedatalikelihood/continuous/ContinuousDataLikelihoodDelegateTest.java +++ b/src/test/dr/evomodel/treedatalikelihood/continuous/ContinuousDataLikelihoodDelegateTest.java @@ -945,7 +945,7 @@ private void testCMeans(TreeDataLikelihood dataLikelihood, String name, double[] format.format(partials[offset + i]), format.format(vector[i])); } - offset += dimTrait + PrecisionType.FULL.getMatrixLength(dimTrait); + offset += PrecisionType.FULL.getPartialsDimension(dimTrait); } } @@ -964,7 +964,7 @@ private void testCVariances(TreeDataLikelihood dataLikelihood, String name, doub format.format(partials[offset + dimTrait + dimTrait * dimTrait + i]), format.format(vector[i])); } - offset += dimTrait + PrecisionType.FULL.getMatrixLength(dimTrait); + offset += PrecisionType.FULL.getPartialsDimension(dimTrait); } } From b2846c3e0742f1727c077ebad42e1e8baa5cf796 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Wed, 14 Sep 2022 17:34:53 -0700 Subject: [PATCH 130/196] reverting old method name but making it private --- .../continuous/cdi/PrecisionType.java | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/dr/evomodel/treedatalikelihood/continuous/cdi/PrecisionType.java b/src/dr/evomodel/treedatalikelihood/continuous/cdi/PrecisionType.java index cdc24106af..0a8245f065 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/cdi/PrecisionType.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/cdi/PrecisionType.java @@ -197,12 +197,12 @@ public void copyObservation(double[] partial, int pOffset, double[] data, int dO @Override public int getPrecisionLength(int dimTrait) { - return super.getSingleMatrixLength(dimTrait); + return super.getMatrixLength(dimTrait); } @Override public int getVarianceLength(int dimTrait) { - return super.getSingleMatrixLength(dimTrait); + return super.getMatrixLength(dimTrait); } @Override @@ -248,7 +248,7 @@ public boolean hasDeterminant() { @Override public int getPartialsDimension(int dimTrait) { - return dimTrait + 2 * getSingleMatrixLength(dimTrait) + 4; + return dimTrait + 2 * getMatrixLength(dimTrait) + 4; } }; @@ -270,7 +270,7 @@ public int getPower() { return power; } - public int getSingleMatrixLength(int dimTrait) { + protected int getMatrixLength(int dimTrait) { int length = 1; final int pow = getPower(); for (int i = 0; i < pow; ++i) { @@ -280,7 +280,7 @@ public int getSingleMatrixLength(int dimTrait) { } public int getPrecisionLength(int dimTrait) { - return getSingleMatrixLength(dimTrait); + return getMatrixLength(dimTrait); } public int getVarianceLength(int dimTrait) { @@ -334,7 +334,7 @@ public int getDeterminantOffset(int dimTrait) { abstract public double[] getScaledPrecision(double[] partial, int offset, double[] diffusionPrecision, int dimTrait); public int getPartialsDimension(int dimTrait) { - return this.getSingleMatrixLength(dimTrait) + dimTrait; + return this.getMatrixLength(dimTrait) + dimTrait; } public boolean hasEffectiveDimension() { From e984230a8e2abc433ae2a00371fe63a41758d6f9 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Wed, 14 Sep 2022 18:12:55 -0700 Subject: [PATCH 131/196] partials providers should not also be likelihoods, the likelihood should be calculated by traitDataLikelihood regardless of the extension --- .../IntegratedFactorAnalysisLikelihood.java | 3 +- .../continuous/JointPartialsProvider.java | 5 ++++ .../cdi/MultivariateIntegrator.java | 15 ++++++++-- .../continuous/cdi/PrecisionType.java | 28 +++++++++++++++++++ 4 files changed, 48 insertions(+), 3 deletions(-) diff --git a/src/dr/evomodel/treedatalikelihood/continuous/IntegratedFactorAnalysisLikelihood.java b/src/dr/evomodel/treedatalikelihood/continuous/IntegratedFactorAnalysisLikelihood.java index e1385bc8c4..ef903f1fc7 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/IntegratedFactorAnalysisLikelihood.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/IntegratedFactorAnalysisLikelihood.java @@ -258,7 +258,7 @@ public double getLogLikelihood() { logLikelihood = calculateLogLikelihood(); likelihoodKnown = true; } - return logLikelihood; + return 0; } @Override @@ -704,6 +704,7 @@ private void computePartialAndRemainderForOneTaxon(int taxon, unwrap(precision, partials, partialsOffset + numFactors); //TODO: use PrecisionType.fillPrecisionInPartials() precisionType.fillEffDimInPartials(partials, partialsOffset, effDim, numFactors); precisionType.fillDeterminantInPartials(partials, partialsOffset, factorLogDeterminant, numFactors); + precisionType.fillRemainderInPartials(partials, partialsOffset, constant, numFactors); if (STORE_VARIANCE) { safeInvert2(precision, variance, true); diff --git a/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java b/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java index dde03a299c..3bb6ad5c2f 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/JointPartialsProvider.java @@ -204,6 +204,7 @@ public double[] getTipPartial(int taxonIndex, boolean fullyObserved) { int varOffset = precisionType.getVarianceOffset(traitDim); int effDimDim = precisionType.getEffectiveDimensionOffset(traitDim); int detDim = precisionType.getDeterminantOffset(traitDim); + int remDim = precisionType.getRemainderOffset(traitDim); WrappedMatrix.Indexed precWrap = wrapBlockDiagonalMatrix(partial, precOffset, 0, traitDim); //TODO: this only works for precisionType.FULL, make general WrappedMatrix.Indexed varWrap = wrapBlockDiagonalMatrix(partial, varOffset, 0, traitDim); //TODO: see above @@ -246,6 +247,10 @@ public double[] getTipPartial(int taxonIndex, boolean fullyObserved) { partial[detDim] += subDet; } + + if (precisionType.hasRemainder()) { + partial[remDim] += subPartial[precisionType.getRemainderOffset(subDim)]; + } } if (!computeDeterminant) { diff --git a/src/dr/evomodel/treedatalikelihood/continuous/cdi/MultivariateIntegrator.java b/src/dr/evomodel/treedatalikelihood/continuous/cdi/MultivariateIntegrator.java index 65c4d002ee..b06e7f2681 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/cdi/MultivariateIntegrator.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/cdi/MultivariateIntegrator.java @@ -78,6 +78,17 @@ private void allocateStorage() { matrix6 = new DenseMatrix64F(dimTrait, dimTrait); } + @Override + public void setPostOrderPartial(int bufferIndex, double[] partial) { + super.setPostOrderPartial(bufferIndex, partial); + + int remOffset = PrecisionType.FULL.getRemainderOffset(dimTrait); + for (int trait = 0; trait < numTraits; trait++) { + remainders[bufferIndex * numTraits + trait] = partial[dimPartialForTrait * trait + remOffset]; + } + + } + @Override public void setDiffusionPrecision(int precisionIndex, final double[] matrix, double logDeterminant) { super.setDiffusionPrecision(precisionIndex, matrix, logDeterminant); @@ -695,10 +706,10 @@ public void calculatePreOrderRoot(int priorBufferIndex, int rootNodeIndex, int p public void calculateRootLogLikelihood(int rootBufferIndex, int priorBufferIndex, int precisionIndex, final double[] logLikelihoods, boolean incrementOuterProducts, boolean isIntegratedProcess) { - assert(logLikelihoods.length == numTraits); + assert (logLikelihoods.length == numTraits); assert (!incrementOuterProducts); - assert(!isIntegratedProcess); + assert (!isIntegratedProcess); if (DEBUG) { System.err.println("Root calculation for " + rootBufferIndex); diff --git a/src/dr/evomodel/treedatalikelihood/continuous/cdi/PrecisionType.java b/src/dr/evomodel/treedatalikelihood/continuous/cdi/PrecisionType.java index 0a8245f065..dfab4f4b50 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/cdi/PrecisionType.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/cdi/PrecisionType.java @@ -187,6 +187,12 @@ public void fillNoDeterminantInPartials(double[] partial, int offset, int dimTra fillDeterminantInPartials(partial, offset, Double.NaN, dimTrait); //TODO: is it bad to assume NaN is missing? } + @Override + public void fillRemainderInPartials(double[] partials, int offset, double remainder, int dimTrait) { + int remOffset = getRemainderOffset(dimTrait); + partials[offset + remOffset] = remainder; + } + @Override public void copyObservation(double[] partial, int pOffset, double[] data, int dOffset, int dimTrait) { for (int i = 0; i < dimTrait; ++i) { @@ -220,6 +226,11 @@ public int getDeterminantOffset(int dimTrait) { return dimTrait + dimTrait * dimTrait * 2 + 2; } + @Override + public int getRemainderOffset(int dimTrait) { + return dimTrait + dimTrait * dimTrait * 2 + 3; + } + @Override public int getVarianceOffset(int dimTrait) { return dimTrait + dimTrait * dimTrait; @@ -246,6 +257,11 @@ public boolean hasDeterminant() { return true; } + @Override + public boolean hasRemainder() { + return true; + } + @Override public int getPartialsDimension(int dimTrait) { return dimTrait + 2 * getMatrixLength(dimTrait) + 4; @@ -331,6 +347,14 @@ public int getDeterminantOffset(int dimTrait) { return -1; } + public int getRemainderOffset(int dimTrait) { + return -1; + } + + public void fillRemainderInPartials(double[] partials, int offset, double remainder, int dimTrait) { + throw new RuntimeException("precision type " + tag + " does not store remainders"); + } + abstract public double[] getScaledPrecision(double[] partial, int offset, double[] diffusionPrecision, int dimTrait); public int getPartialsDimension(int dimTrait) { @@ -345,6 +369,10 @@ public boolean hasDeterminant() { return false; } + public boolean hasRemainder() { + return false; + } + private static double[] scale(double[] in, double scalar) { double[] out = new double[in.length]; From 2f7eb64da8a7d07a7661c3281b7860a916419b94 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Fri, 16 Sep 2022 10:53:15 -0700 Subject: [PATCH 132/196] pretty sure error was just floating point error --- ci/TestXML/testScaledLoadingsGradient.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/TestXML/testScaledLoadingsGradient.xml b/ci/TestXML/testScaledLoadingsGradient.xml index b67c5c4982..17ba21a66c 100644 --- a/ci/TestXML/testScaledLoadingsGradient.xml +++ b/ci/TestXML/testScaledLoadingsGradient.xml @@ -176,7 +176,7 @@ - + From d7cc3536860103c9a6f0a27f806bd209d56b6557 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Fri, 16 Sep 2022 11:56:03 -0700 Subject: [PATCH 133/196] cleaning code --- .../matrixAlgebra/missingData/MissingOps.java | 40 ------------------- 1 file changed, 40 deletions(-) diff --git a/src/dr/math/matrixAlgebra/missingData/MissingOps.java b/src/dr/math/matrixAlgebra/missingData/MissingOps.java index fb14251e55..4559b90aed 100644 --- a/src/dr/math/matrixAlgebra/missingData/MissingOps.java +++ b/src/dr/math/matrixAlgebra/missingData/MissingOps.java @@ -496,47 +496,7 @@ public static InversionResult safeSolve(DenseMatrix64F A, DenseMatrix64F B, Dens return result; } -// public static InversionResult safeInvert(DenseMatrix64F source, DenseMatrix64F destination, boolean getDeterminant) { -// -// final int dim = source.getNumCols(); -// final int finiteCount = countFiniteNonZeroDiagonals(source); -// double logDet = 0; -// -// if (finiteCount == dim) { -// if (getDeterminant) { -// logDet = invertAndGetDeterminant(source, destination, true); -// } else { -//// CommonOps.invert(copyOfSource, result); -// symmPosDefInvert(source, destination); -// } -// return new InversionResult(FULLY_OBSERVED, dim, logDet, true); -// } else { -// if (finiteCount == 0) { -// Arrays.fill(destination.getData(), 0); -// return new InversionResult(NOT_OBSERVED, 0, 0); -// } else { -// final int[] finiteIndices = new int[finiteCount]; -// getFiniteNonZeroDiagonalIndices(source, finiteIndices); -// -// final DenseMatrix64F subSource = new DenseMatrix64F(finiteCount, finiteCount); -// gatherRowsAndColumns(source, subSource, finiteIndices, finiteIndices); -// -// final DenseMatrix64F inverseSubSource = new DenseMatrix64F(finiteCount, finiteCount); -// if (getDeterminant) { -// logDet = invertAndGetDeterminant(subSource, inverseSubSource, true); -// } else { -// CommonOps.invert(subSource, inverseSubSource); -// } -// -// scatterRowsAndColumns(inverseSubSource, destination, finiteIndices, finiteIndices, true); -// -// return new InversionResult(PARTIALLY_OBSERVED, finiteCount, logDet, true); -// } -// } -// } - //TODO: Just have one safeInvert function after checking to make sure it doesn't break anything - // TODO: change all inversion to return logDeterminant public static InversionResult safeInvert2(DenseMatrix64F source, DenseMatrix64F destination, boolean getLogDeterminant) { final int dim = source.getNumCols(); From 88f48caa8a679eb83ef135da0219da805cb9783d Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Fri, 16 Sep 2022 15:12:10 -0700 Subject: [PATCH 134/196] safeInvert now know whether it's dealing with a precision or variance matrix (inverse is the same, but effective dimension is different --- .../matrixAlgebra/missingData/MissingOps.java | 43 ++++++++++++++----- 1 file changed, 33 insertions(+), 10 deletions(-) diff --git a/src/dr/math/matrixAlgebra/missingData/MissingOps.java b/src/dr/math/matrixAlgebra/missingData/MissingOps.java index 4559b90aed..3d45b494b9 100644 --- a/src/dr/math/matrixAlgebra/missingData/MissingOps.java +++ b/src/dr/math/matrixAlgebra/missingData/MissingOps.java @@ -496,8 +496,24 @@ public static InversionResult safeSolve(DenseMatrix64F A, DenseMatrix64F B, Dens return result; } + public static InversionResult safeInvertPrecision(DenseMatrix64F source, DenseMatrix64F destination, + boolean getLogDeterminant) { + return safeInvert2(source, destination, getLogDeterminant, false); + } + + public static InversionResult safeInvertVariance(DenseMatrix64F source, DenseMatrix64F destination, + boolean getLogDeterminant) { + return safeInvert2(source, destination, getLogDeterminant, true); + } + + public static InversionResult safeInvert2(DenseMatrix64F source, DenseMatrix64F destination, + boolean getLogDeterminant) { + return safeInvert2(source, destination, getLogDeterminant, true); + } - public static InversionResult safeInvert2(DenseMatrix64F source, DenseMatrix64F destination, boolean getLogDeterminant) { + private static InversionResult safeInvert2(DenseMatrix64F source, DenseMatrix64F destination, + boolean getLogDeterminant, + boolean isInputVariance) { final int dim = source.getNumCols(); final PermutationIndices permutationIndices = new PermutationIndices(source); @@ -521,7 +537,9 @@ public static InversionResult safeInvert2(DenseMatrix64F source, DenseMatrix64F if (infCount == dim) { //All infinity on diagonals of original matrix - return new InversionResult(NOT_OBSERVED, 0, Double.NEGATIVE_INFINITY, true); + return isInputVariance ? + new InversionResult(NOT_OBSERVED, 0, Double.NEGATIVE_INFINITY, true) : + new InversionResult(FULLY_OBSERVED, dim, Double.POSITIVE_INFINITY, true); } else { @@ -531,19 +549,19 @@ public static InversionResult safeInvert2(DenseMatrix64F source, DenseMatrix64F for (int i = 0; i < dim; i++) { destination.set(i, i, Double.POSITIVE_INFINITY); } - return new InversionResult(FULLY_OBSERVED, dim, Double.POSITIVE_INFINITY, true); + return isInputVariance ? + new InversionResult(FULLY_OBSERVED, dim, Double.POSITIVE_INFINITY, true) : + new InversionResult(NOT_OBSERVED, 0, Double.NEGATIVE_INFINITY, true); } else { //Both zeros and infinities (but no non-zero finite entries) on diagonal int[] zeroInds = permutationIndices.getZeroIndices(); - int[] infInds = permutationIndices.getInfiniteIndices(); for (int i : zeroInds) { destination.set(i, i, Double.POSITIVE_INFINITY); } - //TODO: not sure what to do here with regard to dimension (it could be zeroCount or infCount - //TODO: depending on whether this is a variance or precision matrix respectively. - System.err.println("Warning: safeInvert2 in MissingOps is not designed to invert matrices " + - "with both zero and infinite diagonal entries."); - return new InversionResult(PARTIALLY_OBSERVED, zeroCount, Double.POSITIVE_INFINITY, true); + + return isInputVariance ? + new InversionResult(PARTIALLY_OBSERVED, zeroCount, Double.POSITIVE_INFINITY, true) : + new InversionResult(PARTIALLY_OBSERVED, infCount, Double.POSITIVE_INFINITY, true); } } @@ -570,7 +588,12 @@ public static InversionResult safeInvert2(DenseMatrix64F source, DenseMatrix64F destination.set(index, index, Double.POSITIVE_INFINITY); } - return new InversionResult(PARTIALLY_OBSERVED, finiteNonZeroCount, logDet, true); + int fullyObsCount = isInputVariance ? permutationIndices.getNumberOfZeroDiagonals() : + permutationIndices.getNumberOfInfiniteDiagonals(); + logDet = fullyObsCount == 0 ? logDet : Double.POSITIVE_INFINITY; + + + return new InversionResult(PARTIALLY_OBSERVED, fullyObsCount + finiteNonZeroCount, logDet, true); } } } From d1c82e50f508e453ff0efff955d5cee2f620eba8 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Fri, 16 Sep 2022 15:32:52 -0700 Subject: [PATCH 135/196] bug fixes --- .../continuous/cdi/MultivariateIntegrator.java | 2 +- .../continuous/cdi/SafeMultivariateIntegrator.java | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/dr/evomodel/treedatalikelihood/continuous/cdi/MultivariateIntegrator.java b/src/dr/evomodel/treedatalikelihood/continuous/cdi/MultivariateIntegrator.java index b06e7f2681..c30afa8c46 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/cdi/MultivariateIntegrator.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/cdi/MultivariateIntegrator.java @@ -385,7 +385,7 @@ protected void updatePartial( // final DenseMatrix64F Vk = new DenseMatrix64F(dimTrait, dimTrait); final DenseMatrix64F Vk = matrix5; //TODO: should saveInvert put an infinity on the diagonal of Vk? - InversionResult ck = safeInvert2(Pk, Vk, true); + InversionResult ck = safeInvertPrecision(Pk, Vk, true); // B. Partial mean // for (int g = 0; g < dimTrait; ++g) { diff --git a/src/dr/evomodel/treedatalikelihood/continuous/cdi/SafeMultivariateIntegrator.java b/src/dr/evomodel/treedatalikelihood/continuous/cdi/SafeMultivariateIntegrator.java index 44c8a31811..e51299b031 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/cdi/SafeMultivariateIntegrator.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/cdi/SafeMultivariateIntegrator.java @@ -226,7 +226,7 @@ public void updatePreOrderPartial( CommonOps.add(Pk, Pjp, Pip); final DenseMatrix64F Vip = matrix1; - safeInvert2(Pip, Vip, false); + safeInvertPrecision(Pip, Vip, false); final double[] delta = vectorDelta; computeDelta(jbo, jdo, delta); @@ -517,7 +517,7 @@ private InversionResult increaseVariances(int ibo, final DenseMatrix64F tmp1 = matrix0; CommonOps.add(Pi, Pdi, tmp1); final DenseMatrix64F tmp2 = matrix1; - safeInvert2(tmp1, tmp2, false); + safeInvertPrecision(tmp1, tmp2, false); CommonOps.mult(tmp2, Pi, tmp1); idMinusA(tmp1); if (getDeterminant) ci = safeDeterminant(tmp1, true); From cf163cc0564b10afeecb0cea9e02b0d5c030f6ef Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Tue, 27 Sep 2022 10:49:46 -0700 Subject: [PATCH 136/196] repeated measures model computes remainder when > 1 measurements per tip --- .../RepeatedMeasuresTraitDataModel.java | 39 +++++++++++++++++-- 1 file changed, 36 insertions(+), 3 deletions(-) diff --git a/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java b/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java index c3b5a37a0f..89b6992daf 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java @@ -35,6 +35,7 @@ import dr.evomodelxml.treelikelihood.TreeTraitParserUtilities; import dr.inference.model.*; import dr.math.matrixAlgebra.*; +import dr.math.matrixAlgebra.missingData.InversionResult; import dr.math.matrixAlgebra.missingData.MissingOps; import dr.xml.*; import org.ejml.data.DenseMatrix64F; @@ -73,6 +74,8 @@ public class RepeatedMeasuresTraitDataModel extends ContinuousTraitDataModel imp private ArrayList[] relevantRepeats; private final int nObservedTips; + private final static double LOG2PI = Math.log(Math.PI * 2); + public RepeatedMeasuresTraitDataModel(String name, ContinuousTraitPartialsProvider childModel, @@ -201,6 +204,8 @@ public double[] getTipPartial(int taxonIndex, boolean fullyObserved) { DenseMatrix64F Pm = new DenseMatrix64F(dimTrait, 1); DenseMatrix64F m = new DenseMatrix64F(dimTrait, 1); + double remainder = 0; + for (int i : relevantRepeats[taxonIndex]) { @@ -217,27 +222,55 @@ public double[] getTipPartial(int taxonIndex, boolean fullyObserved) { } } - MissingOps.safeInvert2(Vi, Pi, false); + InversionResult result = MissingOps.safeInvert2(Vi, Pi, true); + CommonOps.addEquals(P, Pi); // System.arraycopy(partial, meanOffset, mi.data, 0, dimTrait); + double sumSquares = 0; + for (int row = 0; row < dimTrait; row++) { + + int offset = offsetInc * i + meanOffset; + double mr = partial[offset + row]; + double value = 0; for (int col = 0; col < dimTrait; col++) { - value += Pi.get(row, col) * partial[offsetInc * i + meanOffset + col]; + double mc = partial[offset + col]; + double x = Pi.get(row, col) * mc; + value += x; + sumSquares += x * mr; } Pm.add(row, 0, value); } + + remainder -= result.getEffectiveDimension() * LOG2PI + sumSquares + result.getLogDeterminant(); + } + MissingOps.safeSolve(P, Pm, m, false); - MissingOps.safeInvert2(P, V, false); //TODO: don't invert twice + InversionResult result = MissingOps.safeInvertPrecision(P, V, true); //TODO: don't invert twice + if (result.getReturnCode() == InversionResult.Code.NOT_OBSERVED) { + remainder = 0; + } else { + double sumSquares = 0; + for (int row = 0; row < dimTrait; row++) { + for (int col = 0; col < dimTrait; col++) { + sumSquares += m.get(row, 0) * m.get(col, 0) * P.get(row, col); + } + } + + remainder += result.getEffectiveDimension() * LOG2PI + sumSquares - result.getLogDeterminant(); + } + partial = new double[offsetInc]; System.arraycopy(m.data, 0, partial, precisionType.getMeanOffset(dimTrait), dimTrait); System.arraycopy(P.data, 0, partial, precisionType.getPrecisionOffset(dimTrait), varDim); System.arraycopy(V.data, 0, partial, precisionType.getVarianceOffset(dimTrait), varDim); + precisionType.fillRemainderInPartials(partial, 0, 0.5 * remainder, dimTrait); return partial; } From b4ff3ed2f97489773c5fec84e5d7871710bcaa66 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Tue, 27 Sep 2022 14:40:35 -0700 Subject: [PATCH 137/196] test xml --- ci/TestXML/testActualRepeatedMeasures.xml | 215 ++++++++++++++++++++++ 1 file changed, 215 insertions(+) create mode 100644 ci/TestXML/testActualRepeatedMeasures.xml diff --git a/ci/TestXML/testActualRepeatedMeasures.xml b/ci/TestXML/testActualRepeatedMeasures.xml new file mode 100644 index 0000000000..f481b9cb51 --- /dev/null +++ b/ci/TestXML/testActualRepeatedMeasures.xml @@ -0,0 +1,215 @@ + + + + + 0.6208852031301316 -1.2991371493463548 -2.6910514790332054 1.109417180690863 + 1.099166772266214 NA + + + + -0.7057152253938193 NA NA NA NA NA + + + 0.8225821150789747 0.2138202366939677 -1.5229118225528515 -0.8078400648319927 + -1.1594623872491492 0.17048275938533758 + + + + 1.2335230762207288 0.5070164014346852 -0.8270985241304889 NA NA NA + + + NA NA NA NA NA NA + + + + + (taxon1:0.034093632223924954,((taxon2:1.5104298872950768,(taxon3:0.39480816393853035,taxon4:2.492954320135449):2.920160090709784):0.7911157035133156,taxon5:3.8691942903231844):0.9572420218437911); + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Check log likelihood of observed data + + + + + + + + -42.01874315757184 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + From db626169e1f56c1a1ff7c4ea867d456db3b9b336 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Wed, 28 Sep 2022 10:38:06 -0700 Subject: [PATCH 138/196] remainders should be propogated up --- .../continuous/RepeatedMeasuresTraitDataModel.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java b/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java index 89b6992daf..6d1211a54f 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.java @@ -196,6 +196,7 @@ public double[] getTipPartial(int taxonIndex, boolean fullyObserved) { int varOffset = precisionType.getVarianceOffset(dimTrait); int meanOffset = precisionType.getMeanOffset(dimTrait); int varDim = precisionType.getVarianceLength(dimTrait); + int remOffset = precisionType.getRemainderOffset(dimTrait); DenseMatrix64F Pi = new DenseMatrix64F(dimTrait, dimTrait); DenseMatrix64F Vi = new DenseMatrix64F(dimTrait, dimTrait); @@ -244,6 +245,8 @@ public double[] getTipPartial(int taxonIndex, boolean fullyObserved) { Pm.add(row, 0, value); } + remainder += partial[offsetInc * i + remOffset]; + remainder -= result.getEffectiveDimension() * LOG2PI + sumSquares + result.getLogDeterminant(); } From 100d1d4991ddec5bfc710bc6b2621c2149919f40 Mon Sep 17 00:00:00 2001 From: Andrew Rambaut Date: Sat, 22 Apr 2023 17:13:35 +0200 Subject: [PATCH 139/196] Implemented Felsenstein's method for picking discrete category rates for discrete gamma distribution using GeneralisedGaussLaguerreQuadrature. --- src/dr/app/bss/PartitionData.java | 4 +- .../app/tools/AncestralSequenceAnnotator.java | 2 +- .../siteratemodel/GammaSiteRateModel.java | 279 +++++++++++------- .../DataLikelihoodTester.java | 4 +- .../DataLikelihoodTester2.java | 4 +- .../siteratemodel/GammaSiteModelParser.java | 48 ++- .../GeneralisedGaussLaguerreQuadrature.java | 7 + ...ncestralStateBeagleTreeLikelihoodTest.java | 2 +- src/test/dr/app/beagle/MarkovJumpsTest.java | 2 +- src/test/dr/app/beagle/TinyTest.java | 2 +- .../CompleteHistorySimulatorTest.java | 4 +- .../ProductChainSubstitutionModelTest.java | 6 +- ...bleBranchCompleteHistorySimulatorTest.java | 2 +- 13 files changed, 223 insertions(+), 143 deletions(-) diff --git a/src/dr/app/bss/PartitionData.java b/src/dr/app/bss/PartitionData.java index 655307cad1..12f6a3ae5f 100644 --- a/src/dr/app/bss/PartitionData.java +++ b/src/dr/app/bss/PartitionData.java @@ -1079,7 +1079,9 @@ public GammaSiteRateModel createSiteRateModel() { siteModel = new GammaSiteRateModel(name, siteRateModelParameterValues[1], - (int) siteRateModelParameterValues[0], siteRateModelParameterValues[2]); + (int) siteRateModelParameterValues[0], + GammaSiteRateModel.DiscretizationType.EVEN, + siteRateModelParameterValues[2]); } else { diff --git a/src/dr/app/tools/AncestralSequenceAnnotator.java b/src/dr/app/tools/AncestralSequenceAnnotator.java index 6dcc849f9d..d5488dfdba 100644 --- a/src/dr/app/tools/AncestralSequenceAnnotator.java +++ b/src/dr/app/tools/AncestralSequenceAnnotator.java @@ -808,7 +808,7 @@ else if(siteRatesModels.indexOf("+GAMMA(") >= 0) { /* For BEAST output */ //System.out.println("alpha and pinv parameters: " + alphaParameter.getParameterValue(0) + "\t" + pInvParameter.getParameterValue(0)); //GammaSiteRateModel siteModel = new GammaSiteRateModel(sml.getSubstitutionModel(), new Parameter.Default(1.0), alphaParameter, categories, pInvParameter); - GammaSiteRateModel siteModel = new GammaSiteRateModel(GammaSiteModelParser.SITE_MODEL, new Parameter.Default(1.0), alphaParameter, categories, pInvParameter); + GammaSiteRateModel siteModel = new GammaSiteRateModel(GammaSiteModelParser.SITE_MODEL, new Parameter.Default(1.0), 1.0, alphaParameter, categories, GammaSiteRateModel.DiscretizationType.EVEN, pInvParameter); siteModel.setSubstitutionModel(sml.getSubstitutionModel()); //SiteModel siteModel = new GammaSiteModel(sml.getSubstitutionModel(), new Parameter.Default(1.0), new Parameter.Default(1.0), 1, new Parameter.Default(0.5)); //SiteModel siteModel = new GammaSiteModel(sml.getSubstitutionModel(), null, null, 0, null); diff --git a/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java b/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java index 9cbcc42d27..faa86bdf4f 100644 --- a/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java +++ b/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java @@ -26,12 +26,16 @@ package dr.evomodel.siteratemodel; import dr.inference.model.*; +import dr.math.GammaFunction; +import dr.math.GeneralisedGaussLaguerreQuadrature; +import dr.math.UnivariateFunction; import dr.math.distributions.GammaDistribution; import dr.evomodel.substmodel.SubstitutionModel; import dr.util.Author; import dr.util.Citable; import dr.util.Citation; +import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -43,70 +47,64 @@ */ public class GammaSiteRateModel extends AbstractModel implements SiteRateModel, Citable { - public enum CategoryWidthType { - FASTEST, - GEOMETRIC - }; + + public enum DiscretizationType { + EVEN, + QUADRATURE + } + + ; public GammaSiteRateModel(String name) { - this( name, + this(name, null, 1.0, null, - 0, - null, null, null); + 0, DiscretizationType.EVEN, + null); } - public GammaSiteRateModel(String name, double alpha, int categoryCount) { - this( name, + public GammaSiteRateModel(String name, double alpha, int categoryCount, DiscretizationType discretizationType) { + this(name, null, 1.0, new Parameter.Default(alpha), categoryCount, - null, null, null); + discretizationType, + null); } - public GammaSiteRateModel(String name, double alpha, int categoryCount, double pInvar) { - this( name, + public GammaSiteRateModel(String name, double alpha, int categoryCount, DiscretizationType discretizationType, double pInvar) { + this(name, null, 1.0, new Parameter.Default(alpha), categoryCount, - new Parameter.Default(pInvar), - null, null); + discretizationType, + new Parameter.Default(pInvar)); } - public GammaSiteRateModel(String name, double alpha, int categoryCount, double pInvar, double catWidth, CategoryWidthType categoryWidthType) { - this( name, - null, + public GammaSiteRateModel(String name, Parameter nu) { + this(name, + nu, 1.0, - new Parameter.Default(alpha), - categoryCount, - new Parameter.Default(pInvar), - new Parameter.Default(catWidth), - categoryWidthType); - } - - public GammaSiteRateModel( - String name, - Parameter nuParameter, - Parameter shapeParameter, int gammaCategoryCount, - Parameter invarParameter) { - this(name, nuParameter, 1.0, shapeParameter, gammaCategoryCount, invarParameter, null, null); + null, + -1, + null, + null); } - /** - * Constructor for gamma+invar distributed sites. Either shapeParameter or - * invarParameter (or both) can be null to turn off that feature. - */ + /** + * Constructor for gamma+invar distributed sites. Either shapeParameter or + * invarParameter (or both) can be null to turn off that feature. + */ public GammaSiteRateModel( String name, Parameter nuParameter, double muWeight, Parameter shapeParameter, int gammaCategoryCount, - Parameter invarParameter, - Parameter categoryWidthParameter, - CategoryWidthType categoryWidthType) { + DiscretizationType discretizationType, + Parameter invarParameter) { super(name); @@ -122,7 +120,6 @@ public GammaSiteRateModel( this.shapeParameter = shapeParameter; if (shapeParameter != null) { this.categoryCount = gammaCategoryCount; - addVariable(shapeParameter); // shapeParameter.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 1E-3, 1)); // removing the bounds on the alpha parameter - to make the prior more explicit @@ -139,14 +136,7 @@ public GammaSiteRateModel( invarParameter.addBounds(new Parameter.DefaultBounds(1.0, 0.0, 1)); } - this.categoryWidthParameter = categoryWidthParameter; - this.categoryWidthType = categoryWidthType; - if (categoryWidthParameter != null) { - this.categoryCount += 1; - - addVariable(categoryWidthParameter); - categoryWidthParameter.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0, 1)); - } + this.discretizationType = discretizationType; categoryRates = new double[this.categoryCount]; categoryProportions = new double[this.categoryCount]; @@ -165,7 +155,7 @@ public void setMu(double mu) { * @return mu */ public final double getMu() { - return nuParameter.getParameterValue(0) * muWeight; + return nuParameter.getParameterValue(0) * muWeight; } /** @@ -249,60 +239,35 @@ public double getProportionForCategory(int category) { */ private void calculateCategoryRates() { - double propVariable = 1.0; - int cat = 0; + int offset = 0; if (invarParameter != null) { categoryRates[0] = 0.0; categoryProportions[0] = invarParameter.getParameterValue(0); - - propVariable = 1.0 - categoryProportions[0]; - cat = 1; + offset = 1; } if (shapeParameter != null) { + double alpha = shapeParameter.getParameterValue(0); + final int gammaCatCount = categoryCount - offset; - final double a = shapeParameter.getParameterValue(0); - double mean = 0.0; - double sum = 0.0; - final int gammaCatCount = categoryCount - cat; - - for (int i = 0; i < gammaCatCount; i++) { - - categoryRates[i + cat] = GammaDistribution.quantile((2.0 * i + 1.0) / (2.0 * gammaCatCount), a, 1.0 / a); - -// if (categoryRates[i + cat] == 0.0) { -// throw new RuntimeException("Alpha parameter for discrete gamma distribution is too small and causing numerical errors."); -// } - - mean += categoryRates[i + cat]; - - if (categoryWidthParameter != null && categoryWidthType == CategoryWidthType.GEOMETRIC && i > 0) { - categoryProportions[i + cat] = categoryProportions[i + cat - 1] * (1.0 + categoryWidthParameter.getParameterValue(0)); - } else if (categoryWidthParameter != null && categoryWidthType == CategoryWidthType.FASTEST && - i == (gammaCatCount - 1)) { - categoryProportions[i + cat] = (1.0 + categoryWidthParameter.getParameterValue(0)); - } else { - categoryProportions[i + cat] = 1.0; - } - sum += categoryProportions[i + cat]; - } - - mean = (propVariable * mean) / gammaCatCount; - - for (int i = 0; i < gammaCatCount; i++) { - categoryRates[i + cat] /= mean; - categoryProportions[i + cat] /= sum; + if (discretizationType == DiscretizationType.QUADRATURE) { + setQuatratureRates(categoryRates, categoryProportions, alpha, gammaCatCount, offset); + } else { + setEqualRates(categoryRates, categoryProportions, alpha, gammaCatCount, offset); } + } else if (offset > 0) { + // just the invariant rate and variant rate + categoryRates[offset] = 2.0; + categoryProportions[offset] = 1.0 - categoryProportions[0]; } else { - categoryRates[cat] = 1.0 / propVariable; - categoryProportions[cat] = propVariable; + categoryRates[0] = 1.0; + categoryProportions[0] = 1.0; } - if (nuParameter != null) { // Moved multiplication by mu to here; it also - // needed by double[] getCategoryRates() -- previously ignored + if (nuParameter != null) { double mu = getMu(); - for (int i=0; i < categoryCount; i++) + for (int i = 0; i < categoryCount; i++) categoryRates[i] *= mu; } @@ -323,12 +288,10 @@ protected final void handleVariableChangedEvent(Variable variable, int index, Pa ratesKnown = false; } else if (variable == invarParameter) { ratesKnown = false; - } else if (variable == categoryWidthParameter) { - ratesKnown = false; } else if (variable == nuParameter) { ratesKnown = false; // MAS: I changed this because the rate parameter can affect the categories if the parameter is in siteModel and not clockModel } else { - throw new RuntimeException("Unknown variable in GammaSiteRateModel.handleVariableChangedEvent"); + throw new RuntimeException("Unknown variable in GammaSiteRateModel.handleVariableChangedEvent"); } listenerHelper.fireModelChanged(this, variable, index); } @@ -382,9 +345,7 @@ public double getStatisticValue(int dim) { */ private Parameter invarParameter; - private Parameter categoryWidthParameter; - - private CategoryWidthType categoryWidthType = null; + private DiscretizationType discretizationType; private boolean ratesKnown; @@ -395,7 +356,6 @@ public double getStatisticValue(int dim) { private double[] categoryProportions; - // This is here solely to allow the GammaSiteModelParser to pass on the substitution model to the // HomogenousBranchSubstitutionModel so that the XML will be compatible with older BEAST versions. To be removed // at some point. @@ -419,14 +379,17 @@ public String getDescription() { } public List getCitations() { + List citations = new ArrayList<>(); if (shapeParameter != null) { - return Collections.singletonList(CITATION); - } else { - return Collections.emptyList(); + citations.add(CITATION_YANG94); + if (discretizationType == DiscretizationType.QUADRATURE) { + citations.add(CITATION_FELSENSTEIN01); + } } + return citations; } - public final static Citation CITATION = new Citation( + public final static Citation CITATION_YANG94 = new Citation( new Author[]{ new Author("Z", "Yang") }, @@ -438,6 +401,124 @@ public List getCitations() { Citation.Status.PUBLISHED ); + public final static Citation CITATION_FELSENSTEIN01 = new Citation( + new Author[]{ + new Author("J", "Felsenstein ") + }, + "Taking Variation of Evolutionary Rates Between Sites into Account in Inferring Phylogenies", + 2001, + "J. Mol. Evol.", + 53, + 447, 455, + Citation.Status.PUBLISHED + ); + private SubstitutionModel substitutionModel; + private static GeneralisedGaussLaguerreQuadrature quadrature = null; + + public static void setQuatratureRates(double[] categoryRates, double[] categoryProportions, double alpha, int catCount, int offset) { + if (quadrature == null) { + quadrature = new GeneralisedGaussLaguerreQuadrature(catCount); + } + quadrature.setAlpha(alpha); + + double[] abscissae = quadrature.getAbscissae(); + double[] coefficients = quadrature.getCoefficients(); + + for (int i = 0; i < catCount; i++) { + categoryRates[i + offset] = abscissae[i] / (alpha + 1); + categoryProportions[i + offset] = coefficients[i]; + } + normalize(categoryRates, categoryProportions); + } + + public static void setEqualRates(double[] categoryRates, double[] categoryProportions, double alpha, int catCount, int offset) { + for (int i = 0; i < catCount; i++) { + categoryRates[i + offset] = GammaDistribution.quantile((2.0 * i + 1.0) / (2.0 * catCount), alpha, 1.0 / alpha); + } + + normalize(categoryRates, categoryProportions); + } + + public static void normalize(double[] categoryRates, double[] categoryProportions) { + double mean = 0.0; + double sum = 0.0; + for (int i = 0; i < categoryRates.length; i++) { + mean += categoryRates[i]; + sum += categoryProportions[i]; + } + mean /= categoryRates.length; + + for(int i = 0; i < categoryRates.length; i++) { + categoryRates[i] /= mean; + categoryProportions[i] /= sum; + } + } + + public static void main(String[] argv) { + final int catCount = 6; + + double[] categoryRates = new double[catCount]; + double[] categoryProportions = new double[catCount]; + + setEqualRates(categoryRates, categoryProportions, 1.0, catCount, 0); + System.out.println(); + System.out.println("Equal, alpha = 1.0"); + System.out.println("cat\trate\tproportion"); + for (int i = 0; i < catCount; i++) { + System.out.println(i + "\t"+ categoryRates[i] +"\t" + categoryProportions[i]); + } + + setQuatratureRates(categoryRates, categoryProportions, 1.0, catCount, 0); + System.out.println(); + System.out.println("Quadrature, alpha = 1.0"); + System.out.println("cat\trate\tproportion"); + for (int i = 0; i < catCount; i++) { + System.out.println(i + "\t"+ categoryRates[i] +"\t" + categoryProportions[i]); + } + + // Table 3 from Felsenstein 2001, JME + // Rates and probabilities chosen by the quadrature method for six rates and coefficient of + // variation of rates among sites 1 (a 4 1) + // Probability Rate + // 0.278 0.264 + // 0.494 0.898 + // 0.203 1.938 + // 0.025 3.459 + // 0.00076 5.617 + // 0.000003 8.823 + + setEqualRates(categoryRates, categoryProportions, 0.1, catCount, 0); + System.out.println(); + System.out.println("Equal, alpha = 0.1"); + System.out.println("cat\trate\tproportion"); + for (int i = 0; i < catCount; i++) { + System.out.println(i + "\t"+ categoryRates[i] +"\t" + categoryProportions[i]); + } + + setQuatratureRates(categoryRates, categoryProportions, 0.1, catCount, 0); + System.out.println(); + System.out.println("Quadrature, alpha = 0.1"); + System.out.println("cat\trate\tproportion"); + for (int i = 0; i < catCount; i++) { + System.out.println(i + "\t"+ categoryRates[i] +"\t" + categoryProportions[i]); + } + + setEqualRates(categoryRates, categoryProportions, 10.0, catCount, 0); + System.out.println(); + System.out.println("Equal, alpha = 10.0"); + System.out.println("cat\trate\tproportion"); + for (int i = 0; i < catCount; i++) { + System.out.println(i + "\t"+ categoryRates[i] +"\t" + categoryProportions[i]); + } + + setQuatratureRates(categoryRates, categoryProportions, 10.0, catCount, 0); + System.out.println(); + System.out.println("Quadrature, alpha = 10.0"); + System.out.println("cat\trate\tproportion"); + for (int i = 0; i < catCount; i++) { + System.out.println(i + "\t"+ categoryRates[i] +"\t" + categoryProportions[i]); + } + } } \ No newline at end of file diff --git a/src/dr/evomodel/treedatalikelihood/DataLikelihoodTester.java b/src/dr/evomodel/treedatalikelihood/DataLikelihoodTester.java index 1dc72c94de..43befa7b9c 100644 --- a/src/dr/evomodel/treedatalikelihood/DataLikelihoodTester.java +++ b/src/dr/evomodel/treedatalikelihood/DataLikelihoodTester.java @@ -87,7 +87,7 @@ public static void main(String[] args) { //siteModel double alpha = 0.5; - GammaSiteRateModel siteRateModel = new GammaSiteRateModel("gammaModel", alpha, 4); + GammaSiteRateModel siteRateModel = new GammaSiteRateModel("gammaModel", alpha, 4, GammaSiteRateModel.DiscretizationType.EVEN); // GammaSiteRateModel siteRateModel = new GammaSiteRateModel("siteRateModel"); siteRateModel.setSubstitutionModel(hky); Parameter mu = new Parameter.Default(GammaSiteModelParser.SUBSTITUTION_RATE, 1.0, 0, Double.POSITIVE_INFINITY); @@ -98,7 +98,7 @@ public static void main(String[] args) { HKY hky2 = new HKY(kappa2, f2) ; - GammaSiteRateModel siteRateModel2 = new GammaSiteRateModel("gammaModel", alpha, 4); + GammaSiteRateModel siteRateModel2 = new GammaSiteRateModel("gammaModel", alpha, 4, GammaSiteRateModel.DiscretizationType.EVEN); siteRateModel2.setSubstitutionModel(hky2); siteRateModel2.setRelativeRateParameter(mu); diff --git a/src/dr/evomodel/treedatalikelihood/DataLikelihoodTester2.java b/src/dr/evomodel/treedatalikelihood/DataLikelihoodTester2.java index 3625650d56..c3e4b54324 100644 --- a/src/dr/evomodel/treedatalikelihood/DataLikelihoodTester2.java +++ b/src/dr/evomodel/treedatalikelihood/DataLikelihoodTester2.java @@ -87,7 +87,7 @@ public static void main(String[] args) { //siteModel double alpha = 0.5; - GammaSiteRateModel siteRateModel = new GammaSiteRateModel("gammaModel", alpha, 4); + GammaSiteRateModel siteRateModel = new GammaSiteRateModel("gammaModel", alpha, 4, GammaSiteRateModel.DiscretizationType.EVEN); // GammaSiteRateModel siteRateModel = new GammaSiteRateModel("siteRateModel"); siteRateModel.setSubstitutionModel(hky); Parameter mu = new Parameter.Default(GammaSiteModelParser.SUBSTITUTION_RATE, 1.0, 0, Double.POSITIVE_INFINITY); @@ -97,7 +97,7 @@ public static void main(String[] args) { Parameter kappa2 = new Parameter.Default(HKYParser.KAPPA, 10.0, 0, 100); HKY hky2 = new HKY(kappa2, f2); - GammaSiteRateModel siteRateModel2 = new GammaSiteRateModel("gammaModel", alpha, 4); + GammaSiteRateModel siteRateModel2 = new GammaSiteRateModel("gammaModel", alpha, 4, GammaSiteRateModel.DiscretizationType.EVEN); siteRateModel2.setSubstitutionModel(hky2); siteRateModel2.setRelativeRateParameter(mu); diff --git a/src/dr/evomodelxml/siteratemodel/GammaSiteModelParser.java b/src/dr/evomodelxml/siteratemodel/GammaSiteModelParser.java index 721f9d564c..e603d64bff 100644 --- a/src/dr/evomodelxml/siteratemodel/GammaSiteModelParser.java +++ b/src/dr/evomodelxml/siteratemodel/GammaSiteModelParser.java @@ -25,6 +25,7 @@ package dr.evomodelxml.siteratemodel; +import java.util.Locale; import java.util.logging.Logger; import dr.evomodel.siteratemodel.GammaSiteRateModel; @@ -55,10 +56,9 @@ public class GammaSiteModelParser extends AbstractXMLObjectParser { public static final String GAMMA_SHAPE = "gammaShape"; public static final String GAMMA_CATEGORIES = "gammaCategories"; public static final String PROPORTION_INVARIANT = "proportionInvariant"; - public static final String CATEGORY_WIDTH = "categoryWidth"; - public static final String TYPE = "type"; - public static final String FASTEST = "fastest"; - public static final String GEOMETRIC = "geometric"; + public static final String DISCRETIZATION = "discretization"; + public static final String EVEN = "even"; + public static final String QUADRATURE = "quadrature"; public String getParserName() { return SITE_MODEL; @@ -90,11 +90,25 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { } } + GammaSiteRateModel.DiscretizationType type = GammaSiteRateModel.DiscretizationType.EVEN; + Parameter shapeParam = null; int catCount = 4; if (xo.hasChildNamed(GAMMA_SHAPE)) { XMLObject cxo = xo.getChild(GAMMA_SHAPE); catCount = cxo.getIntegerAttribute(GAMMA_CATEGORIES); + + try { + type = GammaSiteRateModel.DiscretizationType.valueOf( + cxo.getStringAttribute(DISCRETIZATION).toUpperCase()); + if (type == GammaSiteRateModel.DiscretizationType.EVEN) { + msg += "\n even discretization of gamma distribution"; + } else { + msg += "\n quadrature discretization of gamma distribution"; + } + } catch (IllegalArgumentException eae) { + throw new XMLParseException("Unknown category width type: " + cxo.getStringAttribute(DISCRETIZATION)); + } shapeParam = (Parameter) cxo.getChild(Parameter.class); msg += "\n " + catCount + " category discrete gamma with initial shape = " + shapeParam.getParameterValue(0); @@ -106,30 +120,13 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { msg += "\n initial proportion of invariant sites = " + invarParam.getParameterValue(0); } - Parameter categoryWidthParameter = null; - GammaSiteRateModel.CategoryWidthType type = null; - if (xo.hasChildNamed(CATEGORY_WIDTH)) { - categoryWidthParameter = (Parameter) xo.getElementFirstChild(CATEGORY_WIDTH); - String typeString = xo.getChild(CATEGORY_WIDTH).getStringAttribute(TYPE); - try { - type = GammaSiteRateModel.CategoryWidthType.valueOf(typeString.toUpperCase()); - if (type == GammaSiteRateModel.CategoryWidthType.FASTEST) { - msg += "\n initial proportion of fastest sites = " + categoryWidthParameter.getParameterValue(0); - } else { - msg += "\n initial factor for increasing category width = " + categoryWidthParameter.getParameterValue(0); - } - } catch (IllegalArgumentException eae) { - throw new XMLParseException("Unknown category width type: " + typeString); - } - } - if (msg.length() > 0) { Logger.getLogger("dr.evomodel").info("\nCreating site rate model: " + msg); } else { Logger.getLogger("dr.evomodel").info("\nCreating site rate model."); } - GammaSiteRateModel siteRateModel = new GammaSiteRateModel(SITE_MODEL, muParam, muWeight, shapeParam, catCount, invarParam, categoryWidthParameter, type); + GammaSiteRateModel siteRateModel = new GammaSiteRateModel(SITE_MODEL, muParam, muWeight, shapeParam, catCount, type, invarParam); if (xo.hasChildNamed(SUBSTITUTION_MODEL)) { @@ -183,18 +180,13 @@ public XMLSyntaxRule[] getSyntaxRules() { new ElementRule(GAMMA_SHAPE, new XMLSyntaxRule[]{ AttributeRule.newIntegerRule(GAMMA_CATEGORIES, true), + AttributeRule.newStringRule(DISCRETIZATION, true), new ElementRule(Parameter.class) }, true), new ElementRule(PROPORTION_INVARIANT, new XMLSyntaxRule[]{ new ElementRule(Parameter.class) - }, true), - - new ElementRule(CATEGORY_WIDTH, new XMLSyntaxRule[]{ - AttributeRule.newStringRule(TYPE, false), - new ElementRule(Parameter.class) }, true) - }; }//END: class diff --git a/src/dr/math/GeneralisedGaussLaguerreQuadrature.java b/src/dr/math/GeneralisedGaussLaguerreQuadrature.java index 9cbfc830da..05f4a60e27 100644 --- a/src/dr/math/GeneralisedGaussLaguerreQuadrature.java +++ b/src/dr/math/GeneralisedGaussLaguerreQuadrature.java @@ -168,5 +168,12 @@ public double logIntegrate(UnivariateFunction f, double min){ } + public double[] getAbscissae() { + return abscissae; + } + + public double[] getCoefficients() { + return coefficients; + } } diff --git a/src/test/dr/app/beagle/AncestralStateBeagleTreeLikelihoodTest.java b/src/test/dr/app/beagle/AncestralStateBeagleTreeLikelihoodTest.java index c28f5b192e..2b41906509 100644 --- a/src/test/dr/app/beagle/AncestralStateBeagleTreeLikelihoodTest.java +++ b/src/test/dr/app/beagle/AncestralStateBeagleTreeLikelihoodTest.java @@ -83,7 +83,7 @@ public void testJointLikelihood() { FrequencyModel f = new FrequencyModel(Nucleotides.INSTANCE, freqs); HKY hky = new HKY(kappa, f); - GammaSiteRateModel siteRateModel = new GammaSiteRateModel("gammaModel", mu, null, -1, null); + GammaSiteRateModel siteRateModel = new GammaSiteRateModel("gammaModel", mu); siteRateModel.setSubstitutionModel(hky); BranchModel branchModel = new HomogeneousBranchModel( diff --git a/src/test/dr/app/beagle/MarkovJumpsTest.java b/src/test/dr/app/beagle/MarkovJumpsTest.java index b293ae4580..d5a6f0d064 100644 --- a/src/test/dr/app/beagle/MarkovJumpsTest.java +++ b/src/test/dr/app/beagle/MarkovJumpsTest.java @@ -51,7 +51,7 @@ public void testMarkovJumps() { Parameter mu = new Parameter.Default(GammaSiteModelParser.MUTATION_RATE, 0.5, 0, Double.POSITIVE_INFINITY); // Parameter pInv = new Parameter.Default("pInv", 0.5, 0, 1); Parameter pInv = null; - GammaSiteRateModel siteRateModel = new GammaSiteRateModel("gammaModel", mu, null, -1, pInv); + GammaSiteRateModel siteRateModel = new GammaSiteRateModel("gammaModel", mu, 1.0, null, -1, null, pInv); siteRateModel.setSubstitutionModel(hky); //treeLikelihood diff --git a/src/test/dr/app/beagle/TinyTest.java b/src/test/dr/app/beagle/TinyTest.java index abe2493464..91b59ec52f 100644 --- a/src/test/dr/app/beagle/TinyTest.java +++ b/src/test/dr/app/beagle/TinyTest.java @@ -38,7 +38,7 @@ public void testTiny() { //siteModel double alpha = 0.5; - GammaSiteRateModel siteRateModel = new GammaSiteRateModel("gammaModel", alpha, 4); + GammaSiteRateModel siteRateModel = new GammaSiteRateModel("gammaModel", alpha, 4, GammaSiteRateModel.DiscretizationType.EVEN); siteRateModel.setSubstitutionModel(hky); Parameter mu = new Parameter.Default(GammaSiteModelParser.MUTATION_RATE, 1.0, 0, Double.POSITIVE_INFINITY); siteRateModel.setRelativeRateParameter(mu); diff --git a/src/test/dr/evomodel/substmodel/CompleteHistorySimulatorTest.java b/src/test/dr/evomodel/substmodel/CompleteHistorySimulatorTest.java index 1af233c0a7..34806ead49 100644 --- a/src/test/dr/evomodel/substmodel/CompleteHistorySimulatorTest.java +++ b/src/test/dr/evomodel/substmodel/CompleteHistorySimulatorTest.java @@ -51,7 +51,7 @@ public void testHKYSimulation() { Parameter mu = new Parameter.Default(1, 0.5); Parameter alpha = new Parameter.Default(1, 0.5); - GammaSiteRateModel siteModel = new GammaSiteRateModel("gammaModel", mu, alpha, 4, null); + GammaSiteRateModel siteModel = new GammaSiteRateModel("gammaModel", mu, 1.0, alpha, 4, GammaSiteRateModel.DiscretizationType.EVEN, null); siteModel.setSubstitutionModel(hky); BranchRateModel branchRateModel = new DefaultBranchRateModel(); @@ -93,7 +93,7 @@ public void testCodonSimulation() { Parameter mu = new Parameter.Default(1, 0.5); Parameter alpha = new Parameter.Default(1, 0.5); - GammaSiteRateModel siteModel = new GammaSiteRateModel("gammaModel", mu, alpha, 4, null); + GammaSiteRateModel siteModel = new GammaSiteRateModel("gammaModel", mu, 1.0, alpha, 4, GammaSiteRateModel.DiscretizationType.EVEN, null); siteModel.setSubstitutionModel(codonModel); BranchRateModel branchRateModel = new DefaultBranchRateModel(); diff --git a/src/test/dr/evomodel/substmodel/ProductChainSubstitutionModelTest.java b/src/test/dr/evomodel/substmodel/ProductChainSubstitutionModelTest.java index 648584b340..15f69e417b 100644 --- a/src/test/dr/evomodel/substmodel/ProductChainSubstitutionModelTest.java +++ b/src/test/dr/evomodel/substmodel/ProductChainSubstitutionModelTest.java @@ -78,12 +78,10 @@ private void setUpTwoStatesUnequalRate() { baseModels.add(substModel1); SiteRateModel rateModel0 = new GammaSiteRateModel("rate0", - new Parameter.Default(new double[]{0.5}), - null, -1, null); + new Parameter.Default(new double[]{0.5})); SiteRateModel rateModel1 = new GammaSiteRateModel("rate0", - new Parameter.Default(new double[]{2}), // Runs twice as fast - null, -1, null); + new Parameter.Default(new double[]{2})); // Runs twice as fast List rateModels = new ArrayList(); rateModels.add(rateModel0); diff --git a/src/test/dr/evomodel/substmodel/VariableBranchCompleteHistorySimulatorTest.java b/src/test/dr/evomodel/substmodel/VariableBranchCompleteHistorySimulatorTest.java index 4101538965..cad222adf2 100644 --- a/src/test/dr/evomodel/substmodel/VariableBranchCompleteHistorySimulatorTest.java +++ b/src/test/dr/evomodel/substmodel/VariableBranchCompleteHistorySimulatorTest.java @@ -33,7 +33,7 @@ public void testHKYVariableSimulation() { Parameter mu = new Parameter.Default(1, 0.5); Parameter alpha = new Parameter.Default(1, 0.5); - GammaSiteRateModel siteModel = new GammaSiteRateModel("gammaModel", mu, alpha, 4, null); + GammaSiteRateModel siteModel = new GammaSiteRateModel("gammaModel", mu, 1.0, alpha, 4, GammaSiteRateModel.DiscretizationType.EVEN, null); siteModel.setSubstitutionModel(hky); BranchRateModel branchRateModel = new DefaultBranchRateModel(); From ee14561a51e08920f74aacae96bd01958b861bb0 Mon Sep 17 00:00:00 2001 From: Andrew Rambaut Date: Sat, 22 Apr 2023 17:15:45 +0200 Subject: [PATCH 140/196] Fixed equal weight discretization --- src/dr/evomodel/siteratemodel/GammaSiteRateModel.java | 1 + 1 file changed, 1 insertion(+) diff --git a/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java b/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java index faa86bdf4f..2f9f88a04d 100644 --- a/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java +++ b/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java @@ -436,6 +436,7 @@ public static void setQuatratureRates(double[] categoryRates, double[] categoryP public static void setEqualRates(double[] categoryRates, double[] categoryProportions, double alpha, int catCount, int offset) { for (int i = 0; i < catCount; i++) { categoryRates[i + offset] = GammaDistribution.quantile((2.0 * i + 1.0) / (2.0 * catCount), alpha, 1.0 / alpha); + categoryProportions[i + offset] = 1.0; } normalize(categoryRates, categoryProportions); From 91e4bd0807281bbdd7f0fd3f2d3362da47f73aaa Mon Sep 17 00:00:00 2001 From: Andrew Rambaut Date: Mon, 24 Apr 2023 15:19:49 +0100 Subject: [PATCH 141/196] Updated quadrature rate heterogeneity --- .../siteratemodel/GammaSiteRateModel.java | 31 ++++++++++++++----- .../siteratemodel/GammaSiteModelParser.java | 2 -- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java b/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java index 2f9f88a04d..5b3c20a168 100644 --- a/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java +++ b/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java @@ -442,6 +442,11 @@ public static void setEqualRates(double[] categoryRates, double[] categoryPropor normalize(categoryRates, categoryProportions); } + /** + * Gives the category rates a mean of 1.0 and the proportions sum to 1.0 + * @param categoryRates + * @param categoryProportions + */ public static void normalize(double[] categoryRates, double[] categoryProportions) { double mean = 0.0; double sum = 0.0; @@ -452,7 +457,7 @@ public static void normalize(double[] categoryRates, double[] categoryProportion mean /= categoryRates.length; for(int i = 0; i < categoryRates.length; i++) { - categoryRates[i] /= mean; + //categoryRates[i] /= mean; categoryProportions[i] /= sum; } } @@ -482,13 +487,23 @@ public static void main(String[] argv) { // Table 3 from Felsenstein 2001, JME // Rates and probabilities chosen by the quadrature method for six rates and coefficient of // variation of rates among sites 1 (a 4 1) - // Probability Rate - // 0.278 0.264 - // 0.494 0.898 - // 0.203 1.938 - // 0.025 3.459 - // 0.00076 5.617 - // 0.000003 8.823 + // Rate Probability + // 0.264 0.278 + // 0.898 0.494 + // 1.938 0.203 + // 3.459 0.025 + // 5.617 0.00076 + // 8.823 0.000003 + + // Output (without setting rates to mean of 1) + // Quadrature, alpha = 1.0 + // cat rate proportion + // 0 0.26383406085556455 0.27765014202987454 + // 1 0.8981499048217043 0.49391058305035496 + // 2 1.938320760238456 0.20300429674372977 + // 3 3.459408283352361 0.02466882036918974 + // 4 5.617305214541558 7.6304276746353E-4 + // 5 8.822981776190357 3.1150393875275343E-6 setEqualRates(categoryRates, categoryProportions, 0.1, catCount, 0); System.out.println(); diff --git a/src/dr/evomodelxml/siteratemodel/GammaSiteModelParser.java b/src/dr/evomodelxml/siteratemodel/GammaSiteModelParser.java index e603d64bff..09b31b9aa8 100644 --- a/src/dr/evomodelxml/siteratemodel/GammaSiteModelParser.java +++ b/src/dr/evomodelxml/siteratemodel/GammaSiteModelParser.java @@ -57,8 +57,6 @@ public class GammaSiteModelParser extends AbstractXMLObjectParser { public static final String GAMMA_CATEGORIES = "gammaCategories"; public static final String PROPORTION_INVARIANT = "proportionInvariant"; public static final String DISCRETIZATION = "discretization"; - public static final String EVEN = "even"; - public static final String QUADRATURE = "quadrature"; public String getParserName() { return SITE_MODEL; From a6d7044517fdf44d4fcc247fbe74ba1432ae827a Mon Sep 17 00:00:00 2001 From: rambaut Date: Mon, 24 Apr 2023 15:29:51 +0100 Subject: [PATCH 142/196] Updating reporting --- src/dr/app/bss/PartitionData.java | 2 +- src/dr/app/tools/AncestralSequenceAnnotator.java | 2 +- .../evomodel/siteratemodel/GammaSiteRateModel.java | 7 ++----- .../treedatalikelihood/DataLikelihoodTester.java | 4 ++-- .../treedatalikelihood/DataLikelihoodTester2.java | 4 ++-- .../siteratemodel/GammaSiteModelParser.java | 13 ++++++------- src/test/dr/app/beagle/TinyTest.java | 2 +- .../substmodel/CompleteHistorySimulatorTest.java | 4 ++-- .../VariableBranchCompleteHistorySimulatorTest.java | 2 +- 9 files changed, 18 insertions(+), 22 deletions(-) diff --git a/src/dr/app/bss/PartitionData.java b/src/dr/app/bss/PartitionData.java index 12f6a3ae5f..4f68fa387a 100644 --- a/src/dr/app/bss/PartitionData.java +++ b/src/dr/app/bss/PartitionData.java @@ -1080,7 +1080,7 @@ public GammaSiteRateModel createSiteRateModel() { siteModel = new GammaSiteRateModel(name, siteRateModelParameterValues[1], (int) siteRateModelParameterValues[0], - GammaSiteRateModel.DiscretizationType.EVEN, + GammaSiteRateModel.DiscretizationType.EQUAL, siteRateModelParameterValues[2]); } else { diff --git a/src/dr/app/tools/AncestralSequenceAnnotator.java b/src/dr/app/tools/AncestralSequenceAnnotator.java index d5488dfdba..f0697b5366 100644 --- a/src/dr/app/tools/AncestralSequenceAnnotator.java +++ b/src/dr/app/tools/AncestralSequenceAnnotator.java @@ -808,7 +808,7 @@ else if(siteRatesModels.indexOf("+GAMMA(") >= 0) { /* For BEAST output */ //System.out.println("alpha and pinv parameters: " + alphaParameter.getParameterValue(0) + "\t" + pInvParameter.getParameterValue(0)); //GammaSiteRateModel siteModel = new GammaSiteRateModel(sml.getSubstitutionModel(), new Parameter.Default(1.0), alphaParameter, categories, pInvParameter); - GammaSiteRateModel siteModel = new GammaSiteRateModel(GammaSiteModelParser.SITE_MODEL, new Parameter.Default(1.0), 1.0, alphaParameter, categories, GammaSiteRateModel.DiscretizationType.EVEN, pInvParameter); + GammaSiteRateModel siteModel = new GammaSiteRateModel(GammaSiteModelParser.SITE_MODEL, new Parameter.Default(1.0), 1.0, alphaParameter, categories, GammaSiteRateModel.DiscretizationType.EQUAL, pInvParameter); siteModel.setSubstitutionModel(sml.getSubstitutionModel()); //SiteModel siteModel = new GammaSiteModel(sml.getSubstitutionModel(), new Parameter.Default(1.0), new Parameter.Default(1.0), 1, new Parameter.Default(0.5)); //SiteModel siteModel = new GammaSiteModel(sml.getSubstitutionModel(), null, null, 0, null); diff --git a/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java b/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java index 5b3c20a168..82af24ec61 100644 --- a/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java +++ b/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java @@ -26,9 +26,7 @@ package dr.evomodel.siteratemodel; import dr.inference.model.*; -import dr.math.GammaFunction; import dr.math.GeneralisedGaussLaguerreQuadrature; -import dr.math.UnivariateFunction; import dr.math.distributions.GammaDistribution; import dr.evomodel.substmodel.SubstitutionModel; import dr.util.Author; @@ -36,7 +34,6 @@ import dr.util.Citation; import java.util.ArrayList; -import java.util.Collections; import java.util.List; /** @@ -49,7 +46,7 @@ public class GammaSiteRateModel extends AbstractModel implements SiteRateModel, Citable { public enum DiscretizationType { - EVEN, + EQUAL, QUADRATURE } @@ -60,7 +57,7 @@ public GammaSiteRateModel(String name) { null, 1.0, null, - 0, DiscretizationType.EVEN, + 0, DiscretizationType.EQUAL, null); } diff --git a/src/dr/evomodel/treedatalikelihood/DataLikelihoodTester.java b/src/dr/evomodel/treedatalikelihood/DataLikelihoodTester.java index 43befa7b9c..990dad2c84 100644 --- a/src/dr/evomodel/treedatalikelihood/DataLikelihoodTester.java +++ b/src/dr/evomodel/treedatalikelihood/DataLikelihoodTester.java @@ -87,7 +87,7 @@ public static void main(String[] args) { //siteModel double alpha = 0.5; - GammaSiteRateModel siteRateModel = new GammaSiteRateModel("gammaModel", alpha, 4, GammaSiteRateModel.DiscretizationType.EVEN); + GammaSiteRateModel siteRateModel = new GammaSiteRateModel("gammaModel", alpha, 4, GammaSiteRateModel.DiscretizationType.EQUAL); // GammaSiteRateModel siteRateModel = new GammaSiteRateModel("siteRateModel"); siteRateModel.setSubstitutionModel(hky); Parameter mu = new Parameter.Default(GammaSiteModelParser.SUBSTITUTION_RATE, 1.0, 0, Double.POSITIVE_INFINITY); @@ -98,7 +98,7 @@ public static void main(String[] args) { HKY hky2 = new HKY(kappa2, f2) ; - GammaSiteRateModel siteRateModel2 = new GammaSiteRateModel("gammaModel", alpha, 4, GammaSiteRateModel.DiscretizationType.EVEN); + GammaSiteRateModel siteRateModel2 = new GammaSiteRateModel("gammaModel", alpha, 4, GammaSiteRateModel.DiscretizationType.EQUAL); siteRateModel2.setSubstitutionModel(hky2); siteRateModel2.setRelativeRateParameter(mu); diff --git a/src/dr/evomodel/treedatalikelihood/DataLikelihoodTester2.java b/src/dr/evomodel/treedatalikelihood/DataLikelihoodTester2.java index c3e4b54324..0a72e4f8b2 100644 --- a/src/dr/evomodel/treedatalikelihood/DataLikelihoodTester2.java +++ b/src/dr/evomodel/treedatalikelihood/DataLikelihoodTester2.java @@ -87,7 +87,7 @@ public static void main(String[] args) { //siteModel double alpha = 0.5; - GammaSiteRateModel siteRateModel = new GammaSiteRateModel("gammaModel", alpha, 4, GammaSiteRateModel.DiscretizationType.EVEN); + GammaSiteRateModel siteRateModel = new GammaSiteRateModel("gammaModel", alpha, 4, GammaSiteRateModel.DiscretizationType.EQUAL); // GammaSiteRateModel siteRateModel = new GammaSiteRateModel("siteRateModel"); siteRateModel.setSubstitutionModel(hky); Parameter mu = new Parameter.Default(GammaSiteModelParser.SUBSTITUTION_RATE, 1.0, 0, Double.POSITIVE_INFINITY); @@ -97,7 +97,7 @@ public static void main(String[] args) { Parameter kappa2 = new Parameter.Default(HKYParser.KAPPA, 10.0, 0, 100); HKY hky2 = new HKY(kappa2, f2); - GammaSiteRateModel siteRateModel2 = new GammaSiteRateModel("gammaModel", alpha, 4, GammaSiteRateModel.DiscretizationType.EVEN); + GammaSiteRateModel siteRateModel2 = new GammaSiteRateModel("gammaModel", alpha, 4, GammaSiteRateModel.DiscretizationType.EQUAL); siteRateModel2.setSubstitutionModel(hky2); siteRateModel2.setRelativeRateParameter(mu); diff --git a/src/dr/evomodelxml/siteratemodel/GammaSiteModelParser.java b/src/dr/evomodelxml/siteratemodel/GammaSiteModelParser.java index 09b31b9aa8..a9103bcee0 100644 --- a/src/dr/evomodelxml/siteratemodel/GammaSiteModelParser.java +++ b/src/dr/evomodelxml/siteratemodel/GammaSiteModelParser.java @@ -25,7 +25,6 @@ package dr.evomodelxml.siteratemodel; -import java.util.Locale; import java.util.logging.Logger; import dr.evomodel.siteratemodel.GammaSiteRateModel; @@ -88,7 +87,7 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { } } - GammaSiteRateModel.DiscretizationType type = GammaSiteRateModel.DiscretizationType.EVEN; + GammaSiteRateModel.DiscretizationType type = GammaSiteRateModel.DiscretizationType.EQUAL; Parameter shapeParam = null; int catCount = 4; @@ -99,17 +98,17 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { try { type = GammaSiteRateModel.DiscretizationType.valueOf( cxo.getStringAttribute(DISCRETIZATION).toUpperCase()); - if (type == GammaSiteRateModel.DiscretizationType.EVEN) { - msg += "\n even discretization of gamma distribution"; - } else { - msg += "\n quadrature discretization of gamma distribution"; - } } catch (IllegalArgumentException eae) { throw new XMLParseException("Unknown category width type: " + cxo.getStringAttribute(DISCRETIZATION)); } shapeParam = (Parameter) cxo.getChild(Parameter.class); msg += "\n " + catCount + " category discrete gamma with initial shape = " + shapeParam.getParameterValue(0); + if (type == GammaSiteRateModel.DiscretizationType.EQUAL) { + msg += "\n using equal weight discretization of gamma distribution"; + } else { + msg += "\n using Gauss-Laguerre quadrature discretization of gamma distribution"; + } } Parameter invarParam = null; diff --git a/src/test/dr/app/beagle/TinyTest.java b/src/test/dr/app/beagle/TinyTest.java index 91b59ec52f..a733d0e896 100644 --- a/src/test/dr/app/beagle/TinyTest.java +++ b/src/test/dr/app/beagle/TinyTest.java @@ -38,7 +38,7 @@ public void testTiny() { //siteModel double alpha = 0.5; - GammaSiteRateModel siteRateModel = new GammaSiteRateModel("gammaModel", alpha, 4, GammaSiteRateModel.DiscretizationType.EVEN); + GammaSiteRateModel siteRateModel = new GammaSiteRateModel("gammaModel", alpha, 4, GammaSiteRateModel.DiscretizationType.EQUAL); siteRateModel.setSubstitutionModel(hky); Parameter mu = new Parameter.Default(GammaSiteModelParser.MUTATION_RATE, 1.0, 0, Double.POSITIVE_INFINITY); siteRateModel.setRelativeRateParameter(mu); diff --git a/src/test/dr/evomodel/substmodel/CompleteHistorySimulatorTest.java b/src/test/dr/evomodel/substmodel/CompleteHistorySimulatorTest.java index 34806ead49..f072a7ac31 100644 --- a/src/test/dr/evomodel/substmodel/CompleteHistorySimulatorTest.java +++ b/src/test/dr/evomodel/substmodel/CompleteHistorySimulatorTest.java @@ -51,7 +51,7 @@ public void testHKYSimulation() { Parameter mu = new Parameter.Default(1, 0.5); Parameter alpha = new Parameter.Default(1, 0.5); - GammaSiteRateModel siteModel = new GammaSiteRateModel("gammaModel", mu, 1.0, alpha, 4, GammaSiteRateModel.DiscretizationType.EVEN, null); + GammaSiteRateModel siteModel = new GammaSiteRateModel("gammaModel", mu, 1.0, alpha, 4, GammaSiteRateModel.DiscretizationType.EQUAL, null); siteModel.setSubstitutionModel(hky); BranchRateModel branchRateModel = new DefaultBranchRateModel(); @@ -93,7 +93,7 @@ public void testCodonSimulation() { Parameter mu = new Parameter.Default(1, 0.5); Parameter alpha = new Parameter.Default(1, 0.5); - GammaSiteRateModel siteModel = new GammaSiteRateModel("gammaModel", mu, 1.0, alpha, 4, GammaSiteRateModel.DiscretizationType.EVEN, null); + GammaSiteRateModel siteModel = new GammaSiteRateModel("gammaModel", mu, 1.0, alpha, 4, GammaSiteRateModel.DiscretizationType.EQUAL, null); siteModel.setSubstitutionModel(codonModel); BranchRateModel branchRateModel = new DefaultBranchRateModel(); diff --git a/src/test/dr/evomodel/substmodel/VariableBranchCompleteHistorySimulatorTest.java b/src/test/dr/evomodel/substmodel/VariableBranchCompleteHistorySimulatorTest.java index cad222adf2..df049d341c 100644 --- a/src/test/dr/evomodel/substmodel/VariableBranchCompleteHistorySimulatorTest.java +++ b/src/test/dr/evomodel/substmodel/VariableBranchCompleteHistorySimulatorTest.java @@ -33,7 +33,7 @@ public void testHKYVariableSimulation() { Parameter mu = new Parameter.Default(1, 0.5); Parameter alpha = new Parameter.Default(1, 0.5); - GammaSiteRateModel siteModel = new GammaSiteRateModel("gammaModel", mu, 1.0, alpha, 4, GammaSiteRateModel.DiscretizationType.EVEN, null); + GammaSiteRateModel siteModel = new GammaSiteRateModel("gammaModel", mu, 1.0, alpha, 4, GammaSiteRateModel.DiscretizationType.EQUAL, null); siteModel.setSubstitutionModel(hky); BranchRateModel branchRateModel = new DefaultBranchRateModel(); From f75e82f8f5c2f96dfc49e7084017f789351fd932 Mon Sep 17 00:00:00 2001 From: rambaut Date: Mon, 24 Apr 2023 15:34:40 +0100 Subject: [PATCH 143/196] Updating reporting --- src/dr/app/bss/PartitionData.java | 1 - .../evomodel/siteratemodel/GammaSiteRateModel.java | 14 +++++++------- .../treedatalikelihood/DataLikelihoodTester.java | 4 ++-- .../treedatalikelihood/DataLikelihoodTester2.java | 4 ++-- src/test/dr/app/beagle/TinyTest.java | 2 +- 5 files changed, 12 insertions(+), 13 deletions(-) diff --git a/src/dr/app/bss/PartitionData.java b/src/dr/app/bss/PartitionData.java index 4f68fa387a..6389ed55cc 100644 --- a/src/dr/app/bss/PartitionData.java +++ b/src/dr/app/bss/PartitionData.java @@ -1080,7 +1080,6 @@ public GammaSiteRateModel createSiteRateModel() { siteModel = new GammaSiteRateModel(name, siteRateModelParameterValues[1], (int) siteRateModelParameterValues[0], - GammaSiteRateModel.DiscretizationType.EQUAL, siteRateModelParameterValues[2]); } else { diff --git a/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java b/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java index 82af24ec61..555932dced 100644 --- a/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java +++ b/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java @@ -45,12 +45,12 @@ public class GammaSiteRateModel extends AbstractModel implements SiteRateModel, Citable { + private static final DiscretizationType DEFAULT_DISCRETIZATION = DiscretizationType.EQUAL; + public enum DiscretizationType { EQUAL, QUADRATURE - } - - ; + }; public GammaSiteRateModel(String name) { this(name, @@ -61,23 +61,23 @@ public GammaSiteRateModel(String name) { null); } - public GammaSiteRateModel(String name, double alpha, int categoryCount, DiscretizationType discretizationType) { + public GammaSiteRateModel(String name, double alpha, int categoryCount) { this(name, null, 1.0, new Parameter.Default(alpha), categoryCount, - discretizationType, + DEFAULT_DISCRETIZATION, null); } - public GammaSiteRateModel(String name, double alpha, int categoryCount, DiscretizationType discretizationType, double pInvar) { + public GammaSiteRateModel(String name, double alpha, int categoryCount, double pInvar) { this(name, null, 1.0, new Parameter.Default(alpha), categoryCount, - discretizationType, + DEFAULT_DISCRETIZATION, new Parameter.Default(pInvar)); } diff --git a/src/dr/evomodel/treedatalikelihood/DataLikelihoodTester.java b/src/dr/evomodel/treedatalikelihood/DataLikelihoodTester.java index 990dad2c84..1dc72c94de 100644 --- a/src/dr/evomodel/treedatalikelihood/DataLikelihoodTester.java +++ b/src/dr/evomodel/treedatalikelihood/DataLikelihoodTester.java @@ -87,7 +87,7 @@ public static void main(String[] args) { //siteModel double alpha = 0.5; - GammaSiteRateModel siteRateModel = new GammaSiteRateModel("gammaModel", alpha, 4, GammaSiteRateModel.DiscretizationType.EQUAL); + GammaSiteRateModel siteRateModel = new GammaSiteRateModel("gammaModel", alpha, 4); // GammaSiteRateModel siteRateModel = new GammaSiteRateModel("siteRateModel"); siteRateModel.setSubstitutionModel(hky); Parameter mu = new Parameter.Default(GammaSiteModelParser.SUBSTITUTION_RATE, 1.0, 0, Double.POSITIVE_INFINITY); @@ -98,7 +98,7 @@ public static void main(String[] args) { HKY hky2 = new HKY(kappa2, f2) ; - GammaSiteRateModel siteRateModel2 = new GammaSiteRateModel("gammaModel", alpha, 4, GammaSiteRateModel.DiscretizationType.EQUAL); + GammaSiteRateModel siteRateModel2 = new GammaSiteRateModel("gammaModel", alpha, 4); siteRateModel2.setSubstitutionModel(hky2); siteRateModel2.setRelativeRateParameter(mu); diff --git a/src/dr/evomodel/treedatalikelihood/DataLikelihoodTester2.java b/src/dr/evomodel/treedatalikelihood/DataLikelihoodTester2.java index 0a72e4f8b2..3625650d56 100644 --- a/src/dr/evomodel/treedatalikelihood/DataLikelihoodTester2.java +++ b/src/dr/evomodel/treedatalikelihood/DataLikelihoodTester2.java @@ -87,7 +87,7 @@ public static void main(String[] args) { //siteModel double alpha = 0.5; - GammaSiteRateModel siteRateModel = new GammaSiteRateModel("gammaModel", alpha, 4, GammaSiteRateModel.DiscretizationType.EQUAL); + GammaSiteRateModel siteRateModel = new GammaSiteRateModel("gammaModel", alpha, 4); // GammaSiteRateModel siteRateModel = new GammaSiteRateModel("siteRateModel"); siteRateModel.setSubstitutionModel(hky); Parameter mu = new Parameter.Default(GammaSiteModelParser.SUBSTITUTION_RATE, 1.0, 0, Double.POSITIVE_INFINITY); @@ -97,7 +97,7 @@ public static void main(String[] args) { Parameter kappa2 = new Parameter.Default(HKYParser.KAPPA, 10.0, 0, 100); HKY hky2 = new HKY(kappa2, f2); - GammaSiteRateModel siteRateModel2 = new GammaSiteRateModel("gammaModel", alpha, 4, GammaSiteRateModel.DiscretizationType.EQUAL); + GammaSiteRateModel siteRateModel2 = new GammaSiteRateModel("gammaModel", alpha, 4); siteRateModel2.setSubstitutionModel(hky2); siteRateModel2.setRelativeRateParameter(mu); diff --git a/src/test/dr/app/beagle/TinyTest.java b/src/test/dr/app/beagle/TinyTest.java index a733d0e896..abe2493464 100644 --- a/src/test/dr/app/beagle/TinyTest.java +++ b/src/test/dr/app/beagle/TinyTest.java @@ -38,7 +38,7 @@ public void testTiny() { //siteModel double alpha = 0.5; - GammaSiteRateModel siteRateModel = new GammaSiteRateModel("gammaModel", alpha, 4, GammaSiteRateModel.DiscretizationType.EQUAL); + GammaSiteRateModel siteRateModel = new GammaSiteRateModel("gammaModel", alpha, 4); siteRateModel.setSubstitutionModel(hky); Parameter mu = new Parameter.Default(GammaSiteModelParser.MUTATION_RATE, 1.0, 0, Double.POSITIVE_INFINITY); siteRateModel.setRelativeRateParameter(mu); From 746b0c4186d14c4f5e95cf97888d256285a7cf96 Mon Sep 17 00:00:00 2001 From: rambaut Date: Mon, 24 Apr 2023 16:42:06 +0100 Subject: [PATCH 144/196] rates with mean of 1 --- src/dr/evomodel/siteratemodel/GammaSiteRateModel.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java b/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java index 555932dced..3baa565e0f 100644 --- a/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java +++ b/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java @@ -454,7 +454,7 @@ public static void normalize(double[] categoryRates, double[] categoryProportion mean /= categoryRates.length; for(int i = 0; i < categoryRates.length; i++) { - //categoryRates[i] /= mean; + categoryRates[i] /= mean; categoryProportions[i] /= sum; } } From cb3c5c38d1990f9bc7f688bcd1cc29efdbc4eb00 Mon Sep 17 00:00:00 2001 From: rambaut Date: Tue, 25 Apr 2023 17:45:40 +0100 Subject: [PATCH 145/196] Trying to get it to work... --- .../siteratemodel/GammaSiteRateModel.java | 19 ++++++++++++++++++- .../siteratemodel/GammaSiteModelParser.java | 14 ++++++++------ 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java b/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java index 3baa565e0f..317a989941 100644 --- a/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java +++ b/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java @@ -45,7 +45,7 @@ public class GammaSiteRateModel extends AbstractModel implements SiteRateModel, Citable { - private static final DiscretizationType DEFAULT_DISCRETIZATION = DiscretizationType.EQUAL; + public static final DiscretizationType DEFAULT_DISCRETIZATION = DiscretizationType.EQUAL; public enum DiscretizationType { EQUAL, @@ -414,6 +414,15 @@ public List getCitations() { private static GeneralisedGaussLaguerreQuadrature quadrature = null; + /** + * Set the rates and proportions using a Gauss-Laguerre Quadrature, as proposed by Felsenstein 2001, JME + * + * @param categoryRates + * @param categoryProportions + * @param alpha + * @param catCount + * @param offset + */ public static void setQuatratureRates(double[] categoryRates, double[] categoryProportions, double alpha, int catCount, int offset) { if (quadrature == null) { quadrature = new GeneralisedGaussLaguerreQuadrature(catCount); @@ -430,6 +439,14 @@ public static void setQuatratureRates(double[] categoryRates, double[] categoryP normalize(categoryRates, categoryProportions); } + /** + * set the rates as equally spaced quantiles represented by the mean as proposed by Yang 1994 + * @param categoryRates + * @param categoryProportions + * @param alpha + * @param catCount + * @param offset + */ public static void setEqualRates(double[] categoryRates, double[] categoryProportions, double alpha, int catCount, int offset) { for (int i = 0; i < catCount; i++) { categoryRates[i + offset] = GammaDistribution.quantile((2.0 * i + 1.0) / (2.0 * catCount), alpha, 1.0 / alpha); diff --git a/src/dr/evomodelxml/siteratemodel/GammaSiteModelParser.java b/src/dr/evomodelxml/siteratemodel/GammaSiteModelParser.java index a9103bcee0..f241099b54 100644 --- a/src/dr/evomodelxml/siteratemodel/GammaSiteModelParser.java +++ b/src/dr/evomodelxml/siteratemodel/GammaSiteModelParser.java @@ -87,7 +87,7 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { } } - GammaSiteRateModel.DiscretizationType type = GammaSiteRateModel.DiscretizationType.EQUAL; + GammaSiteRateModel.DiscretizationType type = GammaSiteRateModel.DEFAULT_DISCRETIZATION; Parameter shapeParam = null; int catCount = 4; @@ -95,11 +95,13 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { XMLObject cxo = xo.getChild(GAMMA_SHAPE); catCount = cxo.getIntegerAttribute(GAMMA_CATEGORIES); - try { - type = GammaSiteRateModel.DiscretizationType.valueOf( - cxo.getStringAttribute(DISCRETIZATION).toUpperCase()); - } catch (IllegalArgumentException eae) { - throw new XMLParseException("Unknown category width type: " + cxo.getStringAttribute(DISCRETIZATION)); + if ( cxo.hasAttribute(DISCRETIZATION)) { + try { + type = GammaSiteRateModel.DiscretizationType.valueOf( + cxo.getStringAttribute(DISCRETIZATION).toUpperCase()); + } catch (IllegalArgumentException eae) { + throw new XMLParseException("Unknown category width type: " + cxo.getStringAttribute(DISCRETIZATION)); + } } shapeParam = (Parameter) cxo.getChild(Parameter.class); From 11f20761d076d4ee7d0877e3633d7f1a5647b0c8 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Wed, 10 May 2023 14:25:52 -0700 Subject: [PATCH 146/196] starting to get gradient for both loadings & precision in integrated factor model --- .../hmc/IntegratedLoadingsGradient.java | 44 ++++++++++++++----- 1 file changed, 33 insertions(+), 11 deletions(-) diff --git a/src/dr/evomodel/continuous/hmc/IntegratedLoadingsGradient.java b/src/dr/evomodel/continuous/hmc/IntegratedLoadingsGradient.java index c447cdf942..659d81117e 100644 --- a/src/dr/evomodel/continuous/hmc/IntegratedLoadingsGradient.java +++ b/src/dr/evomodel/continuous/hmc/IntegratedLoadingsGradient.java @@ -253,19 +253,19 @@ public double[] getGradientLogDensity() { return join(gradients); } + protected class MeanAndMoment { + public final ReadableVector mean; + public final double[] moment; - private void computeGradientForOneTaxon(final int index, - final int taxon, - final ReadableMatrix loadings, - final double[] transposedLoadings, - final ReadableVector gamma, - final double[] rawGamma, - final WrappedNormalSufficientStatistics statistic, - final double[][] gradArray) { - - if (TIMING) { - stopWatches[0].start(); + public MeanAndMoment(ReadableVector mean, double[] moment) { + this.mean = mean; + this.moment = moment; } + } + + protected MeanAndMoment getMeanAndMoment(final int taxon, + final WrappedNormalSufficientStatistics statistic) { + // final WrappedVector y = getTipData(taxon); final WrappedNormalSufficientStatistics dataKernel = getTipKernel(taxon); @@ -319,6 +319,28 @@ private void computeGradientForOneTaxon(final int index, double[] moment = ReadableMatrix.Utils.toArray(secondMoment); + return new MeanAndMoment(mean, moment); + } + + + private void computeGradientForOneTaxon(final int index, + final int taxon, + final ReadableMatrix loadings, + final double[] transposedLoadings, + final ReadableVector gamma, + final double[] rawGamma, + final WrappedNormalSufficientStatistics statistic, + final double[][] gradArray) { + + if (TIMING) { + stopWatches[0].start(); + } + + final MeanAndMoment meanAndMoment = getMeanAndMoment(taxon, statistic); + final ReadableVector mean = meanAndMoment.mean; + final double[] moment = meanAndMoment.moment; + + if (TIMING) { stopWatches[0].stop(); stopWatches[1].start(); From 51d6e00919f9b325e33c4f35e9ed98b0ac239cab Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Wed, 10 May 2023 14:34:13 -0700 Subject: [PATCH 147/196] trying to avoid as much code duplication as possible --- .../hmc/IntegratedLoadingsGradient.java | 54 +++++++++++-------- 1 file changed, 32 insertions(+), 22 deletions(-) diff --git a/src/dr/evomodel/continuous/hmc/IntegratedLoadingsGradient.java b/src/dr/evomodel/continuous/hmc/IntegratedLoadingsGradient.java index 659d81117e..8f13be1b6a 100644 --- a/src/dr/evomodel/continuous/hmc/IntegratedLoadingsGradient.java +++ b/src/dr/evomodel/continuous/hmc/IntegratedLoadingsGradient.java @@ -322,6 +322,37 @@ protected MeanAndMoment getMeanAndMoment(final int taxon, return new MeanAndMoment(mean, moment); } + private void computeLoadingsGradientForOneTaxon(final int index, + final int taxon, + final double[] transposedLoadings, + final double[] rawGamma, + final double[][] gradArray, + ReadableVector mean, + double[] moment) { + for (int factor = 0; factor < dimFactors; ++factor) { + double factorMean = mean.get(factor); + + for (int trait = 0; trait < dimTrait; ++trait) { + if (!missing[taxon * dimTrait + trait]) { + + double product = 0.0; + for (int k = 0; k < dimFactors; ++k) { + product += moment[factor * dimFactors + k] // secondMoment.get(factor, k) + * transposedLoadings[trait * dimFactors + k]; // loadings.get(k, trait); + } + + gradArray[index][factor * dimTrait + trait] += + (factorMean // mean.get(factor) + * data[taxon * dimTrait + trait] //y.get(trait) + - product) +// - product.get(factor, trait)) + * rawGamma[trait]; // gamma.get(trait); + + } + } + } + } + private void computeGradientForOneTaxon(final int index, final int taxon, @@ -346,28 +377,7 @@ private void computeGradientForOneTaxon(final int index, stopWatches[1].start(); } - for (int factor = 0; factor < dimFactors; ++factor) { - double factorMean = mean.get(factor); - - for (int trait = 0; trait < dimTrait; ++trait) { - if (!missing[taxon * dimTrait + trait]) { - - double product = 0.0; - for (int k = 0; k < dimFactors; ++k) { - product += moment[factor * dimFactors + k] // secondMoment.get(factor, k) - * transposedLoadings[trait * dimFactors + k]; // loadings.get(k, trait); - } - - gradArray[index][factor * dimTrait + trait] += - (factorMean // mean.get(factor) - * data[taxon * dimTrait + trait] //y.get(trait) - - product) -// - product.get(factor, trait)) - * rawGamma[trait]; // gamma.get(trait); - - } - } - } + computeLoadingsGradientForOneTaxon(index, taxon, transposedLoadings, rawGamma, gradArray, mean, moment); if (TIMING) { stopWatches[1].stop(); From 926e56a31b96fdf84ff4ce78134ad8861665b649 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Wed, 10 May 2023 14:59:30 -0700 Subject: [PATCH 148/196] more generalization --- .../hmc/IntegratedLoadingsGradient.java | 61 +++++++++++++------ 1 file changed, 44 insertions(+), 17 deletions(-) diff --git a/src/dr/evomodel/continuous/hmc/IntegratedLoadingsGradient.java b/src/dr/evomodel/continuous/hmc/IntegratedLoadingsGradient.java index 8f13be1b6a..df8e5a799d 100644 --- a/src/dr/evomodel/continuous/hmc/IntegratedLoadingsGradient.java +++ b/src/dr/evomodel/continuous/hmc/IntegratedLoadingsGradient.java @@ -322,13 +322,25 @@ protected MeanAndMoment getMeanAndMoment(final int taxon, return new MeanAndMoment(mean, moment); } - private void computeLoadingsGradientForOneTaxon(final int index, - final int taxon, - final double[] transposedLoadings, - final double[] rawGamma, - final double[][] gradArray, - ReadableVector mean, - double[] moment) { + protected class GradientComponents { + public final double[] fty; + public final double[] ftfl; + + public GradientComponents(double[] fty, double[] ftfl) { + this.fty = fty; + this.ftfl = ftfl; + } + } + + protected GradientComponents computeGradientComponents(final int taxon, + final double[] transposedLoadings, + final MeanAndMoment meanAndMoment) { + + final ReadableVector mean = meanAndMoment.mean; + final double[] moment = meanAndMoment.moment; + + double fty[] = new double[dimFactors * dimTrait]; + double ftfl[] = new double[dimFactors * dimTrait]; for (int factor = 0; factor < dimFactors; ++factor) { double factorMean = mean.get(factor); @@ -341,16 +353,33 @@ private void computeLoadingsGradientForOneTaxon(final int index, * transposedLoadings[trait * dimFactors + k]; // loadings.get(k, trait); } - gradArray[index][factor * dimTrait + trait] += - (factorMean // mean.get(factor) - * data[taxon * dimTrait + trait] //y.get(trait) - - product) -// - product.get(factor, trait)) - * rawGamma[trait]; // gamma.get(trait); + int ind = factor * dimTrait + trait; + + fty[ind] += factorMean * data[taxon * dimTrait + trait]; + ftfl[ind] += product; } } } + + return new GradientComponents(fty, ftfl); + } + + protected void computeLoadingsGradientForOneTaxon(int index, + GradientComponents components, + double[] rawGamma, + double[][] gradArray) { + + double[] fty = components.fty; + double[] ftfl = components.ftfl; + for (int factor = 0; factor < dimFactors; ++factor) { + for (int trait = 0; trait < dimTrait; ++trait) { + int ind = factor * dimTrait + trait; + gradArray[index][factor * dimTrait + trait] += + (fty[ind] - ftfl[ind]) * rawGamma[trait]; + } + } + } @@ -368,16 +397,14 @@ private void computeGradientForOneTaxon(final int index, } final MeanAndMoment meanAndMoment = getMeanAndMoment(taxon, statistic); - final ReadableVector mean = meanAndMoment.mean; - final double[] moment = meanAndMoment.moment; - if (TIMING) { stopWatches[0].stop(); stopWatches[1].start(); } - computeLoadingsGradientForOneTaxon(index, taxon, transposedLoadings, rawGamma, gradArray, mean, moment); + GradientComponents components = computeGradientComponents(taxon, transposedLoadings, meanAndMoment); + computeLoadingsGradientForOneTaxon(index, components, rawGamma, gradArray); if (TIMING) { stopWatches[1].stop(); From 9e0e9118cf7eaae827d5576f2411da10fcbcdb8c Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Wed, 10 May 2023 15:56:07 -0700 Subject: [PATCH 149/196] (broken) gradient for loadings & precision --- .../app/beast/development_parsers.properties | 1 + ...ntegratedLoadingsAndPrecisionGradient.java | 102 ++++++++++++++++++ .../hmc/IntegratedLoadingsGradient.java | 26 ++--- ...tedLoadingsAndPrecisionGradientParser.java | 64 +++++++++++ .../hmc/IntegratedLoadingsGradientParser.java | 22 +++- 5 files changed, 200 insertions(+), 15 deletions(-) create mode 100644 src/dr/evomodel/continuous/hmc/IntegratedLoadingsAndPrecisionGradient.java create mode 100644 src/dr/evomodelxml/continuous/hmc/IntegratedLoadingsAndPrecisionGradientParser.java diff --git a/src/dr/app/beast/development_parsers.properties b/src/dr/app/beast/development_parsers.properties index f091ef9921..886e2487f2 100644 --- a/src/dr/app/beast/development_parsers.properties +++ b/src/dr/app/beast/development_parsers.properties @@ -308,6 +308,7 @@ dr.evomodel.operators.ExtendedLatentLiabilityGibbsOperator dr.inference.model.FactorProportionStatistic dr.inferencexml.model.BlombergKStatisticParser dr.inference.operators.factorAnalysis.GaussianTreeTraitGibbsOperator +dr.evomodelxml.continuous.hmc.IntegratedLoadingsAndPrecisionGradientParser # Shrinkage dr.inference.model.MaskFromTree diff --git a/src/dr/evomodel/continuous/hmc/IntegratedLoadingsAndPrecisionGradient.java b/src/dr/evomodel/continuous/hmc/IntegratedLoadingsAndPrecisionGradient.java new file mode 100644 index 0000000000..470feeddf3 --- /dev/null +++ b/src/dr/evomodel/continuous/hmc/IntegratedLoadingsAndPrecisionGradient.java @@ -0,0 +1,102 @@ +package dr.evomodel.continuous.hmc; + +import dr.evomodel.treedatalikelihood.TreeDataLikelihood; +import dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate; +import dr.evomodel.treedatalikelihood.continuous.ContinuousTraitPartialsProvider; +import dr.evomodel.treedatalikelihood.continuous.IntegratedFactorAnalysisLikelihood; +import dr.evomodel.treedatalikelihood.preorder.WrappedNormalSufficientStatistics; +import dr.inference.model.CompoundParameter; +import dr.inference.model.Parameter; +import dr.math.matrixAlgebra.ReadableMatrix; +import dr.math.matrixAlgebra.ReadableVector; +import dr.util.TaskPool; + +public class IntegratedLoadingsAndPrecisionGradient extends IntegratedLoadingsGradient { + + CompoundParameter jointParameter; + + public IntegratedLoadingsAndPrecisionGradient(CompoundParameter jointParameter, + TreeDataLikelihood treeDataLikelihood, + ContinuousDataLikelihoodDelegate likelihoodDelegate, + IntegratedFactorAnalysisLikelihood factorAnalysisLikelihood, + ContinuousTraitPartialsProvider partialsProvider, + TaskPool taskPool, + ThreadUseProvider threadUseProvider, + RemainderCompProvider remainderCompProvider) { + super(treeDataLikelihood, likelihoodDelegate, factorAnalysisLikelihood, partialsProvider, taskPool, threadUseProvider, remainderCompProvider); + this.jointParameter = jointParameter; + } + + @Override + public Parameter getParameter() { + return jointParameter; + } + + @Override + protected int getGradientDimension() { + return dimFactors * dimTrait + dimTrait; + } + + @Override + public int getDimension() { + return getGradientDimension(); + } + + private void computePrecisionGradientForOneTaxon(int index, + int taxon, + GradientComponents components, + double[] transposedLoadings, + double[] rawGamma, + double[][] gradArray, + int offset) { + double[] fty = components.fty; + double[] ftfl = components.ftfl; + + + for (int factor = 0; factor < dimFactors; ++factor) { + for (int trait = 0; trait < dimTrait; ++trait) { + int ind = factor * dimTrait + trait; + gradArray[index][offset + trait] += + (2 * fty[ind] - ftfl[ind]) * transposedLoadings[ind]; + } + } + for (int trait = 0; trait < dimTrait; ++trait) { + double dat = data[taxon * dimTrait + trait]; + gradArray[index][offset + trait] += dat * dat + 1 / rawGamma[trait]; //TODO: need to deal w/ missing data + } + } + + @Override + protected void computeGradientForOneTaxon(final int index, + final int taxon, + final ReadableMatrix loadings, + final double[] transposedLoadings, + final ReadableVector gamma, + final double[] rawGamma, + final WrappedNormalSufficientStatistics statistic, + final double[][] gradArray) { + + if (TIMING) { + stopWatches[0].start(); + } + + final MeanAndMoment meanAndMoment = getMeanAndMoment(taxon, statistic); + + if (TIMING) { + stopWatches[0].stop(); + stopWatches[1].start(); + } + + GradientComponents components = computeGradientComponents(taxon, transposedLoadings, meanAndMoment); + computeLoadingsGradientForOneTaxon(index, components, rawGamma, gradArray); + computePrecisionGradientForOneTaxon(index, taxon, components, transposedLoadings, rawGamma, gradArray, dimFactors * dimTrait); + + + if (TIMING) { + stopWatches[1].stop(); + } +// } + } + + +} diff --git a/src/dr/evomodel/continuous/hmc/IntegratedLoadingsGradient.java b/src/dr/evomodel/continuous/hmc/IntegratedLoadingsGradient.java index df8e5a799d..daee17761e 100644 --- a/src/dr/evomodel/continuous/hmc/IntegratedLoadingsGradient.java +++ b/src/dr/evomodel/continuous/hmc/IntegratedLoadingsGradient.java @@ -40,7 +40,7 @@ public class IntegratedLoadingsGradient implements GradientWrtParameterProvider, protected final int dimPartials; private final Tree tree; private final Likelihood likelihood; - private final double[] data; + protected final double[] data; private final boolean[] missing; private final ThreadUseProvider threadUseProvider; private final RemainderCompProvider remainderCompProvider; @@ -132,7 +132,7 @@ public int getDimension() { return dimFactors * dimTrait; } - private int getGradientDimension() { + protected int getGradientDimension() { return dimFactors * dimTrait; } @@ -375,7 +375,7 @@ protected void computeLoadingsGradientForOneTaxon(int index, for (int factor = 0; factor < dimFactors; ++factor) { for (int trait = 0; trait < dimTrait; ++trait) { int ind = factor * dimTrait + trait; - gradArray[index][factor * dimTrait + trait] += + gradArray[index][ind] += (fty[ind] - ftfl[ind]) * rawGamma[trait]; } } @@ -383,14 +383,14 @@ protected void computeLoadingsGradientForOneTaxon(int index, } - private void computeGradientForOneTaxon(final int index, - final int taxon, - final ReadableMatrix loadings, - final double[] transposedLoadings, - final ReadableVector gamma, - final double[] rawGamma, - final WrappedNormalSufficientStatistics statistic, - final double[][] gradArray) { + protected void computeGradientForOneTaxon(final int index, + final int taxon, + final ReadableMatrix loadings, + final double[] transposedLoadings, + final ReadableVector gamma, + final double[] rawGamma, + final WrappedNormalSufficientStatistics statistic, + final double[][] gradArray) { if (TIMING) { stopWatches[0].start(); @@ -507,8 +507,8 @@ boolean computeRemainder() { } - private StopWatch[] stopWatches; - private static final boolean TIMING = false; + protected StopWatch[] stopWatches; + protected static final boolean TIMING = false; private static final boolean DEBUG = false; diff --git a/src/dr/evomodelxml/continuous/hmc/IntegratedLoadingsAndPrecisionGradientParser.java b/src/dr/evomodelxml/continuous/hmc/IntegratedLoadingsAndPrecisionGradientParser.java new file mode 100644 index 0000000000..4d0cc1ec38 --- /dev/null +++ b/src/dr/evomodelxml/continuous/hmc/IntegratedLoadingsAndPrecisionGradientParser.java @@ -0,0 +1,64 @@ +package dr.evomodelxml.continuous.hmc; + +import dr.evomodel.continuous.hmc.IntegratedLoadingsAndPrecisionGradient; +import dr.evomodel.continuous.hmc.IntegratedLoadingsGradient; +import dr.evomodel.treedatalikelihood.TreeDataLikelihood; +import dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate; +import dr.evomodel.treedatalikelihood.continuous.ContinuousTraitPartialsProvider; +import dr.evomodel.treedatalikelihood.continuous.IntegratedFactorAnalysisLikelihood; +import dr.inference.model.CompoundParameter; +import dr.util.TaskPool; +import dr.xml.ElementRule; +import dr.xml.XMLParseException; +import dr.xml.XMLSyntaxRule; + +public class IntegratedLoadingsAndPrecisionGradientParser extends IntegratedLoadingsGradientParser { + + public static final String PARSER_NAME = "integratedFactorAnalysisLoadingsAndPrecisionGradient"; + + protected IntegratedLoadingsGradient factory(TreeDataLikelihood treeDataLikelihood, + ContinuousDataLikelihoodDelegate likelihoodDelegate, + IntegratedFactorAnalysisLikelihood factorAnalysisLikelihood, + ContinuousTraitPartialsProvider jointPartialsProvider, + TaskPool taskPool, + IntegratedLoadingsGradient.ThreadUseProvider threadUseProvider, + IntegratedLoadingsGradient.RemainderCompProvider remainderCompProvider, + CompoundParameter parameter) + throws XMLParseException { + + return new IntegratedLoadingsAndPrecisionGradient( + parameter, + treeDataLikelihood, + likelihoodDelegate, + factorAnalysisLikelihood, + jointPartialsProvider, + taskPool, + threadUseProvider, + remainderCompProvider); + + } + + + @Override + public String getParserDescription() { + return "Generates a gradient provider for the loadings matrix & precision when factors are integrated out"; + } + + @Override + public Class getReturnType() { + return IntegratedLoadingsAndPrecisionGradient.class; + } + + @Override + public String getParserName() { + return PARSER_NAME; + } + + @Override + public XMLSyntaxRule[] getSyntaxRules() { + XMLSyntaxRule[] newRules = new XMLSyntaxRule[rules.length + 1]; + newRules[0] = new ElementRule(CompoundParameter.class); + System.arraycopy(rules, 0, newRules, 1, rules.length); + return newRules; + } +} diff --git a/src/dr/evomodelxml/continuous/hmc/IntegratedLoadingsGradientParser.java b/src/dr/evomodelxml/continuous/hmc/IntegratedLoadingsGradientParser.java index cc635b80ff..ab7635bd92 100644 --- a/src/dr/evomodelxml/continuous/hmc/IntegratedLoadingsGradientParser.java +++ b/src/dr/evomodelxml/continuous/hmc/IntegratedLoadingsGradientParser.java @@ -7,6 +7,7 @@ import dr.evomodel.treedatalikelihood.continuous.ContinuousTraitPartialsProvider; import dr.evomodel.treedatalikelihood.continuous.IntegratedFactorAnalysisLikelihood; import dr.evomodel.treedatalikelihood.continuous.JointPartialsProvider; +import dr.inference.model.CompoundParameter; import dr.util.TaskPool; import dr.xml.*; @@ -73,6 +74,21 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { ContinuousTraitPartialsProvider partialsProvider = (JointPartialsProvider) xo.getChild(JointPartialsProvider.class); if (partialsProvider == null) partialsProvider = factorAnalysis; + // for IntegratedLoadingsAndPrecisionGradient + CompoundParameter parameter = (CompoundParameter) xo.getChild(CompoundParameter.class); + if (parameter != null) { + if (parameter.getParameterCount() != 2) { + throw new XMLParseException("The parameter must have two elements, " + + "the first being the loadings matrix and the second being the precision matrix."); + } + if (parameter.getParameter(0) != factorAnalysis.getLoadings()) { + throw new XMLParseException("The first element of the parameter must be the loadings matrix."); + } + if (parameter.getParameter(1) != factorAnalysis.getPrecision()) { + throw new XMLParseException("The second element of the parameter must be the precision matrix."); + } + } + // TODO Check dimensions, parameters, etc. @@ -83,7 +99,8 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { partialsProvider, taskPool, threadProvider, - remainderCompProvider); + remainderCompProvider, + parameter); } @@ -93,7 +110,8 @@ protected IntegratedLoadingsGradient factory(TreeDataLikelihood treeDataLikeliho ContinuousTraitPartialsProvider jointPartialsProvider, TaskPool taskPool, IntegratedLoadingsGradient.ThreadUseProvider threadUseProvider, - IntegratedLoadingsGradient.RemainderCompProvider remainderCompProvider) + IntegratedLoadingsGradient.RemainderCompProvider remainderCompProvider, + CompoundParameter parameter) throws XMLParseException { return new IntegratedLoadingsGradient( From ed06961870ee4cd8247e0623c0dbb691c7862d7c Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Wed, 10 May 2023 16:04:08 -0700 Subject: [PATCH 150/196] test xml --- .../testLoadingsAndPrecisionGradient.xml | 186 ++++++++++++++++++ 1 file changed, 186 insertions(+) create mode 100644 ci/TestXML/testLoadingsAndPrecisionGradient.xml diff --git a/ci/TestXML/testLoadingsAndPrecisionGradient.xml b/ci/TestXML/testLoadingsAndPrecisionGradient.xml new file mode 100644 index 0000000000..fcee6bae49 --- /dev/null +++ b/ci/TestXML/testLoadingsAndPrecisionGradient.xml @@ -0,0 +1,186 @@ + + + + + -0.8161262390971513 0.4270454816694809 -0.6230781939110783 -0.16450480870919917 + -0.4456292313690999 + + + + -1.0144436604293505 -0.8920557774375898 1.478607059857378 -0.16240458275087258 + -0.6908795858187 + + + + 1.056873733132347 NA NA NA 1.066150616149452 + + + + -0.2221684856569227 -0.11560476062111587 0.41571346596533654 0.8655080462439241 + 0.10898633050895852 + + + + -0.17438164408909365 -0.7122666957434696 -0.44738267923159963 -0.17865180808229236 + 0.23315524522158199 + + + + NA -0.8666572905938715 -0.454549337998158 -0.5001246338040305 + -0.5397266841036715 + + + + 1.0015698840090395 2.3897801345872987 -1.5146279580725215 0.8960512885879615 + -2.2012739127775762 + + + + -0.42898479922720356 NA 0.6436374468251037 -0.09729958495856225 + -1.396413239367357 + + + + 0.09104086929644847 0.38443026317349643 -1.426378176957026 -0.7790116303715932 + 0.2981854605237854 + + + + NA NA NA NA 1.6135287964494376 + + + + + (taxon_6:0.6599920953,((((taxon_8:0.08050441416,taxon_4:0.1993587138):0.06120663346,taxon_7:0.4556825075):0.1431224649,taxon_10:0.5471037512):0.8189234324,((taxon_1:0.739103453,(taxon_2:0.3068418624,taxon_5:0.7002265998):0.6723836821):0.4476448677,(taxon_9:0.16993984,taxon_3:0.2669664454):0.9823498076):0.9481884362):0.3653106997);; + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + From dcf0aa60674ba90509a3c57d6bb525de63c2b097 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Wed, 10 May 2023 16:37:12 -0700 Subject: [PATCH 151/196] works now (was an indexing issue) --- .../testLoadingsAndPrecisionGradient.xml | 2 +- ...ntegratedLoadingsAndPrecisionGradient.java | 32 +++++++++++++++---- .../hmc/IntegratedLoadingsGradient.java | 2 +- 3 files changed, 27 insertions(+), 9 deletions(-) diff --git a/ci/TestXML/testLoadingsAndPrecisionGradient.xml b/ci/TestXML/testLoadingsAndPrecisionGradient.xml index fcee6bae49..87061711c1 100644 --- a/ci/TestXML/testLoadingsAndPrecisionGradient.xml +++ b/ci/TestXML/testLoadingsAndPrecisionGradient.xml @@ -137,7 +137,7 @@ + gradientCheckCount="100" gradientCheckTolerance="5e-3"> diff --git a/src/dr/evomodel/continuous/hmc/IntegratedLoadingsAndPrecisionGradient.java b/src/dr/evomodel/continuous/hmc/IntegratedLoadingsAndPrecisionGradient.java index 470feeddf3..21f22b9538 100644 --- a/src/dr/evomodel/continuous/hmc/IntegratedLoadingsAndPrecisionGradient.java +++ b/src/dr/evomodel/continuous/hmc/IntegratedLoadingsAndPrecisionGradient.java @@ -53,17 +53,35 @@ private void computePrecisionGradientForOneTaxon(int index, double[] ftfl = components.ftfl; - for (int factor = 0; factor < dimFactors; ++factor) { - for (int trait = 0; trait < dimTrait; ++trait) { + for (int trait = 0; trait < dimTrait; ++trait) { + int dataInd = taxon * dimTrait + trait; + if (factorAnalysisLikelihood.getDataMissingIndicators()[dataInd]) { + continue; + } + double dat = data[dataInd]; + gradArray[index][offset + trait] += 0.5 * (1 / rawGamma[trait] - dat * dat); + + for (int factor = 0; factor < dimFactors; ++factor) { + int loadingsInd = trait * dimFactors + factor; int ind = factor * dimTrait + trait; + gradArray[index][offset + trait] += - (2 * fty[ind] - ftfl[ind]) * transposedLoadings[ind]; + (fty[ind] - 0.5 * ftfl[ind]) * transposedLoadings[loadingsInd]; } +// gradArray[index][offset + trait] += +// (fty[ind] - 0.5 * ftfl[ind]) * transposedLoadings[ind]; } - for (int trait = 0; trait < dimTrait; ++trait) { - double dat = data[taxon * dimTrait + trait]; - gradArray[index][offset + trait] += dat * dat + 1 / rawGamma[trait]; //TODO: need to deal w/ missing data - } +// for (int factor = 0; factor < dimFactors; ++factor) { +// for (int trait = 0; trait < dimTrait; ++trait) { +// int ind = factor * dimTrait + trait; +// gradArray[index][offset + trait] += +// (fty[ind] - 0.5 * ftfl[ind]) * transposedLoadings[ind]; +// } +// } +// for (int trait = 0; trait < dimTrait; ++trait) { +// double dat = data[taxon * dimTrait + trait]; +// gradArray[index][offset + trait] += 0.5 * (1 / rawGamma[trait] - dat * dat); //TODO: need to deal w/ missing data +// } } @Override diff --git a/src/dr/evomodel/continuous/hmc/IntegratedLoadingsGradient.java b/src/dr/evomodel/continuous/hmc/IntegratedLoadingsGradient.java index daee17761e..ec34da2f1f 100644 --- a/src/dr/evomodel/continuous/hmc/IntegratedLoadingsGradient.java +++ b/src/dr/evomodel/continuous/hmc/IntegratedLoadingsGradient.java @@ -33,7 +33,7 @@ public class IntegratedLoadingsGradient implements GradientWrtParameterProvider, VariableListener, Reportable { private final TreeTrait> fullConditionalDensity; - private final IntegratedFactorAnalysisLikelihood factorAnalysisLikelihood; + protected final IntegratedFactorAnalysisLikelihood factorAnalysisLikelihood; private final ContinuousTraitPartialsProvider partialsProvider; protected final int dimTrait; protected final int dimFactors; From 9193f5befbe20c93066be23e1c35fb4f38405909 Mon Sep 17 00:00:00 2001 From: twoseventwo Date: Thu, 29 Jun 2023 11:18:36 +0100 Subject: [PATCH 152/196] Dividing weights by Gamma(alpha + 1) will auto-normalise --- .../evomodel/siteratemodel/GammaSiteRateModel.java | 4 ++-- .../math/GeneralisedGaussLaguerreQuadrature.java | 14 ++++++++++---- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java b/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java index 317a989941..d9a975856a 100644 --- a/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java +++ b/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java @@ -29,6 +29,7 @@ import dr.math.GeneralisedGaussLaguerreQuadrature; import dr.math.distributions.GammaDistribution; import dr.evomodel.substmodel.SubstitutionModel; +import dr.math.functionEval.GammaFunction; import dr.util.Author; import dr.util.Citable; import dr.util.Citation; @@ -434,9 +435,8 @@ public static void setQuatratureRates(double[] categoryRates, double[] categoryP for (int i = 0; i < catCount; i++) { categoryRates[i + offset] = abscissae[i] / (alpha + 1); - categoryProportions[i + offset] = coefficients[i]; + categoryProportions[i + offset] = coefficients[i]/GammaFunction.gamma(alpha + 1); } - normalize(categoryRates, categoryProportions); } /** diff --git a/src/dr/math/GeneralisedGaussLaguerreQuadrature.java b/src/dr/math/GeneralisedGaussLaguerreQuadrature.java index 05f4a60e27..0c2cc3d383 100644 --- a/src/dr/math/GeneralisedGaussLaguerreQuadrature.java +++ b/src/dr/math/GeneralisedGaussLaguerreQuadrature.java @@ -80,6 +80,7 @@ private void setupArrays(){ for(int i=0; i Date: Thu, 29 Jun 2023 12:47:05 +0100 Subject: [PATCH 153/196] Implemented Felsenstein's quadrature weights as an option in BEAUti --- .../generator/SubstitutionModelGenerator.java | 40 ++++++++++++------- .../options/PartitionSubstitutionModel.java | 11 +++++ .../sitemodelspanel/PartitionModelPanel.java | 19 +++++---- .../siteratemodel/GammaSiteRateModel.java | 4 +- .../siteratemodel/GammaSiteModelParser.java | 2 +- 5 files changed, 52 insertions(+), 24 deletions(-) diff --git a/src/dr/app/beauti/generator/SubstitutionModelGenerator.java b/src/dr/app/beauti/generator/SubstitutionModelGenerator.java index 837fa4001b..6f24720a0a 100644 --- a/src/dr/app/beauti/generator/SubstitutionModelGenerator.java +++ b/src/dr/app/beauti/generator/SubstitutionModelGenerator.java @@ -347,15 +347,15 @@ private void writeAlignmentRefInFrequencies(XMLWriter writer, PartitionSubstitut writeCodonPatternsRef(prefix, num, model.getCodonPartitionCount(), writer); // get the data partition for this substitution model. - AbstractPartitionData partition = options.getDataPartitions(model).get(0); - - // for empirical frequencies use the entire alignment - if (partition instanceof PartitionData) { - Alignment alignment = ((PartitionData)partition).getAlignment(); - writer.writeIDref(AlignmentParser.ALIGNMENT, alignment.getId()); - } else { - throw new IllegalArgumentException("Partition is missing a data partition"); - } + AbstractPartitionData partition = options.getDataPartitions(model).get(0); + + // for empirical frequencies use the entire alignment + if (partition instanceof PartitionData) { + Alignment alignment = ((PartitionData)partition).getAlignment(); + writer.writeIDref(AlignmentParser.ALIGNMENT, alignment.getId()); + } else { + throw new IllegalArgumentException("Partition is missing a data partition"); + } } else { for (AbstractPartitionData partition : options.getDataPartitions(model)) { //? writer.writeIDref(AlignmentParser.ALIGNMENT, partition.getTaxonList().getId()); @@ -696,8 +696,13 @@ private void writeNucSiteModel(int num, XMLWriter writer, PartitionSubstitutionM if (model.isGammaHetero()) { - writer.writeOpenTag(GammaSiteModelParser.GAMMA_SHAPE, new Attribute.Default( - GammaSiteModelParser.GAMMA_CATEGORIES, "" + model.getGammaCategories())); + writer.writeOpenTag(GammaSiteModelParser.GAMMA_SHAPE, + new Attribute[] { + new Attribute.Default<>(GammaSiteModelParser.GAMMA_CATEGORIES, "" + model.getGammaCategories()), + new Attribute.Default<>(GammaSiteModelParser.DISCRETIZATION, + (model.isGammaHeteroEqualWeights() ? "equal" : "quadrature")), + }); + if (num == -1 || model.isUnlinkedHeterogeneityModel()) { // writeParameter(prefix + "alpha", model, writer); writeParameter(num, "alpha", model, writer); @@ -783,7 +788,11 @@ private void writeTwoStateSiteModel(XMLWriter writer, PartitionSubstitutionModel if (model.isGammaHetero()) { writer.writeOpenTag(GammaSiteModelParser.GAMMA_SHAPE, - new Attribute.Default(GammaSiteModelParser.GAMMA_CATEGORIES, "" + model.getGammaCategories())); + new Attribute[] { + new Attribute.Default<>(GammaSiteModelParser.GAMMA_CATEGORIES, "" + model.getGammaCategories()), + new Attribute.Default<>(GammaSiteModelParser.DISCRETIZATION, + (model.isGammaHeteroEqualWeights() ? "equal" : "quadrature")), + }); writeParameter(prefix + "alpha", model, writer); writer.writeCloseTag(GammaSiteModelParser.GAMMA_SHAPE); } @@ -832,8 +841,11 @@ private void writeAASiteModel(XMLWriter writer, PartitionSubstitutionModel model if (model.isGammaHetero()) { writer.writeOpenTag(GammaSiteModelParser.GAMMA_SHAPE, - new Attribute.Default( - GammaSiteModelParser.GAMMA_CATEGORIES, "" + model.getGammaCategories())); + new Attribute[] { + new Attribute.Default<>(GammaSiteModelParser.GAMMA_CATEGORIES, "" + model.getGammaCategories()), + new Attribute.Default<>(GammaSiteModelParser.DISCRETIZATION, + (model.isGammaHeteroEqualWeights() ? "equal" : "quadrature")), + }); writeParameter("alpha", model, writer); writer.writeCloseTag(GammaSiteModelParser.GAMMA_SHAPE); } diff --git a/src/dr/app/beauti/options/PartitionSubstitutionModel.java b/src/dr/app/beauti/options/PartitionSubstitutionModel.java index 5a0a0332d7..81d3249562 100644 --- a/src/dr/app/beauti/options/PartitionSubstitutionModel.java +++ b/src/dr/app/beauti/options/PartitionSubstitutionModel.java @@ -71,6 +71,7 @@ public class PartitionSubstitutionModel extends PartitionOptions { private boolean gammaHetero = false; private int gammaCategories = 4; private boolean invarHetero = false; + private boolean equalWeights = false; private String codonHeteroPattern = null; private boolean unlinkedSubstitutionModel = true; private boolean unlinkedHeterogeneityModel = true; @@ -133,6 +134,7 @@ public PartitionSubstitutionModel(BeautiOptions options, String name, PartitionS frequencyPolicy = source.frequencyPolicy; gammaHetero = source.gammaHetero; gammaCategories = source.gammaCategories; + equalWeights = source.equalWeights; invarHetero = source.invarHetero; codonHeteroPattern = source.codonHeteroPattern; unlinkedSubstitutionModel = source.unlinkedSubstitutionModel; @@ -1105,6 +1107,14 @@ public void setInvarHetero(boolean invarHetero) { this.invarHetero = invarHetero; } + public boolean isGammaHeteroEqualWeights() { + return equalWeights; + } + + public void setGammaHeteroEqualWeights(boolean equalWeights) { + this.equalWeights = equalWeights; + } + public String getCodonHeteroPattern() { return codonHeteroPattern; } @@ -1229,6 +1239,7 @@ public void copyFrom(PartitionSubstitutionModel source) { frequencyPolicy = source.frequencyPolicy; gammaHetero = source.gammaHetero; gammaCategories = source.gammaCategories; + equalWeights = source.equalWeights; invarHetero = source.invarHetero; codonHeteroPattern = source.codonHeteroPattern; unlinkedSubstitutionModel = source.unlinkedSubstitutionModel; diff --git a/src/dr/app/beauti/sitemodelspanel/PartitionModelPanel.java b/src/dr/app/beauti/sitemodelspanel/PartitionModelPanel.java index 98363fca73..d815726051 100644 --- a/src/dr/app/beauti/sitemodelspanel/PartitionModelPanel.java +++ b/src/dr/app/beauti/sitemodelspanel/PartitionModelPanel.java @@ -74,10 +74,10 @@ public class PartitionModelPanel extends OptionsPanel { .values()); private JComboBox heteroCombo = new JComboBox(new String[] { "None", - "Gamma", "Invariant Sites", "Gamma + Invariant Sites" }); + "Gamma (Felsenstein weights)", "Gamma (equal weights)", "Invariant Sites", "Gamma (equal weights) + Invariant Sites" }); private JComboBox gammaCatCombo = new JComboBox(new String[] { "4", "5", - "6", "7", "8", "9", "10" }); + "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16" }); private JLabel gammaCatLabel; private JComboBox codingCombo = new JComboBox(new String[] { "Off", @@ -216,16 +216,21 @@ public void itemStateChanged(ItemEvent ev) { PanelUtils.setupComponent(heteroCombo); heteroCombo - .setToolTipText("Select the type of site-specific rate
heterogeneity model."); + .setToolTipText("Select the type of site-specific rate
heterogeneity model.
" + + "\"Felsenstein weights\" uses the quadrature method to calculate the category weights described in
" + + "Felsenstein (2001) J Mol Evol 53: 447-455."); heteroCombo.addItemListener(new ItemListener() { public void itemStateChanged(ItemEvent ev) { - boolean gammaHetero = heteroCombo.getSelectedIndex() == 1 - || heteroCombo.getSelectedIndex() == 3; + boolean gammaHetero = heteroCombo.getSelectedIndex() == 1 || + heteroCombo.getSelectedIndex() == 2 + || heteroCombo.getSelectedIndex() == 4; model.setGammaHetero(gammaHetero); - model.setInvarHetero(heteroCombo.getSelectedIndex() == 2 - || heteroCombo.getSelectedIndex() == 3); + model.setInvarHetero(heteroCombo.getSelectedIndex() == 3 + || heteroCombo.getSelectedIndex() == 4); + model.setGammaHeteroEqualWeights(heteroCombo.getSelectedIndex() == 2 + || heteroCombo.getSelectedIndex() == 4); if (gammaHetero) { gammaCatLabel.setEnabled(true); diff --git a/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java b/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java index d9a975856a..e5b06f3723 100644 --- a/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java +++ b/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java @@ -401,7 +401,7 @@ public List getCitations() { public final static Citation CITATION_FELSENSTEIN01 = new Citation( new Author[]{ - new Author("J", "Felsenstein ") + new Author("J", "Felsenstein") }, "Taking Variation of Evolutionary Rates Between Sites into Account in Inferring Phylogenies", 2001, @@ -509,7 +509,7 @@ public static void main(String[] argv) { // 5.617 0.00076 // 8.823 0.000003 - // Output (without setting rates to mean of 1) + // Output // Quadrature, alpha = 1.0 // cat rate proportion // 0 0.26383406085556455 0.27765014202987454 diff --git a/src/dr/evomodelxml/siteratemodel/GammaSiteModelParser.java b/src/dr/evomodelxml/siteratemodel/GammaSiteModelParser.java index f241099b54..2b677dd81b 100644 --- a/src/dr/evomodelxml/siteratemodel/GammaSiteModelParser.java +++ b/src/dr/evomodelxml/siteratemodel/GammaSiteModelParser.java @@ -109,7 +109,7 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { if (type == GammaSiteRateModel.DiscretizationType.EQUAL) { msg += "\n using equal weight discretization of gamma distribution"; } else { - msg += "\n using Gauss-Laguerre quadrature discretization of gamma distribution"; + msg += "\n using Gauss-Laguerre quadrature discretization of gamma distribution (Felsenstein, 2012)"; } } From d9e77253e5307064c93a6a1b981cacb6c440a41d Mon Sep 17 00:00:00 2001 From: twoseventwo Date: Thu, 29 Jun 2023 16:19:59 +0100 Subject: [PATCH 154/196] Added an alternative FreeRate model whose rates will be in ascending order. --- .../evomodel/siteratemodel/FreeRateModel.java | 233 ++++++++++++++++++ .../siteratemodel/FreeRateModelParser.java | 93 +++++++ 2 files changed, 326 insertions(+) create mode 100644 src/dr/evomodel/siteratemodel/FreeRateModel.java create mode 100644 src/dr/evomodelxml/siteratemodel/FreeRateModelParser.java diff --git a/src/dr/evomodel/siteratemodel/FreeRateModel.java b/src/dr/evomodel/siteratemodel/FreeRateModel.java new file mode 100644 index 0000000000..fb07efaa38 --- /dev/null +++ b/src/dr/evomodel/siteratemodel/FreeRateModel.java @@ -0,0 +1,233 @@ +/* + * PdfSiteRateModel.java + * + * Copyright (c) 2002-2020 Alexei Drummond, Andrew Rambaut and Marc Suchard + * + * This file is part of BEAST. + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership and licensing. + * + * BEAST is free software; you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation; either version 2 + * of the License, or (at your option) any later version. + * + * BEAST is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with BEAST; if not, write to the + * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, + * Boston, MA 02110-1301 USA + */ + +package dr.evomodel.siteratemodel; + +import dr.evomodel.substmodel.SubstitutionModel; +import dr.inference.model.AbstractModel; +import dr.inference.model.Model; +import dr.inference.model.Parameter; +import dr.inference.model.Variable; +import dr.util.Author; +import dr.util.Citable; +import dr.util.Citation; + +import java.util.Collections; +import java.util.List; + +/** + * @author Marc A. Suchard + * @author Matthew Hall + * + * The parameters of this version are the differences between rates, not the rates themselves. This allows the actual + * rates to be in ascending order. + */ + +public class FreeRateModel extends AbstractModel implements SiteRateModel, Citable { + + public FreeRateModel( + String name, + Parameter rateDifferences, + Parameter weights) { + + super(name); + + this.rateDifferences = rateDifferences; + this.weights = weights; + this.dim = Math.min(rateDifferences.getDimension(), weights.getDimension()); + + addVariable(rateDifferences); + addVariable(weights); + + ratesKnown = false; + } + + // ***************************************************************** + // Interface SiteRateModel + // ***************************************************************** + + @Override + public int getCategoryCount() { return dim; } + + @Override + public double[] getCategoryRates() { + synchronized (this) { + if (!ratesKnown) { + calculateCategoryRates(); + } + } + return categoryRates; + } + + @Override + public double[] getCategoryProportions() { + synchronized (this) { + if (!ratesKnown) { + calculateCategoryRates(); + } + } + return categoryProportions; + } + + @Override + public double getRateForCategory(int category) { + synchronized (this) { + if (!ratesKnown) { + calculateCategoryRates(); + } + } + return categoryRates[category]; + } + + @Override + public double getProportionForCategory(int category) { + synchronized (this) { + if (!ratesKnown) { + calculateCategoryRates(); + } + } + return categoryProportions[category]; + } + + + private void calculateCategoryRates() { + + double scale = 0.0; + double sum = 0.0; + double[] unnormalisedRates = new double[getCategoryCount()]; + + unnormalisedRates[0] = rateDifferences.getParameterValue(0); + + for (int i = 1; i < dim; i++) { + unnormalisedRates[i] = unnormalisedRates[i-1] + rateDifferences.getParameterValue(i); + } + + if (categoryRates == null) { + categoryRates = new double[dim]; + } + + if (categoryProportions == null) { + categoryProportions = new double[dim]; + } + + for (int i = 0; i < dim; i++) { + sum += weights.getParameterValue(i); + } + + for (int i = 0; i < dim; ++i) { + categoryProportions[i] = weights.getParameterValue(i) / sum; + } + + for (int i = 0; i < dim; i++) { + scale += categoryProportions[i] * unnormalisedRates[i]; + } + + for (int i = 0; i < dim; ++i) { + categoryRates[i] = unnormalisedRates[i] / scale; + } + +// double checker = 0; +// for (int i = 0; i < dim; ++i) { +// checker += categoryRates[i] * categoryProportions[i]; +// } + + ratesKnown = true; + } + + @Override + protected void handleModelChangedEvent(Model model, Object object, int index) { + // Substitution model has changed so fire model changed event + listenerHelper.fireModelChanged(this, object, index); + } + + @Override + protected final void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) { + if (variable == rateDifferences || variable == weights) { + ratesKnown = false; + } else { + throw new RuntimeException("Unknown variable in PdfSiteRateModel.handleVariableChangedEvent"); + } + listenerHelper.fireModelChanged(this, variable, index); + } + + @Override + protected void storeState() { } + + @Override + protected void restoreState() { ratesKnown = false; } + + @Override + protected void acceptState() { } + + @Override + public Citation.Category getCategory() { + return Citation.Category.SUBSTITUTION_MODELS; + } + + @Override + public String getDescription() { + return "Discrete probability distribution free rate heterogeneity model"; + } + + @Override + public List getCitations() { + return Collections.singletonList(CITATION); + } + + public final static Citation CITATION = new Citation( // TODO Update + new Author[]{ + new Author("Z", "Yang") + }, + "A space-time process model for the evolution of DNA Sequences", + 1995, + "Genetics", + 139, + 993, 1005, + Citation.Status.PUBLISHED + ); + + private final Parameter rateDifferences; + private final Parameter weights; + private final int dim; + + private double[] categoryRates; + private double[] categoryProportions; + private boolean ratesKnown; + + // This is here solely to allow the PdfSiteModelParser to pass on the substitution model to the + // HomogenousBranchSubstitutionModel so that the XML will be compatible with older BEAST versions. To be removed + // at some point. + public SubstitutionModel getSubstitutionModel() { + + return substitutionModel; + } + + public void setSubstitutionModel(SubstitutionModel substitutionModel) { + this.substitutionModel = substitutionModel; + } + + + private SubstitutionModel substitutionModel; +} \ No newline at end of file diff --git a/src/dr/evomodelxml/siteratemodel/FreeRateModelParser.java b/src/dr/evomodelxml/siteratemodel/FreeRateModelParser.java new file mode 100644 index 0000000000..444d48348e --- /dev/null +++ b/src/dr/evomodelxml/siteratemodel/FreeRateModelParser.java @@ -0,0 +1,93 @@ +/* + * PdfSiteModelParser.java + * + * Copyright (c) 2002-2020 Alexei Drummond, Andrew Rambaut and Marc Suchard + * + * This file is part of BEAST. + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership and licensing. + * + * BEAST is free software; you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation; either version 2 + * of the License, or (at your option) any later version. + * + * BEAST is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with BEAST; if not, write to the + * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, + * Boston, MA 02110-1301 USA + */ + +package dr.evomodelxml.siteratemodel; + +import dr.evomodel.siteratemodel.FreeRateModel; +import dr.evomodel.siteratemodel.GammaSiteRateModel; +import dr.evomodel.siteratemodel.PdfSiteRateModel; +import dr.evomodel.substmodel.SubstitutionModel; +import dr.inference.model.Parameter; +import dr.xml.*; + +/** + * @author Marc A. Suchard + * @author Matthew Hall + */ +public class FreeRateModelParser extends AbstractXMLObjectParser { + + private static final String SITE_MODEL = "freeRateModel"; + private static final String SUBSTITUTION_MODEL = "substitutionModel"; + private static final String RATEDIFFERENCES = "rateDifferences"; + private static final String WEIGHTS = "weights"; + + public String getParserName() { + return SITE_MODEL; + } + + public Object parseXMLObject(XMLObject xo) throws XMLParseException { + + Parameter rateDifferences = (Parameter) xo.getElementFirstChild(RATEDIFFERENCES); + Parameter weights = (Parameter) xo.getElementFirstChild(WEIGHTS); + + PdfSiteRateModel siteRateModel = new PdfSiteRateModel(SITE_MODEL, rateDifferences, weights); + + if (xo.hasChildNamed(SUBSTITUTION_MODEL)) { + siteRateModel.setSubstitutionModel( + (SubstitutionModel) xo.getElementFirstChild(SUBSTITUTION_MODEL)); + + } + + return siteRateModel; + } + + //************************************************************************ + // AbstractXMLObjectParser implementation + //************************************************************************ + + public String getParserDescription() { + return "A SiteRateModel that has probability free distributed rates across sites"; + } + + public Class getReturnType() { + return FreeRateModel.class; + } + + public XMLSyntaxRule[] getSyntaxRules() { + return rules; + } + + private final XMLSyntaxRule[] rules = { + new ElementRule(RATEDIFFERENCES, new XMLSyntaxRule[]{ + new ElementRule(Parameter.class) + }), + new ElementRule(WEIGHTS, new XMLSyntaxRule[]{ + new ElementRule(Parameter.class) + }), + new ElementRule(SUBSTITUTION_MODEL, new XMLSyntaxRule[]{ + new ElementRule(SubstitutionModel.class) + }, true), + }; +} From 593e47b952b76ec2f262b424aea1c91514f5f100 Mon Sep 17 00:00:00 2001 From: rambaut Date: Thu, 29 Jun 2023 18:12:44 +0100 Subject: [PATCH 155/196] Refactoring to a delegate model for discretized site rate models for more flexibility. --- .../DiscretizedSiteRateModel.java | 218 +++++++++++ .../siteratemodel/GammaSiteRateDelegate.java | 352 ++++++++++++++++++ .../siteratemodel/SiteRateDelegate.java | 13 + .../siteratemodel/FreeRateModelParser.java | 7 +- 4 files changed, 586 insertions(+), 4 deletions(-) create mode 100644 src/dr/evomodel/siteratemodel/DiscretizedSiteRateModel.java create mode 100644 src/dr/evomodel/siteratemodel/GammaSiteRateDelegate.java create mode 100644 src/dr/evomodel/siteratemodel/SiteRateDelegate.java diff --git a/src/dr/evomodel/siteratemodel/DiscretizedSiteRateModel.java b/src/dr/evomodel/siteratemodel/DiscretizedSiteRateModel.java new file mode 100644 index 0000000000..626eefcc57 --- /dev/null +++ b/src/dr/evomodel/siteratemodel/DiscretizedSiteRateModel.java @@ -0,0 +1,218 @@ +/* + * DiscretizedSiteRateModel.java + * + * Copyright (c) 2002-2023 Alexei Drummond, Andrew Rambaut and Marc Suchard + * + * This file is part of BEAST. + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership and licensing. + * + * BEAST is free software; you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation; either version 2 + * of the License, or (at your option) any later version. + * + * BEAST is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with BEAST; if not, write to the + * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, + * Boston, MA 02110-1301 USA + */ + +package dr.evomodel.siteratemodel; + +import dr.inference.model.*; +import dr.evomodel.substmodel.SubstitutionModel; + +/** + * DiscretizedSiteRateModel - A SiteModel that has a discrete categories of rates across sites. + * + * @author Andrew Rambaut + * @version $Id: $ + */ + +public class DiscretizedSiteRateModel extends AbstractModel implements SiteRateModel { + + /** + * Constructor for gamma+invar distributed sites. Either shapeParameter or + * invarParameter (or both) can be null to turn off that feature. + */ + public DiscretizedSiteRateModel( + String name, + Parameter nuParameter, + SiteRateDelegate delegate) { + + super(name); + + this.nuParameter = nuParameter; + if (nuParameter != null) { + addVariable(nuParameter); + nuParameter.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0, 1)); + } + this.muWeight = muWeight; + + addStatistic(muStatistic); + + this.delegate = delegate; + addModel(delegate); + + categoryRates = new double[delegate.getCategoryCount()]; + categoryProportions = new double[delegate.getCategoryCount()]; + + ratesKnown = false; + } + + /** + * set mu + */ + public void setMu(double mu) { + nuParameter.setParameterValue(0, mu / muWeight); + } + + /** + * @return mu + */ + public final double getMu() { + return nuParameter.getParameterValue(0) * muWeight; + } + + + public void setRelativeRateParameter(Parameter nu) { + this.nuParameter = nu; + } + + // ***************************************************************** + // Interface SiteRateModel + // ***************************************************************** + + public int getCategoryCount() { + return delegate.getCategoryCount(); + } + + public double[] getCategoryRates() { + synchronized (this) { + if (!ratesKnown) { + delegate.getCategories(categoryRates, categoryProportions); + ratesKnown = true; + } + } + + return categoryRates; + } + + public double[] getCategoryProportions() { + synchronized (this) { + if (!ratesKnown) { + delegate.getCategories(categoryRates, categoryProportions); + ratesKnown = true; + } + } + + return categoryProportions; + } + + public double getRateForCategory(int category) { + synchronized (this) { + if (!ratesKnown) { + delegate.getCategories(categoryRates, categoryProportions); + ratesKnown = true; + } + } + + return categoryRates[category]; + } + + public double getProportionForCategory(int category) { + synchronized (this) { + if (!ratesKnown) { + delegate.getCategories(categoryRates, categoryProportions); + ratesKnown = true; + } + } + + return categoryProportions[category]; + } + + // ***************************************************************** + // Interface ModelComponent + // ***************************************************************** + + protected void handleModelChangedEvent(Model model, Object object, int index) { + // delegate has changed so fire model changed event + listenerHelper.fireModelChanged(this, object, index); + } + + protected final void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) { + if (variable == nuParameter) { + ratesKnown = false; // MAS: I changed this because the rate parameter can affect the categories if the parameter is in siteModel and not clockModel + } else { + throw new RuntimeException("Unknown variable in DiscretizedSiteRateModel.handleVariableChangedEvent"); + } + listenerHelper.fireModelChanged(this, variable, index); + } + + protected void storeState() { + } // no additional state needs storing + + protected void restoreState() { + ratesKnown = false; + } + + protected void acceptState() { + } // no additional state needs accepting + + + private Statistic muStatistic = new Statistic.Abstract() { + + public String getStatisticName() { + return "mu"; + } + + public int getDimension() { + return 1; + } + + public String getDimensionName(int dim) { + return getId(); + } + + public double getStatisticValue(int dim) { + return getMu(); + } + + }; + + + /** + * mutation rate parameter + */ + private Parameter nuParameter; + + private double muWeight; + + private boolean ratesKnown; + + private final double[] categoryRates; + + private final double[] categoryProportions; + + private final SiteRateDelegate delegate; + + // This is here solely to allow the GammaSiteModelParser to pass on the substitution model to the + // HomogenousBranchSubstitutionModel so that the XML will be compatible with older BEAST versions. To be removed + // at some point. + public SubstitutionModel getSubstitutionModel() { + return substitutionModel; + } + + public void setSubstitutionModel(SubstitutionModel substitutionModel) { + this.substitutionModel = substitutionModel; + } + + private SubstitutionModel substitutionModel; + +} \ No newline at end of file diff --git a/src/dr/evomodel/siteratemodel/GammaSiteRateDelegate.java b/src/dr/evomodel/siteratemodel/GammaSiteRateDelegate.java new file mode 100644 index 0000000000..11fae733a5 --- /dev/null +++ b/src/dr/evomodel/siteratemodel/GammaSiteRateDelegate.java @@ -0,0 +1,352 @@ +/* + * GammaSiteRateModel.java + * + * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard + * + * This file is part of BEAST. + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership and licensing. + * + * BEAST is free software; you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation; either version 2 + * of the License, or (at your option) any later version. + * + * BEAST is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with BEAST; if not, write to the + * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, + * Boston, MA 02110-1301 USA + */ + +package dr.evomodel.siteratemodel; + +import dr.evomodel.substmodel.SubstitutionModel; +import dr.inference.model.*; +import dr.math.GeneralisedGaussLaguerreQuadrature; +import dr.math.distributions.GammaDistribution; +import dr.math.functionEval.GammaFunction; +import dr.util.Author; +import dr.util.Citable; +import dr.util.Citation; + +import java.util.ArrayList; +import java.util.List; + +/** + * GammaSiteModel - A SiteModel that has a gamma distributed rates across sites. + * + * @author Andrew Rambaut + * @version $Id: GammaSiteModel.java,v 1.31 2005/09/26 14:27:38 rambaut Exp $ + */ + +public class GammaSiteRateDelegate extends AbstractModel implements SiteRateDelegate, Citable { + + public static final DiscretizationType DEFAULT_DISCRETIZATION = DiscretizationType.EQUAL; + + public enum DiscretizationType { + EQUAL, + QUADRATURE + }; + + + /** + * Constructor for gamma+invar distributed sites. Either shapeParameter or + * invarParameter (or both) can be null to turn off that feature. + */ + public GammaSiteRateDelegate( + String name, + Parameter shapeParameter, int gammaCategoryCount, + DiscretizationType discretizationType, + Parameter invarParameter) { + + super(name); + + this.shapeParameter = shapeParameter; + if (shapeParameter != null) { + this.categoryCount = gammaCategoryCount; + addVariable(shapeParameter); +// shapeParameter.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 1E-3, 1)); + // removing the bounds on the alpha parameter - to make the prior more explicit + shapeParameter.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0, 1)); + } else { + this.categoryCount = 1; + } + + this.invarParameter = invarParameter; + if (invarParameter != null) { + this.categoryCount += 1; + + addVariable(invarParameter); + invarParameter.addBounds(new Parameter.DefaultBounds(1.0, 0.0, 1)); + } + + this.discretizationType = discretizationType; + } + + // ***************************************************************** + // Interface SiteRateModel + // ***************************************************************** + + public int getCategoryCount() { + return categoryCount; + } + + public void getCategories(double[] categoryRates, double[] categoryProportions) { + assert categoryRates != null && categoryRates.length == categoryCount; + assert categoryProportions != null && categoryProportions.length == categoryCount; + + int offset = 0; + + if (invarParameter != null) { + categoryRates[0] = 0.0; + categoryProportions[0] = invarParameter.getParameterValue(0); + offset = 1; + } + + if (shapeParameter != null) { + double alpha = shapeParameter.getParameterValue(0); + final int gammaCatCount = categoryCount - offset; + + if (discretizationType == DiscretizationType.QUADRATURE) { + setQuatratureRates(categoryRates, categoryProportions, alpha, gammaCatCount, offset); + } else { + setEqualRates(categoryRates, categoryProportions, alpha, gammaCatCount, offset); + } + } else if (offset > 0) { + // just the invariant rate and variant rate + categoryRates[offset] = 2.0; + categoryProportions[offset] = 1.0 - categoryProportions[0]; + } else { + categoryRates[0] = 1.0; + categoryProportions[0] = 1.0; + } + } + + // ***************************************************************** + // Interface ModelComponent + // ***************************************************************** + + protected void handleModelChangedEvent(Model model, Object object, int index) { + listenerHelper.fireModelChanged(this, object, index); + } + + protected final void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) { + listenerHelper.fireModelChanged(this, variable, index); + } + + protected void storeState() { + } // no additional state needs storing + + protected void restoreState() { + } + + protected void acceptState() { + } // no additional state needs accepting + + + /** + * shape parameter + */ + private Parameter shapeParameter; + + /** + * invariant sites parameter + */ + private Parameter invarParameter; + + private DiscretizationType discretizationType; + + private int categoryCount; + + + @Override + public Citation.Category getCategory() { + return Citation.Category.SUBSTITUTION_MODELS; + } + + @Override + public String getDescription() { + return "Discrete gamma-distributed rate heterogeneity model"; + } + + public List getCitations() { + List citations = new ArrayList<>(); + if (shapeParameter != null) { + citations.add(CITATION_YANG94); + if (discretizationType == DiscretizationType.QUADRATURE) { + citations.add(CITATION_FELSENSTEIN01); + } + } + return citations; + } + + public final static Citation CITATION_YANG94 = new Citation( + new Author[]{ + new Author("Z", "Yang") + }, + "Maximum likelihood phylogenetic estimation from DNA sequences with variable rates over sites: approximate methods", + 1994, + "J. Mol. Evol.", + 39, + 306, 314, + Citation.Status.PUBLISHED + ); + + public final static Citation CITATION_FELSENSTEIN01 = new Citation( + new Author[]{ + new Author("J", "Felsenstein") + }, + "Taking Variation of Evolutionary Rates Between Sites into Account in Inferring Phylogenies", + 2001, + "J. Mol. Evol.", + 53, + 447, 455, + Citation.Status.PUBLISHED + ); + + private SubstitutionModel substitutionModel; + + private static GeneralisedGaussLaguerreQuadrature quadrature = null; + + /** + * Set the rates and proportions using a Gauss-Laguerre Quadrature, as proposed by Felsenstein 2001, JME + * + * @param categoryRates + * @param categoryProportions + * @param alpha + * @param catCount + * @param offset + */ + public static void setQuatratureRates(double[] categoryRates, double[] categoryProportions, double alpha, int catCount, int offset) { + if (quadrature == null) { + quadrature = new GeneralisedGaussLaguerreQuadrature(catCount); + } + quadrature.setAlpha(alpha); + + double[] abscissae = quadrature.getAbscissae(); + double[] coefficients = quadrature.getCoefficients(); + + for (int i = 0; i < catCount; i++) { + categoryRates[i + offset] = abscissae[i] / (alpha + 1); + categoryProportions[i + offset] = coefficients[i]/GammaFunction.gamma(alpha + 1); + } + } + + /** + * set the rates as equally spaced quantiles represented by the mean as proposed by Yang 1994 + * @param categoryRates + * @param categoryProportions + * @param alpha + * @param catCount + * @param offset + */ + public static void setEqualRates(double[] categoryRates, double[] categoryProportions, double alpha, int catCount, int offset) { + for (int i = 0; i < catCount; i++) { + categoryRates[i + offset] = GammaDistribution.quantile((2.0 * i + 1.0) / (2.0 * catCount), alpha, 1.0 / alpha); + categoryProportions[i + offset] = 1.0; + } + + normalize(categoryRates, categoryProportions); + } + + /** + * Gives the category rates a mean of 1.0 and the proportions sum to 1.0 + * @param categoryRates + * @param categoryProportions + */ + public static void normalize(double[] categoryRates, double[] categoryProportions) { + double mean = 0.0; + double sum = 0.0; + for (int i = 0; i < categoryRates.length; i++) { + mean += categoryRates[i]; + sum += categoryProportions[i]; + } + mean /= categoryRates.length; + + for(int i = 0; i < categoryRates.length; i++) { + categoryRates[i] /= mean; + categoryProportions[i] /= sum; + } + } + + public static void main(String[] argv) { + final int catCount = 6; + + double[] categoryRates = new double[catCount]; + double[] categoryProportions = new double[catCount]; + + setEqualRates(categoryRates, categoryProportions, 1.0, catCount, 0); + System.out.println(); + System.out.println("Equal, alpha = 1.0"); + System.out.println("cat\trate\tproportion"); + for (int i = 0; i < catCount; i++) { + System.out.println(i + "\t"+ categoryRates[i] +"\t" + categoryProportions[i]); + } + + setQuatratureRates(categoryRates, categoryProportions, 1.0, catCount, 0); + System.out.println(); + System.out.println("Quadrature, alpha = 1.0"); + System.out.println("cat\trate\tproportion"); + for (int i = 0; i < catCount; i++) { + System.out.println(i + "\t"+ categoryRates[i] +"\t" + categoryProportions[i]); + } + + // Table 3 from Felsenstein 2001, JME + // Rates and probabilities chosen by the quadrature method for six rates and coefficient of + // variation of rates among sites 1 (a 4 1) + // Rate Probability + // 0.264 0.278 + // 0.898 0.494 + // 1.938 0.203 + // 3.459 0.025 + // 5.617 0.00076 + // 8.823 0.000003 + + // Output + // Quadrature, alpha = 1.0 + // cat rate proportion + // 0 0.26383406085556455 0.27765014202987454 + // 1 0.8981499048217043 0.49391058305035496 + // 2 1.938320760238456 0.20300429674372977 + // 3 3.459408283352361 0.02466882036918974 + // 4 5.617305214541558 7.6304276746353E-4 + // 5 8.822981776190357 3.1150393875275343E-6 + + setEqualRates(categoryRates, categoryProportions, 0.1, catCount, 0); + System.out.println(); + System.out.println("Equal, alpha = 0.1"); + System.out.println("cat\trate\tproportion"); + for (int i = 0; i < catCount; i++) { + System.out.println(i + "\t"+ categoryRates[i] +"\t" + categoryProportions[i]); + } + + setQuatratureRates(categoryRates, categoryProportions, 0.1, catCount, 0); + System.out.println(); + System.out.println("Quadrature, alpha = 0.1"); + System.out.println("cat\trate\tproportion"); + for (int i = 0; i < catCount; i++) { + System.out.println(i + "\t"+ categoryRates[i] +"\t" + categoryProportions[i]); + } + + setEqualRates(categoryRates, categoryProportions, 10.0, catCount, 0); + System.out.println(); + System.out.println("Equal, alpha = 10.0"); + System.out.println("cat\trate\tproportion"); + for (int i = 0; i < catCount; i++) { + System.out.println(i + "\t"+ categoryRates[i] +"\t" + categoryProportions[i]); + } + + setQuatratureRates(categoryRates, categoryProportions, 10.0, catCount, 0); + System.out.println(); + System.out.println("Quadrature, alpha = 10.0"); + System.out.println("cat\trate\tproportion"); + for (int i = 0; i < catCount; i++) { + System.out.println(i + "\t"+ categoryRates[i] +"\t" + categoryProportions[i]); + } + } +} \ No newline at end of file diff --git a/src/dr/evomodel/siteratemodel/SiteRateDelegate.java b/src/dr/evomodel/siteratemodel/SiteRateDelegate.java new file mode 100644 index 0000000000..3a35e4f510 --- /dev/null +++ b/src/dr/evomodel/siteratemodel/SiteRateDelegate.java @@ -0,0 +1,13 @@ +package dr.evomodel.siteratemodel; + +import dr.inference.model.Model; + +/** + * @author Andrew Rambaut + * @version $ + */ +public interface SiteRateDelegate extends Model { + int getCategoryCount(); + + void getCategories(double[] categoryRates, double[] categoryProportions); +} diff --git a/src/dr/evomodelxml/siteratemodel/FreeRateModelParser.java b/src/dr/evomodelxml/siteratemodel/FreeRateModelParser.java index 444d48348e..1617f5971d 100644 --- a/src/dr/evomodelxml/siteratemodel/FreeRateModelParser.java +++ b/src/dr/evomodelxml/siteratemodel/FreeRateModelParser.java @@ -26,7 +26,6 @@ package dr.evomodelxml.siteratemodel; import dr.evomodel.siteratemodel.FreeRateModel; -import dr.evomodel.siteratemodel.GammaSiteRateModel; import dr.evomodel.siteratemodel.PdfSiteRateModel; import dr.evomodel.substmodel.SubstitutionModel; import dr.inference.model.Parameter; @@ -40,7 +39,7 @@ public class FreeRateModelParser extends AbstractXMLObjectParser { private static final String SITE_MODEL = "freeRateModel"; private static final String SUBSTITUTION_MODEL = "substitutionModel"; - private static final String RATEDIFFERENCES = "rateDifferences"; + private static final String RATE_DIFFERENCES = "rateDifferences"; private static final String WEIGHTS = "weights"; public String getParserName() { @@ -49,7 +48,7 @@ public String getParserName() { public Object parseXMLObject(XMLObject xo) throws XMLParseException { - Parameter rateDifferences = (Parameter) xo.getElementFirstChild(RATEDIFFERENCES); + Parameter rateDifferences = (Parameter) xo.getElementFirstChild(RATE_DIFFERENCES); Parameter weights = (Parameter) xo.getElementFirstChild(WEIGHTS); PdfSiteRateModel siteRateModel = new PdfSiteRateModel(SITE_MODEL, rateDifferences, weights); @@ -80,7 +79,7 @@ public XMLSyntaxRule[] getSyntaxRules() { } private final XMLSyntaxRule[] rules = { - new ElementRule(RATEDIFFERENCES, new XMLSyntaxRule[]{ + new ElementRule(RATE_DIFFERENCES, new XMLSyntaxRule[]{ new ElementRule(Parameter.class) }), new ElementRule(WEIGHTS, new XMLSyntaxRule[]{ From e8b134e96b423feeb86eeef259d074668d14f1cd Mon Sep 17 00:00:00 2001 From: Andrew Rambaut Date: Fri, 30 Jun 2023 15:10:13 +0100 Subject: [PATCH 156/196] Implementing a more generalised rate heterogeneity system --- src/dr/app/beast/release_parsers.properties | 2 +- .../generator/SubstitutionModelGenerator.java | 81 ++++--- .../app/tools/AncestralSequenceAnnotator.java | 6 +- .../DiscretizedSiteRateModel.java | 2 + .../siteratemodel/GammaSiteRateDelegate.java | 2 +- .../DataLikelihoodTester.java | 4 +- .../DataLikelihoodTester2.java | 4 +- .../siteratemodel/GammaSiteRateModel.java | 194 +++++++++++++++++ ...rser.java => OldGammaSiteModelParser.java} | 18 +- .../siteratemodel/SiteModelParser.java | 200 ++++++++++++++++++ .../MultiPartitionDataLikelihoodParser.java | 9 +- .../TreeDataLikelihoodParser.java | 32 +-- .../sitemodel/GammaSiteModelParser.java | 2 +- 13 files changed, 478 insertions(+), 78 deletions(-) create mode 100644 src/dr/evomodelxml/siteratemodel/GammaSiteRateModel.java rename src/dr/evomodelxml/siteratemodel/{GammaSiteModelParser.java => OldGammaSiteModelParser.java} (96%) create mode 100644 src/dr/evomodelxml/siteratemodel/SiteModelParser.java diff --git a/src/dr/app/beast/release_parsers.properties b/src/dr/app/beast/release_parsers.properties index ba5a7ceeb3..9977c5d513 100644 --- a/src/dr/app/beast/release_parsers.properties +++ b/src/dr/app/beast/release_parsers.properties @@ -74,7 +74,7 @@ dr.evomodelxml.substmodel.InfinitesimalRatesLoggerParser dr.evomodelxml.substmodel.LewisMkSubstitutionModelParser # SITE RATE MODELS -dr.evomodelxml.siteratemodel.GammaSiteModelParser +dr.evomodelxml.siteratemodel.SiteModelParser dr.evomodelxml.siteratemodel.PdfSiteModelParser # BRANCH MODELS diff --git a/src/dr/app/beauti/generator/SubstitutionModelGenerator.java b/src/dr/app/beauti/generator/SubstitutionModelGenerator.java index 6f24720a0a..1d2a235ff7 100644 --- a/src/dr/app/beauti/generator/SubstitutionModelGenerator.java +++ b/src/dr/app/beauti/generator/SubstitutionModelGenerator.java @@ -26,7 +26,6 @@ package dr.app.beauti.generator; import dr.app.beauti.options.*; -import dr.evomodel.substmodel.nucleotide.GTR; import dr.evomodel.substmodel.nucleotide.NucModelType; import dr.app.beauti.components.ComponentFactory; import dr.app.beauti.types.FrequencyPolicyType; @@ -34,6 +33,7 @@ import dr.app.beauti.util.XMLWriter; import dr.evolution.alignment.Alignment; import dr.evolution.datatype.DataType; +import dr.evomodelxml.siteratemodel.SiteModelParser; import dr.evomodelxml.substmodel.BinaryCovarionModelParser; import dr.evomodelxml.substmodel.BinarySubstitutionModelParser; import dr.evomodelxml.substmodel.EmpiricalAminoAcidModelParser; @@ -42,7 +42,6 @@ import dr.evomodelxml.substmodel.GeneralSubstitutionModelParser; import dr.evomodelxml.substmodel.HKYParser; import dr.evomodelxml.substmodel.TN93Parser; -import dr.evomodelxml.siteratemodel.GammaSiteModelParser; import dr.inference.model.StatisticParser; import dr.oldevomodel.substmodel.AsymmetricQuadraticModel; import dr.oldevomodel.substmodel.LinearBiasModel; @@ -628,11 +627,11 @@ private void writeNucSiteModel(int num, XMLWriter writer, PartitionSubstitutionM String prefix2 = model.getPrefix(); writer.writeComment("site model"); - writer.writeOpenTag(GammaSiteModelParser.SITE_MODEL, - new Attribute[]{new Attribute.Default(XMLParser.ID, prefix + GammaSiteModelParser.SITE_MODEL)}); + writer.writeOpenTag(SiteModelParser.SITE_MODEL, + new Attribute[]{new Attribute.Default(XMLParser.ID, prefix + SiteModelParser.SITE_MODEL)}); - writer.writeOpenTag(GammaSiteModelParser.SUBSTITUTION_MODEL); + writer.writeOpenTag(SiteModelParser.SUBSTITUTION_MODEL); if (model.isUnlinkedSubstitutionModel()) { switch (model.getNucSubstitutionModel()) { @@ -674,7 +673,7 @@ private void writeNucSiteModel(int num, XMLWriter writer, PartitionSubstitutionM } } - writer.writeCloseTag(GammaSiteModelParser.SUBSTITUTION_MODEL); + writer.writeCloseTag(SiteModelParser.SUBSTITUTION_MODEL); if (options.useNuRelativeRates()) { Parameter parameter; @@ -696,10 +695,10 @@ private void writeNucSiteModel(int num, XMLWriter writer, PartitionSubstitutionM if (model.isGammaHetero()) { - writer.writeOpenTag(GammaSiteModelParser.GAMMA_SHAPE, + writer.writeOpenTag(SiteModelParser.GAMMA_SHAPE, new Attribute[] { - new Attribute.Default<>(GammaSiteModelParser.GAMMA_CATEGORIES, "" + model.getGammaCategories()), - new Attribute.Default<>(GammaSiteModelParser.DISCRETIZATION, + new Attribute.Default<>(SiteModelParser.GAMMA_CATEGORIES, "" + model.getGammaCategories()), + new Attribute.Default<>(SiteModelParser.DISCRETIZATION, (model.isGammaHeteroEqualWeights() ? "equal" : "quadrature")), }); @@ -715,11 +714,11 @@ private void writeNucSiteModel(int num, XMLWriter writer, PartitionSubstitutionM writer.writeIDref(ParameterParser.PARAMETER, prefix2 + "alpha"); } } - writer.writeCloseTag(GammaSiteModelParser.GAMMA_SHAPE); + writer.writeCloseTag(SiteModelParser.GAMMA_SHAPE); } if (model.isInvarHetero()) { - writer.writeOpenTag(GammaSiteModelParser.PROPORTION_INVARIANT); + writer.writeOpenTag(SiteModelParser.PROPORTION_INVARIANT); if (num == -1 || model.isUnlinkedHeterogeneityModel()) { // writeParameter(prefix + "pInv", model, writer); writeParameter(num, "pInv", model, writer); @@ -732,13 +731,13 @@ private void writeNucSiteModel(int num, XMLWriter writer, PartitionSubstitutionM writer.writeIDref(ParameterParser.PARAMETER, prefix2 + "pInv"); } } - writer.writeCloseTag(GammaSiteModelParser.PROPORTION_INVARIANT); + writer.writeCloseTag(SiteModelParser.PROPORTION_INVARIANT); } - writer.writeCloseTag(GammaSiteModelParser.SITE_MODEL); + writer.writeCloseTag(SiteModelParser.SITE_MODEL); if (options.useNuRelativeRates()) { - writeMuStatistic(writer, prefix, GammaSiteModelParser.SITE_MODEL); + writeMuStatistic(writer, prefix, SiteModelParser.SITE_MODEL); } } @@ -754,11 +753,11 @@ private void writeTwoStateSiteModel(XMLWriter writer, PartitionSubstitutionModel String prefix = model.getPrefix(); writer.writeComment("site model"); - writer.writeOpenTag(GammaSiteModelParser.SITE_MODEL, - new Attribute[]{new Attribute.Default(XMLParser.ID, prefix + GammaSiteModelParser.SITE_MODEL)}); + writer.writeOpenTag(SiteModelParser.SITE_MODEL, + new Attribute[]{new Attribute.Default(XMLParser.ID, prefix + SiteModelParser.SITE_MODEL)}); - writer.writeOpenTag(GammaSiteModelParser.SUBSTITUTION_MODEL); + writer.writeOpenTag(SiteModelParser.SUBSTITUTION_MODEL); switch (model.getBinarySubstitutionModel()) { case BIN_SIMPLE: @@ -774,7 +773,7 @@ private void writeTwoStateSiteModel(XMLWriter writer, PartitionSubstitutionModel throw new IllegalArgumentException("Unknown substitution model."); } - writer.writeCloseTag(GammaSiteModelParser.SUBSTITUTION_MODEL); + writer.writeCloseTag(SiteModelParser.SUBSTITUTION_MODEL); if (options.useNuRelativeRates()) { Parameter parameter = model.getParameter("nu"); @@ -783,28 +782,28 @@ private void writeTwoStateSiteModel(XMLWriter writer, PartitionSubstitutionModel writeNuRelativeRateBlock(writer, prefix1, parameter); } } else { - writeParameter(GammaSiteModelParser.RELATIVE_RATE, "mu", model, writer); + writeParameter(SiteModelParser.RELATIVE_RATE, "mu", model, writer); } if (model.isGammaHetero()) { - writer.writeOpenTag(GammaSiteModelParser.GAMMA_SHAPE, + writer.writeOpenTag(SiteModelParser.GAMMA_SHAPE, new Attribute[] { - new Attribute.Default<>(GammaSiteModelParser.GAMMA_CATEGORIES, "" + model.getGammaCategories()), - new Attribute.Default<>(GammaSiteModelParser.DISCRETIZATION, + new Attribute.Default<>(SiteModelParser.GAMMA_CATEGORIES, "" + model.getGammaCategories()), + new Attribute.Default<>(SiteModelParser.DISCRETIZATION, (model.isGammaHeteroEqualWeights() ? "equal" : "quadrature")), }); writeParameter(prefix + "alpha", model, writer); - writer.writeCloseTag(GammaSiteModelParser.GAMMA_SHAPE); + writer.writeCloseTag(SiteModelParser.GAMMA_SHAPE); } if (model.isInvarHetero()) { - writeParameter(GammaSiteModelParser.PROPORTION_INVARIANT, "pInv", model, writer); + writeParameter(SiteModelParser.PROPORTION_INVARIANT, "pInv", model, writer); } - writer.writeCloseTag(GammaSiteModelParser.SITE_MODEL); + writer.writeCloseTag(SiteModelParser.SITE_MODEL); if (options.useNuRelativeRates()) { - writeMuStatistic(writer, prefix, GammaSiteModelParser.SITE_MODEL); + writeMuStatistic(writer, prefix, SiteModelParser.SITE_MODEL); } } @@ -820,13 +819,13 @@ private void writeAASiteModel(XMLWriter writer, PartitionSubstitutionModel model String prefix = model.getPrefix(); writer.writeComment("site model"); - writer.writeOpenTag(GammaSiteModelParser.SITE_MODEL, new Attribute[]{ - new Attribute.Default(XMLParser.ID, prefix + GammaSiteModelParser.SITE_MODEL)}); + writer.writeOpenTag(SiteModelParser.SITE_MODEL, new Attribute[]{ + new Attribute.Default(XMLParser.ID, prefix + SiteModelParser.SITE_MODEL)}); - writer.writeOpenTag(GammaSiteModelParser.SUBSTITUTION_MODEL); + writer.writeOpenTag(SiteModelParser.SUBSTITUTION_MODEL); writer.writeIDref(EmpiricalAminoAcidModelParser.EMPIRICAL_AMINO_ACID_MODEL, prefix + "aa"); - writer.writeCloseTag(GammaSiteModelParser.SUBSTITUTION_MODEL); + writer.writeCloseTag(SiteModelParser.SUBSTITUTION_MODEL); if (options.useNuRelativeRates()) { Parameter parameter = model.getParameter("nu"); @@ -836,28 +835,28 @@ private void writeAASiteModel(XMLWriter writer, PartitionSubstitutionModel model } } else { - writeParameter(GammaSiteModelParser.RELATIVE_RATE, "mu", model, writer); + writeParameter(SiteModelParser.RELATIVE_RATE, "mu", model, writer); } if (model.isGammaHetero()) { - writer.writeOpenTag(GammaSiteModelParser.GAMMA_SHAPE, + writer.writeOpenTag(SiteModelParser.GAMMA_SHAPE, new Attribute[] { - new Attribute.Default<>(GammaSiteModelParser.GAMMA_CATEGORIES, "" + model.getGammaCategories()), - new Attribute.Default<>(GammaSiteModelParser.DISCRETIZATION, + new Attribute.Default<>(SiteModelParser.GAMMA_CATEGORIES, "" + model.getGammaCategories()), + new Attribute.Default<>(SiteModelParser.DISCRETIZATION, (model.isGammaHeteroEqualWeights() ? "equal" : "quadrature")), }); writeParameter("alpha", model, writer); - writer.writeCloseTag(GammaSiteModelParser.GAMMA_SHAPE); + writer.writeCloseTag(SiteModelParser.GAMMA_SHAPE); } if (model.isInvarHetero()) { - writeParameter(GammaSiteModelParser.PROPORTION_INVARIANT, "pInv", model, writer); + writeParameter(SiteModelParser.PROPORTION_INVARIANT, "pInv", model, writer); } - writer.writeCloseTag(GammaSiteModelParser.SITE_MODEL); + writer.writeCloseTag(SiteModelParser.SITE_MODEL); if (options.useNuRelativeRates()) { - writeMuStatistic(writer, prefix, GammaSiteModelParser.SITE_MODEL); + writeMuStatistic(writer, prefix, SiteModelParser.SITE_MODEL); } @@ -870,12 +869,12 @@ private void writeAASiteModel(XMLWriter writer, PartitionSubstitutionModel model */ private void writeNuRelativeRateBlock(XMLWriter writer, String prefix, Parameter parameter) { double weight = ((double) parameter.getParent().getDimensionWeight()) / parameter.getDimensionWeight(); - writer.writeOpenTag(GammaSiteModelParser.RELATIVE_RATE, - new Attribute.Default(GammaSiteModelParser.WEIGHT, "" + weight)); + writer.writeOpenTag(SiteModelParser.RELATIVE_RATE, + new Attribute.Default(SiteModelParser.WEIGHT, "" + weight)); // Initial values must sum to 1.0 double initial = 1.0 / parameter.getParent().getSubParameters().size(); writeParameter(prefix + "nu", 1, initial, 0.0, 1.0, writer); - writer.writeCloseTag(GammaSiteModelParser.RELATIVE_RATE); + writer.writeCloseTag(SiteModelParser.RELATIVE_RATE); } /** diff --git a/src/dr/app/tools/AncestralSequenceAnnotator.java b/src/dr/app/tools/AncestralSequenceAnnotator.java index f0697b5366..2d7c6f33c2 100644 --- a/src/dr/app/tools/AncestralSequenceAnnotator.java +++ b/src/dr/app/tools/AncestralSequenceAnnotator.java @@ -28,7 +28,7 @@ import dr.evomodel.branchmodel.HomogeneousBranchModel; import dr.evomodel.tree.DefaultTreeModel; -import dr.evomodelxml.siteratemodel.GammaSiteModelParser; +import dr.evomodelxml.siteratemodel.SiteModelParser; import dr.evomodel.substmodel.FrequencyModel; import dr.evomodel.substmodel.GeneralSubstitutionModel; import dr.evomodel.substmodel.aminoacid.EmpiricalAminoAcidModel; @@ -808,7 +808,7 @@ else if(siteRatesModels.indexOf("+GAMMA(") >= 0) { /* For BEAST output */ //System.out.println("alpha and pinv parameters: " + alphaParameter.getParameterValue(0) + "\t" + pInvParameter.getParameterValue(0)); //GammaSiteRateModel siteModel = new GammaSiteRateModel(sml.getSubstitutionModel(), new Parameter.Default(1.0), alphaParameter, categories, pInvParameter); - GammaSiteRateModel siteModel = new GammaSiteRateModel(GammaSiteModelParser.SITE_MODEL, new Parameter.Default(1.0), 1.0, alphaParameter, categories, GammaSiteRateModel.DiscretizationType.EQUAL, pInvParameter); + GammaSiteRateModel siteModel = new GammaSiteRateModel(SiteModelParser.SITE_MODEL, new Parameter.Default(1.0), 1.0, alphaParameter, categories, GammaSiteRateModel.DiscretizationType.EQUAL, pInvParameter); siteModel.setSubstitutionModel(sml.getSubstitutionModel()); //SiteModel siteModel = new GammaSiteModel(sml.getSubstitutionModel(), new Parameter.Default(1.0), new Parameter.Default(1.0), 1, new Parameter.Default(0.5)); //SiteModel siteModel = new GammaSiteModel(sml.getSubstitutionModel(), null, null, 0, null); @@ -817,7 +817,7 @@ else if(siteRatesModels.indexOf("+GAMMA(") >= 0) { /* For BEAST output */ /* Default with no gamma or pinv */ //SiteRateModel siteModel = new GammaSiteRateModel(sml.getSubstitutionModel()); - GammaSiteRateModel siteModel = new GammaSiteRateModel(GammaSiteModelParser.SITE_MODEL); + GammaSiteRateModel siteModel = new GammaSiteRateModel(SiteModelParser.SITE_MODEL); siteModel.setSubstitutionModel(sml.getSubstitutionModel()); return siteModel; diff --git a/src/dr/evomodel/siteratemodel/DiscretizedSiteRateModel.java b/src/dr/evomodel/siteratemodel/DiscretizedSiteRateModel.java index 626eefcc57..e3d3d77605 100644 --- a/src/dr/evomodel/siteratemodel/DiscretizedSiteRateModel.java +++ b/src/dr/evomodel/siteratemodel/DiscretizedSiteRateModel.java @@ -37,6 +37,7 @@ public class DiscretizedSiteRateModel extends AbstractModel implements SiteRateModel { + /** * Constructor for gamma+invar distributed sites. Either shapeParameter or * invarParameter (or both) can be null to turn off that feature. @@ -44,6 +45,7 @@ public class DiscretizedSiteRateModel extends AbstractModel implements SiteRateM public DiscretizedSiteRateModel( String name, Parameter nuParameter, + double muWeight, SiteRateDelegate delegate) { super(name); diff --git a/src/dr/evomodel/siteratemodel/GammaSiteRateDelegate.java b/src/dr/evomodel/siteratemodel/GammaSiteRateDelegate.java index 11fae733a5..0e4a56f252 100644 --- a/src/dr/evomodel/siteratemodel/GammaSiteRateDelegate.java +++ b/src/dr/evomodel/siteratemodel/GammaSiteRateDelegate.java @@ -46,7 +46,7 @@ public class GammaSiteRateDelegate extends AbstractModel implements SiteRateDelegate, Citable { - public static final DiscretizationType DEFAULT_DISCRETIZATION = DiscretizationType.EQUAL; + public static final DiscretizationType DEFAULT_DISCRETIZATION = DiscretizationType.QUADRATURE; public enum DiscretizationType { EQUAL, diff --git a/src/dr/evomodel/treedatalikelihood/DataLikelihoodTester.java b/src/dr/evomodel/treedatalikelihood/DataLikelihoodTester.java index 1dc72c94de..f39d6aaa56 100644 --- a/src/dr/evomodel/treedatalikelihood/DataLikelihoodTester.java +++ b/src/dr/evomodel/treedatalikelihood/DataLikelihoodTester.java @@ -30,7 +30,7 @@ import dr.evomodel.branchmodel.HomogeneousBranchModel; import dr.evomodel.branchratemodel.DefaultBranchRateModel; import dr.evomodel.tree.DefaultTreeModel; -import dr.evomodelxml.siteratemodel.GammaSiteModelParser; +import dr.evomodelxml.siteratemodel.SiteModelParser; import dr.evomodelxml.substmodel.HKYParser; import dr.evomodel.siteratemodel.GammaSiteRateModel; import dr.evomodel.siteratemodel.SiteRateModel; @@ -90,7 +90,7 @@ public static void main(String[] args) { GammaSiteRateModel siteRateModel = new GammaSiteRateModel("gammaModel", alpha, 4); // GammaSiteRateModel siteRateModel = new GammaSiteRateModel("siteRateModel"); siteRateModel.setSubstitutionModel(hky); - Parameter mu = new Parameter.Default(GammaSiteModelParser.SUBSTITUTION_RATE, 1.0, 0, Double.POSITIVE_INFINITY); + Parameter mu = new Parameter.Default(SiteModelParser.SUBSTITUTION_RATE, 1.0, 0, Double.POSITIVE_INFINITY); siteRateModel.setRelativeRateParameter(mu); FrequencyModel f2 = new FrequencyModel(Nucleotides.INSTANCE, freqs); diff --git a/src/dr/evomodel/treedatalikelihood/DataLikelihoodTester2.java b/src/dr/evomodel/treedatalikelihood/DataLikelihoodTester2.java index 3625650d56..1f77087508 100644 --- a/src/dr/evomodel/treedatalikelihood/DataLikelihoodTester2.java +++ b/src/dr/evomodel/treedatalikelihood/DataLikelihoodTester2.java @@ -49,7 +49,7 @@ import dr.evomodel.tree.TreeModel; import dr.evomodel.treelikelihood.BeagleTreeLikelihood; import dr.evomodel.treelikelihood.PartialsRescalingScheme; -import dr.evomodelxml.siteratemodel.GammaSiteModelParser; +import dr.evomodelxml.siteratemodel.SiteModelParser; import dr.evomodelxml.substmodel.HKYParser; import dr.inference.model.Parameter; @@ -90,7 +90,7 @@ public static void main(String[] args) { GammaSiteRateModel siteRateModel = new GammaSiteRateModel("gammaModel", alpha, 4); // GammaSiteRateModel siteRateModel = new GammaSiteRateModel("siteRateModel"); siteRateModel.setSubstitutionModel(hky); - Parameter mu = new Parameter.Default(GammaSiteModelParser.SUBSTITUTION_RATE, 1.0, 0, Double.POSITIVE_INFINITY); + Parameter mu = new Parameter.Default(SiteModelParser.SUBSTITUTION_RATE, 1.0, 0, Double.POSITIVE_INFINITY); siteRateModel.setRelativeRateParameter(mu); FrequencyModel f2 = new FrequencyModel(Nucleotides.INSTANCE, freqs); diff --git a/src/dr/evomodelxml/siteratemodel/GammaSiteRateModel.java b/src/dr/evomodelxml/siteratemodel/GammaSiteRateModel.java new file mode 100644 index 0000000000..ed29c87c4a --- /dev/null +++ b/src/dr/evomodelxml/siteratemodel/GammaSiteRateModel.java @@ -0,0 +1,194 @@ +/* + * SiteModelParser.java + * + * Copyright (c) 2002-2016 Alexei Drummond, Andrew Rambaut and Marc Suchard + * + * This file is part of BEAST. + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership and licensing. + * + * BEAST is free software; you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation; either version 2 + * of the License, or (at your option) any later version. + * + * BEAST is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with BEAST; if not, write to the + * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, + * Boston, MA 02110-1301 USA + */ + +package dr.evomodelxml.siteratemodel; + +import dr.evomodel.siteratemodel.DiscretizedSiteRateModel; +import dr.evomodel.siteratemodel.GammaSiteRateDelegate; +import dr.evomodel.substmodel.SubstitutionModel; +import dr.inference.model.Parameter; +import dr.oldevomodel.sitemodel.SiteModel; +import dr.xml.*; + +import java.util.logging.Logger; + +/** + * This is a replacement to GammaSiteModelParser that uses the modular + * DiscretizedSiteRateModel with a Gamma delegate. + * @author Andrew Rambaut + * @version $Id$ + */ +public class GammaSiteRateModel extends AbstractXMLObjectParser { + + public static final String GAMMA_SITE_RATE_MODEL = "GammaSiteRateModel"; + public static final String SUBSTITUTION_MODEL = "substitutionModel"; + public static final String MUTATION_RATE = "mutationRate"; + public static final String SUBSTITUTION_RATE = "substitutionRate"; + public static final String RELATIVE_RATE = "relativeRate"; + public static final String WEIGHT = "weight"; + public static final String GAMMA_SHAPE = "gammaShape"; + public static final String GAMMA_CATEGORIES = "gammaCategories"; + public static final String PROPORTION_INVARIANT = "proportionInvariant"; + public static final String DISCRETIZATION = "discretization"; + + public String getParserName() { + return GAMMA_SITE_RATE_MODEL; + } + + public Object parseXMLObject(XMLObject xo) throws XMLParseException { + + String msg = ""; + SubstitutionModel substitutionModel; + + double muWeight = 1.0; + + Parameter muParam = null; + if (xo.hasChildNamed(SUBSTITUTION_RATE)) { + muParam = (Parameter) xo.getElementFirstChild(SUBSTITUTION_RATE); + + msg += "\n with initial substitution rate = " + muParam.getParameterValue(0); + } else if (xo.hasChildNamed(MUTATION_RATE)) { + muParam = (Parameter) xo.getElementFirstChild(MUTATION_RATE); + + msg += "\n with initial substitution rate = " + muParam.getParameterValue(0); + } else if (xo.hasChildNamed(RELATIVE_RATE)) { + XMLObject cxo = xo.getChild(RELATIVE_RATE); + muParam = (Parameter) cxo.getChild(Parameter.class); + msg += "\n with initial relative rate = " + muParam.getParameterValue(0); + if (cxo.hasAttribute(WEIGHT)) { + muWeight = cxo.getDoubleAttribute(WEIGHT); + msg += " with weight: " + muWeight; + } + } + + GammaSiteRateDelegate.DiscretizationType type = GammaSiteRateDelegate.DEFAULT_DISCRETIZATION; + + Parameter shapeParam = null; + int catCount = 4; + if (xo.hasChildNamed(GAMMA_SHAPE)) { + XMLObject cxo = xo.getChild(GAMMA_SHAPE); + catCount = cxo.getIntegerAttribute(GAMMA_CATEGORIES); + + if ( cxo.hasAttribute(DISCRETIZATION)) { + try { + type = GammaSiteRateDelegate.DiscretizationType.valueOf( + cxo.getStringAttribute(DISCRETIZATION).toUpperCase()); + } catch (IllegalArgumentException eae) { + throw new XMLParseException("Unknown category width type: " + cxo.getStringAttribute(DISCRETIZATION)); + } + } + shapeParam = (Parameter) cxo.getChild(Parameter.class); + + msg += "\n " + catCount + " category discrete gamma with initial shape = " + shapeParam.getParameterValue(0); + if (type == GammaSiteRateDelegate.DiscretizationType.EQUAL) { + msg += "\n using equal weight discretization of gamma distribution"; + } else { + msg += "\n using Gauss-Laguerre quadrature discretization of gamma distribution (Felsenstein, 2012)"; + } + } + + Parameter invarParam = null; + if (xo.hasChildNamed(PROPORTION_INVARIANT)) { + invarParam = (Parameter) xo.getElementFirstChild(PROPORTION_INVARIANT); + msg += "\n initial proportion of invariant sites = " + invarParam.getParameterValue(0); + } + + if (msg.length() > 0) { + Logger.getLogger("dr.evomodel").info("\nCreating site rate model: " + msg); + } else { + Logger.getLogger("dr.evomodel").info("\nCreating site rate model."); + } + + GammaSiteRateDelegate delegate = new GammaSiteRateDelegate("GammaSiteRateDelegate", shapeParam, catCount, type, invarParam); + + DiscretizedSiteRateModel siteRateModel = new DiscretizedSiteRateModel(SiteModel.SITE_MODEL, muParam, muWeight, delegate); + + if (xo.hasChildNamed(SUBSTITUTION_MODEL)) { + +// System.err.println("Doing the substitution model stuff"); + + // set this to pass it along to the OldTreeLikelihoodParser... + substitutionModel = (SubstitutionModel) xo.getElementFirstChild(SUBSTITUTION_MODEL); + siteRateModel.setSubstitutionModel(substitutionModel); + + } + + return siteRateModel; + } + + //************************************************************************ + // AbstractXMLObjectParser implementation + //************************************************************************ + + public String getParserDescription() { + return "A DiscretizedSiteRateModel that has a gamma distributed rates across sites"; + } + + @Override + public String[] getParserNames() { + return super.getParserNames(); + } + + public Class getReturnType() { + return DiscretizedSiteRateModel.class; + } + + public XMLSyntaxRule[] getSyntaxRules() { + return rules; + } + + private final XMLSyntaxRule[] rules = { + + new ElementRule(SUBSTITUTION_MODEL, new XMLSyntaxRule[]{ + new ElementRule(SubstitutionModel.class) + }, true), + + new XORRule( + new XORRule( + new ElementRule(SUBSTITUTION_RATE, new XMLSyntaxRule[]{ + new ElementRule(Parameter.class) + }), + new ElementRule(MUTATION_RATE, new XMLSyntaxRule[]{ + new ElementRule(Parameter.class) + }) + ), + new ElementRule(RELATIVE_RATE, new XMLSyntaxRule[]{ + AttributeRule.newDoubleRule(WEIGHT, true), + new ElementRule(Parameter.class) + }), true + ), + + new ElementRule(GAMMA_SHAPE, new XMLSyntaxRule[]{ + AttributeRule.newIntegerRule(GAMMA_CATEGORIES, true), + AttributeRule.newStringRule(DISCRETIZATION, true), + new ElementRule(Parameter.class) + }, true), + + new ElementRule(PROPORTION_INVARIANT, new XMLSyntaxRule[]{ + new ElementRule(Parameter.class) + }, true) + }; + +}//END: class diff --git a/src/dr/evomodelxml/siteratemodel/GammaSiteModelParser.java b/src/dr/evomodelxml/siteratemodel/OldGammaSiteModelParser.java similarity index 96% rename from src/dr/evomodelxml/siteratemodel/GammaSiteModelParser.java rename to src/dr/evomodelxml/siteratemodel/OldGammaSiteModelParser.java index 2b677dd81b..3bae49a126 100644 --- a/src/dr/evomodelxml/siteratemodel/GammaSiteModelParser.java +++ b/src/dr/evomodelxml/siteratemodel/OldGammaSiteModelParser.java @@ -1,5 +1,5 @@ /* - * GammaSiteModelParser.java + * OldGammaSiteModelParser.java * * Copyright (c) 2002-2016 Alexei Drummond, Andrew Rambaut and Marc Suchard * @@ -25,26 +25,20 @@ package dr.evomodelxml.siteratemodel; -import java.util.logging.Logger; - import dr.evomodel.siteratemodel.GammaSiteRateModel; import dr.evomodel.substmodel.SubstitutionModel; -import dr.oldevomodel.sitemodel.SiteModel; import dr.inference.model.Parameter; -import dr.xml.AbstractXMLObjectParser; -import dr.xml.AttributeRule; -import dr.xml.ElementRule; -import dr.xml.XMLObject; -import dr.xml.XMLParseException; -import dr.xml.XMLSyntaxRule; -import dr.xml.XORRule; +import dr.oldevomodel.sitemodel.SiteModel; +import dr.xml.*; + +import java.util.logging.Logger; /** * @author Andrew Rambaut * @author Alexei Drummond * @version $Id$ */ -public class GammaSiteModelParser extends AbstractXMLObjectParser { +public class OldGammaSiteModelParser extends AbstractXMLObjectParser { public static final String SITE_MODEL = SiteModel.SITE_MODEL; public static final String SUBSTITUTION_MODEL = "substitutionModel"; diff --git a/src/dr/evomodelxml/siteratemodel/SiteModelParser.java b/src/dr/evomodelxml/siteratemodel/SiteModelParser.java new file mode 100644 index 0000000000..685ed31062 --- /dev/null +++ b/src/dr/evomodelxml/siteratemodel/SiteModelParser.java @@ -0,0 +1,200 @@ +/* + * SiteModelParser.java + * + * Copyright (c) 2002-2016 Alexei Drummond, Andrew Rambaut and Marc Suchard + * + * This file is part of BEAST. + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership and licensing. + * + * BEAST is free software; you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation; either version 2 + * of the License, or (at your option) any later version. + * + * BEAST is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with BEAST; if not, write to the + * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, + * Boston, MA 02110-1301 USA + */ + +package dr.evomodelxml.siteratemodel; + +import java.util.logging.Logger; + +import dr.evomodel.siteratemodel.DiscretizedSiteRateModel; +import dr.evomodel.siteratemodel.GammaSiteRateDelegate; +import dr.evomodel.substmodel.SubstitutionModel; +import dr.oldevomodel.sitemodel.SiteModel; +import dr.inference.model.Parameter; +import dr.xml.AbstractXMLObjectParser; +import dr.xml.AttributeRule; +import dr.xml.ElementRule; +import dr.xml.XMLObject; +import dr.xml.XMLParseException; +import dr.xml.XMLSyntaxRule; +import dr.xml.XORRule; + +/** + * This is a replacement to GammaSiteModelParser to keep old XML that used + * the element working. + * @author Andrew Rambaut + * @version $Id$ + */ +public class SiteModelParser extends AbstractXMLObjectParser { + + public static final String SITE_MODEL = "SiteModel"; + public static final String SUBSTITUTION_MODEL = "substitutionModel"; + public static final String MUTATION_RATE = "mutationRate"; + public static final String SUBSTITUTION_RATE = "substitutionRate"; + public static final String RELATIVE_RATE = "relativeRate"; + public static final String WEIGHT = "weight"; + public static final String GAMMA_SHAPE = "gammaShape"; + public static final String GAMMA_CATEGORIES = "gammaCategories"; + public static final String PROPORTION_INVARIANT = "proportionInvariant"; + public static final String DISCRETIZATION = "discretization"; + + public String getParserName() { + return SITE_MODEL; + } + + public Object parseXMLObject(XMLObject xo) throws XMLParseException { + + String msg = ""; + SubstitutionModel substitutionModel; + + double muWeight = 1.0; + + Parameter muParam = null; + if (xo.hasChildNamed(SUBSTITUTION_RATE)) { + muParam = (Parameter) xo.getElementFirstChild(SUBSTITUTION_RATE); + + msg += "\n with initial substitution rate = " + muParam.getParameterValue(0); + } else if (xo.hasChildNamed(MUTATION_RATE)) { + muParam = (Parameter) xo.getElementFirstChild(MUTATION_RATE); + + msg += "\n with initial substitution rate = " + muParam.getParameterValue(0); + } else if (xo.hasChildNamed(RELATIVE_RATE)) { + XMLObject cxo = xo.getChild(RELATIVE_RATE); + muParam = (Parameter) cxo.getChild(Parameter.class); + msg += "\n with initial relative rate = " + muParam.getParameterValue(0); + if (cxo.hasAttribute(WEIGHT)) { + muWeight = cxo.getDoubleAttribute(WEIGHT); + msg += " with weight: " + muWeight; + } + } + + GammaSiteRateDelegate.DiscretizationType type = GammaSiteRateDelegate.DEFAULT_DISCRETIZATION; + + Parameter shapeParam = null; + int catCount = 4; + if (xo.hasChildNamed(GAMMA_SHAPE)) { + XMLObject cxo = xo.getChild(GAMMA_SHAPE); + catCount = cxo.getIntegerAttribute(GAMMA_CATEGORIES); + + if ( cxo.hasAttribute(DISCRETIZATION)) { + try { + type = GammaSiteRateDelegate.DiscretizationType.valueOf( + cxo.getStringAttribute(DISCRETIZATION).toUpperCase()); + } catch (IllegalArgumentException eae) { + throw new XMLParseException("Unknown category width type: " + cxo.getStringAttribute(DISCRETIZATION)); + } + } + shapeParam = (Parameter) cxo.getChild(Parameter.class); + + msg += "\n " + catCount + " category discrete gamma with initial shape = " + shapeParam.getParameterValue(0); + if (type == GammaSiteRateDelegate.DiscretizationType.EQUAL) { + msg += "\n using equal weight discretization of gamma distribution"; + } else { + msg += "\n using Gauss-Laguerre quadrature discretization of gamma distribution (Felsenstein, 2012)"; + } + } + + Parameter invarParam = null; + if (xo.hasChildNamed(PROPORTION_INVARIANT)) { + invarParam = (Parameter) xo.getElementFirstChild(PROPORTION_INVARIANT); + msg += "\n initial proportion of invariant sites = " + invarParam.getParameterValue(0); + } + + if (msg.length() > 0) { + Logger.getLogger("dr.evomodel").info("\nCreating site rate model: " + msg); + } else { + Logger.getLogger("dr.evomodel").info("\nCreating site rate model."); + } + + GammaSiteRateDelegate delegate = new GammaSiteRateDelegate("GammaSiteRateDelegate", shapeParam, catCount, type, invarParam); + + DiscretizedSiteRateModel siteRateModel = new DiscretizedSiteRateModel(SiteModel.SITE_MODEL, muParam, muWeight, delegate); + + if (xo.hasChildNamed(SUBSTITUTION_MODEL)) { + +// System.err.println("Doing the substitution model stuff"); + + // set this to pass it along to the OldTreeLikelihoodParser... + substitutionModel = (SubstitutionModel) xo.getElementFirstChild(SUBSTITUTION_MODEL); + siteRateModel.setSubstitutionModel(substitutionModel); + + } + + return siteRateModel; + } + + //************************************************************************ + // AbstractXMLObjectParser implementation + //************************************************************************ + + public String getParserDescription() { + return "A DiscretizedSiteRateModel that has a gamma distributed rates across sites"; + } + + @Override + public String[] getParserNames() { + return super.getParserNames(); + } + + public Class getReturnType() { + return DiscretizedSiteRateModel.class; + } + + public XMLSyntaxRule[] getSyntaxRules() { + return rules; + } + + private final XMLSyntaxRule[] rules = { + + new ElementRule(SUBSTITUTION_MODEL, new XMLSyntaxRule[]{ + new ElementRule(SubstitutionModel.class) + }, true), + + new XORRule( + new XORRule( + new ElementRule(SUBSTITUTION_RATE, new XMLSyntaxRule[]{ + new ElementRule(Parameter.class) + }), + new ElementRule(MUTATION_RATE, new XMLSyntaxRule[]{ + new ElementRule(Parameter.class) + }) + ), + new ElementRule(RELATIVE_RATE, new XMLSyntaxRule[]{ + AttributeRule.newDoubleRule(WEIGHT, true), + new ElementRule(Parameter.class) + }), true + ), + + new ElementRule(GAMMA_SHAPE, new XMLSyntaxRule[]{ + AttributeRule.newIntegerRule(GAMMA_CATEGORIES, true), + AttributeRule.newStringRule(DISCRETIZATION, true), + new ElementRule(Parameter.class) + }, true), + + new ElementRule(PROPORTION_INVARIANT, new XMLSyntaxRule[]{ + new ElementRule(Parameter.class) + }, true) + }; + +}//END: class diff --git a/src/dr/evomodelxml/treedatalikelihood/MultiPartitionDataLikelihoodParser.java b/src/dr/evomodelxml/treedatalikelihood/MultiPartitionDataLikelihoodParser.java index 9080ce86cc..03c8f8fdaf 100644 --- a/src/dr/evomodelxml/treedatalikelihood/MultiPartitionDataLikelihoodParser.java +++ b/src/dr/evomodelxml/treedatalikelihood/MultiPartitionDataLikelihoodParser.java @@ -141,14 +141,19 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { if (DEBUG) { System.out.println("branchModels == null"); } - branchModels = new ArrayList(); + branchModels = new ArrayList<>(); List substitutionModels = xo.getAllChildren(SubstitutionModel.class); if (substitutionModels.size() == 0) { + // no explicitly defined BranchModels so create one if (DEBUG) { System.out.println("substitutionModels == null"); } for (SiteRateModel siteRateModel : siteRateModels) { - SubstitutionModel substitutionModel = ((GammaSiteRateModel)siteRateModel).getSubstitutionModel(); + SubstitutionModel substitutionModel = null; + if (substitutionModel == null && siteRateModel instanceof GammaSiteRateModel) { + // for backwards compatibility the old GammaSiteRateModel can provide the substitution model... + substitutionModel = ((GammaSiteRateModel)siteRateModel).getSubstitutionModel(); + } if (substitutionModel == null) { throw new XMLParseException("No substitution model available for TreeDataLikelihood: "+xo.getId()); } diff --git a/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java b/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java index 53f51ee387..fc0ed446cf 100644 --- a/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java +++ b/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java @@ -43,6 +43,7 @@ import dr.evomodel.treedatalikelihood.MultiPartitionDataLikelihoodDelegate; import dr.evomodel.treedatalikelihood.TreeDataLikelihood; import dr.evomodel.treelikelihood.PartialsRescalingScheme; +import dr.evomodelxml.siteratemodel.OldGammaSiteModelParser; import dr.inference.model.CompoundLikelihood; import dr.inference.model.Likelihood; import dr.xml.*; @@ -110,7 +111,7 @@ protected Likelihood createTreeDataLikelihood(List patternLists, if (patternLists.size() > 1) { // will currently recommend true if using GPU, CUDA or OpenCL. useBeagle3MultiPartition = MultiPartitionDataLikelihoodDelegate.IS_MULTI_PARTITION_RECOMMENDED(); - + if (System.getProperty("USE_BEAGLE3_EXTENSIONS") != null) { useBeagle3MultiPartition = Boolean.parseBoolean(System.getProperty("USE_BEAGLE3_EXTENSIONS")); } @@ -154,7 +155,7 @@ protected Likelihood createTreeDataLikelihood(List patternLists, useAmbiguities, scalingScheme, delayRescalingUntilUnderflow - ); + ); return new TreeDataLikelihood( dataLikelihoodDelegate, @@ -164,7 +165,7 @@ protected Likelihood createTreeDataLikelihood(List patternLists, useBeagle3MultiPartition = false; } - } + } // The multipartition data likelihood isn't available so make a set of single partition data likelihoods List treeDataLikelihoods = new ArrayList(); @@ -200,7 +201,7 @@ protected Likelihood createTreeDataLikelihood(List patternLists, } return new CompoundLikelihood(treeDataLikelihoods); - + } public Object parseXMLObject(XMLObject xo) throws XMLParseException { @@ -236,7 +237,7 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { hasSinglePartition = true; patternLists.add(patternList); - GammaSiteRateModel siteRateModel = (GammaSiteRateModel) xo.getChild(GammaSiteRateModel.class); + SiteRateModel siteRateModel = (SiteRateModel) xo.getChild(SiteRateModel.class); siteRateModels.add(siteRateModel); FrequencyModel rootFreqModel = (FrequencyModel) xo.getChild(FrequencyModel.class); @@ -244,8 +245,9 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { BranchModel branchModel = (BranchModel) xo.getChild(BranchModel.class); if (branchModel == null) { SubstitutionModel substitutionModel = (SubstitutionModel) xo.getChild(SubstitutionModel.class); - if (substitutionModel == null) { - substitutionModel = siteRateModel.getSubstitutionModel(); + if (substitutionModel == null && siteRateModel instanceof GammaSiteRateModel) { + // for backwards compatibility the old GammaSiteRateModel can provide the substitution model... + substitutionModel = ((GammaSiteRateModel)siteRateModel).getSubstitutionModel(); } if (substitutionModel == null) { throw new XMLParseException("No substitution model available for partition in DataTreeLikelihood: "+xo.getId()); @@ -276,6 +278,7 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { if (branchModel == null) { SubstitutionModel substitutionModel = (SubstitutionModel) cxo.getChild(SubstitutionModel.class); if (substitutionModel == null && siteRateModel instanceof GammaSiteRateModel) { + // for backwards compatibility the old GammaSiteRateModel can provide the substitution model... substitutionModel = ((GammaSiteRateModel)siteRateModel).getSubstitutionModel(); } if (substitutionModel == null) { @@ -355,14 +358,17 @@ public Class getReturnType() { AttributeRule.newStringRule(SCALING_SCHEME,true), // really it should be this set of elements or the PARTITION elements - new OrRule(new AndRule(new XMLSyntaxRule[]{ - new ElementRule(PatternList.class, true), - new ElementRule(SiteRateModel.class, true), - new ElementRule(FrequencyModel.class, true), - new ElementRule(BranchModel.class, true)}) - , + new OrRule( + new AndRule(new XMLSyntaxRule[]{ + new ElementRule(PatternList.class, true), + new ElementRule(SubstitutionModel.class, true), + new ElementRule(SiteRateModel.class, true), + new ElementRule(FrequencyModel.class, true), + new ElementRule(BranchModel.class, true)} + ), new ElementRule(PARTITION, new XMLSyntaxRule[] { new ElementRule(PatternList.class), + new ElementRule(SubstitutionModel.class, true), new ElementRule(SiteRateModel.class), new ElementRule(FrequencyModel.class, true), new ElementRule(BranchModel.class, true) diff --git a/src/dr/oldevomodelxml/sitemodel/GammaSiteModelParser.java b/src/dr/oldevomodelxml/sitemodel/GammaSiteModelParser.java index bdad5d35a3..6d1719b695 100644 --- a/src/dr/oldevomodelxml/sitemodel/GammaSiteModelParser.java +++ b/src/dr/oldevomodelxml/sitemodel/GammaSiteModelParser.java @@ -1,5 +1,5 @@ /* - * GammaSiteModelParser.java + * SiteModelParser.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * From c4a903b379b4ab3880c68c8b383d1409cc10ec25 Mon Sep 17 00:00:00 2001 From: rambaut Date: Fri, 30 Jun 2023 17:54:50 +0100 Subject: [PATCH 157/196] Not quite there yet - --- .../tools/CompleteHistorySimulator.java | 4 +-- src/dr/app/beast/release_parsers.properties | 4 ++- src/dr/app/bss/PartitionData.java | 4 +-- src/dr/app/bss/XMLGenerator.java | 2 +- .../app/tools/AncestralSequenceAnnotator.java | 4 +-- .../BeagleBranchLikelihood.java | 4 +-- .../DiscretizedSiteRateModel.java | 34 +++++++++++-------- .../siteratemodel/GammaSiteRateDelegate.java | 2 +- .../siteratemodel/GammaSiteRateModel.java | 4 +-- .../DataLikelihoodTester.java | 2 +- .../DataLikelihoodTester2.java | 2 +- .../treelikelihood/BeagleTreeLikelihood.java | 4 +-- ...del.java => GammaSiteRateModelParser.java} | 22 ++---------- .../MultiPartitionDataLikelihoodParser.java | 2 +- .../TreeDataLikelihoodParser.java | 5 +-- 15 files changed, 45 insertions(+), 54 deletions(-) rename src/dr/evomodelxml/siteratemodel/{GammaSiteRateModel.java => GammaSiteRateModelParser.java} (89%) diff --git a/src/dr/app/beagle/tools/CompleteHistorySimulator.java b/src/dr/app/beagle/tools/CompleteHistorySimulator.java index 3e73d92b59..7970bad832 100644 --- a/src/dr/app/beagle/tools/CompleteHistorySimulator.java +++ b/src/dr/app/beagle/tools/CompleteHistorySimulator.java @@ -114,12 +114,12 @@ public class CompleteHistorySimulator extends SimpleAlignment * @param branchRateModel * @param nReplications: nr of samples to generate */ -// public CompleteHistorySimulator(Tree tree, GammaSiteRateModel siteModel, BranchRateModel branchRateModel, +// public CompleteHistorySimulator(Tree tree, GammaSiteRateModelParser siteModel, BranchRateModel branchRateModel, // int nReplications) { // this(tree, siteModel, branchRateModel, nReplications, false); // } // -// public CompleteHistorySimulator(Tree tree, GammaSiteRateModel siteModel, BranchRateModel branchRateModel, +// public CompleteHistorySimulator(Tree tree, GammaSiteRateModelParser siteModel, BranchRateModel branchRateModel, // int nReplications, boolean sumAcrossSites) { // this(tree, siteModel, branchRateModel, nReplications, sumAcrossSites, null, null); // diff --git a/src/dr/app/beast/release_parsers.properties b/src/dr/app/beast/release_parsers.properties index 9977c5d513..7d34c3fd64 100644 --- a/src/dr/app/beast/release_parsers.properties +++ b/src/dr/app/beast/release_parsers.properties @@ -74,7 +74,9 @@ dr.evomodelxml.substmodel.InfinitesimalRatesLoggerParser dr.evomodelxml.substmodel.LewisMkSubstitutionModelParser # SITE RATE MODELS -dr.evomodelxml.siteratemodel.SiteModelParser +#dr.evomodelxml.siteratemodel.SiteModelParser +dr.evomodelxml.siteratemodel.OldGammaSiteModelParser +dr.evomodelxml.siteratemodel.GammaSiteRateModelParser dr.evomodelxml.siteratemodel.PdfSiteModelParser # BRANCH MODELS diff --git a/src/dr/app/bss/PartitionData.java b/src/dr/app/bss/PartitionData.java index 6389ed55cc..8469a69658 100644 --- a/src/dr/app/bss/PartitionData.java +++ b/src/dr/app/bss/PartitionData.java @@ -1058,7 +1058,7 @@ public void resetSiteRateModelIdref() { }; public static int[][] siteRateModelParameterIndices = { {}, // NoModel - { 0, 1, 2 }, // GammaSiteRateModel + { 0, 1, 2 }, // GammaSiteRateModelParser }; public double[] siteRateModelParameterValues = new double[] { 4.0, // GammaCategories @@ -1075,7 +1075,7 @@ public GammaSiteRateModel createSiteRateModel() { siteModel = new GammaSiteRateModel(name); - } else if (this.siteRateModelIndex == 1) { // GammaSiteRateModel + } else if (this.siteRateModelIndex == 1) { // GammaSiteRateModelParser siteModel = new GammaSiteRateModel(name, siteRateModelParameterValues[1], diff --git a/src/dr/app/bss/XMLGenerator.java b/src/dr/app/bss/XMLGenerator.java index 1faaf7da08..7bc0cc24dd 100644 --- a/src/dr/app/bss/XMLGenerator.java +++ b/src/dr/app/bss/XMLGenerator.java @@ -1123,7 +1123,7 @@ private void writeSiteRateModel(PartitionData data, XMLWriter writer, int suffix break; - case 1: // GammaSiteRateModel + case 1: // GammaSiteRateModelParser writer.writeOpenTag( GammaSiteModelParser.GAMMA_SHAPE, diff --git a/src/dr/app/tools/AncestralSequenceAnnotator.java b/src/dr/app/tools/AncestralSequenceAnnotator.java index 2d7c6f33c2..fafa3a6704 100644 --- a/src/dr/app/tools/AncestralSequenceAnnotator.java +++ b/src/dr/app/tools/AncestralSequenceAnnotator.java @@ -807,7 +807,7 @@ else if(siteRatesModels.indexOf("+GAMMA(") >= 0) { /* For BEAST output */ } //System.out.println("alpha and pinv parameters: " + alphaParameter.getParameterValue(0) + "\t" + pInvParameter.getParameterValue(0)); - //GammaSiteRateModel siteModel = new GammaSiteRateModel(sml.getSubstitutionModel(), new Parameter.Default(1.0), alphaParameter, categories, pInvParameter); + //GammaSiteRateModelParser siteModel = new GammaSiteRateModelParser(sml.getSubstitutionModel(), new Parameter.Default(1.0), alphaParameter, categories, pInvParameter); GammaSiteRateModel siteModel = new GammaSiteRateModel(SiteModelParser.SITE_MODEL, new Parameter.Default(1.0), 1.0, alphaParameter, categories, GammaSiteRateModel.DiscretizationType.EQUAL, pInvParameter); siteModel.setSubstitutionModel(sml.getSubstitutionModel()); //SiteModel siteModel = new GammaSiteModel(sml.getSubstitutionModel(), new Parameter.Default(1.0), new Parameter.Default(1.0), 1, new Parameter.Default(0.5)); @@ -816,7 +816,7 @@ else if(siteRatesModels.indexOf("+GAMMA(") >= 0) { /* For BEAST output */ } /* Default with no gamma or pinv */ - //SiteRateModel siteModel = new GammaSiteRateModel(sml.getSubstitutionModel()); + //SiteRateModel siteModel = new GammaSiteRateModelParser(sml.getSubstitutionModel()); GammaSiteRateModel siteModel = new GammaSiteRateModel(SiteModelParser.SITE_MODEL); siteModel.setSubstitutionModel(sml.getSubstitutionModel()); return siteModel; diff --git a/src/dr/evomodel/branchmodel/lineagespecific/BeagleBranchLikelihood.java b/src/dr/evomodel/branchmodel/lineagespecific/BeagleBranchLikelihood.java index 472ea1c1a9..34cd5ab8f0 100644 --- a/src/dr/evomodel/branchmodel/lineagespecific/BeagleBranchLikelihood.java +++ b/src/dr/evomodel/branchmodel/lineagespecific/BeagleBranchLikelihood.java @@ -31,7 +31,7 @@ import beagle.BeagleFactory; import dr.evomodel.branchmodel.BranchModel; //import dr.evomodel.branchmodel.HomogeneousBranchModel; -//import dr.evomodel.siteratemodel.GammaSiteRateModel; +//import dr.evomodel.siteratemodel.GammaSiteRateModelParser; import dr.evomodel.siteratemodel.SiteRateModel; import dr.evomodel.substmodel.FrequencyModel; //import dr.evomodel.substmodel.nucleotide.HKY; @@ -484,7 +484,7 @@ public double getDoubleValue() { // BranchRateModel branchRateModel = new StrictClockBranchRates(rate); // // // create site model -// GammaSiteRateModel siteRateModel = new GammaSiteRateModel( +// GammaSiteRateModelParser siteRateModel = new GammaSiteRateModelParser( // "siteModel"); // // // create partition diff --git a/src/dr/evomodel/siteratemodel/DiscretizedSiteRateModel.java b/src/dr/evomodel/siteratemodel/DiscretizedSiteRateModel.java index e3d3d77605..1e5e0eadb9 100644 --- a/src/dr/evomodel/siteratemodel/DiscretizedSiteRateModel.java +++ b/src/dr/evomodel/siteratemodel/DiscretizedSiteRateModel.java @@ -82,11 +82,6 @@ public final double getMu() { return nuParameter.getParameterValue(0) * muWeight; } - - public void setRelativeRateParameter(Parameter nu) { - this.nuParameter = nu; - } - // ***************************************************************** // Interface SiteRateModel // ***************************************************************** @@ -98,8 +93,7 @@ public int getCategoryCount() { public double[] getCategoryRates() { synchronized (this) { if (!ratesKnown) { - delegate.getCategories(categoryRates, categoryProportions); - ratesKnown = true; + calculateCategoryRates(); } } @@ -109,8 +103,7 @@ public double[] getCategoryRates() { public double[] getCategoryProportions() { synchronized (this) { if (!ratesKnown) { - delegate.getCategories(categoryRates, categoryProportions); - ratesKnown = true; + calculateCategoryRates(); } } @@ -120,8 +113,7 @@ public double[] getCategoryProportions() { public double getRateForCategory(int category) { synchronized (this) { if (!ratesKnown) { - delegate.getCategories(categoryRates, categoryProportions); - ratesKnown = true; + calculateCategoryRates(); } } @@ -131,14 +123,26 @@ public double getRateForCategory(int category) { public double getProportionForCategory(int category) { synchronized (this) { if (!ratesKnown) { - delegate.getCategories(categoryRates, categoryProportions); - ratesKnown = true; + calculateCategoryRates(); } } return categoryProportions[category]; } + private void calculateCategoryRates() { + + delegate.getCategories(categoryRates, categoryProportions); + + if (nuParameter != null) { + double mu = getMu(); + for (int i = 0; i < getCategoryCount(); i++) + categoryRates[i] *= mu; + } + + ratesKnown = true; + } + // ***************************************************************** // Interface ModelComponent // ***************************************************************** @@ -192,9 +196,9 @@ public double getStatisticValue(int dim) { /** * mutation rate parameter */ - private Parameter nuParameter; + private final Parameter nuParameter; - private double muWeight; + private final double muWeight; private boolean ratesKnown; diff --git a/src/dr/evomodel/siteratemodel/GammaSiteRateDelegate.java b/src/dr/evomodel/siteratemodel/GammaSiteRateDelegate.java index 0e4a56f252..60280cf70a 100644 --- a/src/dr/evomodel/siteratemodel/GammaSiteRateDelegate.java +++ b/src/dr/evomodel/siteratemodel/GammaSiteRateDelegate.java @@ -1,5 +1,5 @@ /* - * GammaSiteRateModel.java + * GammaSiteRateModelParser.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * diff --git a/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java b/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java index e5b06f3723..75df1cca45 100644 --- a/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java +++ b/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java @@ -1,5 +1,5 @@ /* - * GammaSiteRateModel.java + * GammaSiteRateModelParser.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * @@ -289,7 +289,7 @@ protected final void handleVariableChangedEvent(Variable variable, int index, Pa } else if (variable == nuParameter) { ratesKnown = false; // MAS: I changed this because the rate parameter can affect the categories if the parameter is in siteModel and not clockModel } else { - throw new RuntimeException("Unknown variable in GammaSiteRateModel.handleVariableChangedEvent"); + throw new RuntimeException("Unknown variable in GammaSiteRateModelParser.handleVariableChangedEvent"); } listenerHelper.fireModelChanged(this, variable, index); } diff --git a/src/dr/evomodel/treedatalikelihood/DataLikelihoodTester.java b/src/dr/evomodel/treedatalikelihood/DataLikelihoodTester.java index f39d6aaa56..7b7087dd22 100644 --- a/src/dr/evomodel/treedatalikelihood/DataLikelihoodTester.java +++ b/src/dr/evomodel/treedatalikelihood/DataLikelihoodTester.java @@ -88,7 +88,7 @@ public static void main(String[] args) { //siteModel double alpha = 0.5; GammaSiteRateModel siteRateModel = new GammaSiteRateModel("gammaModel", alpha, 4); -// GammaSiteRateModel siteRateModel = new GammaSiteRateModel("siteRateModel"); +// GammaSiteRateModelParser siteRateModel = new GammaSiteRateModelParser("siteRateModel"); siteRateModel.setSubstitutionModel(hky); Parameter mu = new Parameter.Default(SiteModelParser.SUBSTITUTION_RATE, 1.0, 0, Double.POSITIVE_INFINITY); siteRateModel.setRelativeRateParameter(mu); diff --git a/src/dr/evomodel/treedatalikelihood/DataLikelihoodTester2.java b/src/dr/evomodel/treedatalikelihood/DataLikelihoodTester2.java index 1f77087508..7fb3efc8f6 100644 --- a/src/dr/evomodel/treedatalikelihood/DataLikelihoodTester2.java +++ b/src/dr/evomodel/treedatalikelihood/DataLikelihoodTester2.java @@ -88,7 +88,7 @@ public static void main(String[] args) { //siteModel double alpha = 0.5; GammaSiteRateModel siteRateModel = new GammaSiteRateModel("gammaModel", alpha, 4); -// GammaSiteRateModel siteRateModel = new GammaSiteRateModel("siteRateModel"); +// GammaSiteRateModelParser siteRateModel = new GammaSiteRateModelParser("siteRateModel"); siteRateModel.setSubstitutionModel(hky); Parameter mu = new Parameter.Default(SiteModelParser.SUBSTITUTION_RATE, 1.0, 0, Double.POSITIVE_INFINITY); siteRateModel.setRelativeRateParameter(mu); diff --git a/src/dr/evomodel/treelikelihood/BeagleTreeLikelihood.java b/src/dr/evomodel/treelikelihood/BeagleTreeLikelihood.java index 15e2afd514..6dd751d9f4 100644 --- a/src/dr/evomodel/treelikelihood/BeagleTreeLikelihood.java +++ b/src/dr/evomodel/treelikelihood/BeagleTreeLikelihood.java @@ -37,7 +37,7 @@ import dr.evomodel.treedatalikelihood.BeagleDataLikelihoodDelegate; import dr.evomodel.treedatalikelihood.BufferIndexHelper; import dr.evomodelxml.treelikelihood.BeagleTreeLikelihoodParser; -//import dr.evomodel.siteratemodel.GammaSiteRateModel; +//import dr.evomodel.siteratemodel.GammaSiteRateModelParser; import dr.evomodel.siteratemodel.SiteRateModel; //import dr.evomodel.substmodel.FrequencyModel; //import dr.evomodel.substmodel.nucleotide.HKY; @@ -1464,7 +1464,7 @@ private boolean traverse(Tree tree, NodeRef node, int[] operatorNumber, boolean // BranchRateModel branchRateModel = new StrictClockBranchRates(rate); // // // create site model -// GammaSiteRateModel siteRateModel = new GammaSiteRateModel( +// GammaSiteRateModelParser siteRateModel = new GammaSiteRateModelParser( // "siteModel"); // // BranchModel homogeneousBranchModel = new HomogeneousBranchModel(hky1); diff --git a/src/dr/evomodelxml/siteratemodel/GammaSiteRateModel.java b/src/dr/evomodelxml/siteratemodel/GammaSiteRateModelParser.java similarity index 89% rename from src/dr/evomodelxml/siteratemodel/GammaSiteRateModel.java rename to src/dr/evomodelxml/siteratemodel/GammaSiteRateModelParser.java index ed29c87c4a..d616ed91cb 100644 --- a/src/dr/evomodelxml/siteratemodel/GammaSiteRateModel.java +++ b/src/dr/evomodelxml/siteratemodel/GammaSiteRateModelParser.java @@ -40,9 +40,9 @@ * @author Andrew Rambaut * @version $Id$ */ -public class GammaSiteRateModel extends AbstractXMLObjectParser { +public class GammaSiteRateModelParser extends AbstractXMLObjectParser { - public static final String GAMMA_SITE_RATE_MODEL = "GammaSiteRateModel"; + public static final String GAMMA_SITE_RATE_MODEL = "gammaSiteRateModel"; public static final String SUBSTITUTION_MODEL = "substitutionModel"; public static final String MUTATION_RATE = "mutationRate"; public static final String SUBSTITUTION_RATE = "substitutionRate"; @@ -123,19 +123,7 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { GammaSiteRateDelegate delegate = new GammaSiteRateDelegate("GammaSiteRateDelegate", shapeParam, catCount, type, invarParam); - DiscretizedSiteRateModel siteRateModel = new DiscretizedSiteRateModel(SiteModel.SITE_MODEL, muParam, muWeight, delegate); - - if (xo.hasChildNamed(SUBSTITUTION_MODEL)) { - -// System.err.println("Doing the substitution model stuff"); - - // set this to pass it along to the OldTreeLikelihoodParser... - substitutionModel = (SubstitutionModel) xo.getElementFirstChild(SUBSTITUTION_MODEL); - siteRateModel.setSubstitutionModel(substitutionModel); - - } - - return siteRateModel; + return new DiscretizedSiteRateModel(SiteModel.SITE_MODEL, muParam, muWeight, delegate); } //************************************************************************ @@ -161,10 +149,6 @@ public XMLSyntaxRule[] getSyntaxRules() { private final XMLSyntaxRule[] rules = { - new ElementRule(SUBSTITUTION_MODEL, new XMLSyntaxRule[]{ - new ElementRule(SubstitutionModel.class) - }, true), - new XORRule( new XORRule( new ElementRule(SUBSTITUTION_RATE, new XMLSyntaxRule[]{ diff --git a/src/dr/evomodelxml/treedatalikelihood/MultiPartitionDataLikelihoodParser.java b/src/dr/evomodelxml/treedatalikelihood/MultiPartitionDataLikelihoodParser.java index 03c8f8fdaf..c44c9a4745 100644 --- a/src/dr/evomodelxml/treedatalikelihood/MultiPartitionDataLikelihoodParser.java +++ b/src/dr/evomodelxml/treedatalikelihood/MultiPartitionDataLikelihoodParser.java @@ -151,7 +151,7 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { for (SiteRateModel siteRateModel : siteRateModels) { SubstitutionModel substitutionModel = null; if (substitutionModel == null && siteRateModel instanceof GammaSiteRateModel) { - // for backwards compatibility the old GammaSiteRateModel can provide the substitution model... + // for backwards compatibility the old GammaSiteRateModelParser can provide the substitution model... substitutionModel = ((GammaSiteRateModel)siteRateModel).getSubstitutionModel(); } if (substitutionModel == null) { diff --git a/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java b/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java index fc0ed446cf..d0b5fc0dd3 100644 --- a/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java +++ b/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java @@ -31,6 +31,7 @@ import dr.evomodel.branchmodel.BranchModel; import dr.evomodel.branchmodel.HomogeneousBranchModel; import dr.evomodel.branchratemodel.BranchRateModel; +import dr.evomodel.siteratemodel.DiscretizedSiteRateModel; import dr.evomodel.siteratemodel.GammaSiteRateModel; import dr.evomodel.siteratemodel.SiteRateModel; import dr.evomodel.substmodel.FrequencyModel; @@ -246,7 +247,7 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { if (branchModel == null) { SubstitutionModel substitutionModel = (SubstitutionModel) xo.getChild(SubstitutionModel.class); if (substitutionModel == null && siteRateModel instanceof GammaSiteRateModel) { - // for backwards compatibility the old GammaSiteRateModel can provide the substitution model... + // for backwards compatibility the old GammaSiteRateModelParser can provide the substitution model... substitutionModel = ((GammaSiteRateModel)siteRateModel).getSubstitutionModel(); } if (substitutionModel == null) { @@ -278,7 +279,7 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { if (branchModel == null) { SubstitutionModel substitutionModel = (SubstitutionModel) cxo.getChild(SubstitutionModel.class); if (substitutionModel == null && siteRateModel instanceof GammaSiteRateModel) { - // for backwards compatibility the old GammaSiteRateModel can provide the substitution model... + // for backwards compatibility the old GammaSiteRateModelParser can provide the substitution model... substitutionModel = ((GammaSiteRateModel)siteRateModel).getSubstitutionModel(); } if (substitutionModel == null) { From b02fe753fb84a61c523ecc9978a004454d481127 Mon Sep 17 00:00:00 2001 From: rambaut Date: Fri, 30 Jun 2023 18:02:04 +0100 Subject: [PATCH 158/196] Starting a FreeRate delegate --- .../siteratemodel/FreeRateDelegate.java | 179 ++++++++++++++++++ 1 file changed, 179 insertions(+) create mode 100644 src/dr/evomodel/siteratemodel/FreeRateDelegate.java diff --git a/src/dr/evomodel/siteratemodel/FreeRateDelegate.java b/src/dr/evomodel/siteratemodel/FreeRateDelegate.java new file mode 100644 index 0000000000..e8228a6884 --- /dev/null +++ b/src/dr/evomodel/siteratemodel/FreeRateDelegate.java @@ -0,0 +1,179 @@ +/* + * GammaSiteRateModelParser.java + * + * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard + * + * This file is part of BEAST. + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership and licensing. + * + * BEAST is free software; you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation; either version 2 + * of the License, or (at your option) any later version. + * + * BEAST is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with BEAST; if not, write to the + * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, + * Boston, MA 02110-1301 USA + */ + +package dr.evomodel.siteratemodel; + +import dr.evomodel.substmodel.SubstitutionModel; +import dr.inference.model.AbstractModel; +import dr.inference.model.Model; +import dr.inference.model.Parameter; +import dr.inference.model.Variable; +import dr.math.GeneralisedGaussLaguerreQuadrature; +import dr.math.distributions.GammaDistribution; +import dr.math.functionEval.GammaFunction; +import dr.util.Author; +import dr.util.Citable; +import dr.util.Citation; + +import java.util.ArrayList; +import java.util.List; + +/** + * FreeRateDelegate - A SiteModel delegate that implements the 'FreeRate' model. + * + * @author Andrew Rambaut + * @version $Id: GammaSiteModel.java,v 1.31 2005/09/26 14:27:38 rambaut Exp $ + */ + +public class FreeRateDelegate extends AbstractModel implements SiteRateDelegate, Citable { + + + + /** + * Constructor for gamma+invar distributed sites. Either shapeParameter or + * invarParameter (or both) can be null to turn off that feature. + */ + public FreeRateDelegate( + String name, + Parameter rateParameter, + Parameter weightParameter) { + + super(name); + + this.rateParameter = rateParameter; + this.categoryCount = rateParameter.getDimension() + 1; + addVariable(rateParameter); + + rateParameter.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0, 1)); + + this.weightParameter = weightParameter; + assert categoryCount == weightParameter.getDimension() + 1; + + addVariable(weightParameter); + weightParameter.addBounds(new Parameter.DefaultBounds(1.0, 0.0, 1)); + } + + // ***************************************************************** + // Interface SiteRateModel + // ***************************************************************** + + public int getCategoryCount() { + return categoryCount; + } + + public void getCategories(double[] categoryRates, double[] categoryProportions) { + assert categoryRates != null && categoryRates.length == categoryCount; + assert categoryProportions != null && categoryProportions.length == categoryCount; + + + categoryRates[0] = 1.0; + categoryProportions[0] = 1.0; + } + + // ***************************************************************** + // Interface ModelComponent + // ***************************************************************** + + protected void handleModelChangedEvent(Model model, Object object, int index) { + listenerHelper.fireModelChanged(this, object, index); + } + + protected final void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) { + listenerHelper.fireModelChanged(this, variable, index); + } + + protected void storeState() { + } // no additional state needs storing + + protected void restoreState() { + } + + protected void acceptState() { + } // no additional state needs accepting + + + /** + * rate parameter + */ + private final Parameter rateParameter; + + /** + * weights parameter + */ + private final Parameter weightParameter; + + + private final int categoryCount; + + + @Override + public Citation.Category getCategory() { + return Citation.Category.SUBSTITUTION_MODELS; + } + + @Override + public String getDescription() { + return "Discrete free-rate heterogeneity model"; + } + + public List getCitations() { + List citations = new ArrayList<>(); + return citations; + } + + public final static Citation CITATION_YANG94 = new Citation( + new Author[]{ + new Author("", "") + }, + "", + 1994, + "J. Mol. Evol.", + 39, + 306, 314, + Citation.Status.PUBLISHED + ); + + /** + * Gives the category rates a mean of 1.0 and the proportions sum to 1.0 + * @param categoryRates + * @param categoryProportions + */ + public static void normalize(double[] categoryRates, double[] categoryProportions) { + double mean = 0.0; + double sum = 0.0; + for (int i = 0; i < categoryRates.length; i++) { + mean += categoryRates[i]; + sum += categoryProportions[i]; + } + mean /= categoryRates.length; + + for(int i = 0; i < categoryRates.length; i++) { + categoryRates[i] /= mean; + categoryProportions[i] /= sum; + } + } + + +} \ No newline at end of file From cf57048d7705ad9cbc76e072591175f94b1976af Mon Sep 17 00:00:00 2001 From: rambaut Date: Fri, 30 Jun 2023 18:02:59 +0100 Subject: [PATCH 159/196] Starting a FreeRate delegate --- src/dr/evomodel/siteratemodel/FreeRateDelegate.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dr/evomodel/siteratemodel/FreeRateDelegate.java b/src/dr/evomodel/siteratemodel/FreeRateDelegate.java index e8228a6884..9c8fa587b9 100644 --- a/src/dr/evomodel/siteratemodel/FreeRateDelegate.java +++ b/src/dr/evomodel/siteratemodel/FreeRateDelegate.java @@ -1,7 +1,7 @@ /* - * GammaSiteRateModelParser.java + * FreeRateDelegate.java * - * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard + * Copyright (c) 2002-2023 BEAST Developer Team * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional From 8132f618b218d8dec2e65d53992d69853c82a425 Mon Sep 17 00:00:00 2001 From: Andrew Rambaut Date: Mon, 3 Jul 2023 15:22:42 +0100 Subject: [PATCH 160/196] Implementing a more generalised rate heterogeneity system --- .../siteratemodel/FreeRateDelegate.java | 46 ++++-- .../siteratemodel/GammaSiteRateModel.java | 2 +- .../FreeRateSiteRateModelParser.java | 142 ++++++++++++++++++ .../siteratemodel/SiteModelParser.java | 2 +- 4 files changed, 179 insertions(+), 13 deletions(-) create mode 100644 src/dr/evomodelxml/siteratemodel/FreeRateSiteRateModelParser.java diff --git a/src/dr/evomodel/siteratemodel/FreeRateDelegate.java b/src/dr/evomodel/siteratemodel/FreeRateDelegate.java index 9c8fa587b9..c6a3807d17 100644 --- a/src/dr/evomodel/siteratemodel/FreeRateDelegate.java +++ b/src/dr/evomodel/siteratemodel/FreeRateDelegate.java @@ -57,22 +57,32 @@ public class FreeRateDelegate extends AbstractModel implements SiteRateDelegate, */ public FreeRateDelegate( String name, + int categoryCount, Parameter rateParameter, Parameter weightParameter) { super(name); - - this.rateParameter = rateParameter; - this.categoryCount = rateParameter.getDimension() + 1; - addVariable(rateParameter); - rateParameter.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0, 1)); + this.categoryCount = categoryCount; + + this.rateParameter = rateParameter; + if (this.rateParameter.getDimension() == 1) { + this.rateParameter.setDimension(categoryCount - 1); + } else if (this.rateParameter.getDimension() != categoryCount - 1) { + throw new IllegalArgumentException("Rate parameter should have have an initial dimension of one or category count - 1"); + } + addVariable(this.rateParameter); + this.rateParameter.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0, 1)); this.weightParameter = weightParameter; - assert categoryCount == weightParameter.getDimension() + 1; + if (this.weightParameter.getDimension() == 1) { + this.weightParameter.setDimension(categoryCount - 1); + } else if (this.weightParameter.getDimension() != categoryCount - 1) { + throw new IllegalArgumentException("Weight parameter should have have an initial dimension of one or category count - 1"); + } - addVariable(weightParameter); - weightParameter.addBounds(new Parameter.DefaultBounds(1.0, 0.0, 1)); + addVariable(this.weightParameter); + this.weightParameter.addBounds(new Parameter.DefaultBounds(1.0, 0.0, 1)); } // ***************************************************************** @@ -87,9 +97,23 @@ public void getCategories(double[] categoryRates, double[] categoryProportions) assert categoryRates != null && categoryRates.length == categoryCount; assert categoryProportions != null && categoryProportions.length == categoryCount; - categoryRates[0] = 1.0; - categoryProportions[0] = 1.0; + double sumRate = 1.0; + double sumWeight = 0.0; + for (int i = 0; i < categoryCount - 1; i++) { + categoryRates[i + 1] = categoryRates[i] * rateParameter.getParameterValue(i); + sumRate += categoryRates[i + 1]; + + categoryProportions[i] = weightParameter.getParameterValue(i); + sumWeight += categoryProportions[i]; + } + // calculate the last value + categoryProportions[categoryCount - 1] = 1.0 - sumWeight; + + // scale so their mean is 1 + for (int i = 0; i < categoryCount; i++) { + categoryRates[i] = categoryCount * categoryRates[i] / sumRate; + } } // ***************************************************************** @@ -126,7 +150,7 @@ protected void acceptState() { private final int categoryCount; - + @Override public Citation.Category getCategory() { diff --git a/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java b/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java index 75df1cca45..9a6f32bbb9 100644 --- a/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java +++ b/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java @@ -435,7 +435,7 @@ public static void setQuatratureRates(double[] categoryRates, double[] categoryP for (int i = 0; i < catCount; i++) { categoryRates[i + offset] = abscissae[i] / (alpha + 1); - categoryProportions[i + offset] = coefficients[i]/GammaFunction.gamma(alpha + 1); + categoryProportions[i + offset] = coefficients[i] / GammaFunction.gamma(alpha + 1); } } diff --git a/src/dr/evomodelxml/siteratemodel/FreeRateSiteRateModelParser.java b/src/dr/evomodelxml/siteratemodel/FreeRateSiteRateModelParser.java new file mode 100644 index 0000000000..bc188bd4ad --- /dev/null +++ b/src/dr/evomodelxml/siteratemodel/FreeRateSiteRateModelParser.java @@ -0,0 +1,142 @@ +/* + * FreeRateSiteRateModelParser.java + * + * Copyright (c) 2002-2023 BEAST Development Team + * + * This file is part of BEAST. + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership and licensing. + * + * BEAST is free software; you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation; either version 2 + * of the License, or (at your option) any later version. + * + * BEAST is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with BEAST; if not, write to the + * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, + * Boston, MA 02110-1301 USA + */ + +package dr.evomodelxml.siteratemodel; + +import dr.evomodel.siteratemodel.DiscretizedSiteRateModel; +import dr.evomodel.siteratemodel.FreeRateDelegate; +import dr.evomodel.substmodel.SubstitutionModel; +import dr.inference.model.Parameter; +import dr.oldevomodel.sitemodel.SiteModel; +import dr.xml.*; + +import java.util.logging.Logger; + +/** + * This is a FreeRateSiteRateModelParser that uses the modular + * DiscretizedSiteRateModel with a FreeRates delegate. + * @author Andrew Rambaut + * @version $Id$ + */ +public class FreeRateSiteRateModelParser extends AbstractXMLObjectParser { + + public static final String FREE_RATE_SITE_RATE_MODEL = "freeRateSiteRateModel"; + public static final String MUTATION_RATE = "mutationRate"; + public static final String SUBSTITUTION_RATE = "substitutionRate"; + public static final String RELATIVE_RATE = "relativeRate"; + public static final String WEIGHT = "weight"; + public static final String RATES = "relativeRates"; + public static final String RATE_CATEGORIES = "rateCategories"; + public static final String WEIGHTS = "weights"; + + public String getParserName() { + return FREE_RATE_SITE_RATE_MODEL; + } + + public Object parseXMLObject(XMLObject xo) throws XMLParseException { + + String msg = ""; + SubstitutionModel substitutionModel; + + double muWeight = 1.0; + + Parameter muParam = null; + if (xo.hasChildNamed(SUBSTITUTION_RATE)) { + muParam = (Parameter) xo.getElementFirstChild(SUBSTITUTION_RATE); + + msg += "\n with initial substitution rate = " + muParam.getParameterValue(0); + } else if (xo.hasChildNamed(RELATIVE_RATE)) { + XMLObject cxo = xo.getChild(RELATIVE_RATE); + muParam = (Parameter) cxo.getChild(Parameter.class); + msg += "\n with initial relative rate = " + muParam.getParameterValue(0); + if (cxo.hasAttribute(WEIGHT)) { + muWeight = cxo.getDoubleAttribute(WEIGHT); + msg += " with weight: " + muWeight; + } + } + + int catCount = 4; + XMLObject cxo = xo.getChild(RATES); + catCount = cxo.getIntegerAttribute(RATE_CATEGORIES); + Parameter ratesParameter = (Parameter)xo.getChild(Parameter.class); + + Parameter weightsParameter = (Parameter)xo.getElementFirstChild(WEIGHTS); + + msg += "\n " + catCount + " category discrete free rate site rate heterogeneity model)"; + if (msg.length() > 0) { + Logger.getLogger("dr.evomodel").info("\nCreating free rate site rate model: " + msg); + } else { + Logger.getLogger("dr.evomodel").info("\nCreating free rate site rate model."); + } + + FreeRateDelegate delegate = new FreeRateDelegate("FreeRateDelegate", catCount, ratesParameter, weightsParameter); + + return new DiscretizedSiteRateModel(SiteModel.SITE_MODEL, muParam, muWeight, delegate); + } + + //************************************************************************ + // AbstractXMLObjectParser implementation + //************************************************************************ + + public String getParserDescription() { + return "A DiscretizedSiteRateModel that has freely distributed rates across sites"; + } + + @Override + public String[] getParserNames() { + return super.getParserNames(); + } + + public Class getReturnType() { + return DiscretizedSiteRateModel.class; + } + + public XMLSyntaxRule[] getSyntaxRules() { + return rules; + } + + private final XMLSyntaxRule[] rules = { + AttributeRule.newIntegerRule(RATE_CATEGORIES, true), + new XORRule( + new ElementRule(SUBSTITUTION_RATE, new XMLSyntaxRule[]{ + new ElementRule(Parameter.class) + }), + new ElementRule(RELATIVE_RATE, new XMLSyntaxRule[]{ + AttributeRule.newDoubleRule(WEIGHT, true), + new ElementRule(Parameter.class) + }), true + ), + + new ElementRule(RATES, new XMLSyntaxRule[]{ + AttributeRule.newIntegerRule(RATE_CATEGORIES, false), + new ElementRule(Parameter.class) + }, false), + + new ElementRule(WEIGHTS, new XMLSyntaxRule[]{ + new ElementRule(Parameter.class) + }, false) + }; + +}//END: class diff --git a/src/dr/evomodelxml/siteratemodel/SiteModelParser.java b/src/dr/evomodelxml/siteratemodel/SiteModelParser.java index 685ed31062..a3067887be 100644 --- a/src/dr/evomodelxml/siteratemodel/SiteModelParser.java +++ b/src/dr/evomodelxml/siteratemodel/SiteModelParser.java @@ -48,7 +48,7 @@ */ public class SiteModelParser extends AbstractXMLObjectParser { - public static final String SITE_MODEL = "SiteModel"; + public static final String SITE_MODEL = "siteModel"; public static final String SUBSTITUTION_MODEL = "substitutionModel"; public static final String MUTATION_RATE = "mutationRate"; public static final String SUBSTITUTION_RATE = "substitutionRate"; From ce3fbaea6eb07a246ea2d1e3bf2d580909facac0 Mon Sep 17 00:00:00 2001 From: Andrew Rambaut Date: Tue, 4 Jul 2023 09:01:08 +0100 Subject: [PATCH 161/196] New parsers for discretised Rate heterogeneity models --- src/dr/app/beast/release_parsers.properties | 1 + .../siteratemodel/FreeRateDelegate.java | 10 +++--- .../FreeRateSiteRateModelParser.java | 10 +++--- .../GammaSiteRateModelParser.java | 31 ++++++++++--------- 4 files changed, 26 insertions(+), 26 deletions(-) diff --git a/src/dr/app/beast/release_parsers.properties b/src/dr/app/beast/release_parsers.properties index 7d34c3fd64..4bba1b4173 100644 --- a/src/dr/app/beast/release_parsers.properties +++ b/src/dr/app/beast/release_parsers.properties @@ -77,6 +77,7 @@ dr.evomodelxml.substmodel.LewisMkSubstitutionModelParser #dr.evomodelxml.siteratemodel.SiteModelParser dr.evomodelxml.siteratemodel.OldGammaSiteModelParser dr.evomodelxml.siteratemodel.GammaSiteRateModelParser +dr.evomodelxml.siteratemodel.FreeRateSiteRateModelParser dr.evomodelxml.siteratemodel.PdfSiteModelParser # BRANCH MODELS diff --git a/src/dr/evomodel/siteratemodel/FreeRateDelegate.java b/src/dr/evomodel/siteratemodel/FreeRateDelegate.java index c6a3807d17..3f548bdec9 100644 --- a/src/dr/evomodel/siteratemodel/FreeRateDelegate.java +++ b/src/dr/evomodel/siteratemodel/FreeRateDelegate.java @@ -72,17 +72,17 @@ public FreeRateDelegate( throw new IllegalArgumentException("Rate parameter should have have an initial dimension of one or category count - 1"); } addVariable(this.rateParameter); - this.rateParameter.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0, 1)); + this.rateParameter.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0, categoryCount - 1)); this.weightParameter = weightParameter; if (this.weightParameter.getDimension() == 1) { - this.weightParameter.setDimension(categoryCount - 1); - } else if (this.weightParameter.getDimension() != categoryCount - 1) { - throw new IllegalArgumentException("Weight parameter should have have an initial dimension of one or category count - 1"); + this.weightParameter.setDimension(categoryCount); + } else if (this.weightParameter.getDimension() != categoryCount) { + throw new IllegalArgumentException("Weight parameter should have have an initial dimension of one or category count"); } addVariable(this.weightParameter); - this.weightParameter.addBounds(new Parameter.DefaultBounds(1.0, 0.0, 1)); + this.weightParameter.addBounds(new Parameter.DefaultBounds(1.0, 0.0, categoryCount)); } // ***************************************************************** diff --git a/src/dr/evomodelxml/siteratemodel/FreeRateSiteRateModelParser.java b/src/dr/evomodelxml/siteratemodel/FreeRateSiteRateModelParser.java index bc188bd4ad..3c7eaad55c 100644 --- a/src/dr/evomodelxml/siteratemodel/FreeRateSiteRateModelParser.java +++ b/src/dr/evomodelxml/siteratemodel/FreeRateSiteRateModelParser.java @@ -48,7 +48,7 @@ public class FreeRateSiteRateModelParser extends AbstractXMLObjectParser { public static final String RELATIVE_RATE = "relativeRate"; public static final String WEIGHT = "weight"; public static final String RATES = "relativeRates"; - public static final String RATE_CATEGORIES = "rateCategories"; + public static final String CATEGORIES = "categories"; public static final String WEIGHTS = "weights"; public String getParserName() { @@ -78,9 +78,8 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { } int catCount = 4; - XMLObject cxo = xo.getChild(RATES); - catCount = cxo.getIntegerAttribute(RATE_CATEGORIES); - Parameter ratesParameter = (Parameter)xo.getChild(Parameter.class); + catCount = xo.getIntegerAttribute(CATEGORIES); + Parameter ratesParameter = (Parameter)xo.getElementFirstChild(RATES); Parameter weightsParameter = (Parameter)xo.getElementFirstChild(WEIGHTS); @@ -118,7 +117,7 @@ public XMLSyntaxRule[] getSyntaxRules() { } private final XMLSyntaxRule[] rules = { - AttributeRule.newIntegerRule(RATE_CATEGORIES, true), + AttributeRule.newIntegerRule(CATEGORIES, true), new XORRule( new ElementRule(SUBSTITUTION_RATE, new XMLSyntaxRule[]{ new ElementRule(Parameter.class) @@ -130,7 +129,6 @@ public XMLSyntaxRule[] getSyntaxRules() { ), new ElementRule(RATES, new XMLSyntaxRule[]{ - AttributeRule.newIntegerRule(RATE_CATEGORIES, false), new ElementRule(Parameter.class) }, false), diff --git a/src/dr/evomodelxml/siteratemodel/GammaSiteRateModelParser.java b/src/dr/evomodelxml/siteratemodel/GammaSiteRateModelParser.java index d616ed91cb..fdb6c54ad1 100644 --- a/src/dr/evomodelxml/siteratemodel/GammaSiteRateModelParser.java +++ b/src/dr/evomodelxml/siteratemodel/GammaSiteRateModelParser.java @@ -49,7 +49,7 @@ public class GammaSiteRateModelParser extends AbstractXMLObjectParser { public static final String RELATIVE_RATE = "relativeRate"; public static final String WEIGHT = "weight"; public static final String GAMMA_SHAPE = "gammaShape"; - public static final String GAMMA_CATEGORIES = "gammaCategories"; + public static final String CATEGORIES = "categories"; public static final String PROPORTION_INVARIANT = "proportionInvariant"; public static final String DISCRETIZATION = "discretization"; @@ -83,29 +83,30 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { } } + int catCount = 4; + catCount = xo.getIntegerAttribute(CATEGORIES); + GammaSiteRateDelegate.DiscretizationType type = GammaSiteRateDelegate.DEFAULT_DISCRETIZATION; + if ( xo.hasAttribute(DISCRETIZATION)) { + try { + type = GammaSiteRateDelegate.DiscretizationType.valueOf( + xo.getStringAttribute(DISCRETIZATION).toUpperCase()); + } catch (IllegalArgumentException eae) { + throw new XMLParseException("Unknown category width type: " + xo.getStringAttribute(DISCRETIZATION)); + } + } Parameter shapeParam = null; - int catCount = 4; if (xo.hasChildNamed(GAMMA_SHAPE)) { XMLObject cxo = xo.getChild(GAMMA_SHAPE); - catCount = cxo.getIntegerAttribute(GAMMA_CATEGORIES); - - if ( cxo.hasAttribute(DISCRETIZATION)) { - try { - type = GammaSiteRateDelegate.DiscretizationType.valueOf( - cxo.getStringAttribute(DISCRETIZATION).toUpperCase()); - } catch (IllegalArgumentException eae) { - throw new XMLParseException("Unknown category width type: " + cxo.getStringAttribute(DISCRETIZATION)); - } - } + shapeParam = (Parameter) cxo.getChild(Parameter.class); msg += "\n " + catCount + " category discrete gamma with initial shape = " + shapeParam.getParameterValue(0); if (type == GammaSiteRateDelegate.DiscretizationType.EQUAL) { msg += "\n using equal weight discretization of gamma distribution"; } else { - msg += "\n using Gauss-Laguerre quadrature discretization of gamma distribution (Felsenstein, 2012)"; + msg += "\n using Gauss-Laguerre quadrature discretization of gamma distribution (Felsenstein, 2001)"; } } @@ -149,6 +150,8 @@ public XMLSyntaxRule[] getSyntaxRules() { private final XMLSyntaxRule[] rules = { + AttributeRule.newIntegerRule(CATEGORIES, false), + AttributeRule.newStringRule(DISCRETIZATION, true), new XORRule( new XORRule( new ElementRule(SUBSTITUTION_RATE, new XMLSyntaxRule[]{ @@ -165,8 +168,6 @@ public XMLSyntaxRule[] getSyntaxRules() { ), new ElementRule(GAMMA_SHAPE, new XMLSyntaxRule[]{ - AttributeRule.newIntegerRule(GAMMA_CATEGORIES, true), - AttributeRule.newStringRule(DISCRETIZATION, true), new ElementRule(Parameter.class) }, true), From d253cfd2ef0008cce69e79a39400e05cf3c5af3a Mon Sep 17 00:00:00 2001 From: rambaut Date: Tue, 4 Jul 2023 17:31:28 +0100 Subject: [PATCH 162/196] FreeRate parameterization options. --- .../DiscretizedSiteRateModel.java | 1 + .../siteratemodel/FreeRateDelegate.java | 82 ++++++++++++++----- .../FreeRateSiteRateModelParser.java | 12 ++- 3 files changed, 71 insertions(+), 24 deletions(-) diff --git a/src/dr/evomodel/siteratemodel/DiscretizedSiteRateModel.java b/src/dr/evomodel/siteratemodel/DiscretizedSiteRateModel.java index 1e5e0eadb9..d8e4d4dac0 100644 --- a/src/dr/evomodel/siteratemodel/DiscretizedSiteRateModel.java +++ b/src/dr/evomodel/siteratemodel/DiscretizedSiteRateModel.java @@ -149,6 +149,7 @@ private void calculateCategoryRates() { protected void handleModelChangedEvent(Model model, Object object, int index) { // delegate has changed so fire model changed event + ratesKnown = false; listenerHelper.fireModelChanged(this, object, index); } diff --git a/src/dr/evomodel/siteratemodel/FreeRateDelegate.java b/src/dr/evomodel/siteratemodel/FreeRateDelegate.java index 3f548bdec9..6b224e1fdc 100644 --- a/src/dr/evomodel/siteratemodel/FreeRateDelegate.java +++ b/src/dr/evomodel/siteratemodel/FreeRateDelegate.java @@ -49,6 +49,14 @@ public class FreeRateDelegate extends AbstractModel implements SiteRateDelegate, Citable { + public static final Parameterization DEFAULT_PARAMETERIZATION = Parameterization.ABSOLUTE; + + public enum Parameterization { + ABSOLUTE, + RATIOS, + DIFFERENCES + }; + /** @@ -58,21 +66,32 @@ public class FreeRateDelegate extends AbstractModel implements SiteRateDelegate, public FreeRateDelegate( String name, int categoryCount, + Parameterization parameterization, Parameter rateParameter, Parameter weightParameter) { super(name); this.categoryCount = categoryCount; + this.parameterization = parameterization; this.rateParameter = rateParameter; - if (this.rateParameter.getDimension() == 1) { - this.rateParameter.setDimension(categoryCount - 1); - } else if (this.rateParameter.getDimension() != categoryCount - 1) { - throw new IllegalArgumentException("Rate parameter should have have an initial dimension of one or category count - 1"); + if (parameterization == Parameterization.ABSOLUTE) { + if (this.rateParameter.getDimension() == 1) { + this.rateParameter.setDimension(categoryCount); + } else if (this.rateParameter.getDimension() != categoryCount) { + throw new IllegalArgumentException("Rate parameter should have have an initial dimension of one or category count"); + } + this.rateParameter.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0, categoryCount)); + } else { + if (this.rateParameter.getDimension() == 1) { + this.rateParameter.setDimension(categoryCount - 1); + } else if (this.rateParameter.getDimension() != categoryCount - 1) { + throw new IllegalArgumentException("Rate parameter should have have an initial dimension of one or category count - 1"); + } + this.rateParameter.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0, categoryCount - 1)); } addVariable(this.rateParameter); - this.rateParameter.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0, categoryCount - 1)); this.weightParameter = weightParameter; if (this.weightParameter.getDimension() == 1) { @@ -97,25 +116,43 @@ public void getCategories(double[] categoryRates, double[] categoryProportions) assert categoryRates != null && categoryRates.length == categoryCount; assert categoryProportions != null && categoryProportions.length == categoryCount; - categoryRates[0] = 1.0; - double sumRate = 1.0; - double sumWeight = 0.0; - for (int i = 0; i < categoryCount - 1; i++) { - categoryRates[i + 1] = categoryRates[i] * rateParameter.getParameterValue(i); - sumRate += categoryRates[i + 1]; - - categoryProportions[i] = weightParameter.getParameterValue(i); - sumWeight += categoryProportions[i]; - } - // calculate the last value - categoryProportions[categoryCount - 1] = 1.0 - sumWeight; - - // scale so their mean is 1 - for (int i = 0; i < categoryCount; i++) { - categoryRates[i] = categoryCount * categoryRates[i] / sumRate; + if (parameterization == Parameterization.ABSOLUTE) { + double sumRates = 0.0; + double sumWeights = 0.0; + for (int i = 0; i < categoryCount; i++) { + categoryRates[i] = rateParameter.getParameterValue(i); + sumRates += categoryRates[i]; + categoryProportions[i] = weightParameter.getParameterValue(i); + sumWeights += categoryProportions[i]; + } + assert Math.abs(sumRates - categoryCount) < 1E-10; + assert Math.abs(sumWeights - 1.0) < 1E-10; + } else { + categoryRates[0] = 1.0; + double sumRates = 0.0; + double sumWeights = 0.0; + for (int i = 0; i < categoryCount; i++) { + if (parameterization == Parameterization.RATIOS) { + if (i > 0) { + categoryRates[i] = categoryRates[i - 1] * rateParameter.getParameterValue(i); + } + } else { // Parameterization.DIFFERENCES + categoryRates[i] = categoryRates[i - 1] + rateParameter.getParameterValue(i); + } + sumRates += categoryRates[i + 1]; + + categoryProportions[i] = weightParameter.getParameterValue(i); + sumWeights += categoryProportions[i]; + } + assert Math.abs(sumWeights - 1.0) < 1E-10; + + // scale so their mean is 1 + for (int i = 0; i < categoryCount; i++) { + categoryRates[i] = categoryCount * categoryRates[i] / sumRates; + } } } - + // ***************************************************************** // Interface ModelComponent // ***************************************************************** @@ -151,6 +188,7 @@ protected void acceptState() { private final int categoryCount; + private final Parameterization parameterization; @Override public Citation.Category getCategory() { diff --git a/src/dr/evomodelxml/siteratemodel/FreeRateSiteRateModelParser.java b/src/dr/evomodelxml/siteratemodel/FreeRateSiteRateModelParser.java index 3c7eaad55c..c6257f4eb7 100644 --- a/src/dr/evomodelxml/siteratemodel/FreeRateSiteRateModelParser.java +++ b/src/dr/evomodelxml/siteratemodel/FreeRateSiteRateModelParser.java @@ -47,8 +47,9 @@ public class FreeRateSiteRateModelParser extends AbstractXMLObjectParser { public static final String SUBSTITUTION_RATE = "substitutionRate"; public static final String RELATIVE_RATE = "relativeRate"; public static final String WEIGHT = "weight"; - public static final String RATES = "relativeRates"; + public static final String RATES = "rates"; public static final String CATEGORIES = "categories"; + public static final String PARAMETERIZATION = "parameterization"; public static final String WEIGHTS = "weights"; public String getParserName() { @@ -79,6 +80,12 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { int catCount = 4; catCount = xo.getIntegerAttribute(CATEGORIES); + + FreeRateDelegate.Parameterization parameterization = FreeRateDelegate.Parameterization.ABSOLUTE; + if (xo.hasAttribute(PARAMETERIZATION)) { + parameterization = FreeRateDelegate.Parameterization.valueOf(xo.getStringAttribute(PARAMETERIZATION)); + } + Parameter ratesParameter = (Parameter)xo.getElementFirstChild(RATES); Parameter weightsParameter = (Parameter)xo.getElementFirstChild(WEIGHTS); @@ -90,7 +97,7 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { Logger.getLogger("dr.evomodel").info("\nCreating free rate site rate model."); } - FreeRateDelegate delegate = new FreeRateDelegate("FreeRateDelegate", catCount, ratesParameter, weightsParameter); + FreeRateDelegate delegate = new FreeRateDelegate("FreeRateDelegate", catCount, parameterization, ratesParameter, weightsParameter); return new DiscretizedSiteRateModel(SiteModel.SITE_MODEL, muParam, muWeight, delegate); } @@ -118,6 +125,7 @@ public XMLSyntaxRule[] getSyntaxRules() { private final XMLSyntaxRule[] rules = { AttributeRule.newIntegerRule(CATEGORIES, true), + AttributeRule.newStringRule(PARAMETERIZATION, true), new XORRule( new ElementRule(SUBSTITUTION_RATE, new XMLSyntaxRule[]{ new ElementRule(Parameter.class) From 308e714614d9e50d0b7f02603d88226ab2b9d73c Mon Sep 17 00:00:00 2001 From: Andrew Rambaut Date: Tue, 4 Jul 2023 22:59:45 +0100 Subject: [PATCH 163/196] Added an ordered rates and weights statistic for logging (lazy calculations) --- .../DiscretizedSiteRateModel.java | 58 ++++++++++++++++++- 1 file changed, 56 insertions(+), 2 deletions(-) diff --git a/src/dr/evomodel/siteratemodel/DiscretizedSiteRateModel.java b/src/dr/evomodel/siteratemodel/DiscretizedSiteRateModel.java index 1e5e0eadb9..0a9667682d 100644 --- a/src/dr/evomodel/siteratemodel/DiscretizedSiteRateModel.java +++ b/src/dr/evomodel/siteratemodel/DiscretizedSiteRateModel.java @@ -28,6 +28,9 @@ import dr.inference.model.*; import dr.evomodel.substmodel.SubstitutionModel; +import java.util.Arrays; +import java.util.Comparator; + /** * DiscretizedSiteRateModel - A SiteModel that has a discrete categories of rates across sites. * @@ -58,14 +61,18 @@ public DiscretizedSiteRateModel( this.muWeight = muWeight; addStatistic(muStatistic); + addStatistic(ratesStatistic); + addStatistic(weightsStatistic); this.delegate = delegate; addModel(delegate); categoryRates = new double[delegate.getCategoryCount()]; categoryProportions = new double[delegate.getCategoryCount()]; + orderedCategories = new double[delegate.getCategoryCount()][2]; // for storing ordered rate/weight pairs ratesKnown = false; + orderedRatesKnown = false; } /** @@ -141,8 +148,18 @@ private void calculateCategoryRates() { } ratesKnown = true; + + orderedRatesKnown = false; } + private void calculateOrderedCategories() { + for (int i = 0; i < categoryRates.length; i++) { + orderedCategories[i][0] = categoryRates[i]; + orderedCategories[i][1] = categoryProportions[i]; + } + Arrays.sort(orderedCategories, Comparator.comparingDouble(a -> a[1])); + orderedRatesKnown = true; + } // ***************************************************************** // Interface ModelComponent // ***************************************************************** @@ -172,7 +189,7 @@ protected void acceptState() { } // no additional state needs accepting - private Statistic muStatistic = new Statistic.Abstract() { + private final Statistic muStatistic = new Statistic.Abstract() { public String getStatisticName() { return "mu"; @@ -192,6 +209,42 @@ public double getStatisticValue(int dim) { }; + private final Statistic ratesStatistic = new Statistic.Abstract() { + + public String getStatisticName() { + return "rates"; + } + + public int getDimension() { + return getCategoryCount(); + } + + public double getStatisticValue(int dim) { + if (!orderedRatesKnown) { + calculateOrderedCategories(); + } + return orderedCategories[dim][0]; + } + }; + + private final Statistic weightsStatistic = new Statistic.Abstract() { + + public String getStatisticName() { + return "weights"; + } + + public int getDimension() { + return getCategoryCount(); + } + + public double getStatisticValue(int dim) { + if (!orderedRatesKnown) { + calculateOrderedCategories(); + } + return orderedCategories[dim][1]; + } + + }; /** * mutation rate parameter @@ -200,9 +253,10 @@ public double getStatisticValue(int dim) { private final double muWeight; - private boolean ratesKnown; + private boolean ratesKnown, orderedRatesKnown; private final double[] categoryRates; + private final double[][] orderedCategories; private final double[] categoryProportions; From 0a0baf42f83969bda09516604c6772e2fef80981 Mon Sep 17 00:00:00 2001 From: twoseventwo Date: Fri, 14 Jul 2023 14:59:47 +0100 Subject: [PATCH 164/196] Fixed typos in error messages --- .../operators/TransformedParameterRandomWalkOperatorParser.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dr/inferencexml/operators/TransformedParameterRandomWalkOperatorParser.java b/src/dr/inferencexml/operators/TransformedParameterRandomWalkOperatorParser.java index f2d8a5f782..fb83725bcb 100644 --- a/src/dr/inferencexml/operators/TransformedParameterRandomWalkOperatorParser.java +++ b/src/dr/inferencexml/operators/TransformedParameterRandomWalkOperatorParser.java @@ -45,7 +45,7 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { try { randomWalk = super.parseXMLObject(xo); } catch (XMLParseException e) { - throw new XMLParseException("RandomWalkOperatorParser failled in TraansformedParameterRandomWalkOperator."); + throw new XMLParseException("RandomWalkOperatorParser failed in TransformedParameterRandomWalkOperator."); } return new TransformedParameterRandomWalkOperator((RandomWalkOperator) randomWalk); From f8f656983c23276399b7655d72b398da1c2bf3c8 Mon Sep 17 00:00:00 2001 From: twoseventwo Date: Fri, 14 Jul 2023 15:00:16 +0100 Subject: [PATCH 165/196] Answered a very very old question in comments --- src/dr/inference/operators/ScaleOperator.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dr/inference/operators/ScaleOperator.java b/src/dr/inference/operators/ScaleOperator.java index 70a12bd00e..8a0897220d 100644 --- a/src/dr/inference/operators/ScaleOperator.java +++ b/src/dr/inference/operators/ScaleOperator.java @@ -116,8 +116,8 @@ public final double doOperation() { } } else if (scaleAll) { // update all dimensions - // hasting ratio is dim-2 times of 1dim case. would be nice to have a reference here - // for the proof. It is supposed to be somewhere in an Alexei/Nicholes article. + // hasting ratio is dim-2 times of 1dim case. This can be derived easily from section 2.1 of + // https://people.maths.bris.ac.uk/~mapjg/papers/rjmcmc_20090613.pdf, ignoring the rjMCMC context if (degreesOfFreedom > 0) // For parameters with non-uniform prior on only one dimension logq = -degreesOfFreedom * Math.log(scale); From 16f75c764a73606b4204a3d0442dc665395640de Mon Sep 17 00:00:00 2001 From: twoseventwo Date: Fri, 14 Jul 2023 15:01:27 +0100 Subject: [PATCH 166/196] Added a transform to allow standard moves to work on parameters to be normalised --- .../RealDifferencesToSimplexTransform.java | 184 ++++++++++++++++++ 1 file changed, 184 insertions(+) create mode 100644 src/dr/util/RealDifferencesToSimplexTransform.java diff --git a/src/dr/util/RealDifferencesToSimplexTransform.java b/src/dr/util/RealDifferencesToSimplexTransform.java new file mode 100644 index 0000000000..2e5edcac06 --- /dev/null +++ b/src/dr/util/RealDifferencesToSimplexTransform.java @@ -0,0 +1,184 @@ +package dr.util; + +import dr.inference.model.Parameter; +import dr.math.matrixAlgebra.IllegalDimension; +import dr.math.matrixAlgebra.Matrix; +import dr.xml.*; + +public class RealDifferencesToSimplexTransform extends Transform.MultivariateTransform { + + private Parameter weights; + + public RealDifferencesToSimplexTransform(int dim, Parameter weights) { + super(dim); + this.weights = weights; + } + + public RealDifferencesToSimplexTransform(int dim) { + super(dim); + weights = new Parameter.Default(dim, (double) 1 /dim); + } + + @Override + public double[] inverse(double[] values, int from, int to, double sum) { + throw new RuntimeException("not implemented"); + } + + @Override + public double[] gradient(double[] values, int from, int to) { + throw new RuntimeException("not implemented"); + } + + @Override + public double[] gradientInverse(double[] values, int from, int to) { + throw new RuntimeException("not implemented"); + } + + @Override + public String getTransformName() { + return "realDifferencesToSimplex"; + } + + @Override + protected double[] transform(double[] values) { + + // This is a transformation of an n-dimensional vector to another n-dimensional vector but what comes _out_ is a + // simplex of one greater dimension. The weights parameter has dimension n+1. + + double[] out = new double[values.length + 1]; + + double denominator = 0; + + for(int i=0; i Date: Fri, 14 Jul 2023 19:24:36 +0100 Subject: [PATCH 167/196] Fixed GGLQ typos I inadvertently added the other week --- .../GeneralisedGaussLaguerreQuadrature.java | 22 ++++++------------- 1 file changed, 7 insertions(+), 15 deletions(-) diff --git a/src/dr/math/GeneralisedGaussLaguerreQuadrature.java b/src/dr/math/GeneralisedGaussLaguerreQuadrature.java index 0c2cc3d383..69e73be490 100644 --- a/src/dr/math/GeneralisedGaussLaguerreQuadrature.java +++ b/src/dr/math/GeneralisedGaussLaguerreQuadrature.java @@ -70,8 +70,8 @@ public void setAlphaAndB(double alpha, double B){ } private void setupArrays(){ - final int maxIterations = 10; - final double eps = 1E-14; + final int maxIterations = 110; + final double eps = 3E-14; double z = 0; @@ -79,8 +79,7 @@ private void setupArrays(){ for(int i=0; i Date: Fri, 14 Jul 2023 19:24:59 +0100 Subject: [PATCH 168/196] Fixed typos in function names --- .../siteratemodel/GammaSiteRateDelegate.java | 19 +++++++++++-------- .../siteratemodel/GammaSiteRateModel.java | 16 ++++++++-------- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/src/dr/evomodel/siteratemodel/GammaSiteRateDelegate.java b/src/dr/evomodel/siteratemodel/GammaSiteRateDelegate.java index 60280cf70a..68586e4c47 100644 --- a/src/dr/evomodel/siteratemodel/GammaSiteRateDelegate.java +++ b/src/dr/evomodel/siteratemodel/GammaSiteRateDelegate.java @@ -113,7 +113,7 @@ public void getCategories(double[] categoryRates, double[] categoryProportions) final int gammaCatCount = categoryCount - offset; if (discretizationType == DiscretizationType.QUADRATURE) { - setQuatratureRates(categoryRates, categoryProportions, alpha, gammaCatCount, offset); + setQuadratureRates(categoryRates, categoryProportions, alpha, gammaCatCount, offset); } else { setEqualRates(categoryRates, categoryProportions, alpha, gammaCatCount, offset); } @@ -222,19 +222,22 @@ public List getCitations() { * @param catCount * @param offset */ - public static void setQuatratureRates(double[] categoryRates, double[] categoryProportions, double alpha, int catCount, int offset) { + public static void setQuadratureRates(double[] categoryRates, double[] categoryProportions, double alpha, int catCount, int offset) { if (quadrature == null) { quadrature = new GeneralisedGaussLaguerreQuadrature(catCount); } - quadrature.setAlpha(alpha); + quadrature.setAlpha(alpha-1); double[] abscissae = quadrature.getAbscissae(); double[] coefficients = quadrature.getCoefficients(); for (int i = 0; i < catCount; i++) { - categoryRates[i + offset] = abscissae[i] / (alpha + 1); - categoryProportions[i + offset] = coefficients[i]/GammaFunction.gamma(alpha + 1); + categoryRates[i + offset] = abscissae[i] / (alpha); + categoryProportions[i + offset] = coefficients[i]/GammaFunction.gamma(alpha); } + + double nothing; + } /** @@ -288,7 +291,7 @@ public static void main(String[] argv) { System.out.println(i + "\t"+ categoryRates[i] +"\t" + categoryProportions[i]); } - setQuatratureRates(categoryRates, categoryProportions, 1.0, catCount, 0); + setQuadratureRates(categoryRates, categoryProportions, 1.0, catCount, 0); System.out.println(); System.out.println("Quadrature, alpha = 1.0"); System.out.println("cat\trate\tproportion"); @@ -325,7 +328,7 @@ public static void main(String[] argv) { System.out.println(i + "\t"+ categoryRates[i] +"\t" + categoryProportions[i]); } - setQuatratureRates(categoryRates, categoryProportions, 0.1, catCount, 0); + setQuadratureRates(categoryRates, categoryProportions, 0.1, catCount, 0); System.out.println(); System.out.println("Quadrature, alpha = 0.1"); System.out.println("cat\trate\tproportion"); @@ -341,7 +344,7 @@ public static void main(String[] argv) { System.out.println(i + "\t"+ categoryRates[i] +"\t" + categoryProportions[i]); } - setQuatratureRates(categoryRates, categoryProportions, 10.0, catCount, 0); + setQuadratureRates(categoryRates, categoryProportions, 10.0, catCount, 0); System.out.println(); System.out.println("Quadrature, alpha = 10.0"); System.out.println("cat\trate\tproportion"); diff --git a/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java b/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java index 9a6f32bbb9..baf84f8089 100644 --- a/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java +++ b/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java @@ -250,7 +250,7 @@ private void calculateCategoryRates() { final int gammaCatCount = categoryCount - offset; if (discretizationType == DiscretizationType.QUADRATURE) { - setQuatratureRates(categoryRates, categoryProportions, alpha, gammaCatCount, offset); + setQuadratureRates(categoryRates, categoryProportions, alpha, gammaCatCount, offset); } else { setEqualRates(categoryRates, categoryProportions, alpha, gammaCatCount, offset); } @@ -424,18 +424,18 @@ public List getCitations() { * @param catCount * @param offset */ - public static void setQuatratureRates(double[] categoryRates, double[] categoryProportions, double alpha, int catCount, int offset) { + public static void setQuadratureRates(double[] categoryRates, double[] categoryProportions, double alpha, int catCount, int offset) { if (quadrature == null) { quadrature = new GeneralisedGaussLaguerreQuadrature(catCount); } - quadrature.setAlpha(alpha); + quadrature.setAlpha(alpha-1); double[] abscissae = quadrature.getAbscissae(); double[] coefficients = quadrature.getCoefficients(); for (int i = 0; i < catCount; i++) { - categoryRates[i + offset] = abscissae[i] / (alpha + 1); - categoryProportions[i + offset] = coefficients[i] / GammaFunction.gamma(alpha + 1); + categoryRates[i + offset] = abscissae[i] / alpha; + categoryProportions[i + offset] = coefficients[i]/GammaFunction.gamma(alpha); } } @@ -490,7 +490,7 @@ public static void main(String[] argv) { System.out.println(i + "\t"+ categoryRates[i] +"\t" + categoryProportions[i]); } - setQuatratureRates(categoryRates, categoryProportions, 1.0, catCount, 0); + setQuadratureRates(categoryRates, categoryProportions, 1.0, catCount, 0); System.out.println(); System.out.println("Quadrature, alpha = 1.0"); System.out.println("cat\trate\tproportion"); @@ -527,7 +527,7 @@ public static void main(String[] argv) { System.out.println(i + "\t"+ categoryRates[i] +"\t" + categoryProportions[i]); } - setQuatratureRates(categoryRates, categoryProportions, 0.1, catCount, 0); + setQuadratureRates(categoryRates, categoryProportions, 0.1, catCount, 0); System.out.println(); System.out.println("Quadrature, alpha = 0.1"); System.out.println("cat\trate\tproportion"); @@ -543,7 +543,7 @@ public static void main(String[] argv) { System.out.println(i + "\t"+ categoryRates[i] +"\t" + categoryProportions[i]); } - setQuatratureRates(categoryRates, categoryProportions, 10.0, catCount, 0); + setQuadratureRates(categoryRates, categoryProportions, 10.0, catCount, 0); System.out.println(); System.out.println("Quadrature, alpha = 10.0"); System.out.println("cat\trate\tproportion"); From 39bafa4b4491eb3f413087622398b9c748e14f6e Mon Sep 17 00:00:00 2001 From: twoseventwo Date: Fri, 14 Jul 2023 19:25:24 +0100 Subject: [PATCH 169/196] Fixed parsers to work properly --- .../siteratemodel/GammaSiteRateModelParser.java | 12 ++++++++++-- .../treedatalikelihood/TreeDataLikelihoodParser.java | 4 ++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/dr/evomodelxml/siteratemodel/GammaSiteRateModelParser.java b/src/dr/evomodelxml/siteratemodel/GammaSiteRateModelParser.java index fdb6c54ad1..9f86bc4f2b 100644 --- a/src/dr/evomodelxml/siteratemodel/GammaSiteRateModelParser.java +++ b/src/dr/evomodelxml/siteratemodel/GammaSiteRateModelParser.java @@ -60,7 +60,7 @@ public String getParserName() { public Object parseXMLObject(XMLObject xo) throws XMLParseException { String msg = ""; - SubstitutionModel substitutionModel; + SubstitutionModel substitutionModel = null; double muWeight = 1.0; @@ -82,6 +82,10 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { msg += " with weight: " + muWeight; } } + + if(xo.hasChildNamed(SUBSTITUTION_MODEL)){ + substitutionModel = (SubstitutionModel)xo.getElementFirstChild(SUBSTITUTION_MODEL); + } int catCount = 4; catCount = xo.getIntegerAttribute(CATEGORIES); @@ -124,7 +128,11 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { GammaSiteRateDelegate delegate = new GammaSiteRateDelegate("GammaSiteRateDelegate", shapeParam, catCount, type, invarParam); - return new DiscretizedSiteRateModel(SiteModel.SITE_MODEL, muParam, muWeight, delegate); + DiscretizedSiteRateModel siteRateModel = new DiscretizedSiteRateModel(SiteModel.SITE_MODEL, muParam, muWeight, delegate); + + siteRateModel.setSubstitutionModel(substitutionModel); + + return siteRateModel; } //************************************************************************ diff --git a/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java b/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java index d0b5fc0dd3..aa1c8f3593 100644 --- a/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java +++ b/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java @@ -278,9 +278,9 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { BranchModel branchModel = (BranchModel) cxo.getChild(BranchModel.class); if (branchModel == null) { SubstitutionModel substitutionModel = (SubstitutionModel) cxo.getChild(SubstitutionModel.class); - if (substitutionModel == null && siteRateModel instanceof GammaSiteRateModel) { + if (substitutionModel == null && siteRateModel instanceof DiscretizedSiteRateModel) { // for backwards compatibility the old GammaSiteRateModelParser can provide the substitution model... - substitutionModel = ((GammaSiteRateModel)siteRateModel).getSubstitutionModel(); + substitutionModel = ((DiscretizedSiteRateModel)siteRateModel).getSubstitutionModel(); } if (substitutionModel == null) { throw new XMLParseException("No substitution model available for partition " + k + " in DataTreeLikelihood: "+xo.getId()); From e436a421d561684d02d28a8da999c126783c15d3 Mon Sep 17 00:00:00 2001 From: twoseventwo Date: Mon, 17 Jul 2023 12:45:13 +0100 Subject: [PATCH 170/196] A few trivial changes --- src/dr/evomodel/siteratemodel/GammaSiteRateDelegate.java | 2 -- src/dr/evomodelxml/siteratemodel/GammaSiteRateModelParser.java | 1 + src/dr/math/GeneralisedGaussLaguerreQuadrature.java | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/dr/evomodel/siteratemodel/GammaSiteRateDelegate.java b/src/dr/evomodel/siteratemodel/GammaSiteRateDelegate.java index 68586e4c47..4eea862668 100644 --- a/src/dr/evomodel/siteratemodel/GammaSiteRateDelegate.java +++ b/src/dr/evomodel/siteratemodel/GammaSiteRateDelegate.java @@ -236,8 +236,6 @@ public static void setQuadratureRates(double[] categoryRates, double[] categoryP categoryProportions[i + offset] = coefficients[i]/GammaFunction.gamma(alpha); } - double nothing; - } /** diff --git a/src/dr/evomodelxml/siteratemodel/GammaSiteRateModelParser.java b/src/dr/evomodelxml/siteratemodel/GammaSiteRateModelParser.java index 9f86bc4f2b..19de14142d 100644 --- a/src/dr/evomodelxml/siteratemodel/GammaSiteRateModelParser.java +++ b/src/dr/evomodelxml/siteratemodel/GammaSiteRateModelParser.java @@ -131,6 +131,7 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { DiscretizedSiteRateModel siteRateModel = new DiscretizedSiteRateModel(SiteModel.SITE_MODEL, muParam, muWeight, delegate); siteRateModel.setSubstitutionModel(substitutionModel); + siteRateModel.addModel(substitutionModel); return siteRateModel; } diff --git a/src/dr/math/GeneralisedGaussLaguerreQuadrature.java b/src/dr/math/GeneralisedGaussLaguerreQuadrature.java index 69e73be490..e975a27e39 100644 --- a/src/dr/math/GeneralisedGaussLaguerreQuadrature.java +++ b/src/dr/math/GeneralisedGaussLaguerreQuadrature.java @@ -109,7 +109,7 @@ private void setupArrays(){ throw new RuntimeException("Too many iterations"); } abscissae[i] = z; - coefficients[i] = -Math.exp(GammaFunction.lnGamma(alpha+noPoints) - GammaFunction.lnGamma((double)noPoints))/ + coefficients[i] = -Math.exp(GammaFunction.lnGamma(alpha+noPoints) - GammaFunction.lnGamma(noPoints))/ (pp*noPoints*p2); } From c422163f5a93b596172731463390a5240fe56508 Mon Sep 17 00:00:00 2001 From: rambaut Date: Mon, 17 Jul 2023 13:09:30 +0100 Subject: [PATCH 171/196] Adding a substitution model is optional in the new DiscretizedSiteRateModel --- .../siteratemodel/GammaSiteRateModelParser.java | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/dr/evomodelxml/siteratemodel/GammaSiteRateModelParser.java b/src/dr/evomodelxml/siteratemodel/GammaSiteRateModelParser.java index 19de14142d..264f9824c2 100644 --- a/src/dr/evomodelxml/siteratemodel/GammaSiteRateModelParser.java +++ b/src/dr/evomodelxml/siteratemodel/GammaSiteRateModelParser.java @@ -129,9 +129,11 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { GammaSiteRateDelegate delegate = new GammaSiteRateDelegate("GammaSiteRateDelegate", shapeParam, catCount, type, invarParam); DiscretizedSiteRateModel siteRateModel = new DiscretizedSiteRateModel(SiteModel.SITE_MODEL, muParam, muWeight, delegate); - - siteRateModel.setSubstitutionModel(substitutionModel); - siteRateModel.addModel(substitutionModel); + + if (substitutionModel != null) { + siteRateModel.setSubstitutionModel(substitutionModel); + siteRateModel.addModel(substitutionModel); + } return siteRateModel; } From 854f2842bf6fd24262e8c61204add76d388fcf22 Mon Sep 17 00:00:00 2001 From: twoseventwo Date: Wed, 19 Jul 2023 14:27:49 +0100 Subject: [PATCH 172/196] Working FreeRate completed? --- .../RealDifferencesToSimplexTransform.java | 44 ++++++++++++------- 1 file changed, 29 insertions(+), 15 deletions(-) diff --git a/src/dr/util/RealDifferencesToSimplexTransform.java b/src/dr/util/RealDifferencesToSimplexTransform.java index 2e5edcac06..33c66a10a0 100644 --- a/src/dr/util/RealDifferencesToSimplexTransform.java +++ b/src/dr/util/RealDifferencesToSimplexTransform.java @@ -49,27 +49,35 @@ protected double[] transform(double[] values) { double denominator = 0; - for(int i=0; i Date: Wed, 19 Jul 2023 14:31:24 +0100 Subject: [PATCH 173/196] Working FreeRate prototype missing commits --- .../app/beast/development_parsers.properties | 3 + .../DiscretizedSiteRateModel.java | 10 +- .../siteratemodel/FreeRateDelegate.java | 118 ++++++++++-------- .../FreeRateSiteRateModelParser.java | 33 +++-- .../TransformedMultivariateParameter.java | 22 +++- .../inference/model/TransformedParameter.java | 1 + 6 files changed, 110 insertions(+), 77 deletions(-) diff --git a/src/dr/app/beast/development_parsers.properties b/src/dr/app/beast/development_parsers.properties index f0c17bfb5b..fe41471e48 100644 --- a/src/dr/app/beast/development_parsers.properties +++ b/src/dr/app/beast/development_parsers.properties @@ -182,6 +182,7 @@ dr.inference.regression.SelfControlledCaseSeries # SITE PATTERNS dr.evomodelxml.operators.PatternWeightIncrementOperatorParser + # BRANCH SPECIFIC STUFF dr.evomodel.branchmodel.lineagespecific.CountableRealizationsParameterParser dr.evomodel.branchmodel.lineagespecific.DirichletProcessPriorParser @@ -315,6 +316,8 @@ dr.inferencexml.distribution.shrinkage.JointBayesianBridgeStatisticsParser dr.inferencexml.hmc.CompoundPriorPreconditionerParser dr.inferencexml.hmc.NumericalGradientParser +# SIMPLEX TRANSFORM +dr.util.RealDifferencesToSimplexTransform # SMOOTH SKYGRID dr.evomodelxml.coalescent.smooth.SmoothSkygridLikelihoodParser diff --git a/src/dr/evomodel/siteratemodel/DiscretizedSiteRateModel.java b/src/dr/evomodel/siteratemodel/DiscretizedSiteRateModel.java index 600584d32f..0eb21d1a5b 100644 --- a/src/dr/evomodel/siteratemodel/DiscretizedSiteRateModel.java +++ b/src/dr/evomodel/siteratemodel/DiscretizedSiteRateModel.java @@ -24,10 +24,8 @@ */ package dr.evomodel.siteratemodel; - import dr.inference.model.*; import dr.evomodel.substmodel.SubstitutionModel; - import java.util.Arrays; import java.util.Comparator; @@ -171,11 +169,11 @@ protected void handleModelChangedEvent(Model model, Object object, int index) { } protected final void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) { - if (variable == nuParameter) { +// if (variable == nuParameter) { ratesKnown = false; // MAS: I changed this because the rate parameter can affect the categories if the parameter is in siteModel and not clockModel - } else { - throw new RuntimeException("Unknown variable in DiscretizedSiteRateModel.handleVariableChangedEvent"); - } +// } else { +// throw new RuntimeException("Unknown variable in DiscretizedSiteRateModel.handleVariableChangedEvent"); +// } listenerHelper.fireModelChanged(this, variable, index); } diff --git a/src/dr/evomodel/siteratemodel/FreeRateDelegate.java b/src/dr/evomodel/siteratemodel/FreeRateDelegate.java index 6b224e1fdc..06d4e37757 100644 --- a/src/dr/evomodel/siteratemodel/FreeRateDelegate.java +++ b/src/dr/evomodel/siteratemodel/FreeRateDelegate.java @@ -49,13 +49,13 @@ public class FreeRateDelegate extends AbstractModel implements SiteRateDelegate, Citable { - public static final Parameterization DEFAULT_PARAMETERIZATION = Parameterization.ABSOLUTE; +/* public static final Parameterization DEFAULT_PARAMETERIZATION = Parameterization.ABSOLUTE; public enum Parameterization { ABSOLUTE, RATIOS, DIFFERENCES - }; + };*/ @@ -66,31 +66,31 @@ public enum Parameterization { public FreeRateDelegate( String name, int categoryCount, - Parameterization parameterization, + /* Parameterization parameterization,*/ Parameter rateParameter, Parameter weightParameter) { super(name); this.categoryCount = categoryCount; - this.parameterization = parameterization; +// this.parameterization = parameterization; this.rateParameter = rateParameter; - if (parameterization == Parameterization.ABSOLUTE) { - if (this.rateParameter.getDimension() == 1) { - this.rateParameter.setDimension(categoryCount); - } else if (this.rateParameter.getDimension() != categoryCount) { - throw new IllegalArgumentException("Rate parameter should have have an initial dimension of one or category count"); - } - this.rateParameter.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0, categoryCount)); - } else { - if (this.rateParameter.getDimension() == 1) { - this.rateParameter.setDimension(categoryCount - 1); - } else if (this.rateParameter.getDimension() != categoryCount - 1) { - throw new IllegalArgumentException("Rate parameter should have have an initial dimension of one or category count - 1"); - } - this.rateParameter.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0, categoryCount - 1)); - } +// if (parameterization == Parameterization.ABSOLUTE) { +// if (this.rateParameter.getDimension() == 1) { +// this.rateParameter.setDimension(categoryCount); +// } else if (this.rateParameter.getDimension() != categoryCount) { +// throw new IllegalArgumentException("Rate parameter should have have an initial dimension of one or category count"); +// } +// this.rateParameter.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0, categoryCount)); +// } else { +// if (this.rateParameter.getDimension() == 1) { +// this.rateParameter.setDimension(categoryCount - 1); +// } else if (this.rateParameter.getDimension() != categoryCount - 1) { +// throw new IllegalArgumentException("Rate parameter should have have an initial dimension of one or category count - 1"); +// } +// this.rateParameter.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0, categoryCount - 1)); +// } addVariable(this.rateParameter); this.weightParameter = weightParameter; @@ -101,6 +101,7 @@ public FreeRateDelegate( } addVariable(this.weightParameter); + this.weightParameter.addBounds(new Parameter.DefaultBounds(1.0, 0.0, categoryCount)); } @@ -116,43 +117,44 @@ public void getCategories(double[] categoryRates, double[] categoryProportions) assert categoryRates != null && categoryRates.length == categoryCount; assert categoryProportions != null && categoryProportions.length == categoryCount; - if (parameterization == Parameterization.ABSOLUTE) { - double sumRates = 0.0; - double sumWeights = 0.0; - for (int i = 0; i < categoryCount; i++) { - categoryRates[i] = rateParameter.getParameterValue(i); - sumRates += categoryRates[i]; - categoryProportions[i] = weightParameter.getParameterValue(i); - sumWeights += categoryProportions[i]; - } - assert Math.abs(sumRates - categoryCount) < 1E-10; - assert Math.abs(sumWeights - 1.0) < 1E-10; - } else { - categoryRates[0] = 1.0; - double sumRates = 0.0; - double sumWeights = 0.0; - for (int i = 0; i < categoryCount; i++) { - if (parameterization == Parameterization.RATIOS) { - if (i > 0) { - categoryRates[i] = categoryRates[i - 1] * rateParameter.getParameterValue(i); - } - } else { // Parameterization.DIFFERENCES - categoryRates[i] = categoryRates[i - 1] + rateParameter.getParameterValue(i); - } - sumRates += categoryRates[i + 1]; - - categoryProportions[i] = weightParameter.getParameterValue(i); - sumWeights += categoryProportions[i]; - } - assert Math.abs(sumWeights - 1.0) < 1E-10; - - // scale so their mean is 1 - for (int i = 0; i < categoryCount; i++) { - categoryRates[i] = categoryCount * categoryRates[i] / sumRates; - } + +// if (parameterization == Parameterization.ABSOLUTE) { + double meanRate = 0.0; + double sumWeights = 0.0; + for (int i = 0; i < categoryCount; i++) { + categoryRates[i] = rateParameter.getParameterValue(i); + categoryProportions[i] = weightParameter.getParameterValue(i); + meanRate += categoryRates[i]*categoryProportions[i]; + sumWeights += categoryProportions[i]; } + assert Math.abs(meanRate - 1.0) < 1E-10; + assert Math.abs(sumWeights - 1.0) < 1E-10; +// } else { +// categoryRates[0] = 1.0; +// double sumRates = 0.0; +// double sumWeights = 0.0; +// for (int i = 0; i < categoryCount; i++) { +// if (parameterization == Parameterization.RATIOS) { +// if (i > 0) { +// categoryRates[i] = categoryRates[i - 1] * rateParameter.getParameterValue(i); +// } +// } else { // Parameterization.DIFFERENCES +// categoryRates[i] = categoryRates[i - 1] + rateParameter.getParameterValue(i); +// } +// sumRates += categoryRates[i + 1]; +// +// categoryProportions[i] = weightParameter.getParameterValue(i); +// sumWeights += categoryProportions[i]; +// } +// assert Math.abs(sumWeights - 1.0) < 1E-10; +// +// // scale so their mean is 1 +// for (int i = 0; i < categoryCount; i++) { +// categoryRates[i] = categoryCount * categoryRates[i] / sumRates; +// } +// } } - + // ***************************************************************** // Interface ModelComponent // ***************************************************************** @@ -162,6 +164,12 @@ protected void handleModelChangedEvent(Model model, Object object, int index) { } protected final void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) { + + + if(variable==weightParameter){ + rateParameter.fireParameterChangedEvent(); + } + listenerHelper.fireModelChanged(this, variable, index); } @@ -188,7 +196,7 @@ protected void acceptState() { private final int categoryCount; - private final Parameterization parameterization; +// private final Parameterization parameterization; @Override public Citation.Category getCategory() { diff --git a/src/dr/evomodelxml/siteratemodel/FreeRateSiteRateModelParser.java b/src/dr/evomodelxml/siteratemodel/FreeRateSiteRateModelParser.java index c6257f4eb7..9b10b33d3c 100644 --- a/src/dr/evomodelxml/siteratemodel/FreeRateSiteRateModelParser.java +++ b/src/dr/evomodelxml/siteratemodel/FreeRateSiteRateModelParser.java @@ -29,6 +29,8 @@ import dr.evomodel.siteratemodel.FreeRateDelegate; import dr.evomodel.substmodel.SubstitutionModel; import dr.inference.model.Parameter; +import dr.inference.model.Variable; +import dr.inference.model.VariableListener; import dr.oldevomodel.sitemodel.SiteModel; import dr.xml.*; @@ -43,13 +45,14 @@ public class FreeRateSiteRateModelParser extends AbstractXMLObjectParser { public static final String FREE_RATE_SITE_RATE_MODEL = "freeRateSiteRateModel"; + public static final String SUBSTITUTION_MODEL = "substitutionModel"; public static final String MUTATION_RATE = "mutationRate"; public static final String SUBSTITUTION_RATE = "substitutionRate"; public static final String RELATIVE_RATE = "relativeRate"; public static final String WEIGHT = "weight"; public static final String RATES = "rates"; public static final String CATEGORIES = "categories"; - public static final String PARAMETERIZATION = "parameterization"; +// public static final String PARAMETERIZATION = "parameterization"; public static final String WEIGHTS = "weights"; public String getParserName() { @@ -59,7 +62,7 @@ public String getParserName() { public Object parseXMLObject(XMLObject xo) throws XMLParseException { String msg = ""; - SubstitutionModel substitutionModel; + SubstitutionModel substitutionModel = null; double muWeight = 1.0; @@ -78,16 +81,19 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { } } + if(xo.hasChildNamed(SUBSTITUTION_MODEL)){ + substitutionModel = (SubstitutionModel)xo.getElementFirstChild(SUBSTITUTION_MODEL); + } + int catCount = 4; catCount = xo.getIntegerAttribute(CATEGORIES); - FreeRateDelegate.Parameterization parameterization = FreeRateDelegate.Parameterization.ABSOLUTE; - if (xo.hasAttribute(PARAMETERIZATION)) { - parameterization = FreeRateDelegate.Parameterization.valueOf(xo.getStringAttribute(PARAMETERIZATION)); - } +// FreeRateDelegate.Parameterization parameterization = FreeRateDelegate.Parameterization.ABSOLUTE; +// if (xo.hasAttribute(PARAMETERIZATION)) { +// parameterization = FreeRateDelegate.Parameterization.valueOf(xo.getStringAttribute(PARAMETERIZATION)); +// } Parameter ratesParameter = (Parameter)xo.getElementFirstChild(RATES); - Parameter weightsParameter = (Parameter)xo.getElementFirstChild(WEIGHTS); msg += "\n " + catCount + " category discrete free rate site rate heterogeneity model)"; @@ -97,9 +103,16 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { Logger.getLogger("dr.evomodel").info("\nCreating free rate site rate model."); } - FreeRateDelegate delegate = new FreeRateDelegate("FreeRateDelegate", catCount, parameterization, ratesParameter, weightsParameter); + FreeRateDelegate delegate = new FreeRateDelegate("FreeRateDelegate", catCount, +// parameterization, + ratesParameter, weightsParameter); + + DiscretizedSiteRateModel siteRateModel = new DiscretizedSiteRateModel(SiteModel.SITE_MODEL, muParam, muWeight, delegate); + + siteRateModel.setSubstitutionModel(substitutionModel); + siteRateModel.addModel(substitutionModel); - return new DiscretizedSiteRateModel(SiteModel.SITE_MODEL, muParam, muWeight, delegate); + return siteRateModel; } //************************************************************************ @@ -125,7 +138,7 @@ public XMLSyntaxRule[] getSyntaxRules() { private final XMLSyntaxRule[] rules = { AttributeRule.newIntegerRule(CATEGORIES, true), - AttributeRule.newStringRule(PARAMETERIZATION, true), +// AttributeRule.newStringRule(PARAMETERIZATION, true), new XORRule( new ElementRule(SUBSTITUTION_RATE, new XMLSyntaxRule[]{ new ElementRule(Parameter.class) diff --git a/src/dr/inference/model/TransformedMultivariateParameter.java b/src/dr/inference/model/TransformedMultivariateParameter.java index c9eade96bb..490de1d19a 100644 --- a/src/dr/inference/model/TransformedMultivariateParameter.java +++ b/src/dr/inference/model/TransformedMultivariateParameter.java @@ -55,19 +55,24 @@ public double getParameterValue(int dim) { public void setParameterValue(int dim, double value) { update(); - transformedValues[dim] = value; - unTransformedValues = inverse(transformedValues); + unTransformedValues[dim] = value; +/* transformedValues[dim] = value; + unTransformedValues = inverse(transformedValues);*/ // Need to update all values parameter.setParameterValueNotifyChangedAll(0, unTransformedValues[0]); // Warn everyone is changed for (int i = 1; i < parameter.getDimension(); i++) { parameter.setParameterValueQuietly(i, unTransformedValues[i]); // Do the rest quietly } + transformedValues = transform(unTransformedValues); } public void setParameterValueQuietly(int dim, double value) { update(); - transformedValues[dim] = value; - unTransformedValues = inverse(transformedValues); + unTransformedValues[dim] = value; + transformedValues = transform(unTransformedValues); + +/* transformedValues[dim] = value; + unTransformedValues = inverse(transformedValues);*/ // Need to update all values for (int i = 0; i < parameter.getDimension(); i++) { parameter.setParameterValueQuietly(i, unTransformedValues[i]); @@ -91,18 +96,23 @@ public void addBounds(Bounds bounds) { // } private void update() { - if (hasChanged()) { + +// if (hasChanged()) { unTransformedValues = parameter.getParameterValues(); transformedValues = transform(unTransformedValues); - } +// } } private boolean hasChanged() { + + for (int i = 0; i < unTransformedValues.length; i++) { if (parameter.getParameterValue(i) != unTransformedValues[i]) { return true; } } + + return false; } } diff --git a/src/dr/inference/model/TransformedParameter.java b/src/dr/inference/model/TransformedParameter.java index 8dacd02984..b9e8106887 100644 --- a/src/dr/inference/model/TransformedParameter.java +++ b/src/dr/inference/model/TransformedParameter.java @@ -39,6 +39,7 @@ public TransformedParameter(Parameter parameter, Transform transform) { } public TransformedParameter(Parameter parameter, Transform transform, boolean inverse) { + this.parameter = parameter; this.transform = transform; this.inverse = inverse; From 90edd1741d10744be2bf8ed11b08c94e0ae31311 Mon Sep 17 00:00:00 2001 From: rambaut Date: Tue, 15 Aug 2023 10:09:17 +0100 Subject: [PATCH 174/196] New Tree operator mix is now default with the old one religated to 'Classic' --- src/dr/app/beauti/operatorspanel/OperatorsPanel.java | 2 +- src/dr/app/beauti/options/PartitionTreeModel.java | 4 ++-- src/dr/app/beauti/types/OperatorSetType.java | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/dr/app/beauti/operatorspanel/OperatorsPanel.java b/src/dr/app/beauti/operatorspanel/OperatorsPanel.java index 12ac35d0c3..537f605dc0 100644 --- a/src/dr/app/beauti/operatorspanel/OperatorsPanel.java +++ b/src/dr/app/beauti/operatorspanel/OperatorsPanel.java @@ -64,7 +64,7 @@ public class OperatorsPanel extends BeautiPanel implements Exportable { JComboBox operatorSetCombo = new JComboBox(new OperatorSetType[] { OperatorSetType.DEFAULT, OperatorSetType.FIXED_TREE_TOPOLOGY, - OperatorSetType.NEW_TREE_MIX, + OperatorSetType.CLASSIC, OperatorSetType.ADAPTIVE_MULTIVARIATE, OperatorSetType.CUSTOM, OperatorSetType.HMC diff --git a/src/dr/app/beauti/options/PartitionTreeModel.java b/src/dr/app/beauti/options/PartitionTreeModel.java index 26e8ea1cfd..2284ddec62 100644 --- a/src/dr/app/beauti/options/PartitionTreeModel.java +++ b/src/dr/app/beauti/options/PartitionTreeModel.java @@ -177,10 +177,10 @@ public List selectOperators(List operators) { // if not a fixed tree then sample tree space if (options.operatorSetType == OperatorSetType.DEFAULT) { + newTreeOperatorsInUse = true; // default is now the new tree operators + } else if (options.operatorSetType == OperatorSetType.CLASSIC) { defaultInUse = true; branchesInUse = true; - } else if (options.operatorSetType == OperatorSetType.NEW_TREE_MIX) { - newTreeOperatorsInUse = true; } else if (options.operatorSetType == OperatorSetType.FIXED_TREE_TOPOLOGY) { branchesInUse = true; } else if (options.operatorSetType == OperatorSetType.ADAPTIVE_MULTIVARIATE) { diff --git a/src/dr/app/beauti/types/OperatorSetType.java b/src/dr/app/beauti/types/OperatorSetType.java index 6ddf6fac82..00c824861b 100644 --- a/src/dr/app/beauti/types/OperatorSetType.java +++ b/src/dr/app/beauti/types/OperatorSetType.java @@ -30,9 +30,9 @@ */ public enum OperatorSetType { - DEFAULT("classic operator mix"), + DEFAULT("new tree operator mix"), FIXED_TREE_TOPOLOGY("fixed tree topology"), - NEW_TREE_MIX("new tree operator mix"), + CLASSIC("classic tree operator mix"), ADAPTIVE_MULTIVARIATE("adaptive multivariate"), CUSTOM("custom operator mix"), HMC("Hamiltonian Monte Carlo"); From d7ed663dd98c3200e87407aeda97ffb3ad9e1009 Mon Sep 17 00:00:00 2001 From: rambaut Date: Tue, 15 Aug 2023 15:07:50 +0100 Subject: [PATCH 175/196] Removing a non utf8 character in a comment --- src/dr/app/checkpoint/BeastCheckpointer.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dr/app/checkpoint/BeastCheckpointer.java b/src/dr/app/checkpoint/BeastCheckpointer.java index d171fc8b13..d085eabbc0 100644 --- a/src/dr/app/checkpoint/BeastCheckpointer.java +++ b/src/dr/app/checkpoint/BeastCheckpointer.java @@ -639,7 +639,7 @@ protected long readStateFromFile(File file, MarkovChain markovChain, double[] ln System.out.println("adopting tree structure"); } - //adopt the loaded tree structure; + //adopt the loaded tree structure ((TreeModel) model).beginTreeEdit(); ((TreeModel) model).adoptTreeStructure(parents, nodeHeights, childOrder, taxaNames); if (traitModels.size() > 0) { From cb4192a19ef5159fa5e5f5c8f22a329f32521a2e Mon Sep 17 00:00:00 2001 From: Andrew Rambaut Date: Tue, 15 Aug 2023 21:49:42 +0100 Subject: [PATCH 176/196] Fixed a few parser warnings --- src/dr/app/beast/development_parsers.properties | 4 ++-- src/dr/app/beast/release_parsers.properties | 9 ++++----- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/dr/app/beast/development_parsers.properties b/src/dr/app/beast/development_parsers.properties index eeb4fe2bf3..4a690fb8d8 100644 --- a/src/dr/app/beast/development_parsers.properties +++ b/src/dr/app/beast/development_parsers.properties @@ -122,9 +122,9 @@ dr.inferencexml.distribution.FactorTreeGibbsOperatorParser dr.inferencexml.operators.factorAnalysis.LatentFactorLiabilityGibbsOperatorParser dr.inferencexml.operators.JointGibbsOperatorParser dr.inferencexml.operators.factorAnalysis.LFMSplitMergeOperatorParser -dr.evomodel.continuous.hmc.IntegratedLoadingsGradient +#dr.evomodel.continuous.hmc.IntegratedLoadingsGradient ## doesn't seem to exist? dr.inferencexml.hmc.LoadingsTransformParser -dr.evomodel.continuous.hmc.TaskPool +#dr.evomodel.continuous.hmc.TaskPool ## doesn't seem to exist? dr.inference.model.LogOrderedMatrix dr.inference.operators.factorAnalysis.LoadingsRotationOperator dr.inference.model.SVDStatistic diff --git a/src/dr/app/beast/release_parsers.properties b/src/dr/app/beast/release_parsers.properties index cef4cdd2f0..7492a1ac7b 100644 --- a/src/dr/app/beast/release_parsers.properties +++ b/src/dr/app/beast/release_parsers.properties @@ -458,13 +458,12 @@ dr.inferencexml.trace.ArithmeticMeanAnalysisParser #GMRF dr.evomodelxml.coalescent.operators.GMRFSkyrideFixedEffectsGibbsOperatorParser dr.evomodelxml.coalescent.operators.GMRFSkyrideBlockUpdateOperatorParser -dr.evomodelxml.coalescent.operators.GMRFSkygridBlockUpdateOperatorParser dr.evomodelxml.coalescent.GMRFSkyrideLikelihoodParser dr.evomodelxml.coalescent.GMRFSkyrideGradientParser dr.evomodelxml.coalescent.BayesianSkylineGradientParser dr.evomodelxml.coalescent.CoalescentGradientParser dr.evomodelxml.coalescent.GMRFIntervalHeightsStatisticParser -dr.evomodelxml.coalescent.GMRFTestLikelihoodParser +#dr.evomodelxml.coalescent.GMRFTestLikelihoodParser dr.evomodelxml.coalescent.GMRFPopSizeStatisticParser dr.evomodelxml.coalescent.GMRFBivariateCurveAnalysisParser @@ -605,9 +604,9 @@ dr.evomodel.arg.ARGLogger dr.evoxml.MicrosatellitePatternStatisticParser # N & S counting -dr.inference.trace.DnDsPerSiteAnalysis -dr.inference.trace.CnCsPerSiteAnalysis -dr.inference.trace.CnCsToDnDsPerSiteAnalysis +dr.evomodel.trace.DnDsPerSiteAnalysis +dr.evomodel.trace.CnCsPerSiteAnalysis +dr.evomodel.trace.CnCsToDnDsPerSiteAnalysis # MARGINAL LIKELIHOOD ESTIMATION dr.inference.mcmc.MarginalLikelihoodEstimator From 6ded872292e0136d4d8fa136a9471c6a0583dec6 Mon Sep 17 00:00:00 2001 From: Andrew Rambaut Date: Tue, 15 Aug 2023 21:50:17 +0100 Subject: [PATCH 177/196] Added a check for an object with an id but no parser (likely a missing parser). --- src/dr/xml/XMLParser.java | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/dr/xml/XMLParser.java b/src/dr/xml/XMLParser.java index 9a83e4a73c..ff909d736b 100644 --- a/src/dr/xml/XMLParser.java +++ b/src/dr/xml/XMLParser.java @@ -341,6 +341,14 @@ private Object convert(Element e, Class target, XMLObject parent, boolean run, b } xo.setNativeObject(obj); + } else { + // The element doesn't have a specific parser so is likely to be an internal + // element to another parser. However, it has an ID then it is likely to be + // something that was intended to parse so give a warning. + if (e.hasAttribute(ID)) { // object has ID + java.util.logging.Logger.getLogger("dr.xml").warning("Element called, " + xo.getName() + + ", has an ID, " + e.getAttribute(ID) + ", but no parser."); + } } if (id != null) { From 6bc17280c99dfa2fb86feef212dcec5647fe086d Mon Sep 17 00:00:00 2001 From: rambaut Date: Wed, 16 Aug 2023 10:26:30 +0100 Subject: [PATCH 178/196] Renaming `-beagle_thread_count` to `-beagle_threads` for consistency (old option still recognised) --- src/dr/app/beast/BeastMain.java | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/dr/app/beast/BeastMain.java b/src/dr/app/beast/BeastMain.java index a76194eb68..6dc18a10cc 100644 --- a/src/dr/app/beast/BeastMain.java +++ b/src/dr/app/beast/BeastMain.java @@ -352,7 +352,7 @@ public static void main(String[] args) throws java.io.IOException { new Arguments.StringOption("prefix", "PREFIX", "Specify a prefix for all output log filenames"), new Arguments.Option("overwrite", "Allow overwriting of log files"), new Arguments.IntegerOption("errors", "Specify maximum number of numerical errors before stopping"), - new Arguments.IntegerOption("threads", "The number of computational threads to use (default auto)"), + new Arguments.IntegerOption("threads", "The maximum number of computational threads to use (default auto)"), new Arguments.Option("fail_threads", "Exit with error on uncaught exception in thread."), new Arguments.Option("java", "Use Java only, no native implementations"), new Arguments.LongOption("tests", "The number of full evaluation tests to perform (default 1000)"), @@ -371,8 +371,8 @@ public static void main(String[] args) throws java.io.IOException { new Arguments.Option("beagle_GPU", "BEAGLE: use GPU instance if available"), new Arguments.Option("beagle_SSE", "BEAGLE: use SSE extensions if available"), new Arguments.Option("beagle_SSE_off", "BEAGLE: turn off use of SSE extensions"), - new Arguments.Option("beagle_threading_off", "BEAGLE: turn off auto threading for a CPU instance"), - new Arguments.IntegerOption("beagle_thread_count", 1, Integer.MAX_VALUE, "BEAGLE: manually set number of threads for a CPU instance"), + new Arguments.Option("beagle_threading_off", "BEAGLE: turn off multi-threading for a CPU instance"), + new Arguments.IntegerOption("beagle_threads", 1, Integer.MAX_VALUE, "BEAGLE: manually set number of threads per CPU instance (default auto)"), new Arguments.Option("beagle_cuda", "BEAGLE: use CUDA parallization if available"), new Arguments.Option("beagle_opencl", "BEAGLE: use OpenCL parallization if available"), new Arguments.Option("beagle_single", "BEAGLE: use single precision if available"), @@ -575,6 +575,9 @@ public static void main(String[] args) throws java.io.IOException { if (arguments.hasOption("beagle_thread_count")) { System.setProperty("beagle.thread.count", Integer.toString(arguments.getIntegerOption("beagle_thread_count"))); } + if (arguments.hasOption("beagle_threads")) { + System.setProperty("beagle.thread.count", Integer.toString(arguments.getIntegerOption("beagle_threads"))); + } if (arguments.hasOption("beagle_threading_off")) { System.setProperty("beagle.thread.count", Integer.toString(1)); } From bd0a5b664fcb425ae0d5a7e14dbc14d574d708f2 Mon Sep 17 00:00:00 2001 From: rambaut Date: Wed, 16 Aug 2023 12:18:24 +0100 Subject: [PATCH 179/196] Removing use of beagle_instance for old and deprecated 'BeagleTreeLikelihood'. --- .../BeagleTreeLikelihoodParser.java | 85 ++++--------------- 1 file changed, 17 insertions(+), 68 deletions(-) diff --git a/src/dr/evomodelxml/treelikelihood/BeagleTreeLikelihoodParser.java b/src/dr/evomodelxml/treelikelihood/BeagleTreeLikelihoodParser.java index 5c0204b8a6..6f78678bde 100644 --- a/src/dr/evomodelxml/treelikelihood/BeagleTreeLikelihoodParser.java +++ b/src/dr/evomodelxml/treelikelihood/BeagleTreeLikelihoodParser.java @@ -25,8 +25,6 @@ package dr.evomodelxml.treelikelihood; -//import dr.app.beagle.evomodel.treelikelihood.RestrictedPartialsSequenceLikelihood; - import dr.evolution.tree.MutableTreeModel; import dr.evolution.tree.TreeUtils; import dr.evomodel.branchmodel.BranchModel; @@ -106,15 +104,6 @@ protected BeagleTreeLikelihood createTreeLikelihood(PatternList patternList, Mut public Object parseXMLObject(XMLObject xo) throws XMLParseException { boolean useAmbiguities = xo.getAttribute(USE_AMBIGUITIES, false); - int instanceCount = xo.getAttribute(INSTANCE_COUNT, 1); - if (instanceCount < 1) { - instanceCount = 1; - } - - String ic = System.getProperty(BEAGLE_INSTANCE_COUNT); - if (ic != null && ic.length() > 0) { - instanceCount = Integer.parseInt(ic); - } PatternList patternList = (PatternList) xo.getChild(PatternList.class); MutableTreeModel treeModel = (MutableTreeModel) xo.getChild(MutableTreeModel.class); @@ -180,67 +169,27 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { } if (beagleThreadCount == -1) { - // the default is -1 threads (automatic thread pool size) but an XML attribute can override it - int threadCount = xo.getAttribute(THREADS, -1); - - if (System.getProperty(THREAD_COUNT) != null) { - threadCount = Integer.parseInt(System.getProperty(THREAD_COUNT)); - } - - // Todo: allow for different number of threads per beagle instance according to pattern counts - if (threadCount >= 0) { - System.setProperty(BEAGLE_THREAD_COUNT, Integer.toString(threadCount / instanceCount)); - } - } - - if (instanceCount == 1 || patternList.getPatternCount() < instanceCount) { - return createTreeLikelihood( - patternList, - treeModel, - branchModel, - siteRateModel, - branchRateModel, - tipStatesModel, - useAmbiguities, - scalingScheme, - delayScaling, - partialsRestrictions, - xo - ); - } - - // using multiple instances of BEAGLE... - -// if (!(patternList instanceof SitePatterns)) { -// throw new XMLParseException("BEAGLE_INSTANCES option cannot be used with BEAUti-selected codon partitioning."); -// } + // no beagle_thread_count is given so use the number of available processors + // (actually logical threads - so 2 x number of cores when hyperthreads are used). - if (tipStatesModel != null) { - throw new XMLParseException("BEAGLE_INSTANCES option cannot be used with a TipStateModel (i.e., a sequence error model)."); + beagleThreadCount = Runtime.getRuntime().availableProcessors(); } - List likelihoods = new ArrayList(); - for (int i = 0; i < instanceCount; i++) { + System.setProperty(BEAGLE_THREAD_COUNT, Integer.toString(beagleThreadCount)); - Patterns subPatterns = new Patterns(patternList, i, instanceCount); - - AbstractTreeLikelihood treeLikelihood = createTreeLikelihood( - subPatterns, - treeModel, - branchModel, - siteRateModel, - branchRateModel, - null, - useAmbiguities, - scalingScheme, - delayScaling, - partialsRestrictions, - xo); - treeLikelihood.setId(xo.getId() + "_" + instanceCount); - likelihoods.add(treeLikelihood); - } - - return new CompoundLikelihood(likelihoods); + return createTreeLikelihood( + patternList, + treeModel, + branchModel, + siteRateModel, + branchRateModel, + tipStatesModel, + useAmbiguities, + scalingScheme, + delayScaling, + partialsRestrictions, + xo + ); } //************************************************************************ From 405247e5627b45c776ed3225853be958fc8876e5 Mon Sep 17 00:00:00 2001 From: rambaut Date: Wed, 16 Aug 2023 12:19:33 +0100 Subject: [PATCH 180/196] Changing logic for beagle_thread_count and beagle_instances and adding better reporting --- .../BeagleDataLikelihoodDelegate.java | 7 +- .../TreeDataLikelihoodParser.java | 80 +++++++++---------- 2 files changed, 42 insertions(+), 45 deletions(-) diff --git a/src/dr/evomodel/treedatalikelihood/BeagleDataLikelihoodDelegate.java b/src/dr/evomodel/treedatalikelihood/BeagleDataLikelihoodDelegate.java index cea0c55d72..a727647257 100644 --- a/src/dr/evomodel/treedatalikelihood/BeagleDataLikelihoodDelegate.java +++ b/src/dr/evomodel/treedatalikelihood/BeagleDataLikelihoodDelegate.java @@ -355,7 +355,7 @@ public BeagleDataLikelihoodDelegate(Tree tree, // start auto resource selection String resourceAuto = System.getProperty(RESOURCE_AUTO_PROPERTY); - if (resourceAuto != null && Boolean.parseBoolean(resourceAuto)) { + if (Boolean.parseBoolean(resourceAuto)) { long benchmarkFlags = 0; @@ -443,7 +443,10 @@ public BeagleDataLikelihoodDelegate(Tree tree, if (threadCount > 0) { beagle.setCPUThreadCount(threadCount); logger.info(" Using " + threadCount + " threads for CPU."); - } else { // if no thread_count is specified then this will be -1 so put no upper bound on threads + } else { + // if no thread_count is specified then this will be -1 so put no upper bound on threads + // currently the parser provides a default based on the number of cores as BEAGLE's + // default is suboptimal logger.info(" Using default thread count for CPU."); // this is just intended to remove the cap on number of threads so BEAGLE will // make its own decision (for better or worse). diff --git a/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java b/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java index b87a347610..b45e444f49 100644 --- a/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java +++ b/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java @@ -122,35 +122,44 @@ protected Likelihood createTreeDataLikelihood(List patternLists, } boolean useJava = Boolean.parseBoolean(System.getProperty("java.only", "false")); + if (useJava) { + logger.warning(" Java-only computation is not available - ignoring this option."); + } - int threadCount = -1; int beagleThreadCount = -1; if (System.getProperty(BEAGLE_THREAD_COUNT) != null) { + // if beagle_thread_count is set then use that - this is a per-instance thread count beagleThreadCount = Integer.parseInt(System.getProperty(BEAGLE_THREAD_COUNT)); } + int beagleInstanceCount = 1; + String ic = System.getProperty(BEAGLE_INSTANCE_COUNT); + if (ic != null && ic.length() > 0) { + beagleInstanceCount = Math.max(1, Integer.parseInt(ic)); + } + if (beagleThreadCount == -1) { - // Todo: can't access XML object here, perhaps need to refactor - // the default is -1 threads (automatic thread pool size) but an XML attribute can override it - // int threadCount = xo.getAttribute(THREADS, -1); + // no beagle_thread_count is given so use the number of available processors + // (actually logical threads - so 2 x number of cores when hyperthreads are used). - if (System.getProperty(THREAD_COUNT) != null) { - threadCount = Integer.parseInt(System.getProperty(THREAD_COUNT)); - } - } + beagleThreadCount = Runtime.getRuntime().availableProcessors(); - int instanceCount = 0; - String ic = System.getProperty(BEAGLE_INSTANCE_COUNT); - if (ic != null && ic.length() > 0) { - instanceCount = Integer.parseInt(ic); + // 'threadCount' controls the top level number of Java threads holding the + // likelihood/prior evaluations. Shouldn't be considered here - by default + // this will use an autosizing thread pool so should probably be left alone. + // if (System.getProperty(THREAD_COUNT) != null) { + // threadCount = Integer.parseInt(System.getProperty(THREAD_COUNT)); + // } } - if ( useBeagle3MultiPartition && instanceCount == 0 && !useJava) { + if ( useBeagle3MultiPartition) { - if (beagleThreadCount == -1 && threadCount >= 0) { - System.setProperty(BEAGLE_THREAD_COUNT, Integer.toString(threadCount)); + if ( beagleInstanceCount > 1) { + logger.warning(" BEAGLE multi-partition extensions are not compatible with -beagle_instances option"); } + System.setProperty(BEAGLE_THREAD_COUNT, Integer.toString(beagleThreadCount)); + try { DataLikelihoodDelegate dataLikelihoodDelegate = new MultiPartitionDataLikelihoodDelegate( treeModel, @@ -176,38 +185,24 @@ protected Likelihood createTreeDataLikelihood(List patternLists, List treeDataLikelihoods = new ArrayList(); // Todo: allow for different number of threads per beagle instance according to pattern counts - if (beagleThreadCount == -1 && threadCount >= 0) { - System.setProperty(BEAGLE_THREAD_COUNT, Integer.toString(threadCount / patternLists.size())); - } +// if (beagleThreadCount == -1 && threadCount >= 0) { +// System.setProperty(BEAGLE_THREAD_COUNT, Integer.toString(threadCount / patternLists.size())); +// } - if (instanceCount > 1) { - logger.info(" Dividing each partition amongst " + instanceCount + " BEAGLE instances:"); + if (beagleInstanceCount > 1) { + logger.info(" Dividing each partition amongst " + beagleInstanceCount + " BEAGLE instances:"); } + for (int i = 0; i < patternLists.size(); i++) { - if (instanceCount > 1) { - for (int j = 0; j < instanceCount; j++) { - PatternList patterns = new Patterns(patternLists.get(i), j, instanceCount); - DataLikelihoodDelegate dataLikelihoodDelegate = new BeagleDataLikelihoodDelegate( - treeModel, - patterns, - branchModels.get(i), - siteRateModels.get(i), - useAmbiguities, - preferGPU, - scalingScheme, - delayRescalingUntilUnderflow, - settings); - - treeDataLikelihoods.add( - new TreeDataLikelihood( - dataLikelihoodDelegate, - treeModel, - branchRateModel)); - } - } else { + PatternList partitionPatterns = patternLists.get(i); + // can't divide up a partition by more than the number of patterns... + int bic = Math.min(partitionPatterns.getPatternCount(), beagleInstanceCount); + + for (int j = 0; j < bic; j++) { + PatternList patterns = new Patterns(partitionPatterns, j, bic); DataLikelihoodDelegate dataLikelihoodDelegate = new BeagleDataLikelihoodDelegate( treeModel, - patternLists.get(i), + patterns, branchModels.get(i), siteRateModels.get(i), useAmbiguities, @@ -221,7 +216,6 @@ protected Likelihood createTreeDataLikelihood(List patternLists, dataLikelihoodDelegate, treeModel, branchRateModel)); - } } From 85269136af34c086b73968d514ba9716ca5fff92 Mon Sep 17 00:00:00 2001 From: rambaut Date: Wed, 16 Aug 2023 12:31:30 +0100 Subject: [PATCH 181/196] Cleaning up deprecated BeagleTreeLikelihoodParser --- .../treelikelihood/BeagleTreeLikelihoodParser.java | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/dr/evomodelxml/treelikelihood/BeagleTreeLikelihoodParser.java b/src/dr/evomodelxml/treelikelihood/BeagleTreeLikelihoodParser.java index 6f78678bde..79c884d0db 100644 --- a/src/dr/evomodelxml/treelikelihood/BeagleTreeLikelihoodParser.java +++ b/src/dr/evomodelxml/treelikelihood/BeagleTreeLikelihoodParser.java @@ -60,16 +60,11 @@ */ public class BeagleTreeLikelihoodParser extends AbstractXMLObjectParser { - public static final String BEAGLE_INSTANCE_COUNT = "beagle.instance.count"; public static final String BEAGLE_THREAD_COUNT = "beagle.thread.count"; public static final String THREAD_COUNT = "thread.count"; - public static final String THREADS = "threads"; public static final String TREE_LIKELIHOOD = "treeLikelihood"; public static final String USE_AMBIGUITIES = "useAmbiguities"; - public static final String INSTANCE_COUNT = "instanceCount"; - // public static final String DEVICE_NUMBER = "deviceNumber"; -// public static final String PREFER_SINGLE_PRECISION = "preferSinglePrecision"; public static final String SCALING_SCHEME = "scalingScheme"; public static final String DELAY_SCALING = "delayScaling"; public static final String PARTIALS_RESTRICTION = "partialsRestriction"; From 423360b88e41aa7bf3e5c4548f72984c41a52f24 Mon Sep 17 00:00:00 2001 From: rambaut Date: Wed, 16 Aug 2023 13:47:37 +0100 Subject: [PATCH 182/196] Adding compoundlikelihood's likelihoods to the likelihood list rather than the compoundlikelihood to avoid a spurious warning --- src/dr/xml/XMLParser.java | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/dr/xml/XMLParser.java b/src/dr/xml/XMLParser.java index 9a83e4a73c..38b9588e9d 100644 --- a/src/dr/xml/XMLParser.java +++ b/src/dr/xml/XMLParser.java @@ -25,6 +25,7 @@ package dr.xml; +import dr.inference.model.CompoundLikelihood; import dr.inference.model.Likelihood; import dr.inference.model.Model; import dr.inference.model.Parameter; @@ -332,7 +333,9 @@ private Object convert(Element e, Class target, XMLObject parent, boolean run, b addCitable((Citable)obj); } - if (obj instanceof Likelihood) { + if (obj instanceof CompoundLikelihood) { + Likelihood.FULL_LIKELIHOOD_SET.addAll(((CompoundLikelihood) obj).getLikelihoods()); + } else if (obj instanceof Likelihood) { Likelihood.FULL_LIKELIHOOD_SET.add((Likelihood) obj); } else if (obj instanceof Model) { Model.FULL_MODEL_SET.add((Model) obj); From 5fca6cec85219e7158775074f723eaebcc08fdd0 Mon Sep 17 00:00:00 2001 From: rambaut Date: Wed, 16 Aug 2023 13:48:17 +0100 Subject: [PATCH 183/196] Cleaning reporting --- .../treedatalikelihood/BeagleDataLikelihoodDelegate.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dr/evomodel/treedatalikelihood/BeagleDataLikelihoodDelegate.java b/src/dr/evomodel/treedatalikelihood/BeagleDataLikelihoodDelegate.java index a727647257..fc77dbc1a5 100644 --- a/src/dr/evomodel/treedatalikelihood/BeagleDataLikelihoodDelegate.java +++ b/src/dr/evomodel/treedatalikelihood/BeagleDataLikelihoodDelegate.java @@ -119,7 +119,7 @@ public BeagleDataLikelihoodDelegate(Tree tree, super("BeagleDataLikelihoodDelegate"); final Logger logger = Logger.getLogger("dr.evomodel"); - logger.info("\nUsing BEAGLE DataLikelihood Delegate"); + logger.info("\nCreating BEAGLE DataLikelihood Delegate"); setId(patternList.getId()); this.dataType = patternList.getDataType(); From 73bc470dba56cf90edee815f280eb155e0773e7d Mon Sep 17 00:00:00 2001 From: rambaut Date: Wed, 16 Aug 2023 13:48:36 +0100 Subject: [PATCH 184/196] Cleaning reporting --- .../TreeDataLikelihoodParser.java | 45 +++++++++---------- 1 file changed, 21 insertions(+), 24 deletions(-) diff --git a/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java b/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java index ec12247c43..0bc249faea 100644 --- a/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java +++ b/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java @@ -73,7 +73,8 @@ public String getParserName() { return TREE_DATA_LIKELIHOOD; } - protected Likelihood createTreeDataLikelihood(List patternLists, + protected Likelihood createTreeDataLikelihood(String id, + List patternLists, List branchModels, List siteRateModels, Tree treeModel, @@ -86,7 +87,6 @@ protected Likelihood createTreeDataLikelihood(List patternLists, PreOrderSettings settings) throws XMLParseException { final Logger logger = Logger.getLogger("dr.evomodel"); - logger.info("\nCreating tree data likelihoods for " + patternLists.size() + " partitions"); if (tipStatesModel != null) { throw new XMLParseException("Tip State Error models are not supported yet with TreeDataLikelihood"); @@ -121,7 +121,7 @@ protected Likelihood createTreeDataLikelihood(List patternLists, boolean useJava = Boolean.parseBoolean(System.getProperty("java.only", "false")); if (useJava) { - logger.warning(" Java-only computation is not available - ignoring this option."); + logger.warning(" Java-only computation is not available - ignoring this option."); } int beagleThreadCount = -1; @@ -141,7 +141,8 @@ protected Likelihood createTreeDataLikelihood(List patternLists, // (actually logical threads - so 2 x number of cores when hyperthreads are used). beagleThreadCount = Runtime.getRuntime().availableProcessors(); - + System.setProperty(BEAGLE_THREAD_COUNT, Integer.toString(beagleThreadCount)); + // 'threadCount' controls the top level number of Java threads holding the // likelihood/prior evaluations. Shouldn't be considered here - by default // this will use an autosizing thread pool so should probably be left alone. @@ -150,14 +151,15 @@ protected Likelihood createTreeDataLikelihood(List patternLists, // } } + String plural = (patternLists.size() > 1 ? "s": ""); if ( useBeagle3MultiPartition) { + logger.info("\nCreating multi-partition tree data likelihood for " + patternLists.size() + " partition" + plural); + if ( beagleInstanceCount > 1) { - logger.warning(" BEAGLE multi-partition extensions are not compatible with -beagle_instances option"); + logger.warning(" -beagle_instances option is not compatible with BEAGLE multi-partition extensions"); } - System.setProperty(BEAGLE_THREAD_COUNT, Integer.toString(beagleThreadCount)); - try { DataLikelihoodDelegate dataLikelihoodDelegate = new MultiPartitionDataLikelihoodDelegate( treeModel, @@ -187,8 +189,10 @@ protected Likelihood createTreeDataLikelihood(List patternLists, // System.setProperty(BEAGLE_THREAD_COUNT, Integer.toString(threadCount / patternLists.size())); // } + logger.info("\nCreating tree data likelihood" + plural + " for " + patternLists.size() + " partition" + plural); + if (beagleInstanceCount > 1) { - logger.info(" Dividing each partition amongst " + beagleInstanceCount + " BEAGLE instances:"); + logger.info(" dividing each partition between " + beagleInstanceCount + " BEAGLE instances:"); } for (int i = 0; i < patternLists.size(); i++) { @@ -209,11 +213,14 @@ protected Likelihood createTreeDataLikelihood(List patternLists, delayRescalingUntilUnderflow, settings); - treeDataLikelihoods.add( - new TreeDataLikelihood( - dataLikelihoodDelegate, - treeModel, - branchRateModel)); + TreeDataLikelihood treeDataLikelihood = new TreeDataLikelihood( + dataLikelihoodDelegate, + treeModel, + branchRateModel); + + treeDataLikelihood.setId(id + "_" + (j + 1)); + + treeDataLikelihoods.add(treeDataLikelihood); } } @@ -236,17 +243,6 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { } PreOrderSettings settings = new PreOrderSettings(usePreOrder, branchRateDerivative, branchInfinitesimalDerivative); - // TreeDataLikelihood doesn't currently support Instances defined from the command line -// int instanceCount = xo.getAttribute(INSTANCE_COUNT, 1); -// if (instanceCount < 1) { -// instanceCount = 1; -// } -// -// String ic = System.getProperty(BEAGLE_INSTANCE_COUNT); -// if (ic != null && ic.length() > 0) { -// instanceCount = Integer.parseInt(ic); -// } - List patternLists = new ArrayList<>(); List siteRateModels = new ArrayList<>(); List branchModels = new ArrayList<>(); @@ -346,6 +342,7 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { } return createTreeDataLikelihood( + xo.getId(), patternLists, branchModels, siteRateModels, From d86969636929ac6c99cd2ea861a3ddf03ed3be54 Mon Sep 17 00:00:00 2001 From: rambaut Date: Wed, 16 Aug 2023 14:26:08 +0100 Subject: [PATCH 185/196] Final tweaks (-beagle_threads 0 is equivalent of -beagle_threading_off) --- src/dr/app/beast/BeastMain.java | 2 +- .../treedatalikelihood/BeagleDataLikelihoodDelegate.java | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/dr/app/beast/BeastMain.java b/src/dr/app/beast/BeastMain.java index 6dc18a10cc..0b7d90961b 100644 --- a/src/dr/app/beast/BeastMain.java +++ b/src/dr/app/beast/BeastMain.java @@ -372,7 +372,7 @@ public static void main(String[] args) throws java.io.IOException { new Arguments.Option("beagle_SSE", "BEAGLE: use SSE extensions if available"), new Arguments.Option("beagle_SSE_off", "BEAGLE: turn off use of SSE extensions"), new Arguments.Option("beagle_threading_off", "BEAGLE: turn off multi-threading for a CPU instance"), - new Arguments.IntegerOption("beagle_threads", 1, Integer.MAX_VALUE, "BEAGLE: manually set number of threads per CPU instance (default auto)"), + new Arguments.IntegerOption("beagle_threads", 0, Integer.MAX_VALUE, "BEAGLE: manually set number of threads per CPU instance (default auto)"), new Arguments.Option("beagle_cuda", "BEAGLE: use CUDA parallization if available"), new Arguments.Option("beagle_opencl", "BEAGLE: use OpenCL parallization if available"), new Arguments.Option("beagle_single", "BEAGLE: use single precision if available"), diff --git a/src/dr/evomodel/treedatalikelihood/BeagleDataLikelihoodDelegate.java b/src/dr/evomodel/treedatalikelihood/BeagleDataLikelihoodDelegate.java index fc77dbc1a5..c79f08d840 100644 --- a/src/dr/evomodel/treedatalikelihood/BeagleDataLikelihoodDelegate.java +++ b/src/dr/evomodel/treedatalikelihood/BeagleDataLikelihoodDelegate.java @@ -305,7 +305,7 @@ public BeagleDataLikelihoodDelegate(Tree tree, threadCount = Integer.parseInt(tc); } - if (threadCount == 0 || threadCount == 1) { + if (threadCount < 1) { preferenceFlags &= ~BeagleFlag.THREADING_CPP.getMask(); preferenceFlags |= BeagleFlag.THREADING_NONE.getMask(); } else { @@ -439,10 +439,10 @@ public BeagleDataLikelihoodDelegate(Tree tree, instanceFlags = instanceDetails.getFlags(); if ((instanceFlags & BeagleFlag.THREADING_CPP.getMask()) != 0) { - if (IS_THREAD_COUNT_COMPATIBLE() || threadCount != 0) { + if (IS_THREAD_COUNT_COMPATIBLE() && threadCount != 0) { if (threadCount > 0) { beagle.setCPUThreadCount(threadCount); - logger.info(" Using " + threadCount + " threads for CPU."); + logger.info(" Using " + threadCount + " thread" + (threadCount > 1 ? "s" : "") + " for CPU."); } else { // if no thread_count is specified then this will be -1 so put no upper bound on threads // currently the parser provides a default based on the number of cores as BEAGLE's From 65016f7a994ef6f7a3e0df84fe80bf5cc05fd792 Mon Sep 17 00:00:00 2001 From: rambaut Date: Wed, 16 Aug 2023 14:47:40 +0100 Subject: [PATCH 186/196] Making more compatible with TDLP in hmc_clock --- .../TreeDataLikelihoodParser.java | 27 ++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java b/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java index 0bc249faea..58f9026aed 100644 --- a/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java +++ b/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java @@ -60,6 +60,7 @@ public class TreeDataLikelihoodParser extends AbstractXMLObjectParser { public static final String TREE_DATA_LIKELIHOOD = "treeDataLikelihood"; public static final String USE_AMBIGUITIES = "useAmbiguities"; + public static final String INSTANCE_COUNT = "instanceCount"; public static final String PREFER_GPU = "preferGPU"; public static final String SCALING_SCHEME = "scalingScheme"; public static final String DELAY_SCALING = "delayScaling"; @@ -74,6 +75,7 @@ public String getParserName() { } protected Likelihood createTreeDataLikelihood(String id, + int beagleInstanceCount, List patternLists, List branchModels, List siteRateModels, @@ -130,12 +132,6 @@ protected Likelihood createTreeDataLikelihood(String id, beagleThreadCount = Integer.parseInt(System.getProperty(BEAGLE_THREAD_COUNT)); } - int beagleInstanceCount = 1; - String ic = System.getProperty(BEAGLE_INSTANCE_COUNT); - if (ic != null && ic.length() > 0) { - beagleInstanceCount = Math.max(1, Integer.parseInt(ic)); - } - if (beagleThreadCount == -1) { // no beagle_thread_count is given so use the number of available processors // (actually logical threads - so 2 x number of cores when hyperthreads are used). @@ -201,10 +197,10 @@ protected Likelihood createTreeDataLikelihood(String id, int bic = Math.min(partitionPatterns.getPatternCount(), beagleInstanceCount); for (int j = 0; j < bic; j++) { - PatternList patterns = new Patterns(partitionPatterns, j, bic); + PatternList subPatterns = new Patterns(partitionPatterns, j, bic); DataLikelihoodDelegate dataLikelihoodDelegate = new BeagleDataLikelihoodDelegate( treeModel, - patterns, + subPatterns, branchModels.get(i), siteRateModels.get(i), useAmbiguities, @@ -243,9 +239,15 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { } PreOrderSettings settings = new PreOrderSettings(usePreOrder, branchRateDerivative, branchInfinitesimalDerivative); - List patternLists = new ArrayList<>(); - List siteRateModels = new ArrayList<>(); - List branchModels = new ArrayList<>(); + int beagleInstanceCount = xo.getAttribute(INSTANCE_COUNT, 1); + String bic = System.getProperty(BEAGLE_INSTANCE_COUNT); + if (bic != null && bic.length() > 0) { + beagleInstanceCount = Math.max(1, Integer.parseInt(bic)); + } + + List patternLists = new ArrayList(); + List siteRateModels = new ArrayList(); + List branchModels = new ArrayList(); boolean hasSinglePartition = false; @@ -338,11 +340,12 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { final boolean delayScaling = xo.getAttribute(DELAY_SCALING, true); if (tipStatesModel != null) { - throw new XMLParseException("BEAGLE_INSTANCES option cannot be used with a TipStateModel (i.e., a sequence error model)."); + throw new XMLParseException("TreeDataLikelihood is not currently compatible with TipStateModel (i.e., a sequence error model)."); } return createTreeDataLikelihood( xo.getId(), + beagleInstanceCount, patternLists, branchModels, siteRateModels, From e34ad89af42e6b092a639af1e34df36907b0c382 Mon Sep 17 00:00:00 2001 From: rambaut Date: Wed, 16 Aug 2023 14:55:59 +0100 Subject: [PATCH 187/196] fixing parser syntax --- .../evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java | 1 + 1 file changed, 1 insertion(+) diff --git a/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java b/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java index 58f9026aed..b126cd6e20 100644 --- a/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java +++ b/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java @@ -375,6 +375,7 @@ public Class getReturnType() { AttributeRule.newBooleanRule(USE_AMBIGUITIES, true), AttributeRule.newBooleanRule(PREFER_GPU, true), AttributeRule.newStringRule(SCALING_SCHEME,true), + AttributeRule.newIntegerRule(INSTANCE_COUNT, true), // really it should be this set of elements or the PARTITION elements new OrRule(new AndRule(new XMLSyntaxRule[]{ From 5e1aa87da72b89831515c2426e2a315d8b349d00 Mon Sep 17 00:00:00 2001 From: rambaut Date: Wed, 16 Aug 2023 15:19:53 +0100 Subject: [PATCH 188/196] more small tweaks to reporting --- src/dr/evomodelxml/tree/CTMCScalePriorParser.java | 4 ++-- .../inferencexml/operators/SimpleOperatorScheduleParser.java | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/dr/evomodelxml/tree/CTMCScalePriorParser.java b/src/dr/evomodelxml/tree/CTMCScalePriorParser.java index 8c4c4a1880..9ced2a1395 100644 --- a/src/dr/evomodelxml/tree/CTMCScalePriorParser.java +++ b/src/dr/evomodelxml/tree/CTMCScalePriorParser.java @@ -57,9 +57,9 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { boolean trial = xo.getAttribute(TRIAL, false); SubstitutionModel substitutionModel = (SubstitutionModel) xo.getChild(SubstitutionModel.class); - Logger.getLogger("dr.evolution").info("Creating CTMC Scale Reference Prior model."); + Logger.getLogger("dr.evolution").info("\nCreating CTMC Scale Reference Prior model"); if (taxa != null) { - Logger.getLogger("dr.evolution").info("Acting on subtree of size " + taxa.getTaxonCount()); + Logger.getLogger("dr.evolution").info(" Acting on subtree of size " + taxa.getTaxonCount()); } return new CTMCScalePrior(MODEL_NAME, ctmcScale, treeModel, taxa, reciprocal, substitutionModel, trial); } diff --git a/src/dr/inferencexml/operators/SimpleOperatorScheduleParser.java b/src/dr/inferencexml/operators/SimpleOperatorScheduleParser.java index 895c2182d8..234972a2b8 100644 --- a/src/dr/inferencexml/operators/SimpleOperatorScheduleParser.java +++ b/src/dr/inferencexml/operators/SimpleOperatorScheduleParser.java @@ -59,10 +59,11 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { schedule.setSequential(xo.getBooleanAttribute(SEQUENTIAL)); } + Logger.getLogger("dr.inference").info("\nCreating operator scheduler"); if (xo.hasAttribute(OPTIMIZATION_SCHEDULE)) { String type = xo.getStringAttribute(OPTIMIZATION_SCHEDULE); - Logger.getLogger("dr.inference").info("Optimization Schedule: " + type); + Logger.getLogger("dr.inference").info(" Optimization schedule: " + type); try { if (type.equalsIgnoreCase("default")) { From 3a04257d1eb716c769b03b735ce86552b9519816 Mon Sep 17 00:00:00 2001 From: Andrew Rambaut Date: Mon, 21 Aug 2023 22:58:31 +0100 Subject: [PATCH 189/196] Retain more backwards XML compatibility --- .../treedatalikelihood/TreeDataLikelihoodParser.java | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java b/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java index 4eb77f4859..4087ee618d 100644 --- a/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java +++ b/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java @@ -291,8 +291,12 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { patternLists.add(patternList); SiteRateModel siteRateModel = (SiteRateModel) cxo.getChild(SiteRateModel.class); +// if (siteRateModel == null) { +// siteRateModel = new +// } siteRateModels.add(siteRateModel); + FrequencyModel rootFreqModel = (FrequencyModel) xo.getChild(FrequencyModel.class); BranchModel branchModel = (BranchModel) cxo.getChild(BranchModel.class); @@ -302,6 +306,10 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { // for backwards compatibility the old GammaSiteRateModelParser can provide the substitution model... substitutionModel = ((DiscretizedSiteRateModel)siteRateModel).getSubstitutionModel(); } + if (substitutionModel == null && siteRateModel instanceof GammaSiteRateModel) { + // for backwards compatibility the old GammaSiteRateModelParser can provide the substitution model... + substitutionModel = ((GammaSiteRateModel)siteRateModel).getSubstitutionModel(); + } if (substitutionModel == null) { throw new XMLParseException("No substitution model available for partition " + k + " in DataTreeLikelihood: "+xo.getId()); } From 886775f3843b23507bd4fb65fe52888c0da5e529 Mon Sep 17 00:00:00 2001 From: Andrew Rambaut Date: Mon, 21 Aug 2023 23:13:34 +0100 Subject: [PATCH 190/196] Implemented a 'null' homogeneous-rates delegate for DiscretizedSiteRateModel --- .../DiscretizedSiteRateModel.java | 4 +- .../HomogeneousRateDelegate.java | 46 +++++++++++++++++++ .../TreeDataLikelihoodParser.java | 4 -- 3 files changed, 48 insertions(+), 6 deletions(-) create mode 100644 src/dr/evomodel/siteratemodel/HomogeneousRateDelegate.java diff --git a/src/dr/evomodel/siteratemodel/DiscretizedSiteRateModel.java b/src/dr/evomodel/siteratemodel/DiscretizedSiteRateModel.java index 0eb21d1a5b..30501e347c 100644 --- a/src/dr/evomodel/siteratemodel/DiscretizedSiteRateModel.java +++ b/src/dr/evomodel/siteratemodel/DiscretizedSiteRateModel.java @@ -40,8 +40,8 @@ public class DiscretizedSiteRateModel extends AbstractModel implements SiteRateM /** - * Constructor for gamma+invar distributed sites. Either shapeParameter or - * invarParameter (or both) can be null to turn off that feature. + * Constructor for a discretized site rate model that uses a delegate to set + * the category rates. */ public DiscretizedSiteRateModel( String name, diff --git a/src/dr/evomodel/siteratemodel/HomogeneousRateDelegate.java b/src/dr/evomodel/siteratemodel/HomogeneousRateDelegate.java new file mode 100644 index 0000000000..1084a799f6 --- /dev/null +++ b/src/dr/evomodel/siteratemodel/HomogeneousRateDelegate.java @@ -0,0 +1,46 @@ +package dr.evomodel.siteratemodel; + +import dr.inference.model.AbstractModel; +import dr.inference.model.Model; +import dr.inference.model.Parameter; +import dr.inference.model.Variable; + +public class HomogeneousRateDelegate extends AbstractModel implements SiteRateDelegate { + + public HomogeneousRateDelegate(String name) { + super(name); + } + + @Override + public int getCategoryCount() { + return 1; + } + + @Override + public void getCategories(double[] categoryRates, double[] categoryProportions) { + categoryRates[0] = 1.0; + categoryProportions[0] = 1.0; + } + + // ***************************************************************** + // Interface ModelComponent + // ***************************************************************** + + protected void handleModelChangedEvent(Model model, Object object, int index) { + // do nothing + } + + protected final void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) { + // do nothing + } + + protected void storeState() { + } // no additional state needs storing + + protected void restoreState() { + } + + protected void acceptState() { + } // no additional state needs accepting + +} diff --git a/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java b/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java index 4087ee618d..caf33262af 100644 --- a/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java +++ b/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java @@ -291,12 +291,8 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { patternLists.add(patternList); SiteRateModel siteRateModel = (SiteRateModel) cxo.getChild(SiteRateModel.class); -// if (siteRateModel == null) { -// siteRateModel = new -// } siteRateModels.add(siteRateModel); - FrequencyModel rootFreqModel = (FrequencyModel) xo.getChild(FrequencyModel.class); BranchModel branchModel = (BranchModel) cxo.getChild(BranchModel.class); From aa0051251ab77285ec2ff28de505c730aae6a738 Mon Sep 17 00:00:00 2001 From: Andrew Rambaut Date: Mon, 21 Aug 2023 23:19:25 +0100 Subject: [PATCH 191/196] More backwards compatibility --- src/dr/evomodelxml/siteratemodel/SiteModelParser.java | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/dr/evomodelxml/siteratemodel/SiteModelParser.java b/src/dr/evomodelxml/siteratemodel/SiteModelParser.java index a3067887be..d4d0d9e384 100644 --- a/src/dr/evomodelxml/siteratemodel/SiteModelParser.java +++ b/src/dr/evomodelxml/siteratemodel/SiteModelParser.java @@ -29,6 +29,8 @@ import dr.evomodel.siteratemodel.DiscretizedSiteRateModel; import dr.evomodel.siteratemodel.GammaSiteRateDelegate; +import dr.evomodel.siteratemodel.HomogeneousRateDelegate; +import dr.evomodel.siteratemodel.SiteRateDelegate; import dr.evomodel.substmodel.SubstitutionModel; import dr.oldevomodel.sitemodel.SiteModel; import dr.inference.model.Parameter; @@ -127,7 +129,12 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { Logger.getLogger("dr.evomodel").info("\nCreating site rate model."); } - GammaSiteRateDelegate delegate = new GammaSiteRateDelegate("GammaSiteRateDelegate", shapeParam, catCount, type, invarParam); + SiteRateDelegate delegate; + if (shapeParam != null || invarParam != null) { + delegate = new GammaSiteRateDelegate("GammaSiteRateDelegate", shapeParam, catCount, type, invarParam); + } else { + delegate = new HomogeneousRateDelegate("HomogeneousRateDelegate"); + } DiscretizedSiteRateModel siteRateModel = new DiscretizedSiteRateModel(SiteModel.SITE_MODEL, muParam, muWeight, delegate); From 586143e51f5b7dd10b73ca88817fed67c3572571 Mon Sep 17 00:00:00 2001 From: Andrew Rambaut Date: Mon, 21 Aug 2023 23:49:16 +0100 Subject: [PATCH 192/196] Allowing for no siteRateModel in a partition --- .../siteratemodel/DiscretizedSiteRateModel.java | 16 ++++++++++++++++ .../TreeDataLikelihoodParser.java | 7 +++++++ 2 files changed, 23 insertions(+) diff --git a/src/dr/evomodel/siteratemodel/DiscretizedSiteRateModel.java b/src/dr/evomodel/siteratemodel/DiscretizedSiteRateModel.java index 30501e347c..0f64b81645 100644 --- a/src/dr/evomodel/siteratemodel/DiscretizedSiteRateModel.java +++ b/src/dr/evomodel/siteratemodel/DiscretizedSiteRateModel.java @@ -38,6 +38,22 @@ public class DiscretizedSiteRateModel extends AbstractModel implements SiteRateModel { + /** + * Constructor for a rate homogenous (single category) SiteRateModel. + */ + public DiscretizedSiteRateModel(String name) { + this(name, null, 0.0, new HomogeneousRateDelegate(null)); + } + + /** + * Constructor for a rate homogenous (single category) SiteRateModel. + */ + public DiscretizedSiteRateModel( + String name, + Parameter nuParameter, + double muWeight) { + this(name, nuParameter, muWeight, new HomogeneousRateDelegate(null)); + } /** * Constructor for a discretized site rate model that uses a delegate to set diff --git a/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java b/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java index caf33262af..546e7c89f1 100644 --- a/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java +++ b/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java @@ -34,6 +34,7 @@ import dr.evomodel.branchratemodel.BranchRateModel; import dr.evomodel.siteratemodel.DiscretizedSiteRateModel; import dr.evomodel.siteratemodel.GammaSiteRateModel; +import dr.evomodel.siteratemodel.HomogeneousRateDelegate; import dr.evomodel.siteratemodel.SiteRateModel; import dr.evomodel.substmodel.FrequencyModel; import dr.evomodel.substmodel.SubstitutionModel; @@ -259,6 +260,9 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { patternLists.add(patternList); SiteRateModel siteRateModel = (SiteRateModel) xo.getChild(SiteRateModel.class); + if (siteRateModel == null) { + siteRateModel = new DiscretizedSiteRateModel("SiteRateModel"); + } siteRateModels.add(siteRateModel); FrequencyModel rootFreqModel = (FrequencyModel) xo.getChild(FrequencyModel.class); @@ -291,6 +295,9 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { patternLists.add(patternList); SiteRateModel siteRateModel = (SiteRateModel) cxo.getChild(SiteRateModel.class); + if (siteRateModel == null) { + siteRateModel = new DiscretizedSiteRateModel("SiteRateModel"); + } siteRateModels.add(siteRateModel); FrequencyModel rootFreqModel = (FrequencyModel) xo.getChild(FrequencyModel.class); From a0384d0ce84991facc91c0800692d326a15013ad Mon Sep 17 00:00:00 2001 From: Andrew Rambaut Date: Fri, 1 Sep 2023 13:56:19 +0100 Subject: [PATCH 193/196] Fixed an issue with a scale move on node heights was producing negative branch lengths. Now mirrors up/down operator with one parameter and rejects the move if out of bounds. --- src/dr/inference/operators/ScaleOperator.java | 28 +++++++++++++++---- .../operators/ScaleOperatorParser.java | 3 +- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/src/dr/inference/operators/ScaleOperator.java b/src/dr/inference/operators/ScaleOperator.java index 8a0897220d..3bdd243a65 100644 --- a/src/dr/inference/operators/ScaleOperator.java +++ b/src/dr/inference/operators/ScaleOperator.java @@ -44,6 +44,7 @@ * @version $Id: ScaleOperator.java,v 1.20 2005/06/14 10:40:34 rambaut Exp $ */ public class ScaleOperator extends AbstractAdaptableOperator { + private final boolean REJECT_IF_OUT_OF_BOUNDS = true; private Parameter indicator; private double indicatorOnProb; @@ -55,15 +56,16 @@ public ScaleOperator(Variable variable, double scale) { public ScaleOperator(Variable variable, double scale, AdaptationMode mode, double weight) { - this(variable, false, 0, scale, mode, null, 1.0, false); - setWeight(weight); + this(variable, false, 0, scale, mode, weight, null, 1.0, false); } public ScaleOperator(Variable variable, boolean scaleAll, int degreesOfFreedom, double scale, - AdaptationMode mode, Parameter indicator, double indicatorOnProb, boolean scaleAllInd) { + AdaptationMode mode, double weight, Parameter indicator, double indicatorOnProb, boolean scaleAllInd) { super(mode); + setWeight(weight); + this.variable = variable; this.indicator = indicator; this.indicatorOnProb = indicatorOnProb; @@ -132,9 +134,23 @@ public final double doOperation() { variable.setValue(i, variable.getValue(i) * scale); } - for (int i = 0; i < dim; i++) { - if (variable.getValue(i) > variable.getBounds().getUpperLimit(i)) { - throw new RuntimeException("proposed value greater than upper bound"); + if (REJECT_IF_OUT_OF_BOUNDS) { + // when scaling all parameter dimensions with different bounds (i.e., node heights + // where nodes below may bound a height) if the proposed scale will put any + // of the dimensions out of bounds then reject the move. + for (int i = 0; i < dim; i++) { + if (variable.getValue(i) > variable.getBounds().getUpperLimit(i) || + variable.getValue(i) < variable.getBounds().getLowerLimit(i)) { + return Double.NEGATIVE_INFINITY; + } + } + } else { + for (int i = 0; i < dim; i++) { + if (variable.getValue(i) > variable.getBounds().getUpperLimit(i)) { + throw new RuntimeException("proposed value greater than upper bound"); + } else if (variable.getValue(i) < variable.getBounds().getLowerLimit(i)) { + throw new RuntimeException("proposed value less than lower bound"); + } } } } else { diff --git a/src/dr/inferencexml/operators/ScaleOperatorParser.java b/src/dr/inferencexml/operators/ScaleOperatorParser.java index 4f75454f1d..7e0aa245d0 100644 --- a/src/dr/inferencexml/operators/ScaleOperatorParser.java +++ b/src/dr/inferencexml/operators/ScaleOperatorParser.java @@ -139,9 +139,8 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { ScaleOperator operator = new ScaleOperator(parameter, scaleAll, degreesOfFreedom, scaleFactor, - mode, indicator, indicatorOnProb, + mode, weight, indicator, indicatorOnProb, scaleAllInd); - operator.setWeight(weight); return operator; } From a3ff593ca70d3d09b0703718a1595cfdb12a77b3 Mon Sep 17 00:00:00 2001 From: Andrew Rambaut Date: Fri, 1 Sep 2023 13:57:07 +0100 Subject: [PATCH 194/196] Hiding Felsenstein weights in Gamma site model for BEAUti for the moment. --- .../sitemodelspanel/PartitionModelPanel.java | 20 ++++++++----------- 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/src/dr/app/beauti/sitemodelspanel/PartitionModelPanel.java b/src/dr/app/beauti/sitemodelspanel/PartitionModelPanel.java index d815726051..75219dbab2 100644 --- a/src/dr/app/beauti/sitemodelspanel/PartitionModelPanel.java +++ b/src/dr/app/beauti/sitemodelspanel/PartitionModelPanel.java @@ -74,7 +74,7 @@ public class PartitionModelPanel extends OptionsPanel { .values()); private JComboBox heteroCombo = new JComboBox(new String[] { "None", - "Gamma (Felsenstein weights)", "Gamma (equal weights)", "Invariant Sites", "Gamma (equal weights) + Invariant Sites" }); + /*"Gamma (Felsenstein weights)", */ "Gamma (equal weights)", "Invariant Sites", "Gamma (equal weights) + Invariant Sites" }); private JComboBox gammaCatCombo = new JComboBox(new String[] { "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16" }); @@ -216,21 +216,17 @@ public void itemStateChanged(ItemEvent ev) { PanelUtils.setupComponent(heteroCombo); heteroCombo - .setToolTipText("Select the type of site-specific rate
heterogeneity model.
" - + "\"Felsenstein weights\" uses the quadrature method to calculate the category weights described in
" + - "Felsenstein (2001) J Mol Evol 53: 447-455."); + .setToolTipText("Select the type of site-specific rate
heterogeneity model." + +// "
\"Felsenstein weights\" uses the quadrature method to calculate the category weights described in
" + +// "Felsenstein (2001) J Mol Evol 53: 447-455." + + ""); heteroCombo.addItemListener(new ItemListener() { public void itemStateChanged(ItemEvent ev) { - boolean gammaHetero = heteroCombo.getSelectedIndex() == 1 || - heteroCombo.getSelectedIndex() == 2 - || heteroCombo.getSelectedIndex() == 4; - + boolean gammaHetero = heteroCombo.getSelectedItem().toString().contains("Gamma"); model.setGammaHetero(gammaHetero); - model.setInvarHetero(heteroCombo.getSelectedIndex() == 3 - || heteroCombo.getSelectedIndex() == 4); - model.setGammaHeteroEqualWeights(heteroCombo.getSelectedIndex() == 2 - || heteroCombo.getSelectedIndex() == 4); + model.setInvarHetero(heteroCombo.getSelectedItem().toString().contains("Invariant")); + model.setGammaHeteroEqualWeights(heteroCombo.getSelectedItem().toString().contains("equal")); if (gammaHetero) { gammaCatLabel.setEnabled(true); From 185cfc9487a7d3c70bf7081c0223737f9f84134c Mon Sep 17 00:00:00 2001 From: ghassler Date: Tue, 5 Sep 2023 12:48:25 -0700 Subject: [PATCH 195/196] reverting abd52ff for TransformedMultivariateParameter --- .../TransformedMultivariateParameter.java | 22 +++++-------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/src/dr/inference/model/TransformedMultivariateParameter.java b/src/dr/inference/model/TransformedMultivariateParameter.java index 490de1d19a..c9eade96bb 100644 --- a/src/dr/inference/model/TransformedMultivariateParameter.java +++ b/src/dr/inference/model/TransformedMultivariateParameter.java @@ -55,24 +55,19 @@ public double getParameterValue(int dim) { public void setParameterValue(int dim, double value) { update(); - unTransformedValues[dim] = value; -/* transformedValues[dim] = value; - unTransformedValues = inverse(transformedValues);*/ + transformedValues[dim] = value; + unTransformedValues = inverse(transformedValues); // Need to update all values parameter.setParameterValueNotifyChangedAll(0, unTransformedValues[0]); // Warn everyone is changed for (int i = 1; i < parameter.getDimension(); i++) { parameter.setParameterValueQuietly(i, unTransformedValues[i]); // Do the rest quietly } - transformedValues = transform(unTransformedValues); } public void setParameterValueQuietly(int dim, double value) { update(); - unTransformedValues[dim] = value; - transformedValues = transform(unTransformedValues); - -/* transformedValues[dim] = value; - unTransformedValues = inverse(transformedValues);*/ + transformedValues[dim] = value; + unTransformedValues = inverse(transformedValues); // Need to update all values for (int i = 0; i < parameter.getDimension(); i++) { parameter.setParameterValueQuietly(i, unTransformedValues[i]); @@ -96,23 +91,18 @@ public void addBounds(Bounds bounds) { // } private void update() { - -// if (hasChanged()) { + if (hasChanged()) { unTransformedValues = parameter.getParameterValues(); transformedValues = transform(unTransformedValues); -// } + } } private boolean hasChanged() { - - for (int i = 0; i < unTransformedValues.length; i++) { if (parameter.getParameterValue(i) != unTransformedValues[i]) { return true; } } - - return false; } } From c890dc531c810445a66a68f62c1b4ef826dbff60 Mon Sep 17 00:00:00 2001 From: ghassler Date: Tue, 19 Sep 2023 15:01:30 -0700 Subject: [PATCH 196/196] bug fix in parser --- .../model/CompoundSymmetricMatrixParser.java | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/dr/inferencexml/model/CompoundSymmetricMatrixParser.java b/src/dr/inferencexml/model/CompoundSymmetricMatrixParser.java index 230ac30ef3..4ee267b681 100644 --- a/src/dr/inferencexml/model/CompoundSymmetricMatrixParser.java +++ b/src/dr/inferencexml/model/CompoundSymmetricMatrixParser.java @@ -57,10 +57,6 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { boolean isCholesky = xo.getAttribute(IS_CHOLESKY, false); - int dimOff = diagonalParameter.getDimension() * (diagonalParameter.getDimension() - 1) / 2; - if (dimOff != offDiagonalParameter.getDimension()) { - throw new XMLParseException("The vector '" + OFF_DIAGONAL + "' must be of dimension n*(n-1)/2 = " + dimOff + ", where n=" + diagonalParameter.getDimension() + " is the dimension of the vector '" + DIAGONAL + "'."); - } boolean isStrictlyUpperTriangular = xo.getAttribute(IS_STRICTLY_UPPER, true); @@ -73,6 +69,13 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { compoundSymmetricMatrix.setStrictlyUpperTriangular(false); } + int dimOff = diagonalParameter.getDimension() * (diagonalParameter.getDimension() - 1) / 2; + if (!isStrictlyUpperTriangular) dimOff = dimOff + diagonalParameter.getDimension(); + + if (dimOff != offDiagonalParameter.getDimension()) { + throw new XMLParseException("The vector '" + OFF_DIAGONAL + "' must be of dimension n*(n-1)/2 = " + dimOff + ", where n=" + diagonalParameter.getDimension() + " is the dimension of the vector '" + DIAGONAL + "'."); + } + return compoundSymmetricMatrix; }