diff --git a/lean/Spa/Language/Graphs.lean b/lean/Spa/Language/Graphs.lean index ca1109e..de94ebd 100644 --- a/lean/Spa/Language/Graphs.lean +++ b/lean/Spa/Language/Graphs.lean @@ -19,29 +19,43 @@ def List.finNatAddProd {m : ℕ} (l : List (Fin m × Fin m)) (n : ℕ) : namespace Spa -structure Graph where +structure GGraph (α : Type) where size : ℕ - nodes : Fin size → List BasicStmt + nodes : Fin size → α edges : List (Fin size × Fin size) inputs : List (Fin size) outputs : List (Fin size) -namespace Graph +namespace GGraph -abbrev Index (g : Graph) : Type := Fin g.size +variable {α β : Type} -abbrev Edge (g : Graph) : Type := g.Index × g.Index +abbrev Index (g : GGraph α) : Type := Fin g.size -def comp (g₁ g₂ : Graph) : Graph where +abbrev Edge (g : GGraph α) : Type := g.Index × g.Index + +def map (f : α → β) (g : GGraph α) : GGraph β where + size := g.size + nodes := fun i => f (g.nodes i) + edges := g.edges + inputs := g.inputs + outputs := g.outputs + +@[simp] theorem map_size (f : α → β) (g : GGraph α) : (g.map f).size = g.size := rfl +@[simp] theorem map_edges (f : α → β) (g : GGraph α) : (g.map f).edges = g.edges := rfl +@[simp] theorem map_inputs (f : α → β) (g : GGraph α) : (g.map f).inputs = g.inputs := rfl +@[simp] theorem map_outputs (f : α → β) (g : GGraph α) : (g.map f).outputs = g.outputs := rfl + +def comp (g₁ g₂ : GGraph α) : GGraph α where size := g₁.size + g₂.size nodes := Fin.append g₁.nodes g₂.nodes edges := g₁.edges.finCastAddProd g₂.size ++ g₂.edges.finNatAddProd g₁.size inputs := g₁.inputs.finCastAdd g₂.size ++ g₂.inputs.finNatAdd g₁.size outputs := g₁.outputs.finCastAdd g₂.size ++ g₂.outputs.finNatAdd g₁.size -@[inherit_doc] scoped infixr:70 " ∙ " => Graph.comp +@[inherit_doc] scoped infixr:70 " ∙ " => GGraph.comp -def link (g₁ g₂ : Graph) : Graph where +def link (g₁ g₂ : GGraph α) : GGraph α where size := g₁.size + g₂.size nodes := Fin.append g₁.nodes g₂.nodes edges := g₁.edges.finCastAddProd g₂.size ++ g₂.edges.finNatAddProd g₁.size ++ @@ -49,15 +63,13 @@ def link (g₁ g₂ : Graph) : Graph where inputs := g₁.inputs.finCastAdd g₂.size outputs := g₂.outputs.finNatAdd g₁.size -@[inherit_doc] scoped infixr:70 " ⤳ " => Graph.link +@[inherit_doc] scoped infixr:70 " ⤳ " => GGraph.link -/-- The entry node of a `loop` graph. -/ -def loopIn (g : Graph) : Fin (2 + g.size) := (0 : Fin 2).castAdd g.size +def loopIn (g : GGraph α) : Fin (2 + g.size) := (0 : Fin 2).castAdd g.size -/-- The exit node of a `loop` graph. -/ -def loopOut (g : Graph) : Fin (2 + g.size) := (1 : Fin 2).castAdd g.size +def loopOut (g : GGraph α) : Fin (2 + g.size) := (1 : Fin 2).castAdd g.size -def loop (g : Graph) : Graph where +def loop (g : GGraph (List β)) : GGraph (List β) where size := 2 + g.size nodes := Fin.append (fun _ : Fin 2 => []) g.nodes edges := g.edges.finNatAddProd 2 ++ @@ -67,11 +79,11 @@ def loop (g : Graph) : Graph where inputs := [g.loopIn] outputs := [g.loopOut] -@[simp] theorem loop_inputs (g : Graph) : (loop g).inputs = [g.loopIn] := rfl +@[simp] theorem loop_inputs (g : GGraph (List β)) : (loop g).inputs = [g.loopIn] := rfl -@[simp] theorem loop_outputs (g : Graph) : (loop g).outputs = [g.loopOut] := rfl +@[simp] theorem loop_outputs (g : GGraph (List β)) : (loop g).outputs = [g.loopOut] := rfl -def skipto (g₁ g₂ : Graph) : Graph where +def skipto (g₁ g₂ : GGraph α) : GGraph α where size := g₁.size + g₂.size nodes := Fin.append g₁.nodes g₂.nodes edges := g₁.edges.finCastAddProd g₂.size ++ g₂.edges.finNatAddProd g₁.size ++ @@ -79,17 +91,17 @@ def skipto (g₁ g₂ : Graph) : Graph where inputs := g₁.inputs.finCastAdd g₂.size outputs := g₂.inputs.finNatAdd g₁.size -def singleton (bss : List BasicStmt) : Graph where +def singleton (a : α) : GGraph α where size := 1 - nodes := fun _ => bss + nodes := fun _ => a edges := [] inputs := [0] outputs := [0] -def wrap (g : Graph) : Graph := +def wrap (g : GGraph (List β)) : GGraph (List β) := singleton [] ⤳ g ⤳ singleton [] -variable (g : Graph) +variable (g : GGraph α) def indices : List g.Index := List.finRange g.size @@ -110,13 +122,24 @@ theorem edge_of_mem_predecessors {idx₁ idx₂ : g.Index} (h : idx₁ ∈ g.predecessors idx₂) : (idx₁, idx₂) ∈ g.edges := by simpa using (List.mem_filter.mp h).2 +end GGraph + +abbrev Graph : Type := GGraph (List BasicStmt) + +namespace Graph + +export GGraph (comp link loop skipto singleton wrap loop_inputs loop_outputs) + +@[inherit_doc] scoped infixr:70 " ∙ " => GGraph.comp +@[inherit_doc] scoped infixr:70 " ⤳ " => GGraph.link + end Graph open Graph in def buildCfg : Stmt → Graph - | .basic bs => Graph.singleton [bs] + | .basic bs => singleton [bs] | .andThen s₁ s₂ => buildCfg s₁ ⤳ buildCfg s₂ | .ifElse _ s₁ s₂ => buildCfg s₁ ∙ buildCfg s₂ - | .whileLoop _ s => Graph.loop (buildCfg s) + | .whileLoop _ s => loop (buildCfg s) end Spa diff --git a/lean/Spa/Language/Properties.lean b/lean/Spa/Language/Properties.lean index 909fd21..697e715 100644 --- a/lean/Spa/Language/Properties.lean +++ b/lean/Spa/Language/Properties.lean @@ -226,7 +226,7 @@ theorem Graph.wrap_predecessors_eq_nil (g : Graph) (idx : (Graph.wrap g).Index) (Graph.wrap g).predecessors idx = [] := by rw [Graph.wrap_inputs, List.mem_singleton] at h subst h - rw [Graph.predecessors, List.filter_eq_nil_iff] + rw [GGraph.predecessors, List.filter_eq_nil_iff] intro idx' _ simpa using not_mem_edges_castAdd_link (g₂ := g ⤳ Graph.singleton []) 0 idx'