Generalize graphs over their node content

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-24 16:02:49 -05:00
parent 9ab43b34ef
commit a721a8be8b
2 changed files with 47 additions and 24 deletions

View File

@@ -19,29 +19,43 @@ def List.finNatAddProd {m : } (l : List (Fin m × Fin m)) (n : ) :
namespace Spa
structure Graph where
structure GGraph (α : Type) where
size :
nodes : Fin size List BasicStmt
nodes : Fin size α
edges : List (Fin size × Fin size)
inputs : List (Fin size)
outputs : List (Fin size)
namespace Graph
namespace GGraph
abbrev Index (g : Graph) : Type := Fin g.size
variable {α β : Type}
abbrev Edge (g : Graph) : Type := g.Index × g.Index
abbrev Index (g : GGraph α) : Type := Fin g.size
def comp (g g₂ : Graph) : Graph where
abbrev Edge (g : GGraph α) : Type := g.Index × g.Index
def map (f : α β) (g : GGraph α) : GGraph β where
size := g.size
nodes := fun i => f (g.nodes i)
edges := g.edges
inputs := g.inputs
outputs := g.outputs
@[simp] theorem map_size (f : α β) (g : GGraph α) : (g.map f).size = g.size := rfl
@[simp] theorem map_edges (f : α β) (g : GGraph α) : (g.map f).edges = g.edges := rfl
@[simp] theorem map_inputs (f : α β) (g : GGraph α) : (g.map f).inputs = g.inputs := rfl
@[simp] theorem map_outputs (f : α β) (g : GGraph α) : (g.map f).outputs = g.outputs := rfl
def comp (g₁ g₂ : GGraph α) : GGraph α where
size := g₁.size + g₂.size
nodes := Fin.append g₁.nodes g₂.nodes
edges := g₁.edges.finCastAddProd g₂.size ++ g₂.edges.finNatAddProd g₁.size
inputs := g₁.inputs.finCastAdd g₂.size ++ g₂.inputs.finNatAdd g₁.size
outputs := g₁.outputs.finCastAdd g₂.size ++ g₂.outputs.finNatAdd g₁.size
@[inherit_doc] scoped infixr:70 "" => Graph.comp
@[inherit_doc] scoped infixr:70 "" => GGraph.comp
def link (g₁ g₂ : Graph) : Graph where
def link (g₁ g₂ : GGraph α) : GGraph α where
size := g₁.size + g₂.size
nodes := Fin.append g₁.nodes g₂.nodes
edges := g₁.edges.finCastAddProd g₂.size ++ g₂.edges.finNatAddProd g₁.size ++
@@ -49,15 +63,13 @@ def link (g₁ g₂ : Graph) : Graph where
inputs := g₁.inputs.finCastAdd g₂.size
outputs := g₂.outputs.finNatAdd g₁.size
@[inherit_doc] scoped infixr:70 "" => Graph.link
@[inherit_doc] scoped infixr:70 "" => GGraph.link
/-- The entry node of a `loop` graph. -/
def loopIn (g : Graph) : Fin (2 + g.size) := (0 : Fin 2).castAdd g.size
def loopIn (g : GGraph α) : Fin (2 + g.size) := (0 : Fin 2).castAdd g.size
/-- The exit node of a `loop` graph. -/
def loopOut (g : Graph) : Fin (2 + g.size) := (1 : Fin 2).castAdd g.size
def loopOut (g : GGraph α) : Fin (2 + g.size) := (1 : Fin 2).castAdd g.size
def loop (g : Graph) : Graph where
def loop (g : GGraph (List β)) : GGraph (List β) where
size := 2 + g.size
nodes := Fin.append (fun _ : Fin 2 => []) g.nodes
edges := g.edges.finNatAddProd 2 ++
@@ -67,11 +79,11 @@ def loop (g : Graph) : Graph where
inputs := [g.loopIn]
outputs := [g.loopOut]
@[simp] theorem loop_inputs (g : Graph) : (loop g).inputs = [g.loopIn] := rfl
@[simp] theorem loop_inputs (g : GGraph (List β)) : (loop g).inputs = [g.loopIn] := rfl
@[simp] theorem loop_outputs (g : Graph) : (loop g).outputs = [g.loopOut] := rfl
@[simp] theorem loop_outputs (g : GGraph (List β)) : (loop g).outputs = [g.loopOut] := rfl
def skipto (g₁ g₂ : Graph) : Graph where
def skipto (g₁ g₂ : GGraph α) : GGraph α where
size := g₁.size + g₂.size
nodes := Fin.append g₁.nodes g₂.nodes
edges := g₁.edges.finCastAddProd g₂.size ++ g₂.edges.finNatAddProd g₁.size ++
@@ -79,17 +91,17 @@ def skipto (g₁ g₂ : Graph) : Graph where
inputs := g₁.inputs.finCastAdd g₂.size
outputs := g₂.inputs.finNatAdd g₁.size
def singleton (bss : List BasicStmt) : Graph where
def singleton (a : α) : GGraph α where
size := 1
nodes := fun _ => bss
nodes := fun _ => a
edges := []
inputs := [0]
outputs := [0]
def wrap (g : Graph) : Graph :=
def wrap (g : GGraph (List β)) : GGraph (List β) :=
singleton [] g singleton []
variable (g : Graph)
variable (g : GGraph α)
def indices : List g.Index := List.finRange g.size
@@ -110,13 +122,24 @@ theorem edge_of_mem_predecessors {idx₁ idx₂ : g.Index}
(h : idx₁ g.predecessors idx₂) : (idx₁, idx₂) g.edges := by
simpa using (List.mem_filter.mp h).2
end GGraph
abbrev Graph : Type := GGraph (List BasicStmt)
namespace Graph
export GGraph (comp link loop skipto singleton wrap loop_inputs loop_outputs)
@[inherit_doc] scoped infixr:70 "" => GGraph.comp
@[inherit_doc] scoped infixr:70 "" => GGraph.link
end Graph
open Graph in
def buildCfg : Stmt Graph
| .basic bs => Graph.singleton [bs]
| .basic bs => singleton [bs]
| .andThen s₁ s₂ => buildCfg s₁ buildCfg s₂
| .ifElse _ s₁ s₂ => buildCfg s₁ buildCfg s₂
| .whileLoop _ s => Graph.loop (buildCfg s)
| .whileLoop _ s => loop (buildCfg s)
end Spa

View File

@@ -226,7 +226,7 @@ theorem Graph.wrap_predecessors_eq_nil (g : Graph) (idx : (Graph.wrap g).Index)
(Graph.wrap g).predecessors idx = [] := by
rw [Graph.wrap_inputs, List.mem_singleton] at h
subst h
rw [Graph.predecessors, List.filter_eq_nil_iff]
rw [GGraph.predecessors, List.filter_eq_nil_iff]
intro idx' _
simpa using not_mem_edges_castAdd_link (g₂ := g Graph.singleton []) 0 idx'