From b1dc725cede68dbf9ce57b58220b2cd4211eef6e Mon Sep 17 00:00:00 2001 From: Danila Fedorin Date: Tue, 23 Jun 2026 14:00:06 -0500 Subject: [PATCH] Apply some cleanups to Graphs.lean --- lean/Spa/Language/Graphs.lean | 56 +++++++++++++++---------------- lean/Spa/Language/Properties.lean | 2 +- 2 files changed, 29 insertions(+), 29 deletions(-) diff --git a/lean/Spa/Language/Graphs.lean b/lean/Spa/Language/Graphs.lean index 7b89631..727ace0 100644 --- a/lean/Spa/Language/Graphs.lean +++ b/lean/Spa/Language/Graphs.lean @@ -3,6 +3,20 @@ import Mathlib.Data.Fin.Tuple.Basic import Mathlib.Data.List.ProdSigma import Mathlib.Data.List.FinRange +def List.finCastAdd {n : ℕ} (l : List (Fin n)) (m : ℕ) : List (Fin (n + m)) := + l.map (Fin.castAdd m) + +def List.finNatAdd {m : ℕ} (l : List (Fin m)) (n : ℕ) : List (Fin (n + m)) := + l.map (Fin.natAdd n) + +def List.finCastAddProd {n : ℕ} (l : List (Fin n × Fin n)) (m : ℕ) : + List (Fin (n + m) × Fin (n + m)) := + l.map (fun e => (e.1.castAdd m, e.2.castAdd m)) + +def List.finNatAddProd {m : ℕ} (l : List (Fin m × Fin m)) (n : ℕ) : + List (Fin (n + m) × Fin (n + m)) := + l.map (fun e => (e.1.natAdd n, e.2.natAdd n)) + namespace Spa structure Graph where @@ -18,36 +32,22 @@ abbrev Index (g : Graph) : Type := Fin g.size abbrev Edge (g : Graph) : Type := g.Index × g.Index -def liftIdxL {n : ℕ} (l : List (Fin n)) (m : ℕ) : List (Fin (n + m)) := - l.map (Fin.castAdd m) - -def liftIdxR (n : ℕ) {m : ℕ} (l : List (Fin m)) : List (Fin (n + m)) := - l.map (Fin.natAdd n) - -def liftEdgeL {n : ℕ} (l : List (Fin n × Fin n)) (m : ℕ) : - List (Fin (n + m) × Fin (n + m)) := - l.map (fun e => (e.1.castAdd m, e.2.castAdd m)) - -def liftEdgeR (n : ℕ) {m : ℕ} (l : List (Fin m × Fin m)) : - List (Fin (n + m) × Fin (n + m)) := - l.map (fun e => (e.1.natAdd n, e.2.natAdd n)) - def comp (g₁ g₂ : Graph) : Graph where size := g₁.size + g₂.size nodes := Fin.append g₁.nodes g₂.nodes - edges := liftEdgeL g₁.edges g₂.size ++ liftEdgeR g₁.size g₂.edges - inputs := liftIdxL g₁.inputs g₂.size ++ liftIdxR g₁.size g₂.inputs - outputs := liftIdxL g₁.outputs g₂.size ++ liftIdxR g₁.size g₂.outputs + edges := g₁.edges.finCastAddProd g₂.size ++ g₂.edges.finNatAddProd g₁.size + inputs := g₁.inputs.finCastAdd g₂.size ++ g₂.inputs.finNatAdd g₁.size + outputs := g₁.outputs.finCastAdd g₂.size ++ g₂.outputs.finNatAdd g₁.size @[inherit_doc] scoped infixr:70 " ∙ " => Graph.comp def link (g₁ g₂ : Graph) : Graph where size := g₁.size + g₂.size nodes := Fin.append g₁.nodes g₂.nodes - edges := liftEdgeL g₁.edges g₂.size ++ liftEdgeR g₁.size g₂.edges ++ - (liftIdxL g₁.outputs g₂.size).product (liftIdxR g₁.size g₂.inputs) - inputs := liftIdxL g₁.inputs g₂.size - outputs := liftIdxR g₁.size g₂.outputs + edges := g₁.edges.finCastAddProd g₂.size ++ g₂.edges.finNatAddProd g₁.size ++ + (g₁.outputs.finCastAdd g₂.size).product (g₂.inputs.finNatAdd g₁.size) + inputs := g₁.inputs.finCastAdd g₂.size + outputs := g₂.outputs.finNatAdd g₁.size @[inherit_doc] scoped infixr:70 " ⤳ " => Graph.link @@ -60,9 +60,9 @@ def loopOut (g : Graph) : Fin (2 + g.size) := (1 : Fin 2).castAdd g.size def loop (g : Graph) : Graph where size := 2 + g.size nodes := Fin.append (fun _ : Fin 2 => []) g.nodes - edges := liftEdgeR 2 g.edges ++ - (liftIdxR 2 g.inputs).map (g.loopIn, ·) ++ - (liftIdxR 2 g.outputs).map (·, g.loopOut) ++ + edges := g.edges.finNatAddProd 2 ++ + (g.inputs.finNatAdd 2).map (g.loopIn, ·) ++ + (g.outputs.finNatAdd 2).map (·, g.loopOut) ++ [(g.loopOut, g.loopIn), (g.loopIn, g.loopOut)] inputs := [g.loopIn] outputs := [g.loopOut] @@ -74,10 +74,10 @@ def loop (g : Graph) : Graph where def skipto (g₁ g₂ : Graph) : Graph where size := g₁.size + g₂.size nodes := Fin.append g₁.nodes g₂.nodes - edges := liftEdgeL g₁.edges g₂.size ++ liftEdgeR g₁.size g₂.edges ++ - (liftIdxL g₁.inputs g₂.size).product (liftIdxR g₁.size g₂.inputs) - inputs := liftIdxL g₁.inputs g₂.size - outputs := liftIdxR g₁.size g₂.inputs + edges := g₁.edges.finCastAddProd g₂.size ++ g₂.edges.finNatAddProd g₁.size ++ + (g₁.inputs.finCastAdd g₂.size).product (g₂.inputs.finNatAdd g₁.size) + inputs := g₁.inputs.finCastAdd g₂.size + outputs := g₂.inputs.finNatAdd g₁.size def singleton (bss : List BasicStmt) : Graph where size := 1 diff --git a/lean/Spa/Language/Properties.lean b/lean/Spa/Language/Properties.lean index ab4b2b0..909fd21 100644 --- a/lean/Spa/Language/Properties.lean +++ b/lean/Spa/Language/Properties.lean @@ -212,7 +212,7 @@ private theorem not_mem_edges_castAdd_link {g₂ : Graph} (i : Fin 1) rcases List.mem_append.mp h with h' | h' · rcases List.mem_append.mp h' with h'' | h'' · -- lifted edges of `singleton []`: there are none - simp [Graph.singleton, Graph.liftEdgeL] at h'' + simp [Graph.singleton, List.finCastAddProd] at h'' · -- lifted edges of g₂: targets are natAdd obtain ⟨e, _, heq⟩ := List.mem_map.mp h'' exact Fin.castAdd_ne_natAdd i e.2 (congrArg Prod.snd heq).symm