clutch.prob_lang.lang

From Stdlib Require Import Reals Psatz.
From stdpp Require Export binders strings.
From stdpp Require Import gmap fin_maps countable fin.
From iris.algebra Require Export ofe.
From clutch.prelude Require Export stdpp_ext.
From clutch.prob Require Export distribution.
From clutch.common Require Export language ectx_language ectxi_language locations.
From iris.prelude Require Import options.

Delimit Scope expr_scope with E.
Delimit Scope val_scope with V.

Module prob_lang.

Inductive base_lit : Set :=
  | LitInt (n : Z) | LitBool (b : bool) | LitUnit | LitLoc (l : loc) | LitLbl (l : loc).
Inductive un_op : Set :=
  | NegOp | MinusUnOp.
Inductive bin_op : Set :=
  | PlusOp | MinusOp | MultOp | QuotOp | RemOp (* Arithmetic *)
  | AndOp | OrOp | XorOp (* Bitwise *)
  | ShiftLOp | ShiftROp (* Shifts *)
  | LeOp | LtOp | EqOp (* Relations *)
  | OffsetOp. (* Pointer offset *)

Inductive expr :=
  (* Values *)
  | Val (v : val)
  (* Base lambda calculus *)
  | Var (x : string)
  | Rec (f x : binder) (e : expr)
  | App (e1 e2 : expr)
  (* Base types and their operations *)
  | UnOp (op : un_op) (e : expr)
  | BinOp (op : bin_op) (e1 e2 : expr)
  | If (e0 e1 e2 : expr)
  (* Products *)
  | Pair (e1 e2 : expr)
  | Fst (e : expr)
  | Snd (e : expr)
  (* Sums *)
  | InjL (e : expr)
  | InjR (e : expr)
  | Case (e0 : expr) (e1 : expr) (e2 : expr)
  (* Heap *)
  | AllocN (e1 e2 : expr) (* Array length and initial value *)
  | Load (e : expr)
  | Store (e1 : expr) (e2 : expr)
  (* Probabilistic choice *)
  | AllocTape (e : expr)
  | AllocTapeLaplace (e : expr) (e : expr) (e : expr)
  | Rand (e1 e2 : expr)
  (* Sample from discrete Laplace distribution, with scale e1/e2, located at e3, with tape e4 *)
  | Laplace (e1 : expr) (e2 : expr) (e3 : expr) (e4 : expr)
  (* No-op operator used for cost *)
  | Tick (e : expr)
with val :=
  | LitV (l : base_lit)
  | RecV (f x : binder) (e : expr)
  | PairV (v1 v2 : val)
  | InjLV (v : val)
  | InjRV (v : val).

Bind Scope expr_scope with expr.
Bind Scope val_scope with val.

Notation of_val := Val (only parsing).

Definition to_val (e : expr) : option val :=
  match e with
  | Val v => Some v
  | _ => None
  end.

Definition def_val : val := LitV LitUnit.

We assume the following encoding of values to 64-bit words: The least 3 significant bits of every word are a "tag", and we have 61 bits of payload, which is enough if all pointers are 8-byte-aligned (common on 64bit architectures). The tags have the following meaning:
0: Payload is the data for a LitV (LitInt _). 1: Payload is the data for a InjLV (LitV (LitInt _)). 2: Payload is the data for a InjRV (LitV (LitInt _)). 3: Payload is the data for a LitV (LitLoc _). 4: Payload is the data for a InjLV (LitV (LitLoc _)). 4: Payload is the data for a InjRV (LitV (LitLoc _)). 6: Payload is one of the following finitely many values, which 61 bits are more than enough to encode: LitV LitUnit, InjLV (LitV LitUnit), InjRV (LitV LitUnit), LitV LitPoison, InjLV (LitV LitPoison), InjRV (LitV LitPoison), LitV (LitBool _), InjLV (LitV (LitBool _)), InjRV (LitV (LitBool _)). 7: Value is boxed, i.e., payload is a pointer to some read-only memory area on the heap which stores whether this is a RecV, PairV, InjLV or InjRV and the relevant data for those cases. However, the boxed representation is never used if any of the above representations could be used.
Ignoring (as usual) the fact that we have to fit the infinite Z/loc into 61 bits, this means every value is machine-word-sized and can hence be atomically read and written. Also notice that the sets of boxed and unboxed values are disjoint.
Definition lit_is_unboxed (l: base_lit) : Prop :=
  match l with
  
Disallow comparing (erased) prophecies with (erased) prophecies, by considering them boxed.
  (* | LitProphecy _ | LitPoison => False *)
  | LitInt _ | LitBool _ | LitLoc _ | LitLbl _ | LitUnit => True
  end.
Definition val_is_unboxed (v : val) : Prop :=
  match v with
  | LitV l => lit_is_unboxed l
  | InjLV (LitV l) => lit_is_unboxed l
  | InjRV (LitV l) => lit_is_unboxed l
  | _ => False
  end.

Global Instance lit_is_unboxed_dec l : Decision (lit_is_unboxed l).
Proof. destruct l; simpl; exact (decide _). Defined.
Global Instance val_is_unboxed_dec v : Decision (val_is_unboxed v).
Proof. destruct v as [ | | | [] | [] ]; simpl; exact (decide _). Defined.

We just compare the word-sized representation of two values, without looking into boxed data. This works out fine if at least one of the to-be-compared values is unboxed (exploiting the fact that an unboxed and a boxed value can never be equal because these are disjoint sets).
Definition vals_compare_safe (vl v1 : val) : Prop :=
  val_is_unboxed vl val_is_unboxed v1.
Global Arguments vals_compare_safe !_ !_ /.

Definition tape := { n : nat & list (fin (S n)) }.
Variant tape_laplace :=
  Tape_Laplace (num den mean : Z) (tape_content : list Z).

Global Instance tape_inhabited : Inhabited tape := populate (existT 0%nat []).
Global Instance tape_eq_dec : EqDecision tape. Proof. apply _. Defined.
Global Instance tape_countable : Countable tape. Proof. apply _. Qed.

Global Instance tape_laplace_inhabited : Inhabited tape_laplace := populate (Tape_Laplace 0 0 0 []).
Global Instance tape_laplace_eq_dec : EqDecision tape_laplace. Proof. solve_decision. Defined.
Global Instance tape_laplace_countable : Countable tape_laplace.
Proof.
  unshelve econstructor.
  - intros []. exact (prod_encode (prod_encode (prod_encode (Z_countable.(encode) num) (Z_countable.(encode) den)) (Z_countable.(encode) mean)) (encode tape_content)).
  - intros p.
    exact
      ( match prod_decode_fst p with
        | None => None
        | Some num_den_mean =>
            match prod_decode_fst num_den_mean with
            | None => None
            | Some num_den =>
                match prod_decode_fst num_den with
                | None => None
                | Some num =>
                    match prod_decode_snd num_den with
                    | Some den =>
                        match prod_decode_snd num_den_mean with
                        | Some mean =>
                            match prod_decode_snd p with
                            | Some tape_content =>
                                match (Z_countable.(decode) num, Z_countable.(decode) den, Z_countable.(decode) mean, decode tape_content) with
                                | (Some num, Some den, Some mean, Some tape_content) =>
                                    Some $ Tape_Laplace num den mean tape_content
                                | _ => None
                                end
                            | None => None
                            end
                        | None => None
                        end
                    | None => None
                    end
                end
            end
        end ).
    - simpl. intros [].
      rewrite !prod_decode_encode_fst !prod_decode_encode_snd !decode_encode.
      f_equal. f_equal.
      + destruct num ; done.
      + destruct den ; done.
      + destruct mean ; done.
Qed.

The state: a loc-indexed heap of vals, and loc-indexed tapes of booleans.
Record state : Type := {
  heap : gmap loc val;
  tapes : gmap loc tape;
  tapes_laplace : gmap loc tape_laplace
}.

Equality and other typeclass stuff
Lemma to_of_val v : to_val (of_val v) = Some v.
Proof. by destruct v. Qed.

Lemma of_to_val e v : to_val e = Some v of_val v = e.
Proof. destruct e=>//=. by intros [= <-]. Qed.

Global Instance of_val_inj : Inj (=) (=) of_val.
Proof. intros ??. congruence. Qed.

Global Instance base_lit_eq_dec : EqDecision base_lit.
Proof. solve_decision. Defined.
Global Instance un_op_eq_dec : EqDecision un_op.
Proof. solve_decision. Defined.
Global Instance bin_op_eq_dec : EqDecision bin_op.
Proof. solve_decision. Defined.
Global Instance expr_eq_dec : EqDecision expr.
Proof.
  refine (
   fix go (e1 e2 : expr) {struct e1} : Decision (e1 = e2) :=
     match e1, e2 with
     | Val v, Val v' => cast_if (decide (v = v'))
     | Var x, Var x' => cast_if (decide (x = x'))
     | Rec f x e, Rec f' x' e' =>
        cast_if_and3 (decide (f = f')) (decide (x = x')) (decide (e = e'))
     | App e1 e2, App e1' e2' => cast_if_and (decide (e1 = e1')) (decide (e2 = e2'))
     | UnOp o e, UnOp o' e' => cast_if_and (decide (o = o')) (decide (e = e'))
     | BinOp o e1 e2, BinOp o' e1' e2' =>
        cast_if_and3 (decide (o = o')) (decide (e1 = e1')) (decide (e2 = e2'))
     | If e0 e1 e2, If e0' e1' e2' =>
        cast_if_and3 (decide (e0 = e0')) (decide (e1 = e1')) (decide (e2 = e2'))
     | Pair e1 e2, Pair e1' e2' =>
        cast_if_and (decide (e1 = e1')) (decide (e2 = e2'))
     | Fst e, Fst e' => cast_if (decide (e = e'))
     | Snd e, Snd e' => cast_if (decide (e = e'))
     | InjL e, InjL e' => cast_if (decide (e = e'))
     | InjR e, InjR e' => cast_if (decide (e = e'))
     | Case e0 e1 e2, Case e0' e1' e2' =>
        cast_if_and3 (decide (e0 = e0')) (decide (e1 = e1')) (decide (e2 = e2'))
     | AllocN e1 e2, AllocN e1' e2' => cast_if_and (decide (e1 = e1')) (decide (e2 = e2'))
     | Load e, Load e' => cast_if (decide (e = e'))
     | Store e1 e2, Store e1' e2' =>
        cast_if_and (decide (e1 = e1')) (decide (e2 = e2'))
     | AllocTape e, AllocTape e' => cast_if (decide (e = e'))
     | AllocTapeLaplace e1 e2 e3, AllocTapeLaplace e1' e2' e3' => cast_if_and3 (decide (e1 = e1')) (decide (e2 = e2')) (decide (e3 = e3'))
     | Rand e1 e2, Rand e1' e2' => cast_if_and (decide (e1 = e1')) (decide (e2 = e2'))
     | Laplace e1 e2 e3 e4, Laplace e1' e2' e3' e4' => cast_if_and4 (decide (e1 = e1')) (decide (e2 = e2')) (decide (e3 = e3')) (decide (e4 = e4'))
     | Tick e, Tick e' => cast_if (decide (e = e'))
     | _, _ => right _
     end
   with gov (v1 v2 : val) {struct v1} : Decision (v1 = v2) :=
     match v1, v2 with
     | LitV l, LitV l' => cast_if (decide (l = l'))
     | RecV f x e, RecV f' x' e' =>
        cast_if_and3 (decide (f = f')) (decide (x = x')) (decide (e = e'))
     | PairV e1 e2, PairV e1' e2' =>
        cast_if_and (decide (e1 = e1')) (decide (e2 = e2'))
     | InjLV e, InjLV e' => cast_if (decide (e = e'))
     | InjRV e, InjRV e' => cast_if (decide (e = e'))
     | _, _ => right _
     end
   for go); try (clear go gov; abstract intuition congruence).
Defined.
Global Instance val_eq_dec : EqDecision val.
Proof. solve_decision. Defined.
Global Instance state_eq_dec : EqDecision state.
Proof. solve_decision. Defined.

Global Instance base_lit_countable : Countable base_lit.
Proof.
 refine (inj_countable' (λ l, match l with
  | LitInt n => inl (inl n)
  | LitBool b => inl (inr b)
  | LitUnit => inr (inl ())
  | LitLoc l => inr (inr (inr l))
  | LitLbl l => inr (inr (inl l))
  end) (λ l, match l with
  | inl (inl n) => LitInt n
  | inl (inr b) => LitBool b
  | inr (inl ()) => LitUnit
  | inr (inr (inr l)) => LitLoc l
  | inr (inr (inl l)) => LitLbl l
  end) _); by intros [].
Qed.
Global Instance un_op_finite : Countable un_op.
Proof.
 refine (inj_countable' (λ op, match op with NegOp => 0 | MinusUnOp => 1 end)
  (λ n, match n with 0 => NegOp | _ => MinusUnOp end) _); by intros [].
Qed.
Global Instance bin_op_countable : Countable bin_op.
Proof.
 refine (inj_countable' (λ op, match op with
  | PlusOp => 0 | MinusOp => 1 | MultOp => 2 | QuotOp => 3 | RemOp => 4
  | AndOp => 5 | OrOp => 6 | XorOp => 7 | ShiftLOp => 8 | ShiftROp => 9
  | LeOp => 10 | LtOp => 11 | EqOp => 12 | OffsetOp => 13
  end) (λ n, match n with
  | 0 => PlusOp | 1 => MinusOp | 2 => MultOp | 3 => QuotOp | 4 => RemOp
  | 5 => AndOp | 6 => OrOp | 7 => XorOp | 8 => ShiftLOp | 9 => ShiftROp
  | 10 => LeOp | 11 => LtOp | 12 => EqOp | _ => OffsetOp
  end) _); by intros [].
Qed.
Global Instance expr_countable : Countable expr.
Proof.
 set (enc :=
   fix go e :=
     match e with
     | Val v => GenNode 0 [gov v]
     | Var x => GenLeaf (inl (inl x))
     | Rec f x e => GenNode 1 [GenLeaf (inl (inr f)); GenLeaf (inl (inr x)); go e]
     | App e1 e2 => GenNode 2 [go e1; go e2]
     | UnOp op e => GenNode 3 [GenLeaf (inr (inr (inl op))); go e]
     | BinOp op e1 e2 => GenNode 4 [GenLeaf (inr (inr (inr op))); go e1; go e2]
     | If e0 e1 e2 => GenNode 5 [go e0; go e1; go e2]
     | Pair e1 e2 => GenNode 6 [go e1; go e2]
     | Fst e => GenNode 7 [go e]
     | Snd e => GenNode 8 [go e]
     | InjL e => GenNode 9 [go e]
     | InjR e => GenNode 10 [go e]
     | Case e0 e1 e2 => GenNode 11 [go e0; go e1; go e2]
     | AllocN e1 e2 => GenNode 12 [go e1; go e2]
     | Load e => GenNode 13 [go e]
     | Store e1 e2 => GenNode 14 [go e1; go e2]
     | AllocTape e => GenNode 15 [go e]
     | AllocTapeLaplace e1 e2 e3 => GenNode 16 [go e1; go e2; go e3]
     | Rand e1 e2 => GenNode 17 [go e1; go e2]
     | Laplace e1 e2 e3 e4 => GenNode 18 [go e1; go e2; go e3; go e4]
     | Tick e => GenNode 19 [go e]
     end
   with gov v :=
     match v with
     | LitV l => GenLeaf (inr (inl l))
     | RecV f x e =>
        GenNode 0 [GenLeaf (inl (inr f)); GenLeaf (inl (inr x)); go e]
     | PairV v1 v2 => GenNode 1 [gov v1; gov v2]
     | InjLV v => GenNode 2 [gov v]
     | InjRV v => GenNode 3 [gov v]
     end
   for go).
 set (dec :=
   fix go e :=
     match e with
     | GenNode 0 [v] => Val (gov v)
     | GenLeaf (inl (inl x)) => Var x
     | GenNode 1 [GenLeaf (inl (inr f)); GenLeaf (inl (inr x)); e] => Rec f x (go e)
     | GenNode 2 [e1; e2] => App (go e1) (go e2)
     | GenNode 3 [GenLeaf (inr (inr (inl op))); e] => UnOp op (go e)
     | GenNode 4 [GenLeaf (inr (inr (inr op))); e1; e2] => BinOp op (go e1) (go e2)
     | GenNode 5 [e0; e1; e2] => If (go e0) (go e1) (go e2)
     | GenNode 6 [e1; e2] => Pair (go e1) (go e2)
     | GenNode 7 [e] => Fst (go e)
     | GenNode 8 [e] => Snd (go e)
     | GenNode 9 [e] => InjL (go e)
     | GenNode 10 [e] => InjR (go e)
     | GenNode 11 [e0; e1; e2] => Case (go e0) (go e1) (go e2)
     | GenNode 12 [e1 ; e2] => AllocN (go e1) (go e2)
     | GenNode 13 [e] => Load (go e)
     | GenNode 14 [e1; e2] => Store (go e1) (go e2)
     | GenNode 15 [e] => AllocTape (go e)
     | GenNode 16 [e1; e2; e3] => AllocTapeLaplace (go e1) (go e2) (go e3)
     | GenNode 17 [e1; e2] => Rand (go e1) (go e2)
     | GenNode 18 [e1; e2; e3; e4] => Laplace (go e1) (go e2) (go e3) (go e4)
     | GenNode 19 [e] => Tick (go e)
     | _ => Val $ LitV LitUnit (* dummy *)
     end
   with gov v :=
     match v with
     | GenLeaf (inr (inl l)) => LitV l
     | GenNode 0 [GenLeaf (inl (inr f)); GenLeaf (inl (inr x)); e] => RecV f x (go e)
     | GenNode 1 [v1; v2] => PairV (gov v1) (gov v2)
     | GenNode 2 [v] => InjLV (gov v)
     | GenNode 3 [v] => InjRV (gov v)
     | _ => LitV LitUnit (* dummy *)
     end
   for go).
 refine (inj_countable' enc dec _).
 refine (fix go (e : expr) {struct e} := _ with gov (v : val) {struct v} := _ for go).
 - destruct e as [v| | | | | | | | | | | | | | | | | | | | ]; simpl; f_equal;
     [exact (gov v)|done..].
 - destruct v; by f_equal.
Qed.
Global Instance val_countable : Countable val.
Proof. refine (inj_countable of_val to_val _); auto using to_of_val. Qed.

Global Program Instance state_countable : Countable state :=
  {| encode σ := encode (σ.(heap), σ.(tapes), σ.(tapes_laplace));
     decode p := match decode p : option (_ * _ * _) with
                 | Some (h, t, tl) => Some {|heap:=h; tapes:=t; tapes_laplace:=tl|}
                 | None => None end |}.
Next Obligation. intros [h t tl]. rewrite decode_encode //=. Qed.

Global Instance state_inhabited : Inhabited state :=
  populate {| heap := inhabitant; tapes := inhabitant; tapes_laplace := inhabitant |}.
Global Instance val_inhabited : Inhabited val := populate (LitV LitUnit).
Global Instance expr_inhabited : Inhabited expr := populate (Val inhabitant).

Canonical Structure stateO := leibnizO state.
Canonical Structure locO := leibnizO loc.
Canonical Structure valO := leibnizO val.
Canonical Structure exprO := leibnizO expr.

Evaluation contexts
Inductive ectx_item :=
  | AppLCtx (v2 : val)
  | AppRCtx (e1 : expr)
  | UnOpCtx (op : un_op)
  | BinOpLCtx (op : bin_op) (v2 : val)
  | BinOpRCtx (op : bin_op) (e1 : expr)
  | IfCtx (e1 e2 : expr)
  | PairLCtx (v2 : val)
  | PairRCtx (e1 : expr)
  | FstCtx
  | SndCtx
  | InjLCtx
  | InjRCtx
  | CaseCtx (e1 : expr) (e2 : expr)
  | AllocNLCtx (v2 : val)
  | AllocNRCtx (e1 : expr)
  | LoadCtx
  | StoreLCtx (v2 : val)
  | StoreRCtx (e1 : expr)
  | AllocTapeCtx
  | AllocTapeLaplaceNumCtx (v2 : val) (v3 : val)
  | AllocTapeLaplaceDenCtx (e1 : expr) (v3 : val)
  | AllocTapeLaplaceMeanCtx (e1 : expr) (e2 : expr)
  | RandLCtx (v2 : val)
  | RandRCtx (e1 : expr)
  | LaplaceNumCtx (v2 : val) (v3 : val) (v4 : val)
  | LaplaceDenCtx (e1 : expr) (v3 : val) (v4 : val)
  | LaplaceMeanCtx (e1 : expr) (e2 : expr) (v4 : val)
  | LaplaceTapeCtx (e1 : expr) (e2 : expr) (e3 : expr)
  | TickCtx.

Definition fill_item (Ki : ectx_item) (e : expr) : expr :=
  match Ki with
  | AppLCtx v2 => App e (of_val v2)
  | AppRCtx e1 => App e1 e
  | UnOpCtx op => UnOp op e
  | BinOpLCtx op v2 => BinOp op e (Val v2)
  | BinOpRCtx op e1 => BinOp op e1 e
  | IfCtx e1 e2 => If e e1 e2
  | PairLCtx v2 => Pair e (Val v2)
  | PairRCtx e1 => Pair e1 e
  | FstCtx => Fst e
  | SndCtx => Snd e
  | InjLCtx => InjL e
  | InjRCtx => InjR e
  | CaseCtx e1 e2 => Case e e1 e2
  | AllocNLCtx v2 => AllocN e (Val v2)
  | AllocNRCtx e1 => AllocN e1 e
  | LoadCtx => Load e
  | StoreLCtx v2 => Store e (Val v2)
  | StoreRCtx e1 => Store e1 e
  | AllocTapeCtx => AllocTape e
  | AllocTapeLaplaceNumCtx v2 v3 => AllocTapeLaplace e (Val v2) (Val v3)
  | AllocTapeLaplaceDenCtx e1 v3 => AllocTapeLaplace e1 e (Val v3)
  | AllocTapeLaplaceMeanCtx e1 e2 => AllocTapeLaplace e1 e2 e
  | RandLCtx v2 => Rand e (Val v2)
  | RandRCtx e1 => Rand e1 e
  | LaplaceNumCtx v2 v3 v4 => Laplace e (Val v2) (Val v3) (Val v4)
  | LaplaceDenCtx e1 v3 v4 => Laplace e1 e (Val v3) (Val v4)
  | LaplaceMeanCtx e1 e2 v4 => Laplace e1 e2 e (Val v4)
  | LaplaceTapeCtx e1 e2 e3 => Laplace e1 e2 e3 e
  | TickCtx => Tick e
  end.

Definition decomp_item (e : expr) : option (ectx_item * expr) :=
  let noval (e : expr) (ei : ectx_item) :=
    match e with Val _ => None | _ => Some (ei, e) end in
  match e with
  | App e1 e2 =>
      match e2 with
      | (Val v) => noval e1 (AppLCtx v)
      | _ => Some (AppRCtx e1, e2)
      end
  | UnOp op e => noval e (UnOpCtx op)
  | BinOp op e1 e2 =>
      match e2 with
      | Val v => noval e1 (BinOpLCtx op v)
      | _ => Some (BinOpRCtx op e1, e2)
      end
  | If e0 e1 e2 => noval e0 (IfCtx e1 e2)
  | Pair e1 e2 =>
      match e2 with
      | Val v => noval e1 (PairLCtx v)
      | _ => Some (PairRCtx e1, e2)
      end
  | Fst e => noval e FstCtx
  | Snd e => noval e SndCtx
  | InjL e => noval e InjLCtx
  | InjR e => noval e InjRCtx
  | Case e0 e1 e2 => noval e0 (CaseCtx e1 e2)
  | AllocN e1 e2 =>
      match e2 with
      | Val v => noval e1 (AllocNLCtx v)
      | _ => Some (AllocNRCtx e1, e2)
      end

  | Load e => noval e LoadCtx
  | Store e1 e2 =>
      match e2 with
      | Val v => noval e1 (StoreLCtx v)
      | _ => Some (StoreRCtx e1, e2)
      end
  | AllocTape e => noval e AllocTapeCtx
  | Rand e1 e2 =>
      match e2 with
      | Val v => noval e1 (RandLCtx v)
      | _ => Some (RandRCtx e1, e2)
      end
  | AllocTapeLaplace e1 e2 e3 =>
      match e3 with
      | Val v3 =>
          match e2 with
          | Val v2 => noval e1 (AllocTapeLaplaceNumCtx v2 v3)
          | _ => Some (AllocTapeLaplaceDenCtx e1 v3, e2)
          end
      | _ => Some (AllocTapeLaplaceMeanCtx e1 e2, e3)
      end
  | Laplace e1 e2 e3 e4 =>
      match e4 with
      | Val v4 =>
          match e3 with
          | Val v3 =>
              match e2 with
              | Val v2 => noval e1 (LaplaceNumCtx v2 v3 v4)
              | _ => Some (LaplaceDenCtx e1 v3 v4, e2)
              end
          | _ => Some (LaplaceMeanCtx e1 e2 v4, e3)
          end
      | _ => Some (LaplaceTapeCtx e1 e2 e3, e4)
      end
  | Tick e => noval e TickCtx
  | _ => None
  end.

Substitution
Fixpoint subst (x : string) (v : val) (e : expr) : expr :=
  match e with
  | Val _ => e
  | Var y => if decide (x = y) then Val v else Var y
  | Rec f y e =>
     Rec f y $ if decide (BNamed x f BNamed x y) then subst x v e else e
  | App e1 e2 => App (subst x v e1) (subst x v e2)
  | UnOp op e => UnOp op (subst x v e)
  | BinOp op e1 e2 => BinOp op (subst x v e1) (subst x v e2)
  | If e0 e1 e2 => If (subst x v e0) (subst x v e1) (subst x v e2)
  | Pair e1 e2 => Pair (subst x v e1) (subst x v e2)
  | Fst e => Fst (subst x v e)
  | Snd e => Snd (subst x v e)
  | InjL e => InjL (subst x v e)
  | InjR e => InjR (subst x v e)
  | Case e0 e1 e2 => Case (subst x v e0) (subst x v e1) (subst x v e2)
  | AllocN e1 e2 => AllocN (subst x v e1) (subst x v e2)
  | Load e => Load (subst x v e)
  | Store e1 e2 => Store (subst x v e1) (subst x v e2)
  | AllocTape e => AllocTape (subst x v e)
  | AllocTapeLaplace e1 e2 e3 => AllocTapeLaplace (subst x v e1) (subst x v e2) (subst x v e3)
  | Rand e1 e2 => Rand (subst x v e1) (subst x v e2)
  | Laplace e1 e2 e3 e4 => Laplace (subst x v e1) (subst x v e2) (subst x v e3) (subst x v e4)
  | Tick e => Tick (subst x v e)
  end.

Definition subst' (mx : binder) (v : val) : expr expr :=
  match mx with BNamed x => subst x v | BAnon => λ x, x end.

The stepping relation
Definition un_op_eval (op : un_op) (v : val) : option val :=
  match op, v with
  | NegOp, LitV (LitBool b) => Some $ LitV $ LitBool (negb b)
  | NegOp, LitV (LitInt z) => Some $ LitV $ LitInt (Z.lnot z)
  | MinusUnOp, LitV (LitInt z) => Some $ LitV $ LitInt (- z)
  | _, _ => None
  end.

Definition bin_op_eval_int (op : bin_op) (n1 n2 : Z) : base_lit :=
  match op with
  | PlusOp => LitInt (n1 + n2)
  | MinusOp => LitInt (n1 - n2)
  | MultOp => LitInt (n1 * n2)
  | QuotOp => LitInt (n1 `quot` n2)
  | RemOp => LitInt (n1 `rem` n2)
  | AndOp => LitInt (Z.land n1 n2)
  | OrOp => LitInt (Z.lor n1 n2)
  | XorOp => LitInt (Z.lxor n1 n2)
  | ShiftLOp => LitInt (n1 n2)
  | ShiftROp => LitInt (n1 n2)
  | LeOp => LitBool (bool_decide (n1 n2))
  | LtOp => LitBool (bool_decide (n1 < n2))
  | EqOp => LitBool (bool_decide (n1 = n2))
  | OffsetOp => LitInt (n1 + n2) (* Treat offsets as ints *)
  end%Z.

Definition bin_op_eval_bool (op : bin_op) (b1 b2 : bool) : option base_lit :=
  match op with
  | PlusOp | MinusOp | MultOp | QuotOp | RemOp => None (* Arithmetic *)
  | AndOp => Some (LitBool (b1 && b2))
  | OrOp => Some (LitBool (b1 || b2))
  | XorOp => Some (LitBool (xorb b1 b2))
  | ShiftLOp | ShiftROp => None (* Shifts *)
  | LeOp | LtOp => None (* InEquality *)
  | EqOp => Some (LitBool (bool_decide (b1 = b2)))
  | OffsetOp => None
  end.

Definition bin_op_eval_loc (op : bin_op) (l1 : loc) (v2 : base_lit) : option base_lit :=
  match op, v2 with
  | OffsetOp, LitInt off => Some $ LitLoc (l1 +ₗ off)
  | LeOp, LitLoc l2 => Some $ LitBool (bool_decide (l1 ≤ₗ l2))
  | LtOp, LitLoc l2 => Some $ LitBool (bool_decide (l1 <ₗ l2))
  | _, _ => None
  end.

Definition bin_op_eval (op : bin_op) (v1 v2 : val) : option val :=
  if decide (op = EqOp) then
    if decide (vals_compare_safe v1 v2) then
      Some $ LitV $ LitBool $ bool_decide (v1 = v2)
    else
      None
  else
    match v1, v2 with
    | LitV (LitInt n1), LitV (LitInt n2) => Some $ LitV $ bin_op_eval_int op n1 n2
    | LitV (LitBool b1), LitV (LitBool b2) => LitV <$> bin_op_eval_bool op b1 b2
    | LitV (LitLoc l1), LitV v2 => LitV <$> bin_op_eval_loc op l1 v2
    | _, _ => None
    end.

Definition state_upd_heap (f : gmap loc val gmap loc val) (σ : state) : state :=
  {| heap := f σ.(heap); tapes := σ.(tapes); tapes_laplace := σ.(tapes_laplace) |}.
Global Arguments state_upd_heap _ !_ /.

Definition state_upd_tapes (f : gmap loc tape gmap loc tape) (σ : state) : state :=
  {| heap := σ.(heap); tapes := f σ.(tapes); tapes_laplace := σ.(tapes_laplace) |}.
Global Arguments state_upd_tapes _ !_ /.

Definition state_upd_tapes_laplace (f : gmap loc tape_laplace gmap loc tape_laplace) (σ : state) : state :=
  {| heap := σ.(heap); tapes := σ.(tapes); tapes_laplace := f σ.(tapes_laplace) |}.
Global Arguments state_upd_tapes_laplace _ !_ /.

Lemma state_upd_tapes_twice σ l n xs ys :
  state_upd_tapes <[l:=(n; ys)]> (state_upd_tapes <[l:=(n; xs)]> σ) = state_upd_tapes <[l:=(n; ys)]> σ.
Proof. rewrite /state_upd_tapes /=. f_equal. apply insert_insert_eq. Qed.

Lemma state_upd_tapes_same σ σ' l n xs ys :
  state_upd_tapes <[l:=(n; ys)]> σ = state_upd_tapes <[l:=(n; xs)]> σ' -> xs = ys.
Proof. rewrite /state_upd_tapes /=. intros K. simplify_eq.
       rewrite map_eq_iff in H.
       specialize (H l).
       rewrite !lookup_insert_eq in H.
       by simplify_eq.
Qed.

Lemma state_upd_tapes_no_change σ l n ys :
  tapes σ !! l = Some (n; ys)->
  state_upd_tapes <[l:=(n; ys)]> σ = σ .
Proof.
  destruct σ as [? t]. simpl.
  intros Ht.
  f_equal.
  apply insert_id. done.
Qed.

Lemma state_upd_tapes_same' σ σ' l n xs (x y : fin (S n)) :
  state_upd_tapes <[l:=(n; xs++[x])]> σ = state_upd_tapes <[l:=(n; xs++[y])]> σ' -> x = y.
Proof. intros H. apply state_upd_tapes_same in H.
       by simplify_eq.
Qed.

Lemma state_upd_tapes_neq' σ σ' l n xs (x y : fin (S n)) :
  xy -> state_upd_tapes <[l:=(n; xs++[x])]> σ state_upd_tapes <[l:=(n; xs++[y])]> σ'.
Proof. move => H /state_upd_tapes_same ?. simplify_eq.
Qed.

Fixpoint heap_array (l : loc) (vs : list val) : gmap loc val :=
  match vs with
  | [] =>
  | v :: vs' => {[l := v]} heap_array (l +ₗ 1) vs'
  end.

Lemma heap_array_singleton l v : heap_array l [v] = {[l := v]}.
Proof. by rewrite /heap_array right_id. Qed.

Lemma heap_array_app l vs1 vs2 : heap_array l (vs1 ++ vs2) = (heap_array l vs1) (heap_array (l +ₗ (length vs1)) vs2) .
Proof.
  revert l.
  induction vs1; intro l.
  - simpl.
    rewrite map_empty_union loc_add_0 //.
  - rewrite -app_comm_cons /= IHvs1.
    rewrite map_union_assoc.
    do 2 f_equiv.
    rewrite Nat2Z.inj_succ /=.
    rewrite /Z.succ
      Z.add_comm
      loc_add_assoc //.
Qed.

Lemma heap_array_lookup l vs v k :
  heap_array l vs !! k = Some v
   j, (0 j)%Z k = l +ₗ j vs !! (Z.to_nat j) = Some v.
Proof.
  revert k l; induction vs as [|v' vs IH]=> l' l /=.
  { rewrite lookup_empty. naive_solver lia. }
  rewrite -insert_union_singleton_l lookup_insert_Some IH. split.
  - intros [[-> ?] | (Hl & j & ? & -> & ?)].
    { eexists 0. rewrite loc_add_0. naive_solver lia. }
    eexists (1 + j)%Z. rewrite loc_add_assoc !Z.add_1_l Z2Nat.inj_succ; auto with lia.
  - intros (j & ? & -> & Hil). destruct (decide (j = 0)); simplify_eq/=.
    { rewrite loc_add_0; eauto. }
    right. split.
    { rewrite -{1}(loc_add_0 l). intros ?%(inj (loc_add _)); lia. }
    assert (Z.to_nat j = S (Z.to_nat (j - 1))) as Hj.
    { rewrite -Z2Nat.inj_succ; last lia. f_equal; lia. }
    rewrite Hj /= in Hil.
    eexists (j - 1)%Z. rewrite loc_add_assoc Z.add_sub_assoc Z.add_simpl_l.
    auto with lia.
Qed.

Lemma heap_array_map_disjoint (h : gmap loc val) (l : loc) (vs : list val) :
  ( i, (0 i)%Z (i < length vs)%Z h !! (l +ₗ i) = None)
  (heap_array l vs) ##ₘ h.
Proof.
  intros Hdisj. apply map_disjoint_spec=> l' v1 v2.
  intros (j&?&->&Hj%lookup_lt_Some%inj_lt)%heap_array_lookup.
  move: Hj. rewrite Z2Nat.id // => ?. by rewrite Hdisj.
Qed.

Definition state_upd_heap_N (l : loc) (n : nat) (v : val) (σ : state) : state :=
  state_upd_heap (λ h, heap_array l (replicate n v) h) σ.

Lemma state_upd_heap_singleton l v σ :
  state_upd_heap_N l 1 v σ = state_upd_heap <[l:= v]> σ.
Proof.
  destruct σ as [h p]. rewrite /state_upd_heap_N /=. f_equiv.
  rewrite right_id insert_union_singleton_l. done.
Qed.

Lemma state_upd_tapes_heap σ l1 l2 n xs m v :
  state_upd_tapes <[l2:=(n; xs)]> (state_upd_heap_N l1 m v σ) =
  state_upd_heap_N l1 m v (state_upd_tapes <[l2:=(n; xs)]> σ).
Proof.
  by rewrite /state_upd_tapes /state_upd_heap_N /=.
Qed.

Lemma heap_array_replicate_S_end l v n :
  heap_array l (replicate (S n) v) = heap_array l (replicate n v) {[l +ₗ n:= v]}.
Proof.
  induction n.
  - simpl.
    rewrite map_union_empty.
    rewrite map_empty_union.
    by rewrite loc_add_0.
  - rewrite replicate_S_end
     heap_array_app
     IHn /=.
    rewrite map_union_empty length_replicate //.
Qed.

#[local] Open Scope R.

Definition head_step (e1 : expr) (σ1 : state) : distr (expr * state) :=
  match e1 with
  | Rec f x e =>
      dret (Val $ RecV f x e, σ1)
  | Pair (Val v1) (Val v2) =>
      dret (Val $ PairV v1 v2, σ1)
  | InjL (Val v) =>
      dret (Val $ InjLV v, σ1)
  | InjR (Val v) =>
      dret (Val $ InjRV v, σ1)
  | App (Val (RecV f x e1)) (Val v2) =>
      dret (subst' x v2 (subst' f (RecV f x e1) e1) , σ1)
  | UnOp op (Val v) =>
      match un_op_eval op v with
        | Some w => dret (Val w, σ1)
        | _ => dzero
      end
  | BinOp op (Val v1) (Val v2) =>
      match bin_op_eval op v1 v2 with
        | Some w => dret (Val w, σ1)
        | _ => dzero
      end
  | If (Val (LitV (LitBool true))) e1 e2 =>
      dret (e1 , σ1)
  | If (Val (LitV (LitBool false))) e1 e2 =>
      dret (e2 , σ1)
  | Fst (Val (PairV v1 v2)) =>
      dret (Val v1, σ1)
  | Snd (Val (PairV v1 v2)) =>
      dret (Val v2, σ1)
  | Case (Val (InjLV v)) e1 e2 =>
      dret (App e1 (Val v), σ1)
  | Case (Val (InjRV v)) e1 e2 =>
      dret (App e2 (Val v), σ1)
  | AllocN (Val (LitV (LitInt N))) (Val v) =>
      let := fresh_loc σ1.(heap) in
      if bool_decide (0 < Z.to_nat N)%nat
        then dret (Val $ LitV $ LitLoc , state_upd_heap_N (Z.to_nat N) v σ1)
        else dzero
  | Load (Val (LitV (LitLoc l))) =>
      match σ1.(heap) !! l with
        | Some v => dret (Val v, σ1)
        | None => dzero
      end
  | Store (Val (LitV (LitLoc l))) (Val w) =>
      match σ1.(heap) !! l with
        | Some v => dret (Val $ LitV LitUnit, state_upd_heap <[l:=w]> σ1)
        | None => dzero
      end
  (* Since our language only has integers, we use Z.to_nat, which maps positive
     integers to the corresponding nat, and the rest to 0. We sample from
     dunifP N = dunif (1 + N) to avoid the case dunif 0 = dzero. *)

  (* Uniform sampling from 0, 1 , ..., N *)
  | Rand (Val (LitV (LitInt N))) (Val (LitV LitUnit)) =>
      dmap (λ n : fin _, (Val $ LitV $ LitInt n, σ1)) (dunifP (Z.to_nat N))
  | AllocTape (Val (LitV (LitInt z))) =>
      let ι := fresh_loc σ1.(tapes) in
      dret (Val $ LitV $ LitLbl ι, state_upd_tapes <[ι := (Z.to_nat z; []) ]> σ1)
  | AllocTapeLaplace (Val (LitV (LitInt num))) (Val (LitV (LitInt den))) (Val (LitV (LitInt loc))) =>
      let ι := fresh_loc σ1.(tapes_laplace) in
      dret (Val $ LitV $ LitLbl ι, state_upd_tapes_laplace <[ι := Tape_Laplace num den loc [] ]> σ1)
  (* Labelled sampling, conditional on tape contents *)
  | Rand (Val (LitV (LitInt N))) (Val (LitV (LitLbl l))) =>
      match σ1.(tapes) !! l with
      | Some (M; ns) =>
          if bool_decide (M = Z.to_nat N) then
            match ns with
            | n :: ns =>
                (* the tape is non-empty so we consume the first number *)
                dret (Val $ LitV $ LitInt $ fin_to_nat n, state_upd_tapes <[l:=(M; ns)]> σ1)
            | [] =>
                (* the tape is allocated but empty, so we sample from 0, 1, ..., M uniformly *)
                dmap (λ n : fin _, (Val $ LitV $ LitInt n, σ1)) (dunifP M)
            end
          else
            (* bound did not match the bound of the tape *)
            dmap (λ n : fin _, (Val $ LitV $ LitInt n, σ1)) (dunifP (Z.to_nat N))
      | None => dzero
      end
  | Laplace (Val (LitV (LitInt num))) (Val (LitV (LitInt den))) (Val (LitV (LitInt loc))) (Val (LitV LitUnit)) =>
      dmap (λ z : Z, (Val $ LitV $ LitInt z, σ1)) (laplace_rat num den loc)
  | Laplace (Val (LitV (LitInt num))) (Val (LitV (LitInt den))) (Val (LitV (LitInt loc))) (Val (LitV (LitLbl l))) =>
      match σ1.(tapes_laplace) !! l with
      | Some (Tape_Laplace num' den' loc' xs) =>
          if (bool_decide ((num = num') (den = den') (loc = loc')))%Z then
            match xs with
            | x :: xs => dret (Val $ LitV $ LitInt $ x, state_upd_tapes_laplace <[l:=(Tape_Laplace num' den' loc' xs)]> σ1)
            | [] => dmap (λ z : Z, (Val $ LitV $ LitInt z, σ1)) (laplace_rat num den loc)
            end
          else
            (* tape and laplace parameters mismatch; follow laplace args *)
            dmap (λ z : Z, (Val $ LitV $ LitInt z, σ1)) (laplace_rat num den loc)
      | None => dzero
      end

  | Tick (Val (LitV (LitInt n))) => dret (Val $ LitV $ LitUnit, σ1)
  | _ => dzero
  end.

Definition state_step (σ1 : state) (α : loc) : distr state :=
  if bool_decide (α dom σ1.(tapes)) then
    let: (N; ns) := (σ1.(tapes) !!! α) in
    dmap (λ n, state_upd_tapes (<[α := (N; ns ++ [n])]>) σ1) (dunifP N)
  else dzero.

Lemma state_step_unfold σ α N ns:
  tapes σ !! α = Some (N; ns) ->
  state_step σ α = dmap (λ n, state_upd_tapes (<[α := (N; ns ++ [n])]>) σ) (dunifP N).
Proof.
  intros H.
  rewrite /state_step.
  rewrite bool_decide_eq_true_2; last first.
  { by apply elem_of_dom. }
  by rewrite (lookup_total_correct (tapes σ) α (N; ns)); last done.
Qed.

Definition state_step_laplace (σ1 : state) (α : loc) : distr state :=
  if bool_decide (α dom σ1.(tapes_laplace)) then
    let '(Tape_Laplace num den mean zs) := (σ1.(tapes_laplace) !!! α) in
    dmap (λ z, state_upd_tapes_laplace (<[α := (Tape_Laplace num den mean (zs ++ [z]))]>) σ1) (laplace_rat num den mean)
  else dzero.

Lemma state_step_laplace_unfold σ α num den mean zs:
  tapes_laplace σ !! α = Some (Tape_Laplace num den mean zs) ->
  state_step_laplace σ α = dmap (λ z, state_upd_tapes_laplace (<[α := (Tape_Laplace num den mean (zs ++ [z]))]>) σ) (laplace_rat num den mean).
Proof.
  intros H.
  rewrite /state_step_laplace.
  rewrite bool_decide_eq_true_2; last first.
  { by apply elem_of_dom. }
  by rewrite (lookup_total_correct (tapes_laplace σ) α (Tape_Laplace num den mean zs)); last done.
Qed.

Basic properties about the language
Global Instance fill_item_inj Ki : Inj (=) (=) (fill_item Ki).
Proof. induction Ki; intros ???; simplify_eq/=; auto with f_equal. Qed.

Lemma fill_item_val Ki e :
  is_Some (to_val (fill_item Ki e)) is_Some (to_val e).
Proof. intros [v ?]. induction Ki; simplify_option_eq; eauto. Qed.

Lemma val_head_stuck e σ ρ :
  head_step e σ ρ > 0 to_val e = None.
Proof. destruct ρ, e; [|done..]. rewrite /pmf /=. lra. Qed.
Lemma head_ctx_step_val Ki e σ ρ :
  head_step (fill_item Ki e) σ ρ > 0 is_Some (to_val e).
Proof.
  destruct ρ, Ki ;
    rewrite /pmf/= ;
    repeat case_match; clear -H ; inversion H; intros ; (lra || done).
Qed.

A relational characterization of the support of head_step to make it easier to do inversion and prove reducibility easier c.f. lemma below
Inductive head_step_rel : expr state expr state Prop :=
| RecS f x e σ :
  head_step_rel (Rec f x e) σ (Val $ RecV f x e) σ
| PairS v1 v2 σ :
  head_step_rel (Pair (Val v1) (Val v2)) σ (Val $ PairV v1 v2) σ
| InjLS v σ :
  head_step_rel (InjL $ Val v) σ (Val $ InjLV v) σ
| InjRS v σ :
  head_step_rel (InjR $ Val v) σ (Val $ InjRV v) σ
| BetaS f x e1 v2 e' σ :
  e' = subst' x v2 (subst' f (RecV f x e1) e1)
  head_step_rel (App (Val $ RecV f x e1) (Val v2)) σ e' σ
| UnOpS op v v' σ :
  un_op_eval op v = Some v'
  head_step_rel (UnOp op (Val v)) σ (Val v') σ
| BinOpS op v1 v2 v' σ :
  bin_op_eval op v1 v2 = Some v'
  head_step_rel (BinOp op (Val v1) (Val v2)) σ (Val v') σ
| IfTrueS e1 e2 σ :
  head_step_rel (If (Val $ LitV $ LitBool true) e1 e2) σ e1 σ
| IfFalseS e1 e2 σ :
  head_step_rel (If (Val $ LitV $ LitBool false) e1 e2) σ e2 σ
| FstS v1 v2 σ :
  head_step_rel (Fst (Val $ PairV v1 v2)) σ (Val v1) σ
| SndS v1 v2 σ :
  head_step_rel (Snd (Val $ PairV v1 v2)) σ (Val v2) σ
| CaseLS v e1 e2 σ :
  head_step_rel (Case (Val $ InjLV v) e1 e2) σ (App e1 (Val v)) σ
| CaseRS v e1 e2 σ :
  head_step_rel (Case (Val $ InjRV v) e1 e2) σ (App e2 (Val v)) σ
| AllocNS z N v σ l :
  l = fresh_loc σ.(heap)
  N = Z.to_nat z
  (0 < N)%nat ->
  head_step_rel (AllocN (Val (LitV (LitInt z))) (Val v)) σ
    (Val $ LitV $ LitLoc l) (state_upd_heap_N l N v σ)
| LoadS l v σ :
  σ.(heap) !! l = Some v
  head_step_rel (Load (Val $ LitV $ LitLoc l)) σ (of_val v) σ
| StoreS l v w σ :
  σ.(heap) !! l = Some v
  head_step_rel (Store (Val $ LitV $ LitLoc l) (Val w)) σ
    (Val $ LitV LitUnit) (state_upd_heap <[l:=w]> σ)
| RandNoTapeS z N (n : fin (S N)) σ :
  N = Z.to_nat z
  head_step_rel (Rand (Val $ LitV $ LitInt z) (Val $ LitV LitUnit)) σ (Val $ LitV $ LitInt n) σ
| AllocTapeS z N σ l :
  l = fresh_loc σ.(tapes)
  N = Z.to_nat z
  head_step_rel (AllocTape (Val (LitV (LitInt z)))) σ
    (Val $ LitV $ LitLbl l) (state_upd_tapes <[l := (N; []) : tape]> σ)
| AllocTapeLaplaceS num den mean σ l :
  l = fresh_loc σ.(tapes_laplace)
  head_step_rel (AllocTapeLaplace (Val (LitV (LitInt num))) (Val (LitV (LitInt den))) (Val (LitV (LitInt mean)))) σ
    (Val $ LitV $ LitLbl l) (state_upd_tapes_laplace <[l := Tape_Laplace num den mean []]> σ)

| RandTapeS l z N n ns σ :
  N = Z.to_nat z
  σ.(tapes) !! l = Some ((N; n :: ns) : tape)
  head_step_rel (Rand (Val (LitV (LitInt z))) (Val (LitV (LitLbl l)))) σ
    (Val $ LitV $ LitInt $ n) (state_upd_tapes <[l := (N; ns) : tape]> σ)
| RandTapeEmptyS l z N (n : fin (S N)) σ :
  N = Z.to_nat z
  σ.(tapes) !! l = Some ((N; []) : tape)
  head_step_rel (Rand (Val (LitV (LitInt z))) (Val $ LitV $ LitLbl l)) σ (Val $ LitV $ LitInt n) σ
| RandTapeOtherS l z M N ms (n : fin (S N)) σ :
  N = Z.to_nat z
  σ.(tapes) !! l = Some ((M; ms) : tape)
  N M
  head_step_rel (Rand (Val (LitV (LitInt z))) (Val $ LitV $ LitLbl l)) σ (Val $ LitV $ LitInt n) σ

| LaplaceNoTapeS num den mean σ z :
  ((0 < IZR num / IZR den) mean = z)
  head_step_rel (Laplace (Val $ LitV $ LitInt num) (Val $ LitV $ LitInt den) (Val $ LitV $ LitInt mean) (Val $ LitV $ LitUnit)) σ (Val $ LitV $ LitInt z) σ

(* | LaplaceNoTapeS0 num den mean σ :
     (not (0 < IZR num / IZR den)) →
     head_step_rel (Laplace (Val  LitInt num) (Val  LitInt den) (Val  LitInt mean) (Val  LitUnit)) σ (Val  LitInt mean) σ *)


| LaplaceTapeConsS num den mean lbl x xs σ :
  σ.(tapes_laplace) !! lbl = Some (Tape_Laplace num den mean (x :: xs))
  head_step_rel (Laplace (Val $ LitV $ LitInt num) (Val $ LitV $ LitInt den) (Val $ LitV $ LitInt mean) (Val (LitV (LitLbl lbl)))) σ
    (Val $ LitV $ LitInt $ x) (state_upd_tapes_laplace <[lbl := (Tape_Laplace num den mean xs)]> σ)

| LaplaceTapeEmptyS num den mean lbl σ z :
  ((0 < IZR num / IZR den) mean = z)
  (* (0 < IZR num / IZR den) *)
  σ.(tapes_laplace) !! lbl = Some (Tape_Laplace num den mean [])
  head_step_rel (Laplace (Val $ LitV $ LitInt num) (Val $ LitV $ LitInt den) (Val $ LitV $ LitInt mean) (Val (LitV (LitLbl lbl)))) σ
    (Val $ LitV $ LitInt $ z) σ

(* | LaplaceTapeEmptyS0 num den mean lbl σ :
     (not (0 < IZR num / IZR den)) →
     σ.(tapes_laplace) !! lbl = Some (Tape_Laplace num den mean ) →
     head_step_rel (Laplace (Val  LitInt num) (Val  LitInt den) (Val  LitInt mean) (Val (LitV (LitLbl lbl)))) σ
       (Val  LitInt *)


| LaplaceTapeOtherS num den mean lbl num' den' mean' xs σ z :
  σ.(tapes_laplace) !! lbl = Some (Tape_Laplace num' den' mean' xs)
  (not ((num = num') (den = den') (mean = mean')))
  ((0 < IZR num / IZR den) mean = z) (* (0 < IZR num / IZR den) *)
  head_step_rel (Laplace (Val $ LitV $ LitInt num) (Val $ LitV $ LitInt den) (Val $ LitV $ LitInt mean) (Val (LitV (LitLbl lbl)))) σ
    (Val $ LitV $ LitInt $ z) σ

(* | LaplaceTapeOtherS0 num den mean lbl num' den' mean' xs σ :
     σ.(tapes_laplace) !! lbl = Some (Tape_Laplace num' den' mean' xs) →
     (not ((num = num') ∧ (den = den') ∧ (mean = mean'))) →
     (not (0 < IZR num / IZR den)) →
     head_step_rel (Laplace (Val  LitInt num) (Val  LitInt den) (Val  LitInt mean) (Val (LitV (LitLbl lbl)))) σ
       (Val  LitInt *)


| TickS σ z :
  head_step_rel (Tick $ Val $ LitV $ LitInt z) σ (Val $ LitV $ LitUnit) σ.

Create HintDb head_step.
Global Hint Constructors head_step_rel : head_step.
(* 0*)
Global Hint Extern 1
  (head_step_rel (Rand (Val (LitV _)) (Val (LitV LitUnit))) _ _ _) =>
         eapply (RandNoTapeS _ _ 0%fin) : head_step.
Global Hint Extern 1
  (head_step_rel (Rand (Val (LitV _)) (Val (LitV (LitLbl _)))) _ _ _) =>
         eapply (RandTapeEmptyS _ _ _ 0%fin) : head_step.
Global Hint Extern 1
  (head_step_rel (Rand (Val (LitV _)) (Val (LitV (LitLbl _)))) _ _ _) =>
         eapply (RandTapeOtherS _ _ _ _ _ 0%fin) : head_step.

(* Global Hint Extern 1
     (head_step_rel (Laplace (Val (LitV _)) (Val (LitV _)) (Val (LitV (LitInt ?mean))) (Val (LitV LitUnit))) _ _ _) =>
            eapply (LaplaceNoTapeS _ _ _ _ _ mean) : head_step. *)


Inductive state_step_rel : state loc state Prop :=
| AddTapeS α N (n : fin (S N)) ns σ :
  α dom σ.(tapes)
  σ.(tapes) !!! α = ((N; ns) : tape)
  state_step_rel σ α (state_upd_tapes <[α := (N; ns ++ [n]) : tape]> σ).

Ltac inv_head_step :=
  repeat
    match goal with
    | H : context [@bool_decide ?P ?dec] |- _ =>
        try (rewrite bool_decide_eq_true_2 in H; [|done]);
        try (rewrite bool_decide_eq_false_2 in H; [|done]);
        destruct_decide (@bool_decide_reflect P dec); simplify_eq
    | _ => progress simplify_map_eq; simpl in *; inv_distr; repeat case_match; inv_distr
    | H : to_val _ = Some _ |- _ => apply of_to_val in H
    | H : is_Some (_ !! _) |- _ => destruct H
    end.

Lemma head_step_support_equiv_rel e1 e2 σ1 σ2 :
  head_step e1 σ1 (e2, σ2) > 0 head_step_rel e1 σ1 e2 σ2.
Proof.
  split.
  - intros ?. destruct e1; inv_head_step ; try by eauto with head_step.
    + econstructor ; inv_distr ; intuition simplify_eq ; eauto.
    + econstructor ; inv_distr ; intuition simplify_eq ; eauto.
    + intuition simplify_eq. econstructor. done.
  - inversion 1; simplify_map_eq/= ; try case_bool_decide ; try case_decide ; simplify_eq; solve_distr; try done.
    all: intuition auto.
Qed.

Lemma state_step_support_equiv_rel σ1 α σ2 :
  state_step σ1 α σ2 > 0 state_step_rel σ1 α σ2.
Proof.
  rewrite /state_step. split.
  - case_bool_decide; [|intros; inv_distr].
    case_match. intros ?. inv_distr.
    econstructor; eauto with lia.
  - inversion_clear 1.
    rewrite bool_decide_eq_true_2 // H1. solve_distr.
Qed.

Ltac solve_distr :=
  repeat
    match goal with
    | |- (dret _).(pmf) _ > 0 => rewrite dret_1_1 //; lra
    | |- (dret ?x).(pmf) ?x = 1 => by apply dret_1_1
    | |- (dbind _ _).(pmf) _ > 0 => apply dbind_pos; eexists; split
    | |- (dmap _ _).(pmf) _ > 0 =>
        apply dmap_pos; eexists; (split; [done|]); try done
    | |- (dunifP _).(pmf) _ > 0 => apply dunifP_pos
    | |- (dunifv _ _).(pmf) _ > 0 => apply dunifv_pos
    | |- (d_proj_Some _).(pmf) _ > 0 => rewrite d_proj_Some_pos
    end.

Lemma state_step_head_step_not_stuck e σ σ' α :
  state_step σ α σ' > 0 ( ρ, head_step e σ ρ > 0) ( ρ', head_step e σ' ρ' > 0).
Proof.
  rewrite state_step_support_equiv_rel.
  inversion_clear 1.
  split; intros [[e2 σ2] Hs].
  (* TODO: the sub goals used to be solved by simplify_map_eq  *)
  - destruct e; inv_head_step; ( (eexists; solve_distr)) ; try done.
    all: try rewrite dzero_0.
    4-9: eapply laplace_rat_pos ; eauto.
    + exfalso.
      destruct (decide (α = l1)) ; simplify_eq.
      * eapply (Some_ne_None ((_; _) : tape)).
        rewrite -H11. by simplify_map_eq.
      * eapply (Some_ne_None ((_; _) : tape)).
        rewrite -H11.
        rewrite lookup_insert_ne => //.
    + exfalso.
      destruct (decide (α = l1)) ; simplify_eq.
      * eapply (Some_ne_None ((_; _) : tape)).
        rewrite -H11. simplify_map_eq.
      * eapply (Some_ne_None ((_; _) : tape)).
        rewrite -H11.
        rewrite lookup_insert_ne => //.
    + exfalso.
      destruct (decide (α = l1)) ; simplify_eq.
      * eapply (Some_ne_None ((_; _) : tape)).
        rewrite -H10. simplify_map_eq.
      * eapply (Some_ne_None ((_; _) : tape)).
        rewrite -H10.
        rewrite lookup_insert_ne => //.
  - destruct e; inv_head_step; try ((eexists; solve_distr)) ; try done.
    all: try rewrite dzero_0.
    4-9: eapply laplace_rat_pos ; eauto.

    + destruct (decide (α = l1)); simplify_eq.
      * apply not_elem_of_dom_2 in H11. done.
      * rewrite lookup_insert_ne // in H7. rewrite H11 in H7. done.
    + destruct (decide (α = l1)); simplify_eq.
      * rewrite lookup_insert_eq // in H7.
        apply not_elem_of_dom_2 in H11. done.
      * rewrite lookup_insert_ne // in H7. rewrite H11 in H7. done.
    + destruct (decide (α = l1)); simplify_eq.
      * rewrite lookup_insert_eq // in H7.
        apply not_elem_of_dom_2 in H10. done.
      * rewrite lookup_insert_ne // in H7. rewrite H10 in H7. done.
        Unshelve.
        all: try apply (0%fin).
        all: try apply ((of_val $ LitV LitUnit), σ).
        all: try apply (0%nat).
        all: apply [].
Qed.

Lemma state_step_mass σ α :
  α dom σ.(tapes) SeriesC (state_step σ α) = 1.
Proof.
  intros Hdom.
  rewrite /state_step bool_decide_eq_true_2 //=.
  case_match.
  rewrite dmap_mass dunif_mass //.
Qed.

Lemma head_step_mass e σ :
  ( ρ, head_step e σ ρ > 0) SeriesC (head_step e σ) = 1.
Proof.
  intros [[] Hs%head_step_support_equiv_rel].
  inversion Hs;
    repeat (simplify_map_eq/=; solve_distr_mass || (case_match ; try done) ;
            try (case_bool_decide; done)).
Qed.

Lemma fill_item_no_val_inj Ki1 Ki2 e1 e2 :
  to_val e1 = None to_val e2 = None
  fill_item Ki1 e1 = fill_item Ki2 e2 Ki1 = Ki2.
Proof. destruct Ki2, Ki1; naive_solver eauto with f_equal. Qed.

Fixpoint height (e : expr) : nat :=
  match e with
  | Val _ => 1
  | Var _ => 1
  | Rec _ _ e => 1 + height e
  | App e1 e2 => 1 + height e1 + height e2
  | UnOp _ e => 1 + height e
  | BinOp _ e1 e2 => 1 + height e1 + height e2
  | If e0 e1 e2 => 1 + height e0 + height e1 + height e2
  | Pair e1 e2 => 1 + height e1 + height e2
  | Fst e => 1 + height e
  | Snd e => 1 + height e
  | InjL e => 1 + height e
  | InjR e => 1 + height e
  | Case e0 e1 e2 => 1 + height e0 + height e1 + height e2
  | AllocN e1 e2 => 1 + height e1 + height e2
  | Load e => 1 + height e
  | Store e1 e2 => 1 + height e1 + height e2
  | AllocTape e => 1 + height e
  | AllocTapeLaplace e1 e2 e3 => 1 + height e1 + height e2 + height e3
  | Rand e1 e2 => 1 + height e1 + height e2
  | Laplace e1 e2 e3 e4 => 1 + height e1 + height e2 + height e3 + height e4
  | Tick e => 1 + height e
  end.

Definition expr_ord (e1 e2 : expr) : Prop := (height e1 < height e2)%nat.

Lemma expr_ord_wf' h e : (height e h)%nat Acc expr_ord e.
Proof.
  rewrite /expr_ord. revert e; induction h.
  { destruct e; simpl; lia. }
  intros []; simpl;
    constructor; simpl; intros []; eauto with lia.
Defined.

Lemma expr_ord_wf : well_founded expr_ord.
Proof. red; intro; eapply expr_ord_wf'; eauto. Defined.

(* TODO: this proof is slow, but I do not see how to make it faster... *)
Lemma decomp_expr_ord Ki e e' : decomp_item e = Some (Ki, e') expr_ord e' e.
Proof.
  rewrite /expr_ord /decomp_item.
  destruct Ki ; repeat destruct_match ; intros [=] ; subst ; cbn ; lia.
Qed.

Lemma decomp_fill_item Ki e :
  to_val e = None decomp_item (fill_item Ki e) = Some (Ki, e).
Proof. destruct Ki ; simpl ; by repeat destruct_match. Qed.

(* TODO: this proof is slow, but I do not see how to make it faster... *)
Lemma decomp_fill_item_2 e e' Ki :
  decomp_item e = Some (Ki, e') fill_item Ki e' = e to_val e' = None.
Proof.
  rewrite /decomp_item ;
    destruct e ; try done ;
    destruct Ki ; cbn ; repeat destruct_match ; intros [=] ; subst ; auto.
Qed.

Definition get_active (σ : state) : list loc := elements (dom σ.(tapes)).

Lemma state_step_get_active_mass σ α :
  α get_active σ SeriesC (state_step σ α) = 1.
Proof. rewrite elem_of_elements. apply state_step_mass. Qed.

Lemma state_steps_mass σ αs :
  αs get_active σ
  SeriesC (foldlM state_step σ αs) = 1.
Proof.
  induction αs as [|α αs IH] in σ |-* ; intros Hact.
  { rewrite /= dret_mass //. }
  rewrite foldlM_cons.
  rewrite dbind_det //.
  - apply state_step_get_active_mass. set_solver.
  - intros σ' Hσ'. apply IH.
    apply state_step_support_equiv_rel in Hσ'.
    inversion Hσ'; simplify_eq.
    intros α' ?. rewrite /get_active /=.
    apply elem_of_elements, elem_of_dom.
    destruct (decide (α = α')); subst.
    + eexists. rewrite lookup_insert_eq //.
    + rewrite lookup_insert_ne //.
      apply elem_of_dom. eapply elem_of_elements, Hact. by right.
Qed.

Lemma prob_lang_mixin :
  EctxiLanguageMixin of_val to_val fill_item decomp_item expr_ord head_step state_step get_active.
Proof.
  split; apply _ || eauto using to_of_val, of_to_val, val_head_stuck,
    state_step_head_step_not_stuck, state_step_get_active_mass, head_step_mass,
    fill_item_val, fill_item_no_val_inj, head_ctx_step_val,
    decomp_fill_item, decomp_fill_item_2, expr_ord_wf, decomp_expr_ord.
Qed.

End prob_lang.

Language
Canonical Structure prob_ectxi_lang := EctxiLanguage prob_lang.get_active prob_lang.prob_lang_mixin (def_val := prob_lang.def_val).
Canonical Structure prob_ectx_lang := EctxLanguageOfEctxi prob_ectxi_lang.
Canonical Structure prob_lang := LanguageOfEctx prob_ectx_lang.

(* Prefer prob_lang names over ectx_language names. *)
Export prob_lang.

Definition cfg : Type := expr * state.