diff --git a/Strata/DDM/Format.lean b/Strata/DDM/Format.lean index 7a62d4ebc..c7487012e 100644 --- a/Strata/DDM/Format.lean +++ b/Strata/DDM/Format.lean @@ -111,7 +111,7 @@ def fvarName (ctx : FormatContext) (idx : FreeVarIndex) : String := else s!"fvar!{idx}" -protected def ofDialects (dialects : DialectMap) (globalContext : GlobalContext) (opts : FormatOptions) : FormatContext where +protected def ofDialects (dialects : DialectMap) (globalContext : GlobalContext := {}) (opts : FormatOptions := {}) : FormatContext where opts := opts getFnDecl sym := Id.run do let .function f := dialects.decl! sym @@ -441,13 +441,13 @@ private partial def OperationF.mformatM (op : OperationF α) : FormatM PrecForma end -instance Expr.instToStrataFormat : ToStrataFormat Expr where +instance Expr.instToStrataFormat : ToStrataFormat (ExprF α) where mformat e c s := e.mformatM #[] c s |>.fst -instance Arg.instToStrataFormat : ToStrataFormat Arg where +instance Arg.instToStrataFormat : ToStrataFormat (ArgF α) where mformat a c s := a.mformatM c s |>.fst -instance Operation.instToStrataFormat : ToStrataFormat Operation where +instance Operation.instToStrataFormat : ToStrataFormat (OperationF α) where mformat o c s := o.mformatM c s |>.fst namespace MetadataArg diff --git a/Strata/DDM/Integration/Lean/BoolConv.lean b/Strata/DDM/Integration/Lean/BoolConv.lean index 8e169b397..e6fde0bba 100644 --- a/Strata/DDM/Integration/Lean/BoolConv.lean +++ b/Strata/DDM/Integration/Lean/BoolConv.lean @@ -8,27 +8,29 @@ import Strata.DDM.Integration.Lean.OfAstM namespace Strata + /-- Convert Init.Bool inductive to OperationF -/ -def Bool.toAst {α} [Inhabited α] (v : Ann Bool α) : OperationF α := - if v.val then - ⟨v.ann, q`Init.boolTrue, #[]⟩ +def OperationF.ofBool {α} (ann : α) (b : Bool) : OperationF α := + if b then + { ann := ann, name := q`Init.boolTrue, args := #[] } else - ⟨v.ann, q`Init.boolFalse, #[]⟩ + { ann := ann, name := q`Init.boolFalse, args := #[] } /-- Convert OperationF to Init.Bool -/ -def Bool.ofAst {α} [Inhabited α] [Repr α] (op : OperationF α) : OfAstM (Ann Bool α) := - match op.name with - | q`Init.boolTrue => - if op.args.size = 0 then - pure ⟨op.ann, true⟩ - else - .error s!"boolTrue expects 0 arguments, got {op.args.size}" - | q`Init.boolFalse => - if op.args.size = 0 then - pure ⟨op.ann, false⟩ - else - .error s!"boolFalse expects 0 arguments, got {op.args.size}" - | _ => - .error s!"Unknown Bool operator: {op.name}" +def Bool.ofAst {α} [Inhabited α] [Repr α] (arg : ArgF α) : OfAstM Bool := do + match arg with + | .op op => + match op.name with + | q`Init.boolTrue => + if op.args.size ≠ 0 then + .error s!"boolTrue expects 0 arguments, got {op.args.size}" + pure true + | q`Init.boolFalse => + if op.args.size ≠ 0 then + .error s!"boolFalse expects 0 arguments, got {op.args.size}" + pure false + | _ => + .error s!"Unknown Bool operator: {op.name}" + | _ => .throwExpected "boolean" arg end Strata diff --git a/Strata/DDM/Integration/Lean/Gen.lean b/Strata/DDM/Integration/Lean/Gen.lean index ad90d42a5..4c359b809 100644 --- a/Strata/DDM/Integration/Lean/Gen.lean +++ b/Strata/DDM/Integration/Lean/Gen.lean @@ -11,9 +11,11 @@ import Strata.DDM.Integration.Lean.OfAstM import Strata.DDM.Integration.Lean.BoolConv import Strata.DDM.Util.Graph.Tarjan -open Lean (Command Name Ident Term TSyntax getEnv logError profileitM quote withTraceNode mkIdentFrom) +open Lean (Command Name Ident Term TSyntax addAndCompile getEnv logError) +open Lean (mkApp2 mkApp3 mkAppN mkCIdent mkConst mkIdentFrom) +open Lean (profileitM quote withTraceNode) open Lean.Elab (throwUnsupportedSyntax) -open Lean.Elab.Command (CommandElab CommandElabM elabCommand) +open Lean.Elab.Command (CommandElab CommandElabM elabCommand liftCoreM) open Lean.MonadOptions (getOptions) open Lean.MonadResolveName (getCurrNamespace) open Lean.Parser.Command (ctor) @@ -28,7 +30,7 @@ namespace Lean /-- Prepend the current namespace to the Lean name and convert to an identifier. -/ -def mkScopedIdent (scope : Name) (subName : Lean.Name) : Ident := +private def mkScopedIdent (scope : Name) (subName : Lean.Name) : Ident := let fullName := scope ++ subName let nameStr := toString subName .mk (.ident .none nameStr.toSubstring subName [.decl fullName []]) @@ -36,7 +38,7 @@ def mkScopedIdent (scope : Name) (subName : Lean.Name) : Ident := /-- Prepend the current namespace to the Lean name and convert to an identifier. -/ -def currScopedIdent {m} [Monad m] [Lean.MonadResolveName m] (subName : Lean.Name) : m Ident := do +private def currScopedIdent {m} [Monad m] [Lean.MonadResolveName m] (subName : Lean.Name) : m Ident := do (mkScopedIdent · subName) <$> getCurrNamespace end Lean @@ -54,7 +56,10 @@ abbrev LeanCategoryName := Lean.Name structure GenContext where -- Syntax for #strata_gen for source location purposes. src : Lean.Syntax - categoryNameMap : Std.HashMap QualifiedIdent String + /-- + Maps category identifiers to their relative Lean name. + -/ + categoryNameMap : Std.HashMap QualifiedIdent Name exprHasEta : Bool abbrev GenM := ReaderT GenContext CommandElabM @@ -71,7 +76,7 @@ private def genFreshLeanName (s : String) : GenM Name := do private def genFreshIdentPair (s : String) : GenM (Ident × Ident) := do let name ← genFreshLeanName s let src := (←read).src - return (mkIdentFrom src name true, mkIdentFrom src name) + return (mkIdentFrom (canonical := true) src name, mkIdentFrom src name) /-- Create a canonical identifier. -/ def mkCanIdent (src : Lean.Syntax) (val : Name) : Ident := @@ -179,8 +184,23 @@ structure DefaultCtor where The name in the Strata dialect for this constructor. If `none`, then this must be an auto generated constructor. -/ - strataName : Option QualifiedIdent + strataName : Option QualifiedIdent := none + /-- Whether the generated constructor should add the annotation. -/ + includeAnn : Bool := true + /-- + Argument declarations + -/ argDecls : Array GenArgDecl + /-- + Either annotations are included or there is a single argument we can get + the annotation from. + -/ + includeAnnInvariant : + includeAnn ∨ + if p : argDecls.size = 1 then + ¬(argDecls[0]'(p ▸ Nat.zero_lt_one)).unwrap + else + false := by simp def DefaultCtor.leanName (c : DefaultCtor) : Name := .str .anonymous c.leanNameStr @@ -249,8 +269,8 @@ def mkRootIdent (name : Name) : Ident := .mk (.ident .none name.toString.toSubstring rootName [.decl name []]) /-- -This maps category names in the Init that are already declared to their -representation. +This maps category names in the Init dialect that are already declared +to their fully qualified Lean name. -/ def declaredCategories : Std.HashMap CategoryName Name := .ofList [ (q`Init.Ident, ``String), @@ -383,16 +403,12 @@ partial def mkUsedCategories.aux (m : CatOpMap) (s : WorkSet CategoryName) : Cat match s.pop with | none => s.set | some (s, c) => - match c with - | q`Init.TypeP => - mkUsedCategories.aux m (s.add q`Init.Type) - | _ => - let ops := m.getD c #[] - let addArgs {α:Type} (f : α → CategoryName → α) (a : α) (op : CatOp) := - op.argDecls.foldl (init := a) fun r arg => arg.cat.foldOverAtomicCategories (init := r) f - let addName (pa : WorkSet CategoryName) (c : CategoryName) := pa.add c - let s := ops.foldl (init := s) (addArgs addName) - mkUsedCategories.aux m s + let ops := m.getD c #[] + let addArgs {α:Type} (f : α → CategoryName → α) (a : α) (op : CatOp) := + op.argDecls.foldl (init := a) fun r arg => arg.cat.foldOverAtomicCategories (init := r) f + let addName (pa : WorkSet CategoryName) (c : CategoryName) := pa.add c + let s := ops.foldl (init := s) (addArgs addName) + mkUsedCategories.aux m s def mkUsedCategories (m : CatOpMap) (d : Dialect) : CategorySet := let dname := d.name @@ -415,17 +431,34 @@ def mkUsedCategories (m : CatOpMap) (d : Dialect) : CategorySet := def mkStandardCtors (exprHasEta : Bool) (cat : QualifiedIdent) : Array DefaultCtor := match cat with | q`Init.Expr => + let fvar := { + leanNameStr := "fvar" + argDecls := #[{ name := "idx", cat := .atom .none q`Init.Num, unwrap := true}] + } if exprHasEta then - #[ - .mk "bvar" none #[{ name := "idx", cat := .atom .none q`Init.Num }], - .mk "lambda" none #[ - { name := "var", cat := .atom .none q`Init.Str }, - { name := "type", cat := .atom .none q`Init.Type }, - { name := "fn", cat := .atom .none cat } - ] + #[fvar, + { leanNameStr := "bvar" + argDecls := #[{ name := "idx", cat := .atom .none q`Init.Num, unwrap := true }] + }, + { + leanNameStr := "lambda" + argDecls := #[ + { name := "var", cat := .atom .none q`Init.Str }, + { name := "type", cat := .atom .none q`Init.Type }, + { name := "fn", cat := .atom .none cat } + ] + } ] else - #[] + #[fvar] + | q`Init.TypeP => + #[ + { leanNameStr := "expr", + includeAnn := false, + argDecls := #[{ name := "tp", cat := .atom .none q`Init.Type }] + }, + { leanNameStr := "type", argDecls := #[] } + ] | _ => #[] @@ -457,10 +490,7 @@ def CatOpMap.onlyUsedCategories (m : CatOpMap) (d : Dialect) (exprHasEta : Bool) let usedSet := mkUsedCategories m d m.fold (init := #[]) fun a cat ops => if cat ∉ declaredCategories ∧ cat ∈ usedSet then - let usedNames : Std.HashSet String := - match cat with - | q`Init.Expr => { "fvar" } - | _ => {} + let usedNames : Std.HashSet String := {} let standardCtors := mkStandardCtors exprHasEta cat let usedNames : Std.HashSet String := standardCtors.foldl (init := usedNames) fun m c => @@ -498,15 +528,9 @@ def orderedSyncatGroups (categories : Array (QualifiedIdent × Array DefaultCtor categories.foldl (init := g) fun g (cat, ops) => Id.run do let some resIdx := getIndex cat | panic! s!"Unknown category {cat}" - match cat with - | q`Init.TypeP => - let some typeIdx := getIndex q`Init.Type - | panic! s!"Unknown category Init.Type." - g.addEdge typeIdx resIdx - | _ => - ops.foldl (init := g) fun g op => - op.argDecls.foldl (init := g) fun g arg => - addArgIndices cat op.leanNameStr arg.cat g resIdx + ops.foldl (init := g) fun g op => + op.argDecls.foldl (init := g) fun g arg => + addArgIndices cat op.leanNameStr arg.cat g resIdx let indices := OutGraph.tarjan g indices.map (·.map (categories[·])) @@ -544,11 +568,11 @@ Prepend the current namespace to the Lean name and convert to an identifier. def mkScopedIdent {m} [Monad m] [Lean.MonadResolveName m] (subName : Lean.Name) : m Ident := (scopedIdent · subName) <$> getCurrNamespace -/-- Return identifier for operator with given name to suport category. -/ +/-- Return identifier for operator with given name to Lean name. -/ def getCategoryScopedName (cat : QualifiedIdent) : GenM Name := do match (←read).categoryNameMap[cat]? with | some catName => - return .mkSimple catName + return catName | none => return panic! s!"getCategoryScopedName given {cat}" @@ -558,37 +582,82 @@ def getCategoryIdent (cat : QualifiedIdent) : GenM Ident := do return mkRootIdent nm currScopedIdent (← getCategoryScopedName cat) +/-- +`getCategoryTerm cat annType` returns +-/ def getCategoryTerm (cat : QualifiedIdent) (annType : Ident) : GenM Term := do let catIdent ← mkScopedIdent (← getCategoryScopedName cat) - return Lean.Syntax.mkApp catIdent #[annType] + return mkApp catIdent #[annType] /-- Return identifier for operator with given name to suport category. -/ def getCategoryOpIdent (cat : QualifiedIdent) (name : Name) : GenM Ident := do currScopedIdent <| (← getCategoryScopedName cat) ++ name -partial def ppCatWithUnwrap (annType : Ident) (c : SyntaxCat) (unwrap : Bool) : GenM Term := do - let args ← c.args.mapM (ppCatWithUnwrap annType · false) - match c.name, eq : args.size with - | q`Init.CommaSepBy, 1 => - return mkCApp ``Ann #[mkCApp ``Array #[args[0]], annType] - | q`Init.Option, 1 => - return mkCApp ``Ann #[mkCApp ``Option #[args[0]], annType] - | q`Init.Seq, 1 => - return mkCApp ``Ann #[mkCApp ``Array #[args[0]], annType] - | cat, 0 => - match declaredCategories[cat]? with - | some nm => - -- Check if unwrap is specified - if unwrap && cat ∈ declaredCategories then - pure <| mkRootIdent nm -- Return unwrapped type - else - pure <| mkCApp ``Ann #[mkRootIdent nm, annType] - | none => do - getCategoryTerm cat annType - | f, _ => throwError "Unsupported {f.fullName}" +/-- +Maps builtin polymorphic categories to their Lean representation +-/ +def polymorphicBuiltinCategories : Std.HashMap QualifiedIdent Name := + .ofList [ + (q`Init.CommaSepBy, `Array), + (q`Init.Option, ``Option), + (q`Init.Seq, `Array), + ] + + +def polyCatMap : Std.HashMap QualifiedIdent Lean.Expr := .ofList [ + (q`Init.CommaSepBy, .const ``Array [0]), + (q`Init.Option, .const ``Option [0]), + (q`Init.Seq, .const ``Array [0]), +] + +private def annTypeExpr (base ann : Lean.Expr) := mkApp2 (mkConst ``Ann) base ann -partial def ppCat (annType : Ident) (c : SyntaxCat) : GenM Term := do - ppCatWithUnwrap annType c false +/-- +`getCategoryTerm cat annType` returns +-/ +def getCategoryExpr (cat : QualifiedIdent) (annType : Lean.Expr) : GenM Lean.Expr := do + let relName ← getCategoryScopedName cat + let catName := (← getCurrNamespace) ++ relName + let catType : Lean.Expr := mkConst catName + return .app catType annType + +def mkCatExpr (annType : Lean.Expr) (c : SyntaxCat) (unwrap : Bool) : GenM Lean.Expr := do + let args ← c.args.attach.mapM (fun ⟨sc, _⟩ => mkCatExpr annType sc false) + if let some nm := polymorphicBuiltinCategories[c.name]? then + assert! args.size == 1 + return annTypeExpr (mkAppN (.const nm [0]) args) annType + assert! args.size == 0 + match declaredCategories[c.name]? with + | some nm => + -- Check if unwrap is specified + if unwrap then + return mkConst nm -- Return unwrapped type + else + return annTypeExpr (mkConst nm) annType + | none => do + getCategoryExpr c.name annType +termination_by c +decreasing_by + cases c + decreasing_tactic + +/-- +Convert a category to a Lean term. +-/ +partial def ppCat (annType : Ident) (c : SyntaxCat) (wrap : Bool) : GenM Term := do + let args ← c.args.mapM (ppCat annType (wrap := true)) + let cat := c.name + if let some tp := polymorphicBuiltinCategories[cat]? then + let isTrue _ := inferInstanceAs (Decidable (args.size = 1)) + | throwError s!"internal: {cat} expects a single argument." + return mkCApp ``Ann #[mkCApp tp #[args[0]], annType] + if args.size ≠ 0 then + throwError "internal: Expected no arguments to {cat}." + if let some nm := declaredCategories[cat]? then + -- Check if unwrap is specified + let t := mkRootIdent nm + return if wrap then mkCApp ``Ann #[t, annType] else t + getCategoryTerm cat annType def elabCommands (commands : Array Command) : CommandElabM Unit := do let messageCount := (← get).messages.unreported.size @@ -625,31 +694,21 @@ def explicitBinder (name : String) (typeStx : Term) : CommandElabM BracketedBind def genCtor (annType : Ident) (op : DefaultCtor) : GenM (TSyntax ``ctor) := do let ctorId : Ident := localIdent op.leanNameStr + let ann ← + if op.includeAnn then do + pure #[← `(bracketedBinder| (ann : $annType))] + else + pure #[] let binders ← op.argDecls.mapM fun arg => do - explicitBinder arg.name (← ppCatWithUnwrap annType arg.cat arg.unwrap) - `(ctor| | $ctorId:ident (ann : $annType) $binders:bracketedBinder* ) + explicitBinder arg.name (← ppCat annType arg.cat (wrap := !arg.unwrap)) + `(ctor| | $ctorId:ident $ann:bracketedBinder* $binders:bracketedBinder*) def mkInductive (cat : QualifiedIdent) (ctors : Array DefaultCtor) : GenM Command := do assert! cat ∉ declaredCategories let ident ← mkScopedIdent (← getCategoryScopedName cat) trace[Strata.generator] "Generating {ident}" let annType := localIdent "α" - let builtinCtors : Array (TSyntax ``ctor) ← - match cat with - | q`Init.Expr => do - pure #[ - ← `(ctor| | $(localIdent "fvar"):ident (ann : $annType) (idx : Nat)) - ] - | q`Init.TypeP => do - let typeIdent ← getCategoryTerm q`Init.Type annType - pure #[ - ← `(ctor| | $(localIdent "expr"):ident (tp : $typeIdent)), - ← `(ctor| | $(localIdent "type"):ident (tp : $annType)) - ] - | _ => - pure #[] `(inductive $ident ($annType : Type) : Type where - $builtinCtors:ctor* $(← ctors.mapM (genCtor annType)):ctor* deriving Repr) @@ -660,7 +719,7 @@ def categoryToAstTypeIdent (cat : QualifiedIdent) (annType : Term) : Term := | q`Init.Type => ``Strata.TypeExprF | q`Init.TypeP => ``Strata.ArgF | _ => ``Strata.OperationF - Lean.Syntax.mkApp (mkRootIdent ident) #[annType] + mkApp (mkRootIdent ident) #[annType] structure ToOp where name : String @@ -673,57 +732,42 @@ def ofAstIdentM (cat : QualifiedIdent) : GenM Ident := do currScopedIdent <| (← getCategoryScopedName cat) ++ `ofAst def mkAnnWithTerm (argCtor : Name) (annTerm v : Term) : Term := - mkCApp argCtor #[mkCApp ``Ann.ann #[annTerm], v] + mkApp (mkCIdent argCtor) #[mkCApp ``Ann.ann #[annTerm], v] -def annToAst (argCtor : Name) (annTerm : Term) : Term := - mkCApp argCtor #[mkCApp ``Ann.ann #[annTerm], mkCApp ``Ann.val #[annTerm]] +def annToAst' (argCtor : Name) (term : Term) (unwrap : Bool) : Term := + if unwrap then + mkApp (mkCIdent argCtor) #[mkCApp ``default #[], term] + else + mkAnnWithTerm argCtor term (mkCApp ``Ann.val #[term]) + +partial def annArg (c : SyntaxCat) (unwrap : Bool) : GenM Ident := do + let cat := c.name + if cat ∈ polyCatMap then + assert! c.args.size == 1 + return mkIdentFrom (←read).src ``Ann.ann + assert! c.args.size == 0 + if cat ∈ declaredCategories then + assert! not unwrap + return mkIdentFrom (←read).src ``Ann.ann + getCategoryOpIdent cat `ann mutual -partial def toAstApplyArg (vn : Name) (cat : SyntaxCat) : GenM Term := do - toAstApplyArgWithUnwrap vn cat false - -partial def toAstApplyArgWithUnwrap (vn : Name) (cat : SyntaxCat) (unwrap : Bool) : GenM Term := do +partial def toAstApplyArg (vn : Name) (cat : SyntaxCat) (unwrap : Bool := false) : GenM Term := do let v := mkIdentFrom (←read).src vn match cat.name with | q`Init.Num => - if unwrap then - ``(ArgF.num default $v) - else - return annToAst ``ArgF.num v + return annToAst' ``ArgF.num v unwrap | q`Init.Bool => do - if unwrap then - -- When unwrapped, v is a plain Bool. Create OperationF directly based on the value. - let defaultAnn ← ``(default) - let emptyArray ← ``(#[]) - let trueOp := mkCApp ``OperationF.mk #[defaultAnn, quote q`Init.boolTrue, emptyArray] - let falseOp := mkCApp ``OperationF.mk #[defaultAnn, quote q`Init.boolFalse, emptyArray] - let opExpr ← ``(if $v then $trueOp else $falseOp) - ``(ArgF.op $opExpr) - else - -- When wrapped, v is already Ann Bool α - let boolToAst := mkCApp ``Strata.Bool.toAst #[v] - return mkCApp ``ArgF.op #[boolToAst] + return mkCApp ``ArgF.op #[annToAst' ``OperationF.ofBool v unwrap] | q`Init.Ident => - if unwrap then - ``(ArgF.ident default $v) - else - return annToAst ``ArgF.ident v + return annToAst' ``ArgF.ident v unwrap | q`Init.Str => - if unwrap then - ``(ArgF.strlit default $v) - else - return annToAst ``ArgF.strlit v + return annToAst' ``ArgF.strlit v unwrap | q`Init.Decimal => - if unwrap then - ``(ArgF.decimal default $v) - else - return annToAst ``ArgF.decimal v + return annToAst' ``ArgF.decimal v unwrap | q`Init.ByteArray => - if unwrap then - ``(ArgF.bytes default $v) - else - return annToAst ``ArgF.bytes v + return annToAst' ``ArgF.bytes v unwrap | cid@q`Init.Expr => do let toAst ← toAstIdentM cid return mkCApp ``ArgF.expr #[mkApp toAst #[v]] @@ -775,73 +819,86 @@ end abbrev MatchAlt := TSyntax ``Lean.Parser.Term.matchAlt -def toAstBuiltinMatches (cat : QualifiedIdent) : GenM (Array MatchAlt) := do - let src := (←read).src - match cat with - | q`Init.Expr => - let (annC, annI) ← genFreshIdentPair "ann" - let ctor ← getCategoryOpIdent cat `fvar - let pat : Term := mkApp ctor #[annC, mkCanIdent src `idx] - let rhs := mkCApp ``ExprF.fvar #[annI, mkIdentFrom src `idx] - return #[← `(matchAltExpr| | $pat => $rhs)] - | q`Init.TypeP => do - let (annC, annI) ← genFreshIdentPair "ann" - let typeC ← getCategoryOpIdent cat `type - let typeP : Term := mkApp typeC #[annC] - let typeCat := Lean.Syntax.mkCApp ``SyntaxCatF.atom #[annI, quote q`Init.Type] - let typeRhs := Lean.Syntax.mkCApp ``ArgF.cat #[typeCat] - let typeN ← genFreshLeanName "type" - let exprP := mkApp (← getCategoryOpIdent cat `expr) #[mkCanIdent src typeN] - let exprRhs ← toAstApplyArg typeN (.atom .none q`Init.Type) - return #[ - ← `(matchAltExpr| | $typeP => $typeRhs), - ← `(matchAltExpr| | $exprP => $exprRhs) - ] - | _ => - return #[] +def toAstExprMatch (op : DefaultCtor) (annT : Term) (args : Array GenArgDecl) (names : Vector Name args.size) : GenM Term := do + let lname := op.leanNameStr + if lname == "fvar" then + let .isTrue arg_size_eq := inferInstanceAs (Decidable (args.size = 1)) + | return panic! s!"fvar expected 1 argument" + let src := (←read).src + return mkCApp ``ExprF.fvar #[annT, mkIdentFrom src names[0]] + let some nm := op.strataName + | return panic! s!"Unexpected builtin expression {lname}" + let init := mkCApp ``ExprF.fn #[annT, quote nm] + Fin.foldlM args.size (init := init) fun a i => do + let nm := names[i] + let d := args[i] + let e ← toAstApplyArg nm d.cat d.unwrap + return Lean.Syntax.mkCApp ``ExprF.app #[annT, a, e] def toAstMatch (cat : QualifiedIdent) (op : DefaultCtor) : GenM MatchAlt := do let src := (←read).src let argDecls := op.argDecls - let (annC, annI) ← genFreshIdentPair "ann" let ctor : Ident ← getCategoryOpIdent cat op.leanName - let args ← argDecls.mapM fun arg => do - return (← genFreshLeanName arg.name, arg.cat, arg.unwrap) - let argTerms : Array Term := args.map fun p => mkCanIdent src p.fst - let pat : Term ← ``($ctor $annC $argTerms:term*) + let argc := argDecls.size + let argNames : Vector Name argc ← Vector.ofFnM fun (i : Fin argc) => + genFreshLeanName argDecls[i].name + let ((patArgs, annT) : Array Term × Term) ← + if h : op.includeAnn then + let (annC, annI) ← genFreshIdentPair "ann" + pure (#[(annC : Term)], (annI : Term)) + else + let argc1 : op.argDecls.size = 1 := by + have inv := op.includeAnnInvariant + grind + let d : GenArgDecl := op.argDecls[0] + let annF : Ident ← annArg d.cat d.unwrap + pure (#[], mkApp annF #[mkIdentFrom src argNames[0]]) + let pat := + let argTerms : Array Ident := argNames.map (mkCanIdent src) |>.toArray + mkApp ctor (patArgs ++ argTerms) let rhs : Term ← match cat with | q`Init.Expr => - let lname := op.leanNameStr - let some nm := op.strataName - | return panic! s!"Unexpected builtin expression {lname}" - let init := mkCApp ``ExprF.fn #[annI, quote nm] - args.foldlM (init := init) fun a (nm, tp, unwrap) => do - let e ← toAstApplyArgWithUnwrap nm tp unwrap - return Lean.Syntax.mkCApp ``ExprF.app #[annI, a, e] + toAstExprMatch op annT argDecls argNames | q`Init.Type => do let some nm := op.strataName | return panic! "Expected type name" let toAst ← toAstIdentM cat - let argTerms ← arrayLit <| args.map fun (v, c, _unwrap) => - assert! c.isType - Lean.Syntax.mkApp toAst #[mkIdentFrom src v] - pure <| Lean.Syntax.mkCApp ``TypeExprF.ident #[annI, quote nm, argTerms] + let argTerms ← arrayLit <| Array.ofFn fun (i : Fin argc) => + assert! argDecls[i].cat.isType + mkApp toAst #[mkIdentFrom src argNames[i]] + pure <| mkApp (mkCIdent ``TypeExprF.ident) #[annT, quote nm, argTerms] + | q`Init.TypeP => do + match op.leanNameStr with + | "expr" => + let toAst ← toAstIdentM q`Init.Type + let .isTrue p := inferInstanceAs (Decidable (argc = 1)) + | return panic! "Expected one argument." + assert! argDecls[0].cat.isType + let a := mkApp toAst #[mkIdentFrom src argNames[0]] + pure <| mkCApp ``ArgF.type #[a] + | "type" => + let c := mkCApp ``SyntaxCatF.atom #[annT, quote q`Init.Type] + pure <| mkCApp ``ArgF.cat #[c] + | _ => + return panic! "Unknown typeP op" | _ => let mName ← match op.strataName with | some n => pure n - | none => throwError s!"Internal: Operation requires strata name" - let argTerms : Array Term ← args.mapM fun (nm, tp, unwrap) => toAstApplyArgWithUnwrap nm tp unwrap - pure <| mkCApp ``OperationF.mk #[annI, quote mName, ← arrayLit argTerms] + | none => throwError s!"Internal: Operation {op.leanName} in {cat} requires strata name" + let argTerms : Array Term ← Array.ofFnM fun (i : Fin argc) => + let nm := argNames[i] + let d := argDecls[i] + toAstApplyArg nm d.cat d.unwrap + pure <| mkCApp ``OperationF.mk #[annT, quote mName, ← arrayLit argTerms] `(matchAltExpr| | $pat => $rhs) def genToAst (cat : QualifiedIdent) (ops : Array DefaultCtor) : GenM Command := do let annType := localIdent "α" let catTerm ← getCategoryTerm cat annType let astType : Term := categoryToAstTypeIdent cat annType - let cases ← toAstBuiltinMatches cat - let cases : Array MatchAlt ← ops.mapM_off (init := cases) (toAstMatch cat) + let cases : Array MatchAlt ← ops.mapM_off (toAstMatch cat) let toAst ← toAstIdentM cat trace[Strata.generator] "Generating {toAst}" let src := (←read).src @@ -849,69 +906,35 @@ def genToAst (cat : QualifiedIdent) (ops : Array DefaultCtor) : GenM Command := `(partial def $toAst {$annType : Type} [Inhabited $annType] ($(mkCanIdent src v) : $catTerm) : $astType := match $(mkIdentFrom src v):ident with $cases:matchAlt*) -mutual - -partial def getOfIdentArg (varName : String) (cat : SyntaxCat) (e : Term) : GenM Term := do - getOfIdentArgWithUnwrap varName cat false e +private def addAnn (act : Name) (e : Term) (unwrap : Bool) : Term := + let t := mkApp (mkCIdent act) #[e] + if unwrap then + t + else + mkCApp ``Functor.map #[mkCApp ``Ann.mk #[mkCApp ``ArgF.ann #[e]], t] -partial def getOfIdentArgWithUnwrap (varName : String) (cat : SyntaxCat) (unwrap : Bool) (e : Term) : GenM Term := do +partial def getOfIdentArg (varName : String) (cat : SyntaxCat) (e : Term) (unwrap : Bool := false) : GenM Term := do match cat.name with | q`Init.Num => - if unwrap then - ``((fun arg => match arg with - | ArgF.num _ val => pure val - | a => OfAstM.throwExpected "numeric literal" a) $e) - else - ``(OfAstM.ofNumM $e) + return addAnn ``OfAstM.ofNumM e unwrap | q`Init.Ident => - if unwrap then - ``((fun arg => match arg with - | ArgF.ident _ val => pure val - | a => OfAstM.throwExpected "identifier" a) $e) - else - ``(OfAstM.ofIdentM $e) + return addAnn ``OfAstM.ofIdentM e unwrap | q`Init.Str => - if unwrap then - ``((fun arg => match arg with - | ArgF.strlit _ val => pure val - | a => OfAstM.throwExpected "string literal" a) $e) - else - ``(OfAstM.ofStrlitM $e) + return addAnn ``OfAstM.ofStrlitM e unwrap | q`Init.Decimal => - if unwrap then - ``((fun arg => match arg with - | ArgF.decimal _ val => pure val - | a => OfAstM.throwExpected "decimal literal" a) $e) - else - ``(OfAstM.ofDecimalM $e) + return addAnn ``OfAstM.ofDecimalM e unwrap | q`Init.ByteArray => - if unwrap then - ``((fun arg => match arg with - | ArgF.bytes _ val => pure val - | a => OfAstM.throwExpected "byte array" a) $e) - else - ``(OfAstM.ofBytesM $e) + return addAnn ``OfAstM.ofBytesM e unwrap | q`Init.Bool => do - if unwrap then - -- When unwrapped, extract just the Bool value from Ann Bool α - ``((fun arg => match arg with - | ArgF.op op => Functor.map Ann.val (Strata.Bool.ofAst op) - | a => OfAstM.throwExpected "boolean" a) $e) - else - let (vc, vi) ← genFreshIdentPair varName - let boolOfAst := mkCApp ``Strata.Bool.ofAst #[vi] - ``(OfAstM.ofOperationM $e fun $vc _ => $boolOfAst) + return addAnn ``Strata.Bool.ofAst e unwrap | cid@q`Init.Expr => do - let (vc, vi) ← genFreshIdentPair <| varName ++ "_inner" let ofAst ← ofAstIdentM cid - ``(OfAstM.ofExpressionM $e fun $vc _ => $ofAst $vi) + let (vc, vi) ← genFreshIdentPair <| varName ++ "_inner" + return mkCApp ``OfAstM.ofExpressionM #[e, ←``(fun $vc _ => $ofAst $vi)] | cid@q`Init.Type => do - let (vc, vi) ← genFreshIdentPair varName let ofAst ← ofAstIdentM cid - ``(OfAstM.ofTypeM $e fun $vc _ => $ofAst $vi) - | cid@q`Init.TypeP => do - let ofAst ← ofAstIdentM cid - pure <| mkApp ofAst #[e] + let (vc, vi) ← genFreshIdentPair varName + return mkCApp ``OfAstM.ofTypeM #[e, ←``(fun $vc _ => $ofAst $vi)] | q`Init.CommaSepBy => do let c := cat.args[0]! let (vc, vi) ← genFreshIdentPair varName @@ -927,21 +950,22 @@ partial def getOfIdentArgWithUnwrap (varName : String) (cat : SyntaxCat) (unwrap let (vc, vi) ← genFreshIdentPair varName let body ← getOfIdentArg "e" c vi ``(OfAstM.ofSeqM $e fun $vc _ => $body) + | cid@q`Init.TypeP => do + let ofAst ← ofAstIdentM cid + pure <| mkApp ofAst #[e] | cid => do assert! cat.args.isEmpty let (vc, vi) ← genFreshIdentPair varName let ofAst ← ofAstIdentM cid ``(OfAstM.ofOperationM $e fun $vc _ => $ofAst $vi) -end - def ofAstArgs (argDecls : Array GenArgDecl) (argsVar : Ident) : GenM (Array Ident × Array (TSyntax ``doSeqItem)) := do let argCount := argDecls.size let args ← Array.ofFnM (n := argCount) fun ⟨i, _isLt⟩ => do let arg := argDecls[i] let (vc, vi) ← genFreshIdentPair <| arg.name ++ "_bind" let av ← ``($argsVar[$(quote i)]) - let rhs ← getOfIdentArgWithUnwrap arg.name arg.cat arg.unwrap av + let rhs ← getOfIdentArg arg.name arg.cat av (unwrap := arg.unwrap) let stmt ← `(doSeqItem| let $vc ← $rhs:term) return (vi, stmt) return args.unzip @@ -953,7 +977,7 @@ def ofAstMatch (nameIndexMap : Std.HashMap QualifiedIdent Nat) (op : DefaultCtor | return panic! s!"Unbound operator name {name}" `(matchAltExpr| | Option.some $(quote nameIndex) => $rhs) -def ofAstExprMatchRhs (cat : QualifiedIdent) (annI argsVar : Ident) (op : DefaultCtor) : GenM Term:= do +def ofAstExprMatchRhs (cat : QualifiedIdent) (annI argsVar : Ident) (op : DefaultCtor) : GenM Term := do let ctorIdent ← getCategoryOpIdent cat op.leanName let some nm := op.strataName | return panic! s!"Missing name for {op.leanName}" @@ -1042,7 +1066,11 @@ def genOfAst (cat : QualifiedIdent) (ops : Array DefaultCtor) : GenM (Array Comm let (annC, annI) ← genFreshIdentPair "ann" let (nameIndexMap, ofAstNameMap, cmd) ← createNameIndexMap cat ops let fvarCtorIdent ← getCategoryOpIdent cat `fvar - let cases : Array MatchAlt ← ops.mapM (ofAstExprMatch nameIndexMap cat annI (mkIdentFrom src argsVar)) + let cases : Array MatchAlt ← ops.filterMapM fun op => + if op.leanNameStr == "fvar" then + pure none + else + some <$> ofAstExprMatch nameIndexMap cat annI (mkIdentFrom src argsVar) op let rhs ← `(let vnf := ($(mkIdentFrom src v)).hnf let $(mkCanIdent src argsVar) := vnf.args.val @@ -1113,7 +1141,9 @@ def checkInhabited (cat : QualifiedIdent) (ops : Array DefaultCtor) : StateT Inh continue let ctor : Term ← getCategoryOpIdent cat op.leanName let d := Lean.mkCIdent ``default - let e := Lean.Syntax.mkApp ctor (Array.replicate (op.argDecls.size + 1) d) + let argc := if op.includeAnn then 1 else 0 + let argc := argc + op.argDecls.size + let e := mkApp ctor (Array.replicate argc d) StateT.lift <| runCmd <| elabCommand =<< `(instance [Inhabited $annType] : Inhabited $catTerm where default := $e) modify (·.insert cat) @@ -1130,6 +1160,74 @@ partial def addInhabited (group : Array (QualifiedIdent × Array DefaultCtor)) ( else pure sm +partial def annExpr (c : SyntaxCat) (unwrap : Bool) : GenM Name := do + let cat := c.name + if cat ∈ polyCatMap then + assert! c.args.size == 1 + return ``Ann.ann + if cat ∈ declaredCategories then + assert! c.args.size == 0 + assert! not unwrap + return ``Ann.ann + + assert! c.args.size == 0 + match (←read).categoryNameMap[cat]? with + | some catName => + return (← getCurrNamespace) ++ catName ++ `ann + | none => + return panic! s!"annExpr given {cat}" + + +def annRecursor (c : DefaultCtor) : GenM Lean.Expr := do + let argc := c.argDecls.size + let (inner_off, ann) ← + if h : c.includeAnn then + pure (2, Lean.Expr.bvar argc) + else + have ne : c.argDecls.size > 0 := by + have p := c.includeAnnInvariant + grind + let d := c.argDecls[0] + let annFn ← annExpr d.cat d.unwrap + pure (1, Lean.mkApp2 (.const annFn []) (.bvar (argc+1)) (.bvar (argc - 1))) + let inner : Lean.Expr ← Fin.foldrM argc (init := ann) fun i e => do + let a := c.argDecls[i] + let argType ← mkCatExpr (.bvar (inner_off + i)) a.cat a.unwrap + return .lam (.mkSimple a.name) argType (binderInfo := .default) e + if c.includeAnn then + return .lam `ann (.bvar 1) (binderInfo := .default) inner + else + return inner + +def genAnnFunctions (cat : QualifiedIdent) (ctors : Array DefaultCtor) : GenM Unit := do + let relName ← getCategoryScopedName cat + + let catName := (← getCurrNamespace) ++ relName + let catType : Lean.Expr := mkConst catName + let defName := catName ++ `ann + let type : Lean.Expr := + .forallE `α (.sort 1) (binderInfo := .implicit) <| + .forallE `_ (.app catType (.bvar 0)) (binderInfo := .default) <| + .bvar 1 + let motive : Lean.Expr := .lam `_ (.app catType (.bvar 1)) (binderInfo := .default) (.bvar 2) + let term : Lean.Expr := mkApp3 (.const (catName ++ `casesOn) [1]) (.bvar 1) motive (.bvar 0) + let term : Lean.Expr ← ctors.foldlM (init := term) fun f c => + return .app f (← annRecursor c) + --let term := mkApp2 (mkConst ``sorryAx [1]) (.bvar 1) (mkConst ``true) + let value : Lean.Expr := + .lam `α (.sort 1) (binderInfo := .implicit) <| + .lam `a (.app catType (.bvar 0)) (binderInfo := .default) <| + term + liftCoreM <| addAndCompile <| .defnDecl { + name := defName + levelParams := [] + type := type + value := value + hints := .opaque + safety := .safe + all := [defName] + } + def gen (categories : Array (QualifiedIdent × Array DefaultCtor)) : GenM Unit := do let mut inhabitedCats : InhabitedSet := Std.HashSet.ofArray @@ -1159,6 +1257,9 @@ def gen (categories : Array (QualifiedIdent × Array DefaultCtor)) : GenM Unit : profileitM Lean.Exception s!"Generating inhabited {cats}" (← getOptions) do addInhabited allCtors inhabitedCats let inhabitedCats := inhabitedCats2 + profileitM Lean.Exception s!"Generating ann functions {cats}" (← getOptions) do + allCtors.forM fun (cat, ctors) => do + genAnnFunctions cat ctors profileitM Lean.Exception s!"Generating toAstDefs {cats}" (← getOptions) do let toAstDefs ← allCtors.mapM fun (cat, ctors) => do genToAst cat ctors @@ -1172,18 +1273,18 @@ def gen (categories : Array (QualifiedIdent × Array DefaultCtor)) : GenM Unit : pure inhabitedCats inhabitedCats := s -def runGenM (src : Lean.Syntax) (pref : String) (catNames : Array QualifiedIdent) (exprHasEta : Bool) (m : GenM α) : CommandElabM α := do +def runGenM {α} (src : Lean.Syntax) (pref : String) (catNames : Array QualifiedIdent) (exprHasEta : Bool) (m : GenM α) : CommandElabM α := do let catNameCounts : Std.HashMap String Nat := catNames.foldl (init := {}) fun m k => m.alter k.name (fun v => some (v.getD 0 + 1)) let categoryNameMap := catNames.foldl (init := {}) fun m i => let name := if catNameCounts.getD i.name 0 > 1 then - s!"{i.dialect}_{i.name}" + .mkSimple s!"{i.dialect}_{i.name}" else if i.name ∈ reservedCats then - s!"{pref}{i.name}" + .mkSimple s!"{pref}{i.name}" else - i.name + .mkSimple i.name m.insert i name let ctx : GenContext := { src := src diff --git a/Strata/DDM/Integration/Lean/OfAstM.lean b/Strata/DDM/Integration/Lean/OfAstM.lean index 9ded96f52..02304eb06 100644 --- a/Strata/DDM/Integration/Lean/OfAstM.lean +++ b/Strata/DDM/Integration/Lean/OfAstM.lean @@ -130,24 +130,28 @@ def ofOperationM {α β} [Repr α] [SizeOf α] | .op a1 => act a1 (by decreasing_tactic) | a => .throwExpected "operation" a -def ofIdentM {α} [Repr α] : ArgF α → OfAstM (Ann String α) -| .ident ann val => pure { ann := ann, val := val } +@[inline] +def ofDecimalM {α} [Repr α] : ArgF α → OfAstM Decimal +| .decimal _ val => pure val +| a => .throwExpected "scientific literal" a + +@[inline] +def ofIdentM {α} [Repr α] : ArgF α → OfAstM String +| .ident _ val => pure val | a => .throwExpected "identifier" a -def ofNumM {α} [Repr α] : ArgF α → OfAstM (Ann Nat α) -| .num ann val => pure { ann := ann, val := val } +@[inline] +def ofNumM {α} [Repr α] : ArgF α → OfAstM Nat +| .num _ val => pure val | a => .throwExpected "numeric literal" a -def ofDecimalM {α} [Repr α] : ArgF α → OfAstM (Ann Decimal α) -| .decimal ann val => pure { ann := ann, val := val } -| a => .throwExpected "scientific literal" a - -def ofStrlitM {α} [Repr α] : ArgF α → OfAstM (Ann String α) -| .strlit ann val => pure { ann := ann, val := val } +@[inline] +def ofStrlitM {α} [Repr α] : ArgF α → OfAstM String +| .strlit _ val => pure val | a => .throwExpected "string literal" a -def ofBytesM {α} [Repr α] : ArgF α → OfAstM (Ann ByteArray α) -| .bytes ann val => pure { ann := ann, val := val } +def ofBytesM {α} [Repr α] : ArgF α → OfAstM ByteArray +| .bytes _ val => pure val | a => .throwExpected "byte array" a def ofOptionM {α β} [Repr α] [SizeOf α] diff --git a/StrataTest/DDM/Gen.lean b/StrataTest/DDM/Gen.lean index ed073024f..65ab8f690 100644 --- a/StrataTest/DDM/Gen.lean +++ b/StrataTest/DDM/Gen.lean @@ -4,7 +4,8 @@ SPDX-License-Identifier: Apache-2.0 OR MIT -/ -import Strata.DDM.Integration.Lean +import Strata.DDM.Integration.Lean.Gen +import Strata.DDM.Integration.Lean.HashCommands namespace Strata @@ -55,6 +56,7 @@ op mkMutACommaSep (a : CommaSepBy MutA) : MutACommaSep => a; namespace TestDialect +set_option trace.Strata.generator true #strata_gen TestDialect /-- @@ -98,6 +100,32 @@ TestDialect.TypeP.type : {α : Type} → α → TypeP α #guard_msgs in #print TypeP +/-- +info: def TestDialect.TypeP.ann : {α : Type} → TypeP α → α := +fun {α} a => TypeP.casesOn a (fun tp => tp.ann) fun ann => ann +-/ +#guard_msgs in +#print TypeP.ann + +/-- +info: inductive TestDialect.Expr : Type → Type +number of parameters: 1 +constructors: +TestDialect.Expr.fvar : {α : Type} → α → Nat → Expr α +TestDialect.Expr.trueExpr : {α : Type} → α → Expr α +TestDialect.Expr.and : {α : Type} → α → Expr α → Expr α → Expr α +TestDialect.Expr.lambda : {α : Type} → α → TestDialectType α → Bindings α → Expr α → Expr α +-/ +#guard_msgs in +#print Expr + +/-- +info: def TestDialect.Expr.ann : {α : Type} → Expr α → α := +fun {α} a => Expr.casesOn a (fun ann idx => ann) (fun ann => ann) (fun ann x y => ann) fun ann tp b res => ann +-/ +#guard_msgs in +#print Expr.ann + /-- info: Strata.ExprF.fvar () 1 -/