Skip to content

Commit

Permalink
subclasses for gradient calculations
Browse files Browse the repository at this point in the history
  • Loading branch information
yucais committed Jul 31, 2023
1 parent c7f3fcd commit c21fd5b
Show file tree
Hide file tree
Showing 5 changed files with 404 additions and 344 deletions.
200 changes: 4 additions & 196 deletions ci/TestXML/testEssbdpGradient.xml
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@
<samplingRate>
<parameter id="bdss1.samplingRate" value ="3.42 2.95 1.21" lower="0.0"/>
</samplingRate>
<samplingProbability>
<samplingProbability gradientFlag = "false">
<parameter id="bdss1.samplingProbability" value ="0.2 0.8 0.6" lower="0.0" upper="1.0"/>
</samplingProbability>
<treatmentProbability>
<treatmentProbability gradientFlag = "false">
<parameter id="bdss1.treatmentProbability" value ="0.4 0.3 0.1" lower="0.0" upper="1.0"/>
</treatmentProbability>
<origin>
Expand Down Expand Up @@ -113,9 +113,6 @@
<speciationLikelihoodGradient idref="grad.birthRate1"/>
</report>

<report>
<speciationLikelihoodGradient idref="grad.deathRate1"/>
</report>

<report>
<speciationLikelihoodGradient idref="grad.samplingRate1"/>
Expand All @@ -137,10 +134,10 @@
<samplingRate>
<parameter id="bdss2.samplingRate" value ="3.42 2.95 1.21" lower="0.0"/>
</samplingRate>
<samplingProbability>
<samplingProbability gradientFlag = "false">
<parameter id="bdss2.samplingProbability" value ="0.07 0.22 0.18" lower="0.0" upper="1.0"/>
</samplingProbability>
<treatmentProbability>
<treatmentProbability gradientFlag = "false">
<parameter id="bdss2.treatmentProbability" value ="0.89 0.93 0.79" lower="0.0" upper="1.0"/>
</treatmentProbability>
<origin>
Expand Down Expand Up @@ -193,200 +190,11 @@
<speciationLikelihoodGradient idref="grad.birthRate2"/>
</report>

<report>
<speciationLikelihoodGradient idref="grad.deathRate2"/>
</report>

<report>
<speciationLikelihoodGradient idref="grad.samplingRate2"/>
</report>

<report>
<speciationLikelihoodGradient idref="grad.samplingProbability2"/>
</report>

<report>
<speciationLikelihoodGradient idref="grad.treatmentProbability2"/>
</report>

<!-- ########## -->
<!-- # CASE 3 # -->
<!-- ########## -->

<!-- macroevolution case -->

<newick id="startingTree2" usingHeights="true" usingDates="false">
((D:0.353,C:0.353):0.303,(B:0.471,A:0.471):0.185)
</newick>

<!-- Generate a tree model -->
<treeModel id="treeModel2">
<newick idref="startingTree2"/>
<rootHeight>
<parameter id="treeModel2.rootHeight"/>
</rootHeight>
<nodeHeights internalNodes="true">
<parameter id="treeModel2.internalNodeHeights"/>
</nodeHeights>
<nodeHeights internalNodes="true" rootNode="true">
<parameter id="treeModel2.allInternalNodeHeights"/>
</nodeHeights>
</treeModel>


<newBirthDeathSerialSampling id="bdss3" units="years" conditionOnSurvival="false">
<birthRate>
<parameter id="bdss3.birthRate" value ="8.47 4.66 1.82" lower="0.0"/>
</birthRate>
<deathRate>
<parameter id="bdss3.deathRate" value ="4.89 3.67 3.23" lower="0.0"/>
</deathRate>
<samplingRate>
<parameter id="bdss3.samplingRate" value ="3.42 2.95 1.21" lower="0.0"/>
</samplingRate>
<samplingProbability>
<parameter id="bdss3.samplingProbability" value ="1 0 0" lower="0.0" upper="1.0"/>
</samplingProbability>
<treatmentProbability>
<parameter id="bdss3.treatmentProbability" value ="0 0 0" lower="0.0" upper="1.0"/>
</treatmentProbability>
<origin>
<parameter id="bdss3.origin" value="0.656" lower="0.0"/>
</origin>
<cutOff>
<parameter value="0.656"/>
</cutOff>
<numGridPoints>
<parameter value="3"/>
</numGridPoints>
</newBirthDeathSerialSampling>

<!-- Generate a speciation likelihood for Yule or Birth Death -->
<speciationLikelihood id="speciation3" useNewLoop="true">
<model>
<newBirthDeathSerialSampling idref="bdss3"/>
</model>
<speciesTree>
<treeModel idref="treeModel2"/>
</speciesTree>
</speciationLikelihood>

<speciationLikelihoodGradient id="grad.birthRate3" wrtParameter="birthRate" useNewLoop="true">
<speciationLikelihood idref="speciation3"/>
<treeModel idref="treeModel2"/>
</speciationLikelihoodGradient>


<speciationLikelihoodGradient id="grad.deathRate3" wrtParameter="deathRate" useNewLoop="true">
<speciationLikelihood idref="speciation3"/>
<treeModel idref="treeModel2"/>
</speciationLikelihoodGradient>

<speciationLikelihoodGradient id="grad.samplingRate3" wrtParameter="samplingRate" useNewLoop="true">
<speciationLikelihood idref="speciation3"/>
<treeModel idref="treeModel2"/>
</speciationLikelihoodGradient>



<report>
<speciationLikelihoodGradient idref="grad.birthRate3"/>
</report>

<report>
<speciationLikelihoodGradient idref="grad.deathRate3"/>
</report>

<report>
<speciationLikelihoodGradient idref="grad.samplingRate3"/>
</report>


<!-- ########## -->
<!-- # CASE 4 # -->
<!-- ########## -->

<!-- Origin is in different interval from rest of tree -->
<newBirthDeathSerialSampling id="new.bdss4" units="years" hasFinalSample="false" conditionOnSurvival="false">
<birthRate>
<parameter id="bdss4.birthRate" value ="8.47 2.88 1.82" lower="0.0"/>
</birthRate>
<deathRate>
<parameter id="bdss4.deathRate" value ="4.89 2.67 3.23" lower="0.0"/>
</deathRate>
<samplingRate>
<parameter id="bdss4.samplingRate" value ="3.42 2.95 1.21" lower="0.0"/>
</samplingRate>
<samplingProbability>
<parameter id="bdss4.samplingProbability" value ="0.07 0.22 0.18" lower="0.0" upper="1.0"/>
</samplingProbability>
<treatmentProbability>
<parameter id="bdss4.treatmentProbability" value ="0.89 0.93 0.79" lower="0.0" upper="1.0"/>
</treatmentProbability>
<origin>
<parameter id="bdss4.origin" value="20.0" lower="0.0"/>
</origin>
<cutOff>
<parameter value="18.0"/>
</cutOff>
<numGridPoints>
<parameter value="3"/>
</numGridPoints>
</newBirthDeathSerialSampling>

<speciationLikelihood id="speciation4" useNewLoop="true">
<model>
<newBirthDeathSerialSampling idref="new.bdss4"/>
</model>
<speciesTree>
<treeModel idref="treeModel"/>
</speciesTree>
</speciationLikelihood>

<speciationLikelihoodGradient id="grad.birthRate4" wrtParameter="birthRate" useNewLoop="true">
<speciationLikelihood idref="speciation4"/>
<treeModel idref="treeModel"/>
</speciationLikelihoodGradient>

<speciationLikelihoodGradient id="grad.deathRate4" wrtParameter="deathRate" useNewLoop="true">
<speciationLikelihood idref="speciation4"/>
<treeModel idref="treeModel"/>
</speciationLikelihoodGradient>

<speciationLikelihoodGradient id="grad.samplingRate4" wrtParameter="samplingRate" useNewLoop="true">
<speciationLikelihood idref="speciation4"/>
<treeModel idref="treeModel"/>
</speciationLikelihoodGradient>

<speciationLikelihoodGradient id="grad.samplingProbability4" wrtParameter="samplingProbability" useNewLoop="true">
<speciationLikelihood idref="speciation4"/>
<treeModel idref="treeModel"/>
</speciationLikelihoodGradient>

<speciationLikelihoodGradient id="grad.treatmentProbability4" wrtParameter="treatmentProbability" useNewLoop="true">
<speciationLikelihood idref="speciation4"/>
<treeModel idref="treeModel"/>
</speciationLikelihoodGradient>


<report>
<speciationLikelihoodGradient idref="grad.birthRate4"/>
</report>

<report>
<speciationLikelihoodGradient idref="grad.deathRate4"/>
</report>

<report>
<speciationLikelihoodGradient idref="grad.samplingRate4"/>
</report>

<report>
<speciationLikelihoodGradient idref="grad.samplingProbability4"/>
</report>

<report>
<speciationLikelihoodGradient idref="grad.treatmentProbability4"/>
</report>

</beast>
38 changes: 25 additions & 13 deletions src/dr/evomodel/speciation/MasBirthDeathSerialSamplingModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,33 +52,48 @@ final void accumulateGradientForInterval(final double[] gradient, final int curr
final double[] partialQ_all_young, final double Q_young) {

for (int k = 0; k <= currentModelSegment; k++) {
for (int p = 0; p < 4; ++p) {
gradient[k * 5 + p] += nLineages * (partialQ_all_old[k * 4 + p] / Q_Old
- partialQ_all_young[k * 4 + p] / Q_young);
}
gradient[k * 5 + 0] += nLineages * (partialQ_all_old[k * 4 + 0] / Q_Old
- partialQ_all_young[k * 4 + 0] / Q_young);
gradient[k * 5 + 1] += nLineages * (partialQ_all_old[k * 4 + 1] / Q_Old
- partialQ_all_young[k * 4 + 1] / Q_young);
gradient[k * 5 + 2] += nLineages * (partialQ_all_old[k * 4 + 2] / Q_Old
- partialQ_all_young[k * 4 + 2] / Q_young);
}
}

final void accumulateGradientForSerialSampling(double[] gradient, int currentModelSegment, double term1,
double[] intermediate) {

for (int k = 0; k <= currentModelSegment; k++) {
for (int p = 0; p < 4; ++p) {
gradient[k * 5 + p] += term1 * intermediate[k * 4 + p];
}
gradient[k * 5 + 0] += term1 * intermediate[k * 4 + 0];
gradient[k * 5 + 1] += term1 * intermediate[k * 4 + 1];
gradient[k * 5 + 2] += term1 * intermediate[k * 4 + 2];
}

}

final void accumulateGradientForIntensiveSampling(double[] gradient, int currentModelSegment, double term1,
double[] intermediate) {

for (int k = 0; k < currentModelSegment; k++) {
for (int p = 0; p < 4; ++p) {
gradient[k * 5 + p] += term1 * intermediate[k * 4 + p];
gradient[k * 5 + 0] += term1 * intermediate[k * 4 + 0];
gradient[k * 5 + 1] += term1 * intermediate[k * 4 + 1];
gradient[k * 5 + 2] += term1 * intermediate[k * 4 + 2];
}
}

final void dBCompute(int model, double[] dB) {

for (int k = 0; k < model; ++k) {
for (int p = 0; p < 4; p++) {
dB[k * 4 + p] = -2 * (1 - rho) * lambda / A * dPModelEnd[k * 4 + p];
}
}

double term1 = 1 - 2 * (1 - rho) * previousP;

dB[model * 4 + 0] = (A * term1 - dA[0] * (term1 * lambda + mu + psi)) / (A * A);
dB[model * 4 + 1] = (A - dA[1] * (term1 * lambda + mu + psi)) / (A * A);
dB[model * 4 + 2] = (A - dA[2] * (term1 * lambda + mu + psi)) / (A * A);
}

final void dPCompute(int model, double t, double intervalStart, double eAt, double[] dP, double[] dG2) {
Expand All @@ -103,7 +118,6 @@ final void dPCompute(int model, double t, double intervalStart, double eAt, doub
dP[model * 4 + 0] = (-mu - psi - lambda * dG2[0] + G2) / (2 * lambda * lambda);
dP[model * 4 + 1] = (1 - dG2[1]) / (2 * lambda);
dP[model * 4 + 2] = (1 - dG2[2]) / (2 * lambda);
dP[model * 4 + 3] = -A / lambda * ((1 - B) * (eAt - 1) + G1) * dB[model * 4 + 3] / (G1 * G1);
}


Expand All @@ -122,7 +136,6 @@ final void dQCompute(int model, double t, double[] dQ, double eAt) {
dQ[k * 4 + 0] = term5 * dB[k * 4 + 0];
dQ[k * 4 + 1] = term5 * dB[k * 4 + 1];
dQ[k * 4 + 2] = term5 * dB[k * 4 + 2];
dQ[k * 4 + 3] = term5 * dB[k * 4 + 3];
}

double term6 = term1 / term4;
Expand All @@ -131,7 +144,6 @@ final void dQCompute(int model, double t, double[] dQ, double eAt) {
dQ[model * 4 + 0] = term6 * (dA[0] * term7 - dB[model * 4 + 0] * term3);
dQ[model * 4 + 1] = term6 * (dA[1] * term7 - dB[model * 4 + 1] * term3);
dQ[model * 4 + 2] = term6 * (dA[2] * term7 - dB[model * 4 + 2] * term3);
dQ[model * 4 + 3] = term5 * dB[model * 4 + 3];
}


Expand Down
Loading

0 comments on commit c21fd5b

Please sign in to comment.