clutch.prob.couplings_dp

From Stdlib Require Import Reals Psatz.
From Stdlib.ssr Require Import ssreflect ssrfun.
From Coquelicot Require Import Rcomplements Lim_seq Rbar.
From stdpp Require Export countable.
From clutch.prelude Require Export base Coquelicot_ext Reals_ext stdpp_ext fiber_bounds.
From clutch.prob Require Export countable_sum distribution couplings graded_predicate_lifting couplings_app couplings_exp.

Open Scope R.

Section couplings.
  Context `{Countable A, Countable B, Countable A', Countable B'}.
  Context (μ1 : distr A) (μ2 : distr B) (S : A -> B -> Prop).

  Definition DPcoupl (ε : R) (δ : R):=
     (f : A R) (g : B -> R)
      (Hf : a, 0 <= f a <= 1)
      (Hg : b, 0 <= g b <= 1)
      (Hfg : a b, S a b -> f a <= g b),
      SeriesC (λ a, μ1 a * f a) <= exp(ε) * SeriesC (λ b, μ2 b * g b) + δ.

End couplings.

Section couplings_theory.
  (* Context `{Countable A, Countable B, Countable A', Countable B'}. *)

  Lemma exp_mono (r s : R) : r <= s -> exp r <= exp s.
  Proof.
    intros [| ->].
    + left.
      by apply exp_increasing.
    + lra.
  Qed.

  Lemma exp_pos_ge_1 (r : R) : 0 <= r -> 1 <= exp r.
  Proof.
    intros.
    trans (exp 0); last by apply exp_mono.
    by rewrite exp_0.
  Qed.

  Lemma DPcoupl_mono `{Countable A', Countable B'} (μ1 μ1': distr A') (μ2 μ2': distr B') R R' ε ε' δ δ':
    ( a, μ1 a = μ1' a) ->
    ( b, μ2 b = μ2' b) ->
    ( x y, R x y -> R' x y) ->
    (ε <= ε') ->
    (δ <= δ') ->
    DPcoupl μ1 μ2 R ε δ ->
    DPcoupl μ1' μ2' R' ε' δ'.
  Proof.
    intros Hμ1 Hμ2 HR Hcoupl f g Hf Hg Hfg.
    specialize (Hcoupl f g Hf Hg).
    replace (μ1') with μ1; last by apply distr_ext.
    replace (μ2') with μ2; last by apply distr_ext.
    trans (exp(ε) * SeriesC (λ b, μ2 b * g b) + δ).
    - apply Hcoupl.
      naive_solver.
    - apply Rplus_le_compat; auto.
      apply Rmult_le_compat_r; first by series.
      by apply exp_mono.
  Qed.

  Lemma DPcoupl_1 `{Countable A', Countable B'} (μ1 : distr A') (μ2 : distr B') R ε δ:
    (1 <= δ) -> DPcoupl μ1 μ2 R ε δ.
  Proof.
    rewrite /DPcoupl.
    intros f g Hf Hg Hfg.
    trans 1.
    - trans (SeriesC μ1); last auto.
      apply SeriesC_le; last auto.
      real_solver.
    - replace 1 with (0+1); last lra.
      apply Rplus_le_compat; last lra.
      apply Rmult_le_pos; [left; apply exp_pos |].
      apply SeriesC_ge_0'; real_solver.
  Qed.

  Lemma DPcoupl_mon_grading `{Countable A', Countable B'} (μ1 : distr A') (μ2 : distr B') (R : A' B' Prop) ε1 ε2 δ1 δ2:
    (ε1 <= ε2) ->
    (δ1 <= δ2) ->
    DPcoupl μ1 μ2 R ε1 δ1 ->
    DPcoupl μ1 μ2 R ε2 δ2.
  Proof.
    intros Hleq.
    by apply DPcoupl_mono.
  Qed.

  Lemma DPcoupl_dret `{Countable A, Countable B} (a : A) (b : B) (R : A B Prop) ε δ :
    0 <= ε
    0 <= δ
    R a b DPcoupl (dret a) (dret b) R ε δ.
  Proof.
    intros HR f g Hf Hg Hfg.
    assert (SeriesC (λ a0 : A, dret a a0 * f a0) = f a) as ->.
    { rewrite <-(SeriesC_singleton a (f a)).
      rewrite /pmf/=/dret_pmf ; series.
    }
    assert (SeriesC (λ b0 : B, dret b b0 * g b0) = g b) as ->.
    { rewrite <-(SeriesC_singleton b (g b)).
      rewrite /pmf/=/dret_pmf ; series.
    }
    specialize (Hfg _ _ HR).
    rewrite <- (Rmult_1_l (f a)).
    rewrite <- (Rplus_0_r (1 * f a)).
    apply Rplus_le_compat; auto.
    apply Rmult_le_compat; [real_solver | real_solver | | auto].
    by apply exp_pos_ge_1.
  Qed.

  Lemma DPcoupl_dbind_adv_lhs `{Countable A, Countable B, Countable A', Countable B'} (f : A distr A') (g : B distr B')
    (μ1 : distr A) (μ2 : distr B) (S : A B Prop) (S' : A' B' Prop)
    ε1 ε2 δ1 δ2 (Δ2 : A R) :
    (0 <= δ1) ( a, 0 <= (Δ2 a) <= 1)
    (* (SeriesC (λ a, μ1 a * (E2 a)) <= ε2) → *)
    (SeriesC (λ a, μ1 a * (Δ2 a)) = δ2)
    ( a b, S a b DPcoupl (f a) (g b) S' ε2 (Δ2 a))
    DPcoupl μ1 μ2 S ε1 δ1
    DPcoupl (dbind f μ1) (dbind g μ2) S' (ε1 + ε2) (δ1 + δ2).
  Proof.
    intros Hδ1 HΔ2 <- Hcoup_fg Hcoup_S h1 h2 Hh1pos Hh2pos Hh1h2S.
    rewrite {-3}/pmf/=/dbind_pmf.
    (* To use the hypothesis that we have an R-ACoupling up to ε1 for μ1, μ2,
       we have to rewrite the sums in such a way as to isolate (the expectation
       of) a random variable X on the LHS and Y on the RHS, and ε1 on the
       RHS. *)

    (* First step: rewrite the LHS into a RV X on μ1. *)
    setoid_rewrite <- SeriesC_scal_r.
    rewrite <-(fubini_pos_seriesC (λ '(a,x), μ1 x * f x a * h1 a)).

    (* Boring Fubini sideconditions. *)
    2: { real_solver. }
    2: { intro a'.
         (* specialize (Hh1pos a'). *)
         apply (ex_seriesC_le _ μ1); auto.
         intro a; split.
         + apply Rmult_le_pos.
           * real_solver.
           * real_solver.
         + rewrite <- Rmult_1_r.
           rewrite Rmult_assoc.
           apply Rmult_le_compat_l; auto.
           rewrite <- Rmult_1_r.
           apply Rmult_le_compat; real_solver. }
    2: { setoid_rewrite SeriesC_scal_r.
         apply (ex_seriesC_le _ (λ a : A', SeriesC (λ x : A, μ1 x * f x a))); auto.
         + series.
         + apply (pmf_ex_seriesC (dbind f μ1)). }

    (* LHS: Pull the (μ1 b) factor out of the inner sum. *)
    assert (SeriesC (λ b : A, SeriesC (λ a : A', μ1 b * f b a * h1 a)) =
              SeriesC (λ b : A, μ1 b * SeriesC (λ a : A', f b a * h1 a))) as ->.
    { setoid_rewrite <- SeriesC_scal_l. series. }

    (* Second step: rewrite the RHS into a RV Y on μ2. *)
    (* RHS: Fubini. *)
    rewrite <-(fubini_pos_seriesC (λ '(b,x), μ2 x * g x b * h2 b)).
    2: by series.
    2:{ intro b'.
        specialize (Hh2pos b').
        apply (ex_seriesC_le _ μ2) ; auto.
        intro b; split.
        - series.
        - do 2 rewrite <- Rmult_1_r. series. }
    2:{ setoid_rewrite SeriesC_scal_r.
        apply (ex_seriesC_le _ (λ a : B', SeriesC (λ b : B, μ2 b * g b a))); auto.
        - intros b'; specialize (Hh2pos b'); split.
          + apply Rmult_le_pos; [ | lra].
            apply (pmf_pos ((dbind g μ2)) b').
          + rewrite <- Rmult_1_r.
            apply Rmult_le_compat_l; auto.
            * apply SeriesC_ge_0'. real_solver.
            * real_solver.
        - apply (pmf_ex_seriesC (dbind g μ2)). }

    (* RHS: Factor out (μ2 b) *)
    assert (SeriesC (λ b : B, SeriesC (λ a : B', μ2 b * g b a * h2 a))
            = SeriesC (λ b : B, μ2 b * SeriesC (λ a : B', g b a * h2 a))) as ->.
    { apply SeriesC_ext; intro.
      rewrite <- SeriesC_scal_l.
      apply SeriesC_ext; real_solver. }

    rewrite -Rplus_assoc.
    apply Rle_minus_l.

    (* To construct X, we want to push ε2 into the inner sum. We don't do this
       directly, because X might be larger than 1, but
       our assumption on the ε1 R-ACoupling requires it to be valued in 0,1.
       Instead, we take min(1, exp(ε2) * (Σ(a:A')(f b a * h1 a))).
       ALT: could use a more fine-grained min inside the sum?
     *)


    assert (exp (ε1) * SeriesC (λ b : B, μ2 b * (Rmin 1 (exp (ε2) * SeriesC (λ a : B', g b a * h2 a)))) + δ1
            <= exp (ε1 + ε2) * SeriesC (λ b : B, μ2 b * SeriesC (λ a : B', g b a * h2 a)) + δ1) as <-.
    {
       apply Rplus_le_compat_r.
       rewrite exp_plus.
       rewrite Rmult_assoc.
       rewrite -(SeriesC_scal_l _ (exp ε2)).
       apply Rmult_le_compat_l; [left; apply exp_pos |].
       apply SeriesC_le.
       - intros b; split.
         + apply Rmult_le_pos; auto.
           apply Rmin_glb; [lra |].
           apply Rmult_le_pos; [left; apply exp_pos |].
           apply SeriesC_ge_0'.
           real_solver.
         + rewrite Rmult_min_distr_l; auto.
           etrans; [apply Rmin_r | lra].
       - apply ex_seriesC_scal_l.
         apply (ex_seriesC_le _ μ2); auto.
         intro b; split.
         + apply Rmult_le_pos; auto.
           apply SeriesC_ge_0'.
           intro; apply Rmult_le_pos; auto.
           apply Hh2pos.
         + rewrite <- Rmult_1_r.
           apply Rmult_le_compat_l; auto.
           apply (Rle_trans _ (SeriesC (g b))); auto.
           apply SeriesC_le; auto.
           real_solver.
    }

    assert (
        SeriesC (λ b : A, μ1 b * SeriesC (λ a : A', f b a * h1 a)) -
          SeriesC (λ a, μ1 a * Δ2 a)
        <= SeriesC (λ b : A, μ1 b * Rmax 0 (SeriesC (λ a : A', f b a * h1 a) - Δ2 b))
      ) as ->.
    {
      apply (Rle_trans _ (SeriesC (λ b : A, μ1 b * SeriesC (λ a : A', f b a * h1 a) - μ1 b * Δ2 b))).
      - rewrite SeriesC_minus.
        + apply Rplus_le_compat_l.
          apply Ropp_le_contravar.
          done.
        + apply (ex_seriesC_le _ μ1); auto.
         intro a; split.
         * apply Rmult_le_pos; auto.
           apply SeriesC_ge_0'.
           intro; apply Rmult_le_pos; auto.
           apply Hh1pos.
         * rewrite <- Rmult_1_r.
           apply Rmult_le_compat_l; auto.
           apply (Rle_trans _ (SeriesC (f a))); auto.
           apply SeriesC_le; auto.
           real_solver.
        + apply (ex_seriesC_le _ μ1); auto.
          intros; real_solver.
      - apply SeriesC_le'.
        + intros a.
          rewrite -Rmult_minus_distr_l.
          apply Rmult_le_compat_l; auto.
          apply Rmax_r.
        + apply ex_seriesC_plus.
          * apply (ex_seriesC_le _ μ1); auto.
            intro a; split.
            ** apply Rmult_le_pos; auto.
               apply SeriesC_ge_0'.
               intro; apply Rmult_le_pos; auto.
               apply Hh1pos.
            ** rewrite <- Rmult_1_r.
               apply Rmult_le_compat_l; auto.
               apply (Rle_trans _ (SeriesC (f a))); auto.
               apply SeriesC_le; auto.
               real_solver.
          * apply (ex_seriesC_ext (λ x, -1 * (μ1 x * Δ2 x))).
            1: intros; real_solver.
            apply ex_seriesC_scal_l.
            apply (ex_seriesC_le _ μ1); auto.
            intros; real_solver.
        + apply (ex_seriesC_le _ μ1); auto.
          intros a; split.
          * apply Rmult_le_pos; auto.
            apply Rmax_l.
          * rewrite -{2}(Rmult_1_r (μ1 a)).
            apply Rmult_le_compat_l; auto.
            apply Rmax_lub; first by lra.
            apply Rle_minus_l.
            apply (Rle_trans _ 1); last by real_solver.
            apply (Rle_trans _ (SeriesC (f a))); auto.
            apply SeriesC_le; auto.
            real_solver.
    }

    (*
        Now we instantiate the lifting definitions and use them to prove the
        inequalities
    *)

    rewrite /DPcoupl in Hcoup_S.
    apply Hcoup_S.
    + intro; split; first apply Rmax_l.
      apply Rmax_lub; first by lra.
      apply Rle_minus_l.
      apply (Rle_trans _ 1); last by real_solver.
      apply (Rle_trans _ (SeriesC (f a))); auto.
      apply SeriesC_le; auto; real_solver.
    + intro; split.
      * apply Rmin_glb; [lra |].
        apply Rmult_le_pos.
        ** left. apply exp_pos.
        ** apply SeriesC_ge_0'; intro b'.
           specialize (Hh2pos b'); real_solver.
      * apply Rmin_l.

    + intros a b Rab.
      apply Rmin_glb; apply Rmax_lub; first by lra.
      * apply Rle_minus_l.
        apply (Rle_trans _ 1); last by real_solver.
        apply (Rle_trans _ (SeriesC (f a))); auto.
        apply SeriesC_le; auto; real_solver.
      * series.
        left.
        by apply exp_pos.
      * apply Rle_minus_l.
        by apply Hcoup_fg.
  Qed.

  (* Advanced composition on the right, specialized to the case where the first coupling has parameter ε_1 = 0. This is
     sufficient for recovering the rules used by Approxis (since they were designed for the case ε=0). Maybe the proof
     could be generalized to support non-zero ε_1. Ideally, DPcoupl_dbind should be a corollary of the general lemma. *)

  Lemma DPcoupl_dbind_adv_rhs_specialized `{Countable A, Countable B, Countable A', Countable B'} (f : A distr A') (g : B distr B')
    (μ1 : distr A) (μ2 : distr B) (S : A B Prop) (S' : A' B' Prop)
    ε2 δ1 δ2 (Δ2 : B R) :
    (0 <= δ1) ( b, 0 <= (Δ2 b) <= 1)
    (SeriesC (λ b, μ2 b * (Δ2 b)) = δ2)
    ( a b, S a b DPcoupl (f a) (g b) S' ε2 (Δ2 b))
    DPcoupl μ1 μ2 S 0 δ1
    DPcoupl (dbind f μ1) (dbind g μ2) S' (0 + ε2) (δ1 + δ2).
  Proof.
    intros Hδ1 HΔ2 <- Hcoup_fg Hcoup_S h1 h2 Hh1pos Hh2pos Hh1h2S.
    rewrite {-3}/pmf/=/dbind_pmf.
    (* To use the hypothesis that we have an R-ACoupling up to ε1 for μ1, μ2,
       we have to rewrite the sums in such a way as to isolate (the expectation
       of) a random variable X on the LHS and Y on the RHS, and ε1 on the
       RHS. *)

    (* First step: rewrite the LHS into a RV X on μ1. *)
    setoid_rewrite <- SeriesC_scal_r.
    rewrite <-(fubini_pos_seriesC (λ '(a,x), μ1 x * f x a * h1 a)).

       (* Boring Fubini sideconditions. *)
       2: { real_solver. }
       2: { intro a'.
            (* specialize (Hh1pos a'). *)
            apply (ex_seriesC_le _ μ1); auto.
            intro a; split.
            + apply Rmult_le_pos.
              * real_solver.
              * real_solver.
            + rewrite <- Rmult_1_r.
              rewrite Rmult_assoc.
              apply Rmult_le_compat_l; auto.
              rewrite <- Rmult_1_r.
              apply Rmult_le_compat; real_solver. }
       2: { setoid_rewrite SeriesC_scal_r.
            apply (ex_seriesC_le _ (λ a : A', SeriesC (λ x : A, μ1 x * f x a))); auto.
            + series.
            + apply (pmf_ex_seriesC (dbind f μ1)). }

       (* LHS: Pull the (μ1 b) factor out of the inner sum. *)
       assert (SeriesC (λ b : A, SeriesC (λ a : A', μ1 b * f b a * h1 a)) =
                 SeriesC (λ b : A, μ1 b * SeriesC (λ a : A', f b a * h1 a))) as ->.
       { setoid_rewrite <- SeriesC_scal_l. series. }

       (* Second step: rewrite the RHS into a RV Y on μ2. *)
       (* RHS: Fubini. *)
       rewrite <-(fubini_pos_seriesC (λ '(b,x), μ2 x * g x b * h2 b)).
       2: by series.
       2:{ intro b'.
           specialize (Hh2pos b').
           apply (ex_seriesC_le _ μ2) ; auto.
           intro b; split.
           - series.
           - do 2 rewrite <- Rmult_1_r. series. }
       2:{ setoid_rewrite SeriesC_scal_r.
           apply (ex_seriesC_le _ (λ a : B', SeriesC (λ b : B, μ2 b * g b a))); auto.
           - intros b'; specialize (Hh2pos b'); split.
             + apply Rmult_le_pos; [ | lra].
               apply (pmf_pos ((dbind g μ2)) b').
             + rewrite <- Rmult_1_r.
               apply Rmult_le_compat_l; auto.
               * apply SeriesC_ge_0'. real_solver.
               * real_solver.
           - apply (pmf_ex_seriesC (dbind g μ2)). }

       (* RHS: Factor out (μ2 b) *)
       assert (SeriesC (λ b : B, SeriesC (λ a : B', μ2 b * g b a * h2 a))
               = SeriesC (λ b : B, μ2 b * SeriesC (λ a : B', g b a * h2 a))) as ->.
       { apply SeriesC_ext; intro.
         rewrite <- SeriesC_scal_l.
         apply SeriesC_ext; real_solver. }

       (* here we depart from the lhs proof *)
       rewrite (Rplus_comm δ1 _).
       rewrite -Rplus_assoc.

       (* push (exp ε2) into the series over B *)
       rewrite exp_plus.
       rewrite Rmult_assoc.
       rewrite -(SeriesC_scal_l).
       setoid_rewrite (eq_sym (Rmult_assoc (exp ε2) (μ2 _) _)).
       setoid_rewrite ((Rmult_comm (exp ε2) _)).
       setoid_rewrite ((Rmult_assoc (μ2 _) (exp ε2) _)).
       rewrite exp_0. rewrite Rmult_1_l.

       set (gh2 b := exp ε2 * SeriesC (λ b' : B', g b b' * h2 b')).
       replace (λ b, μ2 b * _) with (λ b, μ2 b * gh2 b) by auto.

       assert ( x, 0 <= exp x) by (left ; apply exp_pos).
       assert (forall b, 0 <= gh2 b + Δ2 b).
       { intros b. apply Rplus_le_le_0_compat. 2: apply HΔ2.
         rewrite /gh2. apply Rmult_le_pos => //.
         apply SeriesC_ge_0'. intros b'.
         apply Rmult_le_pos => //.
         apply Hh2pos.
       }

       rewrite -SeriesC_plus.

       2:{
         apply pmf_ex_seriesC_mult_fn.
         eexists (exp ε2).
         intros b.
         split.
         - apply Rmult_le_pos => //. apply SeriesC_ge_0'.
           intros. apply Rmult_le_pos ; auto. naive_solver.
         - trans (exp ε2 * SeriesC (g b)) => //.
           + apply Rmult_le_compat_l => //. apply SeriesC_le => //. intros b'.
             destruct (Hh2pos b'). real_solver.
           + erewrite (eq_sym (Rmult_1_r (exp _))). rewrite Rmult_assoc. apply Rmult_le_compat_l => //. real_solver.
       }
       2:{
         apply pmf_ex_seriesC_mult_fn.
         eexists 1.
         intros b.
         destruct (HΔ2 b) ; lra.
       }
       replace (SeriesC (λ x : B, μ2 x * gh2 x + μ2 x * Δ2 x) + δ1)
         with (exp 0 * SeriesC (λ x : B, μ2 x * gh2 x + μ2 x * Δ2 x) + δ1).
       2:{ rewrite exp_0. lra. }
       setoid_rewrite <-Rmult_plus_distr_l.

       set (Y b := Rmin (gh2 b + Δ2 b) 1).

       transitivity ((exp 0 * SeriesC (λ x : B, μ2 x * Y x) + δ1)).
       {
       apply Hcoup_S.

       - intro; split.
        + apply SeriesC_ge_0'; intro a'.
          specialize (Hh1pos a'); real_solver.
        + apply (Rle_trans _ (SeriesC (f a))); auto.
          apply SeriesC_le; auto.
          intro a'.
          specialize (Hh1pos a'); real_solver.

       - intros. rewrite /Y. split.
        * rewrite /Y /gh2.
          apply Rmin_glb ; [|lra].
          apply Rle_plus_r. 1: apply HΔ2.
          apply Rmult_le_pos => //.
          apply SeriesC_ge_0'. intros b'.
          apply Rmult_le_pos => //.
          apply Hh2pos.
        * rewrite /Y/gh2. apply Rmin_r.
       - intros.
         unfold Y.
         apply Rmin_glb.
         + apply Hcoup_fg => //.
         + trans (SeriesC (λ a' : A', f a a' * 1)).
           * apply SeriesC_le.
             -- intros a'. destruct (Hh1pos a'). real_solver.
             -- apply pmf_ex_seriesC_mult_fn. exists 1. intros ; lra.
           * setoid_rewrite Rmult_1_r. auto.
       }
       apply Rle_plus_proper => //.
       rewrite exp_0. rewrite !Rmult_1_l.
       apply SeriesC_le'.
       1: { intros b. apply Rmult_le_compat_l => //. rewrite /Y.
            apply Rmin_l. }
    - rewrite /Y.
      apply pmf_ex_seriesC_mult_fn.
      eexists 1.
      intros b.
      split.
      + apply Rmin_glb. 2: lra. auto.
      + apply Rmin_r.
    - apply pmf_ex_seriesC_mult_fn.
      exists (exp ε2 * 1 + 1). intros b. split => //.
      apply Rle_plus_proper => //.
      + rewrite /gh2. trans (exp ε2 * SeriesC (λ b' : B', g b b')) => //.
        1,2: apply Rmult_le_compat_l => //.
        apply SeriesC_le => //.
        intros b' => //. destruct (Hh2pos b'). real_solver.
      + apply HΔ2.
  Qed.

  Lemma DPcoupl_dbind_adv_rhs_specialized'
    `{Countable A, Countable B, Countable A', Countable B'} (f : A distr A') (g : B distr B')
    (μ1 : distr A) (μ2 : distr B) (S : A B Prop) (S' : A' B' Prop)
    ε2 δ1 δ2 (Δ2 : B R) :
    (0 <= δ1) ( r, b, 0 <= (Δ2 b) <= r)
    (SeriesC (λ b, μ2 b * (Δ2 b)) <= δ2)
    ( a b, S a b DPcoupl (f a) (g b) S' ε2 (Δ2 b))
    DPcoupl μ1 μ2 S 0 δ1
    DPcoupl (dbind f μ1) (dbind g μ2) S' (0 + ε2) (δ1 + δ2).
  Proof.
    intros Hε1 HE2 Hsum Hfg Hcoupl.
    pose (Δ2' x := Rmin 1 (Δ2 x)).
    eapply (DPcoupl_mon_grading _ _ _ _ _ (δ1 + SeriesC (λ a, μ2 a * (Δ2' a)))).
    1: reflexivity.
    { apply Rplus_le_compat_l; etrans; last exact. apply SeriesC_le; last apply pmf_ex_seriesC_mult_fn.
      - intros. rewrite /Δ2'. split.
        + apply Rmult_le_pos; try done. apply Rmin_glb; [lra|naive_solver].
        + apply Rmult_le_compat_l; first done. apply Rmin_r.
      - naive_solver.
    }
    eapply (DPcoupl_dbind_adv_rhs_specialized _ _ _ _ _ _ _ _ _ Δ2'); try done.
    - intros a; split.
      + apply Rmin_glb; [lra|naive_solver].
      + apply Rmin_l.
    - intros a b Hs. specialize (Hfg a b Hs).
      rewrite /Δ2'.
      rewrite /Rmin.
      case_match.
      + apply DPcoupl_1; done.
      + eapply DPcoupl_mon_grading; done.
  Qed.

  (* Advanced composition on the right; not clear that the statement is entirely correct, but this should be the general
     version (see DPcoupl_dbind_adv_rhs_specialized). *)


  (* Lemma DPcoupl_dbind_adv_rhs `{Countable A, Countable B, Countable A', Countable B'} (f : A → distr A') (g : B → distr B')
       (μ1 : distr A) (μ2 : distr B) (S : A → B → Prop) (S' : A' → B' → Prop)
       ε1 ε2 δ1 δ2 (Δ2 : B → R) :
       (0 <= δ1) → (∀ b, 0 <= (Δ2 b) <= 1) →
       (* (SeriesC (λ a, μ1 a * (E2 a)) <= ε2) → *)
       (SeriesC (λ b, μ2 b * (Δ2 b)) = δ2) →
       (∀ a b, S a b → DPcoupl (f a) (g b) S' ε2 (Δ2 b)) →
       DPcoupl μ1 μ2 S ε1 δ1 →
       DPcoupl (dbind f μ1) (dbind g μ2) S' (ε1 + ε2) (δ1 + δ2).
     Proof.
       intros Hδ1 HΔ2 <- Hcoup_fg Hcoup_S h1 h2 Hh1pos Hh2pos Hh1h2S.
       rewrite {-3}/pmf/=/dbind_pmf.
       (* To use the hypothesis that we have an R-ACoupling up to ε1 for μ1, μ2,
          we have to rewrite the sums in such a way as to isolate (the expectation
          of) a random variable X on the LHS and Y on the RHS, and ε1 on the
          RHS. *)

       (* First step: rewrite the LHS into a RV X on μ1. *)
       setoid_rewrite <- SeriesC_scal_r.
       rewrite <-(fubini_pos_seriesC (λ '(a,x), μ1 x * f x a * h1 a)).

       (* Boring Fubini sideconditions. *)
       2: { real_solver. }
       2: { intro a'.
            (* specialize (Hh1pos a'). *)
            apply (ex_seriesC_le _ μ1); auto.
            intro a; split.
            + apply Rmult_le_pos.
              * real_solver.
              * real_solver.
            + rewrite <- Rmult_1_r.
              rewrite Rmult_assoc.
              apply Rmult_le_compat_l; auto.
              rewrite <- Rmult_1_r.
              apply Rmult_le_compat; real_solver. }
       2: { setoid_rewrite SeriesC_scal_r.
            apply (ex_seriesC_le _ (λ a : A', SeriesC (λ x : A, μ1 x * f x a))); auto.
            + series.
            + apply (pmf_ex_seriesC (dbind f μ1)). }

       (* LHS: Pull the (μ1 b) factor out of the inner sum. *)
       assert (SeriesC (λ b : A, SeriesC (λ a : A', μ1 b * f b a * h1 a)) =
                 SeriesC (λ b : A, μ1 b * SeriesC (λ a : A', f b a * h1 a))) as ->.
       { setoid_rewrite <- SeriesC_scal_l. series. }

       (* Second step: rewrite the RHS into a RV Y on μ2. *)
       (* RHS: Fubini. *)
       rewrite <-(fubini_pos_seriesC (λ '(b,x), μ2 x * g x b * h2 b)).
       2: by series.
       2:{ intro b'.
           specialize (Hh2pos b').
           apply (ex_seriesC_le _ μ2) ; auto.
           intro b; split.
           - series.
           - do 2 rewrite <- Rmult_1_r. series. }
       2:{ setoid_rewrite SeriesC_scal_r.
           apply (ex_seriesC_le _ (λ a : B', SeriesC (λ b : B, μ2 b * g b a))); auto.
           - intros b'; specialize (Hh2pos b'); split.
             + apply Rmult_le_pos;  | lra.
               apply (pmf_pos ((dbind g μ2)) b').
             + rewrite <- Rmult_1_r.
               apply Rmult_le_compat_l; auto.
               * apply SeriesC_ge_0'. real_solver.
               * real_solver.
           - apply (pmf_ex_seriesC (dbind g μ2)). }

       (* RHS: Factor out (μ2 b) *)
       assert (SeriesC (λ b : B, SeriesC (λ a : B', μ2 b * g b a * h2 a))
               = SeriesC (λ b : B, μ2 b * SeriesC (λ a : B', g b a * h2 a))) as ->.
       { apply SeriesC_ext; intro.
         rewrite <- SeriesC_scal_l.
         apply SeriesC_ext; real_solver. }

       (* here we depart from the lhs proof *)
       rewrite (Rplus_comm δ1 _).
       rewrite -Rplus_assoc.

       (* push (exp ε2) into the series over B *)
       rewrite exp_plus.
       rewrite Rmult_assoc.
       rewrite -(SeriesC_scal_l).
       setoid_rewrite (eq_sym (Rmult_assoc (exp ε2) (μ2 _) _)).
       setoid_rewrite ((Rmult_comm (exp ε2) _)).
       setoid_rewrite ((Rmult_assoc (μ2 _) (exp ε2) _)).









       (* lhs proof below *)

       apply Rle_minus_l.


       (* To construct X, we want to push ε2 into the inner sum. We don't do this
          directly, because X might be larger than 1, but
          our assumption on the ε1 R-ACoupling requires it to be valued in 0,1.
          Instead, we take min(1, exp(ε2) * (Σ(a:A')(f b a * h1 a))).
          ALT: could use a more fine-grained min inside the sum?
        *)


       assert (exp (ε1) * SeriesC (λ b : B, μ2 b * (Rmin 1 (exp (ε2) * SeriesC (λ a : B', g b a * h2 a)))) + δ1
               <= exp (ε1 + ε2) * SeriesC (λ b : B, μ2 b * SeriesC (λ a : B', g b a * h2 a)) + δ1) as <-.
       {
          apply Rplus_le_compat_r.
          rewrite exp_plus.
          rewrite Rmult_assoc.
          rewrite -(SeriesC_scal_l _ (exp ε2)).
          apply Rmult_le_compat_l; left; apply exp_pos |.
          apply SeriesC_le.
          - intros b; split.
            + apply Rmult_le_pos; auto.
              apply Rmin_glb; lra |.
              apply Rmult_le_pos; left; apply exp_pos |.
              apply SeriesC_ge_0'.
              real_solver.
            + rewrite Rmult_min_distr_l; auto.
              etrans; apply Rmin_r | lra.
          - apply ex_seriesC_scal_l.
            apply (ex_seriesC_le _ μ2); auto.
            intro b; split.
            + apply Rmult_le_pos; auto.
              apply SeriesC_ge_0'.
              intro; apply Rmult_le_pos; auto.
              apply Hh2pos.
            + rewrite <- Rmult_1_r.
              apply Rmult_le_compat_l; auto.
              apply (Rle_trans _ (SeriesC (g b))); auto.
              apply SeriesC_le; auto.
              real_solver.
       }

       assert (
           SeriesC (λ b : A, μ1 b * SeriesC (λ a : A', f b a * h1 a)) -
             SeriesC (λ b, μ2 b * Δ2 b)
           <= SeriesC (λ b : A, μ1 b * Rmax 0 (SeriesC (λ a : A', f b a * h1 a) - Δ2 b))
         ) as ->.
       {
         apply (Rle_trans _ (SeriesC (λ b : A, μ1 b * SeriesC (λ a : A', f b a * h1 a) - μ1 b * Δ2 b))).
         - rewrite SeriesC_minus.
           + apply Rplus_le_compat_l.
             apply Ropp_le_contravar.
             done.
           + apply (ex_seriesC_le _ μ1); auto.
            intro a; split.
            * apply Rmult_le_pos; auto.
              apply SeriesC_ge_0'.
              intro; apply Rmult_le_pos; auto.
              apply Hh1pos.
            * rewrite <- Rmult_1_r.
              apply Rmult_le_compat_l; auto.
              apply (Rle_trans _ (SeriesC (f a))); auto.
              apply SeriesC_le; auto.
              real_solver.
           + apply (ex_seriesC_le _ μ1); auto.
             intros; real_solver.
         - apply SeriesC_le'.
           + intros a.
             rewrite -Rmult_minus_distr_l.
             apply Rmult_le_compat_l; auto.
             apply Rmax_r.
           + apply ex_seriesC_plus.
             * apply (ex_seriesC_le _ μ1); auto.
               intro a; split.
               ** apply Rmult_le_pos; auto.
                  apply SeriesC_ge_0'.
                  intro; apply Rmult_le_pos; auto.
                  apply Hh1pos.
               ** rewrite <- Rmult_1_r.
                  apply Rmult_le_compat_l; auto.
                  apply (Rle_trans _ (SeriesC (f a))); auto.
                  apply SeriesC_le; auto.
                  real_solver.
             * apply (ex_seriesC_ext (λ x, -1 * (μ1 x * Δ2 x))).
               1: intros; real_solver.
               apply ex_seriesC_scal_l.
               apply (ex_seriesC_le _ μ1); auto.
               intros; real_solver.
           + apply (ex_seriesC_le _ μ1); auto.
             intros a; split.
             * apply Rmult_le_pos; auto.
               apply Rmax_l.
             * rewrite -{2}(Rmult_1_r (μ1 a)).
               apply Rmult_le_compat_l; auto.
               apply Rmax_lub; first by lra.
               apply Rle_minus_l.
               apply (Rle_trans _ 1); last by real_solver.
               apply (Rle_trans _ (SeriesC (f a))); auto.
               apply SeriesC_le; auto.
               real_solver.
       }

       (*
           Now we instantiate the lifting definitions and use them to prove the
           inequalities
       *)

       rewrite /DPcoupl in Hcoup_S.
       apply Hcoup_S.
       + intro; split; first apply Rmax_l.
         apply Rmax_lub; first by lra.
         apply Rle_minus_l.
         apply (Rle_trans _ 1); last by real_solver.
         apply (Rle_trans _ (SeriesC (f a))); auto.
         apply SeriesC_le; auto; real_solver.
       + intro; split.
         * apply Rmin_glb; lra |.
           apply Rmult_le_pos.
           ** left. apply exp_pos.
           ** apply SeriesC_ge_0'; intro b'.
              specialize (Hh2pos b'); real_solver.
         * apply Rmin_l.

       + intros a b Rab.
         apply Rmin_glb; apply Rmax_lub; first by lra.
         * apply Rle_minus_l.
           apply (Rle_trans _ 1); last by real_solver.
           apply (Rle_trans _ (SeriesC (f a))); auto.
           apply SeriesC_le; auto; real_solver.
         * series.
           left.
           by apply exp_pos.
         * apply Rle_minus_l.
           by apply Hcoup_fg.
     Qed. *)


  Lemma DPcoupl_dbind_choice `{Countable A, Countable B, Countable A', Countable B'} (f : A distr A') (g : B distr B')
    (μ1 : distr A) (μ2 : distr B) (P : A -> Prop) (S1 : A B Prop) (S2 : A B Prop) (S' : A' B' Prop)
    ε1 ε2 δ1 δ2 ε1' ε2' δ1' ε δ:
    (0 <= δ1) (0 <= δ2) (0 <= δ1')
    (ε1 + ε2 <= ε) -> (ε1' + ε2' <= ε) ->
    (δ1 + δ1' + δ2 <= δ) ->
    (* Stronger version: (δ1 + δ1' + Rmax δ2 δ2' <= δ) -> *)
    (forall a a' b, P a -> ¬ P a' -> ¬(S1 a b /\ S2 a' b)) ->
    ( a b, (P a /\ S1 a b) DPcoupl (f a) (g b) S' ε2 δ2)
    ( a b, (¬P a /\ S2 a b) DPcoupl (f a) (g b) S' ε2' δ2)
    DPcoupl μ1 μ2 S1 ε1 δ1
    DPcoupl μ1 μ2 S2 ε1' δ1'
    DPcoupl (dbind f μ1) (dbind g μ2) S' ε δ.
  Proof.
    intros Hδ1 Hδ2 Hδ1' Hεleq Hεleq' Hδleq
      Hindep
      Hcoup_fg1 Hcoup_fg2 Hcoup_S1 Hcoup_S2 h1 h2 Hh1pos Hh2pos Hh1h2S'.

    rewrite /pmf/=/dbind_pmf.
    (* To use the hypothesis that we have an R-ACoupling up to ε1 for μ1, μ2,
       we have to rewrite the sums in such a way as to isolate (the expectation
       of) a random variable X on the LHS and Y on the RHS, and ε1 on the
       RHS. *)

    (* First step: rewrite the LHS into a RV X on μ1. *)
    setoid_rewrite <- SeriesC_scal_r.
    rewrite <-(fubini_pos_seriesC (λ '(a,x), μ1 x * f x a * h1 a)).

    (* Boring Fubini sideconditions. *)
    2: { real_solver. }
    2: { intro a'.
         (* specialize (Hh1pos a'). *)
         apply (ex_seriesC_le _ μ1); auto.
         intro a; split.
         + apply Rmult_le_pos.
           * real_solver.
           * real_solver.
         + rewrite <- Rmult_1_r.
           rewrite Rmult_assoc.
           apply Rmult_le_compat_l; auto.
           rewrite <- Rmult_1_r.
           apply Rmult_le_compat; real_solver. }
    2: { setoid_rewrite SeriesC_scal_r.
         apply (ex_seriesC_le _ (λ a : A', SeriesC (λ x : A, μ1 x * f x a))); auto.
         + series.
         + apply (pmf_ex_seriesC (dbind f μ1)). }

    (* LHS: Pull the (μ1 b) factor out of the inner sum. *)
    assert (SeriesC (λ b : A, SeriesC (λ a : A', μ1 b * f b a * h1 a)) =
              SeriesC (λ b : A, μ1 b * SeriesC (λ a : A', f b a * h1 a))) as ->.
    { setoid_rewrite <- SeriesC_scal_l. series. }

    (* Second step: rewrite the RHS into a RV Y on μ2. *)
    (* RHS: Fubini. *)
    rewrite <-(fubini_pos_seriesC (λ '(b,x), μ2 x * g x b * h2 b)).
    2: by series.
    2:{ intro b'.
        specialize (Hh2pos b').
        apply (ex_seriesC_le _ μ2) ; auto.
        intro b; split.
        - series.
        - do 2 rewrite <- Rmult_1_r. series. }
    2:{ setoid_rewrite SeriesC_scal_r.
        apply (ex_seriesC_le _ (λ a : B', SeriesC (λ b : B, μ2 b * g b a))); auto.
        - intros b'; specialize (Hh2pos b'); split.
          + apply Rmult_le_pos; [ | lra].
            apply (pmf_pos ((dbind g μ2)) b').
          + rewrite <- Rmult_1_r.
            apply Rmult_le_compat_l; auto.
            * apply SeriesC_ge_0'. real_solver.
            * real_solver.
        - apply (pmf_ex_seriesC (dbind g μ2)). }

    (* RHS: Factor out (μ2 b) *)
    assert (SeriesC (λ b : B, SeriesC (λ a : B', μ2 b * g b a * h2 a))
            = SeriesC (λ b : B, μ2 b * SeriesC (λ a : B', g b a * h2 a))) as ->.
    { apply SeriesC_ext; intro.
      rewrite <- SeriesC_scal_l.
      apply SeriesC_ext; real_solver. }

    (* Now let's split the sum depending on whether P holds *)

    assert (SeriesC (λ b : A, μ1 b * SeriesC (λ a : A', f b a * h1 a)) =
              SeriesC (λ b : A, μ1 b * (if (bool_decide (P b)) then SeriesC (λ a : A', f b a * h1 a) else 0)) +
              SeriesC (λ b : A, μ1 b * (if (bool_decide (¬ P b)) then SeriesC (λ a : A', f b a * h1 a) else 0))) as ->.
    {
      rewrite -SeriesC_plus.
      - apply SeriesC_ext.
        intro a.
        case_bool_decide; case_bool_decide; real_solver.
      - apply (ex_seriesC_le _ μ1); auto.
        intros a; split.
        + apply Rmult_le_pos; auto.
          case_bool_decide; [|lra].
          apply SeriesC_ge_0'; real_solver.
        + rewrite -{2}(Rmult_1_r (μ1 _)).
          apply Rmult_le_compat_l; auto.
          case_bool_decide; [|lra].
          transitivity (SeriesC (f a)); auto.
          apply SeriesC_le; auto.
          real_solver.
      - apply (ex_seriesC_le _ μ1); auto.
        intros a; split.
        + apply Rmult_le_pos; auto.
          case_bool_decide; [|lra].
          apply SeriesC_ge_0'; real_solver.
        + rewrite -{2}(Rmult_1_r (μ1 _)).
          apply Rmult_le_compat_l; auto.
          case_bool_decide; [|lra].
          transitivity (SeriesC (f a)); auto.
          apply SeriesC_le; auto.
          real_solver.
    }


    (*  Stronger version: assert (
        SeriesC (λ b : A, μ1 b * (if (bool_decide (P b)) then  SeriesC (λ a : A', f b a * h1 a) else 0))
        <= SeriesC (λ b : A, μ1 b * if (bool_decide (P b)) then  (Rmax 0 (SeriesC (λ a : A', f b a * h1 a) - δ2)) else 0) + SeriesC (λ b : A, μ1 b * if (bool_decide (P b)) then δ2 else 0)
      ) as Htrans1.
    {
      rewrite -SeriesC_plus.
      - apply SeriesC_le.
        + intros a. split.
          * case_bool_decide; |real_solver.
            apply Rmult_le_pos; auto.
            apply SeriesC_ge_0'.
            real_solver.
          * case_bool_decide; |lra.
            rewrite -Rmult_plus_distr_l.
            apply Rmult_le_compat_l; auto.
            rewrite Rplus_max_distr_r.
            eapply Rle_trans; |apply Rmax_r.
            lra.
        + admit.
      - admit.
      - admit.
*)


    assert (
        SeriesC (λ b : A, μ1 b * (if (bool_decide (P b)) then SeriesC (λ a : A', f b a * h1 a) else 0))
        <= SeriesC (λ b : A, μ1 b * if (bool_decide (P b)) then (Rmax 0 (SeriesC (λ a : A', f b a * h1 a) - δ2)) else 0) +
          SeriesC (λ a, μ1 a * (if (bool_decide (P a)) then δ2 else 0))
      ) as Htrans1.
    {
      rewrite -SeriesC_plus.
      - apply SeriesC_le'.
        + intros a.
          rewrite -Rmult_plus_distr_l.
          apply Rmult_le_compat_l; auto.
          case_bool_decide; last by lra.
          rewrite Rplus_max_distr_r.
          etrans; [|apply Rmax_r].
          lra.
        + apply (ex_seriesC_le _ μ1); auto.
          intros a; split.
          * apply Rmult_le_pos; auto.
            case_bool_decide; last by lra.
            apply SeriesC_ge_0'; real_solver.
          * case_bool_decide.
            ** rewrite -{2}(Rmult_1_r (μ1 a)).
               apply Rmult_le_compat_l; auto.
               apply (Rle_trans _ (SeriesC (f a))); auto.
               apply SeriesC_le; auto.
               real_solver.
            ** rewrite Rmult_0_r; auto.
        + apply ex_seriesC_plus.
          * apply (ex_seriesC_le _ μ1); auto.
            intro a; split.
            ** apply Rmult_le_pos; auto.
               case_bool_decide; last by lra.
               apply Rmax_l.
            ** rewrite <- Rmult_1_r.
               apply Rmult_le_compat_l; auto.
               case_bool_decide; last by lra.
               apply (Rle_trans _ (SeriesC (f a))); auto.
               apply Rmax_lub; auto.
               apply Rle_minus_l.
               apply (Rle_trans _ (SeriesC (f a))); last by lra.
               apply SeriesC_le; auto.
               real_solver.
          * apply (ex_seriesC_le _ (λ x, μ1 x * δ2)); [|apply ex_seriesC_scal_r; auto].
            intro a; split.
            ** case_bool_decide; real_solver.
            ** case_bool_decide; real_solver.
      - apply (ex_seriesC_le _ μ1); auto.
        intros a; split.
        * case_bool_decide; last by lra.
          apply Rmult_le_pos; auto.
          apply Rmax_l.
        * rewrite -{2}(Rmult_1_r (μ1 a)).
          apply Rmult_le_compat_l; auto.
          apply (Rle_trans _ (SeriesC (f a))); auto.
          case_bool_decide; auto.
          apply Rmax_lub; auto.
          apply Rle_minus_l.
          apply (Rle_trans _ (SeriesC (f a))); last by lra.
          apply SeriesC_le; auto.
          real_solver.
     - apply (ex_seriesC_le _ (λ x, μ1 x * δ2)); [|apply ex_seriesC_scal_r; auto].
       intro a; split; case_bool_decide; real_solver.
    }

    assert (
        SeriesC (λ b : A, μ1 b * (if (bool_decide (¬ P b)) then SeriesC (λ a : A', f b a * h1 a) else 0))
        <= SeriesC (λ b : A, μ1 b * if (bool_decide (¬ P b)) then (Rmax 0 (SeriesC (λ a : A', f b a * h1 a) - δ2)) else 0) +
            SeriesC (λ a, μ1 a * (if (bool_decide (¬ P a)) then δ2 else 0))
      ) as Htrans2.
    {
      rewrite -SeriesC_plus.
      - apply SeriesC_le'.
        + intros a.
          rewrite -Rmult_plus_distr_l.
          apply Rmult_le_compat_l; auto.
          case_bool_decide; last by lra.
          rewrite Rplus_max_distr_r.
          etrans; [|apply Rmax_r].
          lra.
        + apply (ex_seriesC_le _ μ1); auto.
          intros a; split.
          * apply Rmult_le_pos; auto.
            case_bool_decide; last by lra.
            apply SeriesC_ge_0'; real_solver.
          * case_bool_decide.
            ** rewrite -{2}(Rmult_1_r (μ1 a)).
               apply Rmult_le_compat_l; auto.
               apply (Rle_trans _ (SeriesC (f a))); auto.
               apply SeriesC_le; auto.
               real_solver.
            ** rewrite Rmult_0_r; auto.
        + apply ex_seriesC_plus.
          * apply (ex_seriesC_le _ μ1); auto.
            intro a; split.
            ** apply Rmult_le_pos; auto.
               case_bool_decide; last by lra.
               apply Rmax_l.
            ** rewrite <- Rmult_1_r.
               apply Rmult_le_compat_l; auto.
               case_bool_decide; last by lra.
               apply (Rle_trans _ (SeriesC (f a))); auto.
               apply Rmax_lub; auto.
               apply Rle_minus_l.
               apply (Rle_trans _ (SeriesC (f a))); last by lra.
               apply SeriesC_le; auto.
               real_solver.
          * apply (ex_seriesC_le _ (λ x, μ1 x * δ2)); [|apply ex_seriesC_scal_r; auto].
            intro a; split.
            ** case_bool_decide; real_solver.
            ** case_bool_decide; real_solver.
      - apply (ex_seriesC_le _ μ1); auto.
        intros a; split.
        * case_bool_decide; last by lra.
          apply Rmult_le_pos; auto.
          apply Rmax_l.
        * rewrite -{2}(Rmult_1_r (μ1 a)).
          apply Rmult_le_compat_l; auto.
          apply (Rle_trans _ (SeriesC (f a))); auto.
          case_bool_decide; auto.
          apply Rmax_lub; auto.
          apply Rle_minus_l.
          apply (Rle_trans _ (SeriesC (f a))); last by lra.
          apply SeriesC_le; auto.
          real_solver.
     - apply (ex_seriesC_le _ (λ x, μ1 x * δ2)); [|apply ex_seriesC_scal_r; auto].
       intro a; split; case_bool_decide; real_solver.
    }

    erewrite (Rplus_le_compat); eauto.
    rewrite /DPcoupl in Hcoup_fg1.
    rewrite /DPcoupl in Hcoup_fg2.
    rewrite /DPcoupl in Hcoup_S1.
    rewrite /DPcoupl in Hcoup_S2.

    assert (forall a b, S1 a b -> (if bool_decide (P a) then Rmax 0 (SeriesC (λ a' : A', f a a' * h1 a') - δ2) else 0) <=
                 (if bool_decide (exists a', P a' /\ S1 a' b ) then Rmin 1 (exp (ε2) * SeriesC (λ b' : B', g b b' * h2 b')) else 0 ) ) as Htrans3.
    {
      intros a b HS1.
      case_bool_decide as HdecL; case_bool_decide as HdecR.
      - apply Rmin_glb; apply Rmax_lub; first by lra.
        + apply Rle_minus_l.
          apply (Rle_trans _ 1); last by real_solver.
          apply (Rle_trans _ (SeriesC (f a))); auto.
          apply SeriesC_le; auto; real_solver.
        + series.
          left.
          by apply exp_pos.
        + apply Rle_minus_l.
          by apply Hcoup_fg1.
      - exfalso.
        apply HdecR.
        by exists a.
      - apply Rmin_glb; first lra.
        apply Rmult_le_pos.
        + left. by apply exp_pos.
        + apply SeriesC_ge_0'.
          intros; real_solver.
      - lra.
    }


    assert (forall a b, S2 a b -> (if bool_decide (¬ P a) then Rmax 0 (SeriesC (λ a' : A', f a a' * h1 a') - δ2) else 0) <=
                 (if bool_decide (exists a', ¬ P a' /\ S2 a' b) then Rmin 1 (exp (ε2') * SeriesC (λ b' : B', g b b' * h2 b')) else 0) ) as Htrans4.
    {
      intros a b HS4.
      case_bool_decide as HdecL; case_bool_decide as HdecR.
      - apply Rmin_glb; apply Rmax_lub; first by lra.
        + apply Rle_minus_l.
          apply (Rle_trans _ 1); last by real_solver.
          apply (Rle_trans _ (SeriesC (f a))); auto.
          apply SeriesC_le; auto; real_solver.
        + series.
          left.
          by apply exp_pos.
        + apply Rle_minus_l.
          by apply Hcoup_fg2.
      - exfalso.
        apply HdecR.
        by exists a.
      - apply Rmin_glb; first lra.
        apply Rmult_le_pos.
        + left. by apply exp_pos.
        + apply SeriesC_ge_0'.
          intros; real_solver.
      - lra.
    }

    epose proof (Hcoup_S1 _ _ _ _ Htrans3) as HauxS1.
    epose proof (Hcoup_S2 _ _ _ _ Htrans4) as HauxS2.
    simpl in HauxS1.
    simpl in HauxS2.
    erewrite Rplus_le_compat; eauto.
    2:{
      apply Rplus_le_compat; [apply HauxS1 | apply Rle_refl].
    }
    2:{
      apply Rplus_le_compat; [apply HauxS2 | apply Rle_refl].
    }

    do 3 rewrite -SeriesC_scal_l.
    assert (forall a b c d e f : R, a + b + c + (d + e + f) = (a + d) + (b + e + (c + f))) as ->.
    { intros. lra. }
    rewrite -SeriesC_plus.
    2:{
      apply (ex_seriesC_le _ (λ b, exp ε1 * μ2 b * exp ε2)).
      - intros b.
        split.
        + apply Rmult_le_pos; [left; apply exp_pos |].
          apply Rmult_le_pos; auto.
          case_bool_decide; [|lra].
          apply Rmin_glb; [lra|].
          apply Rmult_le_pos; [left; apply exp_pos |].
          apply SeriesC_ge_0; [real_solver|].
          apply (ex_seriesC_le _ (g b)); auto.
          real_solver.
        + rewrite Rmult_assoc.
          apply Rmult_le_compat_l; [left; apply exp_pos |].
          apply Rmult_le_compat_l; auto.
          case_bool_decide; [|left; apply exp_pos].
          etrans; [apply Rmin_r |].
          rewrite -{2}(Rmult_1_r (exp ε2)).
          apply Rmult_le_compat_l; [left; apply exp_pos|].
          transitivity (SeriesC (g b)); auto.
          apply SeriesC_le; auto.
          real_solver.
      - apply ex_seriesC_scal_r.
        by apply ex_seriesC_scal_l.
    }
    2:{
      apply (ex_seriesC_le _ (λ b, exp ε1' * μ2 b * exp ε2')).
      - intros b.
        split.
        + apply Rmult_le_pos; [left; apply exp_pos |].
          apply Rmult_le_pos; auto.
          case_bool_decide; [|lra].
          apply Rmin_glb; [lra|].
          apply Rmult_le_pos; [left; apply exp_pos |].
          apply SeriesC_ge_0; [real_solver|].
          apply (ex_seriesC_le _ (g b)); auto.
          real_solver.
        + rewrite Rmult_assoc.
          apply Rmult_le_compat_l; [left; apply exp_pos |].
          apply Rmult_le_compat_l; auto.
          case_bool_decide; [|left; apply exp_pos].
          etrans; [apply Rmin_r |].
          rewrite -{2}(Rmult_1_r (exp ε2')).
          apply Rmult_le_compat_l; [left; apply exp_pos|].
          transitivity (SeriesC (g b)); auto.
          apply SeriesC_le; auto.
          real_solver.
      - apply ex_seriesC_scal_r.
        by apply ex_seriesC_scal_l.
    }
    apply Rplus_le_compat; last first.
    {
      rewrite -SeriesC_plus.
      - etrans; [|apply Hδleq].
        apply Rplus_le_compat; [lra|].
        transitivity (SeriesC (λ x, μ1 x * δ2)).
        + apply SeriesC_le; [|apply ex_seriesC_scal_r; auto].
          intros; case_bool_decide; case_bool_decide; try done.
          * rewrite Rmult_0_r Rplus_0_r.
            real_solver.
          * rewrite Rmult_0_r Rplus_0_l.
            real_solver.
        + rewrite SeriesC_scal_r.
          rewrite -{2}(Rmult_1_l δ2).
          apply Rmult_le_compat_r; auto.
      - apply (ex_seriesC_le _ (λ x, μ1 x * δ2)); [|apply ex_seriesC_scal_r; auto].
        intro a; split; case_bool_decide; real_solver.
      - apply (ex_seriesC_le _ (λ x, μ1 x * δ2)); [|apply ex_seriesC_scal_r; auto].
        intro a; split; case_bool_decide; real_solver.
    }
    apply SeriesC_le.
    2:{
      apply ex_seriesC_scal_l.
      apply (ex_seriesC_le _ μ2); auto.
      intros b; split.
      - apply Rmult_le_pos; auto.
        apply SeriesC_ge_0'; real_solver.
      - rewrite -{2}(Rmult_1_r (μ2 b)).
        apply Rmult_le_compat_l; auto.
        transitivity (SeriesC (g b)); auto.
        apply SeriesC_le; auto.
        real_solver.
    }

    intros b.
    split.
    - case_bool_decide as HdecL; case_bool_decide as HdecR.
      + apply Rplus_le_le_0_compat.
        * apply Rmult_le_pos; [left; apply exp_pos|].
          apply Rmult_le_pos; auto.
          apply Rmin_glb; [lra|].
          apply Rmult_le_pos; [left; apply exp_pos|].
          apply SeriesC_ge_0'; real_solver.
        * apply Rmult_le_pos; [left; apply exp_pos|].
          apply Rmult_le_pos; auto.
          apply Rmin_glb; [lra|].
          apply Rmult_le_pos; [left; apply exp_pos|].
          apply SeriesC_ge_0'; real_solver.
     + rewrite !Rmult_0_r Rplus_0_r.
       apply Rmult_le_pos; [left; apply exp_pos|].
       apply Rmult_le_pos; auto.
       apply Rmin_glb; [lra|].
       apply Rmult_le_pos; [left; apply exp_pos|].
       apply SeriesC_ge_0'; real_solver.
     + rewrite !Rmult_0_r Rplus_0_l.
       apply Rmult_le_pos; [left; apply exp_pos|].
       apply Rmult_le_pos; auto.
       apply Rmin_glb; [lra|].
       apply Rmult_le_pos; [left; apply exp_pos|].
       apply SeriesC_ge_0'; real_solver.
     + lra.

    - do 2 rewrite -Rmult_assoc.
      assert (forall x y z r, x * y * (z * r) = (x * z) * (y * r)) as Haux_rw by real_solver.
      case_bool_decide as HdecL; case_bool_decide as HdecR.
      + destruct HdecL as [a [? ?]].
        destruct HdecR as [a' [? ?]].
        exfalso.
        eapply Hindep; eauto.
      + rewrite Rmult_0_r Rplus_0_r.
        rewrite Rmult_min_distr_l.
        * eapply Rle_trans ; [apply Rmin_r|].
          rewrite Haux_rw.
          rewrite -exp_plus.
          apply Rmult_le_compat.
          ** left; apply exp_pos.
          ** apply Rmult_le_pos; auto.
             apply SeriesC_ge_0'; real_solver.
          ** apply exp_mono; auto.
          ** lra.
       * apply Rmult_le_pos; auto.
         left; apply exp_pos.
      + rewrite Rmult_0_r Rplus_0_l.
        rewrite Rmult_min_distr_l.
        * eapply Rle_trans ; [apply Rmin_r|].
          rewrite Haux_rw.
          rewrite -exp_plus.
          apply Rmult_le_compat.
          ** left; apply exp_pos.
          ** apply Rmult_le_pos; auto.
             apply SeriesC_ge_0'; real_solver.
          ** apply exp_mono; auto.
          ** lra.
       * apply Rmult_le_pos; auto.
         left; apply exp_pos.
      + rewrite !Rmult_0_r Rplus_0_l.
        apply Rmult_le_pos; [left; apply exp_pos|].
        apply Rmult_le_pos; auto.
        apply SeriesC_ge_0'; real_solver.
   Unshelve.
   1:{
     intros a. split.
     - case_bool_decide; [apply Rmax_l|lra].
     - case_bool_decide; [|lra].
       apply Rmax_lub; [lra|].
       apply Rle_minus_r.
       transitivity 1; [|lra].
       transitivity (SeriesC (f a)); auto.
       apply SeriesC_le; auto.
       real_solver.
   }
   1:{
     intros b. split.
     - case_bool_decide; [|lra].
       apply Rmin_glb; [lra|].
       apply Rmult_le_pos; [left; apply exp_pos|].
       apply SeriesC_ge_0'; real_solver.
     - case_bool_decide; [apply Rmin_l|lra].
   }

   1:{
     intros a. split.
     - case_bool_decide; [apply Rmax_l|lra].
     - case_bool_decide; [|lra].
       apply Rmax_lub; [lra|].
       apply Rle_minus_r.
       transitivity 1; [|lra].
       transitivity (SeriesC (f a)); auto.
       apply SeriesC_le; auto.
       real_solver.
   }
   1:{
     intros b. split.
     - case_bool_decide; [|lra].
       apply Rmin_glb; [lra|].
       apply Rmult_le_pos; [left; apply exp_pos|].
       apply SeriesC_ge_0'; real_solver.
     - case_bool_decide; [apply Rmin_l|lra].
   }

  Qed.

  (* Depend on both *)
  Lemma DPcoupl_dbind_adv_kanto_plain `{Countable A, Countable B, Countable A', Countable B'}
    (f : A distr A') (g : B distr B')
    (μ1 : distr A) (μ2 : distr B) (S' : A' B' Prop)
    ε δ (E2 : A B R) (D2 : A -> B -> R) :
    ( a b, Rle 0 (E2 a b))
    ( a b, Rle 0 (D2 a b))
    (forall h1 h2,
        (forall a : A, 0 <= h1 a <= 1) ->
        (forall b : B, 0 <= h2 b <= 1) ->
        (forall a b, h1 a <= exp (E2 a b) * h2 b + D2 a b) ->
        (SeriesC (λ a, μ1 a * h1 a) <=
           exp(ε) * SeriesC (λ b, μ2 b * h2 b) + δ)) ->
    ( a b, DPcoupl (f a) (g b) S' (E2 a b) (D2 a b)) DPcoupl (dbind f μ1) (dbind g μ2) S' ε δ.
  Proof.
    intros HE2 HD2 Hexp Hcoup_fg h1 h2 Hh1pos Hh2pos Hh1h2S.
    rewrite /pmf/=/dbind_pmf.

    setoid_rewrite <- SeriesC_scal_r.
    rewrite <-(fubini_pos_seriesC (λ '(a,x), μ1 x * f x a * h1 a)).

    (* Boring Fubini sideconditions. *)
    2: { intros a' a. specialize (Hh1pos a') ; real_solver. }
    2: { intro a'.
         specialize (Hh1pos a').
         apply (ex_seriesC_le _ μ1); auto.
         intro a; split.
         + apply Rmult_le_pos.
           * real_solver.
           * real_solver.
         + rewrite <- Rmult_1_r.
           rewrite Rmult_assoc.
           apply Rmult_le_compat_l; auto.
           rewrite <- Rmult_1_r.
           apply Rmult_le_compat; real_solver. }
    2: { setoid_rewrite SeriesC_scal_r.
         apply (ex_seriesC_le _ (λ a : A', SeriesC (λ x : A, μ1 x * f x a))); auto.
         + intros a'; specialize (Hh1pos a'); split.
           * apply Rmult_le_pos; [ | lra].
             apply (pmf_pos ((dbind f μ1)) a').
           * rewrite <- Rmult_1_r.
             apply Rmult_le_compat_l; auto.
             -- apply SeriesC_ge_0'. real_solver.
             -- real_solver.
         + apply (pmf_ex_seriesC (dbind f μ1)). }

    (* LHS: Pull the (μ1 b) factor out of the inner sum. *)
    assert (SeriesC (λ b : A, SeriesC (λ a : A', μ1 b * f b a * h1 a)) =
              SeriesC (λ b : A, μ1 b * SeriesC (λ a : A', f b a * h1 a))) as ->.
    { apply SeriesC_ext; intro.
      rewrite <- SeriesC_scal_l.
      apply SeriesC_ext; real_solver. }

    (* Second step: rewrite the RHS into a RV Y on μ2. *)
    rewrite <-(fubini_pos_seriesC (λ '(b,x), μ2 x * g x b * h2 b)).
    2:{ intros b' b.
        specialize (Hh2pos b').
        real_solver. }
    2:{ intro b'.
        specialize (Hh2pos b').
        apply (ex_seriesC_le _ μ2); auto.
        intro b; split.
        - apply Rmult_le_pos.
          + real_solver.
          + real_solver.
        - rewrite <- Rmult_1_r.
          rewrite Rmult_assoc.
          apply Rmult_le_compat_l; auto.
          rewrite <- Rmult_1_r.
          apply Rmult_le_compat; real_solver. }
    2:{ setoid_rewrite SeriesC_scal_r.
        apply (ex_seriesC_le _ (λ a : B', SeriesC (λ b : B, μ2 b * g b a))); auto.
        - intros b'; specialize (Hh2pos b'); split.
          + apply Rmult_le_pos; [ | lra].
            apply (pmf_pos ((dbind g μ2)) b').
          + rewrite <- Rmult_1_r.
            apply Rmult_le_compat_l; auto.
            * apply SeriesC_ge_0'. real_solver.
            * real_solver.
        - apply (pmf_ex_seriesC (dbind g μ2)). }

    (* RHS: Factor out (μ2 b) *)
    assert (SeriesC (λ b : B, SeriesC (λ a : B', μ2 b * g b a * h2 a))
            = SeriesC (λ b : B, μ2 b * SeriesC (λ a : B', g b a * h2 a))) as ->.
    { apply SeriesC_ext; intro.
      rewrite <- SeriesC_scal_l.
      apply SeriesC_ext; real_solver. }

    set (X a := SeriesC (λ a' : A', f a a' * h1 a')).
    replace (λ b, μ1 b * _) with (λ a, μ1 a * X a) by auto.

    set (Y b := SeriesC (λ b' : B', g b b' * h2 b')).
    replace (λ b, μ2 b * _) with (λ b, μ2 b * Y b) by auto.
    rewrite /DPcoupl in Hcoup_fg.
    apply Hexp.
    - intros.
      rewrite /X; split.
      + series.
      + trans (SeriesC (f a)) => //.
        apply SeriesC_le => //. intros a'.
        destruct (Hh1pos a'). real_solver.
    - intros.
      rewrite /Y; split.
      + series.
      + trans (SeriesC (g b)) => //.
        apply SeriesC_le => //. intros b'.
        destruct (Hh2pos b'). real_solver.
    - intros.
      rewrite /X /Y /=.
      apply Hcoup_fg; auto.
  Qed.

  Lemma DPcoupl_dbind_subsampling `{Countable A}
    (μ1 : distr A) (μ2 : distr A) (μ3 : distr A) (S : A -> A -> Prop) (r : R)
    (Hr: 0 <= r <= 1) ε δ :
    (0 <= ε) -> (0 <= δ) ->
    (DPcoupl μ1 μ2 S ε δ)
    (DPcoupl μ1 μ3 S ε δ)
    (DPcoupl μ3 μ2 S ε δ)
    (DPcoupl μ3 μ3 S 0 0)
    DPcoupl (dbind (λ b, if b then μ1 else μ3) (biased_coin r Hr))
      (dbind (λ b, if b then μ2 else μ3) (biased_coin r Hr)) S
        (ln (1 + r*(exp(ε)-1))) (r*δ).
  Proof.
    intros Hcoupl12 Hcoupl13 Hcoupl23 Hcoupl33.
    assert (0 = r \/ 0 < r) as [<- | ] by lra.
    {
      (* degenerate case r=0 *)
      simpl.
      assert (biased_coin 0 Hr = dret (false)) as ->.
      - apply distr_ext.
        rewrite /biased_coin/pmf/=.
        rewrite /biased_coin_pmf/dret_pmf/=.
        intros []; simpl; lra.
      - rewrite !dret_id_left.
        rewrite !Rmult_0_l Rplus_0_r ln_1 //.
    }
    set (E2 b1 b2 := match b1,b2 with
                      | false,false => 0
                      | _,_ => ε
                     end).
    set (D2 b1 b2 := match b1,b2 with
                      | false,false => 0
                      | _,_ => δ
                     end).
    eapply (DPcoupl_dbind_adv_kanto_plain _ _ _ _ _ _ _ E2 D2).
    - intros [][]; rewrite /E2 /=; real_solver.
    - intros [][]; rewrite /D2 /=; real_solver.
    - intros h1 h2 Hh1 Hh2 Hh1h2.
      rewrite exp_ln; last first.
      {
        apply Rplus_lt_le_0_compat; [lra|].
        apply Rmult_le_pos; [lra|].
        apply Rle_minus_r.
        rewrite Rplus_0_l.
        by apply exp_pos_ge_1.
      }
      rewrite !SeriesC_bool.
      rewrite /biased_coin/pmf/=.
      (*  assert (h1 true + h1 false <= (exp ε + 1) * SeriesC (λ b : bool, fair_coin b * h2 b) + δ); last by lra. *)
      transitivity (r * h1 true + (1-r) * h2 false).
      {
        apply Rplus_le_compat_l.
        specialize (Hh1h2 false false).
        rewrite /E2 /D2 exp_0 /= in Hh1h2.
        real_solver.
      }
      set (ρ := r + (1-r) * exp (-ε)).
      transitivity (r * (ρ * (exp (ε) * h2 true + δ) + (1-ρ) * (exp(ε) * h2 false + δ)) + (1-r) * h2 false).
      {
        apply Rplus_le_compat_r.
        replace (h1 true) with* h1 true + (1-ρ) * h1 true) by lra.
        apply Rmult_le_compat_l; [real_solver|].
        apply Rplus_le_compat.
        - apply Rmult_le_compat_l.
          + rewrite /ρ.
            apply Rplus_le_le_0_compat; [lra|].
            apply Rmult_le_pos; [lra|].
            left.
            apply exp_pos.
          + specialize (Hh1h2 true true).
            done.
       - apply Rmult_le_compat_l.
         + rewrite /ρ.
           apply Rle_minus_r.
           rewrite Rplus_0_l.
           replace 1 with (r + (1-r) * 1) at 2 by lra.
           apply Rplus_le_compat_l.
           apply Rmult_le_compat_l; [lra|].
           rewrite exp_Ropp.
           replace 1 with (/1) by lra.
           apply Rinv_le_contravar; [lra|].
           by apply exp_pos_ge_1.
         + specialize (Hh1h2 true false).
           done.
      }
      replace (r * (ρ * (exp ε * h2 true + δ) + (1 - ρ) * (exp ε * h2 false + δ)) + (1-r) * h2 false)
        with (r * ρ * exp ε * h2 true + (r * (1 - ρ) * exp ε + (1-r)) * h2 false + r * δ) by lra.
      apply Rplus_le_compat_r.
      replace (1 + r * (exp ε-1)) with (r * exp ε + (1-r)) by lra.
      rewrite Rmult_plus_distr_l.
      apply Rplus_le_compat.
      + replace (r * ρ * exp ε * h2 true) with* exp ε * (r * h2 true)) by lra.
        apply Rmult_le_compat_r; [real_solver|].
        rewrite /ρ.
        rewrite Rmult_plus_distr_r.
        apply Rplus_le_compat_l.
        rewrite Rmult_assoc.
        rewrite -exp_plus.
        replace (-ε+ε) with 0 by lra.
        rewrite exp_0.
        lra.
      + rewrite -Rmult_assoc.
        apply Rmult_le_compat_r; [real_solver|].
        rewrite /ρ.
        rewrite Rmult_assoc.
        rewrite (Rmult_minus_distr_r _ _(exp ε)).
        rewrite Rmult_1_l.
        rewrite (Rmult_plus_distr_r _ _(exp ε)).
        rewrite Rmult_assoc.
        rewrite -exp_plus.
        replace (-ε+ε) with 0 by lra.
        rewrite exp_0.
        lra.
    - rewrite /E2/D2.
      intros [][]; simpl; auto.
  Qed.

  (*

   Failed attempt at subsampling composition


  Lemma DPcoupl_dbind_choice_adv (f : A → distr A') (g : B → distr B')
    (μ1 : distr A) (μ2 : distr B) (P : A -> Prop) (S : A → B → Prop) (S' : A' → B' → Prop)
    ε1 ε2 δ1 δ2 ε2' δ2' ε δ:
    let ρ := probp μ1 P in
    (0 <= δ1) → (0 <= δ2) → (0 <= δ2') ->
    (ε1 + ln (ρ * exp(ε2) + (1-ρ) * exp(ε2')) <= ε) ->
    (δ1 + ρ * δ2 + (1-ρ)* δ2' <= δ) ->
    (* Stronger version: (δ1 + δ1' + Rmax δ2 δ2' <= δ) -> *)
    (forall a a' b, P a -> ¬ P a' -> ¬(S a b /\ S a' b)) ->
    (∀ a b, (P a /\ S a b) → DPcoupl (f a) (g b) S' ε2 δ2) →
    (∀ a b, (¬P a /\ S a b) → DPcoupl (f a) (g b) S' ε2' δ2') →
    DPcoupl μ1 μ2 S ε1 δ1 →
    DPcoupl (dbind f μ1) (dbind g μ2) S' ε δ.
  Proof.
    intros Hρ.
    rewrite /Hρ.
    intros Hδ1 Hδ2 Hδ2' Hεleq Hδleq
      Hindep
      Hcoup_fg1  Hcoup_fg2 Hcoup_S h1 h2 Hh1pos Hh2pos Hh1h2S'.

    rewrite /pmf/=/dbind_pmf.
    (* To use the hypothesis that we have an R-ACoupling up to ε1 for μ1, μ2,
       we have to rewrite the sums in such a way as to isolate (the expectation
       of) a random variable X on the LHS and Y on the RHS, and ε1 on the
       RHS. *)

    (* First step: rewrite the LHS into a RV X on μ1. *)
    setoid_rewrite <- SeriesC_scal_r.
    rewrite <-(fubini_pos_seriesC (λ '(a,x), μ1 x * f x a * h1 a)).

    (* Boring Fubini sideconditions. *)
    2: { real_solver. }
    2: { intro a'.
         (* specialize (Hh1pos a'). *)
         apply (ex_seriesC_le _ μ1); auto.
         intro a; split.
         + apply Rmult_le_pos.
           * real_solver.
           * real_solver.
         + rewrite <- Rmult_1_r.
           rewrite Rmult_assoc.
           apply Rmult_le_compat_l; auto.
           rewrite <- Rmult_1_r.
           apply Rmult_le_compat; real_solver. }
    2: { setoid_rewrite SeriesC_scal_r.
         apply (ex_seriesC_le _ (λ a : A', SeriesC (λ x : A, μ1 x * f x a))); auto.
         + series.
         + apply (pmf_ex_seriesC (dbind f μ1)). }

    (* LHS: Pull the (μ1 b) factor out of the inner sum. *)
    assert (SeriesC (λ b : A, SeriesC (λ a : A', μ1 b * f b a * h1 a)) =
              SeriesC (λ b : A, μ1 b * SeriesC (λ a : A', f b a * h1 a))) as ->.
    { setoid_rewrite <- SeriesC_scal_l. series. }

    (* Second step: rewrite the RHS into a RV Y on μ2. *)
    (* RHS: Fubini. *)
    rewrite <-(fubini_pos_seriesC (λ '(b,x), μ2 x * g x b * h2 b)).
    2: by series.
    2:{ intro b'.
        specialize (Hh2pos b').
        apply (ex_seriesC_le _ μ2) ; auto.
        intro b; split.
        - series.
        - do 2 rewrite <- Rmult_1_r. series. }
    2:{ setoid_rewrite SeriesC_scal_r.
        apply (ex_seriesC_le _ (λ a : B', SeriesC (λ b : B, μ2 b * g b a))); auto.
        - intros b'; specialize (Hh2pos b'); split.
          + apply Rmult_le_pos;  | lra.
            apply (pmf_pos ((dbind g μ2)) b').
          + rewrite <- Rmult_1_r.
            apply Rmult_le_compat_l; auto.
            * apply SeriesC_ge_0'. real_solver.
            * real_solver.
        - apply (pmf_ex_seriesC (dbind g μ2)). }

    (* RHS: Factor out (μ2 b) *)
    assert (SeriesC (λ b : B, SeriesC (λ a : B', μ2 b * g b a * h2 a))
            = SeriesC (λ b : B, μ2 b * SeriesC (λ a : B', g b a * h2 a))) as ->.
    { apply SeriesC_ext; intro.
      rewrite <- SeriesC_scal_l.
      apply SeriesC_ext; real_solver. }

    (* Now let's split the sum depending on whether P holds *)

    assert (SeriesC (λ b : A, μ1 b * SeriesC (λ a : A', f b a * h1 a)) =
              SeriesC (λ b : A, μ1 b * (if (bool_decide (P b)) then  SeriesC (λ a : A', f b a * h1 a) else 0)) +
              SeriesC (λ b : A, μ1 b * (if (bool_decide (¬ P b)) then SeriesC (λ a : A', f b a * h1 a) else 0))) as ->.
    {
      rewrite -SeriesC_plus.
      - apply SeriesC_ext.
        intro a.
        case_bool_decide; case_bool_decide; real_solver.
      - apply (ex_seriesC_le _ μ1); auto.
        intros a; split.
        + apply Rmult_le_pos; auto.
          case_bool_decide; |lra.
          apply SeriesC_ge_0'; real_solver.
        + rewrite -{2}(Rmult_1_r (μ1 _)).
          apply Rmult_le_compat_l; auto.
          case_bool_decide; |lra.
          transitivity (SeriesC (f a)); auto.
          apply SeriesC_le; auto.
          real_solver.
      - apply (ex_seriesC_le _ μ1); auto.
        intros a; split.
        + apply Rmult_le_pos; auto.
          case_bool_decide; |lra.
          apply SeriesC_ge_0'; real_solver.
        + rewrite -{2}(Rmult_1_r (μ1 _)).
          apply Rmult_le_compat_l; auto.
          case_bool_decide; |lra.
          transitivity (SeriesC (f a)); auto.
          apply SeriesC_le; auto.
          real_solver.
    }


    (*  Stronger version: assert (
        SeriesC (λ b : A, μ1 b * (if (bool_decide (P b)) then  SeriesC (λ a : A', f b a * h1 a) else 0))
        <= SeriesC (λ b : A, μ1 b * if (bool_decide (P b)) then  (Rmax 0 (SeriesC (λ a : A', f b a * h1 a) - δ2)) else 0) + SeriesC (λ b : A, μ1 b * if (bool_decide (P b)) then δ2 else 0)
      ) as Htrans1.
    {
      rewrite -SeriesC_plus.
      - apply SeriesC_le.
        + intros a. split.
          * case_bool_decide; |real_solver.
            apply Rmult_le_pos; auto.
            apply SeriesC_ge_0'.
            real_solver.
          * case_bool_decide; |lra.
            rewrite -Rmult_plus_distr_l.
            apply Rmult_le_compat_l; auto.
            rewrite Rplus_max_distr_r.
            eapply Rle_trans; |apply Rmax_r.
            lra.
        + admit.
      - admit.
      - admit.
*)


    assert (
        SeriesC (λ b : A, μ1 b * (if (bool_decide (P b)) then  SeriesC (λ a : A', f b a * h1 a) else 0))
        <= SeriesC (λ b : A, μ1 b * if (bool_decide (P b)) then  (Rmax 0 (SeriesC (λ a : A', f b a * h1 a) - δ2)) else 0) +
          SeriesC (λ a, μ1 a * (if (bool_decide (P a)) then δ2 else 0))
      ) as Htrans1.
    {
      rewrite -SeriesC_plus.
      - apply SeriesC_le'.
        + intros a.
          rewrite -Rmult_plus_distr_l.
          apply Rmult_le_compat_l; auto.
          case_bool_decide; last by lra.
          rewrite Rplus_max_distr_r.
          etrans; |apply Rmax_r.
          lra.
        + apply (ex_seriesC_le _ μ1); auto.
          intros a; split.
          * apply Rmult_le_pos; auto.
            case_bool_decide; last by lra.
            apply SeriesC_ge_0'; real_solver.
          * case_bool_decide.
            ** rewrite -{2}(Rmult_1_r (μ1 a)).
               apply Rmult_le_compat_l; auto.
               apply (Rle_trans _ (SeriesC (f a))); auto.
               apply SeriesC_le; auto.
               real_solver.
            ** rewrite Rmult_0_r; auto.
        + apply ex_seriesC_plus.
          * apply (ex_seriesC_le _ μ1); auto.
            intro a; split.
            ** apply Rmult_le_pos; auto.
               case_bool_decide; last by lra.
               apply Rmax_l.
            ** rewrite <- Rmult_1_r.
               apply Rmult_le_compat_l; auto.
               case_bool_decide; last by lra.
               apply (Rle_trans _ (SeriesC (f a))); auto.
               apply Rmax_lub; auto.
               apply Rle_minus_l.
               apply (Rle_trans _ (SeriesC (f a))); last by lra.
               apply SeriesC_le; auto.
               real_solver.
          * apply (ex_seriesC_le _ (λ x, μ1 x * δ2)); |apply ex_seriesC_scal_r; auto.
            intro a; split.
            ** case_bool_decide; real_solver.
            ** case_bool_decide; real_solver.
      - apply (ex_seriesC_le _ μ1); auto.
        intros a; split.
        * case_bool_decide; last by lra.
          apply Rmult_le_pos; auto.
          apply Rmax_l.
        * rewrite -{2}(Rmult_1_r (μ1 a)).
          apply Rmult_le_compat_l; auto.
          apply (Rle_trans _ (SeriesC (f a))); auto.
          case_bool_decide; auto.
          apply Rmax_lub; auto.
          apply Rle_minus_l.
          apply (Rle_trans _ (SeriesC (f a))); last by lra.
          apply SeriesC_le; auto.
          real_solver.
     - apply (ex_seriesC_le _ (λ x, μ1 x * δ2)); |apply ex_seriesC_scal_r; auto.
       intro a; split; case_bool_decide; real_solver.
    }

    assert (
        SeriesC (λ b : A, μ1 b * (if (bool_decide (¬ P b)) then  SeriesC (λ a : A', f b a * h1 a) else 0))
        <= SeriesC (λ b : A, μ1 b * if (bool_decide (¬ P b)) then  (Rmax 0 (SeriesC (λ a : A', f b a * h1 a) - δ2')) else 0) +
            SeriesC (λ a, μ1 a * (if (bool_decide (¬ P a)) then δ2' else 0))
      ) as Htrans2.
    {
      rewrite -SeriesC_plus.
      - apply SeriesC_le'.
        + intros a.
          rewrite -Rmult_plus_distr_l.
          apply Rmult_le_compat_l; auto.
          case_bool_decide; last by lra.
          rewrite Rplus_max_distr_r.
          etrans; |apply Rmax_r.
          lra.
        + apply (ex_seriesC_le _ μ1); auto.
          intros a; split.
          * apply Rmult_le_pos; auto.
            case_bool_decide; last by lra.
            apply SeriesC_ge_0'; real_solver.
          * case_bool_decide.
            ** rewrite -{2}(Rmult_1_r (μ1 a)).
               apply Rmult_le_compat_l; auto.
               apply (Rle_trans _ (SeriesC (f a))); auto.
               apply SeriesC_le; auto.
               real_solver.
            ** rewrite Rmult_0_r; auto.
        + apply ex_seriesC_plus.
          * apply (ex_seriesC_le _ μ1); auto.
            intro a; split.
            ** apply Rmult_le_pos; auto.
               case_bool_decide; last by lra.
               apply Rmax_l.
            ** rewrite <- Rmult_1_r.
               apply Rmult_le_compat_l; auto.
               case_bool_decide; last by lra.
               apply (Rle_trans _ (SeriesC (f a))); auto.
               apply Rmax_lub; auto.
               apply Rle_minus_l.
               apply (Rle_trans _ (SeriesC (f a))); last by lra.
               apply SeriesC_le; auto.
               real_solver.
          * apply (ex_seriesC_le _ (λ x, μ1 x * δ2')); |apply ex_seriesC_scal_r; auto.
            intro a; split.
            ** case_bool_decide; real_solver.
            ** case_bool_decide; real_solver.
      - apply (ex_seriesC_le _ μ1); auto.
        intros a; split.
        * case_bool_decide; last by lra.
          apply Rmult_le_pos; auto.
          apply Rmax_l.
        * rewrite -{2}(Rmult_1_r (μ1 a)).
          apply Rmult_le_compat_l; auto.
          apply (Rle_trans _ (SeriesC (f a))); auto.
          case_bool_decide; auto.
          apply Rmax_lub; auto.
          apply Rle_minus_l.
          apply (Rle_trans _ (SeriesC (f a))); last by lra.
          apply SeriesC_le; auto.
          real_solver.
     - apply (ex_seriesC_le _ (λ x, μ1 x * δ2')); |apply ex_seriesC_scal_r; auto.
       intro a; split; case_bool_decide; real_solver.
    }

    erewrite (Rplus_le_compat); eauto.
    rewrite /DPcoupl in Hcoup_fg1.
    rewrite /DPcoupl in Hcoup_fg2.
    rewrite /DPcoupl in Hcoup_S.


    assert (forall a b, S a b -> (if bool_decide (P a) then Rmax 0 (SeriesC (λ a' : A', f a a' * h1 a') - δ2) else 0) <=
                 (if bool_decide (exists a', P a' /\ S a' b ) then Rmin 1 (exp (ε2) * SeriesC (λ b' : B', g b b' * h2 b')) else 0 ) ) as Htrans3.
    {
      intros a b HS1.
      case_bool_decide as HdecL; case_bool_decide as HdecR.
      - apply Rmin_glb; apply Rmax_lub; first by lra.
        + apply Rle_minus_l.
          apply (Rle_trans _ 1); last by real_solver.
          apply (Rle_trans _ (SeriesC (f a))); auto.
          apply SeriesC_le; auto; real_solver.
        + series.
          left.
          by apply exp_pos.
        + apply Rle_minus_l.
          by apply Hcoup_fg1.
      - exfalso.
        apply HdecR.
        by exists a.
      - apply Rmin_glb; first lra.
        apply Rmult_le_pos.
        + left. by apply exp_pos.
        + apply SeriesC_ge_0'.
          intros; real_solver.
      - lra.
    }


    assert (forall a b, S a b -> (if bool_decide (¬ P a) then Rmax 0 (SeriesC (λ a' : A', f a a' * h1 a') - δ2') else 0) <=
                 (if bool_decide (exists a', ¬ P a' /\ S a' b) then Rmin 1 (exp (ε2') * SeriesC (λ b' : B', g b b' * h2 b')) else 0) ) as Htrans4.
    {
      intros a b HS4.
      case_bool_decide as HdecL; case_bool_decide as HdecR.
      - apply Rmin_glb; apply Rmax_lub; first by lra.
        + apply Rle_minus_l.
          apply (Rle_trans _ 1); last by real_solver.
          apply (Rle_trans _ (SeriesC (f a))); auto.
          apply SeriesC_le; auto; real_solver.
        + series.
          left.
          by apply exp_pos.
        + apply Rle_minus_l.
          by apply Hcoup_fg2.
      - exfalso.
        apply HdecR.
        by exists a.
      - apply Rmin_glb; first lra.
        apply Rmult_le_pos.
        + left. by apply exp_pos.
        + apply SeriesC_ge_0'.
          intros; real_solver.
      - lra.
    }

    assert (forall a b, S a b -> (if bool_decide (P a) then Rmax 0 (SeriesC (λ a' : A', f a a' * h1 a') - δ2)
                                                 else Rmax 0 (SeriesC (λ a' : A', f a a' * h1 a') - δ2'))
                           <=
                           (if bool_decide (exists a', P a' /\ S a' b ) then Rmin 1 (exp (ε2) * SeriesC (λ b' : B', g b b' * h2 b'))
                                                                   else Rmin 1 (exp (ε2') * SeriesC (λ b' : B', g b b' * h2 b'))))
      as Htrans5.
    { admit. }

    epose proof (Hcoup_S _ _ _ _ Htrans5) as HauxS.

    assert (forall a b c d : R, a + b + (c + d) = (a + c) + (b + d)) as ->.
    { intros. lra. }
    rewrite -SeriesC_plus; |admit|admit.
    rewrite (SeriesC_ext _ (λ a, μ1 a * (if bool_decide (P a) then Rmax 0 (SeriesC (λ a' : A', f a a' * h1 a') - δ2)
                                 else Rmax 0 (SeriesC (λ a' : A', f a a' * h1 a') - δ2')) )); last first.
    {
      intros a.
      case_bool_decide; case_bool_decide.
      - done.
      - lra.
      - lra.
      - done.
    }

    erewrite Rplus_le_compat; eauto.
    2:{ apply Rle_refl. }

    assert (forall a b c d : R, a + b + (c + d) = a + (b + (c + d))) as ->.
    { intros; lra. }
    apply Rplus_le_compat; last first.
    {
      rewrite <- Hδleq.
      admit.
    }
    etrans; last first.
    {
      eapply Rmult_le_compat_r; admit|.
      apply exp_mono.
      apply Hεleq.
    }
    rewrite exp_plus.
    rewrite Rmult_assoc.
    apply Rmult_le_compat_l; admit|.
    rewrite exp_ln; |admit.

    transitivity (SeriesC
              (λ b : B,
                  μ2 b *
                    (if bool_decide (∃ a' : A, P a' ∧ S a' b)
                     then exp ε2 * SeriesC (λ b' : B', g b b' * h2 b')
                     else exp ε2' * SeriesC (λ b' : B', g b b' * h2 b')))).
    {
      apply SeriesC_le; |admit.
      intro b; split.
      - case_bool_decide; admit.
      - case_bool_decide; admit.
    }



    { intros. lra. }
    rewrite -SeriesC_plus.
    2:{
      apply (ex_seriesC_le _ (λ b, exp ε1 * μ2 b * exp ε2)).
      - intros b.
        split.
        + apply Rmult_le_pos; left; apply exp_pos |.
          apply Rmult_le_pos; auto.
          case_bool_decide; |lra.
          apply Rmin_glb; lra|.
          apply Rmult_le_pos; left; apply exp_pos |.
          apply SeriesC_ge_0; real_solver|.
          apply (ex_seriesC_le _ (g b)); auto.
          real_solver.
        + rewrite Rmult_assoc.
          apply Rmult_le_compat_l; left; apply exp_pos |.
          apply Rmult_le_compat_l; auto.
          case_bool_decide; |left; apply exp_pos.
          etrans; apply Rmin_r |.
          rewrite -{2}(Rmult_1_r (exp ε2)).
          apply Rmult_le_compat_l; left; apply exp_pos|.
          transitivity (SeriesC (g b)); auto.
          apply SeriesC_le; auto.
          real_solver.
      - apply ex_seriesC_scal_r.
        by apply ex_seriesC_scal_l.
    }
    2:{
      apply (ex_seriesC_le _ (λ b, exp ε1' * μ2 b * exp ε2')).
      - intros b.
        split.
        + apply Rmult_le_pos; left; apply exp_pos |.
          apply Rmult_le_pos; auto.
          case_bool_decide; |lra.
          apply Rmin_glb; lra|.
          apply Rmult_le_pos; left; apply exp_pos |.
          apply SeriesC_ge_0; real_solver|.
          apply (ex_seriesC_le _ (g b)); auto.
          real_solver.
        + rewrite Rmult_assoc.
          apply Rmult_le_compat_l; left; apply exp_pos |.
          apply Rmult_le_compat_l; auto.
          case_bool_decide; |left; apply exp_pos.
          etrans; apply Rmin_r |.
          rewrite -{2}(Rmult_1_r (exp ε2')).
          apply Rmult_le_compat_l; left; apply exp_pos|.
          transitivity (SeriesC (g b)); auto.
          apply SeriesC_le; auto.
          real_solver.
      - apply ex_seriesC_scal_r.
        by apply ex_seriesC_scal_l.
    }
    apply Rplus_le_compat; last first.
    {
      rewrite -SeriesC_plus; admit.
        (*
      - etrans; |apply Hδleq.
        apply Rplus_le_compat. lra|.
        transitivity (SeriesC (λ x, μ1 x * δ2)).
        + apply SeriesC_le; |apply ex_seriesC_scal_r; auto.
          intros; case_bool_decide; case_bool_decide; try done.
          * rewrite Rmult_0_r Rplus_0_r.
            real_solver.
          * rewrite Rmult_0_r Rplus_0_l.
            real_solver.
        + rewrite SeriesC_scal_r.
          rewrite -{2}(Rmult_1_l δ2).
          apply Rmult_le_compat_r; auto.
      - apply (ex_seriesC_le _ (λ x, μ1 x * δ2)); |apply ex_seriesC_scal_r; auto.
        intro a; split; case_bool_decide; real_solver.
      - apply (ex_seriesC_le _ (λ x, μ1 x * δ2)); |apply ex_seriesC_scal_r; auto.
        intro a; split; case_bool_decide; real_solver.
        *)

        
    }


    apply SeriesC_le.
    2:{
      apply ex_seriesC_scal_l.
      apply (ex_seriesC_le _ μ2); auto.
      intros b; split.
      - apply Rmult_le_pos; auto.
        apply SeriesC_ge_0'; real_solver.
      - rewrite -{2}(Rmult_1_r (μ2 b)).
        apply Rmult_le_compat_l; auto.
        transitivity (SeriesC (g b)); auto.
        apply SeriesC_le; auto.
        real_solver.
    }

    intros b.
    split.
    - case_bool_decide as HdecL; case_bool_decide as HdecR.
      + apply Rplus_le_le_0_compat.
        * apply Rmult_le_pos; left; apply exp_pos|.
          apply Rmult_le_pos; auto.
          apply Rmin_glb; lra|.
          apply Rmult_le_pos; left; apply exp_pos|.
          apply SeriesC_ge_0'; real_solver.
        * apply Rmult_le_pos; left; apply exp_pos|.
          apply Rmult_le_pos; auto.
          apply Rmin_glb; lra|.
          apply Rmult_le_pos; left; apply exp_pos|.
          apply SeriesC_ge_0'; real_solver.
     + rewrite !Rmult_0_r Rplus_0_r.
       apply Rmult_le_pos; left; apply exp_pos|.
       apply Rmult_le_pos; auto.
       apply Rmin_glb; lra|.
       apply Rmult_le_pos; left; apply exp_pos|.
       apply SeriesC_ge_0'; real_solver.
     + rewrite !Rmult_0_r Rplus_0_l.
       apply Rmult_le_pos; left; apply exp_pos|.
       apply Rmult_le_pos; auto.
       apply Rmin_glb; lra|.
       apply Rmult_le_pos; left; apply exp_pos|.
       apply SeriesC_ge_0'; real_solver.
     + lra.

    - do 2 rewrite -Rmult_assoc.
      assert (forall x y z r, x * y * (z * r) = (x * z) * (y * r)) as Haux_rw by real_solver.
      case_bool_decide as HdecL; case_bool_decide as HdecR.
      + destruct HdecL as a [? ?].
        destruct HdecR as a' [? ?].
        exfalso.
        eapply Hindep; eauto.
      + rewrite Rmult_0_r Rplus_0_r.
        rewrite Rmult_min_distr_l.
        * eapply Rle_trans ; apply Rmin_r|.
          rewrite Haux_rw.
          rewrite -exp_plus.
          apply Rmult_le_compat.
          ** left; apply exp_pos.
          ** apply Rmult_le_pos; auto.
             apply SeriesC_ge_0'; real_solver.
          ** apply exp_mono; auto.
          ** lra.
       * apply Rmult_le_pos; auto.
         left; apply exp_pos.
      + rewrite Rmult_0_r Rplus_0_l.
        rewrite Rmult_min_distr_l.
        * eapply Rle_trans ; apply Rmin_r|.
          rewrite Haux_rw.
          rewrite -exp_plus.
          apply Rmult_le_compat.
          ** left; apply exp_pos.
          ** apply Rmult_le_pos; auto.
             apply SeriesC_ge_0'; real_solver.
          ** apply exp_mono; auto.
          ** lra.
       * apply Rmult_le_pos; auto.
         left; apply exp_pos.
      + rewrite !Rmult_0_r Rplus_0_l.
        apply Rmult_le_pos; left; apply exp_pos|.
        apply Rmult_le_pos; auto.
        apply SeriesC_ge_0'; real_solver.
   Unshelve.
   1:{
     intros a. split.
     - case_bool_decide; apply Rmax_l|lra.
     - case_bool_decide; |lra.
       apply Rmax_lub; lra|.
       apply Rle_minus_r.
       transitivity 1; |lra.
       transitivity (SeriesC (f a)); auto.
       apply SeriesC_le; auto.
       real_solver.
   }
   1:{
     intros b. split.
     - case_bool_decide; |lra.
       apply Rmin_glb; lra|.
       apply Rmult_le_pos; left; apply exp_pos|.
       apply SeriesC_ge_0'; real_solver.
     - case_bool_decide; apply Rmin_l|lra.
   }

   1:{
     intros a. split.
     - case_bool_decide; apply Rmax_l|lra.
     - case_bool_decide; |lra.
       apply Rmax_lub; lra|.
       apply Rle_minus_r.
       transitivity 1; |lra.
       transitivity (SeriesC (f a)); auto.
       apply SeriesC_le; auto.
       real_solver.
   }
   1:{
     intros b. split.
     - case_bool_decide; |lra.
       apply Rmin_glb; lra|.
       apply Rmult_le_pos; left; apply exp_pos|.
       apply SeriesC_ge_0'; real_solver.
     - case_bool_decide; apply Rmin_l|lra.
   }

  Qed.

  *)


  (*
  Lemma DPcoupl_dbind_adv_rhs (f : A → distr A') (g : B → distr B')
    (μ1 : distr A) (μ2 : distr B) (S : A → B → Prop) (S' : A' → B' → Prop)
    ε1 ε2 δ1 δ2 (Δ2 : B → R) :
    (0 <= δ1) → (∀ b, 0 <= (Δ2 b) <= 1) →
    (* (SeriesC (λ a, μ1 a * (E2 a)) <= ε2) → *)
    (SeriesC (λ b, μ2 b * (Δ2 b)) = δ2) →
    (∀ a b, S a b → DPcoupl (f a) (g b) S' ε2 (Δ2 b)) →
    DPcoupl μ1 μ2 S ε1 δ1 →
    DPcoupl (dbind f μ1) (dbind g μ2) S' (ε1 + ε2) (δ1 + exp(ε1)*δ2).
  Proof.
    intros Hδ1 HΔ2 <- Hcoup_fg Hcoup_S h1 h2 Hh1pos Hh2pos Hh1h2S.
    rewrite {-3}/pmf/=/dbind_pmf.
    (* To use the hypothesis that we have an R-ACoupling up to ε1 for μ1, μ2,
       we have to rewrite the sums in such a way as to isolate (the expectation
       of) a random variable X on the LHS and Y on the RHS, and ε1 on the
       RHS. *)

    (* First step: rewrite the LHS into a RV X on μ1. *)
    setoid_rewrite <- SeriesC_scal_r.
    rewrite <-(fubini_pos_seriesC (λ '(a,x), μ1 x * f x a * h1 a)).

    (* Boring Fubini sideconditions. *)
    2: { real_solver. }
    2: { intro a'.
         (* specialize (Hh1pos a'). *)
         apply (ex_seriesC_le _ μ1); auto.
         intro a; split.
         + apply Rmult_le_pos.
           * real_solver.
           * real_solver.
         + rewrite <- Rmult_1_r.
           rewrite Rmult_assoc.
           apply Rmult_le_compat_l; auto.
           rewrite <- Rmult_1_r.
           apply Rmult_le_compat; real_solver. }
    2: { setoid_rewrite SeriesC_scal_r.
         apply (ex_seriesC_le _ (λ a : A', SeriesC (λ x : A, μ1 x * f x a))); auto.
         + series.
         + apply (pmf_ex_seriesC (dbind f μ1)). }

    (* LHS: Pull the (μ1 b) factor out of the inner sum. *)
    assert (SeriesC (λ b : A, SeriesC (λ a : A', μ1 b * f b a * h1 a)) =
              SeriesC (λ b : A, μ1 b * SeriesC (λ a : A', f b a * h1 a))) as ->.
    { setoid_rewrite <- SeriesC_scal_l. series. }

    (* Second step: rewrite the RHS into a RV Y on μ2. *)
    (* RHS: Fubini. *)
    rewrite <-(fubini_pos_seriesC (λ '(b,x), μ2 x * g x b * h2 b)).
    2: by series.
    2:{ intro b'.
        specialize (Hh2pos b').
        apply (ex_seriesC_le _ μ2) ; auto.
        intro b; split.
        - series.
        - do 2 rewrite <- Rmult_1_r. series. }
    2:{ setoid_rewrite SeriesC_scal_r.
        apply (ex_seriesC_le _ (λ a : B', SeriesC (λ b : B, μ2 b * g b a))); auto.
        - intros b'; specialize (Hh2pos b'); split.
          + apply Rmult_le_pos;  | lra.
            apply (pmf_pos ((dbind g μ2)) b').
          + rewrite <- Rmult_1_r.
            apply Rmult_le_compat_l; auto.
            * apply SeriesC_ge_0'. real_solver.
            * real_solver.
        - apply (pmf_ex_seriesC (dbind g μ2)). }

    (* RHS: Factor out (μ2 b) *)
    assert (SeriesC (λ b : B, SeriesC (λ a : B', μ2 b * g b a * h2 a))
            = SeriesC (λ b : B, μ2 b * SeriesC (λ a : B', g b a * h2 a))) as ->.
    { apply SeriesC_ext; intro.
      rewrite <- SeriesC_scal_l.
      apply SeriesC_ext; real_solver. }

    rewrite -Rplus_assoc.
    apply Rle_minus_l.


    (* To construct X, we want to push ε2 into the inner sum. We don't do this
       directly, because X might be larger than 1, but
       our assumption on the ε1 R-ACoupling requires it to be valued in 0,1.
       Instead, we take min(1, exp(ε2) * (Σ(a:A')(f b a * h1 a))).
       ALT: could use a more fine-grained min inside the sum?
     *)


    assert (exp (ε1) * SeriesC (λ b : B, μ2 b * (Rmin 1 (exp (ε2) * SeriesC (λ a : B', g b a * h2 a)))) + δ1
            <= exp (ε1 + ε2) * SeriesC (λ b : B, μ2 b * SeriesC (λ a : B', g b a * h2 a)) + δ1) as <-.
    {
       apply Rplus_le_compat_r.
       rewrite exp_plus.
       rewrite Rmult_assoc.
       rewrite -(SeriesC_scal_l _ (exp ε2)).
       apply Rmult_le_compat_l; left; apply exp_pos |.
       apply SeriesC_le.
       - intros b; split.
         + apply Rmult_le_pos; auto.
           apply Rmin_glb; lra |.
           apply Rmult_le_pos; left; apply exp_pos |.
           apply SeriesC_ge_0'.
           real_solver.
         + rewrite Rmult_min_distr_l; auto.
           etrans; apply Rmin_r | lra.
       - apply ex_seriesC_scal_l.
         apply (ex_seriesC_le _ μ2); auto.
         intro b; split.
         + apply Rmult_le_pos; auto.
           apply SeriesC_ge_0'.
           intro; apply Rmult_le_pos; auto.
           apply Hh2pos.
         + rewrite <- Rmult_1_r.
           apply Rmult_le_compat_l; auto.
           apply (Rle_trans _ (SeriesC (g b))); auto.
           apply SeriesC_le; auto.
           real_solver.
    }

    assert (
        SeriesC (λ b : A, μ1 b * SeriesC (λ a : A', f b a * h1 a)) -
          SeriesC (λ a, μ1 a * Δ2 a)
        <= SeriesC (λ b : A, μ1 b * Rmax 0 (SeriesC (λ a : A', f b a * h1 a) - Δ2 b))
      ) as ->.
    {
      apply (Rle_trans _ (SeriesC (λ b : A, μ1 b * SeriesC (λ a : A', f b a * h1 a) - μ1 b * Δ2 b))).
      - rewrite SeriesC_minus.
        + apply Rplus_le_compat_l.
          apply Ropp_le_contravar.
          done.
        + apply (ex_seriesC_le _ μ1); auto.
         intro a; split.
         * apply Rmult_le_pos; auto.
           apply SeriesC_ge_0'.
           intro; apply Rmult_le_pos; auto.
           apply Hh1pos.
         * rewrite <- Rmult_1_r.
           apply Rmult_le_compat_l; auto.
           apply (Rle_trans _ (SeriesC (f a))); auto.
           apply SeriesC_le; auto.
           real_solver.
        + apply (ex_seriesC_le _ μ1); auto.
          intros; real_solver.
      - apply SeriesC_le'.
        + intros a.
          rewrite -Rmult_minus_distr_l.
          apply Rmult_le_compat_l; auto.
          apply Rmax_r.
        + apply ex_seriesC_plus.
          * apply (ex_seriesC_le _ μ1); auto.
            intro a; split.
            ** apply Rmult_le_pos; auto.
               apply SeriesC_ge_0'.
               intro; apply Rmult_le_pos; auto.
               apply Hh1pos.
            ** rewrite <- Rmult_1_r.
               apply Rmult_le_compat_l; auto.
               apply (Rle_trans _ (SeriesC (f a))); auto.
               apply SeriesC_le; auto.
               real_solver.
          * apply (ex_seriesC_ext (λ x, -1 * (μ1 x * Δ2 x))).
            1: intros; real_solver.
            apply ex_seriesC_scal_l.
            apply (ex_seriesC_le _ μ1); auto.
            intros; real_solver.
        + apply (ex_seriesC_le _ μ1); auto.
          intros a; split.
          * apply Rmult_le_pos; auto.
            apply Rmax_l.
          * rewrite -{2}(Rmult_1_r (μ1 a)).
            apply Rmult_le_compat_l; auto.
            apply Rmax_lub; first by lra.
            apply Rle_minus_l.
            apply (Rle_trans _ 1); last by real_solver.
            apply (Rle_trans _ (SeriesC (f a))); auto.
            apply SeriesC_le; auto.
            real_solver.
    }

    (*
        Now we instantiate the lifting definitions and use them to prove the
        inequalities
    *)

    rewrite /DPcoupl in Hcoup_S.
    apply Hcoup_S.
    + intro; split; first apply Rmax_l.
      apply Rmax_lub; first by lra.
      apply Rle_minus_l.
      apply (Rle_trans _ 1); last by real_solver.
      apply (Rle_trans _ (SeriesC (f a))); auto.
      apply SeriesC_le; auto; real_solver.
    + intro; split.
      * apply Rmin_glb; lra |.
        apply Rmult_le_pos.
        ** left. apply exp_pos.
        ** apply SeriesC_ge_0'; intro b'.
           specialize (Hh2pos b'); real_solver.
      * apply Rmin_l.

    + intros a b Rab.
      apply Rmin_glb; apply Rmax_lub; first by lra.
      * apply Rle_minus_l.
        apply (Rle_trans _ 1); last by real_solver.
        apply (Rle_trans _ (SeriesC (f a))); auto.
        apply SeriesC_le; auto; real_solver.
      * series.
        left.
        by apply exp_pos.
      * apply Rle_minus_l.
        by apply Hcoup_fg.
  Qed.
  *)


  (* The hypothesis (0 ≤ δ1) is not really needed, I just kept it for symmetry *)
  Lemma DPcoupl_dbind `{Countable A, Countable B, Countable A', Countable B'} (f : A distr A') (g : B distr B')
    (μ1 : distr A) (μ2 : distr B) (R : A B Prop) (S : A' B' Prop) ε1 ε2 δ1 δ2:
    (0 <= δ1) ->
    (0 <= δ2) ->
    ( a b, R a b DPcoupl (f a) (g b) S ε2 δ2) DPcoupl μ1 μ2 R ε1 δ1 DPcoupl (dbind f μ1) (dbind g μ2) S (ε1 + ε2) (δ1 + δ2).
  Proof.
    intros Hδ1 Hδ2 Hcoup_fg Hcoup_R.
    destruct (decide (δ2 <= 1)).
    2:{
      apply DPcoupl_1.
      lra.
    }
    intros h1 h2 Hh1pos Hh2pos Hh1h2S.
    etransitivity.
    {
      eapply (DPcoupl_dbind_adv_lhs f g μ1 μ2 R S ε1 ε2 δ1
                (SeriesC (λ (a:A), μ1 a * δ2) ) (λ (a:A), δ2)); eauto.
    }
    apply Rplus_le_compat_l.
    apply Rplus_le_compat_l.
    rewrite SeriesC_scal_r.
    rewrite -{2}(Rmult_1_l δ2).
    apply Rmult_le_compat; auto; lra.
  Qed.

 (* OLD Proof
    apply DPcoupl_dbind_adv_lhs.
    rewrite /pmf/=/dbind_pmf.
    (* To use the hypothesis that we have an R-ACoupling up to ε1 for μ1, μ2,
       we have to rewrite the sums in such a way as to isolate (the expectation
       of) a random variable X on the LHS and Y on the RHS, and ε1 on the
       RHS. *)

    (* First step: rewrite the LHS into a RV X on μ1. *)
    setoid_rewrite <- SeriesC_scal_r.
    rewrite <-(fubini_pos_seriesC (λ '(a,x), μ1 x * f x a * h1 a)).

    (* Boring Fubini sideconditions. *)
    2: { real_solver. }
    2: { intro a'.
         (* specialize (Hh1pos a'). *)
         apply (ex_seriesC_le _ μ1); auto.
         intro a; split.
         + apply Rmult_le_pos.
           * real_solver.
           * real_solver.
         + rewrite <- Rmult_1_r.
           rewrite Rmult_assoc.
           apply Rmult_le_compat_l; auto.
           rewrite <- Rmult_1_r.
           apply Rmult_le_compat; real_solver. }
    2: { setoid_rewrite SeriesC_scal_r.
         apply (ex_seriesC_le _ (λ a : A', SeriesC (λ x : A, μ1 x * f x a))); auto.
         + series.
         + apply (pmf_ex_seriesC (dbind f μ1)). }

    (* LHS: Pull the (μ1 b) factor out of the inner sum. *)
    assert (SeriesC (λ b : A, SeriesC (λ a : A', μ1 b * f b a * h1 a)) =
              SeriesC (λ b : A, μ1 b * SeriesC (λ a : A', f b a * h1 a))) as ->.
    { setoid_rewrite <- SeriesC_scal_l. series. }

    (* Second step: rewrite the RHS into a RV Y on μ2. *)
    (* RHS: Fubini. *)
    rewrite <-(fubini_pos_seriesC (λ '(b,x), μ2 x * g x b * h2 b)).
    2: by series.
    2:{ intro b'.
        specialize (Hh2pos b').
        apply (ex_seriesC_le _ μ2) ; auto.
        intro b; split.
        - series.
        - do 2 rewrite <- Rmult_1_r. series. }
    2:{ setoid_rewrite SeriesC_scal_r.
        apply (ex_seriesC_le _ (λ a : B', SeriesC (λ b : B, μ2 b * g b a))); auto.
        - intros b'; specialize (Hh2pos b'); split.
          + apply Rmult_le_pos;  | lra.
            apply (pmf_pos ((dbind g μ2)) b').
          + rewrite <- Rmult_1_r.
            apply Rmult_le_compat_l; auto.
            * apply SeriesC_ge_0'. real_solver.
            * real_solver.
        - apply (pmf_ex_seriesC (dbind g μ2)). }

    (* RHS: Factor out (μ2 b) *)
    assert (SeriesC (λ b : B, SeriesC (λ a : B', μ2 b * g b a * h2 a))
            = SeriesC (λ b : B, μ2 b * SeriesC (λ a : B', g b a * h2 a))) as ->.
    { apply SeriesC_ext; intro.
      rewrite <- SeriesC_scal_l.
      apply SeriesC_ext; real_solver. }

    rewrite -Rplus_assoc.
    apply Rle_minus_l.


    (* To construct X, we want to push ε2 into the inner sum. We don't do this
       directly, because X might be larger than 1, but
       our assumption on the ε1 R-ACoupling requires it to be valued in 0,1.
       Instead, we take min(1, exp(ε2) * (Σ(a:A')(f b a * h1 a))).
       ALT: could use a more fine-grained min inside the sum?
     *)


    assert (exp (ε1) * SeriesC (λ b : B, μ2 b * (Rmin 1 (exp (ε2) * SeriesC (λ a : B', g b a * h2 a)))) + δ1
            <= exp (ε1 + ε2) * SeriesC (λ b : B, μ2 b * SeriesC (λ a : B', g b a * h2 a)) + δ1) as <-.
    {
       apply Rplus_le_compat_r.
       rewrite exp_plus.
       rewrite Rmult_assoc.
       rewrite -(SeriesC_scal_l _ (exp ε2)).
       apply Rmult_le_compat_l; left; apply exp_pos |.
       apply SeriesC_le.
       - intros b; split.
         + apply Rmult_le_pos; auto.
           apply Rmin_glb; lra |.
           apply Rmult_le_pos; left; apply exp_pos |.
           apply SeriesC_ge_0'.
           real_solver.
         + rewrite Rmult_min_distr_l; auto.
           etrans; apply Rmin_r | lra.
       - apply ex_seriesC_scal_l.
         apply (ex_seriesC_le _ μ2); auto.
         intro b; split.
         + apply Rmult_le_pos; auto.
           apply SeriesC_ge_0'.
           intro; apply Rmult_le_pos; auto.
           apply Hh2pos.
         + rewrite <- Rmult_1_r.
           apply Rmult_le_compat_l; auto.
           apply (Rle_trans _ (SeriesC (g b))); auto.
           apply SeriesC_le; auto.
           real_solver.
    }

    assert (
        SeriesC (λ b : A, μ1 b * SeriesC (λ a : A', f b a * h1 a)) - δ2
        <= SeriesC (λ b : A, μ1 b * Rmax 0 (SeriesC (λ a : A', f b a * h1 a) - δ2))
      ) as ->.
    {
      apply (Rle_trans _ (SeriesC (λ b : A, μ1 b * SeriesC (λ a : A', f b a * h1 a) - μ1 b * δ2))).
      - rewrite SeriesC_minus.
        + apply Rplus_le_compat_l.
          apply Ropp_le_contravar.
          rewrite SeriesC_scal_r.
          real_solver.
        + apply (ex_seriesC_le _ μ1); auto.
         intro a; split.
         * apply Rmult_le_pos; auto.
           apply SeriesC_ge_0'.
           intro; apply Rmult_le_pos; auto.
           apply Hh1pos.
         * rewrite <- Rmult_1_r.
           apply Rmult_le_compat_l; auto.
           apply (Rle_trans _ (SeriesC (f a))); auto.
           apply SeriesC_le; auto.
           real_solver.
        + apply ex_seriesC_scal_r; auto.
      - apply SeriesC_le'.
        + intros a.
          rewrite -Rmult_minus_distr_l.
          apply Rmult_le_compat_l; auto.
          apply Rmax_r.
        + apply ex_seriesC_plus.
          * apply (ex_seriesC_le _ μ1); auto.
            intro a; split.
            ** apply Rmult_le_pos; auto.
               apply SeriesC_ge_0'.
               intro; apply Rmult_le_pos; auto.
               apply Hh1pos.
            ** rewrite <- Rmult_1_r.
               apply Rmult_le_compat_l; auto.
               apply (Rle_trans _ (SeriesC (f a))); auto.
               apply SeriesC_le; auto.
               real_solver.
          * setoid_rewrite Ropp_mult_distr_r.
            apply ex_seriesC_scal_r; auto.
        + apply (ex_seriesC_le _ μ1); auto.
          intros a; split.
          * apply Rmult_le_pos; auto.
            apply Rmax_l.
          * rewrite -{2}(Rmult_1_r (μ1 a)).
            apply Rmult_le_compat_l; auto.
            apply Rmax_lub; first by lra.
            apply Rle_minus_l.
            apply (Rle_trans _ 1); last by lra.
            apply (Rle_trans _ (SeriesC (f a))); auto.
            apply SeriesC_le; auto.
            real_solver.
    }

    (*
        Now we instantiate the lifting definitions and use them to prove the
        inequalities
    *)

    rewrite /DPcoupl in Hcoup_R.
    apply Hcoup_R.
    + intro; split; first apply Rmax_l.
      apply Rmax_lub; first by lra.
      apply  (Rle_trans _ (SeriesC (λ a0 : A', f a a0 * h1 a0))); first by lra.
      apply (Rle_trans _ (SeriesC (f a))); auto.
      apply SeriesC_le; auto.
      intro a'.
      specialize (Hh1pos a'); real_solver.
    + intro; split.
      * apply Rmin_glb; lra |.
        apply Rmult_le_pos.
        ** left. apply exp_pos.
        ** apply SeriesC_ge_0'; intro b'.
           specialize (Hh2pos b'); real_solver.
      * apply Rmin_l.

    + intros a b Rab.
      apply Rmin_glb; apply Rmax_lub; first by lra.
      * apply (Rle_trans _ (SeriesC (λ a0 : A', f a a0 * h1 a0))); first by lra.
        apply (Rle_trans _ (SeriesC (f a))); auto.
        apply SeriesC_le; auto.
        intro a'.
        real_solver.
      * series.
        left.
        by apply exp_pos.
      * apply Rle_minus_l.
        by apply Hcoup_fg.
  Qed.
  *)


  Lemma DPcoupl_dbind' `{Countable A, Countable B, Countable A', Countable B'} (ε1 ε2 ε : R) (δ1 δ2 δ : R)
    (f : A distr A') (g : B distr B')
    (μ1 : distr A) (μ2 : distr B) (R : A B Prop) (S : A' B' Prop) :
    ε = ε1 + ε2
    0 <= δ1 ->
    0 <= δ2 ->
    δ = δ1 + δ2
    ( a b, R a b DPcoupl (f a) (g b) S ε2 δ2)
    DPcoupl μ1 μ2 R ε1 δ1
    DPcoupl (dbind f μ1) (dbind g μ2) S ε δ.
  Proof. intros -> ? ? ->. by eapply DPcoupl_dbind. Qed.

  Lemma DPcoupl_map `{Countable A, Countable B, Countable A', Countable B'}
    (f : A A') (g : B B') (μ1 : distr A) (μ2 : distr B) (R : A' B' Prop) ε δ:
    (0 <= ε) -> (0 <= δ) ->
    DPcoupl μ1 μ2 (λ a a', R (f a) (g a')) ε δ DPcoupl (dmap f μ1) (dmap g μ2) R ε δ.
  Proof.
    intros Hleq1 Hleq2 Hcoupl. rewrite /dmap.
    rewrite -(Rplus_0_r ε) -(Rplus_0_r δ).
    eapply (DPcoupl_dbind _ _ _ _ (λ (a : A) (a' : B), R (f a) (g a')) _); auto; [lra|].
    intros a b Hab.
    by eapply DPcoupl_dret.
  Qed.


  Lemma DPcoupl_map_inv `{Countable A, Countable B, Countable A', Countable B'}
    (μ1 : distr A) (μ2 : distr B) (f1 : A A') (f2 : B B') ψ ε δ :
    0 <= ε ->
    DPcoupl (dmap f1 μ1) (dmap f2 μ2) ψ ε δ ->
    DPcoupl μ1 μ2 (fun x y => ψ (f1 x) (f2 y)) ε δ.
  Proof.
    rewrite /DPcoupl.
    intro He.
    intros H3 f g H4 H5 H6.
    assert ( a, f a <= 1). {
      intros. by destruct (H4 a).
    }
    assert ( b, 0 <= g b). {
      intros. by destruct (H5 b).
    }
    set F := sup_fiber f1 f H7.
    set G := inf_fiber f2 g H8.
    epose proof (H3 F G _ _ _).
    Unshelve.
    2 : {
      unfold F.
      apply sup_fiber_range.
    }
    2 : {
      unfold G.
      apply inf_fiber_range.
    }
    2 : {
      intros a' b'.
      destruct (ExcludedMiddle ( a, f1 a = a')). 2 : {
                                                  pose proof (not_exists_forall_not _ _ H9) as H9'.
                                                  simpl in H9'.
                                                  intros.
                                                  rewrite /F sup_fiber_empty; auto.
                                                  epose proof (inf_fiber_range _ _ _ _ ) as [??].
                                                  apply H11.
                                                }
                                                destruct (ExcludedMiddle ( b, b' = f2 b)). 2 : {
                                                                                            pose proof (not_exists_forall_not _ _ H10) as H10'.
                                                                                            simpl in H10'.
                                                                                            intros.
                                                                                            rewrite /G inf_fiber_empty; auto.
                                                                                            epose proof (sup_fiber_range _ _ _ _ ) as [??].
                                                                                            apply H13.
                                                                                          }
                                                                                          destruct H9 as [a H9], H10 as [b H10].
      intros.
      eapply sup_fiber_is_lub.
      move => x [Hx | Hx]; subst; eauto.
      {
        epose proof (inf_fiber_range _ _ _ _ ) as [??].
        apply H9.
      }
      destruct Hx as [a0 [Ha0 Ha1]]; subst.
      eapply inf_fiber_is_glb.
      move => x [Hx | Hx]; subst; eauto.
      destruct Hx as [b0 [Hb0 Hb1]]; subst.
      apply H6.
      by rewrite Ha0 Hb0.
    }
    epose proof (Expval_dmap μ1 f1 F _ _).
    epose proof (Expval_dmap μ2 f2 G _ _).
    unfold Expval in *.
    rewrite H11 H10 in H9.
    trans (SeriesC (λ a : A, μ1 a * (F f1) a)).
    {
      apply SeriesC_le.
      2: { apply ex_expval_unit. intros. simpl. by apply sup_fiber_range. }
      intros.
      split.
      - apply Rmult_le_pos; real_solver.
      - apply Rmult_le_compat_l; auto. simpl.
        apply sup_fiber_is_lub. right. econstructor; eauto.
    }
    etrans.
    { apply H9. }
    apply Rplus_le_compat_r.
    apply Rmult_le_compat_l.
    { specialize (exp_pos ε). lra. }
    apply SeriesC_le.
    2: by apply ex_expval_unit.
    intros.
    split.
    - apply Rmult_le_pos; try real_solver.
      epose proof (inf_fiber_range _ _ _ _ ) as [??].
      apply H12.
    - apply Rmult_le_compat_l; auto. simpl.
      apply inf_fiber_is_glb. right. econstructor; eauto.
      Unshelve.
      + intros. epose proof (sup_fiber_range _ _ _ _ ) as [??].
        apply H10.
      + apply ex_expval_unit. intros. simpl. by apply sup_fiber_range.
      + intros. epose proof (inf_fiber_range _ _ _ _ ) as [??].
        apply H11.
      + apply ex_expval_unit. intros. simpl. by apply inf_fiber_range.
  Qed.


  Lemma DPcoupl_mass_leq `{Countable A, Countable B} (μ1 : distr A) (μ2 : distr B) (R : A B Prop) ε δ :
    DPcoupl μ1 μ2 R ε δ SeriesC μ1 <= exp ε * SeriesC μ2 + δ.
  Proof.
    intros Hcoupl.
    rewrite /DPcoupl in Hcoupl.
    rewrite -(Rmult_1_r (SeriesC μ1)).
    rewrite -(Rmult_1_r (SeriesC μ2)).
    do 2 rewrite -SeriesC_scal_r.
    apply Hcoupl; intros; lra.
  Qed.

  Lemma DPcoupl_eq_elim `{Countable A} (μ1 μ2 : distr A) ε δ :
    DPcoupl μ1 μ2 (=) ε δ forall a, μ1 a <= exp ε * μ2 a + δ.
  Proof.
    intros Hcoupl a.
    rewrite /DPcoupl in Hcoupl.
    rewrite -(SeriesC_singleton a (μ1 a)).
    rewrite -(SeriesC_singleton a (μ2 a)).
    assert (SeriesC (λ n : A, if bool_decide (n = a) then μ1 a else 0)
            = SeriesC (λ n : A, μ1 n * (if bool_decide (n = a) then 1 else 0))) as ->.
    {
      apply SeriesC_ext; real_solver.
    }
    assert (SeriesC (λ n : A, if bool_decide (n = a) then μ2 a else 0)
            = SeriesC (λ n : A, μ2 n * (if bool_decide (n = a) then 1 else 0))) as ->.
    {
      apply SeriesC_ext; real_solver.
    }
    apply Hcoupl; real_solver.
  Qed.

  Lemma DPcoupl_eq_elim_dp `{Countable A} (μ1 μ2 : distr A) ε δ:
    DPcoupl μ1 μ2 (=) ε δ
    forall (P : A -> Prop),
    SeriesC (λ a : A, if bool_decide (P a) then μ1 a else 0) <=
    exp ε * SeriesC (λ a : A, if bool_decide (P a) then μ2 a else 0) + δ.
  Proof.
    intros Hcoupl P.
    rewrite /DPcoupl in Hcoupl.
    assert (SeriesC (λ a : A, if bool_decide (P a) then μ1 a else 0)
            = SeriesC (λ a : A, μ1 a * (if bool_decide (P a) then 1 else 0))) as ->.
    { apply SeriesC_ext; real_solver. }
    assert (SeriesC (λ a : A, if bool_decide (P a) then μ2 a else 0)
            = SeriesC (λ a : A, μ2 a * (if bool_decide (P a) then 1 else 0))) as ->.
    { apply SeriesC_ext; real_solver. }
    apply Hcoupl; real_solver.
  Qed.

  Lemma DPcoupl_to_Mcoupl `{Countable A, Countable B} (μ1 : distr A) (μ2 : distr B) Q ε :
    DPcoupl μ1 μ2 Q ε 0 -> Mcoupl μ1 μ2 Q ε.
  Proof.
    intros Hcoupl f g Hf Hg HQ.
    rewrite <- Rplus_0_r.
    by apply Hcoupl.
  Qed.

  Lemma Mcoupl_to_DPcoupl `{Countable A, Countable B} (μ1 : distr A) (μ2 : distr B) Q ε :
    Mcoupl μ1 μ2 Q ε -> DPcoupl μ1 μ2 Q ε 0.
  Proof.
    intros Hcoupl f g Hf Hg HQ.
    etransitivity; first by apply Hcoupl.
    real_solver.
  Qed.

  Lemma DPcoupl_to_ARcoupl `{Countable A, Countable B} (μ1 : distr A) (μ2 : distr B) Q δ :
    DPcoupl μ1 μ2 Q 0 δ -> ARcoupl μ1 μ2 Q δ.
  Proof.
    intros Hcoupl f g Hf Hg HQ.
    etransitivity; first by apply Hcoupl.
    rewrite exp_0.
    real_solver.
  Qed.

  Lemma ARcoupl_to_DPcoupl `{Countable A, Countable B} (μ1 : distr A) (μ2 : distr B) Q δ :
    ARcoupl μ1 μ2 Q δ -> DPcoupl μ1 μ2 Q 0 δ.
  Proof.
    intros Hcoupl f g Hf Hg HQ.
    etransitivity; first by apply Hcoupl.
    rewrite exp_0.
    real_solver.
  Qed.

  Lemma DPcoupl_to_UB `{Countable A, Countable B} (μ1 : distr A) (μ2 : distr B) (P : A -> Prop) (δ : R) :
    DPcoupl μ1 μ2 (λ a _, P a) 0 δ -> pgl μ1 P δ.
  Proof.
    intros Hcoupl.
    eapply ARcoupl_to_UB.
    eapply DPcoupl_to_ARcoupl.
    apply Hcoupl.
  Qed.

End couplings_theory.

Section DPcoupl.
  Context `{Countable A, Countable B}.
  Variable (μ1 : distr A) (μ2 : distr B).

  Lemma DPcoupl_dzero (μ : distr B) φ ε δ :
    (0 <= δ) ->
    DPcoupl (dzero (A:=A)) μ φ ε δ.
  Proof.
    intros Hleq ?????. rewrite SeriesC_scal_l. field_simplify.
    apply Rle_plus_l; auto.
    apply Rmult_le_pos. 1: left ; apply exp_pos.
    apply SeriesC_ge_0'. intros ; apply Rmult_le_pos => //. apply Hg.
  Qed.

  Lemma DPcoupl_trivial :
    SeriesC μ1 = 1 ->
    SeriesC μ2 = 1 ->
    DPcoupl μ1 μ2 (λ _ _, True) 0 0.
  Proof.
    intros Hμ1 Hμ2 f g Hf Hg Hfg.
    destruct (LubC_correct f) as [H1 H2].
    destruct (GlbC_correct g) as [H3 H4].
    rewrite Rplus_0_r.
    apply (Rle_trans _ (SeriesC (λ a : A, μ1 a * (real (LubC f))))).
    {
      apply SeriesC_le'; auto.
      - intro a.
        apply Rmult_le_compat_l; auto.
        apply rbar_le_finite; auto.
        apply (Rbar_le_sandwich (f a) 1); auto.
        apply H2; auto.
        intro; apply Hf.
      - apply (ex_seriesC_le _ μ1); auto.
        intro a; specialize (Hf a); real_solver.
      - apply ex_seriesC_scal_r; auto.
    }
    rewrite SeriesC_scal_r Hμ1 -Hμ2 -SeriesC_scal_r.
    apply (Rle_trans _ (SeriesC (λ b : B, μ2 b * (real (GlbC g))))).
    {
      (* We step form LubC f to Glb here because it is easier if
         we have an inhabitant of B *)

      apply SeriesC_le'; auto.
      - intro b.
        apply Rmult_le_compat_l; auto.
        apply rbar_le_finite; auto.
        + apply (Rbar_le_sandwich 0 (g b)); auto.
          apply H4; auto.
          apply Hg.
        + apply H4.
          intro b'.
          destruct (LubC f) eqn:Hlub.
          * rewrite <- Hlub; simpl.
            apply finite_rbar_le; auto.
            { apply is_finite_correct; eauto. }
            rewrite Hlub.
            apply H2; intro.
            apply Hfg; auto.
          * apply Hg.
          * apply Hg.
      - apply ex_seriesC_scal_r; auto.
      - apply ex_seriesC_scal_r; auto.
    }
    rewrite exp_0.
    rewrite Rmult_1_l.
    apply SeriesC_le'; auto.
    - intro b.
      apply Rmult_le_compat_l; auto.
      apply finite_rbar_le.
      + apply (Rbar_le_sandwich 0 (g b)); auto.
        apply H4.
        apply Hg.
      + apply H3.
    - apply ex_seriesC_scal_r; auto.
    - apply (ex_seriesC_le _ μ2); auto.
      intro b; specialize (Hg b); real_solver.
  Qed.

  Lemma DPcoupl_trivial_R :
    SeriesC μ2 = 1 -> A -> B ->
    DPcoupl μ1 μ2 (λ _ _, True) 0 0.
  Proof.
    intros Hμ2 a0 b0 f g Hf Hg Hfg.
    destruct (LubC_correct f) as [H1 H2].
    destruct (GlbC_correct g) as [H3 H4].
    rewrite Rplus_0_r.
    apply (Rle_trans _ (SeriesC (λ a : A, μ1 a * (real (LubC f))))).
    {
      apply SeriesC_le'; auto.
      - intro a.
        apply Rmult_le_compat_l; auto.
        apply rbar_le_finite; auto.
        apply (Rbar_le_sandwich (f a) 1); auto.
        apply H2; auto.
        intro; apply Hf.
      - apply (ex_seriesC_le _ μ1); auto.
        intro a; specialize (Hf a); real_solver.
      - apply ex_seriesC_scal_r; auto.
    }
    rewrite SeriesC_scal_r.
    rewrite exp_0 Rmult_1_l.
    apply (Rle_trans _ (SeriesC (λ b : B, μ2 b * (real (GlbC g))))); last first.
    {
      apply SeriesC_le'; auto.
      - intro b.
        apply Rmult_le_compat_l; auto.
        apply finite_rbar_le.
        + apply (Rbar_le_sandwich 0 (g b)); auto.
          apply H4.
          apply Hg.
        + apply H3.
      - apply ex_seriesC_scal_r; auto.
      - apply (ex_seriesC_le _ μ2); auto.
        real_solver.
    }
    rewrite SeriesC_scal_r.
    apply Rmult_le_compat.
    - auto.
    -
      etrans. 1: apply (Hf a0).
      destruct (LubC f) eqn:lubf.
      +
        replace (f a0) with (real (Rbar.Finite (f a0))) ; simpl ; auto.
        eapply H1.
      + exfalso.
        specialize (H2 1).
        simpl in H2. apply H2. intros. apply Hf.
      + exfalso. apply H1. done.
    - transitivity 1; auto.
      lra.
    - assert (forall uf : R, (forall a, Rle (f a) uf) -> Rle (real (LubC f)) uf).
      {
        clear -Hf H1 H2 a0.
        destruct (LubC f) eqn:lubf.
        + intros. simpl. simpl in H1. specialize (H2 (Rbar.Finite uf)).
          apply H2. done.
        + simpl. intros.
          specialize (H2 (Rbar.Finite uf)).
          simpl in H2. exfalso. apply H2. done.
        + intros. simpl. simpl in H1. exfalso. apply H1. done.
      }
      apply H5.
      assert (forall lg : R, (forall b, Rle lg (g b)) -> Rle lg (real (GlbC g))).
      {
        clear -Hg H3 H4 b0.
        destruct (GlbC g) eqn:glbg.
        + intros. simpl. simpl in H3. specialize (H4 (Rbar.Finite lg)).
          apply H4. done.
        + intros. simpl. simpl in H3. exfalso. apply H3. done.
        + simpl. intros.
          specialize (H4 (Rbar.Finite lg)).
          simpl in H4. exfalso. apply H4. done.
      }
      intros.
      apply H6.
      intros. apply Hfg. done.
  Qed.

  Lemma DPcoupl_pos_R R ε δ :
    DPcoupl μ1 μ2 R ε δ DPcoupl μ1 μ2 (λ a b, R a b μ1 a > 0 μ2 b > 0) ε δ.
  Proof.
    intros Hμ1μ2 f g Hf Hg Hfg.
    assert (SeriesC (λ a : A, μ1 a * f a) =
              SeriesC (λ a : A, μ1 a * (if bool_decide (μ1 a > 0) then f a else 0))) as ->.
    { apply SeriesC_ext; intro a.
      case_bool_decide; auto.
      assert (0 <= μ1 a); auto.
      assert (μ1 a = 0); nra.
    }
    assert (SeriesC (λ b : B, μ2 b * g b) =
              SeriesC (λ b : B, μ2 b * (if bool_decide (μ2 b > 0) then g b else 1))) as ->.
    { apply SeriesC_ext; intro b.
      case_bool_decide; auto.
      assert (0 <= μ2 b); auto.
      assert (μ2 b = 0); nra.
    }
    apply Hμ1μ2; auto.
    - intro a; specialize (Hf a); real_solver.
    - intro b; specialize (Hg b); real_solver.
    - intros a b Rab.
      specialize (Hf a).
      specialize (Hg b).
      specialize (Hfg a b).
      real_solver.
  Qed.

End DPcoupl.

Lemma DPcoupl_eq_trans_l `{Countable A, Countable B} μ1 μ2 μ3 (R: A B Prop) ε1 ε2:
  (0 <= ε1) ->
  (0 <= ε2) ->
  DPcoupl μ1 μ2 (=) ε1 0 ->
  DPcoupl μ2 μ3 R ε2 0 ->
  DPcoupl μ1 μ3 R (ε1 + ε2) 0.
Proof.
  intros Hleq1 Hleq2 Heq HR f g Hf Hg Hfg.
  specialize (HR f g Hf Hg Hfg).
  eapply Rle_trans; [apply Heq | ]; auto.
  - intros ? ? ->; lra.
  - do 2 rewrite Rplus_0_r.
    rewrite exp_plus Rmult_assoc.
    apply Rmult_le_compat => //.
    1: etrans. 2: by apply exp_pos_ge_1.
    2: apply SeriesC_ge_0'.
    + real_solver.
    + real_solver.
    + rewrite <- Rplus_0_r.
      apply HR.
Qed.

Lemma DPcoupl_eq_trans_r `{Countable A, Countable B} μ1 μ2 μ3 (R: A B Prop) ε1 ε2 :
  (0 <= ε1) ->
  (0 <= ε2) ->
  DPcoupl μ1 μ2 R ε1 0 ->
  DPcoupl μ2 μ3 (=) ε2 0 ->
  DPcoupl μ1 μ3 R (ε1 + ε2) 0.
Proof.
  intros Hleq1 Hleq2 HR Heq f g Hf Hg Hfg.
  specialize (HR f g Hf Hg Hfg).
  eapply Rle_trans ; eauto.
  do 2 rewrite Rplus_0_r.
  rewrite exp_plus Rmult_assoc.
  apply Rmult_le_compat_l.
  1: etrans ; [| apply exp_pos_ge_1]. 1: lra. 1: lra.
  rewrite <- Rplus_0_r.
  apply Heq; eauto.
  intros; simplify_eq; lra.
Qed.

Lemma DPcoupl_pweq `{Countable A} (μ μ' : distr A) ε δ (εpos : 0 <= ε) (δpos : forall a, 0 <= δ a)
                  (δconv : ex_seriesC δ)
                  (pw : x, DPcoupl μ μ' (λ a a', a = x a' = x) ε (δ x)) :
  DPcoupl μ μ' eq ε (SeriesC δ).
Proof.
  intros ?????.
  rewrite -SeriesC_scal_l.
  rewrite -SeriesC_plus; auto.
  2:{
    apply ex_seriesC_scal_l.
    apply (ex_seriesC_le _ μ'); auto.
    real_solver.
  }
  eapply SeriesC_le.
  2:{
    apply ex_seriesC_plus; auto.
    apply ex_seriesC_scal_l.
    apply (ex_seriesC_le _ μ') => //.
    intros b ; specialize (Hg b) ; real_solver. }
  intros x.
  split. 1: apply Rmult_le_pos => // ; apply Hf.
  cut (SeriesC (λ a, if bool_decide (a = x) then μ a * f x else 0) <=
         SeriesC (λ a, if bool_decide (a = x) then exp ε * μ' a * g x + δ x else 0)).
  { intros h. rewrite !SeriesC_singleton_dependent in h. rewrite -Rmult_assoc ; done. }
  specialize (pw x).
  replace (SeriesC (λ a : A, if bool_decide (a = x) then exp ε * μ' a * g x + δ x else 0))
    with (exp ε * SeriesC (λ a : A, if bool_decide (a = x) then μ' a * g x else 0) + δ x).
  2: do 2 rewrite SeriesC_singleton_dependent; lra.
  set (f' := λ a, if bool_decide (a = x) then f x else 0).
  set (g' := λ a, if bool_decide (a = x) then g x else 0).
  opose proof (pw f' g' _ _ _) as pw'.
  1,2: intros ; subst f' g' => /= ; case_bool_decide ; (apply Hf || apply Hg || lra).
  {
    intros. subst f' g' => /=. repeat case_bool_decide. 4: done.
    - subst. by apply Hfg.
    - subst. exfalso. eauto.
    - apply Hg.
  }
  etrans.
  1: etrans.
  2: exact pw'. 1,2: right ; subst f' g'.
  2: apply Rplus_eq_compat_r; apply Rmult_eq_compat_l.
  all: apply SeriesC_ext ; intros ; case_bool_decide ; field.
Qed.

Lemma DPcoupl_pweq' `{Countable A} `{Countable B} `{Countable X}
  (μ : distr A) (μ' : distr B) (f:A -> X) g
  ε δ (εpos : 0 <= ε) (δpos : forall a, 0 <= δ a)
  (δconv : ex_seriesC δ)
  (pw : x, DPcoupl μ μ' (λ a b, f a = x g b = x) ε (δ x)) :
  DPcoupl μ μ' (λ a b, f a = g b) ε (SeriesC δ).
Proof.
  cut (DPcoupl (dmap f μ) (dmap g μ') eq ε (SeriesC δ)); first by intros ?%DPcoupl_map_inv.
  apply DPcoupl_pweq; [done..|].
  intros.
  by apply DPcoupl_map.
Qed.