% Quasi-Newton 1 method for acceleraction of EM
% max 15 iterations, keep min KLD, quadgk abstol 1e-5, reltol 1e-3
% alpha = beta * dim;

function [est_sigma est_beta est_epsilon KLD] = csm_em_fit_psi_confinement(dim, pmf_v, s_v, in_sigma, epsilon_bd, verbose_flag)

if(~exist('verbose_flag','var'))
  verbose_flag = 0; %no detailed output
end

% ===========  iteration parameters  =================

overall_iteration_limit = 15;

%initial theta
in_beta = 1.0;
in_epsilon = 0.001;

%Psi range
beta_min = 1;
beta_max = 5;
epsilon_min = 0.00001;
sigma_min = 0.00001;

%convergence criteria
rvar_converge_threshold = 1e-3;
KLD_threshold = 1e-5;

%epsilon-bounded estimation criteria
delta_eps = 1e-3;

%other parameters
beta_min_step_log = 13; % 2e-13
max_epsilon_step = 0.1;
max_beta_step = 1;
ATOL = 1e-5;
RTOL = 1e-3;
intital_iteration = 1;
boudary_check_iteration = 6;
% ===========  function definition  =================
if dim>1
  fun_g0 = @(s,w,sigma,beta) (s ./ sigma ./ (2*(dim-1)).^0.5).^(dim-1) ./ sigma .* exp(-s.^2 ./ sigma.^2 / 2 .* w ./ (w+1)) .* (((w.^(-w) .* exp(w-1)).^beta) ./ (w+1).^0.5).^dim;
else
  fun_g0 = @(s,w,sigma,beta) 1 ./ sigma .* exp(-s.^2 ./ sigma.^2 / 2 .* w ./ (w+1)) .* (((w.^(-w) .* exp(w-1)).^beta) ./ (w+1).^0.5).^dim;
end

fun_g1 = @(s,w,sigma,beta) s.^2 .* w ./ (w+1) .* fun_g0(s,w,sigma,beta);
fun_g2 = @(s,w,sigma,beta) w .* (1-log(w)) .* fun_g0(s,w,sigma,beta);

fun_N = @(w,beta) ((w.^(-w) .* exp(w-1)).^beta ./ w.^(1/2)).^dim ;
fun_H = @(w,beta,N) ((w.^(-w) .* exp(w-1)).^beta ./ w.^(1/2)).^dim .*w .* (1-log(w))/N;

if dim>1
  fun_e0 = @(s,w,sigma,beta) (s ./ sigma ./ (2*(dim-1)).^0.5).^(dim-1) ./ sigma .* exp(-s.^2 ./ sigma.^2 / 2 .* w ./ (w+1));
else
  fun_e0 = @(s,w,sigma,beta) 1 ./ sigma .* exp(-s.^2 ./ sigma.^2 / 2 .* w ./ (w+1));
end

fun_range = @(theta) (theta(1)>beta_max) || (theta(1)<beta_min) || (theta(2)<sigma_min) || (theta(3)<epsilon_min) || (theta(3)>epsilon_bd);


% ===========  Iteration begins  =================
est_sigma = in_sigma;
est_beta = in_beta;
est_epsilon = in_epsilon;
est_rvar = est_beta*dim*est_sigma^2;
theta = [est_beta est_sigma est_epsilon]';
A = -eye(3);
g = [0 0 0]';

best_cand = [est_beta est_sigma est_epsilon 100000];

for overall_iteration=1:overall_iteration_limit
    if overall_iteration > intital_iteration
      delta_theta = -A*g;
    else
      delta_theta = g;
    end
    new_theta = theta + delta_theta;
    
    if verbose_flag==1
        fprintf('cond(A)=%.3f; ',cond(A));
    end
    
    %restart if parameters are out of range
    if fun_range(new_theta)
      if verbose_flag==1
        fprintf('Out of Psi! ');
      end

      if abs(delta_theta(3)) > max_epsilon_step
        delta_theta = delta_theta/abs(delta_theta(3))*max_epsilon_step;
      end
      
      if abs(delta_theta(1)) > max_beta_step
        delta_theta = delta_theta/abs(delta_theta(1))*max_beta_step;
      end
      
      new_theta = theta + delta_theta;
      
      if fun_range(new_theta)
        delta_theta = g;
        new_theta = theta + delta_theta;
      end

      if fun_range(new_theta) %final constraint
        diff = [(new_theta(1)-beta_max)/theta(1) (beta_min-new_theta(1))/theta(1)  (sigma_min-new_theta(2))/theta(2) (epsilon_min-new_theta(3))/theta(3) (new_theta(3)-epsilon_bd)/theta(3)];
        scale = [(beta_max-theta(1))/delta_theta(1) (beta_min-theta(1))/delta_theta(1)  (sigma_min-theta(2))/delta_theta(2) (epsilon_min-theta(3))/delta_theta(3) (epsilon_bd-theta(3))/delta_theta(3)];
        scale = scale(diff > 0);
        min_scale = min(scale);
        delta_theta = delta_theta*min_scale; % gradient descent
        new_theta = theta + delta_theta;
      end 
    else
      if verbose_flag==1
        fprintf('EM+       ! ');
      end
    end
    
    est_beta = new_theta(1);
    est_sigma = new_theta(2);
    est_epsilon = new_theta(3);

    % =========== fix epsilon, update sigma and beta ===============
        f0_v = zeros(1,length(s_v));
        f1_v = zeros(1,length(s_v));
        f2_v = zeros(1,length(s_v));
        
        %EM for updating sigma
        tstart=tic;
        for j=1:length(s_v)
            sj = s_v(j);
            
            fun_para_g0 = @(w) fun_g0(sj,w,est_sigma,est_beta);
            fun_para_g1 = @(w) fun_g1(sj,w,est_sigma,est_beta);
            fun_para_g2 = @(w) fun_g2(sj,w,est_sigma,est_beta);
        
            f0_v(j) = quadgk(fun_para_g0,est_epsilon,1,'AbsTol',ATOL,'RelTol',RTOL);
            f1_v(j) = quadgk(fun_para_g1,est_epsilon,1,'AbsTol',ATOL,'RelTol',RTOL);
            f2_v(j) = quadgk(fun_para_g2,est_epsilon,1,'AbsTol',ATOL,'RelTol',RTOL);
        end
        
        %avoid f0 = 0
        ori_f0_v = f0_v;
        f1_v(f0_v==0) = 0;
        f2_v(f0_v==0) = 0;
        f0_v(f0_v==0) = 0.1;
        
        est_var = sum(pmf_v .* f1_v ./ f0_v) / dim;
        pre_sigma = est_sigma;
        est_sigma = sqrt(est_var);
        t_sigma=toc(tstart);
        
        %use bisection search for beta (use the property that H(beta) is increasing function)
        tstart=tic;
        hc = sum(pmf_v .* f2_v ./ f0_v);
        
        this_beta = est_beta;
        N = quadgk(@(w) fun_N(w,this_beta),est_epsilon,1);
        this_h = quadgk(@(w) fun_H(w,this_beta,N),est_epsilon,1);
        
        if (hc > this_h) && (this_beta < beta_max) %search upward
           pre_h = this_h;
           pre_beta = this_beta;
           for i=0:beta_min_step_log
               beta_step = 2^(-i);
               for test_beta = pre_beta + beta_step : beta_step : beta_max
                 this_beta = test_beta;
                 N = quadgk(@(w) fun_N(w,this_beta),est_epsilon,1);
                 this_h = quadgk(@(w) fun_H(w,this_beta,N),est_epsilon,1);
                 
                 if (hc < this_h) || (this_beta > beta_max)
                   break;
                 end
                 
                 pre_h = this_h;
                 pre_beta = this_beta;
               end
           end
           
           if abs(this_h-hc) > abs(pre_h-hc)
             this_beta = pre_beta;
           end
        elseif (hc < this_h) && (this_beta > beta_min) %search downward
           pre_h = this_h;
           pre_beta = this_beta;
           for i=0:beta_min_step_log
               beta_step = 2^(-i);
               for test_beta = pre_beta - beta_step : -beta_step : beta_min
                 this_beta = test_beta;
                 N = quadgk(@(w) fun_N(w,this_beta),est_epsilon,1);
                 this_h = quadgk(@(w) fun_H(w,this_beta,N),est_epsilon,1);
                 
                 if (hc > this_h) || (this_beta < beta_min)
                   break;
                 end
                 
                 pre_h = this_h;
                 pre_beta = this_beta;
               end
           end
           
           if abs(this_h-hc) > abs(pre_h-hc)
             this_beta = pre_beta;
           end
        end

        pre_beta = est_beta;
        est_beta = this_beta;

        t_beta=toc(tstart);
    
    % =========== fix sigma and beta, update epsilon ===================
        e0_v = zeros(1,length(s_v));
        
        tstart=tic;
          for j=1:length(s_v)
              sj = s_v(j);
              
              fun_para_g0 = @(w) fun_g0(sj,w,est_sigma,est_beta);
              f0_v(j) = quadgk(fun_para_g0,est_epsilon,1,'AbsTol',ATOL,'RelTol',RTOL);
          
              e0_v(j) = fun_e0(sj,est_epsilon,est_sigma,est_beta);
          end
          
          %avoid f0 = 0
          e0_v(f0_v==0) = 0;
          f0_v(f0_v==0) = 0.1;
          
          N = quadgk(@(w) fun_N(w,est_beta),est_epsilon,1);
          
          est_para = sum(pmf_v .* e0_v ./ f0_v) * N; 
          pre_epsilon = est_epsilon;  
          est_epsilon = 1/(est_para^(2/dim) - 1);
          %relax for convergence
          %if est_epsilon < pre_epsilon
          %  est_epsilon = pre_epsilon;
          %end
        t_epsilon = toc(tstart);
        
        %QN1
        new_g = [est_beta est_sigma est_epsilon]' - [pre_beta pre_sigma pre_epsilon]';
        delta_g = new_g - g;
        
        if overall_iteration > 1 % intital_iteration
          delta_A = (delta_theta - A* delta_g) * delta_theta' * A / (delta_theta' * A * delta_g);
          A = A + delta_A;
        end
       
        pre_theta = theta;
        theta = new_theta;
        g = new_g;
        
        est_beta = theta(1);
        est_sigma = theta(2);
        est_epsilon = theta(3);

        rvar = est_beta*dim*est_sigma*est_sigma;

        est_pmf_v = ori_f0_v/sum(ori_f0_v);
        KLD = pmf_cross_entropy(pmf_v,est_pmf_v) - pmf_entropy(pmf_v); 
        
        if(verbose_flag==1)
          fprintf('EM+ iter %d: KLD %f, est_sigma %f, est_beta %f, est_epsilon %f, rvar %f; t_sigma %f, t_beta %f, t_epsilon %f\n', overall_iteration, KLD, est_sigma, est_beta, est_epsilon, rvar, t_sigma, t_beta, t_epsilon);
        end
        
        if KLD < best_cand(4)
          best_cand = [est_beta est_sigma est_epsilon KLD];
        end
        
        pre_rvar = pre_theta(1)*dim*pre_theta(2)*pre_theta(2);

        %convergence checking
        if (overall_iteration > intital_iteration) && (((abs(rvar-pre_rvar)/pre_rvar)<rvar_converge_threshold) || KLD < KLD_threshold)
          break;
        end

end

est_beta = best_cand(1);
est_sigma = best_cand(2);
est_epsilon = best_cand(3);
KLD = best_cand(4);


%boundary check

if abs(est_epsilon - epsilon_bd) < delta_eps
   if(verbose_flag==1)
     fprintf('==========> Start epsilon-bounded estimation!\n');
   end

   %bi-section search using golden ratio
   b_start = beta_min;
   b_end = beta_max;
   inv_phi = 2/(1+sqrt(5));
   x2 = b_start + (b_end-b_start)*inv_phi;
   x1 = b_start + (b_end-b_start)*(1-inv_phi);
   
   test_epsilon = epsilon_bd;
   test_sigma = est_sigma;
   
   [sigma_x1 now_beta now_epsilon new_fx1] = csm_em_fit_sigma_only(dim, pmf_v, s_v, test_sigma, x1, test_epsilon, verbose_flag, 10, 1e-2, 1e-5);
   test_sigma = sigma_x1;
   [sigma_x2 now_beta now_epsilon new_fx2] = csm_em_fit_sigma_only(dim, pmf_v, s_v, test_sigma, x2, test_epsilon, verbose_flag, 10, 1e-2, 1e-5);

   for boundar_loop=2:boudary_check_iteration
       if new_fx1 < new_fx2 %moving to left
          b_end = x2;
          
          x2 = x1;
          sigma_x2 = sigma_x1;
          new_fx2 = new_fx1;
          
          x1 = b_start + (b_end-b_start)*(1-inv_phi);
          test_sigma = sigma_x2;
          [sigma_x1 now_beta now_epsilon new_fx1] = csm_em_fit_sigma_only(dim, pmf_v, s_v, test_sigma, x1, test_epsilon, verbose_flag, 10, 1e-2, 1e-5);
       else %moving to right
          b_start = x1;
          
          x1 = x2;
          sigma_x1 = sigma_x2;
          new_fx1 = new_fx2;
          
          x2 = b_start + (b_end-b_start)*inv_phi;
          test_sigma = sigma_x1;
          [sigma_x2 now_beta now_epsilon new_fx2] = csm_em_fit_sigma_only(dim, pmf_v, s_v, test_sigma, x2, test_epsilon, verbose_flag, 10, 1e-2, 1e-5);

       end
   end

   if new_fx1 < KLD
     est_beta = x1;
     est_sigma = sigma_x1;
     est_epsilon = epsilon_bd;
     KLD = new_fx1;
   end

   if new_fx2 < KLD
     est_beta = x2;
     est_sigma = sigma_x2;
     est_epsilon = epsilon_bd;
     KLD = new_fx2;
   end
end



if(verbose_flag==1)
  fprintf('==========> End: est_sigma %f, est_beta %f, est_epsilon %f, KLD %f\n', est_sigma, est_beta, est_epsilon, KLD);
end

