Skip to content

Commit

Permalink
Merge branch 'hmc-clock' of https://github.com/beast-dev/beast-mcmc i…
Browse files Browse the repository at this point in the history
…nto hmc-clock
  • Loading branch information
afmagee committed Aug 1, 2023
2 parents 30a6644 + c2ecacb commit 8c88b8b
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 70 deletions.
16 changes: 8 additions & 8 deletions src/dr/evomodel/speciation/MasBirthDeathSerialSamplingModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ public final double processModelSegmentBreakPoint(int model, double intervalStar
return lnL;
}

final void accumulateGradientForInterval(final double[] gradient, final int currentModelSegment, final int nLineages,
final double[] partialQ_all_old, final double Q_Old,
final double[] partialQ_all_young, final double Q_young) {
void accumulateGradientForInterval(final double[] gradient, final int currentModelSegment, final int nLineages,
final double[] partialQ_all_old, final double Q_Old,
final double[] partialQ_all_young, final double Q_young) {

for (int k = 0; k <= currentModelSegment; k++) {
gradient[k * 5 + 0] += nLineages * (partialQ_all_old[k * 4 + 0] / Q_Old
Expand All @@ -61,7 +61,7 @@ final void accumulateGradientForInterval(final double[] gradient, final int curr
}
}

final void accumulateGradientForSerialSampling(double[] gradient, int currentModelSegment, double term1,
void accumulateGradientForSerialSampling(double[] gradient, int currentModelSegment, double term1,
double[] intermediate) {

for (int k = 0; k <= currentModelSegment; k++) {
Expand All @@ -71,7 +71,7 @@ final void accumulateGradientForSerialSampling(double[] gradient, int currentMod
}
}

final void accumulateGradientForIntensiveSampling(double[] gradient, int currentModelSegment, double term1,
void accumulateGradientForIntensiveSampling(double[] gradient, int currentModelSegment, double term1,
double[] intermediate) {

for (int k = 0; k < currentModelSegment; k++) {
Expand All @@ -81,7 +81,7 @@ final void accumulateGradientForIntensiveSampling(double[] gradient, int current
}
}

final void dBCompute(int model, double[] dB) {
void dBCompute(int model, double[] dB) {

for (int k = 0; k < model; ++k) {
for (int p = 0; p < 4; p++) {
Expand All @@ -96,7 +96,7 @@ final void dBCompute(int model, double[] dB) {
dB[model * 4 + 2] = (A - dA[2] * (term1 * lambda + mu + psi)) / (A * A);
}

final void dPCompute(int model, double t, double intervalStart, double eAt, double[] dP, double[] dG2) {
void dPCompute(int model, double t, double intervalStart, double eAt, double[] dP, double[] dG2) {

double G1 = g1(eAt);

Expand All @@ -121,7 +121,7 @@ final void dPCompute(int model, double t, double intervalStart, double eAt, doub
}


final void dQCompute(int model, double t, double[] dQ, double eAt) {
void dQCompute(int model, double t, double[] dQ, double eAt) {

double dwell = t - modelStartTimes[model];
double G1 = g1(eAt);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,25 +27,12 @@

import dr.inference.model.Parameter;

public class TwoParamBirthDeathSerialSamplingModel extends NewBirthDeathSerialSamplingModel {
public class TwoParamBirthDeathSerialSamplingModel extends MasBirthDeathSerialSamplingModel {

public TwoParamBirthDeathSerialSamplingModel(Parameter birthRate, Parameter deathRate, Parameter serialSamplingRate, Parameter treatmentProbability, Parameter samplingProbability, Parameter originTime, boolean condition, int numIntervals, double gridEnd, Type units) {
super(birthRate, deathRate, serialSamplingRate, treatmentProbability, samplingProbability, originTime, condition, numIntervals, gridEnd, units);
}

@Override
public final double processModelSegmentBreakPoint(int model, double intervalStart, double intervalEnd, int nLineages) {
// double lnL = nLineages * (logQ(model, intervalEnd) - logQ(model, intervalStart));
double lnL = nLineages * Math.log(Q(model, intervalEnd) / Q(model, intervalStart));
if ( samplingProbability.getValue(model + 1) > 0.0 && samplingProbability.getValue(model + 1) < 1.0) {
// Add in probability of un-sampled lineages
// We don't need this at t=0 because all lineages in the tree are sampled
// TODO: check if we're right about how many lineages are actually alive at this time. Are we inadvertently over-counting or under-counting due to samples added at this _exact_ time?
lnL += nLineages * Math.log(1.0 - samplingProbability.getValue(model + 1));
}
this.savedLogQ = Double.NaN;
return lnL;
}

final void accumulateGradientForInterval(final double[] gradient, final int currentModelSegment, final int nLineages,
final double[] partialQ_all_old, final double Q_Old,
Expand Down Expand Up @@ -140,54 +127,6 @@ final void dQCompute(int model, double t, double[] dQ, double eAt) {
dQ[model * 4 + 2] = term6 * (dA[2] * term7 - dB[model * 4 + 2] * term3);
}


final double Q(int model, double time) {
double At = A * (time - modelStartTimes[model]);
double eAt = Math.exp(At);
double sqrtDenominator = g1(eAt);
return eAt / (sqrtDenominator * sqrtDenominator);
}

final double logQ(int model, double time) {
double At = A * (time - modelStartTimes[model]);
double eAt = Math.exp(At);
double sqrtDenominator = g1(eAt);
return At - 2 * Math.log(sqrtDenominator); // TODO log4 (additive constant) is not needed since we always see logQ(a) - logQ(b)
}

@Override
public double processInterval(int model, double tYoung, double tOld, int nLineages) {
double logQ_young;
double logQ_old = Q(model, tOld);
if (!Double.isNaN(this.savedLogQ)) {
logQ_young = this.savedLogQ;
} else {
logQ_young = Q(model, tYoung);
}
this.savedLogQ = logQ_old;
return nLineages * Math.log(logQ_old / logQ_young);
}

@Override
public double processSampling(int model, double tOld) {

double logSampProb;

boolean sampleIsAtEventTime = tOld == modelStartTimes[model];
boolean samplesTakenAtEventTime = rho > 0;

if (sampleIsAtEventTime && samplesTakenAtEventTime) {
logSampProb = Math.log(rho);
if (model > 0) {
logSampProb += Math.log(r + ((1.0 - r) * previousP));
}
} else {
double logPsi = Math.log(psi);
logSampProb = logPsi + Math.log(r + (1.0 - r) * p(model,tOld));
}

return logSampProb;
}
}

/*
Expand Down

0 comments on commit 8c88b8b

Please sign in to comment.