From 99d28e77a4a63fa55978511da42e6c446d3807f8 Mon Sep 17 00:00:00 2001 From: Andy Magee Date: Tue, 25 Jul 2023 11:01:59 -0700 Subject: [PATCH] For linking stemWeight with clocks --- .../app/beast/development_parsers.properties | 3 + ...teredTransformedMultivariateParameter.java | 19 ++++ ...ransformedMultivariateParameterParser.java | 82 ++++++++++++++ .../TimeProportionToFixedEffectTransform.java | 107 ++++++++++++++++++ .../TimeToDistanceProportionTransform.java | 107 ++++++++++++++++++ 5 files changed, 318 insertions(+) create mode 100644 src/dr/inference/model/DimensionAlteredTransformedMultivariateParameter.java create mode 100644 src/dr/inferencexml/model/DimensionAlteredTransformedMultivariateParameterParser.java create mode 100644 src/dr/util/TimeProportionToFixedEffectTransform.java create mode 100644 src/dr/util/TimeToDistanceProportionTransform.java diff --git a/src/dr/app/beast/development_parsers.properties b/src/dr/app/beast/development_parsers.properties index 52054b39df..4ff2a620f9 100644 --- a/src/dr/app/beast/development_parsers.properties +++ b/src/dr/app/beast/development_parsers.properties @@ -361,6 +361,9 @@ dr.evomodelxml.speciation.NewBDSSHistorySimulatorParser dr.inferencexml.distribution.BayesianBridgeMarkovRandomFieldLikelihoodParser dr.inferencexml.operators.shrinkage.DimensionMismatchedBayesianBridgeShrinkageOperatorParser dr.evomodelxml.branchmodel.EstimableStemWeightBranchSpecificBranchModelParser +dr.inferencexml.model.DimensionAlteredTransformedMultivariateParameterParser +dr.util.TimeToDistanceProportionTransform +dr.util.TimeProportionToFixedEffectTransform # GLM covariate importance dr.evomodelxml.substmodel.GlmCovariateImportanceParser diff --git a/src/dr/inference/model/DimensionAlteredTransformedMultivariateParameter.java b/src/dr/inference/model/DimensionAlteredTransformedMultivariateParameter.java new file mode 100644 index 0000000000..4352713dbc --- /dev/null +++ b/src/dr/inference/model/DimensionAlteredTransformedMultivariateParameter.java @@ -0,0 +1,19 @@ +package dr.inference.model; + +import dr.util.Transform; + +public class DimensionAlteredTransformedMultivariateParameter extends TransformedMultivariateParameter { + + public DimensionAlteredTransformedMultivariateParameter(Parameter parameter, Transform.MultivariableTransform transform) { + this(parameter, transform, false); + } + + public DimensionAlteredTransformedMultivariateParameter(Parameter parameter, Transform.MultivariableTransform transform, boolean inverse) { + super(parameter, transform, inverse); + } + + public int getDimension() { + return ((Transform.MultivariableTransform) transform).getDimension(); + } + +} diff --git a/src/dr/inferencexml/model/DimensionAlteredTransformedMultivariateParameterParser.java b/src/dr/inferencexml/model/DimensionAlteredTransformedMultivariateParameterParser.java new file mode 100644 index 0000000000..8a3a21fd59 --- /dev/null +++ b/src/dr/inferencexml/model/DimensionAlteredTransformedMultivariateParameterParser.java @@ -0,0 +1,82 @@ +/* + * DimensionAlteredTransformedMultivariateParameterParser.java + * + * Copyright (c) 2002-2018 Alexei Drummond, Andrew Rambaut and Marc Suchard + * + * This file is part of BEAST. + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership and licensing. + * + * BEAST is free software; you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation; either version 2 + * of the License, or (at your option) any later version. + * + * BEAST is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with BEAST; if not, write to the + * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, + * Boston, MA 02110-1301 USA + */ + +package dr.inferencexml.model; + + import dr.inference.model.Bounds; + import dr.inference.model.Parameter; + import dr.inference.model.DimensionAlteredTransformedMultivariateParameter; + import dr.util.Transform; + import dr.xml.*; + + import java.awt.*; + +public class DimensionAlteredTransformedMultivariateParameterParser extends AbstractXMLObjectParser { + + private static final String TRANSFORMED_MULTIVARIATE_PARAMETER = "dimensionAlteredTransformedMultivariateParameter"; + public static final String INVERSE = "inverse"; + private static final String BOUNDS = "bounds"; + + public Object parseXMLObject(XMLObject xo) throws XMLParseException { + + final Parameter parameter = (Parameter) xo.getChild(Parameter.class); + Transform.MultivariableTransform transform = (Transform.MultivariableTransform) + xo.getChild(Transform.MultivariableTransform.class); + final boolean inverse = xo.getAttribute(INVERSE, false); + + DimensionAlteredTransformedMultivariateParameter transformedParameter = new DimensionAlteredTransformedMultivariateParameter(parameter, transform, inverse); + if (xo.hasChildNamed(BOUNDS)) { + Bounds bounds = ((Parameter) xo.getElementFirstChild(BOUNDS)).getBounds(); + transformedParameter.addBounds(bounds); + } else { + transformedParameter.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, + Double.NEGATIVE_INFINITY, parameter.getDimension())); + } + return transformedParameter; + } + + public XMLSyntaxRule[] getSyntaxRules() { + return rules; + } + + private final XMLSyntaxRule[] rules = { + new ElementRule(Parameter.class), + new ElementRule(Transform.MultivariableTransform.class), + AttributeRule.newBooleanRule(INVERSE, true), + + }; + + public String getParserDescription() { + return "A transformed multivariate parameter whose dimension is not the same as the original parameter's dimension."; + } + + public Class getReturnType() { + return DimensionAlteredTransformedMultivariateParameter.class; + } + + public String getParserName() { + return TRANSFORMED_MULTIVARIATE_PARAMETER; + } +} diff --git a/src/dr/util/TimeProportionToFixedEffectTransform.java b/src/dr/util/TimeProportionToFixedEffectTransform.java new file mode 100644 index 0000000000..776d9e0cea --- /dev/null +++ b/src/dr/util/TimeProportionToFixedEffectTransform.java @@ -0,0 +1,107 @@ +package dr.util; + +import dr.xml.AbstractXMLObjectParser; +import dr.xml.XMLObject; +import dr.xml.XMLParseException; +import dr.xml.XMLSyntaxRule; + +public class TimeProportionToFixedEffectTransform extends Transform.MultivariateTransform { + public static String NAME = "TimeProportionToFixedEffectTransform"; + + // The variables are assumed to be in the following order: + // 0: the proportion in time for forward transform, the fixed-effect for reverse transform + // 1: the log-scale rate in the ancestral-model portion of the branch + // 2: the log-scale rate in the descendant-model portion of the branch + // The input is these as a length-3 vector, the output is a length-1 vector of the transformed value + TimeProportionToFixedEffectTransform() { + super(3, 1); + } + + @Override + public double[] inverse(double[] values, int from, int to, double sum) { + throw new RuntimeException("Not yet implemented"); + } + + @Override + public double[] gradient(double[] values, int from, int to) { + throw new RuntimeException("Not yet implemented"); + } + + @Override + public double[] gradientInverse(double[] values, int from, int to) { + throw new RuntimeException("Not yet implemented"); + } + + @Override + public String getTransformName() { + return NAME; + } + + @Override + protected double[] transform(double[] values) { + double propTime = values[0]; + double rateAncestral = Math.exp(values[1]); + double rateDescendant = Math.exp(values[2]); + double[] transformed = new double[1]; + transformed[0] = Math.log((rateDescendant * propTime + rateAncestral * (1.0 - propTime)) / (rateDescendant * propTime)); + return transformed; + } + + @Override + protected double[] inverse(double[] values) { + throw new RuntimeException("Not yet implemented"); + } + + @Override + protected double getLogJacobian(double[] values) { + throw new RuntimeException("Not yet implemented"); + } + + @Override + protected double[] getGradientLogJacobianInverse(double[] values) { + throw new RuntimeException("Not yet implemented"); + } + + @Override + public double[][] computeJacobianMatrixInverse(double[] values) { + throw new RuntimeException("Not yet implemented"); + } + + @Override + protected boolean isInInteriorDomain(double[] values) { + // Only the proportion is bounded + return values[0] >= 0.0 && values[0] <= 1.0; + } + + public static final AbstractXMLObjectParser PARSER = new AbstractXMLObjectParser() { + + @Override + public Object parseXMLObject(XMLObject xo) throws XMLParseException { + final String name = xo.hasId() ? xo.getId() : null; + + TimeProportionToFixedEffectTransform transform = new TimeProportionToFixedEffectTransform(); + + return transform; + } + + @Override + public XMLSyntaxRule[] getSyntaxRules() { + return new XMLSyntaxRule[0]; + } + + @Override + public String getParserDescription() { + return null; + } + + @Override + public Class getReturnType() { + return Transform.MultivariateTransform.class; + } + + @Override + public String getParserName() { + return NAME; + } + }; +} diff --git a/src/dr/util/TimeToDistanceProportionTransform.java b/src/dr/util/TimeToDistanceProportionTransform.java new file mode 100644 index 0000000000..86d0c0dd6f --- /dev/null +++ b/src/dr/util/TimeToDistanceProportionTransform.java @@ -0,0 +1,107 @@ +package dr.util; + +import dr.xml.AbstractXMLObjectParser; +import dr.xml.XMLObject; +import dr.xml.XMLParseException; +import dr.xml.XMLSyntaxRule; + +public class TimeToDistanceProportionTransform extends Transform.MultivariateTransform { + public static String NAME = "TimeToDistanceProportionTransform"; + + // The variables are assumed to be in the following order: + // 0: the proportion (in time for forward transform, in distance for reverse transform) + // 1: the log-scale rate in the ancestral-model portion of the branch + // 2: the log-scale rate in the descendant-model portion of the branch + // The input is these as a length-3 vector, the output is a length-1 vector of the transformed proportion + TimeToDistanceProportionTransform() { + super(3, 1); + } + + @Override + public double[] inverse(double[] values, int from, int to, double sum) { + throw new RuntimeException("Not yet implemented"); + } + + @Override + public double[] gradient(double[] values, int from, int to) { + throw new RuntimeException("Not yet implemented"); + } + + @Override + public double[] gradientInverse(double[] values, int from, int to) { + throw new RuntimeException("Not yet implemented"); + } + + @Override + public String getTransformName() { + return NAME; + } + + @Override + protected double[] transform(double[] values) { + double propTime = values[0]; + double rateAncestral = Math.exp(values[1]); + double rateDescendant = Math.exp(values[2]); + double[] transformed = new double[1]; + transformed[0] = (rateDescendant * propTime) / (rateDescendant * propTime + rateAncestral * (1.0 - propTime)); + return transformed; + } + + @Override + protected double[] inverse(double[] values) { + throw new RuntimeException("Not yet implemented"); + } + + @Override + protected double getLogJacobian(double[] values) { + throw new RuntimeException("Not yet implemented"); + } + + @Override + protected double[] getGradientLogJacobianInverse(double[] values) { + throw new RuntimeException("Not yet implemented"); + } + + @Override + public double[][] computeJacobianMatrixInverse(double[] values) { + throw new RuntimeException("Not yet implemented"); + } + + @Override + protected boolean isInInteriorDomain(double[] values) { + // Only the proportion is bounded + return values[0] >= 0.0 && values[0] <= 1.0; + } + + public static final AbstractXMLObjectParser PARSER = new AbstractXMLObjectParser() { + + @Override + public Object parseXMLObject(XMLObject xo) throws XMLParseException { + final String name = xo.hasId() ? xo.getId() : null; + + TimeToDistanceProportionTransform transform = new TimeToDistanceProportionTransform(); + + return transform; + } + + @Override + public XMLSyntaxRule[] getSyntaxRules() { + return new XMLSyntaxRule[0]; + } + + @Override + public String getParserDescription() { + return null; + } + + @Override + public Class getReturnType() { + return Transform.MultivariateTransform.class; + } + + @Override + public String getParserName() { + return NAME; + } + }; +}