Require Import Coq.Lists.List.
From Ltac2 Require Import Ltac2.

Inductive intrinsic :=
  | swap
  | clone
  | drop
  | quote
  | compose
  | apply.

Inductive expr :=
  | e_int (i : intrinsic)
  | e_quote (e : expr)
  | e_comp (e1 e2 : expr).

Definition e_compose (e : expr) (es : list expr) := fold_left e_comp es e.

Inductive value := v_quote (e : expr).
Definition value_stack := list value.

Definition value_to_expr (v : value) : expr :=
  match v with
  | v_quote e => e_quote e
  end.

Inductive Sem_int : value_stack -> intrinsic -> value_stack -> Prop :=
  | Sem_swap : forall (v v' : value) (vs : value_stack), Sem_int (v' :: v :: vs) swap (v :: v' :: vs)
  | Sem_clone : forall (v : value) (vs : value_stack), Sem_int (v :: vs) clone (v :: v :: vs)
  | Sem_drop : forall (v : value) (vs : value_stack), Sem_int (v :: vs) drop vs
  | Sem_quote : forall (v : value) (vs : value_stack), Sem_int (v :: vs) quote ((v_quote (value_to_expr v)) :: vs)
  | Sem_compose : forall (e e' : expr) (vs : value_stack), Sem_int (v_quote e' :: v_quote e :: vs) compose (v_quote (e_comp e e') :: vs)
  | Sem_apply : forall (e : expr) (vs vs': value_stack), Sem_expr vs e vs' -> Sem_int (v_quote e :: vs) apply vs'

with Sem_expr : value_stack -> expr -> value_stack -> Prop :=
  | Sem_e_int : forall (i : intrinsic) (vs vs' : value_stack), Sem_int vs i vs' -> Sem_expr vs (e_int i) vs'
  | Sem_e_quote : forall (e : expr) (vs : value_stack), Sem_expr vs (e_quote e) (v_quote e :: vs)
  | Sem_e_comp : forall (e1 e2 : expr) (vs1 vs2 vs3 : value_stack),
      Sem_expr vs1 e1 vs2 -> Sem_expr vs2 e2 vs3 -> Sem_expr vs1 (e_comp e1 e2) vs3.

Definition false : expr := e_quote (e_int drop).
Definition false_v : value := v_quote (e_int drop).

Definition true : expr := e_quote (e_comp (e_int swap) (e_int drop)).
Definition true_v : value := v_quote (e_comp (e_int swap) (e_int drop)).

Theorem false_correct : forall (v v' : value) (vs : value_stack), Sem_expr (v' :: v :: vs) (e_comp false (e_int apply)) (v :: vs).
Proof.
  intros v v' vs.
  eapply Sem_e_comp.
  - apply Sem_e_quote.
  - apply Sem_e_int. apply Sem_apply. apply Sem_e_int. apply Sem_drop.
Qed.

Theorem true_correct : forall (v v' : value) (vs : value_stack), Sem_expr (v' :: v :: vs) (e_comp true (e_int apply)) (v' :: vs).
Proof.
  intros v v' vs.
  eapply Sem_e_comp.
  - apply Sem_e_quote.
  - apply Sem_e_int. apply Sem_apply. eapply Sem_e_comp.
    * apply Sem_e_int. apply Sem_swap.
    * apply Sem_e_int. apply Sem_drop.
Qed.

Definition or : expr := e_comp (e_int clone) (e_int apply).

Theorem or_false_v : forall (v : value) (vs : value_stack), Sem_expr (false_v :: v :: vs) or (v :: vs).
Proof with apply Sem_e_int.
  intros v vs.
  eapply Sem_e_comp...
  - apply Sem_clone.
  - apply Sem_apply... apply Sem_drop.
Qed.

Theorem or_true : forall (v : value) (vs : value_stack), Sem_expr (true_v :: v :: vs) or (true_v :: vs).
Proof with apply Sem_e_int.
  intros v vs.
  eapply Sem_e_comp...
  - apply Sem_clone...
  - apply Sem_apply. eapply Sem_e_comp...
    * apply Sem_swap.
    * apply Sem_drop.
Qed.

Definition or_false_false := or_false_v false_v.
Definition or_false_true := or_false_v true_v.
Definition or_true_false := or_true false_v.
Definition or_true_true := or_true true_v.

Fixpoint quote_n (n : nat) :=
  match n with
  | O => e_int quote
  | S n' => e_compose (quote_n n') (e_int swap :: e_int quote :: e_int swap :: e_int compose :: nil)
  end.

Theorem quote_2_correct : forall (v1 v2 : value) (vs : value_stack),
    Sem_expr (v2 :: v1 :: vs) (quote_n 1) (v_quote (e_comp (value_to_expr v1) (value_to_expr v2)) :: vs).
Proof with apply Sem_e_int.
  intros v1 v2 vs. simpl.
  repeat (eapply Sem_e_comp)...
  - apply Sem_quote.
  - apply Sem_swap.
  - apply Sem_quote.
  - apply Sem_swap.
  - apply Sem_compose.
Qed.

Theorem quote_3_correct : forall (v1 v2 v3 : value) (vs : value_stack),
  Sem_expr (v3 :: v2 :: v1 :: vs) (quote_n 2) (v_quote (e_comp (value_to_expr v1) (e_comp (value_to_expr v2) (value_to_expr v3))) :: vs).
Proof with apply Sem_e_int.
  intros v1 v2 v3 vs. simpl.
  repeat (eapply Sem_e_comp)...
  - apply Sem_quote.
  - apply Sem_swap.
  - apply Sem_quote.
  - apply Sem_swap.
  - apply Sem_compose.
  - apply Sem_swap.
  - apply Sem_quote.
  - apply Sem_swap.
  - apply Sem_compose.
Qed.

Ltac2 rec solve_basic () := Control.enter (fun _ =>
  match! goal with
  | [|- Sem_int ?vs1 swap ?vs2] => apply Sem_swap
  | [|- Sem_int ?vs1 clone ?vs2] => apply Sem_clone
  | [|- Sem_int ?vs1 drop ?vs2] => apply Sem_drop
  | [|- Sem_int ?vs1 quote ?vs2] => apply Sem_quote
  | [|- Sem_int ?vs1 compose ?vs2] => apply Sem_compose
  | [|- Sem_int ?vs1 apply ?vs2] => apply Sem_apply
  | [|- Sem_expr ?vs1 (e_comp ?e1 ?e2) ?vs2] => eapply Sem_e_comp; solve_basic ()
  | [|- Sem_expr ?vs1 (e_int ?e) ?vs2] => apply Sem_e_int; solve_basic ()
  | [|- Sem_expr ?vs1 (e_quote ?e) ?vs2] => apply Sem_e_quote
  | [_ : _ |- _] => ()
  end).

Theorem quote_2_correct' : forall (v1 v2 : value) (vs : value_stack),
    Sem_expr (v2 :: v1 :: vs) (quote_n 1) (v_quote (e_comp (value_to_expr v1) (value_to_expr v2)) :: vs).
Proof. intros. simpl. solve_basic (). Qed.

Theorem quote_3_correct' : forall (v1 v2 v3 : value) (vs : value_stack),
  Sem_expr (v3 :: v2 :: v1 :: vs) (quote_n 2) (v_quote (e_comp (value_to_expr v1) (e_comp (value_to_expr v2) (value_to_expr v3))) :: vs).
Proof. intros. simpl. solve_basic (). Qed.

Definition rotate_n (n : nat) := e_compose (quote_n n) (e_int swap :: e_int quote :: e_int compose :: e_int apply :: nil).

Lemma eval_value : forall (v : value) (vs : value_stack),
  Sem_expr vs (value_to_expr v) (v :: vs).
Proof.
  intros v vs.
  destruct v. 
  simpl. apply Sem_e_quote.
Qed.

Theorem rotate_3_correct : forall (v1 v2 v3 : value) (vs : value_stack),
  Sem_expr (v3 :: v2 :: v1 :: vs) (rotate_n 1) (v1 :: v3 :: v2 :: vs).
Proof.
  intros. unfold rotate_n. simpl. solve_basic ().
  repeat (eapply Sem_e_comp); apply eval_value.
Qed.

Theorem rotate_4_correct : forall (v1 v2 v3 v4 : value) (vs : value_stack),
  Sem_expr (v4 :: v3 :: v2 :: v1 :: vs) (rotate_n 2) (v1 :: v4 :: v3 :: v2 :: vs).
Proof.
  intros. unfold rotate_n. simpl. solve_basic ().
  repeat (eapply Sem_e_comp); apply eval_value.
Qed.

Theorem e_comp_assoc : forall (e1 e2 e3 : expr) (vs vs' : value_stack),
  Sem_expr vs (e_comp e1 (e_comp e2 e3)) vs' <-> Sem_expr vs (e_comp (e_comp e1 e2) e3) vs'.
Proof.
  intros e1 e2 e3 vs vs'.
  split; intros Heval.
  - inversion Heval; subst. inversion H4; subst.
    eapply Sem_e_comp. eapply Sem_e_comp. apply H2. apply H3. apply H6.
  - inversion Heval; subst. inversion H2; subst.
    eapply Sem_e_comp. apply H3. eapply Sem_e_comp. apply H6. apply H4.
Qed.