Files
agda-spa/lean/Spa/Language/Graphs.lean
2026-06-25 18:55:09 -05:00

248 lines
10 KiB
Lean4
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import Spa.Language.Base
import Mathlib.Data.Fin.Tuple.Basic
import Mathlib.Data.List.ProdSigma
import Mathlib.Data.List.FinRange
/-!
# Algebraic Control Flow Graphs
This file defines control flow graphs and operations to naturally compose them,
making it possible to inductively covnert a program in the object language
(see `Spa.Stmt` in `Spa/Language/Base.lean`) into its corresponding graph.
Graphs are, in general, parameterized by their "payload" (the per-node data); see `GGraph`.
This is useful because other operations, such as finding the CFG node corresponding
to an AST node, are performed by embellishing a graph's basic blocks with their AST
identifiers.
The operations are deliberately a little bit sloppy here, creating empty / statement-less
CFG nodes. Additionally, the current CFG construction algorithm doesn't group
consecutive statements in a single notional basic block into one node.
This makes graph construction much easier to define, and might save us the
trouble of (when trying to find the CFG node for an AST node) doing
indexing into a list.
-/
/-- Bump the upper bound of a list of `Fin`s without changing their value. -/
def List.finCastAdd {n : } (l : List (Fin n)) (m : ) : List (Fin (n + m)) :=
l.map (Fin.castAdd m)
/-- Bump the upper bound of a list of `Fin`s by adding the amount to their value. -/
def List.finNatAdd {m : } (l : List (Fin m)) (n : ) : List (Fin (n + m)) :=
l.map (Fin.natAdd n)
/-- Bump the upper bound of a list of `Fin` pairs without changing their value. -/
def List.finCastAddProd {n : } (l : List (Fin n × Fin n)) (m : ) :
List (Fin (n + m) × Fin (n + m)) :=
l.map (fun e => (e.1.castAdd m, e.2.castAdd m))
/-- Bump the upper bound of a list of `Fin` pairs by adding the amount to their value. -/
def List.finNatAddProd {m : } (l : List (Fin m × Fin m)) (n : ) :
List (Fin (n + m) × Fin (n + m)) :=
l.map (fun e => (e.1.natAdd n, e.2.natAdd n))
namespace Spa
/-- Graph with general (`α`-labeled) nodes. By using a tuple `Fin size → α`
and writing `edges` over the `Fin size`, guarantees all edges are between real nodes.
To make graph composition via operations not force a
[`alga`](https://hackage.haskell.org/package/algebraic-graphs)-style "connect"-based
algebra, explicitly defines `inputs` and `outputs`, which are the only nodes that
get connected when graphs are sequenced. This makes the graph construction
operations more naturally fit with how CFGs are created from `Stmt`s. -/
structure GGraph (α : Type) where
size :
nodes : Fin size α
edges : List (Fin size × Fin size)
inputs : List (Fin size)
outputs : List (Fin size)
namespace GGraph
variable {α β : Type}
/-- An index (node) in the CFG. -/
abbrev Index (g : GGraph α) : Type := Fin g.size
/-- An edge in the CFG. -/
abbrev Edge (g : GGraph α) : Type := g.Index × g.Index
instance : Functor GGraph where
map {α β : Type} (f : α β) (g : GGraph α) : GGraph β :=
{ size := g.size,
nodes := f g.nodes
edges := g.edges,
inputs := g.inputs,
outputs := g.outputs }
@[simp] lemma map_size (f : α β) (g : GGraph α) : (f <$> g).size = g.size := rfl
@[simp] lemma map_edges (f : α β) (g : GGraph α) : (f <$> g).edges = g.edges := rfl
@[simp] lemma map_inputs (f : α β) (g : GGraph α) : (f <$> g).inputs = g.inputs := rfl
@[simp] lemma map_outputs (f : α β) (g : GGraph α) : (f <$> g).outputs = g.outputs := rfl
/-- Overlay two graphs: create a new graph whose nodes and edges come from two
sub-graphs, without inserting any additional edges. Also combines the
input and output node sets. -/
def overlay (g₁ g₂ : GGraph α) : GGraph α where
size := g₁.size + g₂.size
nodes := Fin.append g₁.nodes g₂.nodes
edges := g₁.edges.finCastAddProd g₂.size ++ g₂.edges.finNatAddProd g₁.size
inputs := g₁.inputs.finCastAdd g₂.size ++ g₂.inputs.finNatAdd g₁.size
outputs := g₁.outputs.finCastAdd g₂.size ++ g₂.outputs.finNatAdd g₁.size
@[inherit_doc] scoped infixr:70 "" => GGraph.overlay
/-- Sequence two CFGs: create a combined graph whose nodes and edges come
from two subgraphs, __and__ make all the outputs of the left graph have edges to
all the inputs of the right graph. By the semantics of CFGs, this
encodes the fact that code first traverses the basic blocks in theleft
graph, and does the same for the right graph. -/
def sequence (g₁ g₂ : GGraph α) : GGraph α where
size := g₁.size + g₂.size
nodes := Fin.append g₁.nodes g₂.nodes
edges := g₁.edges.finCastAddProd g₂.size ++ g₂.edges.finNatAddProd g₁.size ++
(g₁.outputs.finCastAdd g₂.size).product (g₂.inputs.finNatAdd g₁.size)
inputs := g₁.inputs.finCastAdd g₂.size
outputs := g₂.outputs.finNatAdd g₁.size
@[inherit_doc] scoped infixr:70 "" => GGraph.sequence
/-- When a graph `g` is wrapped in a `loop`, the index / node corresponding
to the input of the new loop. -/
def loopIn (g : GGraph α) : Fin (2 + g.size) := (0 : Fin 2).castAdd g.size
/-- When a graph `g` is wrapped in a `loop`, the index / node corresponding
to the output of the new loop. -/
def loopOut (g : GGraph α) : Fin (2 + g.size) := (1 : Fin 2).castAdd g.size
/-- Creates a zero-or-more loop loop in the CFG: connects all the output
nodes of the CFG back to the graph's beginning, and also introduces a path
to a new ending node (see `loopOut`) which bypasses the entire graph.
Notably, both the new input (`loopIn`) and new output (`loopOut`)
nodes are necessary for correctness: adding a path from inputs to a
hypothetical no-op end node encodes something like "just the first statement is executed".
Similarly, just adding a path from a a hypothetical no-op beginning node
to the outputs encodes "just the last statement is executed".
This is technically sloppy (see module comment), but it's simple.
-/
def loop (g : GGraph (List β)) : GGraph (List β) where
size := 2 + g.size
nodes := Fin.append (fun _ : Fin 2 => []) g.nodes
edges := g.edges.finNatAddProd 2 ++
((g.loopIn, ·) <$> g.inputs.finNatAdd 2) ++
((·, g.loopOut) <$> g.outputs.finNatAdd 2) ++
[(g.loopOut, g.loopIn), (g.loopIn, g.loopOut)]
inputs := [g.loopIn]
outputs := [g.loopOut]
@[simp] lemma loop_inputs (g : GGraph (List β)) : (loop g).inputs = [g.loopIn] := rfl
@[simp] lemma loop_outputs (g : GGraph (List β)) : (loop g).outputs = [g.loopOut] := rfl
/-- Creates a single-node graph whose node contains the given value. -/
def singleton (a : α) : GGraph α where
size := 1
nodes := fun _ => a
edges := []
inputs := [0]
outputs := [0]
/-- Creates a new graph with a single input and single output node. Useful to ensure there's
a single point of entry and single point of exit. -/
def wrap (g : GGraph (List β)) : GGraph (List β) :=
singleton [] g singleton []
@[simp] lemma map_singleton (f : α β) (a : α) :
f <$> singleton a = singleton (f a) := rfl
@[simp] lemma map_overlay (f : α β) (g₁ g₂ : GGraph α) :
f<$> (g₁ g₂) = f <$> g₁ f <$> g₂ := by
rcases g₁ with n₁, nd₁, e₁, i₁, o₁; rcases g₂ with n₂, nd₂, e₂, i₂, o₂
simp only [Functor.map, GGraph.overlay]
congr 1
funext i
refine Fin.addCases ?_ ?_ i <;> intro j <;> simp [Fin.append_left, Fin.append_right]
@[simp] lemma map_sequence (f : α β) (g₁ g₂ : GGraph α) :
f <$> (g₁ g₂) = (f <$> g₁) (f <$> g₂) := by
rcases g₁ with n₁, nd₁, e₁, i₁, o₁; rcases g₂ with n₂, nd₂, e₂, i₂, o₂
simp only [Functor.map, GGraph.sequence]
congr 1
funext i
refine Fin.addCases ?_ ?_ i <;> intro j <;> simp [Fin.append_left, Fin.append_right]
@[simp] lemma map_loop (h : β γ) (g : GGraph (List β)) :
(List.map h) <$> (loop g) = loop (List.map h <$> g) := by
rcases g with n, nd, e, i, o
simp only [Functor.map, GGraph.loop]
congr 1
funext i
refine Fin.addCases ?_ ?_ i <;> intro j <;> simp [Fin.append_left, Fin.append_right]
@[simp] lemma map_wrap (h : β γ) (g : GGraph (List β)) :
(List.map h) <$> wrap g = wrap (List.map h <$> g) := by
simp [GGraph.wrap, GGraph.map_sequence, GGraph.map_singleton]
variable (g : GGraph α)
/-- All the nodes in the graph. -/
def indices : List g.Index := List.finRange g.size
/-- All of the graph's indices are listed in `indices`. -/
lemma mem_indices (idx : g.Index) : idx g.indices :=
List.mem_finRange idx
/-- `indices` does not have duplicates. -/
lemma nodup_indices : g.indices.Nodup :=
List.nodup_finRange g.size
/-- Predecessors of a particular node in the graph. --/
def predecessors (idx : g.Index) : List g.Index :=
g.indices.filter (fun idx' => (idx', idx) g.edges)
/-- There's there's an edge between two nodes `idx₁` and `idx₂`,
then `idx₁` is the predecessor of `idx₂`. -/
lemma mem_predecessors_of_edge {idx₁ idx₂ : g.Index}
(h : (idx₁, idx₂) g.edges) : idx₁ g.predecessors idx₂ :=
List.mem_filter.mpr g.mem_indices idx₁, by simpa using h
/-- A node is a predecessor of another node only if there's an
edge between them. -/
lemma edge_of_mem_predecessors {idx₁ idx₂ : g.Index}
(h : idx₁ g.predecessors idx₂) : (idx₁, idx₂) g.edges := by
simpa using (List.mem_filter.mp h).2
end GGraph
/-- "Normal" graphs, for the purposes of the analyses in this
framework, have basic blocks in their nodes, and nothing else. -/
abbrev Graph : Type := GGraph (List BasicStmt)
namespace Graph
export GGraph (overlay sequence loop singleton wrap loop_inputs loop_outputs)
@[inherit_doc] scoped infixr:70 "" => GGraph.overlay
@[inherit_doc] scoped infixr:70 "" => GGraph.sequence
end Graph
open Graph in
def Stmt.cfg : Stmt Graph
-- A basic statement goes into a single basic block
| .basic bs => singleton [bs]
-- Sequencing of statements corresponds naturally to CFG sequencing
| .andThen s₁ s₂ => s₁.cfg s₂.cfg
-- An if can execute either one branch or the other; overlap them.
-- Subsequent sequencing (etc.) will end up creating the forks and joins.
| .ifElse _ s₁ s₂ => s₁.cfg s₂.cfg
-- The `loop` construct was developed specifically for zero-or-more loops like this.
| .whileLoop _ s => loop s.cfg
end Spa