Skip to content

Commit

Permalink
Merge pull request #245 from stan-dev/sem-check-invariant-test
Browse files Browse the repository at this point in the history
Add check to semantic check that original AST is not modified (only decorated)
  • Loading branch information
VMatthijs authored Aug 9, 2019
2 parents 1a19c30 + 6b3b116 commit 74282da
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 36 deletions.
34 changes: 20 additions & 14 deletions src/frontend/Ast.ml
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@ type fun_kind = StanLib | UserDefined [@@deriving compare, sexp, hash]

(** Expression shapes (used for both typed and untyped expressions, where we
substitute untyped_expression or typed_expression for 'e *)
type 'e expression =
type ('e, 'f) expression =
| TernaryIf of 'e * 'e * 'e
| BinOp of 'e * Middle.operator * 'e
| PrefixOp of Middle.operator * 'e
| PostfixOp of 'e * Middle.operator
| Variable of identifier
| IntNumeral of string
| RealNumeral of string
| FunApp of fun_kind * identifier * 'e list
| CondDistApp of fun_kind * identifier * 'e list
| FunApp of 'f * identifier * 'e list
| CondDistApp of 'f * identifier * 'e list
(* GetLP is deprecated *)
| GetLP
| GetTarget
Expand All @@ -39,14 +39,14 @@ type 'e expression =
| Indexed of 'e * 'e index list
[@@deriving sexp, hash, compare, map]

type 'm expr_with = {expr: 'm expr_with expression; emeta: 'm}
type ('m, 'f) expr_with = {expr: (('m, 'f) expr_with, 'f) expression; emeta: 'm}
[@@deriving sexp, compare, map, hash]

(** Untyped expressions, which have location_spans as meta-data *)
type located_meta = {loc: Middle.location_span sexp_opaque [@compare.ignore]}
[@@deriving sexp, compare, map, hash]

type untyped_expression = located_meta expr_with
type untyped_expression = (located_meta, unit) expr_with
[@@deriving sexp, compare, map, hash]

(** Typed expressions also have meta-data after type checking: a location_span, as well as a type
Expand All @@ -57,7 +57,7 @@ type typed_expr_meta =
; type_: Middle.unsizedtype }
[@@deriving sexp, compare, map, hash]

type typed_expression = typed_expr_meta expr_with
type typed_expression = (typed_expr_meta, fun_kind) expr_with
[@@deriving sexp, compare, map, hash]

let mk_untyped_expression ~expr ~loc = {expr; emeta= {loc}}
Expand Down Expand Up @@ -133,12 +133,12 @@ type typed_lval = (typed_expression, typed_expr_meta) lval_with
(** Statement shapes, where we substitute untyped_expression and untyped_statement
for 'e and 's respectively to get untyped_statement and typed_expression and
typed_statement to get typed_statement *)
type ('e, 's, 'l) statement =
type ('e, 's, 'l, 'f) statement =
| Assignment of
{ assign_lhs: 'l
; assign_op: assignmentoperator
; assign_rhs: 'e }
| NRFunApp of fun_kind * identifier * 'e list
| NRFunApp of 'f * identifier * 'e list
| TargetPE of 'e
(* IncrementLogProb is deprecated *)
| IncrementLogProb of 'e
Expand Down Expand Up @@ -191,13 +191,13 @@ type statement_returntype =
| AnyReturnType
[@@deriving sexp, hash, compare]

type ('e, 'm, 'l) statement_with =
{stmt: ('e, ('e, 'm, 'l) statement_with, 'l) statement; smeta: 'm}
type ('e, 'm, 'l, 'f) statement_with =
{stmt: ('e, ('e, 'm, 'l, 'f) statement_with, 'l, 'f) statement; smeta: 'm}
[@@deriving sexp, compare, map, hash]

(** Untyped statements, which have location_spans as meta-data *)
type untyped_statement =
(untyped_expression, located_meta, untyped_lval) statement_with
(untyped_expression, located_meta, untyped_lval, unit) statement_with
[@@deriving sexp, compare, map, hash]

let mk_untyped_statement ~stmt ~loc : untyped_statement = {stmt; smeta= {loc}}
Expand All @@ -210,7 +210,11 @@ type stmt_typed_located_meta =
(** Typed statements also have meta-data after type checking: a location_span, as well as a statement returntype
to check that function bodies have the right return type*)
type typed_statement =
(typed_expression, stmt_typed_located_meta, typed_lval) statement_with
( typed_expression
, stmt_typed_located_meta
, typed_lval
, fun_kind )
statement_with
[@@deriving sexp, compare, map, hash]

let mk_typed_statement ~stmt ~loc ~return_type =
Expand Down Expand Up @@ -240,7 +244,8 @@ type typed_program = typed_statement program [@@deriving sexp, compare, map]
(** Forgetful function from typed to untyped expressions *)
let rec untyped_expression_of_typed_expression
({expr; emeta} : typed_expression) : untyped_expression =
{ expr= map_expression untyped_expression_of_typed_expression expr
{ expr=
map_expression untyped_expression_of_typed_expression (fun _ -> ()) expr
; emeta= {loc= emeta.loc} }

let rec untyped_lvalue_of_typed_lvalue ({lval; lmeta} : typed_lval) :
Expand All @@ -255,11 +260,12 @@ let rec untyped_statement_of_typed_statement {stmt; smeta} =
{ stmt=
map_statement untyped_expression_of_typed_expression
untyped_statement_of_typed_statement untyped_lvalue_of_typed_lvalue
(fun _ -> ())
stmt
; smeta= {loc= smeta.loc} }

(** Forgetful function from typed to untyped programs *)
let untyped_program_of_typed_program =
let untyped_program_of_typed_program : typed_program -> untyped_program =
map_program untyped_statement_of_typed_statement

let rec expr_of_lvalue {lval; lmeta} =
Expand Down
4 changes: 1 addition & 3 deletions src/frontend/Ast_to_Mir.mli
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
(** Translate from the AST to the MIR *)

val trans_prog : string -> Ast.typed_program -> Middle.typed_prog

val trans_expr :
Ast.typed_expr_meta Ast.expr_with -> Middle.mtype_loc_ad Middle.with_expr
val trans_expr : Ast.typed_expression -> Middle.mtype_loc_ad Middle.with_expr
52 changes: 36 additions & 16 deletions src/frontend/Semantic_check.ml
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,11 @@ let semantic_check_fn_rng cf ~loc id =
then Semantic_error.invalid_rng_fn loc |> error
else ok ())

let mk_fun_app ~is_cond_dist (x, y, z) =
if is_cond_dist then CondDistApp (x, y, z) else FunApp (x, y, z)

(* Regular function application *)
let semantic_check_fn_normal ~loc id es =
let semantic_check_fn_normal ~is_cond_dist ~loc id es =
Validate.(
match Symbol_table.look vm id.name with
| Some (_, UFun (_, Void)) ->
Expand All @@ -277,7 +280,7 @@ let semantic_check_fn_normal ~loc id es =
|> error
| Some (_, UFun (_, ReturnType ut)) ->
mk_typed_expression
~expr:(FunApp (UserDefined, id, es))
~expr:(mk_fun_app ~is_cond_dist (UserDefined, id, es))
~ad_level:(lub_ad_e es) ~type_:ut ~loc
|> ok
| Some _ ->
Expand All @@ -288,14 +291,14 @@ let semantic_check_fn_normal ~loc id es =
|> error)

(* Stan-Math function application *)
let semantic_check_fn_stan_math ~loc id es =
let semantic_check_fn_stan_math ~is_cond_dist ~loc id es =
match stan_math_returntype id.name (get_arg_types es) with
| Some Void ->
Semantic_error.returning_fn_expected_nonreturning_found loc id.name
|> Validate.error
| Some (ReturnType ut) ->
mk_typed_expression
~expr:(FunApp (StanLib, id, es))
~expr:(mk_fun_app ~is_cond_dist (StanLib, id, es))
~ad_level:(lub_ad_e es) ~type_:ut ~loc
|> Validate.ok
| _ ->
Expand All @@ -320,10 +323,10 @@ let fn_kind_from_application id es =
(** Determines the function kind based on the identifier and performs the
corresponding semantic check
*)
let semantic_check_fn ~loc id es =
let semantic_check_fn ~is_cond_dist ~loc id es =
match fn_kind_from_application id es with
| StanLib -> semantic_check_fn_stan_math ~loc id es
| UserDefined -> semantic_check_fn_normal ~loc id es
| StanLib -> semantic_check_fn_stan_math ~is_cond_dist ~loc id es
| UserDefined -> semantic_check_fn_normal ~is_cond_dist ~loc id es

(* -- Ternary If ------------------------------------------------------------ *)

Expand Down Expand Up @@ -659,19 +662,12 @@ and semantic_check_funapp ~is_cond_dist id es cf emeta =
|> List.map ~f:(semantic_check_expression cf)
|> sequence
>>= fun ues ->
semantic_check_fn ~loc:emeta.loc id ues
semantic_check_fn ~is_cond_dist ~loc:emeta.loc id ues
|> apply_const (semantic_check_identifier id)
|> apply_const (semantic_check_fn_map_rect ~loc:emeta.loc id ues)
|> apply_const (name_check ~loc:emeta.loc id)
|> apply_const (semantic_check_fn_target_plus_equals cf ~loc:emeta.loc id)
|> apply_const (semantic_check_fn_rng cf ~loc:emeta.loc id)
>>= fun e ->
ok
{ e with
expr=
( match e.expr with
| FunApp (fun_kind, id, ues) -> CondDistApp (fun_kind, id, ues)
| _ -> raise_s [%sexp ("This should never happen!" : string)] ) })
|> apply_const (semantic_check_fn_rng cf ~loc:emeta.loc id))

and semantic_check_expression_of_int_type cf e name =
Validate.(
Expand Down Expand Up @@ -1676,9 +1672,33 @@ let semantic_check_program
; generatedquantitiesblock= ugb }
in
let apply_to x f = Validate.apply ~f x in
let check_correctness_invariant (decorated_ast : typed_program) :
typed_program =
if
compare_untyped_program
{ functionblock= fb
; datablock= db
; transformeddatablock= tdb
; parametersblock= pb
; transformedparametersblock= tpb
; modelblock= mb
; generatedquantitiesblock= gb }
(untyped_program_of_typed_program decorated_ast)
= 0
then decorated_ast
else
raise_s
[%message
"Type checked AST does not match original AST. Please file a bug!"
(decorated_ast : typed_program)]
in
let check_correctness_invariant_validate =
Validate.map ~f:check_correctness_invariant
in
Validate.(
ok mk_typed_prog |> apply_to ufb |> apply_to udb |> apply_to utdb
|> apply_to upb |> apply_to utpb |> apply_to umb |> apply_to ugb
|> check_correctness_invariant_validate
|> get_with
~with_ok:(fun ok -> Result.Ok ok)
~with_errors:(fun errs -> Result.Error errs))
6 changes: 3 additions & 3 deletions src/frontend/parser.mly
Original file line number Diff line number Diff line change
Expand Up @@ -376,14 +376,14 @@ common_expression:
| LBRACK xs=separated_nonempty_list(COMMA, expression) RBRACK
{ grammar_logger "row_vector_expression" ; RowVectorExpr xs }
| id=identifier LPAREN args=separated_list(COMMA, expression) RPAREN
{ grammar_logger "fun_app" ; FunApp (UserDefined, id, args) }
{ grammar_logger "fun_app" ; FunApp ((), id, args) }
| TARGET LPAREN RPAREN
{ grammar_logger "target_read" ; GetTarget }
| GETLP LPAREN RPAREN
{ grammar_logger "get_lp" ; GetLP } (* deprecated *)
| id=identifier LPAREN e=expression BAR args=separated_list(COMMA, expression)
RPAREN
{ grammar_logger "conditional_dist_app" ; CondDistApp (UserDefined, id, e :: args) }
{ grammar_logger "conditional_dist_app" ; CondDistApp ((), id, e :: args) }
| LPAREN e=expression RPAREN
{ grammar_logger "extra_paren" ; Paren e }

Expand Down Expand Up @@ -500,7 +500,7 @@ atomic_statement:
assign_op=op;
assign_rhs=e} }
| id=identifier LPAREN args=separated_list(COMMA, expression) RPAREN SEMICOLON
{ grammar_logger "funapp_statement" ; NRFunApp (UserDefined,id, args) }
{ grammar_logger "funapp_statement" ; NRFunApp ((),id, args) }
| INCREMENTLOGPROB LPAREN e=expression RPAREN SEMICOLON
{ grammar_logger "incrementlogprob_statement" ; IncrementLogProb e } (* deprecated *)
| e=expression TILDE id=identifier LPAREN es=separated_list(COMMA, expression)
Expand Down

0 comments on commit 74282da

Please sign in to comment.