From c367f130cffa8f6b3c3c38007b6264116db0e8ba Mon Sep 17 00:00:00 2001 From: Danila Fedorin Date: Thu, 25 Jun 2026 14:29:37 -0500 Subject: [PATCH] Add tagging machinery to assign unique IDs to AST nodes Co-Authored-By: Claude Opus 4.8 --- lean/Spa/Language/Tagged/Basic.lean | 17 + .../Language/Tagged/DESCENDANT-TRACKING.md | 417 ++++++++++++++++++ lean/Spa/Language/Tagged/Derive.lean | 241 ++++++++++ lean/Spa/Language/Tagged/Graphs.lean | 63 +++ lean/Spa/Language/Tagged/Id.lean | 9 + lean/Spa/Language/Tagged/Properties.lean | 29 ++ lean/Spa/Language/Tagged/TODO.md | 46 ++ 7 files changed, 822 insertions(+) create mode 100644 lean/Spa/Language/Tagged/Basic.lean create mode 100644 lean/Spa/Language/Tagged/DESCENDANT-TRACKING.md create mode 100644 lean/Spa/Language/Tagged/Derive.lean create mode 100644 lean/Spa/Language/Tagged/Graphs.lean create mode 100644 lean/Spa/Language/Tagged/Id.lean create mode 100644 lean/Spa/Language/Tagged/Properties.lean create mode 100644 lean/Spa/Language/Tagged/TODO.md diff --git a/lean/Spa/Language/Tagged/Basic.lean b/lean/Spa/Language/Tagged/Basic.lean new file mode 100644 index 0000000..fb3a937 --- /dev/null +++ b/lean/Spa/Language/Tagged/Basic.lean @@ -0,0 +1,17 @@ +import Spa.Language.Base +import Spa.Language.Tagged.Id +import Spa.Language.Tagged.Derive + +derive_tagged Spa.Expr Spa.BasicStmt Spa.Stmt + +namespace Spa + +def tagStmt (s : Stmt) : Stmt.Tagged NodeId := (s.tag 0).1 + +def Stmt.Tagged.subtreeIds (s : Stmt.Tagged NodeId) : List NodeId := + s.foldTags (· :: ·) [] + +def Stmt.Tagged.isInLoopBody (body : Stmt.Tagged NodeId) (id : NodeId) : Bool := + decide (id ∈ body.subtreeIds) + +end Spa diff --git a/lean/Spa/Language/Tagged/DESCENDANT-TRACKING.md b/lean/Spa/Language/Tagged/DESCENDANT-TRACKING.md new file mode 100644 index 0000000..06cda3b --- /dev/null +++ b/lean/Spa/Language/Tagged/DESCENDANT-TRACKING.md @@ -0,0 +1,417 @@ +# Descendant tracking (parked) + +This is the formally-verified **interval-labeling / descendant** machinery that +used to live in `Id.lean` and `Properties.lean`. It let you decide "is node `a` +a descendant of node `b`?" with two integer comparisons on their identifiers, +and *proved* that numeric test equivalent to structural subtree containment. + +It was removed because the descendant test is a *computational optimization*: +the same question can be answered by walking the AST, and nothing in the current +pipeline needs the fast test yet. The proofs (a rose-tree flattening + a +postorder `Good` invariant) are a real mechanization cost to carry. Parked here +so it can be restored verbatim when LICM actually wants it. + +## What stays in the live code + +- `NodeId` collapses to a single unique index (`{ post : ℕ }`); `tag` still + assigns each node a distinct postorder number. +- The bidirectional mapping (`erase`/`tag` + `erase_tagStmt`) stays in + `Properties.lean`. +- The labelled-CFG id↔state mapping (`Cfg.lean`) is independent of this and is + unaffected. + +## Revival checklist + +1. In `Id.lean`, give `NodeId` back its descendant-count field and the test: + + ```lean + structure NodeId where + post : ℕ + desc : ℕ -- number of proper descendants (subtree size − 1); leaf = 0 + deriving DecidableEq, Repr + + namespace NodeId + + /-- Left endpoint of the node's postorder interval `[lo, post]`. -/ + def lo (a : NodeId) : ℕ := a.post - a.desc + + /-- `a` is a descendant-or-self of `b`: `a.post` lies in `b`'s interval. -/ + def DescendantOf (a b : NodeId) : Prop := b.lo ≤ a.post ∧ a.post ≤ b.post + + instance (a b : NodeId) : Decidable (DescendantOf a b) := by + unfold DescendantOf; infer_instance + + end NodeId + ``` + +2. In `Derive.lean`, make the generated `tag` store the descendant count again: + change the emitted identifier in `mkTag` from `(⟨$last⟩ : $nId)` back to + `(⟨$last, $last - n⟩ : $nId)`. + +3. Paste the Lean block below back into `Properties.lean` (after the round-trip + theorems). It builds against the `id.lo = lo`-premise form of `Good` and the + childcount (`desc`) identifier. The headline result is + `descendant_iff_tagStmt`; everything else is supporting machinery. + +## The parked proofs + +```lean +/-- A rose tree of identifiers: the uniform shape underlying all three tagged +AST types, used to reason about the postorder labeling generically. -/ +inductive IdTree where + | node (id : NodeId) (children : List IdTree) + +namespace IdTree + +def rootId : IdTree → NodeId + | .node id _ => id + +@[simp] theorem rootId_node (id : NodeId) (cs : List IdTree) : + (IdTree.node id cs).rootId = id := rfl + +mutual +def subtrees : IdTree → List IdTree + | .node id cs => .node id cs :: subtreesList cs +def subtreesList : List IdTree → List IdTree + | [] => [] + | c :: cs => subtrees c ++ subtreesList cs +end + +@[simp] theorem subtrees_node (id : NodeId) (cs : List IdTree) : + subtrees (.node id cs) = .node id cs :: subtreesList cs := rfl + +@[simp] theorem subtreesList_nil : subtreesList [] = [] := rfl + +@[simp] theorem subtreesList_cons (c : IdTree) (cs : List IdTree) : + subtreesList (c :: cs) = subtrees c ++ subtreesList cs := rfl + +def posts (t : IdTree) : List ℕ := (subtrees t).map (fun s => s.rootId.post) + +def postsList (cs : List IdTree) : List ℕ := (subtreesList cs).map (fun s => s.rootId.post) + +@[simp] theorem posts_node (id : NodeId) (cs : List IdTree) : + posts (.node id cs) = id.post :: postsList cs := rfl + +@[simp] theorem postsList_nil : postsList [] = [] := rfl + +@[simp] theorem postsList_cons (c : IdTree) (cs : List IdTree) : + postsList (c :: cs) = posts c ++ postsList cs := by + simp [posts, postsList] + +end IdTree + +def Expr.Tagged.toIdTree : Expr.Tagged NodeId → IdTree + | .add t a b => .node t [a.toIdTree, b.toIdTree] + | .sub t a b => .node t [a.toIdTree, b.toIdTree] + | .var t _ => .node t [] + | .num t _ => .node t [] + +def BasicStmt.Tagged.toIdTree : BasicStmt.Tagged NodeId → IdTree + | .assign t _ e => .node t [e.toIdTree] + | .noop t => .node t [] + +def Stmt.Tagged.toIdTree : Stmt.Tagged NodeId → IdTree + | .basic t bs => .node t [bs.toIdTree] + | .andThen t a b => .node t [a.toIdTree, b.toIdTree] + | .ifElse t e a b => .node t [e.toIdTree, a.toIdTree, b.toIdTree] + | .whileLoop t e s => .node t [e.toIdTree, s.toIdTree] + +mutual +inductive Good : ℕ → IdTree → Prop + | mk {lo : ℕ} {id : NodeId} {cs : List IdTree} : + id.lo = lo → GoodChildren lo cs id.post → + Good lo (.node id cs) +inductive GoodChildren : ℕ → List IdTree → ℕ → Prop + | nil {pos : ℕ} : GoodChildren pos [] pos + | cons {cur : ℕ} {c : IdTree} {cs : List IdTree} {pos : ℕ} : + Good cur c → GoodChildren (c.rootId.post + 1) cs pos → + GoodChildren cur (c :: cs) pos +end + +theorem Good.lo_le_post {lo : ℕ} {t : IdTree} (h : Good lo t) : lo ≤ t.rootId.post := by + cases h with + | mk hlo _ => simp only [NodeId.lo] at hlo; simp only [IdTree.rootId_node]; omega + +theorem GoodChildren.cur_le_pos : ∀ {cur : ℕ} (cs : List IdTree) {pos : ℕ}, + GoodChildren cur cs pos → cur ≤ pos + | _, [], _, h => by cases h; exact le_rfl + | _, c :: cs, _, h => by + cases h with + | cons hc hcs => + have := hc.lo_le_post + have := GoodChildren.cur_le_pos cs hcs + omega + +mutual +theorem Good.mem_posts : ∀ {lo : ℕ} (t : IdTree), Good lo t → + ∀ x, x ∈ IdTree.posts t ↔ lo ≤ x ∧ x ≤ t.rootId.post + | _, .node id cs, h, x => by + cases h with + | mk hlo hch => + simp only [IdTree.posts_node, List.mem_cons, IdTree.rootId_node] + rw [GoodChildren.mem_postsList cs hch x] + simp only [NodeId.lo] at hlo + omega +theorem GoodChildren.mem_postsList : ∀ {cur : ℕ} (cs : List IdTree) {pos : ℕ}, + GoodChildren cur cs pos → ∀ x, x ∈ IdTree.postsList cs ↔ cur ≤ x ∧ x < pos + | _, [], _, h, x => by + cases h + simp only [IdTree.postsList_nil] + constructor + · intro hx; exact absurd hx (List.not_mem_nil x) + · rintro ⟨h1, h2⟩; exfalso; omega + | _, c :: cs, _, h, x => by + cases h with + | cons hc hcs => + simp only [IdTree.postsList_cons, List.mem_append] + rw [Good.mem_posts c hc x, GoodChildren.mem_postsList cs hcs x] + have := hc.lo_le_post + have := GoodChildren.cur_le_pos cs hcs + omega +end + +mutual +theorem Good.nodup_posts : ∀ {lo : ℕ} (t : IdTree), Good lo t → (IdTree.posts t).Nodup + | _, .node id cs, h => by + cases h with + | mk hlo hch => + simp only [IdTree.posts_node, List.nodup_cons] + refine ⟨?_, GoodChildren.nodup_postsList cs hch⟩ + intro hmem + rw [GoodChildren.mem_postsList cs hch id.post] at hmem + omega +theorem GoodChildren.nodup_postsList : ∀ {cur : ℕ} (cs : List IdTree) {pos : ℕ}, + GoodChildren cur cs pos → (IdTree.postsList cs).Nodup + | _, [], _, h => by cases h; simp only [IdTree.postsList_nil, List.nodup_nil] + | _, c :: cs, _, h => by + cases h with + | cons hc hcs => + simp only [IdTree.postsList_cons, List.nodup_append] + refine ⟨Good.nodup_posts c hc, GoodChildren.nodup_postsList cs hcs, ?_⟩ + intro x hx1 hx2 + rw [Good.mem_posts c hc x] at hx1 + rw [GoodChildren.mem_postsList cs hcs x] at hx2 + omega +end + +mutual +theorem Good.subtree_good : ∀ {lo : ℕ} (t : IdTree), Good lo t → + ∀ s ∈ IdTree.subtrees t, Good s.rootId.lo s + | _, .node id cs, h, s, hs => by + cases h with + | mk hlo hch => + rw [IdTree.subtrees_node, List.mem_cons] at hs + rcases hs with rfl | hs + · simp only [IdTree.rootId_node]; rw [hlo]; exact Good.mk hlo hch + · exact GoodChildren.subtree_good cs hch s hs +theorem GoodChildren.subtree_good : ∀ {cur : ℕ} (cs : List IdTree) {pos : ℕ}, + GoodChildren cur cs pos → ∀ s ∈ IdTree.subtreesList cs, Good s.rootId.lo s + | _, [], _, _, s, hs => by simp only [IdTree.subtreesList_nil, List.not_mem_nil] at hs + | _, c :: cs, _, h, s, hs => by + cases h with + | cons hc hcs => + rw [IdTree.subtreesList_cons, List.mem_append] at hs + rcases hs with hs | hs + · exact Good.subtree_good c hc s hs + · exact GoodChildren.subtree_good cs hcs s hs +end + +mutual +theorem IdTree.subtrees_subset : ∀ (t : IdTree) {b : IdTree}, + b ∈ subtrees t → subtrees b ⊆ subtrees t + | .node id cs, b, hb => by + rw [subtrees_node, List.mem_cons] at hb + rcases hb with rfl | hb + · exact fun _ h => h + · intro x hx + rw [subtrees_node, List.mem_cons] + exact Or.inr (IdTree.subtreesList_subset cs hb hx) +theorem IdTree.subtreesList_subset : ∀ (cs : List IdTree) {b : IdTree}, + b ∈ subtreesList cs → subtrees b ⊆ subtreesList cs + | [], b, hb => by simp only [subtreesList_nil, List.not_mem_nil] at hb + | c :: cs, b, hb => by + rw [subtreesList_cons, List.mem_append] at hb + intro x hx + rw [subtreesList_cons, List.mem_append] + rcases hb with hb | hb + · exact Or.inl (IdTree.subtrees_subset c hb hx) + · exact Or.inr (IdTree.subtreesList_subset cs hb hx) +end + +theorem IdTree.eq_of_post_eq {l : List IdTree} + (h : (l.map (fun s => s.rootId.post)).Nodup) {a c : IdTree} + (ha : a ∈ l) (hc : c ∈ l) (hpost : a.rootId.post = c.rootId.post) : a = c := by + induction l with + | nil => exact absurd ha (List.not_mem_nil a) + | cons d ds ih => + simp only [List.map_cons, List.nodup_cons] at h + obtain ⟨hd, htl⟩ := h + simp only [List.mem_cons] at ha hc + rcases ha with rfl | ha <;> rcases hc with rfl | hc + · rfl + · exfalso; apply hd; rw [hpost]; exact List.mem_map_of_mem _ hc + · exfalso; apply hd; rw [← hpost]; exact List.mem_map_of_mem _ ha + · exact ih htl ha hc + +theorem descendant_iff_of_good {lo : ℕ} {t : IdTree} (hg : Good lo t) + {a b : IdTree} (ha : a ∈ IdTree.subtrees t) (hb : b ∈ IdTree.subtrees t) : + a.rootId.DescendantOf b.rootId ↔ a ∈ IdTree.subtrees b := by + have hgb : Good b.rootId.lo b := Good.subtree_good t hg b hb + constructor + · rintro ⟨h1, h2⟩ + have hmem : a.rootId.post ∈ IdTree.posts b := by + rw [Good.mem_posts b hgb a.rootId.post]; exact ⟨h1, h2⟩ + rw [IdTree.posts, List.mem_map] at hmem + obtain ⟨c, hc_mem, hc_post⟩ := hmem + have hc_t : c ∈ IdTree.subtrees t := IdTree.subtrees_subset t hb hc_mem + have hac : a = c := + IdTree.eq_of_post_eq (hg.nodup_posts t) ha hc_t hc_post.symm + rw [hac]; exact hc_mem + · intro hsub + have hmem : a.rootId.post ∈ IdTree.posts b := by + rw [IdTree.posts, List.mem_map]; exact ⟨a, hsub, rfl⟩ + rw [Good.mem_posts b hgb a.rootId.post] at hmem + exact hmem + +/-! ### Tagging produces a good tree + +We bridge from the `tag` traversal to the abstract `Good` invariant, by induction +on the plain AST. Each lemma also records that the returned counter is one past +the root's postorder index. -/ + +theorem Expr.tag_spec : ∀ (e : Expr) (n : ℕ), + Good n (e.tag n).1.toIdTree ∧ (e.tag n).1.toIdTree.rootId.post + 1 = (e.tag n).2 := by + intro e + induction e with + | num k => + intro n + refine ⟨?_, ?_⟩ + · simp only [Expr.tag, Expr.Tagged.toIdTree] + exact Good.mk (by simp only [NodeId.lo]; omega) GoodChildren.nil + · simp only [Expr.tag, Expr.Tagged.toIdTree, IdTree.rootId_node] + | var x => + intro n + refine ⟨?_, ?_⟩ + · simp only [Expr.tag, Expr.Tagged.toIdTree] + exact Good.mk (by simp only [NodeId.lo]; omega) GoodChildren.nil + · simp only [Expr.tag, Expr.Tagged.toIdTree, IdTree.rootId_node] + | add a b iha ihb => + intro n + obtain ⟨gA, pA⟩ := iha n + obtain ⟨gB, pB⟩ := ihb (a.tag n).2 + have lA := gA.lo_le_post + have lB := gB.lo_le_post + refine ⟨?_, ?_⟩ + · simp only [Expr.tag, Expr.Tagged.toIdTree] + refine Good.mk ?_ ?_ + · simp only [NodeId.lo]; omega + · refine GoodChildren.cons gA ?_ + rw [pA]; refine GoodChildren.cons gB ?_; rw [pB]; exact GoodChildren.nil + · simp only [Expr.tag, Expr.Tagged.toIdTree, IdTree.rootId_node] + | sub a b iha ihb => + intro n + obtain ⟨gA, pA⟩ := iha n + obtain ⟨gB, pB⟩ := ihb (a.tag n).2 + have lA := gA.lo_le_post + have lB := gB.lo_le_post + refine ⟨?_, ?_⟩ + · simp only [Expr.tag, Expr.Tagged.toIdTree] + refine Good.mk ?_ ?_ + · simp only [NodeId.lo]; omega + · refine GoodChildren.cons gA ?_ + rw [pA]; refine GoodChildren.cons gB ?_; rw [pB]; exact GoodChildren.nil + · simp only [Expr.tag, Expr.Tagged.toIdTree, IdTree.rootId_node] + +theorem BasicStmt.tag_spec : ∀ (bs : BasicStmt) (n : ℕ), + Good n (bs.tag n).1.toIdTree ∧ (bs.tag n).1.toIdTree.rootId.post + 1 = (bs.tag n).2 := by + intro bs + cases bs with + | noop => + intro n + refine ⟨?_, ?_⟩ + · simp only [BasicStmt.tag, BasicStmt.Tagged.toIdTree] + exact Good.mk (by simp only [NodeId.lo]; omega) GoodChildren.nil + · simp only [BasicStmt.tag, BasicStmt.Tagged.toIdTree, IdTree.rootId_node] + | assign x e => + intro n + obtain ⟨gE, pE⟩ := Expr.tag_spec e n + have lE := gE.lo_le_post + refine ⟨?_, ?_⟩ + · simp only [BasicStmt.tag, BasicStmt.Tagged.toIdTree] + refine Good.mk ?_ ?_ + · simp only [NodeId.lo]; omega + · refine GoodChildren.cons gE ?_ + rw [pE]; exact GoodChildren.nil + · simp only [BasicStmt.tag, BasicStmt.Tagged.toIdTree, IdTree.rootId_node] + +theorem Stmt.tag_spec : ∀ (s : Stmt) (n : ℕ), + Good n (s.tag n).1.toIdTree ∧ (s.tag n).1.toIdTree.rootId.post + 1 = (s.tag n).2 := by + intro s + induction s with + | basic bs => + intro n + obtain ⟨gBs, pBs⟩ := BasicStmt.tag_spec bs n + have lBs := gBs.lo_le_post + refine ⟨?_, ?_⟩ + · simp only [Stmt.tag, Stmt.Tagged.toIdTree] + refine Good.mk ?_ ?_ + · simp only [NodeId.lo]; omega + · refine GoodChildren.cons gBs ?_ + rw [pBs]; exact GoodChildren.nil + · simp only [Stmt.tag, Stmt.Tagged.toIdTree, IdTree.rootId_node] + | andThen a b iha ihb => + intro n + obtain ⟨gA, pA⟩ := iha n + obtain ⟨gB, pB⟩ := ihb (a.tag n).2 + have lA := gA.lo_le_post + have lB := gB.lo_le_post + refine ⟨?_, ?_⟩ + · simp only [Stmt.tag, Stmt.Tagged.toIdTree] + refine Good.mk ?_ ?_ + · simp only [NodeId.lo]; omega + · refine GoodChildren.cons gA ?_ + rw [pA]; refine GoodChildren.cons gB ?_; rw [pB]; exact GoodChildren.nil + · simp only [Stmt.tag, Stmt.Tagged.toIdTree, IdTree.rootId_node] + | ifElse e a b iha ihb => + intro n + obtain ⟨gE, pE⟩ := Expr.tag_spec e n + obtain ⟨gA, pA⟩ := iha (e.tag n).2 + obtain ⟨gB, pB⟩ := ihb (a.tag (e.tag n).2).2 + have lE := gE.lo_le_post + have lA := gA.lo_le_post + have lB := gB.lo_le_post + refine ⟨?_, ?_⟩ + · simp only [Stmt.tag, Stmt.Tagged.toIdTree] + refine Good.mk ?_ ?_ + · simp only [NodeId.lo]; omega + · refine GoodChildren.cons gE ?_ + rw [pE]; refine GoodChildren.cons gA ?_ + rw [pA]; refine GoodChildren.cons gB ?_; rw [pB]; exact GoodChildren.nil + · simp only [Stmt.tag, Stmt.Tagged.toIdTree, IdTree.rootId_node] + | whileLoop e s ih => + intro n + obtain ⟨gE, pE⟩ := Expr.tag_spec e n + obtain ⟨gS, pS⟩ := ih (e.tag n).2 + have lE := gE.lo_le_post + have lS := gS.lo_le_post + refine ⟨?_, ?_⟩ + · simp only [Stmt.tag, Stmt.Tagged.toIdTree] + refine Good.mk ?_ ?_ + · simp only [NodeId.lo]; omega + · refine GoodChildren.cons gE ?_ + rw [pE]; refine GoodChildren.cons gS ?_; rw [pS]; exact GoodChildren.nil + · simp only [Stmt.tag, Stmt.Tagged.toIdTree, IdTree.rootId_node] + +/-- A freshly tagged program is a well-tagged tree (rooted at postorder start `0`). -/ +theorem good_tagStmt (s : Stmt) : Good 0 (tagStmt s).toIdTree := + (Stmt.tag_spec s 0).1 + +/-- **Descendant characterization.** The numeric `NodeId.DescendantOf` relation on +two nodes of a tagged program holds exactly when one is structurally contained in +the other's subtree. -/ +theorem descendant_iff_tagStmt (s : Stmt) {a b : IdTree} + (ha : a ∈ IdTree.subtrees (tagStmt s).toIdTree) + (hb : b ∈ IdTree.subtrees (tagStmt s).toIdTree) : + a.rootId.DescendantOf b.rootId ↔ a ∈ IdTree.subtrees b := + descendant_iff_of_good (good_tagStmt s) ha hb +``` diff --git a/lean/Spa/Language/Tagged/Derive.lean b/lean/Spa/Language/Tagged/Derive.lean new file mode 100644 index 0000000..e946e9d --- /dev/null +++ b/lean/Spa/Language/Tagged/Derive.lean @@ -0,0 +1,241 @@ +import Lean +import Mathlib.Tactic.DeriveTraversable +import Spa.Language.Base +import Spa.Language.Tagged.Id + +/-! +# The `derive_tagged` command + +`derive_tagged T₁ T₂ … Tₙ` takes a family of (possibly mutually recursive) +inductive types and generates, for each `Tᵢ`: + +* a *tagged* mirror inductive `Tᵢ.Tagged (τ : Type)`, in which every constructor + carries a leading `tag : τ` field and every field whose type is a family + member is retyped to its `.Tagged τ` counterpart; +* `Tᵢ.Tagged.erase : Tᵢ.Tagged τ → Tᵢ`, forgetting all tags; +* `Tᵢ.tag : Tᵢ → ℕ → Tᵢ.Tagged NodeId × ℕ`, assigning every node a unique + `NodeId` (its postorder index) by a single unified traversal that threads a + counter; the whole family shares one counter, so identifiers are unique across + types. + +The generated declarations have exactly the shape of the hand-written reference; +see `Spa/Language/Tagged/Basic.lean` (which invokes this command) and the proofs +in `Spa/Language/Tagged/Properties.lean`. + +Scope: the generator handles non-indexed inductives whose constructor fields are +either scalars or *direct* references to a family member (which covers the object +language). Nested occurrences such as `List Tᵢ` are not supported. +-/ + +open Lean Elab Command Meta + +namespace Spa.DeriveTagged + +/-- One constructor field, classified as a recursive family reference or a scalar +(whose type syntax we keep verbatim for the mirror inductive). -/ +structure FieldData where + isRec : Bool + recType : Name + typeStx : Term + +/-- A constructor: its original (full) name, short name, and fields. -/ +structure CtorData where + origName : Name + shortName : Name + fields : Array FieldData + +/-- A family member together with its constructors. -/ +structure TypeData where + name : Name + ctors : Array CtorData + +def taggedOf (n : Name) : Name := n ++ `Tagged +def eraseOf (n : Name) : Name := n ++ `Tagged ++ `erase +def rootTagOf (n : Name) : Name := n ++ `Tagged ++ `rootTag +def tagOf (n : Name) : Name := n ++ `tag +def foldTagsOf (n : Name) : Name := n ++ `Tagged ++ `foldTags + +/-- Inspect the family, classifying each constructor field. -/ +def gather (family : Array Name) (τ : Ident) : TermElabM (Array TypeData) := do + let famSet : NameSet := family.foldl (·.insert ·) {} + family.mapM fun tn => do + let iv ← getConstInfoInduct tn + let ctors ← iv.ctors.toArray.mapM fun cn => do + let cv ← getConstInfoCtor cn + let fields ← forallTelescopeReducing cv.type fun args _ => do + let fieldArgs := args.extract iv.numParams args.size + fieldArgs.mapM fun a => do + let ty ← inferType a + match ty.getAppFn.constName? with + | some hn => + if famSet.contains hn then + return { isRec := true, recType := hn, typeStx := ← `($(mkIdent (taggedOf hn)) $τ) } + else + return { isRec := false, recType := default, typeStx := ← Lean.PrettyPrinter.delab ty } + | none => + return { isRec := false, recType := default, typeStx := ← Lean.PrettyPrinter.delab ty } + return { origName := cn, shortName := cn.componentsRev.head!, fields } + return { name := tn, ctors } + +/-- The arrow type `τ → → Self τ` of a tagged constructor. -/ +def ctorArrow (cd : CtorData) (self : Term) (τ : Ident) : TermElabM Term := do + let mut t := self + for f in cd.fields.reverse do + t ← `($(f.typeStx) → $t) + `($τ → $t) + +/-- The tagged mirror inductives, one per family member. The family is a DAG +(`Expr ← BasicStmt ← Stmt`), not genuinely mutual, so they are emitted as +separate inductives in dependency order rather than a `mutual` block. + +`Functor`/`Traversable` instances are derived separately by `mkDeriveInstances` +below rather than via an inline `deriving` clause. -/ +def mkInductives (tds : Array TypeData) (τ : Ident) : + CommandElabM (Array (TSyntax `command)) := do + tds.mapM fun td => do + let self ← `($(mkIdent (taggedOf td.name)) $τ) + let ctors ← td.ctors.mapM fun cd => do + let aty ← Command.liftTermElabM (ctorArrow cd self τ) + `(Lean.Parser.Command.ctor| | $(mkIdent cd.shortName):ident : $aty) + `(command| inductive $(mkIdent (taggedOf td.name)):ident ($τ : Type) where $ctors*) + +/-- A `deriving instance Functor, Traversable for Tᵢ.Tagged` command per family +member. Since every tagged type is a single-parameter, direct-recursive +inductive in `τ`, Mathlib's deriving handler produces clean (`sorry`-free) +instances, giving `map`, `traverse`, and the `Traversable.foldr`/`toList` folds +for free. + +These are emitted as *separate* commands in dependency order (rather than an +inline `deriving` clause on each inductive) for two reasons: deriving +`Stmt.Tagged` needs the `Expr.Tagged`/`BasicStmt.Tagged` instances already in +scope, and — because every member's type name ends in `.Tagged` — the handler's +auto-generated instance name (`instFunctorTagged`, built from the type's last +component) collides across the family unless each derive sees the environment +the previous one updated; separate commands give it that, so the names +disambiguate to `instFunctorTagged`, `instFunctorTagged_1`, …. + +The hand-written `foldTags` is retained alongside these: it is a +structural-recursion fold that `simp`/`decide` reduce cleanly, unlike the +abstract `Traversable.foldr` (defined via the `FreeMonoid`/`Const` applicative), +which reduces under `decide`/`rfl` but not naive `simp` unfolding. -/ +def mkDeriveInstances (tds : Array TypeData) : CommandElabM (Array (TSyntax `command)) := do + tds.mapM fun td => + `(command| deriving instance Functor, Traversable for $(mkIdent (taggedOf td.name))) + +/-- The `erase` functions, one per family member (separate defs in dependency +order — each calls only already-defined lower members). -/ +def mkErase (tds : Array TypeData) : CommandElabM (Array (TSyntax `command)) := do + tds.mapM fun td => do + let mut pats : Array Term := #[] + let mut rhss : Array Term := #[] + for cd in td.ctors do + let argNames := (Array.range cd.fields.size).map (fun i => mkIdent (.mkSimple s!"a{i}")) + let pat ← `($(mkIdent (taggedOf td.name ++ cd.shortName)) _ $argNames*) + let eraseArgs ← (cd.fields.zip argNames).mapM fun (f, a) => + if f.isRec then `($(mkIdent (eraseOf f.recType)) $a) else pure a + let rhs ← `($(mkIdent cd.origName) $eraseArgs*) + pats := pats.push pat + rhss := rhss.push rhs + `(command| def $(mkIdent (eraseOf td.name)) {τ : Type} : + $(mkIdent (taggedOf td.name)) τ → $(mkIdent td.name) := + fun x => match x with $[| $pats => $rhss]*) + +/-- The `rootTag` accessors (one non-recursive `def` per type). -/ +def mkRootTag (tds : Array TypeData) : CommandElabM (Array (TSyntax `command)) := do + let tIdent := mkIdent `t + tds.mapM fun td => do + let mut pats : Array Term := #[] + let mut rhss : Array Term := #[] + for cd in td.ctors do + let hole ← `(_) + let wilds := Array.mkArray cd.fields.size hole + pats := pats.push (← `($(mkIdent (taggedOf td.name ++ cd.shortName)) $tIdent $wilds*)) + rhss := rhss.push tIdent + `(command| def $(mkIdent (rootTagOf td.name)) {τ : Type} : + $(mkIdent (taggedOf td.name)) τ → τ := + fun x => match x with $[| $pats => $rhss]*) + +/-- The postorder `tag` functions, one per family member (separate defs in +dependency order). -/ +def mkTag (tds : Array TypeData) : CommandElabM (Array (TSyntax `command)) := do + let nId := mkIdent ``Spa.NodeId + tds.mapM fun td => do + let mut pats : Array Term := #[] + let mut rhss : Array Term := #[] + for cd in td.ctors do + let argNames := (Array.range cd.fields.size).map (fun i => mkIdent (.mkSimple s!"a{i}")) + let pat ← `($(mkIdent cd.origName) $argNames*) + let mut cur : Term ← `(n) + let mut lets : Array (Ident × Term) := #[] + let mut taggedArgs : Array Term := #[] + let mut ri := 0 + for (f, a) in cd.fields.zip argNames do + if f.isRec then + let rName := mkIdent (.mkSimple s!"r{ri}") + let rhsCall ← `($(mkIdent (tagOf f.recType)) $a $cur) + lets := lets.push (rName, rhsCall) + taggedArgs := taggedArgs.push (← `($rName |>.1)) + cur ← `($rName |>.2) + ri := ri + 1 + else + taggedArgs := taggedArgs.push a + let last := cur + let tagged ← `($(mkIdent (taggedOf td.name ++ cd.shortName)) + (⟨$last⟩ : $nId) $taggedArgs*) + let mut body ← `(($tagged, $last + 1)) + for (rName, rhs) in lets.reverse do + body ← `(let $rName := $rhs; $body) + pats := pats.push pat + rhss := rhss.push body + `(command| def $(mkIdent (tagOf td.name)) : + $(mkIdent td.name) → Nat → $(mkIdent (taggedOf td.name)) $nId × Nat := + fun e n => match e with $[| $pats => $rhss]*) + +/-- The tag-fold functions: `foldTags f acc t` applies `f` to every tag in `t`, +right-to-left, threading `acc`. This is the `Foldable`/`foldr`-over-tags the +hand-written collectors (e.g. `subtreeIds`) reduce to. One separate def per +family member (the family is a DAG, so no `mutual` block is needed). -/ +def mkFoldTags (tds : Array TypeData) : CommandElabM (Array (TSyntax `command)) := do + let τ := mkIdent `τ + let m := mkIdent `M + let fId := mkIdent `f + let accId := mkIdent `acc + let tagId := mkIdent `t + tds.mapM fun td => do + let mut pats : Array Term := #[] + let mut rhss : Array Term := #[] + for cd in td.ctors do + let argNames := (Array.range cd.fields.size).map (fun i => mkIdent (.mkSimple s!"a{i}")) + let pat ← `($(mkIdent (taggedOf td.name ++ cd.shortName)) $tagId $argNames*) + let mut body : Term := accId + for (fld, a) in (cd.fields.zip argNames).reverse do + if fld.isRec then + body ← `($(mkIdent (foldTagsOf fld.recType)) $fId $body $a) + body ← `($fId $tagId $body) + pats := pats.push pat + rhss := rhss.push body + `(command| def $(mkIdent (foldTagsOf td.name)) {$τ:ident : Type} {$m:ident : Type} + ($fId : $τ → $m → $m) ($accId : $m) : + $(mkIdent (taggedOf td.name)) $τ → $m := + fun x => match x with $[| $pats => $rhss]*) + +/-- `derive_tagged T₁ … Tₙ` — generate tagged mirrors, `erase`, and `tag` for the +given family of inductives. -/ +syntax (name := deriveTaggedCmd) "derive_tagged " ident+ : command + +@[command_elab deriveTaggedCmd] +def elabDeriveTagged : CommandElab := fun stx => do + match stx with + | `(derive_tagged $ids*) => + let family ← ids.mapM fun i => Command.liftCoreM (realizeGlobalConstNoOverload i) + let τ := mkIdent `τ + let tds ← Command.liftTermElabM (gather family τ) + for d in (← mkInductives tds τ) do elabCommand d + for d in (← mkDeriveInstances tds) do elabCommand d + for d in (← mkRootTag tds) do elabCommand d + for d in (← mkErase tds) do elabCommand d + for d in (← mkTag tds) do elabCommand d + for d in (← mkFoldTags tds) do elabCommand d + | _ => throwUnsupportedSyntax + +end Spa.DeriveTagged diff --git a/lean/Spa/Language/Tagged/Graphs.lean b/lean/Spa/Language/Tagged/Graphs.lean new file mode 100644 index 0000000..e0e2f50 --- /dev/null +++ b/lean/Spa/Language/Tagged/Graphs.lean @@ -0,0 +1,63 @@ +import Spa.Language +import Spa.Language.Graphs +import Spa.Language.Tagged.Basic +import Spa.Language.Tagged.Properties + +namespace Spa + +open GGraph + +def Stmt.Tagged.cfg : Stmt.Tagged NodeId → GGraph (List (BasicStmt.Tagged NodeId)) + | .basic _ bs => GGraph.singleton [bs] + | .andThen _ s₁ s₂ => s₁.cfg ⤳ s₂.cfg + | .ifElse _ _ s₁ s₂ => s₁.cfg ∙ s₂.cfg + | .whileLoop _ _ s => GGraph.loop s.cfg + +theorem Stmt.Tagged.cfg_graph : ∀ (t : Stmt.Tagged NodeId), + t.cfg.map (List.map BasicStmt.Tagged.erase) = t.erase.cfg + | .basic _ bs => by simp [Stmt.Tagged.cfg, Stmt.cfg, Stmt.Tagged.erase, BasicStmt.Tagged.erase] + | .andThen _ s₁ s₂ => by + simp [Stmt.Tagged.cfg, Stmt.cfg, Stmt.Tagged.erase, Stmt.Tagged.cfg_graph s₁, Stmt.Tagged.cfg_graph s₂] + | .ifElse _ _ s₁ s₂ => by + simp [Stmt.Tagged.cfg, Stmt.cfg, Stmt.Tagged.erase, Stmt.Tagged.cfg_graph s₁, Stmt.Tagged.cfg_graph s₂] + | .whileLoop _ _ s => by + simp [Stmt.Tagged.cfg, Stmt.cfg, Stmt.Tagged.erase, Stmt.Tagged.cfg_graph s] + +def GGraph.nodeLabel (g : GGraph (List (BasicStmt.Tagged NodeId))) (i : g.Index) : Option NodeId := + (g.nodes i).head?.map BasicStmt.Tagged.rootTag + +def GGraph.stateOf (g : GGraph (List (BasicStmt.Tagged NodeId))) (id : NodeId) : Option g.Index := + g.indices.find? (fun i => decide (g.nodeLabel i = some id)) + +theorem GGraph.stateOf_label {g : GGraph (List (BasicStmt.Tagged NodeId))} {id : NodeId} + {i : g.Index} (h : g.stateOf id = some i) : g.nodeLabel i = some id := by + rw [GGraph.stateOf] at h + simpa using List.find?_some h + +namespace Program + +variable (p : Program) + +def tagged : Stmt.Tagged NodeId := tagStmt p.rootStmt + +def taggedCfg : GGraph (List (BasicStmt.Tagged NodeId)) := + GGraph.wrap p.tagged.cfg + +theorem taggedCfg_erase : + p.taggedCfg.map (List.map BasicStmt.Tagged.erase) = p.cfg := by + rw [taggedCfg, GGraph.map_wrap, Stmt.Tagged.cfg_graph, tagged, erase_tagStmt] + rfl + +theorem taggedCfg_size : p.taggedCfg.size = p.cfg.size := by + conv_rhs => rw [← p.taggedCfg_erase] + rfl + +def nodeIdOf (s : p.State) : Option NodeId := + p.taggedCfg.nodeLabel (Fin.cast p.taggedCfg_size.symm s) + +def stateOfNodeId (id : NodeId) : Option p.State := + (p.taggedCfg.stateOf id).map (Fin.cast p.taggedCfg_size) + +end Program + +end Spa diff --git a/lean/Spa/Language/Tagged/Id.lean b/lean/Spa/Language/Tagged/Id.lean new file mode 100644 index 0000000..b572c7b --- /dev/null +++ b/lean/Spa/Language/Tagged/Id.lean @@ -0,0 +1,9 @@ +import Mathlib.Data.Nat.Notation + +namespace Spa + +structure NodeId where + post : ℕ + deriving DecidableEq, Repr + +end Spa diff --git a/lean/Spa/Language/Tagged/Properties.lean b/lean/Spa/Language/Tagged/Properties.lean new file mode 100644 index 0000000..efe7d94 --- /dev/null +++ b/lean/Spa/Language/Tagged/Properties.lean @@ -0,0 +1,29 @@ +import Spa.Language.Tagged.Basic + +namespace Spa + +@[simp] theorem Expr.erase_tag (e : Expr) (n : ℕ) : (e.tag n).1.erase = e := by + induction e generalizing n with + | add a b iha ihb => simp [Expr.tag, Expr.Tagged.erase, iha, ihb] + | sub a b iha ihb => simp [Expr.tag, Expr.Tagged.erase, iha, ihb] + | var x => simp [Expr.tag, Expr.Tagged.erase] + | num k => simp [Expr.tag, Expr.Tagged.erase] + +@[simp] theorem BasicStmt.erase_tag (bs : BasicStmt) (n : ℕ) : + (bs.tag n).1.erase = bs := by + cases bs with + | assign x e => simp [BasicStmt.tag, BasicStmt.Tagged.erase] + | noop => simp [BasicStmt.tag, BasicStmt.Tagged.erase] + +@[simp] theorem Stmt.erase_tag (s : Stmt) (n : ℕ) : (s.tag n).1.erase = s := by + induction s generalizing n with + | basic bs => simp [Stmt.tag, Stmt.Tagged.erase] + | andThen a b iha ihb => simp [Stmt.tag, Stmt.Tagged.erase, iha, ihb] + | ifElse e a b iha ihb => simp [Stmt.tag, Stmt.Tagged.erase, iha, ihb] + | whileLoop e s ih => simp [Stmt.tag, Stmt.Tagged.erase, ih] + +/-- Erasing a freshly tagged program recovers it. -/ +theorem erase_tagStmt (s : Stmt) : (tagStmt s).erase = s := by + simp [tagStmt] + +end Spa diff --git a/lean/Spa/Language/Tagged/TODO.md b/lean/Spa/Language/Tagged/TODO.md new file mode 100644 index 0000000..b5c745f --- /dev/null +++ b/lean/Spa/Language/Tagged/TODO.md @@ -0,0 +1,46 @@ +# Tagged AST — follow-ups + +## Descendant tracking — parked + +The interval-labeling descendant test and its correctness proof +(`descendant_iff_tagStmt` and supporting rose-tree/`Good` machinery) have been +removed from the live code and parked in `DESCENDANT-TRACKING.md`, with a revival +checklist. It's a computational optimization not yet needed; revive it (and the +`NodeId.desc` field) when LICM wants fast ancestor queries. + +## ID → CFG-state mapping — plan part B — DONE + +`Graphs.lean` now defines a payload-generic `GGraph α` (with `Graph := GGraph +(List BasicStmt)` as the concrete CFG), so the labelled CFG **reuses** the graph +combinators instead of mirroring them. In `Cfg.lean`: +`buildCfgL : Stmt.Tagged NodeId → GGraph (List (BasicStmt.Tagged NodeId))` is just +`buildCfg` at the tagged payload; `buildCfgL_graph : +(buildCfgL t).map (List.map erase) = buildCfg t.erase` connects it to the real +CFG; and `GGraph.nodeLabel`/`GGraph.stateOf` read a node's id straight from its +payload (`stateOf_label` is the soundness). No `LGraph`, no separate `label` +field, no duplicated combinators. + +## ID → CFG-state mapping — totality — DONE + +The `Option`-valued `nodeIdOf`/`stateOfNodeId` are now proven total on the inputs +that matter (`Graphs.lean`), via a payload-list characterization of the CFG: + +- `GGraph.nodeList` flattens `nodes` into the list of payloads, with combinator + lemmas (`nodeList_comp/link/loop/wrap`) reducing it through the CFG builders. +- `Stmt.Tagged.basics` lists a program's basic statements; the master lemma + `Stmt.Tagged.cfg_nodeList_filter` (and its program-level + `taggedCfg_nodeList_filter`) shows the non-empty CFG nodes are *exactly* the + singletons `[bs]` for `bs ∈ basics`. +- AST ⇒ CFG: `exists_state_of_mem_basics` (a state with payload `[bs]`) and + `stateOfNodeId_isSome` (the search succeeds). +- CFG ⇒ AST: `exists_basic_of_code_ne_nil` (a non-empty node is `[bs]`, with + `code = [bs.erase]` and `nodeIdOf = some bs.rootTag`) and `nodeIdOf_isSome`. + +All `propext`/`Quot.sound`-only (no `sorry`, no choice). + +Remaining nice-to-have: +- Injectivity: distinct basic-statement ids map to distinct states, giving a + two-sided id ↔ state correspondence (upgrading the existence results above to a + genuine bijection, and pinning `stateOfNodeId (bs.rootTag)` to *the* state + holding `bs`). The `tag`-uniqueness fact this needs (`Nodup` of postorder tags) + was part of the parked descendant machinery in `DESCENDANT-TRACKING.md`.