Skip to content

Commit

Permalink
gradients for conditioning on survival
Browse files Browse the repository at this point in the history
  • Loading branch information
yucais committed Oct 3, 2023
1 parent 359de1b commit 1593fd8
Show file tree
Hide file tree
Showing 8 changed files with 50 additions and 30 deletions.
16 changes: 8 additions & 8 deletions ci/TestXML/testCrssbdpGradient.xml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
<!-- ########## -->

<!--full model-->
<newBirthDeathSerialSampling id="bdss1" units="years" conditionOnSurvival="false">
<newBirthDeathSerialSampling id="bdss1" units="years" conditionOnSurvival="true">
<birthRate>
<parameter id="bdss1.birthRate" value="8.472" lower="0.0"/>
</birthRate>
Expand Down Expand Up @@ -130,7 +130,7 @@
<!-- ########## -->

<!-- rho = 1 case -->
<newBirthDeathSerialSampling id="bdss2" units="years" conditionOnSurvival="false">
<newBirthDeathSerialSampling id="bdss2" units="years" conditionOnSurvival="true">
<birthRate>
<parameter id="bdss2.birthRate" value="8.472" lower="0.0"/>
</birthRate>
Expand Down Expand Up @@ -211,7 +211,7 @@
<!-- ########## -->

<!-- rho = 0 case -->
<newBirthDeathSerialSampling id="bdss3" units="years" conditionOnSurvival="false">
<newBirthDeathSerialSampling id="bdss3" units="years" conditionOnSurvival="true">
<birthRate>
<parameter id="bdss3.birthRate" value="8.472" lower="0.0"/>
</birthRate>
Expand Down Expand Up @@ -283,7 +283,7 @@
<!-- ########## -->

<!-- r = 1 case -->
<newBirthDeathSerialSampling id="bdss4" units="years" conditionOnSurvival="false">
<newBirthDeathSerialSampling id="bdss4" units="years" conditionOnSurvival="true">
<birthRate>
<parameter id="bdss4.birthRate" value="8.472" lower="0.0"/>
</birthRate>
Expand Down Expand Up @@ -364,7 +364,7 @@
<!-- ########## -->

<!-- r = 0 case -->
<newBirthDeathSerialSampling id="bdss5" units="years" conditionOnSurvival="false">
<newBirthDeathSerialSampling id="bdss5" units="years" conditionOnSurvival="true">
<birthRate>
<parameter id="bdss5.birthRate" value="8.472" lower="0.0"/>
</birthRate>
Expand Down Expand Up @@ -435,7 +435,7 @@
<!-- ########## -->

<!-- mu = 0, r > 0 case -->
<newBirthDeathSerialSampling id="bdss6" units="years" conditionOnSurvival="false">
<newBirthDeathSerialSampling id="bdss6" units="years" conditionOnSurvival="true">
<birthRate>
<parameter id="bdss6.birthRate" value="8.472" lower="0.0"/>
</birthRate>
Expand Down Expand Up @@ -525,7 +525,7 @@
</treeModel>

<!-- psi = 0, 0 < rho < 1 case -->
<newBirthDeathSerialSampling id="bdss7" units="years" conditionOnSurvival="false">
<newBirthDeathSerialSampling id="bdss7" units="years" conditionOnSurvival="true">
<birthRate>
<parameter id="bdss7.birthRate" value="8.472" lower="0.0"/>
</birthRate>
Expand Down Expand Up @@ -590,7 +590,7 @@
<!-- ########## -->

<!-- psi = 0, rho = 1 case -->
<newBirthDeathSerialSampling id="bdss8" units="years" conditionOnSurvival="false">
<newBirthDeathSerialSampling id="bdss8" units="years" conditionOnSurvival="true">
<birthRate>
<parameter id="bdss8.birthRate" value="8.472" lower="0.0"/>
</birthRate>
Expand Down
8 changes: 4 additions & 4 deletions ci/TestXML/testEssbdpGradient.xml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
<!-- ########## -->

<!-- multiple rates model with three intervals -->
<newBirthDeathSerialSampling id="new.bdss1" units="years" hasFinalSample="false" conditionOnSurvival="false">
<newBirthDeathSerialSampling id="new.bdss1" units="years" hasFinalSample="false" conditionOnSurvival="true">
<birthRate>
<parameter id="bdss1.birthRate" value ="8.47 2.88 1.82" lower="0.0"/>
</birthRate>
Expand Down Expand Up @@ -127,7 +127,7 @@
<!-- ########## -->

<!-- multiple rates model with three intervals and one intensive sampling event at t1 = 2 -->
<newBirthDeathSerialSampling id="new.bdss2" units="years" hasFinalSample="false" conditionOnSurvival="false">
<newBirthDeathSerialSampling id="new.bdss2" units="years" hasFinalSample="false" conditionOnSurvival="true">
<birthRate>
<parameter id="bdss2.birthRate" value ="8.47 2.88 1.82" lower="0.0"/>
</birthRate>
Expand Down Expand Up @@ -234,7 +234,7 @@
</treeModel>


<newBirthDeathSerialSampling id="bdss3" units="years" conditionOnSurvival="false">
<newBirthDeathSerialSampling id="bdss3" units="years" conditionOnSurvival="true">
<birthRate>
<parameter id="bdss3.birthRate" value ="8.47 4.66 1.82" lower="0.0"/>
</birthRate>
Expand Down Expand Up @@ -307,7 +307,7 @@
<!-- ########## -->

<!-- Origin is in different interval from rest of tree -->
<newBirthDeathSerialSampling id="new.bdss4" units="years" hasFinalSample="false" conditionOnSurvival="false">
<newBirthDeathSerialSampling id="new.bdss4" units="years" hasFinalSample="false" conditionOnSurvival="true">
<birthRate>
<parameter id="bdss4.birthRate" value ="8.47 2.88 1.82" lower="0.0"/>
</birthRate>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -452,9 +452,7 @@ private double ti(int i) {
}

@Override
public double logConditioningProbability() {
return 0;
}
public double logConditioningProbability(int model) {return 0;}

private double partialApartialLambda(int i) {
return (lambda(i) - mu(i) + psi(i)) / Ai[i];
Expand Down Expand Up @@ -733,7 +731,7 @@ public void processGradientSampling(double[] gradient, int currentModelSegment,
}

@Override
public void logConditioningProbability(double[] gradient) {
public void logConditioningProbability(int currentModelSegment, double[] gradient) {
return;
}
@Override
Expand Down
2 changes: 1 addition & 1 deletion src/dr/evomodel/speciation/CachedGradientDelegate.java
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ private double[] getGradientLogDensityImpl() {
// origin branch is a fake branch that doesn't exist in the tree, now compute its contribution
provider.processGradientOrigin(gradient, currentModelSegment, treeIntervals.getTotalDuration());

provider.logConditioningProbability(gradient);
provider.logConditioningProbability(currentModelSegment,gradient);

if (MEASURE_RUN_TIME) {
timer.stop();
Expand Down
2 changes: 1 addition & 1 deletion src/dr/evomodel/speciation/EfficientSpeciationLikelihood.java
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ final BigFastTreeIntervals getTreeIntervals() {
// origin branch is a fake branch that doesn't exist in the tree, now compute its contribution
logL += speciationModel.processOrigin(currentModelSegment, treeIntervals.getTotalDuration());

logL += speciationModel.logConditioningProbability();
logL += speciationModel.logConditioningProbability(currentModelSegment);

if (MEASURE_RUN_TIME) {
timer.stop();
Expand Down
42 changes: 32 additions & 10 deletions src/dr/evomodel/speciation/NewBirthDeathSerialSamplingModel.java
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -288,19 +288,21 @@ private static double computeB(double lambda, double mu, double psi, double rho,
}

@Override
public double logConditioningProbability() {
public double logConditioningProbability(int model) {
double logP = 0.0;
if ( conditionOnSurvival ) {
double origin = originTime.getParameterValue(0);
double[] modelBreakPoints = getBreakPoints();
int idx = modelBreakPoints.length - 2;
double intervalStart = modelBreakPoints[idx];
// Origin is probably near the last index, so we do a linear search forwards in time from there
while ( origin < intervalStart) {
--idx;
intervalStart = modelBreakPoints[idx];

double segmentIntervalEnd = modelBreakPoints[model];

while (origin >= segmentIntervalEnd) { // TODO Maybe it's >= ?
++model;
updateLikelihoodModelValues(model);
segmentIntervalEnd = modelBreakPoints[model];
}
logP -= Math.log(1.0 - p(idx, origin));

logP -= Math.log(1.0 - p(model, origin));
}
return logP;
}
Expand Down Expand Up @@ -1021,9 +1023,29 @@ public void processGradientOrigin(double[] gradient, int currentModelSegment, do
}

@Override
public void logConditioningProbability(double[] gradient) {
public void logConditioningProbability(int model, double[] gradient) {
double grad = 0.0;
double[] dPOrigin = new double[numIntervals * 4];
if ( conditionOnSurvival ) {
throw new RuntimeException("Cannot yet condition ESSBDP for gradient.");
double origin = originTime.getParameterValue(0);
double[] modelBreakPoints = getBreakPoints();
double intervalStart = model > 0? modelBreakPoints[model-1]:0;
double segmentIntervalEnd = modelBreakPoints[model];

while (origin >= segmentIntervalEnd) { // TODO Maybe it's >= ?
intervalStart = segmentIntervalEnd;
++model;
updateGradientModelValues(model);
segmentIntervalEnd = modelBreakPoints[model];
}
double eAt_Origin = Math.exp(A * (origin - intervalStart));
grad += 1 /(1.0 - p(eAt_Origin));
dPCompute(model, origin, intervalStart, eAt_Origin, dPOrigin, dG2);
for (int p = 0; p < 4; p++) {
for (int k = 0; k <= model; k++) {
gradient[genericIndex(k, p, numIntervals)] += grad * dPOrigin[k * 4 + p];
}
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/dr/evomodel/speciation/SpeciationModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ public void updateModelValues() {
throw new RuntimeException("Not implemented");
}

public double logConditioningProbability() {
public double logConditioningProbability(int model) {
throw new RuntimeException("Not implemented");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ default void processGradientOrigin(double[] gradient,
int currentModelSegment,
double totalDuration) { throw new RuntimeException("Not yet implemented"); }

default void logConditioningProbability(double[] gradient) { throw new RuntimeException("Not yet implemented"); }
default void logConditioningProbability(int currentModelSegment, double[] gradient) { throw new RuntimeException("Not yet implemented"); }

default void updateGradientModelValues(int currentModelSegment) { throw new RuntimeException("Not yet implemented"); }

Expand Down

0 comments on commit 1593fd8

Please sign in to comment.