function [KL,KLM,KLV] = KLdist(mean1, inv_var1, mean2, inv_var2)

% KL = < -0.5 log |var_ex| - 0.5 ([theta_1t; .. theta_Nt; m_t] - mean1)' var_ex^-1 idem
%      +0.5 log |var_app| + 0.5 ([theta_1t; .. theta_Nt; m_t] - mean2)' var_app^-1 idem >

% KL = -log(det(var_ex)) - length(mean1) + log(det(var_app)) + ...
%      (mean1 - mean2)'*inv_var2*(mean1 - mean2) + ...
%      trace(sqrtm(inv_var2)*var_ex*sqrtm(inv(var_app)));
% KL = KL/2;

if (size(mean1,2) == 1)
  KLM = (mean1 - mean2)'*inv_var2*(mean1 - mean2)/2;
  KLV = log(det(inv_var1)) - length(mean1) - log(det(inv_var2)) + ...
         trace(sqrtm(inv_var2)*inv(inv_var1)*sqrtm(inv_var2));
else
  S = size(mean1);
  N = S(1);
  S = prod(S(2:end));
  mean1 = reshape(mean1,N,S);
  mean2 = reshape(mean2,N,S);
  inv_var1 = reshape(inv_var1,N,N,S);
  inv_var2 = reshape(inv_var2,N,N,S);
  KLM = sum( (mean1 - mean2).*Tx(inv_var2, mean1 - mean2,[2 1],[3 2]) )/2;
  KLV = zeros(1,size(mean1,2));
  for i = 1:size(mean1,2)
    KLV(i) = log(det(inv_var1(:,:,i))) - N - log(det(inv_var2(:,:,i))) + ...
                trace(sqrtm(inv_var2(:,:,i))*inv(inv_var1(:,:,i))*sqrtm(inv_var2(:,:,i)));
  end;
end;
KLV = KLV/2;
KL = KLM + KLV;
