Skip to content

Commit

Permalink
Merge pull request #1347 from stan-dev/fix/soa-ad-scalar-data-matrix-…
Browse files Browse the repository at this point in the history
…checks

Fix: Logic for demotion rules around scalar ad types and data matrices
  • Loading branch information
WardBrian authored Aug 21, 2023
2 parents 69edf41 + 870a5f4 commit d5c25b6
Show file tree
Hide file tree
Showing 8 changed files with 2,894 additions and 145 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,6 @@ test/*.log

# .hpp files in test folder
test/**/*.hpp

# .omir files in test folder
test/**/*.omir
246 changes: 109 additions & 137 deletions src/analysis_and_optimization/Memory_patterns.ml
Original file line number Diff line number Diff line change
Expand Up @@ -195,111 +195,74 @@ and query_initial_demotable_funs (in_loop : bool) (acc : string Set.Poly.t)
| UserDefined ((_ : string), (_ : bool Fun_kind.suffix)) ->
Set.Poly.union acc demoted_and_top_level_names

(**
Check whether any functions in the right hand side expression of an assignment
support SoA. If so then return true, otherwise return false.
*)
let rec is_any_soa_supported_expr
Expr.Fixed.{pattern; meta= Expr.Typed.Meta.{adlevel; type_; _}} : bool =
(**
* Recurse through subexpressions and return a list of Unsized types.
* Recursion continues until
* 1. A non-autodiffable type is found
* 2. An autodiffable scalar is found
* 3. A `Var` type is found that is an autodiffable matrix
*)
let rec extract_nonderived_admatrix_types
Expr.Fixed.{pattern; meta= Expr.Typed.Meta.{adlevel; type_; _}} =
if
UnsizedType.is_dataonlytype adlevel
|| not (UnsizedType.contains_eigen_type type_)
then true
else
UnsizedType.is_autodifftype adlevel && UnsizedType.contains_eigen_type type_
then
match pattern with
| FunApp (kind, (exprs : Expr.Typed.t list)) ->
is_any_soa_supported_fun_expr kind exprs
| Indexed (expr, (_ : Expr.Typed.t Index.t list))
|Promotion (expr, _, _)
|TupleProjection (expr, _) ->
is_any_soa_supported_expr expr
| Var (_ : string) | Lit ((_ : Expr.Fixed.Pattern.litType), (_ : string)) ->
true
| TernaryIf (_, texpr, fexpr) ->
is_any_soa_supported_expr texpr && is_any_soa_supported_expr fexpr
| EAnd (lhs, rhs) | EOr (lhs, rhs) ->
is_any_soa_supported_expr lhs && is_any_soa_supported_expr rhs

(**
Return false if the [Fun_kind.t] does not support [SoA]
*)
and is_any_soa_supported_fun_expr (kind : 'a Fun_kind.t)
(exprs : Expr.Typed.t list) : bool =
match kind with
| CompilerInternal (Internal_fun.FnMakeArray | FnMakeRowVec) -> false
| UserDefined ((_ : string), (_ : bool Fun_kind.suffix)) -> false
| CompilerInternal (_ : 'a Internal_fun.t) -> true
| Fun_kind.StanLib (name, (_ : bool Fun_kind.suffix), _) -> (
match name with
| "check_matching_dims" -> true
| _ ->
is_fun_soa_supported name exprs
&& List.exists ~f:is_any_soa_supported_expr exprs )

(**
Return true if the rhs expression of an assignment contains only
combinations of AutoDiffable Reals and Data Matrices
*)
let rec is_any_ad_real_data_matrix_expr
Expr.Fixed.{pattern; meta= Expr.Typed.Meta.{adlevel; _}} : bool =
if UnsizedType.is_dataonlytype adlevel then false
else
match pattern with
| FunApp (kind, (exprs : Expr.Typed.t list)) ->
is_any_ad_real_data_matrix_expr_fun kind exprs
extract_nonderived_admatrix_types_fun kind exprs
| Indexed (expr, _) | Promotion (expr, _, _) | TupleProjection (expr, _) ->
is_any_ad_real_data_matrix_expr expr
extract_nonderived_admatrix_types expr
| Var (_ : string) | Lit ((_ : Expr.Fixed.Pattern.litType), (_ : string)) ->
false
[(adlevel, type_)]
| TernaryIf (_, texpr, fexpr) ->
is_any_ad_real_data_matrix_expr texpr
|| is_any_ad_real_data_matrix_expr fexpr
List.concat
[ extract_nonderived_admatrix_types texpr
; extract_nonderived_admatrix_types fexpr ]
| EAnd (lhs, rhs) | EOr (lhs, rhs) ->
is_any_ad_real_data_matrix_expr lhs
&& is_any_ad_real_data_matrix_expr rhs
List.concat
[ extract_nonderived_admatrix_types lhs
; extract_nonderived_admatrix_types rhs ]
else [(adlevel, type_)]

(**
Return true if the expressions in a function call are all
combinations of AutoDiffable Reals and Data Matrices
* Recurse through functions to find nonderived ad matrix types.
* Special cases for StanLib functions are for
* - `check_matching_dims`: compiler function that has no effect on optimization
* - `rep_*vector` These are templated in the C++ to cast up to `Var<Matrix>` types
* - `rep_matrix`. When it's only a scalar being propogated an math library overload can upcast to `Var<Matrix>`
*)
and is_any_ad_real_data_matrix_expr_fun (kind : 'a Fun_kind.t)
(exprs : Expr.Typed.t list) : bool =
and extract_nonderived_admatrix_types_fun (kind : 'a Fun_kind.t)
(exprs : Expr.Typed.t list) =
match kind with
| Fun_kind.StanLib (name, (_ : bool Fun_kind.suffix), _) -> (
match name with
| "check_matching_dims" -> false
| _ -> (
let fun_args = List.map ~f:Expr.Typed.fun_arg exprs in
(*Right now we can't handle AD real and data matrix funcs
that return a matrix :-/*)
let is_args_autodiff_real_data_matrix =
(*If there are any autodiffable vars*)
List.exists
~f:(fun (x, y) ->
match (x, y) with
| UnsizedType.AutoDiffable, UnsizedType.UReal -> true
| _ -> false )
fun_args
(*And there are any data matrices*)
&& List.exists
~f:(fun (x, y) ->
match (x, UnsizedType.is_container y) with
| UnsizedType.DataOnly, true -> true
| _ -> false )
fun_args
(*And there are no Autodiffable matrices*)
&& List.exists
~f:(fun (x, y) ->
match (x, UnsizedType.contains_eigen_type y) with
| UnsizedType.AutoDiffable, true -> false
| _ -> true )
fun_args in
match is_args_autodiff_real_data_matrix with
| true -> true
| false -> List.exists ~f:is_any_ad_real_data_matrix_expr exprs ) )
| CompilerInternal (Internal_fun.FnMakeArray | FnMakeRowVec) -> true
| CompilerInternal (_ : 'a Internal_fun.t) -> false
| UserDefined ((_ : string), (_ : bool Fun_kind.suffix)) -> false
| "check_matching_dims" -> []
| "rep_vector" -> [(UnsizedType.AutoDiffable, UnsizedType.UVector)]
| "rep_row_vector" -> [(UnsizedType.AutoDiffable, UnsizedType.URowVector)]
| "rep_matrix"
when match List.map ~f:Expr.Typed.fun_arg exprs with
| [(_, UnsizedType.UReal); _; _] -> true
| _ -> false ->
[(UnsizedType.AutoDiffable, UnsizedType.UMatrix)]
| _ -> List.concat_map ~f:extract_nonderived_admatrix_types exprs )
(*While not "true", we need to tell the optimizer these are danger functions*)
| CompilerInternal Internal_fun.FnMakeArray ->
[(AutoDiffable, UReal); (DataOnly, UArray UReal)]
| CompilerInternal Internal_fun.FnMakeRowVec ->
[(AutoDiffable, UReal); (DataOnly, URowVector)]
| CompilerInternal (_ : 'a Internal_fun.t) -> []
| UserDefined ((_ : string), (_ : bool Fun_kind.suffix)) -> []

(**Checks if a list of types contains at least on ad matrix or if everything is derived from data*)
let contains_at_least_one_ad_matrix_or_all_data
(fun_args : (UnsizedType.autodifftype * UnsizedType.t) list) =
List.is_empty fun_args
|| List.exists
~f:(fun x ->
UnsizedType.is_autodifftype (fst x)
&& UnsizedType.is_eigen_type (snd x) )
fun_args
|| List.for_all ~f:(fun x -> UnsizedType.is_dataonlytype (fst x)) fun_args

(**
Query to find the initial set of objects in statements that cannot be SoA.
Expand All @@ -308,13 +271,13 @@ and is_any_ad_real_data_matrix_expr_fun (kind : 'a Fun_kind.t)
*
For assignments:
We demote the LHS variable if any of the following are true:
1. None of the RHS's functions are able to accept SoA matrices
and the rhs is not an internal compiler function.
2. A single cell of the LHS is being assigned within a loop.
3. The top level expression on the RHS is a combination of only
1. A single cell of the LHS is being assigned within a loop.
2. The top level expression on the RHS is a combination of only
data matrices and scalar types. Operations on data matrix and
scalar values in Stan math will return a AoS matrix. We currently
have no way to tell Stan math to return a SoA matrix.
3. None of the RHS's functions are able to accept SoA matrices
and the rhs is not an internal compiler function.
*
We demote RHS variables if any of the following are true:
1. The LHS variable has previously or through this iteration
Expand All @@ -336,57 +299,66 @@ let rec query_initial_demotable_stmt (in_loop : bool) (acc : string Set.Poly.t)
, (ut : UnsizedType.t)
, (Expr.Fixed.{meta= Expr.Typed.Meta.{type_; adlevel; _}; _} as rhs) ) ->
let name = Stmt.Helpers.lhs_variable lval in
let idx = Stmt.Helpers.lhs_indices lval in
let idx_list =
List.fold ~init:acc
~f:(fun accum x ->
Index.folder accum
(fun acc -> query_initial_demotable_expr in_loop ~acc)
x )
idx in
(* LHS (1)*)
let idx_demotable =
(* RHS (2)*)
let idx = Stmt.Helpers.lhs_indices lval in
let idx_list =
List.fold ~init:acc
~f:(fun accum x ->
Index.folder accum
(fun acc -> query_initial_demotable_expr in_loop ~acc)
x )
idx in
match is_uni_eigen_loop_indexing in_loop ut idx with
| true -> Set.Poly.add idx_list name
| false -> idx_list in
let rhs_demotable_names = query_expr acc rhs in
(* RHS (3)*)
let check_if_rhs_ad_real_data_matrix_expr =
match (UnsizedType.contains_eigen_type type_, adlevel) with
| true, UnsizedType.AutoDiffable ->
is_any_ad_real_data_matrix_expr rhs
|| not (is_any_soa_supported_expr rhs)
| _ -> false in
let rhs_and_idx_demotions =
Set.Poly.union idx_demotable rhs_demotable_names in
(* RHS (1)*)
let is_all_rhs_aos =
let all_rhs_eigen_names = query_var_eigen_names rhs in
is_nonzero_subset ~subset:all_rhs_eigen_names ~set:rhs_demotable_names
in
let is_not_supported_func =
match rhs.pattern with
| FunApp (CompilerInternal _, _) -> false
| FunApp (UserDefined _, _) -> true
| _ -> false in
let is_eigen_stmt = UnsizedType.contains_eigen_type rhs.meta.type_ in
let assign_demotes =
if
is_eigen_stmt
&& ( is_all_rhs_aos || check_if_rhs_ad_real_data_matrix_expr
|| is_not_supported_func )
then
let base_set = Set.Poly.union idx_demotable rhs_demotable_names in
Set.Poly.add
(Set.Poly.union base_set (query_var_eigen_names rhs))
name
else Set.Poly.union idx_demotable rhs_demotable_names in
let tuple_demotes =
let tuple_demotions =
match lval with
| LTupleProjection _, _ ->
Set.Poly.add
(Set.Poly.union assign_demotes (query_var_eigen_names rhs))
(Set.Poly.union rhs_and_idx_demotions (query_var_eigen_names rhs))
name
| _ -> rhs_and_idx_demotions in
let assign_demotions =
let is_eigen_stmt = UnsizedType.contains_eigen_type rhs.meta.type_ in
if is_eigen_stmt then
(* LHS (2)*)
let is_rhs_not_promoteable_to_soa =
match (UnsizedType.contains_eigen_type type_, adlevel) with
| true, UnsizedType.AutoDiffable ->
not
(contains_at_least_one_ad_matrix_or_all_data
(extract_nonderived_admatrix_types rhs) )
| _ -> false in
(* LHS (3) rhs unsupported function*)
let is_not_supported_func =
match rhs.pattern with
| FunApp (UserDefined _, _) -> true
| FunApp (CompilerInternal _, _) -> false
| FunApp (StanLib (name, _, _), exprs) ->
not
(query_stan_math_mem_pattern_support name
(List.map ~f:Expr.Typed.fun_arg exprs) )
| _ -> false in
(* LHS (3) all rhs aos*)
let is_all_rhs_aos =
is_nonzero_subset
~subset:(query_var_eigen_names rhs)
~set:rhs_demotable_names in
if
is_all_rhs_aos || is_rhs_not_promoteable_to_soa
|| is_not_supported_func
then
Set.Poly.add
(Set.Poly.union tuple_demotions (query_var_eigen_names rhs))
name
| _ -> assign_demotes in
Set.Poly.union acc tuple_demotes
else tuple_demotions
else tuple_demotions in
Set.Poly.union acc assign_demotions
| NRFunApp (kind, exprs) ->
query_initial_demotable_funs in_loop acc kind exprs
| IfElse (predicate, true_stmt, op_false_stmt) ->
Expand Down
5 changes: 0 additions & 5 deletions src/analysis_and_optimization/Optimize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1228,11 +1228,6 @@ let optimize_soa (mir : Program.Typed.t) =
List.fold ~init:Set.Poly.empty
~f:(Memory_patterns.query_initial_demotable_stmt false)
mir.reverse_mode_log_prob in
(*
let print_set s =
Set.Poly.iter ~f:print_endline s in
let () = print_set initial_variables in
*)
let mod_exprs aos_exits mod_expr =
Mir_utils.map_rec_expr
(Memory_patterns.modify_expr_pattern aos_exits)
Expand Down
4 changes: 2 additions & 2 deletions src/middle/Stan_math_signatures.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2204,8 +2204,8 @@ let () =
(List.range 1 3) )
bare_types ;
add_unqualified ("rep_matrix", ReturnType UMatrix, [UReal; UInt; UInt], SoA) ;
add_unqualified ("rep_matrix", ReturnType UMatrix, [UVector; UInt], AoS) ;
add_unqualified ("rep_matrix", ReturnType UMatrix, [URowVector; UInt], AoS) ;
add_unqualified ("rep_matrix", ReturnType UMatrix, [UVector; UInt], SoA) ;
add_unqualified ("rep_matrix", ReturnType UMatrix, [URowVector; UInt], SoA) ;
add_unqualified
("rep_matrix", ReturnType UComplexMatrix, [UComplex; UInt; UInt], AoS) ;
add_unqualified
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
data {
int N;
matrix[N, N] X_data;
vector[N] y_data;
}
parameters {
real alpha;
real sigma;
vector[N] beta;
}
model{
vector[N] soa_simple= alpha + rep_vector(0.0, N) + beta;
vector[N] aos_deep = 2 * Phi(y_data / sigma) - 1;
vector[N] soa_dual_rep = transpose(rep_row_vector(0.0, N)) + rep_vector(sigma, N);
vector[N] soa_data_rep = rep_vector(0.0, N) + rep_vector(N, N);
vector[N] soa_mix = Phi(y_data / sigma) + soa_simple;
vector[N] aos_from_data = alpha + sigma * alpha + y_data - alpha - sigma * alpha;
matrix[N, N] soa_mat_rep = transpose(rep_matrix(2.0, N, N)) + rep_matrix(sum(rep_vector(sigma, N)), N, N);
matrix[N, N] soa_mat_rep_vec = transpose(rep_matrix(2.0, N, N)) + rep_matrix(rep_vector(sigma, N), N);
matrix[N, N] aos_mat_rep = transpose(rep_matrix(aos_deep, N)) + rep_matrix(aos_deep, N);
matrix[2, 2] aos_mat_from_vecs = [[alpha ^ 2, sigma], [alpha, sigma ^ 2]];
y_data ~ normal(soa_simple, aos_deep);
target += sum(soa_dual_rep);
target += sum(aos_from_data);
target += sum(soa_data_rep);
target += sum(soa_mix);
target += sum(soa_mat_rep);
target += sum(soa_mat_rep_vec);
target += sum(aos_mat_rep);
target += sum(aos_mat_from_vecs);
}
Loading

0 comments on commit d5c25b6

Please sign in to comment.