diff --git a/src/Lean/Compiler/LCNF/InferType.lean b/src/Lean/Compiler/LCNF/InferType.lean index cc278044006e..44aae02e42cc 100644 --- a/src/Lean/Compiler/LCNF/InferType.lean +++ b/src/Lean/Compiler/LCNF/InferType.lean @@ -168,13 +168,12 @@ mutual /- TODO: after we erase universe variables, we can just extract a better type using just `structName` and `idx`. -/ return erasedExpr else - matchConstStruct structType.getAppFn failed fun structVal structLvls ctorVal => - let n := structVal.numParams - let structParams := structType.getAppArgs - if n != structParams.size then + matchConstStructure structType.getAppFn failed fun structVal structLvls ctorVal => + let structTypeArgs := structType.getAppArgs + if structVal.numParams + structVal.numIndices != structTypeArgs.size then failed () else do - let mut ctorType ← inferAppType (mkAppN (mkConst ctorVal.name structLvls) structParams) + let mut ctorType ← inferAppType (mkAppN (mkConst ctorVal.name structLvls) structTypeArgs[:structVal.numParams]) for _ in [:idx] do match ctorType with | .forallE _ _ body _ => diff --git a/src/Lean/Elab/App.lean b/src/Lean/Elab/App.lean index 40dd40ffa12e..7752ac82024f 100644 --- a/src/Lean/Elab/App.lean +++ b/src/Lean/Elab/App.lean @@ -1188,19 +1188,19 @@ private def resolveLValAux (e : Expr) (eType : Expr) (lval : LVal) : TermElabM L if idx == 0 then throwError "invalid projection, index must be greater than 0" let env ← getEnv - unless isStructureLike env structName do - throwLValError e eType "invalid projection, structure expected" - let numFields := getStructureLikeNumFields env structName - if idx - 1 < numFields then - if isStructure env structName then - let fieldNames := getStructureFields env structName - return LValResolution.projFn structName structName fieldNames[idx - 1]! + let failK _ := throwLValError e eType "invalid projection, structure expected" + matchConstStructure eType.getAppFn failK fun _ _ ctorVal => do + let numFields := ctorVal.numFields + if idx - 1 < numFields then + if isStructure env structName then + let fieldNames := getStructureFields env structName + return LValResolution.projFn structName structName fieldNames[idx - 1]! + else + /- `structName` was declared using `inductive` command. + So, we don't projection functions for it. Thus, we use `Expr.proj` -/ + return LValResolution.projIdx structName (idx - 1) else - /- `structName` was declared using `inductive` command. - So, we don't projection functions for it. Thus, we use `Expr.proj` -/ - return LValResolution.projIdx structName (idx - 1) - else - throwLValError e eType m!"invalid projection, structure has only {numFields} field(s)" + throwLValError e eType m!"invalid projection, structure has only {numFields} field(s)" | some structName, LVal.fieldName _ fieldName _ _ => let env ← getEnv let searchEnv : Unit → TermElabM LValResolution := fun _ => do diff --git a/src/Lean/Meta/ExprDefEq.lean b/src/Lean/Meta/ExprDefEq.lean index a35adad96726..de3bf2daf9aa 100644 --- a/src/Lean/Meta/ExprDefEq.lean +++ b/src/Lean/Meta/ExprDefEq.lean @@ -1975,7 +1975,7 @@ where assign `?m`. -/ return false - let ctorVal := getStructureCtor (← getEnv) structName + let some ctorVal := getStructureLikeCtor? (← getEnv) structName | return false if ctorVal.numFields != 1 then return false -- It is not a structure with a single field. let sType ← whnf (← inferType s) @@ -2013,7 +2013,7 @@ private def isDefEqApp (t s : Expr) : MetaM Bool := do /-- Return `true` if the type of the given expression is an inductive datatype with a single constructor with no fields. -/ private def isDefEqUnitLike (t : Expr) (s : Expr) : MetaM Bool := do let tType ← whnf (← inferType t) - matchConstStruct tType.getAppFn (fun _ => return false) fun _ _ ctorVal => do + matchConstStructureLike tType.getAppFn (fun _ => return false) fun _ _ ctorVal => do if ctorVal.numFields != 0 then return false else if (← useEtaStruct ctorVal.induct) then diff --git a/src/Lean/Meta/InferType.lean b/src/Lean/Meta/InferType.lean index f4424d048a89..6c0ff0e370ff 100644 --- a/src/Lean/Meta/InferType.lean +++ b/src/Lean/Meta/InferType.lean @@ -99,13 +99,12 @@ private def inferProjType (structName : Name) (idx : Nat) (e : Expr) : MetaM Exp let structType ← whnf structType let failed {α} : Unit → MetaM α := fun _ => throwError "invalid projection{indentExpr (mkProj structName idx e)} from type {structType}" - matchConstStruct structType.getAppFn failed fun structVal structLvls ctorVal => - let n := structVal.numParams - let structParams := structType.getAppArgs - if n != structParams.size then + matchConstStructure structType.getAppFn failed fun structVal structLvls ctorVal => + let structTypeArgs := structType.getAppArgs + if structVal.numParams + structVal.numIndices != structTypeArgs.size then failed () else do - let mut ctorType ← inferAppType (mkConst ctorVal.name structLvls) structParams + let mut ctorType ← inferAppType (mkConst ctorVal.name structLvls) structTypeArgs[:structVal.numParams] for i in [:idx] do ctorType ← whnf ctorType match ctorType with diff --git a/src/Lean/Meta/Tactic/Constructor.lean b/src/Lean/Meta/Tactic/Constructor.lean index 57034e1e3fdd..da98ce00acb7 100644 --- a/src/Lean/Meta/Tactic/Constructor.lean +++ b/src/Lean/Meta/Tactic/Constructor.lean @@ -32,7 +32,7 @@ def _root_.Lean.MVarId.existsIntro (mvarId : MVarId) (w : Expr) : MetaM MVarId : mvarId.withContext do mvarId.checkNotAssigned `exists let target ← mvarId.getType' - matchConstStruct target.getAppFn + matchConstStructure target.getAppFn (fun _ => throwTacticEx `exists mvarId "target is not an inductive datatype with one constructor") fun _ us cval => do if cval.numFields < 2 then diff --git a/src/Lean/MonadEnv.lean b/src/Lean/MonadEnv.lean index 82c529cb6eb1..61831dc60a4a 100644 --- a/src/Lean/MonadEnv.lean +++ b/src/Lean/MonadEnv.lean @@ -118,7 +118,26 @@ def getConstInfoRec [Monad m] [MonadEnv m] [MonadError m] (constName : Name) : m | ConstantInfo.recInfo v => pure v | _ => throwError "'{mkConst constName}' is not a recursor" -@[inline] def matchConstStruct [Monad m] [MonadEnv m] [MonadError m] (e : Expr) (failK : Unit → m α) (k : InductiveVal → List Level → ConstructorVal → m α) : m α := +/-- +Matches if `e` is a constant that is an inductive type with one constructor. +Such types can be used with primitive projections. +See also `Lean.matchConstStructLike` for a more restrictive version. +-/ +@[inline] def matchConstStructure [Monad m] [MonadEnv m] [MonadError m] (e : Expr) (failK : Unit → m α) (k : InductiveVal → List Level → ConstructorVal → m α) : m α := + matchConstInduct e failK fun ival us => do + match ival.ctors with + | [ctor] => + match (← getConstInfo ctor) with + | ConstantInfo.ctorInfo cval => k ival us cval + | _ => failK () + | _ => failK () + +/-- +Matches if `e` is a constant that is an non-recursive inductive type with no indices and with one constructor. +Such a type satisfies `Lean.isStructureLike`. +See also `Lean.matchConstStructure` for a less restrictive version. +-/ +@[inline] def matchConstStructureLike [Monad m] [MonadEnv m] [MonadError m] (e : Expr) (failK : Unit → m α) (k : InductiveVal → List Level → ConstructorVal → m α) : m α := matchConstInduct e failK fun ival us => do if ival.isRec || ival.numIndices != 0 then failK () else match ival.ctors with diff --git a/src/Lean/Structure.lean b/src/Lean/Structure.lean index 235bf55b6ef9..288615d6094c 100644 --- a/src/Lean/Structure.lean +++ b/src/Lean/Structure.lean @@ -133,9 +133,17 @@ def getStructureInfo (env : Environment) (structName : Name) : StructureInfo := else panic! "structure expected" +/-- +Gets the constructor of an inductive type that has exactly one constructor. +This is meant to be used with types that have had been registered as a structure by `registerStructure`, +but this is not checked. + +Warning: these do *not* need to be "structure-likes". A structure-like is non-recursive, +and structure-likes have special kernel support. +-/ def getStructureCtor (env : Environment) (constName : Name) : ConstructorVal := match env.find? constName with - | some (.inductInfo { isRec := false, ctors := [ctorName], .. }) => + | some (.inductInfo { ctors := [ctorName], .. }) => match env.find? ctorName with | some (ConstantInfo.ctorInfo val) => val | _ => panic! "ill-formed environment" @@ -222,9 +230,10 @@ def getStructureFieldsFlattened (env : Environment) (structName : Name) (include getStructureFieldsFlattenedAux env structName #[] includeSubobjectFields /-- -Return true if `constName` is the name of an inductive datatype +Returns true if `constName` is the name of an inductive datatype created using the `structure` or `class` commands. +These are inductive types for which structure information has been registered with `registerStructure`. See also `Lean.getStructureInfo?`. -/ def isStructure (env : Environment) (constName : Name) : Bool := @@ -269,18 +278,33 @@ partial def getPathToBaseStructureAux (env : Environment) (baseStructName : Name | some projFn => getPathToBaseStructureAux env baseStructName parentStructName (projFn :: path) /-- -If `baseStructName` is an ancestor structure for `structName`, then return a sequence of projection functions +If `baseStructName` is an ancestor structure for `structName`, then returns a sequence of projection functions to go from `structName` to `baseStructName`. -/ def getPathToBaseStructure? (env : Environment) (baseStructName : Name) (structName : Name) : Option (List Name) := getPathToBaseStructureAux env baseStructName structName [] -/-- Return true iff `constName` is the a non-recursive inductive datatype that has only one constructor. -/ +/-- +Returns true iff `constName` is a non-recursive inductive datatype that has only one constructor and no indices. + +Such types have special kernel support. This must be in sync with `is_structure_like`. +-/ def isStructureLike (env : Environment) (constName : Name) : Bool := match env.find? constName with | some (.inductInfo { isRec := false, ctors := [_], numIndices := 0, .. }) => true | _ => false +/-- +Returns the constructor of the structure named `constName` if it is a non-recursive single-constructor inductive type with no indices. +-/ +def getStructureLikeCtor? (env : Environment) (constName : Name) : Option ConstructorVal := + match env.find? constName with + | some (.inductInfo { isRec := false, ctors := [ctorName], numIndices := 0, .. }) => + match env.find? ctorName with + | some (ConstantInfo.ctorInfo val) => val + | _ => panic! "ill-formed environment" + | _ => none + /-- Return number of fields for a structure-like type -/ def getStructureLikeNumFields (env : Environment) (constName : Name) : Nat := match env.find? constName with diff --git a/tests/lean/run/inductive_rec_proj.lean b/tests/lean/run/inductive_rec_proj.lean new file mode 100644 index 000000000000..8f63a5d09bbb --- /dev/null +++ b/tests/lean/run/inductive_rec_proj.lean @@ -0,0 +1,56 @@ +/-! +# Tests for numeric projections of inductive types +-/ + +/-! +Non-recursive, no indices. +-/ +inductive I0 where + | mk (x : Nat) (xs : List Nat) +/-- info: fun v => v.1 : I0 → Nat -/ +#guard_msgs in #check fun (v : I0) => v.1 +/-- info: fun v => v.2 : I0 → List Nat -/ +#guard_msgs in #check fun (v : I0) => v.2 + +/-! +Recursive, no indices. +-/ +inductive I1 where + | mk (x : Nat) (xs : I1) +/-- info: fun v => v.1 : I1 → Nat -/ +#guard_msgs in #check fun (v : I1) => v.1 +/-- info: fun v => v.2 : I1 → I1 -/ +#guard_msgs in #check fun (v : I1) => v.2 + +/-! +Non-recursive, indices. +-/ +inductive I2 : Nat → Type where + | mk (x : Nat) (xs : List (Fin x)) : I2 (x + 1) +/-- info: fun v => v.1 : I2 2 → Nat -/ +#guard_msgs in #check fun (v : I2 2) => v.1 +/-- info: fun v => v.2 : (v : I2 2) → List (Fin v.1) -/ +#guard_msgs in #check fun (v : I2 2) => v.2 + +/-! +Recursive, indices. +-/ +inductive I3 : Nat → Type where + | mk (x : Nat) (xs : I3 (x + 1)) : I3 x +/-- info: fun v => v.1 : I3 2 → Nat -/ +#guard_msgs in #check fun (v : I3 2) => v.1 +/-- info: fun v => v.2 : (v : I3 2) → I3 (v.1 + 1) -/ +#guard_msgs in #check fun (v : I3 2) => v.2 + + +/-! +Make sure these can be compiled. +-/ +def f0_1 (v : I0) : Nat := v.1 +def f0_2 (v : I0) : List Nat := v.2 +def f1_1 (v : I1) : Nat := v.1 +def f1_2 (v : I1) : I1 := v.2 +def f2_1 (v : I2 n) : Nat := v.1 +def f2_2 (v : I2 n) : List (Fin (f2_1 v)) := v.2 +def f3_1 (v : I3 n) : Nat := v.1 +def f3_2 (v : I3 n) : I3 (f3_1 v + 1) := v.2