105 lines
4.3 KiB
Lean4
105 lines
4.3 KiB
Lean4
|
|
import Spa.Analysis.Reaching
|
|||
|
|
import Spa.Language.Tagged.Graphs
|
|||
|
|
|
|||
|
|
/-!
|
|||
|
|
# Finding loop-invariant assignments (LICM groundwork)
|
|||
|
|
|
|||
|
|
This wires the **reaching-definitions** analysis (`Spa/Analysis/Reaching.lean`)
|
|||
|
|
to the **tagged AST** to *find* — not yet move — assignments inside a `while`
|
|||
|
|
loop whose right-hand side depends only on definitions made *outside* the loop.
|
|||
|
|
These are the candidates a later LICM pass could hoist.
|
|||
|
|
|
|||
|
|
The pipeline, for each assignment immediately enclosed by a loop:
|
|||
|
|
|
|||
|
|
1. locate its CFG state via the tagged-graph bridge (`Program.stateOfNodeId`);
|
|||
|
|
2. read the reaching definitions at the assignment's *entry*
|
|||
|
|
(`joinForKey s result` — the join over predecessors, i.e. before the
|
|||
|
|
assignment itself runs);
|
|||
|
|
3. union the definition sets of the RHS variables;
|
|||
|
|
4. map each definition site back to its `NodeId` (`Program.nodeIdOf`) and check
|
|||
|
|
it is **not** inside the loop body (structural `subtreeIds` membership).
|
|||
|
|
|
|||
|
|
If every reaching definition of every RHS variable lies outside the loop, the
|
|||
|
|
assignment is reported as loop-invariant. This is the first-order check ("all
|
|||
|
|
reaching definitions outside the loop"); transitive/iterated invariance and the
|
|||
|
|
actual hoisting are out of scope here.
|
|||
|
|
-/
|
|||
|
|
|
|||
|
|
namespace Spa
|
|||
|
|
|
|||
|
|
namespace LicmTransformation
|
|||
|
|
|
|||
|
|
open Forward
|
|||
|
|
|
|||
|
|
/-- An assignment found inside a loop, paired with the data needed to test its
|
|||
|
|
invariance against that (immediately enclosing) loop. -/
|
|||
|
|
structure Candidate where
|
|||
|
|
/-- The enclosing `whileLoop`'s tag (for reporting). -/
|
|||
|
|
loopId : NodeId
|
|||
|
|
/-- Every `NodeId` inside the loop body (the "is-child-of-loop" set). -/
|
|||
|
|
bodyIds : List NodeId
|
|||
|
|
/-- The assignment `BasicStmt`'s tag — what labels its CFG node. -/
|
|||
|
|
assignId : NodeId
|
|||
|
|
/-- The variables read by the assignment's RHS. -/
|
|||
|
|
rhsVars : List String
|
|||
|
|
|
|||
|
|
/-- Collect every assignment together with its *immediately enclosing* loop.
|
|||
|
|
`enclosing` carries the current loop's tag and body id-set, or `none` outside any
|
|||
|
|
loop (in which case assignments are skipped — only in-loop assignments are
|
|||
|
|
candidates). -/
|
|||
|
|
def collectCandidates :
|
|||
|
|
Option (NodeId × List NodeId) → Stmt.Tagged NodeId → List Candidate
|
|||
|
|
| enc, .basic _ bs =>
|
|||
|
|
match bs, enc with
|
|||
|
|
| .assign t _ e, some (loopId, bodyIds) =>
|
|||
|
|
[{ loopId := loopId, bodyIds := bodyIds, assignId := t,
|
|||
|
|
rhsVars := e.erase.vars.sort (· ≤ ·) }]
|
|||
|
|
| _, _ => []
|
|||
|
|
| enc, .andThen _ a b => collectCandidates enc a ++ collectCandidates enc b
|
|||
|
|
| enc, .ifElse _ _ a b => collectCandidates enc a ++ collectCandidates enc b
|
|||
|
|
| _, .whileLoop loopT _ body =>
|
|||
|
|
collectCandidates (some (loopT, body.subtreeIds)) body
|
|||
|
|
|
|||
|
|
/-- Read the definition set assigned to variable `k`, or `⊥` if absent. -/
|
|||
|
|
def lookupDef (prog : Program) (vs : VariableValues (DefSet prog) prog)
|
|||
|
|
(k : String) : DefSet prog :=
|
|||
|
|
if h : FiniteMap.MemKey k vs then (FiniteMap.locate h).1 else ⊥
|
|||
|
|
|
|||
|
|
/-- The CFG states marked as definition sites in a `DefSet` (those mapped to
|
|||
|
|
`true`). -/
|
|||
|
|
def defSites (prog : Program) (d : DefSet prog) : List prog.State :=
|
|||
|
|
prog.states.filter (fun s =>
|
|||
|
|
if h : FiniteMap.MemKey s d then (FiniteMap.locate h).1 else false)
|
|||
|
|
|
|||
|
|
/-- Is the candidate assignment loop-invariant: do all reaching definitions of
|
|||
|
|
its RHS variables lie outside the loop body? -/
|
|||
|
|
def isInvariant (prog : Program) (c : Candidate) : Bool :=
|
|||
|
|
match prog.stateOfNodeId c.assignId with
|
|||
|
|
| none => false
|
|||
|
|
| some s =>
|
|||
|
|
let entry := joinForKey s (result (DefSet prog) prog)
|
|||
|
|
let combined : DefSet prog :=
|
|||
|
|
c.rhsVars.foldl (fun acc k => acc ⊔ lookupDef prog entry k) ⊥
|
|||
|
|
(defSites prog combined).all (fun site =>
|
|||
|
|
match prog.nodeIdOf site with
|
|||
|
|
| some nid => ! decide (nid ∈ c.bodyIds)
|
|||
|
|
| none => false)
|
|||
|
|
|
|||
|
|
/-- The loop-invariant assignments of `prog`, as `(loopId, assignId)` pairs. -/
|
|||
|
|
def licmCandidates (prog : Program) : List (NodeId × NodeId) :=
|
|||
|
|
(collectCandidates none prog.tagged).filterMap (fun c =>
|
|||
|
|
if isInvariant prog c then some (c.loopId, c.assignId) else none)
|
|||
|
|
|
|||
|
|
/-- A human-readable report of the loop-invariant assignments. -/
|
|||
|
|
def output (prog : Program) : String :=
|
|||
|
|
match licmCandidates prog with
|
|||
|
|
| [] => "no loop-invariant assignments found"
|
|||
|
|
| cands =>
|
|||
|
|
"loop-invariant assignments (loop ↦ assignment):\n" ++
|
|||
|
|
String.intercalate "\n"
|
|||
|
|
(cands.map (fun p => s!" loop #{p.1.post}: assignment #{p.2.post}"))
|
|||
|
|
|
|||
|
|
end LicmTransformation
|
|||
|
|
|
|||
|
|
end Spa
|