%Yaroslavsky and bilateral filter
%Adaptively selecting filtering parameters from the EM+ fitting result
%Use Quasi-Netwon 1 for parameter fitting

%CSM fitting acceleration:
%GPMF: Grid-subsampled PMF derivation
%EFM:  Equal-frequency merging for reducing bins

function GPMF_EFM_yaro_bf_recursive_filtering(dim, epsilon_bd, T, test_image_v, test_noise_v, test_radius_v, test_neighbor_v)

%=============Global Setting =====================

addpath('utility');
addpath('utility_csm');
addpath('utility_merge');

dir_name = 'GPMF_EFM_yaro_bf_recursive_filtering';

Image_v = {'lena','baboon','barbara','peppers','F16','house','kodim04','kodim08','kodim13','kodim19','kodim22','kodim23'};
Noise_v = {'N5', 'N10', 'N20', 'N40', 'N50', 'N0'};
sigma_v = [5 10 20 40 50 0];
Radius_v = [2 4 6 10 0]; %5x5, 9x9, 13x13, 21x21, four-connected
Neighbor_v = {'Yaroslavsky','Bilateral'}; %Yaroslavsky or BF

%=========== Global parameters ========================
%stop criteria for recursive processing
iteration_max = 3;
nvar_stop = 10;
order_n = 5;

output_figure_flag = 1;
em_fit_verbose_flag = 1;

%=========== Parallel Computing Toolbox ========================

%use multi-threading (four workers)
if matlabpool('size') == 0
   jm = findResource('scheduler', 'configuration', defaultParallelConfig);
   pool_size = min(4, jm.ClusterSize);
   matlabpool(pool_size)
else
   pool_size = matlabpool('size');
end

%=============Global Setting Done =====================

fig_dir = sprintf('T%d_fig',T);
fit_fig_dir = sprintf('T%d_fit_fig',T);

mkdir(fullfile(dir_name,fig_dir));
mkdir(fullfile(dir_name,fit_fig_dir));

%============= Big Loop =====================

 for big_loop_j = 1:length(Image_v)
   check_image = 0;
   for check_j=1:length(test_image_v)
       if (strcmp(Image_v(big_loop_j),test_image_v(check_j)))
          check_image = 1;
          break;
       end
   end
   
   if(check_image==0)
     continue;
   end

   for big_loop_i = 1:length(Noise_v)
       check_noise = 0;
       for check_i=1:length(test_noise_v)
           if (strcmp(Noise_v(big_loop_i),test_noise_v(check_i)))
              check_noise = 1;
              break;
           end
       end
       
       if(check_noise==0)
         continue;
       end

      
      %============= one image/sigma pair =====================
      this_image = Image_v{big_loop_j};
      this_noise = Noise_v{big_loop_i};
      this_sigma = sigma_v(big_loop_i);
      
      randn('seed',big_loop_j*100+big_loop_i); % keep the random seed for future use
      [noisy_img,clean_img] = get_noisy_image(this_image,this_noise);
      size_img = size(clean_img);
      
      
      for big_loop_k = 1:length(test_radius_v)
       this_radius = test_radius_v(big_loop_k);
        
        for big_loop_m = 1:length(Neighbor_v)
           check_neighbor = 0;
           for check_m=1:length(test_neighbor_v)               
               if (strcmp(Neighbor_v(big_loop_m),test_neighbor_v(check_m)))
                  check_neighbor = 1;
                  break;
               end
           end
           
           if(check_neighbor==0)
             continue;
           end
           
           Neighbor_type = Neighbor_v{big_loop_m};
           
           %get filter kernel
           if(this_radius==0) %special case for N4
             coeff = [0 1 0; 1 1 1; 0 1 0];
           else
             coeff = ones(2*this_radius+1,2*this_radius+1);
             
             if strcmp('Bilateral',Neighbor_type)
               sigma_d = this_radius/2;
               center = this_radius+1;
               
               for bf_i = 1:2*this_radius+1
                     for bf_j = 1:2*this_radius+1
                         coeff(bf_i,bf_j) = exp(-((bf_i - center)^2 + (bf_j - center)^2)/2/sigma_d^2);
                     end           
               end
             end
           end           
           
           text_filename = sprintf('GPMF_EFM_T%d_%s_size%d_result.csv',T,Neighbor_type,this_radius*2+1);
           if exist(fullfile(dir_name,text_filename),'file')
             text_fid_result = fopen(fullfile(dir_name,text_filename),'a');             
           else
             text_fid_result = fopen(fullfile(dir_name,text_filename),'w');
             fprintf(text_fid_result,'Filter,Support,Image,Noise,Total Iteration,MSE,PSNR\n');
           end
           fprintf(text_fid_result,'%s,%d,%s,%s,',Neighbor_type,this_radius*2+1, this_image, this_noise);
           
           
           text_filename = sprintf('GPMF_EFM_T%d_%s_size%d_fitting.csv',T,Neighbor_type,this_radius*2+1);
           if exist(fullfile(dir_name,text_filename),'file')
             text_fid_fitting = fopen(fullfile(dir_name,text_filename),'a');          
           else
             text_fid_fitting = fopen(fullfile(dir_name,text_filename),'w');
             fprintf(text_fid_fitting,'Filter,Support,Image,Noise,');
             fprintf(text_fid_fitting,'|,alpha,epsilon,H(P(s)),range variance,CSM noise variance,KLD-EFM,KLD,MAD noise variance,t_pmf,t_fit,');
             fprintf(text_fid_fitting,'|,alpha,epsilon,H(P(s)),range variance,CSM noise variance,KLD-EFM,KLD,MAD noise variance,t_pmf,t_fit,');
             fprintf(text_fid_fitting,'|,alpha,epsilon,H(P(s)),range variance,CSM noise variance,KLD-EFM,KLD,MAD noise variance,t_pmf,t_fit\n');
           end
           fprintf(text_fid_fitting,'%s,%d,%s,%s,',Neighbor_type,this_radius*2+1,this_image, this_noise);
           
           
           text_filename = sprintf('GPMF_EFM_T%d_%s_size%d_filter.csv',T,Neighbor_type,this_radius*2+1);
           if exist(fullfile(dir_name,text_filename),'file')
             text_fid_filter = fopen(fullfile(dir_name,text_filename),'a');          
           else
             text_fid_filter = fopen(fullfile(dir_name,text_filename),'w');
             fprintf(text_fid_filter,'Filter,Support,Image,Noise,');
             fprintf(text_fid_filter,'|,range variance,MSE,PSNR,t_filter,');
             fprintf(text_fid_filter,'|,range variance,MSE,PSNR,t_filter,');
             fprintf(text_fid_filter,'|,range variance,MSE,PSNR,t_filter\n');
           end
           fprintf(text_fid_filter,'%s,%d,%s,%s,',Neighbor_type,this_radius*2+1,this_image, this_noise);
           
           fprintf('\n');

           template_image = noisy_img;
           input_image = noisy_img;
           pre_est_nvar = 10000000;
           pre_est_alpha = 100;
           pre_est_epsilon = 1;
           for iter = 1:iteration_max+1 %iterative filtering loop
             % == CSM fitting
             tstart=tic;
             
             %GPMF
             pre_kde_pmf = get_chi_pmf_neighborhood_array_grid_opt(double(input_image),coeff);
               %reduce half of histogram bins for acceleration without sacrificing fitting quality
               if mod(length(pre_kde_pmf),2) == 1
                 pre_kde_pmf(length(pre_kde_pmf)+1)=0;
               end
               
               pre_kde_pmf = sum(reshape(pre_kde_pmf,2,length(pre_kde_pmf)/2));

             s_v = [0.75 0.5+2*(1:1:length(pre_kde_pmf)-1)];
             ori_len_v = [1.5 2*ones(1,length(pre_kde_pmf)-1)];
             
             [mm nn] = max(pre_kde_pmf);
             s_md = s_v(nn);
             in_sigma = s_md/sqrt(2*(dim-1));
               
             t_pmf = toc(tstart);
             
             tstart=tic;
             %EFM
             ori_send_v = 2*[1:1:length(pre_kde_pmf)] - 0.5;
             merge_count = T;
             
             [pmf_merge send_v] = equal_freq_merge(pre_kde_pmf, ori_send_v, merge_count);

             s_length_v = send_v - [0 send_v(1:merge_count-1)];
             
             s_merge_v = find_representative_s_poly(order_n, pre_kde_pmf, ori_send_v, s_v, ori_len_v, pmf_merge, send_v, s_length_v);

             %EM+ fitting
             [est_sigma est_beta est_epsilon KLDmin] = csm_em_fit_psi_confinement_merge(dim, pmf_merge, s_merge_v, s_length_v, in_sigma, epsilon_bd, em_fit_verbose_flag);
             est_alpha = est_beta*dim;
             est_nvar = est_sigma*est_sigma;
             est_range_variance = est_alpha*est_nvar;
             t_fit = toc(tstart);
             
             mad_nvar_2d = get_mad_nvar(input_image);
             
             H_Q = pmf_entropy(pre_kde_pmf);
             est_pmf_p = get_CSM_pmf(est_alpha,est_epsilon,sqrt(2*(dim-1))*est_sigma,s_v);
             KLDori = pmf_cross_entropy(pre_kde_pmf,est_pmf_p) - pmf_entropy(pre_kde_pmf);
             
             fprintf('%s_size%d %s %s Iteration=%d: est_aa=%f est_eps=%f H_Q=%f est_rvar=%f est_nvar=%f KLD-EFM=%f KLD=%f est_nvar_mad_2d=%f t_pmf=%f t_fit=%f\n',Neighbor_type,this_radius*2+1,this_image, this_noise,iter,est_alpha, est_epsilon, H_Q,  est_range_variance, est_nvar,KLDmin,KLDori,mad_nvar_2d,t_pmf,t_fit);
             fprintf(text_fid_fitting,'|,%f,%f,%f,%f,%f,%f,%f,%f,%f,%f,',est_alpha, est_epsilon, H_Q,  est_range_variance, est_nvar,KLDmin,KLDori,mad_nvar_2d,t_pmf,t_fit);
             
             % == check if go for filtering
             if(iter > 1) && ((est_nvar < nvar_stop) || (iter>iteration_max))
                break;
             end
             pre_est_nvar = est_nvar;
             pre_est_alpha = est_alpha;
             pre_est_epsilon = est_epsilon;
             
             % == Filtering
             tstart=tic;
             denoise_image = square_filter_distance_fast_opt(double(input_image),est_range_variance,coeff,1);            
             %image rotation for next iteration
             input_image = denoise_image;
             template_image = denoise_image;
          
             t_filter = toc(tstart);
             mse = mean(mean(mean((denoise_image-double(clean_img)).^2)));
             
             psnr = 10*log10(255*255/mse);
             
             fprintf('%s_size%d %s %s Iteration=%d: est_rvar=%f mse=%f psnr=%f time=%f\n\n',Neighbor_type,this_radius*2+1,this_image, this_noise,iter,est_range_variance, mse,psnr,t_filter);
             fprintf(text_fid_filter,'|,%f,%f,%f,%f,',est_range_variance, mse,psnr,t_filter);
  
             if output_figure_flag == 1
               %fitting result
               est_pmf_p = get_CSM_pmf(est_alpha,est_epsilon,sqrt(2*(dim-1))*est_sigma,s_v);
               
               merge_s = [0 send_v(1:merge_count-1); send_v(1:merge_count)];
               fs_merge = repmat(pmf_merge./s_length_v,2,1);
               
               figure(1);
               set(figure(1),'PaperUnits','inches','PaperPosition',1.2*[0 0 5 3]);
               semilogy(s_v,pre_kde_pmf./ori_len_v,'r',merge_s(:),fs_merge(:),'k:',s_v,est_pmf_p./ori_len_v,'b--');
                  if this_radius == 0
                    string_title = sprintf('EFM-%d fitting; Color image gradient; %s', T, this_image);
                  else
                    string_title = sprintf('EFM-%d fitting; %s %dx%d; %s, \\sigma_n=%d; Iteration %d', T, Neighbor_type,this_radius*2+1,this_radius*2+1,this_image, this_sigma, iter);
                  end
                  
                  title(string_title);
                  string_0 = sprintf('Empirical');
                  string_1 = sprintf('EFM-%d Merging',T);
                  string_2 = sprintf('EFM-%d Fitting \n(\\alpha = %.2f, \\epsilon =%.4f, \\sigma =%.2f, KLD-EFM = %.4f, KLD = %.4f)', T, est_alpha, est_epsilon, est_sigma, KLDmin,KLDori);
                  h_legend=legend(string_0,string_1,string_2);
                  set(h_legend,'FontSize',8,'Location', 'south');    
                  
                  min_empirical_pmf = -log10(min(pre_kde_pmf(find(pre_kde_pmf~=0))));   
                  max_empirical_pmf = -log10(max(pre_kde_pmf));   
                  min_estimated_pmf = -log10(min(est_pmf_p(find(est_pmf_p~=0))));   
                  max_estimated_pmf = -log10(max(est_pmf_p)); 
                  bottom_pmf = ceil((min_empirical_pmf+min_estimated_pmf)/2);
                  top_pmf = floor(max([max_empirical_pmf max_estimated_pmf]));
                  bottom_pmf = ceil((bottom_pmf-top_pmf)/4 + bottom_pmf);
                  cur_axis = axis;
                  axis([cur_axis(1) cur_axis(2) 10^(-bottom_pmf) 10^(-top_pmf)]);
   
               fig_filename = sprintf('GPMF_EFM_T%d_semilog_rgb_chi_%s_%s_%s_size%d_iter%d.png',T,this_image, this_noise,Neighbor_type,this_radius*2+1,iter);
               print(figure(1),'-r100','-dpng',fullfile(dir_name,fit_fig_dir,fig_filename))             

               %denoised result
               denoise_image_out = uint8(denoise_image);
               fig_filename = sprintf('GPMF_EFM_T%d_%s_%s_%s_size%d_iter%d.png',T,this_image, this_noise,Neighbor_type,this_radius*2+1,iter);
               imwrite(denoise_image_out,fullfile(dir_name,fig_dir,fig_filename),'png');
               
             end
             
             
           end
           
           fprintf(text_fid_result,'%d,%f,%f\n',iter-1,mse,psnr);
           fclose(text_fid_result);
           fprintf(text_fid_fitting,'\n');
           fclose(text_fid_fitting);
           fprintf(text_fid_filter,'\n');
           fclose(text_fid_filter);

        end
      end
      %============= one image/sigma pair done =====================
    end
end

%============= Big Loop Done =====================