Files
agda-spa/lean/Spa/Language/Graphs.lean

127 lines
4.0 KiB
Lean4
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import Spa.Language.Base
import Mathlib.Data.Fin.Tuple.Basic
import Mathlib.Data.List.ProdSigma
import Mathlib.Data.List.FinRange
namespace Spa
structure Graph where
size :
nodes : Fin size List BasicStmt
edges : List (Fin size × Fin size)
inputs : List (Fin size)
outputs : List (Fin size)
namespace Graph
abbrev Index (g : Graph) : Type := Fin g.size
abbrev Edge (g : Graph) : Type := g.Index × g.Index
def liftIdxL {n : } (l : List (Fin n)) (m : ) : List (Fin (n + m)) :=
l.map (Fin.castAdd m)
def liftIdxR (n : ) {m : } (l : List (Fin m)) : List (Fin (n + m)) :=
l.map (Fin.natAdd n)
def liftEdgeL {n : } (l : List (Fin n × Fin n)) (m : ) :
List (Fin (n + m) × Fin (n + m)) :=
l.map (fun e => (e.1.castAdd m, e.2.castAdd m))
def liftEdgeR (n : ) {m : } (l : List (Fin m × Fin m)) :
List (Fin (n + m) × Fin (n + m)) :=
l.map (fun e => (e.1.natAdd n, e.2.natAdd n))
def comp (g₁ g₂ : Graph) : Graph where
size := g₁.size + g₂.size
nodes := Fin.append g₁.nodes g₂.nodes
edges := liftEdgeL g₁.edges g₂.size ++ liftEdgeR g₁.size g₂.edges
inputs := liftIdxL g₁.inputs g₂.size ++ liftIdxR g₁.size g₂.inputs
outputs := liftIdxL g₁.outputs g₂.size ++ liftIdxR g₁.size g₂.outputs
@[inherit_doc] scoped infixr:70 "" => Graph.comp
def link (g₁ g₂ : Graph) : Graph where
size := g₁.size + g₂.size
nodes := Fin.append g₁.nodes g₂.nodes
edges := liftEdgeL g₁.edges g₂.size ++ liftEdgeR g₁.size g₂.edges ++
(liftIdxL g₁.outputs g₂.size).product (liftIdxR g₁.size g₂.inputs)
inputs := liftIdxL g₁.inputs g₂.size
outputs := liftIdxR g₁.size g₂.outputs
@[inherit_doc] scoped infixr:70 "" => Graph.link
/-- The entry node of a `loop` graph. -/
def loopIn (g : Graph) : Fin (2 + g.size) := (0 : Fin 2).castAdd g.size
/-- The exit node of a `loop` graph. -/
def loopOut (g : Graph) : Fin (2 + g.size) := (1 : Fin 2).castAdd g.size
def loop (g : Graph) : Graph where
size := 2 + g.size
nodes := Fin.append (fun _ : Fin 2 => []) g.nodes
edges := liftEdgeR 2 g.edges ++
(liftIdxR 2 g.inputs).map (g.loopIn, ·) ++
(liftIdxR 2 g.outputs).map (·, g.loopOut) ++
[(g.loopOut, g.loopIn), (g.loopIn, g.loopOut)]
inputs := [g.loopIn]
outputs := [g.loopOut]
@[simp] theorem loop_inputs (g : Graph) : (loop g).inputs = [g.loopIn] := rfl
@[simp] theorem loop_outputs (g : Graph) : (loop g).outputs = [g.loopOut] := rfl
def skipto (g₁ g₂ : Graph) : Graph where
size := g₁.size + g₂.size
nodes := Fin.append g₁.nodes g₂.nodes
edges := liftEdgeL g₁.edges g₂.size ++ liftEdgeR g₁.size g₂.edges ++
(liftIdxL g₁.inputs g₂.size).product (liftIdxR g₁.size g₂.inputs)
inputs := liftIdxL g₁.inputs g₂.size
outputs := liftIdxR g₁.size g₂.inputs
def singleton (bss : List BasicStmt) : Graph where
size := 1
nodes := fun _ => bss
edges := []
inputs := [0]
outputs := [0]
def wrap (g : Graph) : Graph :=
singleton [] g singleton []
end Graph
open Graph in
def buildCfg : Stmt Graph
| .basic bs => Graph.singleton [bs]
| .andThen s₁ s₂ => buildCfg s₁ buildCfg s₂
| .ifElse _ s₁ s₂ => buildCfg s₁ buildCfg s₂
| .whileLoop _ s => Graph.loop (buildCfg s)
namespace Graph
variable (g : Graph)
def indices : List g.Index := List.finRange g.size
theorem mem_indices (idx : g.Index) : idx g.indices :=
List.mem_finRange idx
theorem nodup_indices : g.indices.Nodup :=
List.nodup_finRange g.size
def predecessors (idx : g.Index) : List g.Index :=
g.indices.filter (fun idx' => (idx', idx) g.edges)
theorem mem_predecessors_of_edge {idx₁ idx₂ : g.Index}
(h : (idx₁, idx₂) g.edges) : idx₁ g.predecessors idx₂ :=
List.mem_filter.mpr g.mem_indices idx₁, by simpa using h
theorem edge_of_mem_predecessors {idx₁ idx₂ : g.Index}
(h : idx₁ g.predecessors idx₂) : (idx₁, idx₂) g.edges := by
simpa using (List.mem_filter.mp h).2
end Graph
end Spa