clutch.con_prob_lang.erasure

From Stdlib Require Import Reals Psatz.
From Coquelicot Require Import Rcomplements Rbar Lim_seq.
From stdpp Require Import fin_maps fin_map_dom.
From clutch.prelude Require Import stdpp_ext.
From clutch.common Require Import con_language con_ectx_language sch_erasable.
From clutch.con_prob_lang Require Import notation lang metatheory.
From clutch.prob Require Import couplings couplings_app mdp.

Set Default Proof Using "Type*".
Local Open Scope R.
Local Opaque state_upd_tapes.

Section erasure_helpers.

  Variable (m : nat).
  Context {sch_int_σ: Type}.
  Context `{TapeOblivious sch_int_σ sch}.
  Hypothesis IH :
     (es1 : list expr) (σ1 : state) α N zs ζ,
    tapes σ1 !! α = Some (N; zs)
    Rcoupl
      (dmap (λ x, x.2.1) (sch_pexec sch m (ζ, (es1, σ1))))
      (dmap (λ x, x.2.1) (dunifP N ≫= (λ z, sch_pexec sch m (ζ, (es1, state_upd_tapes <[α:= (N; zs ++ [z])]> σ1))))) eq.

  Local Lemma ind_case_det e σ α N zs K (n:nat) s es:
    tapes σ !! α = Some (N; zs)
    is_det_head_step e σ = true ->
    Rcoupl
    (dmap (pair s ((λ '(expr', σ', efs), (<[n:=expr']> es ++ efs, σ')) fill_lift' K))
       (head_step e σ)
     ≫= λ a,
          dmap (λ x, x.2.1) (sch_pexec sch m a))
    ((dunifP N
      ≫= λ a0 : fin (S N),
           dmap (pair s ((λ '(expr', σ', efs), (<[n:=expr']> es ++ efs, σ')) fill_lift' K))
             (head_step e (state_upd_tapes <[α:=(N; zs ++ [a0])]> σ)))
     ≫= λ b,
       dmap (λ x, x.2.1) (sch_pexec sch m b)) eq.
  Proof using m IH.
    intros (e2 & (σ2 & (efs & Hdet)))%is_det_head_step_true%det_step_pred_ex_rel.
    erewrite 1!det_head_step_singleton; [|done..].
    setoid_rewrite (det_head_step_singleton ); eauto; last first.
    - eapply det_head_step_upd_tapes; eauto.
    - erewrite det_step_eq_tapes in ; [|done].
      rewrite !dmap_dret.
      rewrite !dret_id_left /=.
      rewrite -!dbind_assoc.
      erewrite (distr_ext (dunifP _ ≫= _) _); last first.
      { intros. apply dbind_pmf_ext; [|done..]. intros.
        rewrite dmap_dret dret_id_left. simpl. done. }
      rewrite -dmap_dbind. apply IH. done.
  Qed.

  Local Lemma ind_case_dzero e σ α N zs K (n:nat) s es:
    tapes σ !! α = Some (N; zs)
    head_step e σ = dzero ->
    Rcoupl
    (dmap (pair s ((λ '(expr', σ', efs), (<[n:=expr']> es ++ efs, σ')) fill_lift' K))
       (head_step e σ)
     ≫= λ a,
          dmap (λ x, x.2.1) (sch_pexec sch m a))
    ((dunifP N
      ≫= λ a0 : fin (S N),
           dmap (pair s ((λ '(expr', σ', efs), (<[n:=expr']> es ++ efs, σ')) fill_lift' K))
             (head_step e (state_upd_tapes <[α:=(N; zs ++ [a0])]> σ)))
     ≫= λ b,
       dmap (λ x, x.2.1) (sch_pexec sch m b)) eq.
  Proof using m IH.
    intros Hz.
    rewrite Hz.
    setoid_rewrite head_step_dzero_upd_tapes; [|by eapply elem_of_dom_2|done].
    rewrite dmap_dzero dbind_dzero dzero_dbind.
    rewrite dbind_dzero.
    apply Rcoupl_dzero_dzero.
  Qed.

  Local Lemma ind_case_alloc (z:Z) σ α N zs K (n:nat) s es:
    tapes σ !! α = Some (N; zs)
    Rcoupl
    (dmap (pair s ((λ '(expr', σ', efs), (<[n:=expr']> es ++ efs, σ')) fill_lift' K))
       (head_step (alloc #z) σ)
     ≫= λ a,
          dmap (λ x, x.2.1) (sch_pexec sch m a))
    ((dunifP N
      ≫= λ a0 : fin (S N),
           dmap (pair s ((λ '(expr', σ', efs), (<[n:=expr']> es ++ efs, σ')) fill_lift' K))
             (head_step (alloc #z) (state_upd_tapes <[α:=(N; zs ++ [a0])]> σ)))
     ≫= λ b,
       dmap (λ x, x.2.1) (sch_pexec sch m b)) eq.
  Proof using m IH.
    intros .
    rewrite dmap_dret dret_id_left.
    rewrite {3}/dmap.
    erewrite dbind_ext_right'; [|done|]; last first.
    { rewrite dbind_assoc dmap_fold.
      rewrite /dmap -dbind_assoc.
      erewrite dbind_ext_right; last first.
      - intros. rewrite dret_id_left. done.
      - done.
    }
    rewrite -dmap_dbind.
    rewrite -dbind_assoc.
    erewrite dbind_ext_right; last first.
    { intros.
      rewrite dret_id_left'.
      rewrite sch_pexec_fold.
      done.
    }
    apply lookup_total_correct in as Hαtot.
    pose proof (elem_fresh_ne _ _ _ ) as Hne.
    erewrite dbind_ext_right; last first.
    { intros ?.
      rewrite -(fresh_loc_upd_some _ _ (N; zs)); [|done].
      rewrite (fresh_loc_upd_swap σ α (N; zs) (_; [])) //. }
    apply IH.
    by apply fresh_loc_lookup.
  Qed.

  Local Lemma ind_case_rand_some (z:Z) σ α α' (N M:nat) n ns ns' K (id:nat) s es:
    N=Z.to_nat z ->
    tapes σ!!α = Some (M;ns') ->
    tapes σ !! α' = Some (N; n::ns)
    Rcoupl
    (dmap (pair s ((λ '(expr', σ', efs), (<[id:=expr']> es ++ efs, σ')) fill_lift' K))
       (head_step (rand(#lbl:α') #z) σ)
     ≫= λ a,
          dmap (λ x, x.2.1) (sch_pexec sch m a))
    ((dunifP M
      ≫= λ a0 ,
           dmap (pair s ((λ '(expr', σ', efs), (<[id:=expr']> es ++ efs, σ')) fill_lift' K))
             (head_step (rand(#lbl:α') #z) (state_upd_tapes <[α:=(M; ns' ++ [a0])]> σ)))
     ≫= λ b,
       dmap (λ x, x.2.1) (sch_pexec sch m b)) eq.
  Proof using m IH.
    intros Hz Hα'.
    apply lookup_total_correct in as Hαtot.
    apply lookup_total_correct in Hα' as Hα'tot.
    destruct (decide (α = α')) as [-> | Hαneql].
    - simplify_eq. rewrite /head_step .
      setoid_rewrite lookup_insert_eq.
      rewrite bool_decide_eq_true_2 //.
      rewrite dmap_dret dret_id_left.
      rewrite -dmap_dbind.
      erewrite dbind_ext_right'; [|done|]; last first.
      { apply dbind_ext_right.
        intros.
        simpl. rewrite dmap_dret. done.
      }
      assert (Haux : n,
                 state_upd_tapes <[α':=(Z.to_nat z; ns ++ [n])]> σ =
                 state_upd_tapes <[α':=(Z.to_nat z; ns ++ [n])]> (state_upd_tapes <[α':=(Z.to_nat z; ns)]> σ)).
      { intros. by rewrite state_upd_tapes_twice. }
      rewrite -!dbind_assoc.
      erewrite dbind_ext_right; last first.
      { intros. by rewrite dret_id_left sch_pexec_fold state_upd_tapes_twice Haux. }
      apply IH.
      apply lookup_insert_eq.
    - rewrite /head_step Hα'.
      rewrite bool_decide_eq_true_2 //.
      setoid_rewrite lookup_insert_ne; [|done].
      rewrite Hα' bool_decide_eq_true_2 //.
      rewrite !dmap_dret !dret_id_left.
      rewrite -dmap_dbind.
      rewrite -dbind_assoc.
      erewrite dbind_ext_right; last first.
      { intros.
        rewrite upd_diff_tape_comm; [|done].
        rewrite dmap_dret dret_id_left sch_pexec_fold //. }
      eapply IH.
      rewrite lookup_insert_ne //.
  Qed.

  Local Lemma ind_case_rand_empty (z:Z) σ α α' (N M:nat) ns K (id:nat) s es:
    M=Z.to_nat z ->
    tapes σ!!α = Some (N;ns) ->
    tapes σ !! α' = Some (M; [])
    Rcoupl
    (dmap (pair s ((λ '(expr', σ', efs), (<[id:=expr']> es ++ efs, σ')) fill_lift' K))
       (head_step (rand(#lbl:α') #z) σ)
     ≫= λ a,
          dmap (λ x, x.2.1) (sch_pexec sch m a))
    ((dunifP N
      ≫= λ a0 ,
           dmap (pair s ((λ '(expr', σ', efs), (<[id:=expr']> es ++ efs, σ')) fill_lift' K))
             (head_step (rand(#lbl:α') #z) (state_upd_tapes <[α:=(N; ns ++ [a0])]> σ)))
     ≫= λ b,
       dmap (λ x, x.2.1) (sch_pexec sch m b)) eq.
  Proof using m IH.
    intros Hz Hα'.
    destruct (decide (α = α')) as [-> | Hαneql].
    + simplify_eq. rewrite /head_step .
      rewrite bool_decide_eq_true_2 //.
      rewrite {1 2}/dmap.
      rewrite -!dbind_assoc.
      eapply (Rcoupl_dbind _ _ _ _ (=)); [ |apply Rcoupl_eq].
      intros ? b ->.
      do 2 rewrite dret_id_left.
      rewrite lookup_insert_eq.
      rewrite bool_decide_eq_true_2 //.
      rewrite dmap_dret dret_id_left.
      rewrite upd_tape_twice.
      rewrite dmap_fold.
      rewrite state_upd_tapes_no_change; [|done].
      apply Rcoupl_eq.
    + rewrite /head_step /=.
      setoid_rewrite lookup_insert_ne; [|done].
      rewrite Hα'.
      rewrite bool_decide_eq_true_2 //.
      rewrite !dbind_assoc.
      rewrite -!dbind_assoc.
      erewrite (dbind_ext_right (dunifP N)); last first.
      { intro.
        rewrite /dmap.
        rewrite -!dbind_assoc.
        apply dbind_ext_right.
        intros.
        rewrite !dret_id_left. done.
      }
      rewrite dbind_comm.
      eapply Rcoupl_dbind; [|apply Rcoupl_eq].
      intros; simplify_eq.
      do 2 rewrite dret_id_left /=.
      rewrite !dbind_assoc.
      rewrite !dmap_fold.
      by apply IH.
  Qed.

  Local Lemma ind_case_rand_some_neq (z:Z) σ α α' (N M:nat) ns ns' K (id:nat) s es:
    NZ.to_nat z ->
    tapes σ!!α = Some (M;ns') ->
    tapes σ !! α' = Some (N; ns)
    Rcoupl
    (dmap (pair s ((λ '(expr', σ', efs), (<[id:=expr']> es ++ efs, σ')) fill_lift' K))
       (head_step (rand(#lbl:α') #z) σ)
     ≫= λ a,
          dmap (λ x, x.2.1) (sch_pexec sch m a))
    ((dunifP M
      ≫= λ a0 ,
           dmap (pair s ((λ '(expr', σ', efs), (<[id:=expr']> es ++ efs, σ')) fill_lift' K))
             (head_step (rand(#lbl:α') #z) (state_upd_tapes <[α:=(M; ns' ++ [a0])]> σ)))
     ≫= λ b,
       dmap (λ x, x.2.1) (sch_pexec sch m b)) eq.
  Proof using m IH.
    intros Hz Hα'.
    rewrite /head_step Hα'.
    rewrite bool_decide_eq_false_2 //.
    destruct (decide (α = α')) as [-> | Heq].
    - simplify_eq.
      setoid_rewrite lookup_insert_eq.
      rewrite bool_decide_eq_false_2 //.
      rewrite /dmap /=.
      rewrite -!dbind_assoc.
      erewrite (dbind_ext_right (dunifP M)); last first.
      { intros. rewrite -!dbind_assoc.
        apply dbind_ext_right.
        intros.
        rewrite !dret_id_left'//.
      }
      rewrite dbind_comm.
      eapply Rcoupl_dbind; [|apply Rcoupl_eq].
      intros; simplify_eq.
      rewrite 2!dret_id_left.
      rewrite -!dmap_dbind.
      by apply IH.
    - setoid_rewrite lookup_insert_ne; [|done].
      rewrite Hα' bool_decide_eq_false_2 //.
      rewrite /dmap.
      rewrite -!dbind_assoc.
      erewrite (dbind_ext_right (dunifP M)); last first.
      { intros. rewrite -!dbind_assoc.
        apply dbind_ext_right.
        intros.
        rewrite !dret_id_left//.
      }
      rewrite dbind_comm.
      eapply Rcoupl_dbind; [|apply Rcoupl_eq].
      intros; simplify_eq.
      rewrite 2!dret_id_left.
      rewrite -dmap_dbind.
      by apply IH.
  Qed.

  Local Lemma ind_case_rand (z:Z) σ α (N M:nat) ns K (id:nat) s es:
    N=Z.to_nat z ->
    tapes σ!!α = Some (M;ns) ->
    Rcoupl
    (dmap (pair s ((λ '(expr', σ', efs), (<[id:=expr']> es ++ efs, σ')) fill_lift' K))
       (head_step (rand #z) σ)
     ≫= λ a,
          dmap (λ x, x.2.1) (sch_pexec sch m a))
    ((dunifP M
      ≫= λ a0 ,
           dmap (pair s ((λ '(expr', σ', efs), (<[id:=expr']> es ++ efs, σ')) fill_lift' K))
             (head_step (rand #z) (state_upd_tapes <[α:=(M; ns ++ [a0])]> σ)))
     ≫= λ b,
       dmap (λ x, x.2.1) (sch_pexec sch m b)) eq.
  Proof using m IH.
    intros Hz .
    rewrite /head_step.
    erewrite (dbind_ext_right (dunifP M)); last first.
    { intro.
      rewrite {1 2}/dmap.
      rewrite -dbind_assoc.
      apply dbind_ext_right.
      intros. rewrite !dret_id_left. done.
    }
    rewrite {2 3}/dmap.
    rewrite -!dbind_assoc.
    erewrite (dbind_ext_right (dunifP M)); last first.
    { intros n. rewrite -!dbind_assoc. done. }
    rewrite dbind_comm.
    eapply Rcoupl_dbind; [|apply Rcoupl_eq].
    intros; simplify_eq.
    do 2 rewrite dret_id_left.
    erewrite (distr_ext (dunifP M ≫=_ )); last first.
    { intros. apply dbind_pmf_ext; [|done..].
      intros. rewrite !dret_id_left. done.
    }
    rewrite -dmap_dbind.
    apply IH; auto.
  Qed.

End erasure_helpers.

Lemma prim_coupl_upd_tapes_dom `{Countable sch_int_σ} m (es1: list expr) σ1 α N ns ζ `{TapeOblivious sch_int_σ sch}:
  σ1.(tapes) !! α = Some (N; ns)
  Rcoupl
    (dmap (λ x, x.2.1) (sch_pexec sch m (ζ, (es1, σ1))))
    (dunifP N ≫=
       (λ n, dmap (λ x, x.2.1) (sch_pexec sch m (ζ, (es1, state_upd_tapes <[α := (N; ns ++ [n])]> σ1)))))
    (=).
Proof.
  rewrite -dmap_dbind.
  revert es1 σ1 α N ns ζ; induction m; intros es1 σ1 α N ns ζ .
  - rewrite /sch_pexec /=.
    rewrite dmap_dret.
    rewrite dmap_dbind.
    erewrite (distr_ext (dunifP N≫=_)); last first.
    { intros. apply dbind_pmf_ext; [|done..].
      intros. rewrite dmap_dret. done.
    }
    rewrite (dret_const (dunifP N)); [apply Rcoupl_eq | apply dunif_mass; lia].
  - rewrite sch_pexec_Sn /sch_step_or_final.
    case_match eqn:He1; simpl in He1.
    + rewrite dret_id_left.
      rewrite -/(sch_pexec sch m (ζ, (es1, σ1))).
      rewrite sch_pexec_is_final; last by rewrite /is_final.
      rewrite dmap_dret. simpl.
      rewrite dmap_dbind.
      erewrite (distr_ext (dunifP N ≫=_)); last first.
      { intros. apply dbind_pmf_ext; [|done..].
        intros. rewrite sch_pexec_is_final; last by rewrite /is_final.
        rewrite dmap_dret. simpl. done.
      }
      rewrite dret_const; [|solve_distr_mass].
      apply Rcoupl_eq.
    + rewrite !dmap_dbind.
      erewrite (distr_ext (dunifP N ≫= _)); last first.
      { intros. apply dbind_pmf_ext; [|done..].
        intros. setoid_rewrite sch_pexec_Sn.
        rewrite /sch_step_or_final/=He1/sch_step/=.
        rewrite !dmap_dbind/=.
        apply dbind_pmf_ext; [done| |done].
        apply dbind_ext_right.
        intros [sch_int_σ' thread_id].
        rewrite /mbind/option_bind He1.
        instantiate (1 := λ '(sch_int_σ', thread_id),
                       dmap (λ mdp_σ' : con_language.cfg con_prob_lang, (sch_int_σ', mdp_σ'))
                         match es1 !! thread_id with
                         | Some expr0 =>
                             match to_val expr0 with
                             | Some _ => dret (es1, state_upd_tapes <[α:=(N; ns ++ [a0])]> σ1)
                             | None =>
                                 dmap (λ '(expr', σ', efs), (<[thread_id:=expr']> es1 ++ efs, σ'))
                                   (prim_step expr0 (state_upd_tapes <[α:=(N; ns ++ [a0])]> σ1))
                             end
                         | None => dret (es1, state_upd_tapes <[α:=(N; ns ++ [a0])]> σ1)
                         end ).
        done.
      }
      rewrite /sch_step/prim_step/=.
      rewrite /mbind/option_bind He1.
      setoid_rewrite sch_tape_oblivious_state_upd_tapes.
      rewrite dbind_assoc dbind_comm -!dbind_assoc'.
      eapply Rcoupl_dbind; last apply Rcoupl_eq.
      intros ?[] ->.
      rewrite /dmap.
      destruct (es1 !! n) as [e1|]eqn:He2; last first.
      { (* we step a thread id that is out of bound *)
        rewrite !dret_id_left. rewrite -/dmap.
        rewrite dmap_fold.
        erewrite (distr_ext(_≫=_)); last first.
        - intros.
          rewrite dbind_assoc. rewrite dmap_fold.
          erewrite dmap_eq; first done; last first.
          + intros [?[??]].
            apply dbind_pmf_ext; [done| |done].
            etrans.
            * apply dbind_ext_right.
              intros. rewrite dret_id_left. done.
            * rewrite dmap_fold. done.
          + done.
        - rewrite {3}/dmap. rewrite -dbind_assoc.
          eapply Rcoupl_eq_trans; first apply IHm; last first.
          + apply Rcoupl_dmap.
            eapply Rcoupl_dbind; last apply Rcoupl_eq.
            intros ?? ->.
            rewrite dret_id_left. eapply Rcoupl_mono; first apply Rcoupl_eq.
            naive_solver.
          + done.
      }
      destruct (to_val e1) eqn:He3.
      { (* the thread we chose is already a value *)
        rewrite !dret_id_left. rewrite -/dmap.
        rewrite dmap_fold.
        erewrite (distr_ext(_≫=_)); last first.
        - intros.
          rewrite dbind_assoc. rewrite dmap_fold.
          erewrite dmap_eq; first done; last first.
          + intros [?[??]].
            apply dbind_pmf_ext; [done| |done].
            etrans.
            * apply dbind_ext_right.
              intros. rewrite dret_id_left. done.
            * rewrite dmap_fold. done.
          + done.
        - rewrite {3}/dmap. rewrite -dbind_assoc.
          eapply Rcoupl_eq_trans; first apply IHm; last first.
          + apply Rcoupl_dmap.
            eapply Rcoupl_dbind; last apply Rcoupl_eq.
            intros ?? ->.
            rewrite dret_id_left. eapply Rcoupl_mono; first apply Rcoupl_eq.
            naive_solver.
          + done.
      }
      rewrite /prim_step/=.
      destruct (decomp e1) as [K ered] eqn:Hdecomp_e1.
      rewrite Hdecomp_e1.
      rewrite !dmap_fold !dmap_comp /=.
      erewrite (distr_ext (dunifP N≫=_)); last first.
      { intros. apply dbind_pmf_ext; [|done..].
      intros. rewrite !dmap_fold !dmap_comp. done. }
      destruct (det_or_prob_or_dzero ered σ1) as [ HD | [HP | HZ]].
      * eapply ind_case_det; [done|done|by apply is_det_head_step_true].
      * inversion HP; simplify_eq.
        -- by eapply ind_case_alloc.
        -- by eapply ind_case_rand_some.
        -- by eapply ind_case_rand_empty.
        -- by eapply ind_case_rand_some_neq.
        -- by eapply ind_case_rand.
      * by eapply ind_case_dzero.
Qed.

Lemma pexec_coupl_step_pexec `{Countable sch_int_σ} m es1 σ1 α bs ζ `{TapeOblivious sch_int_σ sch} :
  σ1.(tapes) !! α = Some bs
   Rcoupl
    (dmap (λ ρ, ρ.2.1) (sch_pexec sch m (ζ, (es1, σ1))))
    (dmap (λ ρ, ρ.2.1) (state_step σ1 α ≫= (λ σ2, sch_pexec sch m (ζ, (es1, σ2)))))
    eq.
Proof.
  intros.
  destruct bs.
  eapply Rcoupl_eq_trans; first eapply prim_coupl_upd_tapes_dom; try done.
  rewrite <-dmap_dbind.
  apply Rcoupl_dmap.
  erewrite state_step_unfold; last done.
  rewrite /dmap.
  rewrite -dbind_assoc.
  eapply Rcoupl_dbind; last apply Rcoupl_eq.
  intros ??->.
  rewrite dret_id_left.
  eapply Rcoupl_mono; first apply Rcoupl_eq.
  intros. naive_solver.
Qed.

Lemma state_step_sch_erasable σ1 α bs :
  σ1.(tapes) !! α = Some bs
  sch_erasable (λ t Heq Hcount sch', TapeOblivious t sch') (state_step σ1 α) σ1.
Proof.
  intros. rewrite /sch_erasable.
  intros.
  symmetry.
  apply Rcoupl_eq_elim.
  by eapply pexec_coupl_step_pexec.
Qed.

Lemma prim_coupl_step_prim `{Hcountable:Countable sch_int_σ} m es1 σ1 α bs ζ `{HTO: TapeOblivious sch_int_σ sch} :
  σ1.(tapes) !! α = Some bs
  Rcoupl
    (sch_exec sch m (ζ, (es1, σ1)))
    (state_step σ1 α ≫= (λ σ2, sch_exec sch m (ζ, (es1, σ2))))
    eq.
Proof.
  intros .
  erewrite (distr_ext _ _); first apply Rcoupl_eq.
  intros.
  erewrite sch_erasable_sch_erasable_val; [done|by eapply state_step_sch_erasable|done].
Qed.

Lemma state_step_sch_erasable_val σ1 α bs :
  σ1.(tapes) !! α = Some bs
  sch_erasable_val (λ t Heq Hcount sch', TapeOblivious t sch') (state_step σ1 α) σ1.
Proof.
  intros. rewrite /sch_erasable_val.
  intros.
  symmetry.
  apply Rcoupl_eq_elim.
  by eapply prim_coupl_step_prim.
Qed.

Lemma iterM_state_step_sch_erasable
  σ1 α bs n:
  σ1.(tapes) !! α = Some bs
  sch_erasable (λ t Heq Hcount sch', TapeOblivious t sch') (iterM n (λ σ, state_step σ α) σ1) σ1.
Proof.
  revert σ1 bs.
  induction n; intros σ1 bs K.
  - simpl. apply dret_sch_erasable.
  - simpl. apply sch_erasable_dbind; first by eapply state_step_sch_erasable.
    intros ? K'.
    destruct bs.
    erewrite state_step_unfold in K'; last done.
    rewrite dmap_pos in K'. destruct K' as (?&->&?).
    eapply IHn. simpl. apply lookup_insert_eq.
Qed.

Lemma limprim_coupl_step_limprim_aux
  `{Hcountable:Countable sch_int_σ} e1 σ1 α bs v ζ `{TapeOblivious sch_int_σ sch}:
  σ1.(tapes) !! α = Some bs
  (sch_lim_exec sch (ζ, (e1, σ1))) v =
  (state_step σ1 α ≫= (λ σ2, sch_lim_exec sch (ζ, (e1, σ2)))) v.
Proof.
  intro Hsome.
  erewrite <-sch_erasable_sch_lim_exec; [done|by eapply state_step_sch_erasable|done].
Qed.

Lemma sch_limprim_coupl_step_sch_limprim
  `{Hcountable:Countable sch_int_σ} (e1 : list expr) σ1 α bs ζ`{TapeOblivious sch_int_σ sch} :
  σ1.(tapes) !! α = Some bs
  Rcoupl
    (sch_lim_exec sch (ζ, (e1, σ1)))
    (state_step σ1 α ≫= (λ σ2, sch_lim_exec sch (ζ, (e1, σ2))))
    eq.
Proof.
  intro Hsome.
  erewrite (distr_ext (sch_lim_exec sch (ζ, (e1, σ1)))); last first.
  - intro a.
    apply (limprim_coupl_step_limprim_aux _ _ _ _ _ _ Hsome).
  - apply Rcoupl_eq.
Qed.

Lemma sch_lim_exec_eq_erasure `{Hcountable:Countable sch_int_σ} αs e σ ζ `{TapeOblivious sch_int_σ sch}:
  αs get_active σ
  sch_lim_exec sch (ζ, (e, σ)) = foldlM state_step σ αs ≫= (λ σ', sch_lim_exec sch (ζ, (e, σ'))).
Proof.
  induction αs as [|α αs IH] in σ |-*.
  { rewrite /= dret_id_left //. }
  intros .
  eapply Rcoupl_eq_elim.
  assert (sch_lim_exec sch (ζ, (e, σ)) = state_step σ α ≫= (λ σ2, sch_lim_exec sch (ζ, (e, σ2)))) as ->.
  { apply distr_ext => v.
    assert (α get_active σ) as Hel; [apply ; left|].
    rewrite /get_active in Hel.
    apply elem_of_elements, elem_of_dom in Hel as [? ?].
    by eapply limprim_coupl_step_limprim_aux. }
  rewrite foldlM_cons -dbind_assoc.
  eapply Rcoupl_dbind; [|eapply Rcoupl_pos_R, Rcoupl_eq].
  intros ?? (-> & Hs%state_step_support_equiv_rel & _).
  inversion_clear Hs.
  rewrite IH; [eapply Rcoupl_eq|].
  intros α' ?. rewrite /get_active /=.
  apply elem_of_elements.
  apply elem_of_dom.
  destruct (decide (α = α')); subst.
  + eexists. rewrite lookup_insert_eq //.
  + rewrite lookup_insert_ne //.
    apply elem_of_dom. eapply elem_of_elements, . by right.
Qed.

Local Definition force_first_thread_scheduler `{Hcountable:Countable sch_int_σ} sch (num:nat) (initial: sch_int_σ)
  `{!TapeOblivious sch_int_σ sch} :=
  Build_scheduler {|
      scheduler_f '(ζ, ρ) :=
        match ζ with
        | None => dret (Some initial, num)
        | Some ζ' => dmap (λ '(ζ', ac), (Some ζ', ac)) (sch (ζ', ρ))
        end
    |}.

Local Lemma force_first_thread_scheduler_tape_oblivious
  `{Hcountable:Countable sch_int_σ} sch (num:nat) (initial: sch_int_σ)
  `{HTO: !TapeOblivious sch_int_σ sch}:
  TapeOblivious _ (force_first_thread_scheduler sch num initial).
Proof.
  rewrite /TapeOblivious/force_first_thread_scheduler/=.
  intros. case_match; try done.
  f_equal.
  by apply HTO.
Qed.

Local Lemma force_first_thread_scheduler_pexec_lemma `{Hcountable:Countable sch_int_σ} ζ ρ sch num initial n `{!TapeOblivious sch_int_σ sch} :
  dmap (λ ρ, ρ.2.1) (sch_pexec sch n (ζ, ρ)) =
  dmap (λ ρ, ρ.2.1) (sch_pexec (force_first_thread_scheduler sch num initial) n (Some ζ, ρ)).
Proof.
  revert ζ ρ.
  induction n.
  - intros. rewrite !sch_pexec_O/=. by rewrite !dmap_dret.
  - intros. rewrite !sch_pexec_Sn. rewrite {2}/force_first_thread_scheduler/scheduler_f.
    rewrite /sch_step_or_final /=.
    case_match.
    + rewrite !dret_id_left. naive_solver.
    + rewrite /sch_step /=. destruct sch. simpl.
      rewrite /dmap -!dbind_assoc'.
      apply dbind_ext_right.
      intros [].
      rewrite dret_id_left.
      rewrite -!dbind_assoc'.
      apply dbind_ext_right.
      intros.
      rewrite !dret_id_left.
      rewrite /dmap in IHn. naive_solver.
Qed.

Local Lemma force_first_thread_scheduler_pexec_lemma' `{Hcountable:Countable sch_int_σ} e1 e es1 σ1 sch num initial n `{!TapeOblivious sch_int_σ sch} :
  (e::es1)!!num=Some e1 ->
  to_val e = None ->
  to_val e1 = None ->
  dmap (λ ρ, ρ.2.1) (prim_step e1 σ1 ≫= λ '(e', s, l), sch_pexec sch n (initial, (<[num:=e']> (e::es1) ++ l, s))) =
  dmap (λ ρ, ρ.2.1) (sch_pexec (force_first_thread_scheduler sch num initial) (S n) (None, (e::es1, σ1))).
Proof.
  intros H Hv Hv'.
  rewrite /force_first_thread_scheduler sch_pexec_Sn.
  rewrite /sch_step_or_final/=.
  rewrite Hv /sch_step -!dbind_assoc'/= dret_id_left Hv H Hv' /dmap -!dbind_assoc'.
  apply dbind_ext_right.
  intros [[]].
  rewrite !dret_id_left.
  epose proof force_first_thread_scheduler_pexec_lemma as K.
  rewrite /dmap in K.
  erewrite K.
  repeat f_equal.
Qed.

Lemma prim_coupl_step_prim_pexec_sch_erasable `{Hcountable:Countable sch_int_σ} e n es1 σ1 ζ e1 (num:nat) μ `{HTO: TapeOblivious sch_int_σ sch} :
  sch_erasable (λ t _ _ sch, TapeOblivious t sch) μ σ1 ->
  (e::es1)!!num=Some e1 ->
  to_val e = None ->
  to_val e1 = None ->
  Rcoupl
    (dmap (λ ρ, ρ.2.1) (prim_step e1 σ1 ≫= λ '(e', s, l), sch_pexec sch n (ζ, (<[num:=e']> (e::es1) ++ l, s))))
    (dmap (λ ρ, ρ.2.1) (μ ≫= (λ σ2, prim_step e1 σ2 ≫= λ '(e', s, l), sch_pexec sch n (ζ, (<[num:=e']> (e::es1) ++ l, s)))))
    eq.
Proof.
  intros H1 H2 H3 H4.
  erewrite force_first_thread_scheduler_pexec_lemma'; try done.
  eapply Rcoupl_eq_trans.
  - erewrite <-H1; last apply force_first_thread_scheduler_tape_oblivious.
    apply Rcoupl_eq.
  - rewrite /dmap -!dbind_assoc. eapply Rcoupl_dbind; last apply Rcoupl_eq.
    intros ??->.
    apply Rcoupl_eq_sym.
    unshelve epose proof force_first_thread_scheduler_pexec_lemma' as K; try done.
    rewrite /dmap in K. rewrite K; first apply Rcoupl_eq; done.
Qed.

Lemma prim_coupl_step_prim_pexec' `{Hcountable:Countable sch_int_σ} e n es1 σ1 α bs ζ e1 (num:nat) `{HTO: TapeOblivious sch_int_σ sch} :
  σ1.(tapes) !! α = Some bs
  (e::es1)!!num=Some e1 ->
  to_val e = None ->
  to_val e1 = None ->
  Rcoupl
    (dmap (λ ρ, ρ.2.1) (prim_step e1 σ1 ≫= λ '(e', s, l), sch_pexec sch n (ζ, (<[num:=e']> (e::es1) ++ l, s))))
    (dmap (λ ρ, ρ.2.1) (state_step σ1 α ≫= (λ σ2, prim_step e1 σ2 ≫= λ '(e', s, l), sch_pexec sch n (ζ, (<[num:=e']> (e::es1) ++ l, s)))))
    eq.
Proof.
  intros H1 H2 H3 H4.
  apply prim_coupl_step_prim_pexec_sch_erasable; try done.
  by eapply state_step_sch_erasable.
Qed.

Local Lemma force_first_thread_scheduler_lemma `{Hcountable:Countable sch_int_σ} ζ ρ sch num initial n `{!TapeOblivious sch_int_σ sch} :
  sch_exec sch n (ζ, ρ) = sch_exec (force_first_thread_scheduler sch num initial) n (Some ζ, ρ).
Proof.
  revert ζ ρ.
  induction n.
  - intros. by rewrite /sch_exec.
  - intros. rewrite !sch_exec_Sn. rewrite {2}/force_first_thread_scheduler/scheduler_f.
    rewrite /sch_step_or_final /=.
    case_match.
    + rewrite !dret_id_left. naive_solver.
    + rewrite /sch_step /=. destruct sch. simpl.
      rewrite /dmap -!dbind_assoc'.
      apply dbind_ext_right.
      intros [].
      rewrite dret_id_left.
      rewrite -!dbind_assoc'.
      apply dbind_ext_right.
      intros.
      by rewrite !dret_id_left.
Qed.

Local Lemma force_first_thread_scheduler_lemma' `{Hcountable:Countable sch_int_σ} e1 e es1 σ1 sch num initial n `{!TapeOblivious sch_int_σ sch} :
  (e::es1)!!num=Some e1 ->
  to_val e = None ->
  to_val e1 = None ->
  (prim_step e1 σ1 ≫= λ '(e', s, l), sch_exec sch n (initial, (<[num:=e']> (e::es1) ++ l, s))) = sch_exec (force_first_thread_scheduler sch num initial) (S n) (None, (e::es1, σ1)).
Proof.
  intros H Hv Hv'.
  rewrite /force_first_thread_scheduler{2}/sch_exec.
  simpl.
  rewrite Hv /sch_step -!dbind_assoc'/= dret_id_left Hv H Hv' /dmap -!dbind_assoc'.
  apply dbind_ext_right.
  intros [[]].
  rewrite !dret_id_left.
  erewrite force_first_thread_scheduler_lemma.
  repeat f_equal.
Qed.

Lemma prim_coupl_step_prim_sch_erasable `{Hcountable:Countable sch_int_σ} e n es1 σ1 ζ e1 (num:nat) μ `{HTO: TapeOblivious sch_int_σ sch} :
  sch_erasable (λ t _ _ sch, TapeOblivious t sch) μ σ1 ->
  (e::es1)!!num=Some e1 ->
  to_val e = None ->
  to_val e1 = None ->
  Rcoupl
    (prim_step e1 σ1 ≫= λ '(e', s, l), sch_exec sch n (ζ, (<[num:=e']> (e::es1) ++ l, s)))
    (μ ≫= (λ σ2, prim_step e1 σ2 ≫= λ '(e', s, l), sch_exec sch n (ζ, (<[num:=e']> (e::es1) ++ l, s))))
    eq.
Proof.
  intros H1 H2 H3 H4.
  apply sch_erasable_sch_erasable_val in H1.
  erewrite force_first_thread_scheduler_lemma'; try done.
  eapply Rcoupl_eq_trans.
  - erewrite <-H1; last apply force_first_thread_scheduler_tape_oblivious.
    apply Rcoupl_eq.
  - eapply Rcoupl_dbind; last apply Rcoupl_eq.
    intros ??->.
    apply Rcoupl_eq_sym.
    rewrite force_first_thread_scheduler_lemma'; try done.
    apply Rcoupl_eq.
Qed.

Lemma prim_coupl_step_prim' `{Hcountable:Countable sch_int_σ} e n es1 σ1 α bs ζ e1 (num:nat) `{HTO: TapeOblivious sch_int_σ sch} :
  σ1.(tapes) !! α = Some bs
  (e::es1)!!num=Some e1 ->
  to_val e = None ->
  to_val e1 = None ->
  Rcoupl
    (prim_step e1 σ1 ≫= λ '(e', s, l), sch_exec sch n (ζ, (<[num:=e']> (e::es1) ++ l, s)))
    (state_step σ1 α ≫= (λ σ2, prim_step e1 σ2 ≫= λ '(e', s, l), sch_exec sch n (ζ, (<[num:=e']> (e::es1) ++ l, s))))
    eq.
Proof.
  intros H1 H2 H3 H4.
  apply prim_coupl_step_prim_sch_erasable; try done.
  by eapply state_step_sch_erasable.
Qed.