% Written by Tobias Grafke (grafke@cims.nyu.edu), based on the
% publication
%
%    T. Grafke, T. Schäfer, and E. Vanden-Eijnden,
%    "Long-Lasting Effects of Small Random Perturbations on Dynamical
%       Systems: Theoretical and Computational Tools" (2016)
%
% available at http://arxiv.org/abs/1604.03818.
%
% Computes the minimizer via simplified geometric minimum action
% method, for Hamiltonians that have H_pp as diagonal matrix.
%
% Usage:
%  x = sgmam(x_initial, params)
%
% Returns the minimizer x. Here, params is a data structure containing
% the fields
%  params.H:          The Hamiltonian as function handle of x and p
%  params.H_p:        The p-derivative of the Hamiltonian as a
%                     function handle of x and p
%  params.H_x:        The x-derivative of the Hamiltonian as a
%                     function handle of x and p
%  params.invH_pp:    The inverse of the second derivative of the
%                     Hamiltonian with respect to p, as a function of x
%                     and p. In this subroutine, it is assumed to be a
%                     diagonal matrix, so only the diagonal is to be
%                     passed as function handle of x and p that
%                     returns a vector
%  params.iterations: (optional) Total number of iterations
%  params.plotevery:  (optional) Do a diagnostic plot every <plotevery>
%                     iterations
%  params.epsilon:    (optional) Relaxation step length
%  params.{xa,xb,xs}: (optional) initial, final and saddle point
%                     for diagnostic plotting purposes

function x = sgmam(x_initial, params)
   % default params
   if isfield(params, 'iterations')
       iterations = params.iterations;
   else
       iterations = 1000;
   end

   if isfield(params, 'plotevery')
       plotevery = params.plotevery;
   else
       plotevery = 100;
   end

   if isfield(params, 'epsilon')
       epsilon = params.epsilon;
   else
       epsilon = 1e-1;
   end

   if isfield(params, 'xs')
       xsaddle = params.xs
   else
       xsaddle = []
   end

   x = x_initial;
   xa = x(1,:); xb = x(end,:);
   H = params.H; H_p = params.H_p; H_x = params.H_x; invH_pp = params.invH_pp;
   [Nt, Nx] = size(x);
   s = linspace(0,1,Nt);
   p = 0*x;
   pdot = 0*x;
   xdot = 0*x;
   
   tic
   for i = 1:iterations
       xdot = [zeros(1,Nx); 0.5*(x(3:end,:)-x(1:end-2,:)); zeros(1,Nx)];

       % inner loop to compute conjugate momentum p
       p=0*x;
       for j = 1:100
           H_ = H(x,p);
           Hp_ = H_p(x,p);
           invHpp_ = invH_pp(x,p);
           mu = sqrt(max((sum(Hp_.*invHpp_.*Hp_,2)-2*H_)./sum(xdot.*invHpp_.*xdot,2),0));
           p = p + invHpp_.*(repmat(mu,1,Nx).*xdot - Hp_);
           if max(abs(H(x,p))) < 1e-10 && j > 3
               %disp(sprintf('%d inner iterations to compute theta', j));
               break
           end
       end
       p(1,:)=0; p(end,:)=0;
       lambda = mu; lambda(1)=0; lambda(end)=0;

       % Alternative: Direct computation, only correct for quadratic
       % Hamiltonian, where H_pp does not depend on p
       %bx = H_p(x,0*x);
       %iH = invH_pp(x,0*x);
       %lambda = sqrt(sum(bx.^2.*iH,2)./sum(xdot.^2.*iH,2)); 
       %lambda(1)=0; lambda(end)=0;
       %p = iH.*(repmat(lambda,1,Nx).*xdot - bx);

       pdot = [zeros(1,Nx); 0.5*(p(3:end,:)-p(1:end-2,:)); zeros(1,Nx)];
       Hx = H_x(x,p);
       
       % explicit update (deactivated, using implicit instead)
       %x = x + epsilon*(repmat(lambda,1,Nx).*pdot + Hx);
       
       % implicit update
       iH = invH_pp(x,p);

       xdotdot = [zeros(1,Nx); 0.5*(xdot(3:end,:)-xdot(1:end-2,:)); zeros(1,Nx)];
       
       % each dof has same lambda, but possibly different H_pp^{-1}
       for dof=1:Nx
           rhs = (x(2:end-1,dof) + epsilon.*(lambda(2:end-1).*pdot(2:end-1,dof) + Hx(2:end-1,dof) - iH(2:end-1,dof).*lambda(2:end-1).^2.*xdotdot(2:end-1,dof)));
           
           rhs(1) = rhs(1) + epsilon*iH(2,dof)*lambda(2)^2*xa(dof);
           rhs(end) = rhs(end) + epsilon*iH(end-1,dof)*lambda(end-1)^2*xb(dof);
           
           A = diag(1+2*epsilon*iH(2:end-1,dof).*lambda(2:end-1).^2) ...
               - diag(epsilon*iH(2:end-2,dof).*lambda(2:end-2).^2,1) ...
               - diag(epsilon*iH(3:end-1,dof).*lambda(3:end-1).^2,-1);
           A = sparse(A);
           
           x(2:end-1,dof) = A\rhs;
       end

       % reparametrize to arclength
       alpha = [0; cumsum(sqrt(sum(iH(2:end,:).*(x(2:end,:)-x(1:end-1,:)).^2, 2)))];
       alpha = alpha/alpha(end);
       x = interp1(alpha, x, s, 'linear');

       if mod(i,plotevery)==0
           disp(sprintf('%d iterations took %d seconds', plotevery, toc))
           tic
           if Nx==2
               figure(1)
               plot(x(:,1),x(:,2),'x-', [xa(1),xsaddle(1),xb(1)], [xa(2),xsaddle(2),xb(2)], 'ro')
               title('Trajectory')
               xlabel('x')
               ylabel('y')
               
               figure(2); clf
               plot(p(:,1),p(:,2),'x-')
               title('P')
               xlabel('px')
               ylabel('py')
               
           else
               figure(1)
               plot(x, 'x-')
               title('Trajectory')
               
               figure(2)
               plot(p, 'x-')
               title('p')
           end
           
           figure(3)
           plot(lambda, 'x-')
           title('\lambda')
           xlabel('s')
           ylabel('\lambda')
           
           figure(4)
           semilogy(abs(sum(xdot.*p,2))*Nt, 'x-')
           title('Action')
           xlabel('s')
           ylabel('dS')
           
           drawnow
       end
   end
end