diff --git a/src/dr/evomodel/coalescent/smooth/SmoothSkygridLikelihood.java b/src/dr/evomodel/coalescent/smooth/SmoothSkygridLikelihood.java index 599da4697e..0d5080bbd5 100644 --- a/src/dr/evomodel/coalescent/smooth/SmoothSkygridLikelihood.java +++ b/src/dr/evomodel/coalescent/smooth/SmoothSkygridLikelihood.java @@ -31,6 +31,7 @@ import dr.evomodel.bigfasttree.BigFastTreeIntervals; import dr.evomodel.coalescent.AbstractCoalescentLikelihood; import dr.evomodel.tree.TreeModel; +import dr.inference.model.Model; import dr.inference.model.Parameter; import dr.util.Author; import dr.util.Citable; @@ -76,6 +77,21 @@ public SmoothSkygridLikelihood(String name, this.populationSizeInverse = new SmoothSkygridPopulationSizeInverse(logPopSizeParameter, gridPointParameter, smoothFunction, smoothRate); this.lineageCount = new OldSmoothLineageCount(trees.get(0), smoothFunction, smoothRate); intervalsList = new ArrayList<>(); + + this.tmpA = new double[trees.get(0).getNodeCount()]; + this.tmpB = new double[trees.get(0).getNodeCount()]; + this.tmpC = new double[trees.get(0).getNodeCount()]; + this.tmpADerivOverS = new double[trees.get(0).getNodeCount()]; + this.tmpBDerivOverS = new double[trees.get(0).getNodeCount()]; + this.tmpCDerivOverS = new double[trees.get(0).getNodeCount()]; + this.tmpD = new double[gridPointParameter.getDimension()]; + this.tmpE = new double[gridPointParameter.getDimension()]; + this.tmpF = new double[gridPointParameter.getDimension()]; + this.tmpLineageEffect = new double[trees.get(0).getNodeCount()]; + this.tmpTimes = new double[trees.get(0).getNodeCount()]; + this.tmpCounts = new int[trees.get(0).getNodeCount()]; + this.tmpSumsKnown = false; + for (int i = 0; i < trees.size(); i++) { intervalsList.add(new BigFastTreeIntervals(trees.get(i))); addModel(intervalsList.get(i)); @@ -256,24 +272,28 @@ private double getLineageCountDifference(int intervalIndex, BigFastTreeIntervals } } - protected double calculateLogLikelihood() { - assert(trees.size() == 1); - if (!likelihoodKnown) { + private double[] tmpA; + private double[] tmpADerivOverS; + private double[] tmpB; + private double[] tmpBDerivOverS; + private double[] tmpC; + private double[] tmpCDerivOverS; + private double[] tmpD; + private double[] tmpE; + private double[] tmpF; + private double[] tmpLineageEffect; + private double[] tmpTimes; + private int[] tmpCounts; + private int uniqueTimes; + private boolean tmpSumsKnown; + + private void calculateTmpSums() { + if (!tmpSumsKnown) { TreeModel tree = trees.get(0); final double startTime = 0; final double endTime = tree.getNodeHeight(tree.getRoot()); final int maxGridIndex = getMaxGridIndex(gridPointParameter, endTime); - double[] tmpA = new double[tree.getNodeCount()]; - double[] tmpB = new double[tree.getNodeCount()]; - double[] tmpC = new double[tree.getNodeCount()]; - double[] tmpD = new double[maxGridIndex]; - double[] tmpE = new double[maxGridIndex]; - double[] tmpF = new double[maxGridIndex]; - double[] tmpLineageEffect = new double[tree.getNodeCount()]; - double[] tmpTimes = new double[tree.getNodeCount()]; - int[] tmpCounts = new int[tree.getNodeCount()]; - NodeRef[] nodes = new NodeRef[tree.getNodeCount()]; System.arraycopy(tree.getNodes(), 0, nodes, 0, tree.getNodeCount()); Arrays.parallelSort(nodes, (a, b) -> Double.compare(tree.getNodeHeight(a), tree.getNodeHeight(b))); @@ -301,7 +321,7 @@ protected double calculateLogLikelihood() { } tmpLineageEffect[index] = currentLineageEffect; tmpCounts[index] = currentCount; - final int uniqueTimes = index + 1; + uniqueTimes = index + 1; for (int i = 0; i < uniqueTimes; i++) { final double timeI = tmpTimes[i]; @@ -354,101 +374,273 @@ protected double calculateLogLikelihood() { tmpE[k] = sum; tmpF[k] = sum * sum - quadraticSum; } + tmpSumsKnown = true; + } + } - double tripleIntegrationSum = 0; - double lineageEffectSqaredSum = 0; - for (int i = 0; i < uniqueTimes; i++) { - final double lineageCountEffect = tmpLineageEffect[i]; - lineageEffectSqaredSum += lineageCountEffect * lineageCountEffect; - tripleIntegrationSum += lineageCountEffect * tmpA[i] * tmpB[i] * tmpC[i]; + private void calculateTmpSumDerivatives() { + if (!tmpSumsKnown) { + calculateTmpSums(); + } + + TreeModel tree = trees.get(0); + final double startTime = 0; + final double endTime = tree.getNodeHeight(tree.getRoot()); + final int maxGridIndex = getMaxGridIndex(gridPointParameter, endTime); + + for (int i = 0; i < uniqueTimes; i++) { + final double timeI = tmpTimes[i]; + double sum = 0; + for (int j = 0; j < uniqueTimes; j++) { + if (j != i) { + final double timeJ = tmpTimes[j]; + final double lineageCountEffect = tmpLineageEffect[j]; + final double thisInverse = smoothFunction.getInverseOneMinusExponential(timeJ - timeI, smoothRate.getParameterValue(0)); + sum += lineageCountEffect * thisInverse * (1 - thisInverse); + } } - tripleIntegrationSum *= 2; + tmpADerivOverS[i] = - sum; + } + for (int i = 0; i < uniqueTimes; i++) { + final double timeI = tmpTimes[i]; + double sum = 0; for (int k = 0; k < maxGridIndex; k++) { final double currentPopSizeInverse = Math.exp(-logPopSizeParameter.getParameterValue(k)); final double nextPopSizeInverse = Math.exp(-logPopSizeParameter.getParameterValue(k + 1)); - tripleIntegrationSum += (nextPopSizeInverse - currentPopSizeInverse) * tmpF[k] * tmpD[k]; + final double gridTime = gridPointParameter.getParameterValue(k); + final double thisInverse = smoothFunction.getInverseOneMinusExponential(gridTime - timeI, smoothRate.getParameterValue(0)); + sum += (nextPopSizeInverse - currentPopSizeInverse) * thisInverse * (1 - thisInverse); } + tmpBDerivOverS[i] = -sum; + } - tripleIntegrationSum /= -smoothRate.getParameterValue(0) * 2; - tripleIntegrationSum += -0.5 * (1 - lineageEffectSqaredSum) - * (Math.exp(-logPopSizeParameter.getParameterValue(maxGridIndex)) - Math.exp(-logPopSizeParameter.getParameterValue(0))) - * (endTime - startTime); + for (int i = 0; i < uniqueTimes; i++) { + final double timeI = tmpTimes[i]; + tmpCDerivOverS[i] = smoothFunction.getSingleIntegrationDerivative(startTime, endTime, timeI, smoothRate.getParameterValue(0)); + } + } - double tripleWithQuadraticIntegrationSum = 0; - final double commonFirstTermMultiplier = (Math.exp(-logPopSizeParameter.getParameterValue(maxGridIndex)) - Math.exp(-logPopSizeParameter.getParameterValue(0))) * (endTime - startTime); + protected double calculateLogLikelihood() { + assert(trees.size() == 1); + if (!likelihoodKnown) { + TreeModel tree = trees.get(0); + final double startTime = 0; + final double endTime = tree.getNodeHeight(tree.getRoot()); + final int maxGridIndex = getMaxGridIndex(gridPointParameter, endTime); + + calculateTmpSums(); + + double lineageEffectSquaredSum = 0; for (int i = 0; i < uniqueTimes; i++) { - final double lineageCountEffect = tmpLineageEffect[i] * tmpLineageEffect[i]; - final double timeI = tmpTimes[i]; - double thisResult = commonFirstTermMultiplier; - final double commonSecondTermMultiplier = smoothFunction.getInverseOnePlusExponential(timeI - startTime, smoothRate.getParameterValue(0)) - - smoothFunction.getInverseOnePlusExponential(timeI - endTime, smoothRate.getParameterValue(0)); - for (int k = 0; k < maxGridIndex; k++) { - final double currentPopSizeInverse = Math.exp(-logPopSizeParameter.getParameterValue(k)); - final double nextPopSizeInverse = Math.exp(-logPopSizeParameter.getParameterValue(k + 1)); - final double gridTime = gridPointParameter.getParameterValue(k); - final double inverse = smoothFunction.getInverseOneMinusExponential(gridTime - timeI, smoothRate.getParameterValue(0)); - thisResult += (nextPopSizeInverse - currentPopSizeInverse) / smoothRate.getParameterValue(0) - * (inverse * commonSecondTermMultiplier + (2.0 - inverse) * inverse * tmpC[i] + - (1 - inverse) * (1 - inverse) * tmpD[k]); - } - thisResult *= lineageCountEffect; - tripleWithQuadraticIntegrationSum += thisResult; + lineageEffectSquaredSum += tmpLineageEffect[i] * tmpLineageEffect[i]; } - tripleWithQuadraticIntegrationSum *= -0.5; - double firstDoubleIntegrationOffDiagonalSum = 0; - double firstDoubleIntegrationDiagonalSum = 0; - for (int i = 0; i < uniqueTimes; i++) { - final double lineageCountEffect = tmpLineageEffect[i]; - final double timeI = tmpTimes[i]; - firstDoubleIntegrationOffDiagonalSum += lineageCountEffect * tmpA[i] * tmpC[i]; - firstDoubleIntegrationDiagonalSum += lineageCountEffect * lineageCountEffect - * smoothFunction.getQuadraticIntegration(startTime, endTime, timeI, smoothRate.getParameterValue(0)); + double tripleIntegrationSum = getTripleIntegration(startTime, endTime, maxGridIndex, lineageEffectSquaredSum); + + double doubleIntegrationSum = getDoubleIntegration(startTime, endTime, maxGridIndex, lineageEffectSquaredSum); + + final double singleIntegration = getSingleIntegration(startTime, endTime); + + double logPopulationSizeInverse = 0; + for (int i = 0; i < tree.getInternalNodeCount(); i++) { + NodeRef node = tree.getNode(tree.getExternalNodeCount() + i); + logPopulationSizeInverse += Math.log(getSmoothPopulationSizeInverse(tree.getNodeHeight(node), tree.getNodeHeight(tree.getRoot()))); } - firstDoubleIntegrationOffDiagonalSum /= smoothRate.getParameterValue(0); - firstDoubleIntegrationOffDiagonalSum += 0.5 * (1 - lineageEffectSqaredSum) * (endTime - startTime); - final double firstDoubleIntegrationSum = -(firstDoubleIntegrationDiagonalSum * 0.5 + firstDoubleIntegrationOffDiagonalSum) * Math.exp(-logPopSizeParameter.getParameterValue(0)); + logLikelihood = logPopulationSizeInverse + singleIntegration + doubleIntegrationSum + tripleIntegrationSum; - double secondDoubleIntegrationSum = 0; - for (int i = 0; i < uniqueTimes; i++) { - secondDoubleIntegrationSum += 0.5 * tmpB[i] * tmpC[i] * tmpLineageEffect[i]; + likelihoodKnown = true; + } + return logLikelihood; + } + + private double[] getGradientWrtNodeHeightNew() { + if (!likelihoodKnown) { + calculateLogLikelihood(); + } + TreeModel tree = trees.get(0); + final double startTime = 0; + final double endTime = tree.getNodeHeight(tree.getRoot()); + double[] gradient = new double[tree.getInternalNodeCount()]; + getGradientWrtNodeHeightFromSingleIntegration(startTime, endTime, gradient); + + double lineageEffectSquaredSum = 0; + for (int i = 0; i < uniqueTimes; i++) { + lineageEffectSquaredSum += tmpLineageEffect[i] * tmpLineageEffect[i]; + } + getGradientWrtNodeHeightFromDoubleIntegration(startTime, endTime, getMaxGridIndex(gridPointParameter, endTime), gradient); + + getGradientWrtNodeHeightFromTripleIntegration(startTime, endTime, getMaxGridIndex(gridPointParameter, endTime), gradient); + return gradient; + } + + double getTripleIntegration(double startTime, double endTime, int maxGridIndex, double lineageEffectSquaredSum) { + double tripleIntegrationSum = 0; + for (int i = 0; i < uniqueTimes; i++) { + final double lineageCountEffect = tmpLineageEffect[i]; + tripleIntegrationSum += lineageCountEffect * tmpA[i] * tmpB[i] * tmpC[i]; + } + tripleIntegrationSum *= 2; + + for (int k = 0; k < maxGridIndex; k++) { + final double currentPopSizeInverse = Math.exp(-logPopSizeParameter.getParameterValue(k)); + final double nextPopSizeInverse = Math.exp(-logPopSizeParameter.getParameterValue(k + 1)); + tripleIntegrationSum += (nextPopSizeInverse - currentPopSizeInverse) * tmpF[k] * tmpD[k]; + } + + tripleIntegrationSum /= -smoothRate.getParameterValue(0) * 2; + tripleIntegrationSum += -0.5 * (1 - lineageEffectSquaredSum) + * (Math.exp(-logPopSizeParameter.getParameterValue(maxGridIndex)) - Math.exp(-logPopSizeParameter.getParameterValue(0))) + * (endTime - startTime); + + double tripleWithQuadraticIntegrationSum = 0; + final double commonFirstTermMultiplier = (Math.exp(-logPopSizeParameter.getParameterValue(maxGridIndex)) - Math.exp(-logPopSizeParameter.getParameterValue(0))) * (endTime - startTime); + for (int i = 0; i < uniqueTimes; i++) { + final double lineageCountEffect = tmpLineageEffect[i] * tmpLineageEffect[i]; + final double timeI = tmpTimes[i]; + double thisResult = commonFirstTermMultiplier; + final double commonSecondTermMultiplier = smoothFunction.getInverseOnePlusExponential(timeI - startTime, smoothRate.getParameterValue(0)) + - smoothFunction.getInverseOnePlusExponential(timeI - endTime, smoothRate.getParameterValue(0)); + for (int k = 0; k < maxGridIndex; k++) { + final double currentPopSizeInverse = Math.exp(-logPopSizeParameter.getParameterValue(k)); + final double nextPopSizeInverse = Math.exp(-logPopSizeParameter.getParameterValue(k + 1)); + final double gridTime = gridPointParameter.getParameterValue(k); + final double inverse = smoothFunction.getInverseOneMinusExponential(gridTime - timeI, smoothRate.getParameterValue(0)); + thisResult += (nextPopSizeInverse - currentPopSizeInverse) / smoothRate.getParameterValue(0) + * (inverse * commonSecondTermMultiplier + (2.0 - inverse) * inverse * tmpC[i] + + (1 - inverse) * (1 - inverse) * tmpD[k]); } + thisResult *= lineageCountEffect; + tripleWithQuadraticIntegrationSum += thisResult; + } + tripleWithQuadraticIntegrationSum *= -0.5; + return tripleIntegrationSum + tripleWithQuadraticIntegrationSum; + } + private void getGradientWrtNodeHeightFromTripleIntegration(double startTime, double endTime, int maxGridIndex, + double[] gradient) { + for (int i = 0; i < uniqueTimes; i++) { + final double lineageCountEffect = tmpLineageEffect[i]; + final double timeI = tmpTimes[i]; + gradient[i] += lineageCountEffect * (tmpADerivOverS[i] * tmpB[i] * tmpC[i] + tmpA[i] * tmpBDerivOverS[i] * tmpC[i] + tmpA[i] * tmpB[i] * tmpCDerivOverS[i]); for (int k = 0; k < maxGridIndex; k++) { final double currentPopSizeInverse = Math.exp(-logPopSizeParameter.getParameterValue(k)); final double nextPopSizeInverse = Math.exp(-logPopSizeParameter.getParameterValue(k + 1)); - secondDoubleIntegrationSum += 0.5 * tmpE[k] * tmpD[k] * (nextPopSizeInverse - currentPopSizeInverse); + final double gridTime = gridPointParameter.getParameterValue(k); + final double tmpEInverse = smoothFunction.getInverseOneMinusExponential(timeI - gridTime, smoothRate.getParameterValue(0)); + + gradient[i] += (nextPopSizeInverse - currentPopSizeInverse) * tmpD[k] * (tmpE[k] - lineageCountEffect * tmpEInverse ) * tmpEInverse * (1 - tmpEInverse) * lineageCountEffect; } - secondDoubleIntegrationSum /= smoothRate.getParameterValue(0); - secondDoubleIntegrationSum += 0.5 * (endTime - startTime) * (Math.exp(-logPopSizeParameter.getParameterValue(maxGridIndex)) - Math.exp(-logPopSizeParameter.getParameterValue(0))); + final double startTimeInverse = smoothFunction.getInverseOnePlusExponential(timeI - startTime, smoothRate.getParameterValue(0)); + final double endTimeInverse = smoothFunction.getInverseOnePlusExponential(timeI - startTime, smoothRate.getParameterValue(0)); + final double commonSecondTermMultiplier = startTimeInverse - endTimeInverse; + final double commonSecondTermMultiplierDerivativeOverS = - startTimeInverse * (1 - startTimeInverse) + endTimeInverse * (1 - endTimeInverse); - double singleIntegration = 0; - for (int i = 0; i < uniqueTimes; i++) { - final double timeI = tmpTimes[i]; - final double lineageCountEffectI = tmpLineageEffect[i]; - singleIntegration += lineageCountEffectI * smoothFunction.getSingleIntegration(startTime, endTime, timeI, smoothRate.getParameterValue(0)); + for (int k = 0; k < maxGridIndex; k++) { + final double currentPopSizeInverse = Math.exp(-logPopSizeParameter.getParameterValue(k)); + final double nextPopSizeInverse = Math.exp(-logPopSizeParameter.getParameterValue(k + 1)); + final double gridTime = gridPointParameter.getParameterValue(k); + final double inverse = smoothFunction.getInverseOneMinusExponential(gridTime - timeI, smoothRate.getParameterValue(0)); + final double inverseDerivativeOverS = -inverse * (1 - inverse); + gradient[i] += (nextPopSizeInverse - currentPopSizeInverse) + * (inverseDerivativeOverS * commonSecondTermMultiplier + inverse * commonSecondTermMultiplierDerivativeOverS + + 2 * (1 - inverse) * inverseDerivativeOverS * tmpC[i] + (2.0 - inverse) * inverse * tmpCDerivOverS[i] + + 2 * (1 - inverse) * (-inverseDerivativeOverS) * tmpD[k]); } - singleIntegration *= 0.5 * Math.exp(-logPopSizeParameter.getParameterValue(0)); + } - double logPopulationSizeInverse = 0; - for (int i = 0; i < tree.getInternalNodeCount(); i++) { - NodeRef node = tree.getNode(tree.getExternalNodeCount() + i); - logPopulationSizeInverse += Math.log(getSmoothPopulationSizeInverse(tree.getNodeHeight(node), tree.getNodeHeight(tree.getRoot()))); + } + + double getDoubleIntegration(double startTime, double endTime, int maxGridIndex, double lineageEffectSquaredSum) { + double firstDoubleIntegrationOffDiagonalSum = 0; + double firstDoubleIntegrationDiagonalSum = 0; + for (int i = 0; i < uniqueTimes; i++) { + final double lineageCountEffect = tmpLineageEffect[i]; + final double timeI = tmpTimes[i]; + firstDoubleIntegrationOffDiagonalSum += lineageCountEffect * tmpA[i] * tmpC[i]; + firstDoubleIntegrationDiagonalSum += lineageCountEffect * lineageCountEffect + * smoothFunction.getQuadraticIntegration(startTime, endTime, timeI, smoothRate.getParameterValue(0)); + } + firstDoubleIntegrationOffDiagonalSum /= smoothRate.getParameterValue(0); + firstDoubleIntegrationOffDiagonalSum += 0.5 * (1 - lineageEffectSquaredSum) * (endTime - startTime); + + final double firstDoubleIntegrationSum = -(firstDoubleIntegrationDiagonalSum * 0.5 + firstDoubleIntegrationOffDiagonalSum) * Math.exp(-logPopSizeParameter.getParameterValue(0)); + + double secondDoubleIntegrationSum = 0; + for (int i = 0; i < uniqueTimes; i++) { + secondDoubleIntegrationSum += 0.5 * tmpB[i] * tmpC[i] * tmpLineageEffect[i]; + } + + for (int k = 0; k < maxGridIndex; k++) { + final double currentPopSizeInverse = Math.exp(-logPopSizeParameter.getParameterValue(k)); + final double nextPopSizeInverse = Math.exp(-logPopSizeParameter.getParameterValue(k + 1)); + secondDoubleIntegrationSum += 0.5 * tmpE[k] * tmpD[k] * (nextPopSizeInverse - currentPopSizeInverse); + } + + secondDoubleIntegrationSum /= smoothRate.getParameterValue(0); + secondDoubleIntegrationSum += 0.5 * (endTime - startTime) * (Math.exp(-logPopSizeParameter.getParameterValue(maxGridIndex)) - Math.exp(-logPopSizeParameter.getParameterValue(0))); + + return firstDoubleIntegrationSum + secondDoubleIntegrationSum; + + } + + private void getGradientWrtNodeHeightFromDoubleIntegration(double startTime, double endTime, int maxGridIndex, + double[] gradient) { + final double firstPopSize = Math.exp(-logPopSizeParameter.getParameterValue(0)); + for (int i = 0; i < uniqueTimes; i++) { + final double lineageCountEffect = tmpLineageEffect[i]; + final double timeI = tmpTimes[i]; + //firstDoubleIntegrationOffDiagonalSum += lineageCountEffect * tmpA[i] * tmpC[i]; + gradient[i] += -lineageCountEffect * (tmpA[i] * tmpCDerivOverS[i] + tmpADerivOverS[i] * tmpC[i]) * firstPopSize; + + //firstDoubleIntegrationDiagonalSum + gradient[i] += lineageCountEffect * lineageCountEffect + * (smoothFunction.getSingleIntegrationDerivative(startTime, endTime, timeI, smoothRate.getParameterValue(0)) + + (smoothFunction.getDerivative(timeI, endTime, 0, 1, smoothRate.getParameterValue(0)) + - smoothFunction.getDerivative(timeI, startTime, 0, 1, smoothRate.getParameterValue(0)) / smoothRate.getParameterValue(0)) + ) * -0.5 * firstPopSize; + + gradient[i] += 0.5 * tmpLineageEffect[i] * (tmpB[i] * tmpCDerivOverS[i] + tmpBDerivOverS[i] * tmpC[i]); + + for (int k = 0; k < maxGridIndex; k++) { + final double currentPopSizeInverse = Math.exp(-logPopSizeParameter.getParameterValue(k)); + final double nextPopSizeInverse = Math.exp(-logPopSizeParameter.getParameterValue(k + 1)); + final double gridTime = gridPointParameter.getParameterValue(k); + final double tmpEInverse = smoothFunction.getInverseOneMinusExponential(timeI - gridTime, smoothRate.getParameterValue(0)); + gradient[i] += 0.5 * tmpD[k] * (nextPopSizeInverse - currentPopSizeInverse) * tmpEInverse * (1 - tmpEInverse) * lineageCountEffect; } + } + } - logLikelihood = logPopulationSizeInverse + singleIntegration + firstDoubleIntegrationSum + secondDoubleIntegrationSum + tripleIntegrationSum + tripleWithQuadraticIntegrationSum; + private double getSingleIntegration(double startTime, double endTime) { + double singleIntegration = 0; + for (int i = 0; i < uniqueTimes; i++) { + final double timeI = tmpTimes[i]; + final double lineageCountEffectI = tmpLineageEffect[i]; + singleIntegration += lineageCountEffectI * smoothFunction.getSingleIntegration(startTime, endTime, timeI, smoothRate.getParameterValue(0)); + } + singleIntegration *= 0.5 * Math.exp(-logPopSizeParameter.getParameterValue(0)); + return singleIntegration; + } - likelihoodKnown = true; + private void getGradientWrtNodeHeightFromSingleIntegration(double startTime, double endTime, double[] gradient) { + for (int i = 0; i < uniqueTimes; i++) { + final double timeI = tmpTimes[i]; + final double lineageCountEffectI = tmpLineageEffect[i]; + gradient[i] += lineageCountEffectI * smoothFunction.getSingleIntegrationDerivative(startTime, endTime, timeI, smoothRate.getParameterValue(0)) + * 0.5 * Math.exp(-logPopSizeParameter.getParameterValue(0)); } - return logLikelihood; } + protected void handleModelChangedEvent(Model model, Object object, int index) { + super.handleModelChangedEvent(model, object, index); + tmpSumsKnown = false; + } - private double getLineageCountEffect(Tree tree, int node) { + private double getLineageCountEffect(Tree tree, int node) { if (tree.isExternal(tree.getNode(node))) { return 1; } else { diff --git a/src/dr/evomodel/speciation/EfficientSpeciationLikelihood.java b/src/dr/evomodel/speciation/EfficientSpeciationLikelihood.java index bdf455d23c..95a6036e0d 100644 --- a/src/dr/evomodel/speciation/EfficientSpeciationLikelihood.java +++ b/src/dr/evomodel/speciation/EfficientSpeciationLikelihood.java @@ -237,9 +237,11 @@ public void addTrait(TreeTrait trait) { public String getReport() { String message = super.getReport(); - message += "\n"; - // add likelihood calculation time - message += "Likelihood calculation time is " + likelihoodTime / likelihoodCounts + " nanoseconds.\n"; + if (MEASURE_RUN_TIME) { + message += "\n"; + // add likelihood calculation time + message += "Likelihood calculation time is " + likelihoodTime / likelihoodCounts + " nanoseconds.\n"; + } return message; } } \ No newline at end of file diff --git a/src/dr/evomodel/speciation/EfficientSpeciationLikelihoodGradient.java b/src/dr/evomodel/speciation/EfficientSpeciationLikelihoodGradient.java index 512ceaa87e..ed73a82165 100644 --- a/src/dr/evomodel/speciation/EfficientSpeciationLikelihoodGradient.java +++ b/src/dr/evomodel/speciation/EfficientSpeciationLikelihoodGradient.java @@ -35,6 +35,8 @@ import dr.util.Timer; import dr.xml.Reportable; +import static dr.evomodel.speciation.CachedGradientDelegate.MEASURE_RUN_TIME; + /** * @author Andy Magee * @author Yucai Shao @@ -136,7 +138,7 @@ public LogColumn[] getColumns() { @Override public String getReport() { String message = GradientWrtParameterProvider.getReportAndCheckForError(this, 0.0, Double.POSITIVE_INFINITY, 1E-3); - if (gradientProvider instanceof CachedGradientDelegate) { + if (gradientProvider instanceof CachedGradientDelegate && MEASURE_RUN_TIME) { message += "\n"; message += "Gradient calculation time is " + ((CachedGradientDelegate) gradientProvider).getGradientTime() + " nanoseconds.\n"; } diff --git a/src/dr/evomodel/speciation/MasBirthDeathSerialSamplingModel.java b/src/dr/evomodel/speciation/MasBirthDeathSerialSamplingModel.java new file mode 100644 index 0000000000..a8ccbe2090 --- /dev/null +++ b/src/dr/evomodel/speciation/MasBirthDeathSerialSamplingModel.java @@ -0,0 +1,181 @@ +/* + * NewBirthDeathSerialSamplingModel.java + * + * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard + * + * This file is part of BEAST. + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership and licensing. + * + * BEAST is free software; you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation; either version 2 + * of the License, or (at your option) any later version. + * + * BEAST is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with BEAST; if not, write to the + * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, + * Boston, MA 02110-1301 USA + */ + +package dr.evomodel.speciation; + +import dr.inference.model.Parameter; + +public class MasBirthDeathSerialSamplingModel extends NewBirthDeathSerialSamplingModel { + + public MasBirthDeathSerialSamplingModel(Parameter birthRate, Parameter deathRate, Parameter serialSamplingRate, Parameter treatmentProbability, Parameter samplingProbability, Parameter originTime, boolean condition, int numIntervals, double gridEnd, Type units) { + super(birthRate, deathRate, serialSamplingRate, treatmentProbability, samplingProbability, originTime, condition, numIntervals, gridEnd, units); + } + + @Override + public final double processModelSegmentBreakPoint(int model, double intervalStart, double intervalEnd, int nLineages) { +// double lnL = nLineages * (logQ(model, intervalEnd) - logQ(model, intervalStart)); + double lnL = nLineages * Math.log(Q(model, intervalEnd) / Q(model, intervalStart)); + if ( samplingProbability.getValue(model + 1) > 0.0 && samplingProbability.getValue(model + 1) < 1.0) { + // Add in probability of un-sampled lineages + // We don't need this at t=0 because all lineages in the tree are sampled + // TODO: check if we're right about how many lineages are actually alive at this time. Are we inadvertently over-counting or under-counting due to samples added at this _exact_ time? + lnL += nLineages * Math.log(1.0 - samplingProbability.getValue(model + 1)); + } + this.savedLogQ = Double.NaN; + return lnL; + } + + final void accumulateGradientForInterval(final double[] gradient, final int currentModelSegment, final int nLineages, + final double[] partialQ_all_old, final double Q_Old, + final double[] partialQ_all_young, final double Q_young) { + + for (int k = 0; k <= currentModelSegment; k++) { + for (int p = 0; p < 4; ++p) { + gradient[k * 5 + p] += nLineages * (partialQ_all_old[k * 4 + p] / Q_Old + - partialQ_all_young[k * 4 + p] / Q_young); + } + } + } + + final void accumulateGradientForSampling(double[] gradient, int currentModelSegment, double term1, + double[] intermediate) { + + for (int k = 0; k <= currentModelSegment; k++) { + for (int p = 0; p < 4; ++p) { + gradient[k * 5 + p] += term1 * intermediate[k * 4 + p]; + } + } + + } + + final void dPCompute(int model, double t, double intervalStart, double eAt, double[] dP, double[] dG2) { + + double G1 = g1(eAt); + + double term1 = -A / lambda * ((1 - B) * (eAt - 1) + G1) / (G1 * G1); + + for (int k = 0; k < model; k ++) { + for (int p = 0; p < 4; p++) { + dP[k * 4 + p] = term1 * dB[k * 4 + p]; + } + } + + for (int p = 0; p < 3; ++p) { + double term2 = eAt * (1 + B) * dA[p] * (t - intervalStart) + (eAt - 1) * dB[model * 4 + p]; + dG2[p] = dA[p] - 2 * (G1 * (dA[p] * (1 - B) - dB[model * 4 + p] * A) - (1 - B) * term2 * A) / (G1 * G1); + } + + double G2 = g2(G1); + + dP[model * 4 + 0] = (-mu - psi - lambda * dG2[0] + G2) / (2 * lambda * lambda); + dP[model * 4 + 1] = (1 - dG2[1]) / (2 * lambda); + dP[model * 4 + 2] = (1 - dG2[2]) / (2 * lambda); + dP[model * 4 + 3] = -A / lambda * ((1 - B) * (eAt - 1) + G1) * dB[model * 4 + 3] / (G1 * G1); + } + + + final void dQCompute(int model, double t, double[] dQ, double eAt) { + + double dwell = t - modelStartTimes[model]; + double G1 = g1(eAt); + + double term1 = 8 * eAt; + double term2 = G1 / 2 - eAt * (1 + B); + double term3 = eAt - 1; + double term4 = G1 * G1 * G1; + double term5 = -term1 * term3 / term4; + + for (int k = 0; k < model; ++k) { + dQ[k * 4 + 0] = term5 * dB[k * 4 + 0]; + dQ[k * 4 + 1] = term5 * dB[k * 4 + 1]; + dQ[k * 4 + 2] = term5 * dB[k * 4 + 2]; + dQ[k * 4 + 3] = term5 * dB[k * 4 + 3]; + } + + double term6 = term1 / term4; + double term7 = dwell * term2; + + dQ[model * 4 + 0] = term6 * (dA[0] * term7 - dB[model * 4 + 0] * term3); + dQ[model * 4 + 1] = term6 * (dA[1] * term7 - dB[model * 4 + 1] * term3); + dQ[model * 4 + 2] = term6 * (dA[2] * term7 - dB[model * 4 + 2] * term3); + dQ[model * 4 + 3] = term5 * dB[model * 4 + 3]; + } + + + final double Q(int model, double time) { + double At = A * (time - modelStartTimes[model]); + double eAt = Math.exp(At); + double sqrtDenominator = g1(eAt); + return eAt / (sqrtDenominator * sqrtDenominator); + } + + final double logQ(int model, double time) { + double At = A * (time - modelStartTimes[model]); + double eAt = Math.exp(At); + double sqrtDenominator = g1(eAt); + return At - 2 * Math.log(sqrtDenominator); // TODO log4 (additive constant) is not needed since we always see logQ(a) - logQ(b) + } + + @Override + public double processInterval(int model, double tYoung, double tOld, int nLineages) { + double logQ_young; + double logQ_old = Q(model, tOld); + if (!Double.isNaN(this.savedLogQ)) { + logQ_young = this.savedLogQ; + } else { + logQ_young = Q(model, tYoung); + } + this.savedLogQ = logQ_old; + return nLineages * Math.log(logQ_old / logQ_young); + } + + @Override + public double processSampling(int model, double tOld) { + + double logSampProb; + + boolean sampleIsAtEventTime = tOld == modelStartTimes[model]; + boolean samplesTakenAtEventTime = rho > 0; + + if (sampleIsAtEventTime && samplesTakenAtEventTime) { + logSampProb = Math.log(rho); + if (model > 0) { + logSampProb += Math.log(r + ((1.0 - r) * previousP)); + } + } else { + double logPsi = Math.log(psi); + logSampProb = logPsi + Math.log(r + (1.0 - r) * p(model,tOld)); + } + + return logSampProb; + } +} + +/* + * Notes on inlining: + * https://www.baeldung.com/jvm-method-inlining#:~:text=Essentially%2C%20the%20JIT%20compiler%20tries,times%20we%20invoke%20the%20method. + * https://miuv.blog/2018/02/25/jit-optimizations-method-inlining/ + * static, private, final + */ diff --git a/src/dr/evomodel/speciation/NewBirthDeathSerialSamplingModel.java b/src/dr/evomodel/speciation/NewBirthDeathSerialSamplingModel.java index 84d0aad51a..e7000c86ff 100644 --- a/src/dr/evomodel/speciation/NewBirthDeathSerialSamplingModel.java +++ b/src/dr/evomodel/speciation/NewBirthDeathSerialSamplingModel.java @@ -82,11 +82,12 @@ public class NewBirthDeathSerialSamplingModel extends SpeciationModel implements double psi; double r; double rho; + double logRho; //double rho0; // TODO remove private boolean[] gradientFlags; - private double savedLogQ; + double savedLogQ; private double savedQ; private double[] partialQ; @@ -101,8 +102,8 @@ public class NewBirthDeathSerialSamplingModel extends SpeciationModel implements private double eAt_Old; private double eAt_End; - private final double[] dA; - private final double[] dB; + final double[] dA; + final double[] dB; private final double[] dG2; boolean computedBCurrent; @@ -257,7 +258,7 @@ protected void handleVariableChangedEvent(Variable variable, int index, Paramete // Do nothing } - private double p(int model, double t) { + final double p(int model, double t) { double eAt = Math.exp(A * (t - modelStartTimes[model])); return p(eAt); } @@ -267,7 +268,7 @@ private double p(double eAt) { return (lambda + mu + psi - A * ((eAt1B - (1.0 - B)) / (eAt1B + (1.0 - B)))) / (2.0 * lambda); } - private double logQ(int model, double time) { + double logQ(int model, double time) { double At = A * (time - modelStartTimes[model]); double eAt = Math.exp(At); double sqrtDenominator = g1(eAt); @@ -339,6 +340,7 @@ private void updateParameterValues(int model) { psi = serialSamplingRate.getParameterValue(model); r = treatmentProbability.getParameterValue(model); rho = samplingProbability.getParameterValue(model); +// logRho = Math.log(rho); this.savedLogQ = Double.NaN; } @@ -390,7 +392,7 @@ public void updateGradientModelValues(int model) { double end = modelStartTimes[model + 1]; double start = modelStartTimes[model]; eAt_End = Math.exp(A * (end - start)); - dPCompute(model, end, start, eAt_End, dPModelEnd); + dPCompute(model, end, start, eAt_End, dPModelEnd, dG2); } computedBCurrent = true; @@ -492,20 +494,20 @@ public List getCitations() { )); } - private double g1(double eAt) { + final double g1(double eAt) { return (1 + B) * eAt + (1 - B); } - private double g2(double G1) { + final double g2(double G1) { return A * (1 - 2 * (1 - B) / G1); } - public double q(int model, double t) { + public final double q(int model, double t) { double eAt = Math.exp(A * (t - modelStartTimes[model])); return q(eAt); } - public double q(double eAt) { + public final double q(double eAt) { double sqrtDenominator = g1(eAt); return 4 * eAt / (sqrtDenominator * sqrtDenominator); } @@ -606,7 +608,7 @@ private void dBCompute(int model, double[] dB) { } } - private void dPCompute(int model, double t, double intervalStart, double eAt, double[] dP) { + void dPCompute(int model, double t, double intervalStart, double eAt, double[] dP, double[] dG2) { double G1 = g1(eAt); @@ -689,7 +691,6 @@ private void dPCompute(int model, double t, double intervalStart, double eAt, do } dP[model * 4 + 3] = -A / lambda * ((1 - B) * (eAt - 1) + G1) * dB[model * 4 + 3] / (G1 * G1); } - } private void dQCompute(int model, double t, double[] dQ) { @@ -698,7 +699,7 @@ private void dQCompute(int model, double t, double[] dQ) { dQCompute(model, t, dQ, eAt); } - private void dQCompute(int model, double t, double[] dQ, double eAt) { + void dQCompute(int model, double t, double[] dQ, double eAt) { double dwell = t - modelStartTimes[model]; double G1 = g1(eAt); @@ -848,10 +849,38 @@ public void processGradientInterval(double[] gradient, int currentModelSegment, // gradient[k * 5 + p] += nLineages * (partialQ_all_old[k * 4 + p] / Q_Old - partialQ_all_young[k * 4 + p] / Q_young); // } // } + +// for (int p = 0; p < 4; ++p) { +// if (gradientFlags[p]) { +// for (int k = 0; k <= currentModelSegment; k++) { +// gradient[k * 5 + p] += nLineages * (partialQ_all_old[k * 4 + p] / Q_Old - partialQ_all_young[k * 4 + p] / Q_young); +// } +// } +// } + + accumulateGradientForInterval(gradient, currentModelSegment, nLineages, + partialQ_all_old, Q_Old, partialQ_all_young, Q_young); + } + + void accumulateGradientForInterval(double[] gradient, int currentModelSegment, int nLineages, + double[] partialQ_all_old, double Q_Old, + double[] partialQ_all_young, double Q_young) { for (int p = 0; p < 4; ++p) { if (gradientFlags[p]) { for (int k = 0; k <= currentModelSegment; k++) { - gradient[k * 5 + p] += nLineages * (partialQ_all_old[k * 4 + p] / Q_Old - partialQ_all_young[k * 4 + p] / Q_young); + gradient[k * 5 + p] += nLineages * (partialQ_all_old[k * 4 + p] / Q_Old + - partialQ_all_young[k * 4 + p] / Q_young); + } + } + } + } + + void accumulateGradientForSampling(double[] gradient, int currentModelSegment, double term1, + double[] intermediate) { + for (int p = 0; p < 4; p++) { + if (gradientFlags[p]) { + for (int k = 0; k <= currentModelSegment; k++) { + gradient[k * 5 + p] += term1 * intermediate[k * 4 + p]; } } } @@ -887,7 +916,7 @@ public void processGradientSampling(double[] gradient, int currentModelSegment, // double eAt = Math.exp(A * (intervalEnd - modelStartTimes[currentModelSegment])); - dPCompute(currentModelSegment, intervalEnd, modelStartTimes[currentModelSegment], eAt_Old, this.dPIntervalEnd); + dPCompute(currentModelSegment, intervalEnd, modelStartTimes[currentModelSegment], eAt_Old, dPIntervalEnd, dG2); double term1 = (1 - r) / ((1 - r) * p_it + r); @@ -898,13 +927,15 @@ public void processGradientSampling(double[] gradient, int currentModelSegment, // gradient[fractionIndex(k, numIntervals)] += term1 * dPIntervalEnd[k * 4 + 3]; // } - for (int p = 0; p < 4; p++) { - if (gradientFlags[p]) { - for (int k = 0; k <= currentModelSegment; k++) { - gradient[genericIndex(k, p, numIntervals)] += term1 * dPIntervalEnd[k * 4 + p]; - } - } - } +// for (int p = 0; p < 4; p++) { +// if (gradientFlags[p]) { +// for (int k = 0; k <= currentModelSegment; k++) { +// gradient[genericIndex(k, p, numIntervals)] += term1 * dPIntervalEnd[k * 4 + p]; +// } +// } +// } + + accumulateGradientForSampling(gradient, currentModelSegment, term1, dPIntervalEnd); } if (sampleIsAtEventTime && currentModelSegment > 0) { @@ -919,13 +950,16 @@ public void processGradientSampling(double[] gradient, int currentModelSegment, // gradient[samplingIndex(k, numIntervals)] += term1 * dPModelEnd_prev[k * 4 + 2]; // gradient[fractionIndex(k, numIntervals)] += term1 * dPModelEnd_prev[k * 4 + 3]; // } - for (int p = 0; p < 4; p++) { - if (gradientFlags[p]) { - for (int k = 0; k < currentModelSegment; k++) { - gradient[genericIndex(k, p, numIntervals)] += term1 * dPModelEnd_prev[k * 4 + p]; - } - } - } + +// for (int p = 0; p < 4; p++) { +// if (gradientFlags[p]) { +// for (int k = 0; k < currentModelSegment; k++) { +// gradient[genericIndex(k, p, numIntervals)] += term1 * dPModelEnd_prev[k * 4 + p]; +// } +// } +// } + + accumulateGradientForSampling(gradient, currentModelSegment, term1, dPModelEnd_prev); } } diff --git a/src/dr/evomodelxml/speciation/NewBirthDeathSerialSamplingModelParser.java b/src/dr/evomodelxml/speciation/NewBirthDeathSerialSamplingModelParser.java index 9dde88edb8..e1231a5fb3 100644 --- a/src/dr/evomodelxml/speciation/NewBirthDeathSerialSamplingModelParser.java +++ b/src/dr/evomodelxml/speciation/NewBirthDeathSerialSamplingModelParser.java @@ -26,6 +26,7 @@ package dr.evomodelxml.speciation; import dr.evolution.util.Units; +import dr.evomodel.speciation.MasBirthDeathSerialSamplingModel; import dr.evomodel.speciation.NewBirthDeathSerialSamplingModel; import dr.evoxml.util.XMLUnits; import dr.inference.model.Parameter; @@ -98,13 +99,16 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { Logger.getLogger("dr.evomodel").info(citeThisModel); - // return new NewBirthDeathSerialSamplingModel(lambda, mu, psi, r, rho, origin, condition, (int)(numGridPoints.getParameterValue(0)), cutoff.getParameterValue(0), units); - NewBirthDeathSerialSamplingModel model = new NewBirthDeathSerialSamplingModel(lambda, mu, psi, r, rho, origin, condition, (int)(numGridPoints.getParameterValue(0)), cutoff.getParameterValue(0), units); + NewBirthDeathSerialSamplingModel model = MAS_TEST ? + new MasBirthDeathSerialSamplingModel(lambda, mu, psi, r, rho, origin, condition, (int)(numGridPoints.getParameterValue(0)), cutoff.getParameterValue(0), units) : + new NewBirthDeathSerialSamplingModel(lambda, mu, psi, r, rho, origin, condition, (int)(numGridPoints.getParameterValue(0)), cutoff.getParameterValue(0), units); model.setupGradientFlags(gradientFlags); model.setupTimeline(grids != null ? grids.getParameterValues(): null); return model; } + private static final boolean MAS_TEST = false; + //************************************************************************ // AbstractXMLObjectParser implementation //************************************************************************ diff --git a/src/dr/util/Timer.java b/src/dr/util/Timer.java index fea98c1863..5ba1f409b3 100644 --- a/src/dr/util/Timer.java +++ b/src/dr/util/Timer.java @@ -31,18 +31,18 @@ public class Timer { private long nanoStart = 0, nanoStop = 0; public void start() { + nanoStart = System.nanoTime(); // One wants the hihest precision first. TODO Do we really need this? start = System.currentTimeMillis(); - nanoStart = System.nanoTime(); } public void stop() { - stop = System.currentTimeMillis(); nanoStop = System.nanoTime(); + stop = System.currentTimeMillis(); } public void update() { - stop = System.currentTimeMillis(); nanoStop = System.nanoTime(); + stop = System.currentTimeMillis(); } /**