clutch.delay_prob_lang.urn_erasable

From Stdlib Require Import Reals Psatz.
From stdpp Require Import gmap countable.
From clutch.delay_prob_lang Require Import lang metatheory.
From clutch.prob Require Export couplings distribution mdp.

Section urn_erasable.
  Definition urn_erasable (μ : distr (gmap loc urn)) m:=
     (μ ≫= urns_f_distr) = urns_f_distr m.

  Lemma urn_erasable_dbind (μ1 : distr _) (μ2 : _ distr (_)) σ:
    urn_erasable μ1 σ ( σ', μ1 σ' > 0 urn_erasable (μ2 σ') σ') urn_erasable (μ1 ≫= μ2) σ.
  Proof.
    intros H1 H2.
    rewrite /urn_erasable.
    intros. rewrite -dbind_assoc'.
    erewrite <-H1. rewrite /dmap.
    apply dbind_eq; last naive_solver.
    intros. by apply H2.
  Qed.

  Lemma urn_erasable_dbind_predicate `{Countable A} (μ : distr _) μ1 μ2 σ (f: A bool):
    SeriesC μ = 1
    urn_erasable μ1 σ
    urn_erasable μ2 σ
    urn_erasable (dbind (λ x, if f x then μ1 else μ2) μ) σ.
  Proof.
    rewrite /urn_erasable.
    intros Hsum H1 H2. rewrite -dbind_assoc'.
    apply distr_ext. intros v.
    rewrite {1}/dbind{1}/dbind_pmf{1}/pmf/=.
    erewrite (SeriesC_ext _ (λ a : A, μ a * urns_f_distr σ v)).
    { erewrite SeriesC_scal_r. rewrite Hsum. apply Rmult_1_l. }
    intros. rewrite /dmap in H1 H2. case_match.
    - by rewrite H1.
    - by rewrite H2.
  Qed.

  Lemma urn_erasable_same_support_set μ m:
    urn_erasable μ m ->
     m', μ m' > 0 -> urns_support_set m' = urns_support_set m.
  Proof.
    rewrite /urn_erasable.
    intros H m' Hpos.
    pose proof urns_f_distr_witness m' as Hexists.
    destruct Hexists as [f Hexists].
    assert (urns_f_distr m f > 0) as Hpos'.
    { rewrite -H. rewrite dbind_pos.
      naive_solver. }
    apply urns_f_distr_pos in Hpos', Hexists.
    apply urns_f_valid_support in Hpos', Hexists.
    set_solver.
  Qed.

End urn_erasable.

Section urn_erasable_functions.

  Lemma dret_urn_erasable σ:
    urn_erasable (dret σ) σ.
  Proof.
    intros.
    rewrite /urn_erasable.
    intros. rewrite dret_id_left'. done.
  Qed.

  Lemma complete_split_urn_erasable (m:gmap loc urn) u s :
  s ->
  m!!u=Some (urn_unif s) ->
  urn_erasable (unif_set s ≫= (λ y, dret (<[u:= urn_unif {[y]}]> m))) m.
  Proof.
    rewrite /urn_erasable.
    intros. rewrite -dbind_assoc'. by erewrite urns_f_distr_split.
  Qed.

  Lemma two_split_urn_erasable m u s1 s2 s (H:(0<=size s1/size s<=1)%R):
    s1 ->
    s2 ->
    s1 ##s2 ->
    s1s2 = s ->
    m!!u=Some (urn_unif s)->
    urn_erasable (dmap (λ b, if b then (<[u:= urn_unif s1]> m) else (<[u:= urn_unif s2]> m)) (biased_coin _ H)) m.
  Proof.
    intros Hnonempty1 Hnonempty2 Hdisjoint Hunion Hlookup.
    rewrite /urn_erasable.
    rewrite /dmap -dbind_assoc'.
    apply distr_ext.
    intros f.
    replace (m) with (<[u:=urn_unif s]> (delete u m)) at 1; last first.
    { apply map_eq.
      intros u'.
      destruct (decide (u=u')); subst; by simplify_map_eq.
    }
    rewrite urns_f_distr_insert; simpl; simplify_map_eq; try done; last by apply union_positive_l_alt_L.
    rewrite dbind_comm.
    replace (urns_f_distr_compute_distr _) with (dbind (λ b, if b then urns_f_distr_compute_distr (urn_unif s1) else urns_f_distr_compute_distr (urn_unif s2)) (biased_coin (size s1/size (s1 s2)) H)); last first.
    - apply distr_ext.
      intros x.
      destruct (decide (xs1)).
      + rewrite {1}/dbind/dbind_pmf{1}/pmf. rewrite SeriesC_finite_foldr/=.
        rewrite Rplus_0_r.
        rewrite /biased_coin/biased_coin_pmf{1 3}/pmf.
        rewrite /urns_f_distr_compute_distr/urns_f_distr_compute/pmf.
        rewrite bool_decide_eq_true_2; last done.
        rewrite bool_decide_eq_false_2; last set_solver.
        rewrite bool_decide_eq_true_2; last set_solver.
        rewrite Rmult_0_r Rplus_0_r.
        rewrite Rdiv_def.
        rewrite Rmult_comm.
        rewrite -Rmult_assoc.
        replace (/_ * _) with 1%R; first lra.
        rewrite Rmult_comm.
        rewrite Rmult_inv_r; first lra.
        apply not_0_INR.
        intros Hcontra.
        apply size_empty_inv in Hcontra.
        rewrite leibniz_equiv_iff in Hcontra.
        naive_solver.
      + destruct (decide (xs2)).
        * rewrite {1}/dbind/dbind_pmf{1}/pmf. rewrite SeriesC_finite_foldr/=.
          rewrite Rplus_0_r.
          rewrite /biased_coin/biased_coin_pmf{1 3}/pmf.
          rewrite /urns_f_distr_compute_distr/urns_f_distr_compute/pmf.
          rewrite bool_decide_eq_false_2; last set_solver.
          rewrite bool_decide_eq_true_2; last done.
          rewrite bool_decide_eq_true_2; last set_solver.
          rewrite Rmult_0_r Rplus_0_l.
          rewrite Rdiv_def.
          rewrite -(Rdiv_diag (size s1 + size s2)); last first.
          -- rewrite -plus_INR.
             apply not_0_INR.
             assert (0<size s1)%nat; last lia.
             destruct (size s1) eqn:K; last lia.
             exfalso.
             apply size_empty_inv in K.
             rewrite leibniz_equiv_iff in K. naive_solver.
          -- rewrite Rdiv_def.
             rewrite Rmult_plus_distr_r.
             rewrite size_union; last done.
             rewrite -plus_INR.
             rewrite Rplus_minus_l.
             rewrite Rmult_comm.
             rewrite -Rmult_assoc.
             rewrite Rinv_l; first lra.
             apply not_0_INR.
             intros K.
             apply size_empty_inv in K.
             set_solver.
        * rewrite {1}/dbind/dbind_pmf{1}/pmf. rewrite SeriesC_finite_foldr/=.
          rewrite Rplus_0_r.
          rewrite /biased_coin/biased_coin_pmf{1 3}/pmf.
          rewrite /urns_f_distr_compute_distr/urns_f_distr_compute/pmf.
          rewrite bool_decide_eq_false_2; last set_solver.
          rewrite bool_decide_eq_false_2; last done.
          rewrite bool_decide_eq_false_2; last set_solver.
          lra.
    - rewrite -dbind_assoc'. apply dbind_pmf_ext; last done; last by apply distr_ext.
      intros []?; rewrite dret_id_left -/(urns_f_distr _);
      rewrite -insert_delete_eq;
        rewrite urns_f_distr_insert; try rewrite Hlookup/=; simplify_map_eq; try done; by rewrite dbind_comm.
  Qed.
End urn_erasable_functions.