%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%
% hyper = TC4L2L(xt,yt,gt,nhid,nprior,B,D)
%
% inputs : xt,     inputs used for training (#inputs x #trainpat x #parallel tasks)
%          yt,     outputs used for training (#trainpat x #parallel tasks) (outputs are assumed 1D)
%          gt,     missing value inicator (1=present, 0=missing)
%          nhid,   #hidden units, not counting bias
%          nprior, #clusters/priors desired
%          B,      initial input-to-hidden weights (optional, e.g. when results on the same data 
%                    but with different nprior are available)
%          D,      initial input bias (just as optional)
%
% outputs: hypers, structure containing the optimized hyperparameters:
%                    B: [nhid x ninp]
%                    M: [nhid+1 x nprior],          means for each cluster's distribution over 
%                                                   hidden-to-output weights
%                    D: [nhid x 1]
%                    K: [nhid+1 x nhid+1 x nprior], variance of same
%                    S: [1],                        output variance
%                    q: [nprior x 1],               a priori cluster assignment probability
%
% Made by Bart Bakker, bartb@mbfys.kun.nl
%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

function [hypers, tt] = TC4L2L(x,y,g,nhid,nprior,B,D)

t = cputime;

[ninp,npat,ntask] = size(x);

x = x.*add_reps(g,0,ninp);
y = y.*g;

Nn = sum(g);
N  = sum(Nn);

Syy  = sum(y.^2);
SSyy = sum(Syy);

Variance = add_reps(add_reps(eye(nhid+1),2,ntask),2,nprior);
State    = zeros(nhid+1,ntask,nprior);

hyper = init_bmksq(ninp,nhid,ntask,nprior);
if (nargin > 5)
  hyper.B = B;
  hyper.D = D;
else
  B = hyper.B;
  D = hyper.D;
end
Z = [tanh(Tx(B,x) + add_reps(add_reps(D,1,npat),2,ntask)); ones(1,npat,ntask)];

%%%  A = max likelihood A  %%%

A = Tx(Tinv(Tx(Z,Z,2,3)), Tx(Z,y,[2 1],[3 2]),1,[3 2]);

%%%  Initiele clustering op basis van Euclidische afstanden A's  %%%

dd = sqsum(reshape(repmat(A,ntask,1) - repmat(A(:),1,ntask), [nhid+1 ntask ntask]).^2);

[c_size,c_tasks] = Cluster(dd,'ward',1,nprior);
ind              = [1, cumsum(c_size(1:end-1))+1; cumsum(c_size)];

rho = 0.1*ones(nprior,ntask);
for i = 1:nprior
  hyper.M(:,i)             = mean(A(:,c_tasks(ind(:,i))),2);
  rho(i,c_tasks(ind(:,i))) = 0.9;
end;
hyper.q = log(c_size');

vec = bmksq_to_vec(hyper)';

%%%  ind(:,i) geeft de hypers voor cluster i  %%%

ind = make_ind(nhid,ninp,nprior,nprior,1);
for i = 1:nprior
  [loglik(i,:),iR1(i,:),R2(i,:)] = e_bmksq(vec(ind(:,i)), x, y, g, Syy, nhid, Nn);
end;
loglik = -loglik;

%%% rho(i,j) = P( taak j in cluster i | data ) = P( data j | taak j in cluster i ) * P( taak j in cluster i ) %%%
%%%                                                 \= exp(loglik(i,j))                 \= q(i)

q   = vec(end-nprior+1:end);
q   = exp(q)/sum(exp(q));
rho = exp(loglik-ones(nprior,1)*max(loglik));
rho = rho.*repmat(q,1,ntask);
rho = rho./repmat(sum(rho,1),nprior,1);

iter = 1;
converged = 0;
while ~converged

  lastvec = vec;
  vec = frprmn('e_em', 'grad_em', vec, 5, 1e-2, 1e-2, [], ind, x, y, g, Syy, nhid, Nn, rho);
  
  for i = 1:nprior
    [loglik(i,:),iR1(i,:),R2(i,:)] = e_bmksq(vec(ind(:,i)), x, y, g, Syy, nhid, Nn);
  end;
  loglik = -loglik;

  q   = vec(end-nprior+1:end);
  q   = exp(q)/sum(exp(q));
  rho = exp(loglik-ones(nprior,1)*max(loglik));
  rho = rho.*repmat(q,1,ntask);
  rho = rho./repmat(sum(rho,1),nprior,1);

  %%%  check convergentie  %%%

% LastVariance = Variance;
% LastState    = State;

% Variance = reshape(iR1',nhid+1,nhid+1,ntask,nprior);
% State    = Tx(Variance, reshape(R2',nhid+1,ntask,nprior), [2 1], [3 2; 4 3]);

% KL = KLdist(LastState,LastVariance,State,Variance)*rho(:)/ntask;

% fprintf('KL = %f\n',KL);
% converged = (KL < 1e-2);

  hypers{iter}    = vec_to_bmksq(vec, ninp, nprior, nhid, ntask, nprior, 1);
  tt(iter) = cputime - t;
  converged = (iter > 100);
  fprintf(1,'iter = %i\n',iter);
  iter = iter + 1;
end

hypers{iter}    = vec_to_bmksq(vec, ninp, nprior, nhid, ntask, nprior, 1);









