105 lines
4.3 KiB
Lean4
105 lines
4.3 KiB
Lean4
import Spa.Analysis.Reaching
|
||
import Spa.Language.Tagged.Graphs
|
||
|
||
/-!
|
||
# Finding loop-invariant assignments (LICM groundwork)
|
||
|
||
This wires the **reaching-definitions** analysis (`Spa/Analysis/Reaching.lean`)
|
||
to the **tagged AST** to *find* — not yet move — assignments inside a `while`
|
||
loop whose right-hand side depends only on definitions made *outside* the loop.
|
||
These are the candidates a later LICM pass could hoist.
|
||
|
||
The pipeline, for each assignment immediately enclosed by a loop:
|
||
|
||
1. locate its CFG state via the tagged-graph bridge (`Program.stateOfNodeId`);
|
||
2. read the reaching definitions at the assignment's *entry*
|
||
(`joinForKey s result` — the join over predecessors, i.e. before the
|
||
assignment itself runs);
|
||
3. union the definition sets of the RHS variables;
|
||
4. map each definition site back to its `NodeId` (`Program.nodeIdOf`) and check
|
||
it is **not** inside the loop body (structural `subtreeIds` membership).
|
||
|
||
If every reaching definition of every RHS variable lies outside the loop, the
|
||
assignment is reported as loop-invariant. This is the first-order check ("all
|
||
reaching definitions outside the loop"); transitive/iterated invariance and the
|
||
actual hoisting are out of scope here.
|
||
-/
|
||
|
||
namespace Spa
|
||
|
||
namespace LicmTransformation
|
||
|
||
open Forward
|
||
|
||
/-- An assignment found inside a loop, paired with the data needed to test its
|
||
invariance against that (immediately enclosing) loop. -/
|
||
structure Candidate where
|
||
/-- The enclosing `whileLoop`'s tag (for reporting). -/
|
||
loopId : NodeId
|
||
/-- Every `NodeId` inside the loop body (the "is-child-of-loop" set). -/
|
||
bodyIds : List NodeId
|
||
/-- The assignment `BasicStmt`'s tag — what labels its CFG node. -/
|
||
assignId : NodeId
|
||
/-- The variables read by the assignment's RHS. -/
|
||
rhsVars : List String
|
||
|
||
/-- Collect every assignment together with its *immediately enclosing* loop.
|
||
`enclosing` carries the current loop's tag and body id-set, or `none` outside any
|
||
loop (in which case assignments are skipped — only in-loop assignments are
|
||
candidates). -/
|
||
def collectCandidates (enc : Option (NodeId × List NodeId)) :
|
||
Stmt.Tagged NodeId → List Candidate
|
||
| .basic _ bs =>
|
||
match bs, enc with
|
||
| .assign t _ e, some (loopId, bodyIds) =>
|
||
[{ loopId := loopId, bodyIds := bodyIds, assignId := t,
|
||
rhsVars := e.erase.vars.sort (· ≤ ·) }]
|
||
| _, _ => []
|
||
| .andThen _ a b => collectCandidates enc a ++ collectCandidates enc b
|
||
| .ifElse _ _ a b => collectCandidates enc a ++ collectCandidates enc b
|
||
| .whileLoop loopT _ body =>
|
||
collectCandidates (some (loopT, body.subtreeIds)) body
|
||
|
||
/-- Read the definition set assigned to variable `k`, or `⊥` if absent. -/
|
||
def lookupDef (prog : Program) (vs : VariableValues (DefSet prog) prog)
|
||
(k : String) : DefSet prog :=
|
||
if h : FiniteMap.MemKey k vs then (FiniteMap.locate h).1 else ⊥
|
||
|
||
/-- The CFG states marked as definition sites in a `DefSet` (those mapped to
|
||
`true`). -/
|
||
def defSites (prog : Program) (d : DefSet prog) : List prog.State :=
|
||
prog.states.filter (fun s =>
|
||
if h : FiniteMap.MemKey s d then (FiniteMap.locate h).1 else false)
|
||
|
||
/-- Is the candidate assignment loop-invariant: do all reaching definitions of
|
||
its RHS variables lie outside the loop body? -/
|
||
def isInvariant (prog : Program) (c : Candidate) : Bool :=
|
||
match prog.stateOfNodeId c.assignId with
|
||
| none => false
|
||
| some s =>
|
||
let entry := joinForKey s (result (DefSet prog) prog)
|
||
let combined : DefSet prog :=
|
||
c.rhsVars.foldl (fun acc k => acc ⊔ lookupDef prog entry k) ⊥
|
||
(defSites prog combined).all (fun site =>
|
||
match prog.nodeIdOf site with
|
||
| some nid => ! decide (nid ∈ c.bodyIds)
|
||
| none => false)
|
||
|
||
/-- The loop-invariant assignments of `prog`, as `(loopId, assignId)` pairs. -/
|
||
def licmCandidates (prog : Program) : List (NodeId × NodeId) :=
|
||
(collectCandidates none prog.tagged).filterMap (fun c =>
|
||
if isInvariant prog c then some (c.loopId, c.assignId) else none)
|
||
|
||
/-- A human-readable report of the loop-invariant assignments. -/
|
||
def output (prog : Program) : String :=
|
||
match licmCandidates prog with
|
||
| [] => "no loop-invariant assignments found"
|
||
| cands =>
|
||
"loop-invariant assignments (loop ↦ assignment):\n" ++
|
||
String.intercalate "\n"
|
||
(cands.map (fun p => s!" loop #{p.1.post}: assignment #{p.2.post}"))
|
||
|
||
end LicmTransformation
|
||
|
||
end Spa
|