Skip to content

Commit

Permalink
For linking stemWeight with clocks
Browse files Browse the repository at this point in the history
  • Loading branch information
afmagee committed Jul 25, 2023
1 parent 1b19880 commit 99d28e7
Show file tree
Hide file tree
Showing 5 changed files with 318 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/dr/app/beast/development_parsers.properties
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,9 @@ dr.evomodelxml.speciation.NewBDSSHistorySimulatorParser
dr.inferencexml.distribution.BayesianBridgeMarkovRandomFieldLikelihoodParser
dr.inferencexml.operators.shrinkage.DimensionMismatchedBayesianBridgeShrinkageOperatorParser
dr.evomodelxml.branchmodel.EstimableStemWeightBranchSpecificBranchModelParser
dr.inferencexml.model.DimensionAlteredTransformedMultivariateParameterParser
dr.util.TimeToDistanceProportionTransform
dr.util.TimeProportionToFixedEffectTransform

# GLM covariate importance
dr.evomodelxml.substmodel.GlmCovariateImportanceParser
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package dr.inference.model;

import dr.util.Transform;

public class DimensionAlteredTransformedMultivariateParameter extends TransformedMultivariateParameter {

public DimensionAlteredTransformedMultivariateParameter(Parameter parameter, Transform.MultivariableTransform transform) {
this(parameter, transform, false);
}

public DimensionAlteredTransformedMultivariateParameter(Parameter parameter, Transform.MultivariableTransform transform, boolean inverse) {
super(parameter, transform, inverse);
}

public int getDimension() {
return ((Transform.MultivariableTransform) transform).getDimension();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* DimensionAlteredTransformedMultivariateParameterParser.java
*
* Copyright (c) 2002-2018 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/

package dr.inferencexml.model;

import dr.inference.model.Bounds;
import dr.inference.model.Parameter;
import dr.inference.model.DimensionAlteredTransformedMultivariateParameter;
import dr.util.Transform;
import dr.xml.*;

import java.awt.*;

public class DimensionAlteredTransformedMultivariateParameterParser extends AbstractXMLObjectParser {

private static final String TRANSFORMED_MULTIVARIATE_PARAMETER = "dimensionAlteredTransformedMultivariateParameter";
public static final String INVERSE = "inverse";
private static final String BOUNDS = "bounds";

public Object parseXMLObject(XMLObject xo) throws XMLParseException {

final Parameter parameter = (Parameter) xo.getChild(Parameter.class);
Transform.MultivariableTransform transform = (Transform.MultivariableTransform)
xo.getChild(Transform.MultivariableTransform.class);
final boolean inverse = xo.getAttribute(INVERSE, false);

DimensionAlteredTransformedMultivariateParameter transformedParameter = new DimensionAlteredTransformedMultivariateParameter(parameter, transform, inverse);
if (xo.hasChildNamed(BOUNDS)) {
Bounds<Double> bounds = ((Parameter) xo.getElementFirstChild(BOUNDS)).getBounds();
transformedParameter.addBounds(bounds);
} else {
transformedParameter.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY,
Double.NEGATIVE_INFINITY, parameter.getDimension()));
}
return transformedParameter;
}

public XMLSyntaxRule[] getSyntaxRules() {
return rules;
}

private final XMLSyntaxRule[] rules = {
new ElementRule(Parameter.class),
new ElementRule(Transform.MultivariableTransform.class),
AttributeRule.newBooleanRule(INVERSE, true),

};

public String getParserDescription() {
return "A transformed multivariate parameter whose dimension is not the same as the original parameter's dimension.";
}

public Class getReturnType() {
return DimensionAlteredTransformedMultivariateParameter.class;
}

public String getParserName() {
return TRANSFORMED_MULTIVARIATE_PARAMETER;
}
}
107 changes: 107 additions & 0 deletions src/dr/util/TimeProportionToFixedEffectTransform.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package dr.util;

import dr.xml.AbstractXMLObjectParser;
import dr.xml.XMLObject;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;

public class TimeProportionToFixedEffectTransform extends Transform.MultivariateTransform {
public static String NAME = "TimeProportionToFixedEffectTransform";

// The variables are assumed to be in the following order:
// 0: the proportion in time for forward transform, the fixed-effect for reverse transform
// 1: the log-scale rate in the ancestral-model portion of the branch
// 2: the log-scale rate in the descendant-model portion of the branch
// The input is these as a length-3 vector, the output is a length-1 vector of the transformed value
TimeProportionToFixedEffectTransform() {
super(3, 1);
}

@Override
public double[] inverse(double[] values, int from, int to, double sum) {
throw new RuntimeException("Not yet implemented");
}

@Override
public double[] gradient(double[] values, int from, int to) {
throw new RuntimeException("Not yet implemented");
}

@Override
public double[] gradientInverse(double[] values, int from, int to) {
throw new RuntimeException("Not yet implemented");
}

@Override
public String getTransformName() {
return NAME;
}

@Override
protected double[] transform(double[] values) {
double propTime = values[0];
double rateAncestral = Math.exp(values[1]);
double rateDescendant = Math.exp(values[2]);
double[] transformed = new double[1];
transformed[0] = Math.log((rateDescendant * propTime + rateAncestral * (1.0 - propTime)) / (rateDescendant * propTime));
return transformed;
}

@Override
protected double[] inverse(double[] values) {
throw new RuntimeException("Not yet implemented");
}

@Override
protected double getLogJacobian(double[] values) {
throw new RuntimeException("Not yet implemented");
}

@Override
protected double[] getGradientLogJacobianInverse(double[] values) {
throw new RuntimeException("Not yet implemented");
}

@Override
public double[][] computeJacobianMatrixInverse(double[] values) {
throw new RuntimeException("Not yet implemented");
}

@Override
protected boolean isInInteriorDomain(double[] values) {
// Only the proportion is bounded
return values[0] >= 0.0 && values[0] <= 1.0;
}

public static final AbstractXMLObjectParser PARSER = new AbstractXMLObjectParser() {

@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
final String name = xo.hasId() ? xo.getId() : null;

TimeProportionToFixedEffectTransform transform = new TimeProportionToFixedEffectTransform();

return transform;
}

@Override
public XMLSyntaxRule[] getSyntaxRules() {
return new XMLSyntaxRule[0];
}

@Override
public String getParserDescription() {
return null;
}

@Override
public Class getReturnType() {
return Transform.MultivariateTransform.class;
}

@Override
public String getParserName() {
return NAME;
}
};
}
107 changes: 107 additions & 0 deletions src/dr/util/TimeToDistanceProportionTransform.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package dr.util;

import dr.xml.AbstractXMLObjectParser;
import dr.xml.XMLObject;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;

public class TimeToDistanceProportionTransform extends Transform.MultivariateTransform {
public static String NAME = "TimeToDistanceProportionTransform";

// The variables are assumed to be in the following order:
// 0: the proportion (in time for forward transform, in distance for reverse transform)
// 1: the log-scale rate in the ancestral-model portion of the branch
// 2: the log-scale rate in the descendant-model portion of the branch
// The input is these as a length-3 vector, the output is a length-1 vector of the transformed proportion
TimeToDistanceProportionTransform() {
super(3, 1);
}

@Override
public double[] inverse(double[] values, int from, int to, double sum) {
throw new RuntimeException("Not yet implemented");
}

@Override
public double[] gradient(double[] values, int from, int to) {
throw new RuntimeException("Not yet implemented");
}

@Override
public double[] gradientInverse(double[] values, int from, int to) {
throw new RuntimeException("Not yet implemented");
}

@Override
public String getTransformName() {
return NAME;
}

@Override
protected double[] transform(double[] values) {
double propTime = values[0];
double rateAncestral = Math.exp(values[1]);
double rateDescendant = Math.exp(values[2]);
double[] transformed = new double[1];
transformed[0] = (rateDescendant * propTime) / (rateDescendant * propTime + rateAncestral * (1.0 - propTime));
return transformed;
}

@Override
protected double[] inverse(double[] values) {
throw new RuntimeException("Not yet implemented");
}

@Override
protected double getLogJacobian(double[] values) {
throw new RuntimeException("Not yet implemented");
}

@Override
protected double[] getGradientLogJacobianInverse(double[] values) {
throw new RuntimeException("Not yet implemented");
}

@Override
public double[][] computeJacobianMatrixInverse(double[] values) {
throw new RuntimeException("Not yet implemented");
}

@Override
protected boolean isInInteriorDomain(double[] values) {
// Only the proportion is bounded
return values[0] >= 0.0 && values[0] <= 1.0;
}

public static final AbstractXMLObjectParser PARSER = new AbstractXMLObjectParser() {

@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
final String name = xo.hasId() ? xo.getId() : null;

TimeToDistanceProportionTransform transform = new TimeToDistanceProportionTransform();

return transform;
}

@Override
public XMLSyntaxRule[] getSyntaxRules() {
return new XMLSyntaxRule[0];
}

@Override
public String getParserDescription() {
return null;
}

@Override
public Class getReturnType() {
return Transform.MultivariateTransform.class;
}

@Override
public String getParserName() {
return NAME;
}
};
}

0 comments on commit 99d28e7

Please sign in to comment.