Add proof of reaching definition analysis

This requires a few pieces:

* Make node tags use `Fin n` intead of natural numbers. This makes
  it possible to build a finite lattice over AST nodes, and also
  ensure automatic, total indexing from CFG nodes into the AST that
  created them. For this, use the elaborator to derive the ordering
  statements etc. where possible.
* Adjust the forward framework to enable proofs that don't just state
  correctness on the environment, but also on an arbitrary additional
  state accumulated from traversing the trace.
* State the reaching definition analysis's correctness in terms
  of this new framework.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-27 16:29:16 -05:00
parent 5737805125
commit b6b30958aa
20 changed files with 678 additions and 197 deletions

View File

@@ -130,9 +130,9 @@ def loopOut (g : GGraph α) : Fin (2 + g.size) := (1 : Fin 2).castAdd g.size
This is technically sloppy (see module comment), but it's simple.
-/
def loop (g : GGraph (List β)) : GGraph (List β) where
def loop (g : GGraph (Option β)) : GGraph (Option β) where
size := 2 + g.size
nodes := Fin.append (fun _ : Fin 2 => []) g.nodes
nodes := Fin.append (fun _ : Fin 2 => none) g.nodes
edges := g.edges.finNatAddProd 2 ++
((g.loopIn, ·) <$> g.inputs.finNatAdd 2) ++
((·, g.loopOut) <$> g.outputs.finNatAdd 2) ++
@@ -140,9 +140,9 @@ def loop (g : GGraph (List β)) : GGraph (List β) where
inputs := [g.loopIn]
outputs := [g.loopOut]
@[simp] lemma loop_inputs (g : GGraph (List β)) : (loop g).inputs = [g.loopIn] := rfl
@[simp] lemma loop_inputs (g : GGraph (Option β)) : (loop g).inputs = [g.loopIn] := rfl
@[simp] lemma loop_outputs (g : GGraph (List β)) : (loop g).outputs = [g.loopOut] := rfl
@[simp] lemma loop_outputs (g : GGraph (Option β)) : (loop g).outputs = [g.loopOut] := rfl
/-- Creates a single-node graph whose node contains the given value. -/
def singleton (a : α) : GGraph α where
@@ -154,8 +154,8 @@ def singleton (a : α) : GGraph α where
/-- Creates a new graph with a single input and single output node. Useful to ensure there's
a single point of entry and single point of exit. -/
def wrap (g : GGraph (List β)) : GGraph (List β) :=
singleton [] g singleton []
def wrap (g : GGraph (Option β)) : GGraph (Option β) :=
singleton none g singleton none
@[simp] lemma map_singleton (f : α β) (a : α) :
f <$> singleton a = singleton (f a) := rfl
@@ -176,16 +176,16 @@ def wrap (g : GGraph (List β)) : GGraph (List β) :=
funext i
refine Fin.addCases ?_ ?_ i <;> intro j <;> simp [Fin.append_left, Fin.append_right]
@[simp] lemma map_loop (h : β γ) (g : GGraph (List β)) :
(List.map h) <$> (loop g) = loop (List.map h <$> g) := by
@[simp] lemma map_loop (h : β γ) (g : GGraph (Option β)) :
(Option.map h) <$> (loop g) = loop (Option.map h <$> g) := by
rcases g with n, nd, e, i, o
simp only [Functor.map, GGraph.loop]
congr 1
funext i
refine Fin.addCases ?_ ?_ i <;> intro j <;> simp [Fin.append_left, Fin.append_right]
@[simp] lemma map_wrap (h : β γ) (g : GGraph (List β)) :
(List.map h) <$> wrap g = wrap (List.map h <$> g) := by
@[simp] lemma map_wrap (h : β γ) (g : GGraph (Option β)) :
(Option.map h) <$> wrap g = wrap (Option.map h <$> g) := by
simp [GGraph.wrap, GGraph.map_sequence, GGraph.map_singleton]
variable (g : GGraph α)
@@ -220,8 +220,8 @@ lemma edge_of_mem_predecessors {idx₁ idx₂ : g.Index}
end GGraph
/-- "Normal" graphs, for the purposes of the analyses in this
framework, have basic blocks in their nodes, and nothing else. -/
abbrev Graph : Type := GGraph (List BasicStmt)
framework, have basic statements in their nodes, and nothing else. -/
abbrev Graph : Type := GGraph (Option BasicStmt)
namespace Graph
@@ -235,7 +235,7 @@ end Graph
open Graph in
def Stmt.cfg : Stmt Graph
-- A basic statement goes into a single basic block
| .basic bs => singleton [bs]
| .basic bs => singleton (some bs)
-- Sequencing of statements corresponds naturally to CFG sequencing
| .andThen s₁ s₂ => s₁.cfg s₂.cfg
-- An if can execute either one branch or the other; overlap them.

View File

@@ -17,7 +17,7 @@ section Embeddings
variable {g₁ g₂ : Graph} {ρ₁ ρ₂ : Env}
lemma Trace.overlay_left {idx₁ idx₂ : g₁.Index}
noncomputable def Trace.overlay_left {idx₁ idx₂ : g₁.Index}
(tr : Trace g₁ idx₁ idx₂ ρ₁ ρ₂) :
Trace (g₁ g₂) (idx₁.castAdd g₂.size) (idx₂.castAdd g₂.size) ρ₁ ρ₂ := by
induction tr with
@@ -29,7 +29,7 @@ lemma Trace.overlay_left {idx₁ idx₂ : g₁.Index}
· rwa [show (g₁ g₂).nodes = Fin.append g₁.nodes g₂.nodes from rfl, Fin.append_left]
· exact List.mem_append_left _ (List.mem_map_of_mem _ he)
lemma Trace.overlay_right {idx₁ idx₂ : g₂.Index}
noncomputable def Trace.overlay_right {idx₁ idx₂ : g₂.Index}
(tr : Trace g₂ idx₁ idx₂ ρ₁ ρ₂) :
Trace (g₁ g₂) (idx₁.natAdd g₁.size) (idx₂.natAdd g₁.size) ρ₁ ρ₂ := by
induction tr with
@@ -41,7 +41,7 @@ lemma Trace.overlay_right {idx₁ idx₂ : g₂.Index}
· rwa [show (g₁ g₂).nodes = Fin.append g₁.nodes g₂.nodes from rfl, Fin.append_right]
· exact List.mem_append_right _ (List.mem_map_of_mem _ he)
lemma Trace.sequence_left {idx₁ idx₂ : g₁.Index}
noncomputable def Trace.sequence_left {idx₁ idx₂ : g₁.Index}
(tr : Trace g₁ idx₁ idx₂ ρ₁ ρ₂) :
Trace (g₁ g₂) (idx₁.castAdd g₂.size) (idx₂.castAdd g₂.size) ρ₁ ρ₂ := by
induction tr with
@@ -53,7 +53,7 @@ lemma Trace.sequence_left {idx₁ idx₂ : g₁.Index}
· rwa [show (g₁ g₂).nodes = Fin.append g₁.nodes g₂.nodes from rfl, Fin.append_left]
· exact List.mem_append_left _ (List.mem_append_left _ (List.mem_map_of_mem _ he))
lemma Trace.sequence_right {idx₁ idx₂ : g₂.Index}
noncomputable def Trace.sequence_right {idx₁ idx₂ : g₂.Index}
(tr : Trace g₂ idx₁ idx₂ ρ₁ ρ₂) :
Trace (g₁ g₂) (idx₁.natAdd g₁.size) (idx₂.natAdd g₁.size) ρ₁ ρ₂ := by
induction tr with
@@ -66,21 +66,21 @@ lemma Trace.sequence_right {idx₁ idx₂ : g₂.Index}
· exact List.mem_append_left _
(List.mem_append_right _ (List.mem_map_of_mem _ he))
lemma EndToEndTrace.overlay_left (etr : EndToEndTrace g₁ ρ₁ ρ₂) :
noncomputable def EndToEndTrace.overlay_left (etr : EndToEndTrace g₁ ρ₁ ρ₂) :
EndToEndTrace (g₁ g₂) ρ₁ ρ₂ := by
obtain i₁, h₁, i₂, h₂, tr := etr
exact i₁.castAdd g₂.size, List.mem_append_left _ (List.mem_map_of_mem _ h₁),
i₂.castAdd g₂.size, List.mem_append_left _ (List.mem_map_of_mem _ h₂),
tr.overlay_left
lemma EndToEndTrace.overlay_right (etr : EndToEndTrace g₂ ρ₁ ρ₂) :
noncomputable def EndToEndTrace.overlay_right (etr : EndToEndTrace g₂ ρ₁ ρ₂) :
EndToEndTrace (g₁ g₂) ρ₁ ρ₂ := by
obtain i₁, h₁, i₂, h₂, tr := etr
exact i₁.natAdd g₁.size, List.mem_append_right _ (List.mem_map_of_mem _ h₁),
i₂.natAdd g₁.size, List.mem_append_right _ (List.mem_map_of_mem _ h₂),
tr.overlay_right
lemma EndToEndTrace.concat {ρ₃ : Env} (etr₁ : EndToEndTrace g₁ ρ₁ ρ₂)
noncomputable def EndToEndTrace.concat {ρ₃ : Env} (etr₁ : EndToEndTrace g₁ ρ₁ ρ₂)
(etr₂ : EndToEndTrace g₂ ρ₂ ρ₃) : EndToEndTrace (g₁ g₂) ρ₁ ρ₃ := by
obtain i₁, h₁, i₂, h₂, tr₁ := etr₁
obtain j₁, k₁, j₂, k₂, tr₂ := etr₂
@@ -98,29 +98,29 @@ section Loop
variable {g : Graph} {ρ₁ ρ₂ ρ₃ : Env}
lemma Trace.loop {idx₁ idx₂ : g.Index} (tr : Trace g idx₁ idx₂ ρ₁ ρ₂) :
noncomputable def Trace.loop {idx₁ idx₂ : g.Index} (tr : Trace g idx₁ idx₂ ρ₁ ρ₂) :
Trace (Graph.loop g) (idx₁.natAdd 2) (idx₂.natAdd 2) ρ₁ ρ₂ := by
induction tr with
| single hbs =>
exact Trace.single (by
rwa [show (Graph.loop g).nodes = Fin.append (fun _ : Fin 2 => []) g.nodes from rfl,
rwa [show (Graph.loop g).nodes = Fin.append (fun _ : Fin 2 => none) g.nodes from rfl,
Fin.append_right])
| edge hbs he _ ih =>
refine Trace.edge ?_ ?_ ih
· rwa [show (Graph.loop g).nodes = Fin.append (fun _ : Fin 2 => []) g.nodes from rfl,
· rwa [show (Graph.loop g).nodes = Fin.append (fun _ : Fin 2 => none) g.nodes from rfl,
Fin.append_right]
· exact List.mem_append_left _ (List.mem_append_left _
(List.mem_append_left _ (List.mem_map_of_mem _ he)))
private lemma loop_nodes_at_in :
(Graph.loop g).nodes g.loopIn = [] :=
Fin.append_left (fun _ : Fin 2 => []) g.nodes 0
(Graph.loop g).nodes g.loopIn = none :=
Fin.append_left (fun _ : Fin 2 => none) g.nodes 0
private lemma loop_nodes_at_out :
(Graph.loop g).nodes g.loopOut = [] :=
Fin.append_left (fun _ : Fin 2 => []) g.nodes 1
(Graph.loop g).nodes g.loopOut = none :=
Fin.append_left (fun _ : Fin 2 => none) g.nodes 1
lemma EndToEndTrace.loop (etr : EndToEndTrace g ρ₁ ρ₂) :
noncomputable def EndToEndTrace.loop (etr : EndToEndTrace g ρ₁ ρ₂) :
EndToEndTrace (Graph.loop g) ρ₁ ρ₂ := by
obtain i₁, h₁, i₂, h₂, tr := etr
-- the edge in → (2 ↑ʳ i₁), reached through the second edge group
@@ -132,15 +132,15 @@ lemma EndToEndTrace.loop (etr : EndToEndTrace g ρ₁ ρ₂) :
refine List.mem_append_left _ (List.mem_append_right _ ?_)
exact List.mem_map_of_mem _ (List.mem_map_of_mem _ h₂)
refine g.loopIn, List.mem_singleton_self _, g.loopOut, List.mem_singleton_self _, ?_
exact Trace.concat (Trace.single (loop_nodes_at_in EvalBasicStmts.nil)) hin
(Trace.concat tr.loop hout (Trace.single (loop_nodes_at_out EvalBasicStmts.nil)))
exact Trace.concat (Trace.single (loop_nodes_at_in EvalBasicStmtOpt.none)) hin
(Trace.concat tr.loop hout (Trace.single (loop_nodes_at_out EvalBasicStmtOpt.none)))
private lemma loop_edge_out_in :
((g.loopOut, g.loopIn) : (Graph.loop g).Edge) (Graph.loop g).edges := by
refine List.mem_append_right _ ?_
exact List.mem_cons_self _ _
lemma EndToEndTrace.loop_concat (etr₁ : EndToEndTrace (Graph.loop g) ρ₁ ρ₂)
noncomputable def EndToEndTrace.loop_concat (etr₁ : EndToEndTrace (Graph.loop g) ρ₁ ρ₂)
(etr₂ : EndToEndTrace (Graph.loop g) ρ₂ ρ₃) :
EndToEndTrace (Graph.loop g) ρ₁ ρ₃ := by
obtain i₁, h₁, i₂, h₂, tr₁ := etr₁
@@ -150,35 +150,35 @@ lemma EndToEndTrace.loop_concat (etr₁ : EndToEndTrace (Graph.loop g) ρ₁ ρ
exact g.loopIn, List.mem_singleton_self _, g.loopOut, List.mem_singleton_self _,
Trace.concat tr₁ loop_edge_out_in tr₂
lemma EndToEndTrace.loop_empty {ρ : Env} : EndToEndTrace (Graph.loop g) ρ ρ := by
noncomputable def EndToEndTrace.loop_empty {ρ : Env} : EndToEndTrace (Graph.loop g) ρ ρ := by
have hedge : ((g.loopIn, g.loopOut) : (Graph.loop g).Edge) (Graph.loop g).edges :=
List.mem_append_right _ (List.mem_cons_of_mem _ (List.mem_cons_self _ _))
exact g.loopIn, List.mem_singleton_self _, g.loopOut, List.mem_singleton_self _,
Trace.concat (Trace.single (loop_nodes_at_in EvalBasicStmts.nil)) hedge
(Trace.single (loop_nodes_at_out EvalBasicStmts.nil))
Trace.concat (Trace.single (loop_nodes_at_in EvalBasicStmtOpt.none)) hedge
(Trace.single (loop_nodes_at_out EvalBasicStmtOpt.none))
end Loop
/-! ### Singletons, wrap, and the main result -/
lemma EndToEndTrace.singleton {bss : List BasicStmt} {ρ₁ ρ₂ : Env}
(h : EvalBasicStmts ρ₁ bss ρ₂) : EndToEndTrace (Graph.singleton bss) ρ₁ ρ₂ :=
noncomputable def EndToEndTrace.singleton {o : Option BasicStmt} {ρ₁ ρ₂ : Env}
(h : EvalBasicStmtOpt ρ₁ o ρ₂) : EndToEndTrace (Graph.singleton o) ρ₁ ρ₂ :=
(0 : Fin 1), List.mem_singleton_self _, (0 : Fin 1), List.mem_singleton_self _,
Trace.single h
lemma EndToEndTrace.singleton_nil (ρ : Env) :
EndToEndTrace (Graph.singleton []) ρ ρ :=
EndToEndTrace.singleton EvalBasicStmts.nil
noncomputable def EndToEndTrace.singleton_nil (ρ : Env) :
EndToEndTrace (Graph.singleton none) ρ ρ :=
EndToEndTrace.singleton EvalBasicStmtOpt.none
lemma EndToEndTrace.wrap {g : Graph} {ρ₁ ρ₂ : Env}
noncomputable def EndToEndTrace.wrap {g : Graph} {ρ₁ ρ₂ : Env}
(etr : EndToEndTrace g ρ₁ ρ₂) : EndToEndTrace (Graph.wrap g) ρ₁ ρ₂ :=
(EndToEndTrace.singleton_nil ρ₁).concat (etr.concat (EndToEndTrace.singleton_nil ρ₂))
theorem Stmt.cfg_sufficient {s : Stmt} {ρ₁ ρ₂ : Env}
noncomputable def Stmt.cfg_sufficient {s : Stmt} {ρ₁ ρ₂ : Env}
(h : EvalStmt ρ₁ s ρ₂) : EndToEndTrace s.cfg ρ₁ ρ₂ := by
induction h with
| basic ρ₁ ρ₂ bs hbs =>
exact EndToEndTrace.singleton (EvalBasicStmts.cons hbs EvalBasicStmts.nil)
exact EndToEndTrace.singleton (EvalBasicStmtOpt.some hbs)
| andThen ρ₁ ρ₂ ρ₃ s₁ s₂ _ _ ih₁ ih₂ =>
exact ih₁.concat ih₂
| ifTrue ρ₁ ρ₂ e z s₁ s₂ _ _ _ ih =>
@@ -193,7 +193,7 @@ theorem Stmt.cfg_sufficient {s : Stmt} {ρ₁ ρ₂ : Env}
/-! ### The wrapped graph's entry has no predecessors (Agda's "ugly" block) -/
def Graph.wrapInput (g : Graph) : (Graph.wrap g).Index :=
(0 : Fin 1).castAdd ((g Graph.singleton []).size)
(0 : Fin 1).castAdd ((g Graph.singleton none).size)
def Graph.wrapOutput (g : Graph) : (Graph.wrap g).Index :=
Fin.natAdd 1 ((Fin.natAdd g.size (0 : Fin 1)))
@@ -205,9 +205,9 @@ lemma Graph.wrap_outputs (g : Graph) :
(Graph.wrap g).outputs = [g.wrapOutput] := rfl
private lemma not_mem_edges_castAdd_sequence {g₂ : Graph} (i : Fin 1)
(idx : (Graph.singleton [] g₂).Index) :
((idx, i.castAdd g₂.size) : (Graph.singleton [] g₂).Edge)
(Graph.singleton [] g₂).edges := by
(idx : (Graph.singleton none g₂).Index) :
((idx, i.castAdd g₂.size) : (Graph.singleton none g₂).Edge)
(Graph.singleton none g₂).edges := by
intro h
rcases List.mem_append.mp h with h' | h'
· rcases List.mem_append.mp h' with h'' | h''
@@ -228,6 +228,6 @@ lemma Graph.wrap_predecessors_eq_nil (g : Graph) (idx : (Graph.wrap g).Index)
subst h
rw [GGraph.predecessors, List.filter_eq_nil_iff]
intro idx' _
simpa using not_mem_edges_castAdd_sequence (g₂ := g Graph.singleton []) 0 idx'
simpa using not_mem_edges_castAdd_sequence (g₂ := g Graph.singleton none) 0 idx'
end Spa

View File

@@ -46,22 +46,20 @@ inductive EvalExpr : Env → Expr → Value → Prop
/-- Inference rules for evaluating a basic statement (`Spa.BasicStmt`) in
a given environment, potentially changing the environment.
Pretty standard big-step evaluation. -/
inductive EvalBasicStmt : Env BasicStmt Env Prop
inductive EvalBasicStmt : Env BasicStmt Env Type
| noop (ρ : Env) : EvalBasicStmt ρ .noop ρ
| assign (ρ : Env) (x : String) (e : Expr) (v : Value) :
EvalExpr ρ e v EvalBasicStmt ρ (.assign x e) ((x, v) :: ρ)
/-- Inference rules for evaluating a sequence of basic statements. -/
inductive EvalBasicStmts : Env List BasicStmt Env Prop
| nil {ρ : Env} : EvalBasicStmts ρ [] ρ
| cons {ρ₁ ρ₂ ρ : Env} {bs : BasicStmt} {bss : List BasicStmt} :
EvalBasicStmt ρ₁ bs ρ₂ EvalBasicStmts ρ₂ bss ρ₃
EvalBasicStmts ρ₁ (bs :: bss) ρ₃
inductive EvalBasicStmtOpt : Env Option BasicStmt Env Type
| none {ρ : Env} : EvalBasicStmtOpt ρ Option.none ρ
| some {ρ ρ₂ : Env} {bs : BasicStmt} :
EvalBasicStmt ρ₁ bs ρ EvalBasicStmtOpt ρ₁ (Option.some bs) ρ₂
/-- Inference rules for evaluating statements (`Spa.Stmt`) in a given
environment, potentially changing the environment.
Pretty standard big-step evaluation. -/
inductive EvalStmt : Env Stmt Env Prop
inductive EvalStmt : Env Stmt Env Type
| basic (ρ₁ ρ₂ : Env) (bs : BasicStmt) :
EvalBasicStmt ρ₁ bs ρ₂ EvalStmt ρ₁ (.basic bs) ρ₂
| andThen (ρ₁ ρ₂ ρ₃ : Env) (s₁ s₂ : Stmt) :

View File

@@ -6,12 +6,13 @@ derive_tagged Spa.Expr Spa.BasicStmt Spa.Stmt
namespace Spa
def tagStmt (s : Stmt) : Stmt.Tagged NodeId := (s.tag 0).1
def tagStmt (s : Stmt) : Stmt.Tagged RawId := (s.tag 0).1
def Stmt.Tagged.subtreeIds (s : Stmt.Tagged NodeId) : List NodeId :=
def Stmt.Tagged.subtreeIds {τ : Type} (s : Stmt.Tagged τ) : List τ :=
s.foldTags (· :: ·) []
def Stmt.Tagged.isInLoopBody (body : Stmt.Tagged NodeId) (id : NodeId) : Bool :=
def Stmt.Tagged.isInLoopBody {τ : Type} [DecidableEq τ]
(body : Stmt.Tagged τ) (id : τ) : Bool :=
decide (id body.subtreeIds)
end Spa

View File

@@ -13,8 +13,8 @@ inductive types and generates, for each `Tᵢ`:
carries a leading `tag : τ` field and every field whose type is a family
member is retyped to its `.Tagged τ` counterpart;
* `Tᵢ.Tagged.erase : Tᵢ.Tagged τ → Tᵢ`, forgetting all tags;
* `Tᵢ.tag : Tᵢ → → Tᵢ.Tagged NodeId × `, assigning every node a unique
`NodeId` (its postorder index) by a single unified traversal that threads a
* `Tᵢ.tag : Tᵢ → → Tᵢ.Tagged RawId × `, assigning every node a unique
`RawId` (its postorder index) by a single unified traversal that threads a
counter; the whole family shares one counter, so identifiers are unique across
types.
@@ -54,6 +54,45 @@ def eraseOf (n : Name) : Name := n ++ `Tagged ++ `erase
def rootTagOf (n : Name) : Name := n ++ `Tagged ++ `rootTag
def tagOf (n : Name) : Name := n ++ `tag
def foldTagsOf (n : Name) : Name := n ++ `Tagged ++ `foldTags
def wfOf (n : Name) : Name := n ++ `Tagged ++ `WF
def narrowOf (n : Name) : Name := n ++ `Tagged ++ `narrow
def narrowEraseOf (n : Name) : Name := n ++ `Tagged ++ `narrow_erase
def tagLeOf (n : Name) : Name := n ++ `tag_le
def tagRootTagPostOf (n : Name) : Name := n ++ `tag_rootTag_post
def tagWfOf (n : Name) : Name := n ++ `tag_wf
/-- Project the `i`-th conjunct (1-based) out of `hyp`, which has type a
right-nested `And` of `total` conjuncts, e.g. `hyp |>.2 |>.2 |>.1`. -/
def projAnd {m : Type Type} [Monad m] [MonadQuotation m]
(hyp : Term) (i total : Nat) : m Term := do
let mut t := hyp
for _ in [0:i-1] do
t `($t |>.2)
if i < total then
t `($t |>.1)
return t
/-- Combine a non-empty array of propositions into a right-nested conjunction. -/
def mkAndR {m : Type Type} [Monad m] [MonadQuotation m]
(cs : Array Term) : m Term := do
let mut t := cs.back!
for c in cs.pop.reverse do
t `($c $t)
return t
/-- For a constructor, return one entry per *recursive* field: its argument
identifier, the family member it references, and the start-counter expression at
which it is tagged (`n`, then `(a.tag n).2`, …) — the same threading `mkTag`
uses. -/
def recChildren (cd : CtorData) (argNames : Array Ident) (nStart : Term) :
CommandElabM (Array (Ident × Name × Term)) := do
let mut res : Array (Ident × Name × Term) := #[]
let mut cur := nStart
for (f, a) in cd.fields.zip argNames do
if f.isRec then
res := res.push (a, f.recType, cur)
cur `(($(mkIdent (tagOf f.recType)) $a $cur) |>.2)
return res
/-- Inspect the family, classifying each constructor field. -/
def gather (family : Array Name) (τ : Ident) : TermElabM (Array TypeData) := do
@@ -158,7 +197,7 @@ def mkRootTag (tds : Array TypeData) : CommandElabM (Array (TSyntax `command)) :
/-- The postorder `tag` functions, one per family member (separate defs in
dependency order). -/
def mkTag (tds : Array TypeData) : CommandElabM (Array (TSyntax `command)) := do
let nId := mkIdent ``Spa.NodeId
let nId := mkIdent ``Spa.RawId
tds.mapM fun td => do
let mut pats : Array Term := #[]
let mut rhss : Array Term := #[]
@@ -219,6 +258,229 @@ def mkFoldTags (tds : Array TypeData) : CommandElabM (Array (TSyntax `command))
$(mkIdent (taggedOf td.name)) $τ $m :=
fun x => match x with $[| $pats => $rhss]*)
/-- The well-formedness predicate `T.Tagged.WF : T.Tagged RawId → Prop`: every
recursive child's root tag has a strictly smaller postorder index than the node's
own tag, and each child is itself well-formed. Leaf constructors are `True`. -/
def mkWF (tds : Array TypeData) : CommandElabM (Array (TSyntax `command)) := do
let tId := mkIdent `t
let rawId := mkIdent ``Spa.RawId
tds.mapM fun td => do
let mut pats : Array Term := #[]
let mut rhss : Array Term := #[]
for cd in td.ctors do
let hasRec := cd.fields.any (·.isRec)
let mut patArgs : Array Term := #[]
let mut recArgs : Array Ident := #[]
let mut i := 0
for f in cd.fields do
if f.isRec then
let a := mkIdent (.mkSimple s!"a{i}")
patArgs := patArgs.push a
recArgs := recArgs.push a
else
patArgs := patArgs.push ( `(_))
i := i + 1
let tagBind : Term if hasRec then `($tId) else `(_)
let pat `($(mkIdent (taggedOf td.name ++ cd.shortName)) $tagBind $patArgs*)
let rhs if recArgs.isEmpty then `(True) else do
let bounds recArgs.mapM fun a => `($(a).rootTag.post < $(tId).post)
let wfs recArgs.mapM fun a => `($(a).WF)
mkAndR (bounds ++ wfs)
pats := pats.push pat
rhss := rhss.push rhs
`(command| def $(mkIdent (wfOf td.name)) :
$(mkIdent (taggedOf td.name)) $rawId Prop :=
fun x => match x with $[| $pats => $rhss]*)
/-- The `narrow` coercion `T.Tagged RawId → T.Tagged (Fin N)`, given a bound on
the root tag and a well-formedness proof. Each node's tag becomes the `Fin N`
built from its postorder index, and recursion threads the bound through `lt_trans`
and the (definitionally unfolded) `WF` conjunction. -/
def mkNarrow (tds : Array TypeData) : CommandElabM (Array (TSyntax `command)) := do
let rawId := mkIdent ``Spa.RawId
let tId := mkIdent `t
let nId := mkIdent `N
let hId := mkIdent `h
let hwfId := mkIdent `hwf
let tgId := mkIdent `tg
tds.mapM fun td => do
let self `($(mkIdent (taggedOf td.name)) $rawId)
let mut patss : Array (Array Term) := #[]
let mut rhss : Array Term := #[]
for cd in td.ctors do
let argNames := (Array.range cd.fields.size).map fun i => mkIdent (.mkSimple s!"a{i}")
let ctorPat `($(mkIdent (taggedOf td.name ++ cd.shortName)) $tgId $argNames*)
let k := (cd.fields.filter (·.isRec)).size
let mut newArgs : Array Term := #[]
let mut ri := 0
for (f, a) in cd.fields.zip argNames do
if f.isRec then
let bound projAnd hwfId (ri + 1) (2 * k)
let wf projAnd hwfId (k + ri + 1) (2 * k)
newArgs := newArgs.push ( `($(a).narrow (lt_trans $bound $hId) $wf))
ri := ri + 1
else
newArgs := newArgs.push a
let built `($(mkIdent (taggedOf td.name ++ cd.shortName)) $(tgId).post, $hId $newArgs*)
let nPat `(_)
let hPat `($hId)
let hwfPat : Term if k == 0 then `(_) else `($hwfId)
patss := patss.push #[ctorPat, nPat, hPat, hwfPat]
rhss := rhss.push built
`(command| def $(mkIdent (narrowOf td.name)) : ($tId : $self) {$nId : }
$(tId).rootTag.post < $nId $(tId).WF $(mkIdent (taggedOf td.name)) (Fin $nId)
$[| $[$patss],* => $rhss]*)
/-- `T.tag_rootTag_post`: the root tag of a freshly tagged node is exactly one
below the threaded-out counter, i.e. the node itself is numbered last (postorder).
A uniform `cases <;> simp` discharges every constructor. -/
def mkTagRootTagPost (tds : Array TypeData) : CommandElabM (Array (TSyntax `command)) := do
let eId := mkIdent `e
let nId := mkIdent `n
tds.mapM fun td =>
`(command| theorem $(mkIdent (tagRootTagPostOf td.name))
($eId : $(mkIdent td.name)) ($nId : ) :
($(eId).tag $nId).1.rootTag.post + 1 = ($(eId).tag $nId).2 := by
cases $eId:ident <;>
simp [$(mkIdent (tagOf td.name)):ident, $(mkIdent (rootTagOf td.name)):ident])
/-- `T.tag_le`: tagging only ever advances the counter (`n ≤ (e.tag n).2`).
Proved by induction; each arm threads the counter through its recursive children
(using the relevant `tag_le`/induction hypothesis) and closes with `omega`. -/
def mkTagLe (tds : Array TypeData) : CommandElabM (Array (TSyntax `command)) := do
let eId := mkIdent `e
let nId := mkIdent `n
tds.mapM fun td => do
let mut ctorLabels : Array Ident := #[]
let mut binderss : Array (Array Ident) := #[]
let mut tacs : Array (TSyntax ``Lean.Parser.Tactic.tacticSeq) := #[]
for cd in td.ctors do
let argNames := (Array.range cd.fields.size).map fun i => mkIdent (.mkSimple s!"a{i}")
let mut ihBinders : Array Ident := #[]
let mut haveTacs : Array (TSyntax `tactic) := #[]
let mut cur : Term `($nId)
let mut i := 0
for (f, a) in cd.fields.zip argNames do
if f.isRec then
let fact if f.recType == td.name then
`($(mkIdent (.mkSimple s!"ih{i}")) $cur)
else
`($(mkIdent (tagLeOf f.recType)) $a $cur)
if f.recType == td.name then
ihBinders := ihBinders.push (mkIdent (.mkSimple s!"ih{i}"))
haveTacs := haveTacs.push ( `(tactic| have := $fact))
cur `(($(mkIdent (tagOf f.recType)) $a $cur) |>.2)
i := i + 1
let simpTac `(tactic| simp only [$(mkIdent (tagOf td.name)):ident])
let omegaTac `(tactic| omega)
let allTacs := #[simpTac] ++ haveTacs ++ #[omegaTac]
ctorLabels := ctorLabels.push (mkIdent cd.shortName)
binderss := binderss.push (argNames ++ ihBinders)
tacs := tacs.push ( `(tacticSeq| $[$allTacs]*))
`(command| theorem $(mkIdent (tagLeOf td.name)) ($eId : $(mkIdent td.name)) ($nId : ) :
$nId ($(eId).tag $nId).2 := by
induction $eId:ident generalizing $nId:ident with
$[| $ctorLabels:ident $binderss* => $tacs]*)
/-- `T.tag_wf`: a freshly tagged term is well-formed. Each recursive child's
bound conjunct is closed by `omega` from that child's `tag_rootTag_post` plus the
`tag_le` of every later child (which bounds the threaded-out counter), and each
well-formedness conjunct is the child's induction hypothesis / `tag_wf`. -/
def mkTagWf (tds : Array TypeData) : CommandElabM (Array (TSyntax `command)) := do
let eId := mkIdent `e
let nId := mkIdent `n
tds.mapM fun td => do
let mut ctorLabels : Array Ident := #[]
let mut binderss : Array (Array Ident) := #[]
let mut tacs : Array (TSyntax ``Lean.Parser.Tactic.tacticSeq) := #[]
for cd in td.ctors do
let argNames := (Array.range cd.fields.size).map fun i => mkIdent (.mkSimple s!"a{i}")
-- recursive children: (arg, recType, startCounter, sameType?, fieldIndex)
let mut recs : Array (Ident × Name × Term × Bool × Nat) := #[]
let mut cur : Term `($nId)
let mut i := 0
for (f, a) in cd.fields.zip argNames do
if f.isRec then
recs := recs.push (a, f.recType, cur, f.recType == td.name, i)
cur `(($(mkIdent (tagOf f.recType)) $a $cur) |>.2)
i := i + 1
let k := recs.size
let ihBinders := (recs.filter (·.2.2.2.1)).map fun r => mkIdent (.mkSimple s!"ih{r.2.2.2.2}")
let tac : TSyntax ``Lean.Parser.Tactic.tacticSeq if k == 0 then
`(tacticSeq| exact True.intro)
else do
let mut comps : Array Term := #[]
-- bound conjuncts
for idx in [0:k] do
let (a, rt, s, _, _) := recs[idx]!
let mut bHaves : Array (TSyntax `tactic) :=
#[ `(tactic| have := $(mkIdent (tagRootTagPostOf rt)) $a $s)]
for j in [idx+1:k] do
let (aj, rtj, sj, _, _) := recs[j]!
bHaves := bHaves.push ( `(tactic| have := $(mkIdent (tagLeOf rtj)) $aj $sj))
bHaves := bHaves.push ( `(tactic| omega))
comps := comps.push ( `(by $( `(tacticSeq| $[$bHaves]*))))
-- well-formedness conjuncts
for idx in [0:k] do
let (a, rt, s, same, fi) := recs[idx]!
comps := comps.push <| if same then `($(mkIdent (.mkSimple s!"ih{fi}")) $s)
else `($(mkIdent (tagWfOf rt)) $a $s)
let simpTac `(tactic| simp only
[$(mkIdent (tagOf td.name)):ident, $(mkIdent (wfOf td.name)):ident])
let exactTac `(tactic| exact $comps,*)
`(tacticSeq| $[$(#[simpTac, exactTac])]*)
ctorLabels := ctorLabels.push (mkIdent cd.shortName)
binderss := binderss.push (argNames ++ ihBinders)
tacs := tacs.push tac
`(command| theorem $(mkIdent (tagWfOf td.name)) ($eId : $(mkIdent td.name)) ($nId : ) :
($(eId).tag $nId).1.WF := by
induction $eId:ident generalizing $nId:ident with
$[| $ctorLabels:ident $binderss* => $tacs]*)
/-- `T.Tagged.narrow_erase`: narrowing the tag type does not change the erased
(untagged) term. A per-constructor `simp` with the local `narrow`/`erase`
equations, the lower members' `narrow_erase`, and the induction hypotheses. -/
def mkNarrowErase (tds : Array TypeData) : CommandElabM (Array (TSyntax `command)) := do
let rawId := mkIdent ``Spa.RawId
let tId := mkIdent `t
let nId := mkIdent `N
let hId := mkIdent `h
let hwfId := mkIdent `hwf
let tgId := mkIdent `tg
tds.mapM fun td => do
let mut ctorLabels : Array Ident := #[]
let mut binderss : Array (Array Ident) := #[]
let mut tacs : Array (TSyntax ``Lean.Parser.Tactic.tacticSeq) := #[]
for cd in td.ctors do
let argNames := (Array.range cd.fields.size).map fun i => mkIdent (.mkSimple s!"a{i}")
let mut lemmas : Array Term :=
#[ `($(mkIdent (narrowOf td.name))), `($(mkIdent (eraseOf td.name)))]
let mut ihBinders : Array Ident := #[]
let mut seenLower : Array Name := #[]
let mut i := 0
for f in cd.fields do
if f.isRec then
if f.recType == td.name then
let ih := mkIdent (.mkSimple s!"ih{i}")
ihBinders := ihBinders.push ih
lemmas := lemmas.push ( `($ih))
else if !seenLower.contains f.recType then
seenLower := seenLower.push f.recType
lemmas := lemmas.push ( `($(mkIdent (narrowEraseOf f.recType))))
i := i + 1
let introTac `(tactic| intro $nId $hId $hwfId)
let simpTac `(tactic| simp [$[$lemmas:term],*])
ctorLabels := ctorLabels.push (mkIdent cd.shortName)
binderss := binderss.push (#[tgId] ++ argNames ++ ihBinders)
tacs := tacs.push ( `(tacticSeq| $[$(#[introTac, simpTac])]*))
`(command| theorem $(mkIdent (narrowEraseOf td.name)) :
($tId : $(mkIdent (taggedOf td.name)) $rawId) {$nId : }
($hId : $(tId).rootTag.post < $nId) ($hwfId : $(tId).WF),
($(tId).narrow $hId $hwfId).erase = $(tId).erase := by
intro $tId:ident
induction $tId:ident with
$[| $ctorLabels:ident $binderss* => $tacs]*)
/-- `derive_tagged T₁ … Tₙ` — generate tagged mirrors, `erase`, and `tag` for the
given family of inductives. -/
syntax (name := deriveTaggedCmd) "derive_tagged " ident+ : command
@@ -236,6 +498,12 @@ def elabDeriveTagged : CommandElab := fun stx => do
for d in ( mkErase tds) do elabCommand d
for d in ( mkTag tds) do elabCommand d
for d in ( mkFoldTags tds) do elabCommand d
for d in ( mkWF tds) do elabCommand d
for d in ( mkNarrow tds) do elabCommand d
for d in ( mkTagRootTagPost tds) do elabCommand d
for d in ( mkTagLe tds) do elabCommand d
for d in ( mkTagWf tds) do elabCommand d
for d in ( mkNarrowErase tds) do elabCommand d
| _ => throwUnsupportedSyntax
end Spa.DeriveTagged

View File

@@ -7,14 +7,14 @@ namespace Spa
open GGraph
def Stmt.Tagged.cfg : Stmt.Tagged NodeId GGraph (List (BasicStmt.Tagged NodeId))
| .basic _ bs => GGraph.singleton [bs]
def Stmt.Tagged.cfg {τ : Type} : Stmt.Tagged τ GGraph (Option (BasicStmt.Tagged τ))
| .basic _ bs => GGraph.singleton (some bs)
| .andThen _ s₁ s₂ => s₁.cfg s₂.cfg
| .ifElse _ _ s₁ s₂ => s₁.cfg s₂.cfg
| .whileLoop _ _ s => GGraph.loop s.cfg
theorem Stmt.Tagged.cfg_graph : (t : Stmt.Tagged NodeId),
t.cfg.map (List.map BasicStmt.Tagged.erase) = t.erase.cfg
theorem Stmt.Tagged.cfg_graph {τ : Type} : (t : Stmt.Tagged τ),
(Option.map BasicStmt.Tagged.erase) <$> t.cfg = t.erase.cfg
| .basic _ bs => by simp [Stmt.Tagged.cfg, Stmt.cfg, Stmt.Tagged.erase, BasicStmt.Tagged.erase]
| .andThen _ s₁ s₂ => by
simp [Stmt.Tagged.cfg, Stmt.cfg, Stmt.Tagged.erase, Stmt.Tagged.cfg_graph s₁, Stmt.Tagged.cfg_graph s₂]
@@ -23,13 +23,16 @@ theorem Stmt.Tagged.cfg_graph : ∀ (t : Stmt.Tagged NodeId),
| .whileLoop _ _ s => by
simp [Stmt.Tagged.cfg, Stmt.cfg, Stmt.Tagged.erase, Stmt.Tagged.cfg_graph s]
def GGraph.nodeLabel (g : GGraph (List (BasicStmt.Tagged NodeId))) (i : g.Index) : Option NodeId :=
(g.nodes i).head?.map BasicStmt.Tagged.rootTag
def GGraph.nodeLabel {τ : Type} (g : GGraph (Option (BasicStmt.Tagged τ))) (i : g.Index) :
Option τ :=
(g.nodes i).map BasicStmt.Tagged.rootTag
def GGraph.stateOf (g : GGraph (List (BasicStmt.Tagged NodeId))) (id : NodeId) : Option g.Index :=
def GGraph.stateOf {τ : Type} [DecidableEq τ] (g : GGraph (Option (BasicStmt.Tagged τ)))
(id : τ) : Option g.Index :=
g.indices.find? (fun i => decide (g.nodeLabel i = some id))
theorem GGraph.stateOf_label {g : GGraph (List (BasicStmt.Tagged NodeId))} {id : NodeId}
theorem GGraph.stateOf_label {τ : Type} [DecidableEq τ]
{g : GGraph (Option (BasicStmt.Tagged τ))} {id : τ}
{i : g.Index} (h : g.stateOf id = some i) : g.nodeLabel i = some id := by
rw [GGraph.stateOf] at h
simpa using List.find?_some h
@@ -38,26 +41,64 @@ namespace Program
variable (p : Program)
def tagged : Stmt.Tagged NodeId := tagStmt p.rootStmt
def tagged : Stmt.Tagged RawId := tagStmt p.rootStmt
def taggedCfg : GGraph (List (BasicStmt.Tagged NodeId)) :=
GGraph.wrap p.tagged.cfg
def size : := p.tagged.rootTag.post + 1
theorem size_pos : 0 < p.size := Nat.succ_pos _
abbrev NodeId : Type := Fin p.size
theorem tagged_wf : p.tagged.WF := Stmt.tag_wf p.rootStmt 0
def taggedFin : Stmt.Tagged p.NodeId :=
p.tagged.narrow (Nat.lt_succ_self _) p.tagged_wf
def taggedCfg : GGraph (Option (BasicStmt.Tagged p.NodeId)) :=
GGraph.wrap p.taggedFin.cfg
theorem taggedCfg_erase :
p.taggedCfg.map (List.map BasicStmt.Tagged.erase) = p.cfg := by
rw [taggedCfg, GGraph.map_wrap, Stmt.Tagged.cfg_graph, tagged, erase_tagStmt]
(Option.map BasicStmt.Tagged.erase) <$> p.taggedCfg = p.cfg := by
rw [taggedCfg, GGraph.map_wrap, Stmt.Tagged.cfg_graph, taggedFin,
Stmt.Tagged.narrow_erase, tagged, erase_tagStmt]
rfl
theorem taggedCfg_size : p.taggedCfg.size = p.cfg.size := by
conv_rhs => rw [ p.taggedCfg_erase]
rfl
def nodeIdOf (s : p.State) : Option NodeId :=
def nodeIdOf (s : p.State) : Option p.NodeId :=
p.taggedCfg.nodeLabel (Fin.cast p.taggedCfg_size.symm s)
def stateOfNodeId (id : NodeId) : Option p.State :=
def stateOfNodeId (id : p.NodeId) : Option p.State :=
(p.taggedCfg.stateOf id).map (Fin.cast p.taggedCfg_size)
theorem cfg_nodes_eq (s : p.State) :
p.cfg.nodes s = Option.map BasicStmt.Tagged.erase
(p.taggedCfg.nodes (Fin.cast p.taggedCfg_size.symm s)) := by
have key : (g : Graph) (hsz : p.taggedCfg.size = g.size),
(Option.map BasicStmt.Tagged.erase) <$> p.taggedCfg = g
i : Fin g.size,
g.nodes i = Option.map BasicStmt.Tagged.erase
(p.taggedCfg.nodes (Fin.cast hsz.symm i)) := by
intro g hsz hg i
subst hg
rfl
exact key p.cfg p.taggedCfg_size p.taggedCfg_erase s
theorem nodeIdOf_isSome_of_code {s : p.State} {bs : BasicStmt}
(h : p.code s = some bs) : (p.nodeIdOf s).isSome = true := by
have hc : Option.map BasicStmt.Tagged.erase
(p.taggedCfg.nodes (Fin.cast p.taggedCfg_size.symm s)) = some bs := by
rw [ p.cfg_nodes_eq s]; exact h
unfold Program.nodeIdOf GGraph.nodeLabel
cases hcase : p.taggedCfg.nodes (Fin.cast p.taggedCfg_size.symm s) with
| none => rw [hcase] at hc; simp at hc
| some tbs => simp
def nodeIdOfNonempty (s : p.State) {bs : BasicStmt} (h : p.code s = some bs) : p.NodeId :=
(p.nodeIdOf s).get (p.nodeIdOf_isSome_of_code h)
end Program
end Spa

View File

@@ -2,7 +2,7 @@ import Mathlib.Data.Nat.Notation
namespace Spa
structure NodeId where
structure RawId where
post :
deriving DecidableEq, Repr

View File

@@ -3,14 +3,14 @@ import Spa.Language.Graphs
namespace Spa
inductive Trace (g : Graph) : g.Index g.Index Env Env Prop
inductive Trace (g : Graph) : g.Index g.Index Env Env Type
| single {ρ₁ ρ₂ : Env} {idx : g.Index} :
EvalBasicStmts ρ₁ (g.nodes idx) ρ₂ Trace g idx idx ρ₁ ρ₂
EvalBasicStmtOpt ρ₁ (g.nodes idx) ρ₂ Trace g idx idx ρ₁ ρ₂
| edge {ρ₁ ρ₂ ρ₃ : Env} {idx₁ idx₂ idx₃ : g.Index} :
EvalBasicStmts ρ₁ (g.nodes idx₁) ρ₂ (idx₁, idx₂) g.edges
EvalBasicStmtOpt ρ₁ (g.nodes idx₁) ρ₂ (idx₁, idx₂) g.edges
Trace g idx₂ idx₃ ρ₂ ρ₃ Trace g idx₁ idx₃ ρ₁ ρ₃
lemma Trace.concat {g : Graph} {idx₁ idx₂ idx₃ idx₄ : g.Index}
noncomputable def Trace.concat {g : Graph} {idx₁ idx₂ idx₃ idx₄ : g.Index}
{ρ₁ ρ₂ ρ₃ : Env} (tr₁ : Trace g idx₁ idx₂ ρ₁ ρ₂)
(he : (idx₂, idx₃) g.edges) (tr₂ : Trace g idx₃ idx₄ ρ₂ ρ₃) :
Trace g idx₁ idx₄ ρ₁ ρ₃ := by
@@ -18,7 +18,7 @@ lemma Trace.concat {g : Graph} {idx₁ idx₂ idx₃ idx₄ : g.Index}
| single hbs => exact Trace.edge hbs he tr₂
| edge hbs he' _ ih => exact Trace.edge hbs he' (ih he tr₂)
inductive EndToEndTrace (g : Graph) (ρ₁ ρ₂ : Env) : Prop
inductive EndToEndTrace (g : Graph) (ρ₁ ρ₂ : Env) : Type
| intro (idx₁ : g.Index) (idx₁_mem : idx₁ g.inputs)
(idx₂ : g.Index) (idx₂_mem : idx₂ g.outputs)
(trace : Trace g idx₁ idx₂ ρ₁ ρ₂) : EndToEndTrace g ρ₁ ρ₂