Skip to content

Commit

Permalink
thinking about how to clean up Transform
Browse files Browse the repository at this point in the history
  • Loading branch information
msuchard committed Oct 10, 2024
1 parent 538e916 commit 48a9cfb
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,6 @@ public static ApproximationMode factory(String label) {
}
}

private static final ApproximationMode DEFAULT_MODE = ApproximationMode.FIRST_ORDER;

public AbstractLogAdditiveSubstitutionModelGradient(String traitName,
TreeDataLikelihood treeDataLikelihood,
BeagleDataLikelihoodDelegate likelihoodDelegate,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,6 @@ public String getDescription() {
@Override
public List<Citation> getCitations() {
// TODO Update
return Collections.singletonList(CommonCitations.LEMEY_2014_UNIFYING);
return Collections.singletonList(CommonCitations.MONTI_GENERIC_RATES_2024);
}
}
18 changes: 18 additions & 0 deletions src/dr/inference/distribution/CauchyDistribution.java
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,21 @@ public double[] getGradientLogDensity(Object input) {
return new double[] { gradLogPdf(x, median, scale) };
}
}

/*
log-Cauchy: y = e^x, such that x ~ Cauchy. Then
p(y) = s / pi / ((log y - m)^2 + s^2)) / y
\frac{d p(y)}{d y} = -1 *
\frac{
(log y - m)^2 - 2m + s^2 + 2log y
}{
y^2 [(log y - m)^2 + s^2]^2
}
this is decreasing when f(y) = (log y -m)^2 - 2m + s^2 +2log y > 0
f(y) has roots at x = log(y) = (m - 1) +/- sqrt(1 - s^2)
these are only real for s < 1, in which case
x = log(y) > m - 1 - sqrt(1 - s^2) is decreasing.
*/
4 changes: 2 additions & 2 deletions src/dr/inference/model/WeightedMixtureModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ public String getParserName() {
public Object parseXMLObject(XMLObject xo) throws XMLParseException {

Parameter weights = (Parameter) xo.getChild(Parameter.class);
List<AbstractModelLikelihood> likelihoodList = new ArrayList<AbstractModelLikelihood>();
List<AbstractModelLikelihood> likelihoodList = new ArrayList<>();

for (int i = 0; i < xo.getChildCount(); i++) {
if (xo.getChild(i) instanceof Likelihood)
Expand Down Expand Up @@ -296,7 +296,7 @@ public void setId(String id) {
}
};

List<AbstractModelLikelihood> likelihoodList = new ArrayList<AbstractModelLikelihood>();
List<AbstractModelLikelihood> likelihoodList = new ArrayList<>();
likelihoodList.add(like1);
likelihoodList.add(like2);

Expand Down
29 changes: 7 additions & 22 deletions src/dr/util/CommonCitations.java
Original file line number Diff line number Diff line change
Expand Up @@ -240,17 +240,13 @@ public class CommonCitations {
Citation.Status.PUBLISHED
);

// Minin VN, Suchard MA (2008) . Philos Trans R Soc Lond B Biol Sci 363(1512):3985-3995.

// public static Citation LEMEY_2012 = new Citation(
// new Author[]{
// new Author("P", "Lemey"),
// new Author("T", "Bedford"),
// new Author("A", "Rambaut"),
// new Author("MA", "Suchard"),
// },
// Citation.Status.IN_PREPARATION
// );
public static Citation MONTI_GENERIC_RATES_2024 = new Citation(
new Author[]{
new Author("F", "Monti"),
new Author("MA", "Suchard"),
},
Citation.Status.IN_PREPARATION
);

public static Citation LEMEY_MIXTURE_2012 = new Citation(
new Author[]{
Expand All @@ -273,17 +269,6 @@ public class CommonCitations {
"e00631"
);


// Gong LI, Suchard MA, Bloom JD. Stability-mediated epistasis constrains the evolution of an influenza protein. eLife, 2, e00631, 2013.

public static Citation SUCHARD_2012_LATENT = new Citation(
new Author[]{
new Author("MA", "Suchard"),
new Author("J", "Felsenstein"),
},
Citation.Status.IN_PREPARATION
);

public static Citation SUCHARD_GENERIC = new Citation(
new Author[]{
new Author("MA", "Suchard"),
Expand Down
106 changes: 89 additions & 17 deletions src/dr/util/Transform.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ public interface Transform {
// Transform: y = f(x)

/**
* @param value evaluation point
* @return the transformed value
* @param x evaluation point
* @return y transformed value
*/
double transform(double x);

Expand Down Expand Up @@ -309,7 +309,7 @@ public boolean isInInteriorDomain(double[] values, int from, int to) {
@Deprecated
public double logGradientInverse(double value) {
throw new RuntimeException("Not yet implemented.");
};
}

@Deprecated
public double[] logGradientInverse(double[] values, int from, int to) {
Expand All @@ -323,7 +323,7 @@ public double[] logGradientInverse(double[] values, int from, int to) {
@Deprecated
public double derivativeOfTransformWrtValue(double value) {
throw new RuntimeException("Not yet implemented.");
};
}

@Deprecated
public double[] derivativeOfTransformWrtValue(double[] values, int from, int to) {
Expand All @@ -337,7 +337,7 @@ public double[] derivativeOfTransformWrtValue(double[] values, int from, int to)
@Deprecated
public double secondDerivativeOfTransformWrtValue(double value) {
throw new RuntimeException("Not yet implemented.");
};
}

@Deprecated
public double[] secondDerivativeOfTransformWrtValue(double[] values, int from, int to) {
Expand Down Expand Up @@ -446,7 +446,7 @@ public double[] logGradientInverse(double[] values, int from, int to) {
@Deprecated
public double derivativeOfTransformWrtValue(double value) {
throw new RuntimeException("Transformation not permitted for this type of parameter, exiting ...");
};
}

@Deprecated
public double[] derivativeOfTransformWrtValue(double[] values, int from, int to) {
Expand All @@ -456,7 +456,7 @@ public double[] derivativeOfTransformWrtValue(double[] values, int from, int to)
@Deprecated
public double secondDerivativeOfTransformWrtValue(double value) {
throw new RuntimeException("Transformation not permitted for this type of parameter, exiting ...");
};
}

@Deprecated
public double[] secondDerivativeOfTransformWrtValue(double[] values, int from, int to) {
Expand Down Expand Up @@ -698,9 +698,15 @@ public double gradient(double value) {

// y = x^2
class SquaredTransform extends UnivariableTransform {

Transform inverse;

@Override
public Transform inverseTransform() {
throw new RuntimeException("Not yet implemented");
if (inverse == null) {
inverse = new PowerTransform(1/2);
}
return inverse;
}

public double transform(double x) {
Expand Down Expand Up @@ -1002,13 +1008,74 @@ public static void main(String[] args) {

}

class SigmoidTransform extends UnivariableTransform {

public SigmoidTransform() { }

@Override
public Transform inverseTransform() {
return LOGIT;
}

public double transform(double value) { return 1.0 / (1.0 + Math.exp(-value)); }

public double inverse(double value) { return Math.log(value / (1.0 - value)); }

public boolean isInInteriorDomain(double value) {
return true;
}

public double gradientInverse(double value) {
return gradient(inverse(value));
}

public double updateGradientLogDensity(double gradient, double value) {
throw new RuntimeException("Not yet implemented");
// return gradient * value * (1.0 - value) - (2.0 * value - 1.0);
}

public double gradientLogJacobianInverse(double value) {
throw new RuntimeException("Not yet implemented");
}

@Override
public double updateDiagonalHessianLogDensity(double diagonalHessian, double gradient, double value) {
throw new RuntimeException("Not yet implemented");
}

@Override
public double updateOffdiagonalHessianLogDensity(double offdiagonalHessian, double transformationHessian, double gradientI, double gradientJ, double valueI, double valueJ) {
throw new RuntimeException("Not yet implemented");
}

@Override
public double gradient(double value) {
throw new RuntimeException("Not yet implemented"); // TODO appears to be dx / dy evaluated with x (which is gradientInverse, no?)
// return value * (1.0 - value);
}

public String getTransformName() {
return "sigmoid";
}

public double logJacobian(double value) {
throw new RuntimeException("Not yet implemented");
// return -Math.log(1.0 - value) - Math.log(value);
}
}

class LogitTransform extends UnivariableTransform {

public LogitTransform() {
range = 1.0;
lower = 0.0;
}

@Override
public Transform inverseTransform() {
return SIGMOID;
}

public double transform(double value) {
return Math.log(value / (1.0 - value));
}
Expand Down Expand Up @@ -1181,6 +1248,11 @@ public double logJacobian(double value) {

class NegateTransform extends UnivariableTransform {

@Override
public Transform inverseTransform() {
return NEGATE;
}

public double transform(double value) {
return -value;
}
Expand Down Expand Up @@ -1230,7 +1302,7 @@ public double logJacobian(double value) {
}

class PowerTransform extends UnivariableTransform{
private double power;
private final double power;

PowerTransform(){
this.power = 2;
Expand Down Expand Up @@ -1359,7 +1431,7 @@ public boolean isInInteriorDomain(double value) {
}

class InverseSumTransform extends UnivariableTransform {
private double sum;
private final double sum;

InverseSumTransform() {
this.sum = 1;
Expand Down Expand Up @@ -1424,6 +1496,11 @@ public boolean isInInteriorDomain(double value) {

class NoTransform extends UnivariableTransform {

@Override
public Transform inverseTransform() {
return NONE;
}

public double transform(double value) {
return value;
}
Expand Down Expand Up @@ -2665,6 +2742,7 @@ public static MultivariableTransform parseMultivariableTransform(Object obj) {
Compose LOG_NEGATE = new Compose(new LogTransform(), new NegateTransform());
LogConstrainedSumTransform LOG_CONSTRAINED_SUM = new LogConstrainedSumTransform();
LogitTransform LOGIT = new LogitTransform();
SigmoidTransform SIGMOID = new SigmoidTransform();
FisherZTransform FISHER_Z = new FisherZTransform();

enum Type {
Expand All @@ -2675,6 +2753,7 @@ enum Type {
LOG_NEGATE("log-negate", new Compose(new LogTransform(), new NegateTransform())),
LOG_CONSTRAINED_SUM("logConstrainedSum", new LogConstrainedSumTransform()),
LOGIT("logit", new LogitTransform()),
SIGMOID("sigmoid", new SigmoidTransform()),
FISHER_Z("fisherZ",new FisherZTransform()),
INVERSE_SUM("inverseSum", new InverseSumTransform()),
SQUARED("squared", new SquaredTransform()),
Expand All @@ -2696,11 +2775,4 @@ public String getName() {
private Transform transform;
private String name;
}
// String TRANSFORM = "transform";
// String TYPE = "type";
// String START = "start";
// String END = "end";
// String EVERY = "every";
// String INVERSE = "inverse";

}

0 comments on commit 48a9cfb

Please sign in to comment.