From c2ad0db668be13b7d02d66b20acfaa7ab9a8c74f Mon Sep 17 00:00:00 2001 From: Danila Fedorin Date: Thu, 25 Jun 2026 17:01:27 -0500 Subject: [PATCH] Update comments in Graph and make map be a Functor instance Signed-off-by: Danila Fedorin --- lean/Spa/Language/Graphs.lean | 147 ++++++++++++++++++++++-------- lean/Spa/Language/Properties.lean | 26 +++--- 2 files changed, 122 insertions(+), 51 deletions(-) diff --git a/lean/Spa/Language/Graphs.lean b/lean/Spa/Language/Graphs.lean index 7dc919f..e10c535 100644 --- a/lean/Spa/Language/Graphs.lean +++ b/lean/Spa/Language/Graphs.lean @@ -3,22 +3,56 @@ import Mathlib.Data.Fin.Tuple.Basic import Mathlib.Data.List.ProdSigma import Mathlib.Data.List.FinRange +/-! + +Algebraic Control Flow Graphs. + +This file defines control flow graphs and operations to naturally compose them, +making it possible to inductively covnert a program in the object language +(see `Spa.Stmt` in `Spa/Language/Base.lean`) into its corresponding graph. + +Graphs are, in general, parameterized by their "payload" (the per-node data); see `GGraph`. +This is useful because other operations, such as finding the CFG node corresponding +to an AST node, are performed by embellishing a graph's basic blocks with their AST +identifiers. + +The operations are deliberately a little bit sloppy here, creating empty / statement-less +CFG nodes. Additionally, the current CFG construction algorithm doesn't group +consecutive statements in a single notional basic block into one node. +This makes graph construction much easier to define, and might save us the +trouble of (when trying to find the CFG node for an AST node) doing +indexing into a list. + +-/ + +/-- Bump the upper bound of a list of `Fin`s without changing their value. -/ def List.finCastAdd {n : ℕ} (l : List (Fin n)) (m : ℕ) : List (Fin (n + m)) := l.map (Fin.castAdd m) +/-- Bump the upper bound of a list of `Fin`s by adding the amount to their value. -/ def List.finNatAdd {m : ℕ} (l : List (Fin m)) (n : ℕ) : List (Fin (n + m)) := l.map (Fin.natAdd n) +/-- Bump the upper bound of a list of `Fin` pairs without changing their value. -/ def List.finCastAddProd {n : ℕ} (l : List (Fin n × Fin n)) (m : ℕ) : List (Fin (n + m) × Fin (n + m)) := l.map (fun e => (e.1.castAdd m, e.2.castAdd m)) +/-- Bump the upper bound of a list of `Fin` pairs by adding the amount to their value. -/ def List.finNatAddProd {m : ℕ} (l : List (Fin m × Fin m)) (n : ℕ) : List (Fin (n + m) × Fin (n + m)) := l.map (fun e => (e.1.natAdd n, e.2.natAdd n)) namespace Spa +/-- Graph with general (`α`-labeled) nodes. By using a tuple `Fin size → α` + and writing `edges` over the `Fin size`, guarantees all edges are between real nodes. + + To make graph composition via operations not force a + [`alga`](https://hackage.haskell.org/package/algebraic-graphs)-style "connect"-based + algebra, explicitly defines `inputs` and `outputs`, which are the only nodes that + get connected when graphs are sequenced. This makes the graph construction + operations more naturally fit with how CFGs are created from `Stmt`s. -/ structure GGraph (α : Type) where size : ℕ nodes : Fin size → α @@ -30,32 +64,43 @@ namespace GGraph variable {α β : Type} +/-- An index (node) in the CFG. -/ abbrev Index (g : GGraph α) : Type := Fin g.size +/-- An edge in the CFG. -/ abbrev Edge (g : GGraph α) : Type := g.Index × g.Index -def map (f : α → β) (g : GGraph α) : GGraph β where - size := g.size - nodes := fun i => f (g.nodes i) - edges := g.edges - inputs := g.inputs - outputs := g.outputs +instance : Functor GGraph where + map {α β : Type} (f : α → β) (g : GGraph α) : GGraph β := + { size := g.size, + nodes := f ∘ g.nodes + edges := g.edges, + inputs := g.inputs, + outputs := g.outputs } -@[simp] lemma map_size (f : α → β) (g : GGraph α) : (g.map f).size = g.size := rfl -@[simp] lemma map_edges (f : α → β) (g : GGraph α) : (g.map f).edges = g.edges := rfl -@[simp] lemma map_inputs (f : α → β) (g : GGraph α) : (g.map f).inputs = g.inputs := rfl -@[simp] lemma map_outputs (f : α → β) (g : GGraph α) : (g.map f).outputs = g.outputs := rfl +@[simp] lemma map_size (f : α → β) (g : GGraph α) : (f <$> g).size = g.size := rfl +@[simp] lemma map_edges (f : α → β) (g : GGraph α) : (f <$> g).edges = g.edges := rfl +@[simp] lemma map_inputs (f : α → β) (g : GGraph α) : (f <$> g).inputs = g.inputs := rfl +@[simp] lemma map_outputs (f : α → β) (g : GGraph α) : (f <$> g).outputs = g.outputs := rfl -def comp (g₁ g₂ : GGraph α) : GGraph α where +/-- Overlay two graphs: create a new graph whose nodes and edges come from two + sub-graphs, without inserting any additional edges. Also combines the + input and output node sets. -/ +def overlay (g₁ g₂ : GGraph α) : GGraph α where size := g₁.size + g₂.size nodes := Fin.append g₁.nodes g₂.nodes edges := g₁.edges.finCastAddProd g₂.size ++ g₂.edges.finNatAddProd g₁.size inputs := g₁.inputs.finCastAdd g₂.size ++ g₂.inputs.finNatAdd g₁.size outputs := g₁.outputs.finCastAdd g₂.size ++ g₂.outputs.finNatAdd g₁.size -@[inherit_doc] scoped infixr:70 " ∙ " => GGraph.comp +@[inherit_doc] scoped infixr:70 " ∙ " => GGraph.overlay -def link (g₁ g₂ : GGraph α) : GGraph α where +/-- Sequence two CFGs: create a combined graph whose nodes and edges come + from two subgraphs, __and__ make all the outputs of the left graph have edges to + all the inputs of the right graph. By the semantics of CFGs, this + encodes the fact that code first traverses the basic blocks in theleft + graph, and does the same for the right graph. -/ +def sequence (g₁ g₂ : GGraph α) : GGraph α where size := g₁.size + g₂.size nodes := Fin.append g₁.nodes g₂.nodes edges := g₁.edges.finCastAddProd g₂.size ++ g₂.edges.finNatAddProd g₁.size ++ @@ -63,18 +108,34 @@ def link (g₁ g₂ : GGraph α) : GGraph α where inputs := g₁.inputs.finCastAdd g₂.size outputs := g₂.outputs.finNatAdd g₁.size -@[inherit_doc] scoped infixr:70 " ⤳ " => GGraph.link +@[inherit_doc] scoped infixr:70 " ⤳ " => GGraph.sequence +/-- When a graph `g` is wrapped in a `loop`, the index / node corresponding + to the input of the new loop. -/ def loopIn (g : GGraph α) : Fin (2 + g.size) := (0 : Fin 2).castAdd g.size +/-- When a graph `g` is wrapped in a `loop`, the index / node corresponding + to the output of the new loop. -/ def loopOut (g : GGraph α) : Fin (2 + g.size) := (1 : Fin 2).castAdd g.size +/-- Creates a zero-or-more loop loop in the CFG: connects all the output + nodes of the CFG back to the graph's beginning, and also introduces a path + to a new ending node (see `loopOut`) which bypasses the entire graph. + + Notably, both the new input (`loopIn`) and new output (`loopOut`) + nodes are necessary for correctness: adding a path from inputs to a + hypothetical no-op end node encodes something like "just the first statement is executed". + Similarly, just adding a path from a a hypothetical no-op beginning node + to the outputs encodes "just the last statement is executed". + + This is technically sloppy (see module comment), but it's simple. +-/ def loop (g : GGraph (List β)) : GGraph (List β) where size := 2 + g.size nodes := Fin.append (fun _ : Fin 2 => []) g.nodes edges := g.edges.finNatAddProd 2 ++ - (g.inputs.finNatAdd 2).map (g.loopIn, ·) ++ - (g.outputs.finNatAdd 2).map (·, g.loopOut) ++ + ((g.loopIn, ·) <$> g.inputs.finNatAdd 2) ++ + ((·, g.loopOut) <$> g.outputs.finNatAdd 2) ++ [(g.loopOut, g.loopIn), (g.loopIn, g.loopOut)] inputs := [g.loopIn] outputs := [g.loopOut] @@ -83,14 +144,7 @@ def loop (g : GGraph (List β)) : GGraph (List β) where @[simp] lemma loop_outputs (g : GGraph (List β)) : (loop g).outputs = [g.loopOut] := rfl -def skipto (g₁ g₂ : GGraph α) : GGraph α where - size := g₁.size + g₂.size - nodes := Fin.append g₁.nodes g₂.nodes - edges := g₁.edges.finCastAddProd g₂.size ++ g₂.edges.finNatAddProd g₁.size ++ - (g₁.inputs.finCastAdd g₂.size).product (g₂.inputs.finNatAdd g₁.size) - inputs := g₁.inputs.finCastAdd g₂.size - outputs := g₂.inputs.finNatAdd g₁.size - +/-- Creates a single-node graph whose node contains the given value. -/ def singleton (a : α) : GGraph α where size := 1 nodes := fun _ => a @@ -98,79 +152,96 @@ def singleton (a : α) : GGraph α where inputs := [0] outputs := [0] +/-- Creates a new graph with a single input and single output node. Useful to ensure there's + a single point of entry and single point of exit. -/ def wrap (g : GGraph (List β)) : GGraph (List β) := singleton [] ⤳ g ⤳ singleton [] @[simp] lemma map_singleton (f : α → β) (a : α) : - (singleton a).map f = singleton (f a) := rfl + f <$> singleton a = singleton (f a) := rfl -@[simp] lemma map_comp (f : α → β) (g₁ g₂ : GGraph α) : - (g₁ ∙ g₂).map f = g₁.map f ∙ g₂.map f := by +@[simp] lemma map_overlay (f : α → β) (g₁ g₂ : GGraph α) : + f<$> (g₁ ∙ g₂) = f <$> g₁ ∙ f <$> g₂ := by rcases g₁ with ⟨n₁, nd₁, e₁, i₁, o₁⟩; rcases g₂ with ⟨n₂, nd₂, e₂, i₂, o₂⟩ - simp only [GGraph.map, GGraph.comp] + simp only [Functor.map, GGraph.overlay] congr 1 funext i refine Fin.addCases ?_ ?_ i <;> intro j <;> simp [Fin.append_left, Fin.append_right] -@[simp] lemma map_link (f : α → β) (g₁ g₂ : GGraph α) : - (g₁ ⤳ g₂).map f = g₁.map f ⤳ g₂.map f := by +@[simp] lemma map_sequence (f : α → β) (g₁ g₂ : GGraph α) : + f <$> (g₁ ⤳ g₂) = (f <$> g₁) ⤳ (f <$> g₂) := by rcases g₁ with ⟨n₁, nd₁, e₁, i₁, o₁⟩; rcases g₂ with ⟨n₂, nd₂, e₂, i₂, o₂⟩ - simp only [GGraph.map, GGraph.link] + simp only [Functor.map, GGraph.sequence] congr 1 funext i refine Fin.addCases ?_ ?_ i <;> intro j <;> simp [Fin.append_left, Fin.append_right] @[simp] lemma map_loop (h : β → γ) (g : GGraph (List β)) : - (loop g).map (List.map h) = loop (g.map (List.map h)) := by + (List.map h) <$> (loop g) = loop (List.map h <$> g) := by rcases g with ⟨n, nd, e, i, o⟩ - simp only [GGraph.map, GGraph.loop] + simp only [Functor.map, GGraph.loop] congr 1 funext i refine Fin.addCases ?_ ?_ i <;> intro j <;> simp [Fin.append_left, Fin.append_right] @[simp] lemma map_wrap (h : β → γ) (g : GGraph (List β)) : - (wrap g).map (List.map h) = wrap (g.map (List.map h)) := by - simp [GGraph.wrap, GGraph.map_link, GGraph.map_singleton] + (List.map h) <$> wrap g = wrap (List.map h <$> g) := by + simp [GGraph.wrap, GGraph.map_sequence, GGraph.map_singleton] variable (g : GGraph α) +/-- All the nodes in the graph. -/ def indices : List g.Index := List.finRange g.size +/-- All of the graph's indices are listed in `indices`. -/ lemma mem_indices (idx : g.Index) : idx ∈ g.indices := List.mem_finRange idx +/-- `indices` does not have duplicates. -/ lemma nodup_indices : g.indices.Nodup := List.nodup_finRange g.size +/-- Predecessors of a particular node in the graph. --/ def predecessors (idx : g.Index) : List g.Index := g.indices.filter (fun idx' => (idx', idx) ∈ g.edges) +/-- There's there's an edge between two nodes `idx₁` and `idx₂`, + then `idx₁` is the predecessor of `idx₂`. -/ lemma mem_predecessors_of_edge {idx₁ idx₂ : g.Index} (h : (idx₁, idx₂) ∈ g.edges) : idx₁ ∈ g.predecessors idx₂ := List.mem_filter.mpr ⟨g.mem_indices idx₁, by simpa using h⟩ +/-- A node is a predecessor of another node only if there's an + edge between them. -/ lemma edge_of_mem_predecessors {idx₁ idx₂ : g.Index} (h : idx₁ ∈ g.predecessors idx₂) : (idx₁, idx₂) ∈ g.edges := by simpa using (List.mem_filter.mp h).2 end GGraph +/-- "Normal" graphs, for the purposes of the analyses in this + framework, have basic blocks in their nodes, and nothing else. -/ abbrev Graph : Type := GGraph (List BasicStmt) namespace Graph -export GGraph (comp link loop skipto singleton wrap loop_inputs loop_outputs) +export GGraph (overlay sequence loop singleton wrap loop_inputs loop_outputs) -@[inherit_doc] scoped infixr:70 " ∙ " => GGraph.comp -@[inherit_doc] scoped infixr:70 " ⤳ " => GGraph.link +@[inherit_doc] scoped infixr:70 " ∙ " => GGraph.overlay +@[inherit_doc] scoped infixr:70 " ⤳ " => GGraph.sequence end Graph open Graph in def Stmt.cfg : Stmt → Graph + -- A basic statement goes into a single basic block | .basic bs => singleton [bs] + -- Sequencing of statements corresponds naturally to CFG sequencing | .andThen s₁ s₂ => s₁.cfg ⤳ s₂.cfg + -- An if can execute either one branch or the other; overlap them. + -- Subsequent sequencing (etc.) will end up creating the forks and joins. | .ifElse _ s₁ s₂ => s₁.cfg ∙ s₂.cfg + -- The `loop` construct was developed specifically for zero-or-more loops like this. | .whileLoop _ s => loop s.cfg end Spa diff --git a/lean/Spa/Language/Properties.lean b/lean/Spa/Language/Properties.lean index ce57bf4..e55c69d 100644 --- a/lean/Spa/Language/Properties.lean +++ b/lean/Spa/Language/Properties.lean @@ -17,7 +17,7 @@ section Embeddings variable {g₁ g₂ : Graph} {ρ₁ ρ₂ : Env} -lemma Trace.comp_left {idx₁ idx₂ : g₁.Index} +lemma Trace.overlay_left {idx₁ idx₂ : g₁.Index} (tr : Trace g₁ idx₁ idx₂ ρ₁ ρ₂) : Trace (g₁ ∙ g₂) (idx₁.castAdd g₂.size) (idx₂.castAdd g₂.size) ρ₁ ρ₂ := by induction tr with @@ -29,7 +29,7 @@ lemma Trace.comp_left {idx₁ idx₂ : g₁.Index} · rwa [show (g₁ ∙ g₂).nodes = Fin.append g₁.nodes g₂.nodes from rfl, Fin.append_left] · exact List.mem_append_left _ (List.mem_map_of_mem _ he) -lemma Trace.comp_right {idx₁ idx₂ : g₂.Index} +lemma Trace.overlay_right {idx₁ idx₂ : g₂.Index} (tr : Trace g₂ idx₁ idx₂ ρ₁ ρ₂) : Trace (g₁ ∙ g₂) (idx₁.natAdd g₁.size) (idx₂.natAdd g₁.size) ρ₁ ρ₂ := by induction tr with @@ -41,7 +41,7 @@ lemma Trace.comp_right {idx₁ idx₂ : g₂.Index} · rwa [show (g₁ ∙ g₂).nodes = Fin.append g₁.nodes g₂.nodes from rfl, Fin.append_right] · exact List.mem_append_right _ (List.mem_map_of_mem _ he) -lemma Trace.link_left {idx₁ idx₂ : g₁.Index} +lemma Trace.sequence_left {idx₁ idx₂ : g₁.Index} (tr : Trace g₁ idx₁ idx₂ ρ₁ ρ₂) : Trace (g₁ ⤳ g₂) (idx₁.castAdd g₂.size) (idx₂.castAdd g₂.size) ρ₁ ρ₂ := by induction tr with @@ -53,7 +53,7 @@ lemma Trace.link_left {idx₁ idx₂ : g₁.Index} · rwa [show (g₁ ⤳ g₂).nodes = Fin.append g₁.nodes g₂.nodes from rfl, Fin.append_left] · exact List.mem_append_left _ (List.mem_append_left _ (List.mem_map_of_mem _ he)) -lemma Trace.link_right {idx₁ idx₂ : g₂.Index} +lemma Trace.sequence_right {idx₁ idx₂ : g₂.Index} (tr : Trace g₂ idx₁ idx₂ ρ₁ ρ₂) : Trace (g₁ ⤳ g₂) (idx₁.natAdd g₁.size) (idx₂.natAdd g₁.size) ρ₁ ρ₂ := by induction tr with @@ -66,19 +66,19 @@ lemma Trace.link_right {idx₁ idx₂ : g₂.Index} · exact List.mem_append_left _ (List.mem_append_right _ (List.mem_map_of_mem _ he)) -lemma EndToEndTrace.comp_left (etr : EndToEndTrace g₁ ρ₁ ρ₂) : +lemma EndToEndTrace.overlay_left (etr : EndToEndTrace g₁ ρ₁ ρ₂) : EndToEndTrace (g₁ ∙ g₂) ρ₁ ρ₂ := by obtain ⟨i₁, h₁, i₂, h₂, tr⟩ := etr exact ⟨i₁.castAdd g₂.size, List.mem_append_left _ (List.mem_map_of_mem _ h₁), i₂.castAdd g₂.size, List.mem_append_left _ (List.mem_map_of_mem _ h₂), - tr.comp_left⟩ + tr.overlay_left⟩ -lemma EndToEndTrace.comp_right (etr : EndToEndTrace g₂ ρ₁ ρ₂) : +lemma EndToEndTrace.overlay_right (etr : EndToEndTrace g₂ ρ₁ ρ₂) : EndToEndTrace (g₁ ∙ g₂) ρ₁ ρ₂ := by obtain ⟨i₁, h₁, i₂, h₂, tr⟩ := etr exact ⟨i₁.natAdd g₁.size, List.mem_append_right _ (List.mem_map_of_mem _ h₁), i₂.natAdd g₁.size, List.mem_append_right _ (List.mem_map_of_mem _ h₂), - tr.comp_right⟩ + tr.overlay_right⟩ lemma EndToEndTrace.concat {ρ₃ : Env} (etr₁ : EndToEndTrace g₁ ρ₁ ρ₂) (etr₂ : EndToEndTrace g₂ ρ₂ ρ₃) : EndToEndTrace (g₁ ⤳ g₂) ρ₁ ρ₃ := by @@ -86,7 +86,7 @@ lemma EndToEndTrace.concat {ρ₃ : Env} (etr₁ : EndToEndTrace g₁ ρ₁ ρ obtain ⟨j₁, k₁, j₂, k₂, tr₂⟩ := etr₂ refine ⟨i₁.castAdd g₂.size, List.mem_map_of_mem _ h₁, j₂.natAdd g₁.size, List.mem_map_of_mem _ k₂, - Trace.concat tr₁.link_left ?_ tr₂.link_right⟩ + Trace.concat tr₁.sequence_left ?_ tr₂.sequence_right⟩ exact List.mem_append_right _ (List.mem_product.mpr ⟨List.mem_map_of_mem _ h₂, List.mem_map_of_mem _ k₁⟩) @@ -182,9 +182,9 @@ theorem Stmt.cfg_sufficient {s : Stmt} {ρ₁ ρ₂ : Env} | andThen ρ₁ ρ₂ ρ₃ s₁ s₂ _ _ ih₁ ih₂ => exact ih₁.concat ih₂ | ifTrue ρ₁ ρ₂ e z s₁ s₂ _ _ _ ih => - exact ih.comp_left + exact ih.overlay_left | ifFalse ρ₁ ρ₂ e s₁ s₂ _ _ ih => - exact ih.comp_right + exact ih.overlay_right | whileTrue ρ₁ ρ₂ ρ₃ e z s _ _ _ _ ih₁ ih₂ => exact (ih₁.loop).loop_concat ih₂ | whileFalse ρ e s _ => @@ -204,7 +204,7 @@ lemma Graph.wrap_inputs (g : Graph) : lemma Graph.wrap_outputs (g : Graph) : (Graph.wrap g).outputs = [g.wrapOutput] := rfl -private lemma not_mem_edges_castAdd_link {g₂ : Graph} (i : Fin 1) +private lemma not_mem_edges_castAdd_sequence {g₂ : Graph} (i : Fin 1) (idx : (Graph.singleton [] ⤳ g₂).Index) : ((idx, i.castAdd g₂.size) : (Graph.singleton [] ⤳ g₂).Edge) ∉ (Graph.singleton [] ⤳ g₂).edges := by @@ -228,6 +228,6 @@ lemma Graph.wrap_predecessors_eq_nil (g : Graph) (idx : (Graph.wrap g).Index) subst h rw [GGraph.predecessors, List.filter_eq_nil_iff] intro idx' _ - simpa using not_mem_edges_castAdd_link (g₂ := g ⤳ Graph.singleton []) 0 idx' + simpa using not_mem_edges_castAdd_sequence (g₂ := g ⤳ Graph.singleton []) 0 idx' end Spa