%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%
% hyper = TG4L2L(xt,yt,gt,Ft,nhid,nprior,B,D)
%
% inputs : xt,     inputs used for training (#inputs x #trainpat x #parallel tasks)
%          yt,     outputs used for training (#trainpat x #parallel tasks) (outputs are assumed 1D)
%          gt,     missing value inicator (1=present, 0=missing)
%          Ft,     task features (#features x #tasks)
%          nhid,   #hidden units, not counting bias
%          nprior, #clusters/priors desired
%          B,      initial input-to-hidden weights (optional, e.g. when results on the same data 
%                    but with different nprior are available)
%          D,      initial input bias (just as optional)
%
% outputs: hypers, structure containing the optimized hyperparameters:
%                    B: [nhid x ninp]
%                    M: [nhid+1 x nprior],          means for each cluster's distribution over 
%                                                   hidden-to-output weights
%                    D: [nhid x 1]
%                    K: [nhid+1 x nhid+1 x nprior], variance of same
%                    S: [1],                        output variance
%                    q: [nprior x nfeat],           q * Ft(:,i) = a priori cluster assignment 
%                                                   probability for task i
%
% Made by Bart Bakker, bartb@mbfys.kun.nl
%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

function [hypers, tt] = TG4L2L(x,y,g,F,nhid,nprior,B,D)

t = cputime;

[ninp,npat,ntask] = size(x);
nfeat             = size(F,1);

x = x.*add_reps(g,0,ninp);
y = y.*g;

Nn = sum(g);
N  = sum(Nn);

Syy  = sum(y.^2);
SSyy = sum(Syy);

Variance = add_reps(add_reps(eye(nhid+1),2,ntask),2,nprior);
State    = zeros(nhid+1,ntask,nprior);

hyper = init_gating(ninp,nhid,ntask,nprior,nprior,1,nfeat);
if (nargin > 6)
  hyper.B = B;
  hyper.D = D;
else
  B = hyper.B;
  D = hyper.D;
end

Z = [tanh(Tx(B,x) + add_reps(add_reps(D,1,npat),2,ntask)); ones(1,npat,ntask)];
q = hyper.q;

vec = bmksq_to_vec(hyper)';

%%%  ind(:,i) geeft de hypers voor cluster i  %%%

ind = make_ind(nhid,ninp,nprior,nprior,1,nfeat);
for i = 1:nprior
  [loglik(i,:),iR1(i,:),R2(i,:)] = e_bmksq(vec(ind(:,i)), x, y, g, Syy, nhid, Nn);
end;
loglik = -loglik;

%%% rho(i,j) = P( taak j in cluster i | data ) = P( data j | taak j in cluster i ) * P( taak j in cluster i ) %%%
%%%                                                 \= exp(loglik(i,j))                 \= q(i,j)

q   = exp(-q*F)./(ones(nprior,1)*sum(exp(-q*F)));
rho = exp(loglik-ones(nprior,1)*max(loglik));
rho = rho.*q;
rho = rho./repmat(sum(rho,1),nprior,1);

iter = 1;
converged = 0;
while ~converged

  vec = frprmn('e_emG', 'grad_emG', vec, 5, 1e-2, 1e-2, [], ind, x, y, g, F, Syy, nhid, Nn, rho);
  
  for i = 1:nprior
    [loglik(i,:),iR1(i,:),R2(i,:)] = e_bmksq(vec(ind(:,i)), x, y, g, Syy, nhid, Nn);
  end;
  loglik = -loglik;

  q   = reshape(vec(end-nprior*nfeat+1:end),nprior,nfeat);
  q   = exp(-q*F)./(ones(nprior,1)*sum(exp(-q*F)));
  rho = exp(loglik-ones(nprior,1)*max(loglik));
  rho = rho.*q;
  rho = rho./repmat(sum(rho,1),nprior,1);

  %%%  check convergentie  %%%

% LastVariance = Variance;
% LastState    = State;

% Variance = reshape(iR1',nhid+1,nhid+1,ntask,nprior);
% State    = Tx(Variance, reshape(R2',nhid+1,ntask,nprior), [2 1], [3 2; 4 3]);

% KL = KLdist(LastState,LastVariance,State,Variance)*rho(:)/ntask;

% fprintf('KL = %f\n',KL);
% converged = (KL < 1e-2);

  hypers{iter}    = vec_to_bmksq(vec, ninp, nprior, nhid, ntask, nprior, 1);
  tt(iter) = cputime - t;
  converged = (iter > 100);
  fprintf(1,'iter = %i\n',iter);
  iter = iter + 1;
end

hypers{iter}    = vec_to_bmksq(vec, ninp, nprior, nhid, ntask, nprior, 1);









