Files
agda-spa/lean/Spa/Language/Graphs.lean
2026-06-25 09:26:15 -05:00

177 lines
6.2 KiB
Lean4
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import Spa.Language.Base
import Mathlib.Data.Fin.Tuple.Basic
import Mathlib.Data.List.ProdSigma
import Mathlib.Data.List.FinRange
def List.finCastAdd {n : } (l : List (Fin n)) (m : ) : List (Fin (n + m)) :=
l.map (Fin.castAdd m)
def List.finNatAdd {m : } (l : List (Fin m)) (n : ) : List (Fin (n + m)) :=
l.map (Fin.natAdd n)
def List.finCastAddProd {n : } (l : List (Fin n × Fin n)) (m : ) :
List (Fin (n + m) × Fin (n + m)) :=
l.map (fun e => (e.1.castAdd m, e.2.castAdd m))
def List.finNatAddProd {m : } (l : List (Fin m × Fin m)) (n : ) :
List (Fin (n + m) × Fin (n + m)) :=
l.map (fun e => (e.1.natAdd n, e.2.natAdd n))
namespace Spa
structure GGraph (α : Type) where
size :
nodes : Fin size α
edges : List (Fin size × Fin size)
inputs : List (Fin size)
outputs : List (Fin size)
namespace GGraph
variable {α β : Type}
abbrev Index (g : GGraph α) : Type := Fin g.size
abbrev Edge (g : GGraph α) : Type := g.Index × g.Index
def map (f : α β) (g : GGraph α) : GGraph β where
size := g.size
nodes := fun i => f (g.nodes i)
edges := g.edges
inputs := g.inputs
outputs := g.outputs
@[simp] theorem map_size (f : α β) (g : GGraph α) : (g.map f).size = g.size := rfl
@[simp] theorem map_edges (f : α β) (g : GGraph α) : (g.map f).edges = g.edges := rfl
@[simp] theorem map_inputs (f : α β) (g : GGraph α) : (g.map f).inputs = g.inputs := rfl
@[simp] theorem map_outputs (f : α β) (g : GGraph α) : (g.map f).outputs = g.outputs := rfl
def comp (g₁ g₂ : GGraph α) : GGraph α where
size := g₁.size + g₂.size
nodes := Fin.append g₁.nodes g₂.nodes
edges := g₁.edges.finCastAddProd g₂.size ++ g₂.edges.finNatAddProd g₁.size
inputs := g₁.inputs.finCastAdd g₂.size ++ g₂.inputs.finNatAdd g₁.size
outputs := g₁.outputs.finCastAdd g₂.size ++ g₂.outputs.finNatAdd g₁.size
@[inherit_doc] scoped infixr:70 "" => GGraph.comp
def link (g₁ g₂ : GGraph α) : GGraph α where
size := g₁.size + g₂.size
nodes := Fin.append g₁.nodes g₂.nodes
edges := g₁.edges.finCastAddProd g₂.size ++ g₂.edges.finNatAddProd g₁.size ++
(g₁.outputs.finCastAdd g₂.size).product (g₂.inputs.finNatAdd g₁.size)
inputs := g₁.inputs.finCastAdd g₂.size
outputs := g₂.outputs.finNatAdd g₁.size
@[inherit_doc] scoped infixr:70 "" => GGraph.link
def loopIn (g : GGraph α) : Fin (2 + g.size) := (0 : Fin 2).castAdd g.size
def loopOut (g : GGraph α) : Fin (2 + g.size) := (1 : Fin 2).castAdd g.size
def loop (g : GGraph (List β)) : GGraph (List β) where
size := 2 + g.size
nodes := Fin.append (fun _ : Fin 2 => []) g.nodes
edges := g.edges.finNatAddProd 2 ++
(g.inputs.finNatAdd 2).map (g.loopIn, ·) ++
(g.outputs.finNatAdd 2).map (·, g.loopOut) ++
[(g.loopOut, g.loopIn), (g.loopIn, g.loopOut)]
inputs := [g.loopIn]
outputs := [g.loopOut]
@[simp] theorem loop_inputs (g : GGraph (List β)) : (loop g).inputs = [g.loopIn] := rfl
@[simp] theorem loop_outputs (g : GGraph (List β)) : (loop g).outputs = [g.loopOut] := rfl
def skipto (g₁ g₂ : GGraph α) : GGraph α where
size := g₁.size + g₂.size
nodes := Fin.append g₁.nodes g₂.nodes
edges := g₁.edges.finCastAddProd g₂.size ++ g₂.edges.finNatAddProd g₁.size ++
(g₁.inputs.finCastAdd g₂.size).product (g₂.inputs.finNatAdd g₁.size)
inputs := g₁.inputs.finCastAdd g₂.size
outputs := g₂.inputs.finNatAdd g₁.size
def singleton (a : α) : GGraph α where
size := 1
nodes := fun _ => a
edges := []
inputs := [0]
outputs := [0]
def wrap (g : GGraph (List β)) : GGraph (List β) :=
singleton [] g singleton []
@[simp] theorem map_singleton (f : α β) (a : α) :
(singleton a).map f = singleton (f a) := rfl
@[simp] theorem map_comp (f : α β) (g₁ g₂ : GGraph α) :
(g₁ g₂).map f = g₁.map f g₂.map f := by
rcases g₁ with n₁, nd₁, e₁, i₁, o₁; rcases g₂ with n₂, nd₂, e₂, i₂, o₂
simp only [GGraph.map, GGraph.comp]
congr 1
funext i
refine Fin.addCases ?_ ?_ i <;> intro j <;> simp [Fin.append_left, Fin.append_right]
@[simp] theorem map_link (f : α β) (g₁ g₂ : GGraph α) :
(g₁ g₂).map f = g₁.map f g₂.map f := by
rcases g₁ with n₁, nd₁, e₁, i₁, o₁; rcases g₂ with n₂, nd₂, e₂, i₂, o₂
simp only [GGraph.map, GGraph.link]
congr 1
funext i
refine Fin.addCases ?_ ?_ i <;> intro j <;> simp [Fin.append_left, Fin.append_right]
@[simp] theorem map_loop (h : β γ) (g : GGraph (List β)) :
(loop g).map (List.map h) = loop (g.map (List.map h)) := by
rcases g with n, nd, e, i, o
simp only [GGraph.map, GGraph.loop]
congr 1
funext i
refine Fin.addCases ?_ ?_ i <;> intro j <;> simp [Fin.append_left, Fin.append_right]
@[simp] theorem map_wrap (h : β γ) (g : GGraph (List β)) :
(wrap g).map (List.map h) = wrap (g.map (List.map h)) := by
simp [GGraph.wrap, GGraph.map_link, GGraph.map_singleton]
variable (g : GGraph α)
def indices : List g.Index := List.finRange g.size
theorem mem_indices (idx : g.Index) : idx g.indices :=
List.mem_finRange idx
theorem nodup_indices : g.indices.Nodup :=
List.nodup_finRange g.size
def predecessors (idx : g.Index) : List g.Index :=
g.indices.filter (fun idx' => (idx', idx) g.edges)
theorem mem_predecessors_of_edge {idx₁ idx₂ : g.Index}
(h : (idx₁, idx₂) g.edges) : idx₁ g.predecessors idx₂ :=
List.mem_filter.mpr g.mem_indices idx₁, by simpa using h
theorem edge_of_mem_predecessors {idx₁ idx₂ : g.Index}
(h : idx₁ g.predecessors idx₂) : (idx₁, idx₂) g.edges := by
simpa using (List.mem_filter.mp h).2
end GGraph
abbrev Graph : Type := GGraph (List BasicStmt)
namespace Graph
export GGraph (comp link loop skipto singleton wrap loop_inputs loop_outputs)
@[inherit_doc] scoped infixr:70 "" => GGraph.comp
@[inherit_doc] scoped infixr:70 "" => GGraph.link
end Graph
open Graph in
def buildCfg : Stmt Graph
| .basic bs => singleton [bs]
| .andThen s₁ s₂ => buildCfg s₁ buildCfg s₂
| .ifElse _ s₁ s₂ => buildCfg s₁ buildCfg s₂
| .whileLoop _ s => loop (buildCfg s)
end Spa