Files
agda-spa/lean/Spa/Transformation/Licm.lean

105 lines
4.3 KiB
Lean4
Raw Normal View History

import Spa.Analysis.Reaching
import Spa.Language.Tagged.Graphs
/-!
# Finding loop-invariant assignments (LICM groundwork)
This wires the **reaching-definitions** analysis (`Spa/Analysis/Reaching.lean`)
to the **tagged AST** to *find* not yet move assignments inside a `while`
loop whose right-hand side depends only on definitions made *outside* the loop.
These are the candidates a later LICM pass could hoist.
The pipeline, for each assignment immediately enclosed by a loop:
1. locate its CFG state via the tagged-graph bridge (`Program.stateOfNodeId`);
2. read the reaching definitions at the assignment's *entry*
(`joinForKey s result` the join over predecessors, i.e. before the
assignment itself runs);
3. union the definition sets of the RHS variables;
4. map each definition site back to its `NodeId` (`Program.nodeIdOf`) and check
it is **not** inside the loop body (structural `subtreeIds` membership).
If every reaching definition of every RHS variable lies outside the loop, the
assignment is reported as loop-invariant. This is the first-order check ("all
reaching definitions outside the loop"); transitive/iterated invariance and the
actual hoisting are out of scope here.
-/
namespace Spa
namespace LicmTransformation
open Forward
/-- An assignment found inside a loop, paired with the data needed to test its
invariance against that (immediately enclosing) loop. -/
structure Candidate where
/-- The enclosing `whileLoop`'s tag (for reporting). -/
loopId : NodeId
/-- Every `NodeId` inside the loop body (the "is-child-of-loop" set). -/
bodyIds : List NodeId
/-- The assignment `BasicStmt`'s tag — what labels its CFG node. -/
assignId : NodeId
/-- The variables read by the assignment's RHS. -/
rhsVars : List String
/-- Collect every assignment together with its *immediately enclosing* loop.
`enclosing` carries the current loop's tag and body id-set, or `none` outside any
loop (in which case assignments are skipped only in-loop assignments are
candidates). -/
def collectCandidates :
Option (NodeId × List NodeId) Stmt.Tagged NodeId List Candidate
| enc, .basic _ bs =>
match bs, enc with
| .assign t _ e, some (loopId, bodyIds) =>
[{ loopId := loopId, bodyIds := bodyIds, assignId := t,
rhsVars := e.erase.vars.sort (· ·) }]
| _, _ => []
| enc, .andThen _ a b => collectCandidates enc a ++ collectCandidates enc b
| enc, .ifElse _ _ a b => collectCandidates enc a ++ collectCandidates enc b
| _, .whileLoop loopT _ body =>
collectCandidates (some (loopT, body.subtreeIds)) body
/-- Read the definition set assigned to variable `k`, or `⊥` if absent. -/
def lookupDef (prog : Program) (vs : VariableValues (DefSet prog) prog)
(k : String) : DefSet prog :=
if h : FiniteMap.MemKey k vs then (FiniteMap.locate h).1 else
/-- The CFG states marked as definition sites in a `DefSet` (those mapped to
`true`). -/
def defSites (prog : Program) (d : DefSet prog) : List prog.State :=
prog.states.filter (fun s =>
if h : FiniteMap.MemKey s d then (FiniteMap.locate h).1 else false)
/-- Is the candidate assignment loop-invariant: do all reaching definitions of
its RHS variables lie outside the loop body? -/
def isInvariant (prog : Program) (c : Candidate) : Bool :=
match prog.stateOfNodeId c.assignId with
| none => false
| some s =>
let entry := joinForKey s (result (DefSet prog) prog)
let combined : DefSet prog :=
c.rhsVars.foldl (fun acc k => acc lookupDef prog entry k)
(defSites prog combined).all (fun site =>
match prog.nodeIdOf site with
| some nid => ! decide (nid c.bodyIds)
| none => false)
/-- The loop-invariant assignments of `prog`, as `(loopId, assignId)` pairs. -/
def licmCandidates (prog : Program) : List (NodeId × NodeId) :=
(collectCandidates none prog.tagged).filterMap (fun c =>
if isInvariant prog c then some (c.loopId, c.assignId) else none)
/-- A human-readable report of the loop-invariant assignments. -/
def output (prog : Program) : String :=
match licmCandidates prog with
| [] => "no loop-invariant assignments found"
| cands =>
"loop-invariant assignments (loop ↦ assignment):\n" ++
String.intercalate "\n"
(cands.map (fun p => s!" loop #{p.1.post}: assignment #{p.2.post}"))
end LicmTransformation
end Spa