Skip to content

Commit

Permalink
Merge branch 'hmc-clock' of https://github.com/beast-dev/beast-mcmc i…
Browse files Browse the repository at this point in the history
…nto hmc-clock
  • Loading branch information
afmagee committed Jul 25, 2023
2 parents 99d28e7 + 4a5d462 commit b03acba
Show file tree
Hide file tree
Showing 7 changed files with 529 additions and 114 deletions.
344 changes: 268 additions & 76 deletions src/dr/evomodel/coalescent/smooth/SmoothSkygridLikelihood.java

Large diffs are not rendered by default.

8 changes: 5 additions & 3 deletions src/dr/evomodel/speciation/EfficientSpeciationLikelihood.java
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,11 @@ public void addTrait(TreeTrait trait) {

public String getReport() {
String message = super.getReport();
message += "\n";
// add likelihood calculation time
message += "Likelihood calculation time is " + likelihoodTime / likelihoodCounts + " nanoseconds.\n";
if (MEASURE_RUN_TIME) {
message += "\n";
// add likelihood calculation time
message += "Likelihood calculation time is " + likelihoodTime / likelihoodCounts + " nanoseconds.\n";
}
return message;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
import dr.util.Timer;
import dr.xml.Reportable;

import static dr.evomodel.speciation.CachedGradientDelegate.MEASURE_RUN_TIME;

/**
* @author Andy Magee
* @author Yucai Shao
Expand Down Expand Up @@ -136,7 +138,7 @@ public LogColumn[] getColumns() {
@Override
public String getReport() {
String message = GradientWrtParameterProvider.getReportAndCheckForError(this, 0.0, Double.POSITIVE_INFINITY, 1E-3);
if (gradientProvider instanceof CachedGradientDelegate) {
if (gradientProvider instanceof CachedGradientDelegate && MEASURE_RUN_TIME) {
message += "\n";
message += "Gradient calculation time is " + ((CachedGradientDelegate) gradientProvider).getGradientTime() + " nanoseconds.\n";
}
Expand Down
181 changes: 181 additions & 0 deletions src/dr/evomodel/speciation/MasBirthDeathSerialSamplingModel.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
/*
* NewBirthDeathSerialSamplingModel.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/

package dr.evomodel.speciation;

import dr.inference.model.Parameter;

public class MasBirthDeathSerialSamplingModel extends NewBirthDeathSerialSamplingModel {

public MasBirthDeathSerialSamplingModel(Parameter birthRate, Parameter deathRate, Parameter serialSamplingRate, Parameter treatmentProbability, Parameter samplingProbability, Parameter originTime, boolean condition, int numIntervals, double gridEnd, Type units) {
super(birthRate, deathRate, serialSamplingRate, treatmentProbability, samplingProbability, originTime, condition, numIntervals, gridEnd, units);
}

@Override
public final double processModelSegmentBreakPoint(int model, double intervalStart, double intervalEnd, int nLineages) {
// double lnL = nLineages * (logQ(model, intervalEnd) - logQ(model, intervalStart));
double lnL = nLineages * Math.log(Q(model, intervalEnd) / Q(model, intervalStart));
if ( samplingProbability.getValue(model + 1) > 0.0 && samplingProbability.getValue(model + 1) < 1.0) {
// Add in probability of un-sampled lineages
// We don't need this at t=0 because all lineages in the tree are sampled
// TODO: check if we're right about how many lineages are actually alive at this time. Are we inadvertently over-counting or under-counting due to samples added at this _exact_ time?
lnL += nLineages * Math.log(1.0 - samplingProbability.getValue(model + 1));
}
this.savedLogQ = Double.NaN;
return lnL;
}

final void accumulateGradientForInterval(final double[] gradient, final int currentModelSegment, final int nLineages,
final double[] partialQ_all_old, final double Q_Old,
final double[] partialQ_all_young, final double Q_young) {

for (int k = 0; k <= currentModelSegment; k++) {
for (int p = 0; p < 4; ++p) {
gradient[k * 5 + p] += nLineages * (partialQ_all_old[k * 4 + p] / Q_Old
- partialQ_all_young[k * 4 + p] / Q_young);
}
}
}

final void accumulateGradientForSampling(double[] gradient, int currentModelSegment, double term1,
double[] intermediate) {

for (int k = 0; k <= currentModelSegment; k++) {
for (int p = 0; p < 4; ++p) {
gradient[k * 5 + p] += term1 * intermediate[k * 4 + p];
}
}

}

final void dPCompute(int model, double t, double intervalStart, double eAt, double[] dP, double[] dG2) {

double G1 = g1(eAt);

double term1 = -A / lambda * ((1 - B) * (eAt - 1) + G1) / (G1 * G1);

for (int k = 0; k < model; k ++) {
for (int p = 0; p < 4; p++) {
dP[k * 4 + p] = term1 * dB[k * 4 + p];
}
}

for (int p = 0; p < 3; ++p) {
double term2 = eAt * (1 + B) * dA[p] * (t - intervalStart) + (eAt - 1) * dB[model * 4 + p];
dG2[p] = dA[p] - 2 * (G1 * (dA[p] * (1 - B) - dB[model * 4 + p] * A) - (1 - B) * term2 * A) / (G1 * G1);
}

double G2 = g2(G1);

dP[model * 4 + 0] = (-mu - psi - lambda * dG2[0] + G2) / (2 * lambda * lambda);
dP[model * 4 + 1] = (1 - dG2[1]) / (2 * lambda);
dP[model * 4 + 2] = (1 - dG2[2]) / (2 * lambda);
dP[model * 4 + 3] = -A / lambda * ((1 - B) * (eAt - 1) + G1) * dB[model * 4 + 3] / (G1 * G1);
}


final void dQCompute(int model, double t, double[] dQ, double eAt) {

double dwell = t - modelStartTimes[model];
double G1 = g1(eAt);

double term1 = 8 * eAt;
double term2 = G1 / 2 - eAt * (1 + B);
double term3 = eAt - 1;
double term4 = G1 * G1 * G1;
double term5 = -term1 * term3 / term4;

for (int k = 0; k < model; ++k) {
dQ[k * 4 + 0] = term5 * dB[k * 4 + 0];
dQ[k * 4 + 1] = term5 * dB[k * 4 + 1];
dQ[k * 4 + 2] = term5 * dB[k * 4 + 2];
dQ[k * 4 + 3] = term5 * dB[k * 4 + 3];
}

double term6 = term1 / term4;
double term7 = dwell * term2;

dQ[model * 4 + 0] = term6 * (dA[0] * term7 - dB[model * 4 + 0] * term3);
dQ[model * 4 + 1] = term6 * (dA[1] * term7 - dB[model * 4 + 1] * term3);
dQ[model * 4 + 2] = term6 * (dA[2] * term7 - dB[model * 4 + 2] * term3);
dQ[model * 4 + 3] = term5 * dB[model * 4 + 3];
}


final double Q(int model, double time) {
double At = A * (time - modelStartTimes[model]);
double eAt = Math.exp(At);
double sqrtDenominator = g1(eAt);
return eAt / (sqrtDenominator * sqrtDenominator);
}

final double logQ(int model, double time) {
double At = A * (time - modelStartTimes[model]);
double eAt = Math.exp(At);
double sqrtDenominator = g1(eAt);
return At - 2 * Math.log(sqrtDenominator); // TODO log4 (additive constant) is not needed since we always see logQ(a) - logQ(b)
}

@Override
public double processInterval(int model, double tYoung, double tOld, int nLineages) {
double logQ_young;
double logQ_old = Q(model, tOld);
if (!Double.isNaN(this.savedLogQ)) {
logQ_young = this.savedLogQ;
} else {
logQ_young = Q(model, tYoung);
}
this.savedLogQ = logQ_old;
return nLineages * Math.log(logQ_old / logQ_young);
}

@Override
public double processSampling(int model, double tOld) {

double logSampProb;

boolean sampleIsAtEventTime = tOld == modelStartTimes[model];
boolean samplesTakenAtEventTime = rho > 0;

if (sampleIsAtEventTime && samplesTakenAtEventTime) {
logSampProb = Math.log(rho);
if (model > 0) {
logSampProb += Math.log(r + ((1.0 - r) * previousP));
}
} else {
double logPsi = Math.log(psi);
logSampProb = logPsi + Math.log(r + (1.0 - r) * p(model,tOld));
}

return logSampProb;
}
}

/*
* Notes on inlining:
* https://www.baeldung.com/jvm-method-inlining#:~:text=Essentially%2C%20the%20JIT%20compiler%20tries,times%20we%20invoke%20the%20method.
* https://miuv.blog/2018/02/25/jit-optimizations-method-inlining/
* static, private, final
*/
Loading

0 comments on commit b03acba

Please sign in to comment.