Generalize graphs over their node content
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -19,29 +19,43 @@ def List.finNatAddProd {m : ℕ} (l : List (Fin m × Fin m)) (n : ℕ) :
|
|||||||
|
|
||||||
namespace Spa
|
namespace Spa
|
||||||
|
|
||||||
structure Graph where
|
structure GGraph (α : Type) where
|
||||||
size : ℕ
|
size : ℕ
|
||||||
nodes : Fin size → List BasicStmt
|
nodes : Fin size → α
|
||||||
edges : List (Fin size × Fin size)
|
edges : List (Fin size × Fin size)
|
||||||
inputs : List (Fin size)
|
inputs : List (Fin size)
|
||||||
outputs : List (Fin size)
|
outputs : List (Fin size)
|
||||||
|
|
||||||
namespace Graph
|
namespace GGraph
|
||||||
|
|
||||||
abbrev Index (g : Graph) : Type := Fin g.size
|
variable {α β : Type}
|
||||||
|
|
||||||
abbrev Edge (g : Graph) : Type := g.Index × g.Index
|
abbrev Index (g : GGraph α) : Type := Fin g.size
|
||||||
|
|
||||||
def comp (g₁ g₂ : Graph) : Graph where
|
abbrev Edge (g : GGraph α) : Type := g.Index × g.Index
|
||||||
|
|
||||||
|
def map (f : α → β) (g : GGraph α) : GGraph β where
|
||||||
|
size := g.size
|
||||||
|
nodes := fun i => f (g.nodes i)
|
||||||
|
edges := g.edges
|
||||||
|
inputs := g.inputs
|
||||||
|
outputs := g.outputs
|
||||||
|
|
||||||
|
@[simp] theorem map_size (f : α → β) (g : GGraph α) : (g.map f).size = g.size := rfl
|
||||||
|
@[simp] theorem map_edges (f : α → β) (g : GGraph α) : (g.map f).edges = g.edges := rfl
|
||||||
|
@[simp] theorem map_inputs (f : α → β) (g : GGraph α) : (g.map f).inputs = g.inputs := rfl
|
||||||
|
@[simp] theorem map_outputs (f : α → β) (g : GGraph α) : (g.map f).outputs = g.outputs := rfl
|
||||||
|
|
||||||
|
def comp (g₁ g₂ : GGraph α) : GGraph α where
|
||||||
size := g₁.size + g₂.size
|
size := g₁.size + g₂.size
|
||||||
nodes := Fin.append g₁.nodes g₂.nodes
|
nodes := Fin.append g₁.nodes g₂.nodes
|
||||||
edges := g₁.edges.finCastAddProd g₂.size ++ g₂.edges.finNatAddProd g₁.size
|
edges := g₁.edges.finCastAddProd g₂.size ++ g₂.edges.finNatAddProd g₁.size
|
||||||
inputs := g₁.inputs.finCastAdd g₂.size ++ g₂.inputs.finNatAdd g₁.size
|
inputs := g₁.inputs.finCastAdd g₂.size ++ g₂.inputs.finNatAdd g₁.size
|
||||||
outputs := g₁.outputs.finCastAdd g₂.size ++ g₂.outputs.finNatAdd g₁.size
|
outputs := g₁.outputs.finCastAdd g₂.size ++ g₂.outputs.finNatAdd g₁.size
|
||||||
|
|
||||||
@[inherit_doc] scoped infixr:70 " ∙ " => Graph.comp
|
@[inherit_doc] scoped infixr:70 " ∙ " => GGraph.comp
|
||||||
|
|
||||||
def link (g₁ g₂ : Graph) : Graph where
|
def link (g₁ g₂ : GGraph α) : GGraph α where
|
||||||
size := g₁.size + g₂.size
|
size := g₁.size + g₂.size
|
||||||
nodes := Fin.append g₁.nodes g₂.nodes
|
nodes := Fin.append g₁.nodes g₂.nodes
|
||||||
edges := g₁.edges.finCastAddProd g₂.size ++ g₂.edges.finNatAddProd g₁.size ++
|
edges := g₁.edges.finCastAddProd g₂.size ++ g₂.edges.finNatAddProd g₁.size ++
|
||||||
@@ -49,15 +63,13 @@ def link (g₁ g₂ : Graph) : Graph where
|
|||||||
inputs := g₁.inputs.finCastAdd g₂.size
|
inputs := g₁.inputs.finCastAdd g₂.size
|
||||||
outputs := g₂.outputs.finNatAdd g₁.size
|
outputs := g₂.outputs.finNatAdd g₁.size
|
||||||
|
|
||||||
@[inherit_doc] scoped infixr:70 " ⤳ " => Graph.link
|
@[inherit_doc] scoped infixr:70 " ⤳ " => GGraph.link
|
||||||
|
|
||||||
/-- The entry node of a `loop` graph. -/
|
def loopIn (g : GGraph α) : Fin (2 + g.size) := (0 : Fin 2).castAdd g.size
|
||||||
def loopIn (g : Graph) : Fin (2 + g.size) := (0 : Fin 2).castAdd g.size
|
|
||||||
|
|
||||||
/-- The exit node of a `loop` graph. -/
|
def loopOut (g : GGraph α) : Fin (2 + g.size) := (1 : Fin 2).castAdd g.size
|
||||||
def loopOut (g : Graph) : Fin (2 + g.size) := (1 : Fin 2).castAdd g.size
|
|
||||||
|
|
||||||
def loop (g : Graph) : Graph where
|
def loop (g : GGraph (List β)) : GGraph (List β) where
|
||||||
size := 2 + g.size
|
size := 2 + g.size
|
||||||
nodes := Fin.append (fun _ : Fin 2 => []) g.nodes
|
nodes := Fin.append (fun _ : Fin 2 => []) g.nodes
|
||||||
edges := g.edges.finNatAddProd 2 ++
|
edges := g.edges.finNatAddProd 2 ++
|
||||||
@@ -67,11 +79,11 @@ def loop (g : Graph) : Graph where
|
|||||||
inputs := [g.loopIn]
|
inputs := [g.loopIn]
|
||||||
outputs := [g.loopOut]
|
outputs := [g.loopOut]
|
||||||
|
|
||||||
@[simp] theorem loop_inputs (g : Graph) : (loop g).inputs = [g.loopIn] := rfl
|
@[simp] theorem loop_inputs (g : GGraph (List β)) : (loop g).inputs = [g.loopIn] := rfl
|
||||||
|
|
||||||
@[simp] theorem loop_outputs (g : Graph) : (loop g).outputs = [g.loopOut] := rfl
|
@[simp] theorem loop_outputs (g : GGraph (List β)) : (loop g).outputs = [g.loopOut] := rfl
|
||||||
|
|
||||||
def skipto (g₁ g₂ : Graph) : Graph where
|
def skipto (g₁ g₂ : GGraph α) : GGraph α where
|
||||||
size := g₁.size + g₂.size
|
size := g₁.size + g₂.size
|
||||||
nodes := Fin.append g₁.nodes g₂.nodes
|
nodes := Fin.append g₁.nodes g₂.nodes
|
||||||
edges := g₁.edges.finCastAddProd g₂.size ++ g₂.edges.finNatAddProd g₁.size ++
|
edges := g₁.edges.finCastAddProd g₂.size ++ g₂.edges.finNatAddProd g₁.size ++
|
||||||
@@ -79,17 +91,17 @@ def skipto (g₁ g₂ : Graph) : Graph where
|
|||||||
inputs := g₁.inputs.finCastAdd g₂.size
|
inputs := g₁.inputs.finCastAdd g₂.size
|
||||||
outputs := g₂.inputs.finNatAdd g₁.size
|
outputs := g₂.inputs.finNatAdd g₁.size
|
||||||
|
|
||||||
def singleton (bss : List BasicStmt) : Graph where
|
def singleton (a : α) : GGraph α where
|
||||||
size := 1
|
size := 1
|
||||||
nodes := fun _ => bss
|
nodes := fun _ => a
|
||||||
edges := []
|
edges := []
|
||||||
inputs := [0]
|
inputs := [0]
|
||||||
outputs := [0]
|
outputs := [0]
|
||||||
|
|
||||||
def wrap (g : Graph) : Graph :=
|
def wrap (g : GGraph (List β)) : GGraph (List β) :=
|
||||||
singleton [] ⤳ g ⤳ singleton []
|
singleton [] ⤳ g ⤳ singleton []
|
||||||
|
|
||||||
variable (g : Graph)
|
variable (g : GGraph α)
|
||||||
|
|
||||||
def indices : List g.Index := List.finRange g.size
|
def indices : List g.Index := List.finRange g.size
|
||||||
|
|
||||||
@@ -110,13 +122,24 @@ theorem edge_of_mem_predecessors {idx₁ idx₂ : g.Index}
|
|||||||
(h : idx₁ ∈ g.predecessors idx₂) : (idx₁, idx₂) ∈ g.edges := by
|
(h : idx₁ ∈ g.predecessors idx₂) : (idx₁, idx₂) ∈ g.edges := by
|
||||||
simpa using (List.mem_filter.mp h).2
|
simpa using (List.mem_filter.mp h).2
|
||||||
|
|
||||||
|
end GGraph
|
||||||
|
|
||||||
|
abbrev Graph : Type := GGraph (List BasicStmt)
|
||||||
|
|
||||||
|
namespace Graph
|
||||||
|
|
||||||
|
export GGraph (comp link loop skipto singleton wrap loop_inputs loop_outputs)
|
||||||
|
|
||||||
|
@[inherit_doc] scoped infixr:70 " ∙ " => GGraph.comp
|
||||||
|
@[inherit_doc] scoped infixr:70 " ⤳ " => GGraph.link
|
||||||
|
|
||||||
end Graph
|
end Graph
|
||||||
|
|
||||||
open Graph in
|
open Graph in
|
||||||
def buildCfg : Stmt → Graph
|
def buildCfg : Stmt → Graph
|
||||||
| .basic bs => Graph.singleton [bs]
|
| .basic bs => singleton [bs]
|
||||||
| .andThen s₁ s₂ => buildCfg s₁ ⤳ buildCfg s₂
|
| .andThen s₁ s₂ => buildCfg s₁ ⤳ buildCfg s₂
|
||||||
| .ifElse _ s₁ s₂ => buildCfg s₁ ∙ buildCfg s₂
|
| .ifElse _ s₁ s₂ => buildCfg s₁ ∙ buildCfg s₂
|
||||||
| .whileLoop _ s => Graph.loop (buildCfg s)
|
| .whileLoop _ s => loop (buildCfg s)
|
||||||
|
|
||||||
end Spa
|
end Spa
|
||||||
|
|||||||
@@ -226,7 +226,7 @@ theorem Graph.wrap_predecessors_eq_nil (g : Graph) (idx : (Graph.wrap g).Index)
|
|||||||
(Graph.wrap g).predecessors idx = [] := by
|
(Graph.wrap g).predecessors idx = [] := by
|
||||||
rw [Graph.wrap_inputs, List.mem_singleton] at h
|
rw [Graph.wrap_inputs, List.mem_singleton] at h
|
||||||
subst h
|
subst h
|
||||||
rw [Graph.predecessors, List.filter_eq_nil_iff]
|
rw [GGraph.predecessors, List.filter_eq_nil_iff]
|
||||||
intro idx' _
|
intro idx' _
|
||||||
simpa using not_mem_edges_castAdd_link (g₂ := g ⤳ Graph.singleton []) 0 idx'
|
simpa using not_mem_edges_castAdd_link (g₂ := g ⤳ Graph.singleton []) 0 idx'
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user