Files
agda-spa/lean/Spa/Transformation/Licm.lean

103 lines
4.5 KiB
Lean4
Raw Normal View History

import Spa.Analysis.Reaching
import Spa.Language.Tagged.Graphs
/-!
# Finding loop-invariant assignments (LICM groundwork)
This wires the **reaching-definitions** analysis (`Spa/Analysis/Reaching.lean`)
to the **tagged AST** to *find* not yet move assignments inside a `while`
loop whose right-hand side depends only on definitions made *outside* the loop.
These are the candidates a later LICM pass could hoist.
The pipeline, for each assignment immediately enclosed by a loop:
1. locate its CFG state via the tagged-graph bridge (`Program.stateOfNodeId`);
2. read the reaching definitions at the assignment's *entry*
(`joinForKey s result` the join over predecessors, i.e. before the
assignment itself runs);
3. union the definition sets of the RHS variables;
4. map each definition site back to its `RawId` (`Program.nodeIdOf`) and check
it is **not** inside the loop body (structural `subtreeIds` membership).
If every reaching definition of every RHS variable lies outside the loop, the
assignment is reported as loop-invariant. This is the first-order check ("all
reaching definitions outside the loop"); transitive/iterated invariance and the
actual hoisting are out of scope here.
-/
namespace Spa
namespace LicmTransformation
open Forward
/-- An assignment found inside a loop, paired with the data needed to test its
invariance against that (immediately enclosing) loop. -/
structure Candidate (prog : Program) where
/-- The enclosing `whileLoop`'s tag (for reporting). -/
loopId : prog.NodeId
/-- Every node id inside the loop body (the "is-child-of-loop" set). -/
bodyIds : List prog.NodeId
/-- The assignment `BasicStmt`'s tag — what labels its CFG node. -/
assignId : prog.NodeId
/-- The variables read by the assignment's RHS. -/
rhsVars : List String
/-- Collect every assignment together with its *immediately enclosing* loop.
`enclosing` carries the current loop's tag and body id-set, or `none` outside any
loop (in which case assignments are skipped only in-loop assignments are
candidates). -/
def collectCandidates (prog : Program) (enc : Option (prog.NodeId × List prog.NodeId)) :
Stmt.Tagged prog.NodeId List (Candidate prog)
| .basic _ bs =>
match bs, enc with
| .assign t _ e, some (loopId, bodyIds) =>
[{ loopId := loopId, bodyIds := bodyIds, assignId := t,
rhsVars := e.erase.vars.sort (· ·) }]
| _, _ => []
| .andThen _ a b => collectCandidates prog enc a ++ collectCandidates prog enc b
| .ifElse _ _ a b => collectCandidates prog enc a ++ collectCandidates prog enc b
| .whileLoop loopT _ body =>
collectCandidates prog (some (loopT, body.subtreeIds)) body
/-- Read the definition set assigned to variable `k`, or `⊥` if absent. -/
def lookupDef (prog : Program) (vs : VariableValues (DefSet prog) prog)
(k : String) : DefSet prog :=
if h : FiniteMap.MemKey k vs then (FiniteMap.locate h).1 else
/-- The AST node ids marked as definition sites in a `DefSet` (those mapped to
`true`). With the AST-id-keyed lattice these are recovered directly. -/
def defSites (prog : Program) (d : DefSet prog) : List prog.NodeId :=
(List.finRange prog.size).filter (fun i => d i)
/-- Is the candidate assignment loop-invariant: do all reaching definitions of
its RHS variables lie outside the loop body? Reaching sets are now keyed by AST
node id, so we compare against the loop-body ids directly (embedding the raw
body ids into `p.NodeId`). -/
def isInvariant (prog : Program) (c : Candidate prog) : Bool :=
match prog.stateOfNodeId c.assignId with
| none => false
| some s =>
let entry := joinForKey s (result (DefSet prog) prog)
let combined : DefSet prog :=
c.rhsVars.foldl (fun acc k => acc lookupDef prog entry k)
(defSites prog combined).all (fun nid => ! decide (nid c.bodyIds))
/-- The loop-invariant assignments of `prog`, as `(loopId, assignId)` pairs. -/
def licmCandidates (prog : Program) : List (prog.NodeId × prog.NodeId) :=
(collectCandidates prog none prog.taggedFin).filterMap (fun c =>
if isInvariant prog c then some (c.loopId, c.assignId) else none)
/-- A human-readable report of the loop-invariant assignments. -/
def output (prog : Program) : String :=
match licmCandidates prog with
| [] => "no loop-invariant assignments found"
| cands =>
"loop-invariant assignments (loop ↦ assignment):\n" ++
String.intercalate "\n"
(cands.map (fun p => s!" loop #{p.1.val}: assignment #{p.2.val}"))
end LicmTransformation
end Spa