This requires a few pieces: * Make node tags use `Fin n` intead of natural numbers. This makes it possible to build a finite lattice over AST nodes, and also ensure automatic, total indexing from CFG nodes into the AST that created them. For this, use the elaborator to derive the ordering statements etc. where possible. * Adjust the forward framework to enable proofs that don't just state correctness on the environment, but also on an arbitrary additional state accumulated from traversing the trace. * State the reaching definition analysis's correctness in terms of this new framework. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
103 lines
4.5 KiB
Lean4
103 lines
4.5 KiB
Lean4
import Spa.Analysis.Reaching
|
||
import Spa.Language.Tagged.Graphs
|
||
|
||
/-!
|
||
# Finding loop-invariant assignments (LICM groundwork)
|
||
|
||
This wires the **reaching-definitions** analysis (`Spa/Analysis/Reaching.lean`)
|
||
to the **tagged AST** to *find* — not yet move — assignments inside a `while`
|
||
loop whose right-hand side depends only on definitions made *outside* the loop.
|
||
These are the candidates a later LICM pass could hoist.
|
||
|
||
The pipeline, for each assignment immediately enclosed by a loop:
|
||
|
||
1. locate its CFG state via the tagged-graph bridge (`Program.stateOfNodeId`);
|
||
2. read the reaching definitions at the assignment's *entry*
|
||
(`joinForKey s result` — the join over predecessors, i.e. before the
|
||
assignment itself runs);
|
||
3. union the definition sets of the RHS variables;
|
||
4. map each definition site back to its `RawId` (`Program.nodeIdOf`) and check
|
||
it is **not** inside the loop body (structural `subtreeIds` membership).
|
||
|
||
If every reaching definition of every RHS variable lies outside the loop, the
|
||
assignment is reported as loop-invariant. This is the first-order check ("all
|
||
reaching definitions outside the loop"); transitive/iterated invariance and the
|
||
actual hoisting are out of scope here.
|
||
-/
|
||
|
||
namespace Spa
|
||
|
||
namespace LicmTransformation
|
||
|
||
open Forward
|
||
|
||
/-- An assignment found inside a loop, paired with the data needed to test its
|
||
invariance against that (immediately enclosing) loop. -/
|
||
structure Candidate (prog : Program) where
|
||
/-- The enclosing `whileLoop`'s tag (for reporting). -/
|
||
loopId : prog.NodeId
|
||
/-- Every node id inside the loop body (the "is-child-of-loop" set). -/
|
||
bodyIds : List prog.NodeId
|
||
/-- The assignment `BasicStmt`'s tag — what labels its CFG node. -/
|
||
assignId : prog.NodeId
|
||
/-- The variables read by the assignment's RHS. -/
|
||
rhsVars : List String
|
||
|
||
/-- Collect every assignment together with its *immediately enclosing* loop.
|
||
`enclosing` carries the current loop's tag and body id-set, or `none` outside any
|
||
loop (in which case assignments are skipped — only in-loop assignments are
|
||
candidates). -/
|
||
def collectCandidates (prog : Program) (enc : Option (prog.NodeId × List prog.NodeId)) :
|
||
Stmt.Tagged prog.NodeId → List (Candidate prog)
|
||
| .basic _ bs =>
|
||
match bs, enc with
|
||
| .assign t _ e, some (loopId, bodyIds) =>
|
||
[{ loopId := loopId, bodyIds := bodyIds, assignId := t,
|
||
rhsVars := e.erase.vars.sort (· ≤ ·) }]
|
||
| _, _ => []
|
||
| .andThen _ a b => collectCandidates prog enc a ++ collectCandidates prog enc b
|
||
| .ifElse _ _ a b => collectCandidates prog enc a ++ collectCandidates prog enc b
|
||
| .whileLoop loopT _ body =>
|
||
collectCandidates prog (some (loopT, body.subtreeIds)) body
|
||
|
||
/-- Read the definition set assigned to variable `k`, or `⊥` if absent. -/
|
||
def lookupDef (prog : Program) (vs : VariableValues (DefSet prog) prog)
|
||
(k : String) : DefSet prog :=
|
||
if h : FiniteMap.MemKey k vs then (FiniteMap.locate h).1 else ⊥
|
||
|
||
/-- The AST node ids marked as definition sites in a `DefSet` (those mapped to
|
||
`true`). With the AST-id-keyed lattice these are recovered directly. -/
|
||
def defSites (prog : Program) (d : DefSet prog) : List prog.NodeId :=
|
||
(List.finRange prog.size).filter (fun i => d i)
|
||
|
||
/-- Is the candidate assignment loop-invariant: do all reaching definitions of
|
||
its RHS variables lie outside the loop body? Reaching sets are now keyed by AST
|
||
node id, so we compare against the loop-body ids directly (embedding the raw
|
||
body ids into `p.NodeId`). -/
|
||
def isInvariant (prog : Program) (c : Candidate prog) : Bool :=
|
||
match prog.stateOfNodeId c.assignId with
|
||
| none => false
|
||
| some s =>
|
||
let entry := joinForKey s (result (DefSet prog) prog)
|
||
let combined : DefSet prog :=
|
||
c.rhsVars.foldl (fun acc k => acc ⊔ lookupDef prog entry k) ⊥
|
||
(defSites prog combined).all (fun nid => ! decide (nid ∈ c.bodyIds))
|
||
|
||
/-- The loop-invariant assignments of `prog`, as `(loopId, assignId)` pairs. -/
|
||
def licmCandidates (prog : Program) : List (prog.NodeId × prog.NodeId) :=
|
||
(collectCandidates prog none prog.taggedFin).filterMap (fun c =>
|
||
if isInvariant prog c then some (c.loopId, c.assignId) else none)
|
||
|
||
/-- A human-readable report of the loop-invariant assignments. -/
|
||
def output (prog : Program) : String :=
|
||
match licmCandidates prog with
|
||
| [] => "no loop-invariant assignments found"
|
||
| cands =>
|
||
"loop-invariant assignments (loop ↦ assignment):\n" ++
|
||
String.intercalate "\n"
|
||
(cands.map (fun p => s!" loop #{p.1.val}: assignment #{p.2.val}"))
|
||
|
||
end LicmTransformation
|
||
|
||
end Spa
|