%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%
% [y_pred, rho, EV] = Pred4L2L(xt,yt,gt,F,hypers,xv,yv,gv)
%
% inputs : xt,     inputs used for training (#inputs x #trainpat x #parallel tasks)
%          yt,     outputs used for training (#trainpat x #parallel tasks) (outputs are assumed 1D)
%          gt,     missing value inicator (1=present, 0=missing)
%          F,      feature vectors for all of the tasks (#features x #tasks)
%          hypers, hyperparameters struct obtained through e.g. TC4L2L
%          *v      x, y, g used for (testing) prediction (quality)
%
% outputs: y_pred, predicted value for y on xv
%          rho(i,j) = prob task j is in cluster i
%          EV,     explained variance on the test set (xv,yv)
%
% Made by Bart Bakker, bartb@mbfys.kun.nl
%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 

function [y_pred, rho, EV] = GPred4L2L(xt,yt,gt,F,hypers,xv,yv,gv)

[ninp,npat,ntask] = size(xt);

xt = xt.*add_reps(gt,0,ninp);
yt = yt.*gt;

Nn = sum(gt);

g  = [gt; gv];

Syy  = sum(yt.^2);


B = hypers.B;
M = hypers.M;
D = hypers.D;
K = hypers.K;
S = hypers.S;
q = hypers.q;

[nhid, ninp]    = size(B);
[nprior, nfeat] = size(q);

% Z   = tanh(Tx(B,xt) + reshape(D*gt(:)',nhid,npat,ntask));
% Z   = [Z;gt];
% 
% Aml = Tx(Tinv(Tx(Z,Z,2,3)), Tx(Z,yt,[2 1],[3 2]),1,[3 2]);

vec = bmksq_to_vec(hypers)';
ind = make_ind(nhid,ninp,nprior,nprior,1,nfeat);

for i = 1:nprior
  [loglik(i,:),iR1(i,:),R2(i,:)] = e_bmksq(vec(ind(:,i)), xt, yt, gt, Syy, nhid, Nn);
end;
iR1 = reshape(iR1',nhid+1,nhid+1,ntask,nprior);
R2  = reshape(R2',nhid+1,ntask,nprior);

if (nprior > 1)
  loglik = -loglik;

  q   = exp(-q*F)./(ones(nprior,1)*sum(exp(-q*F)));
  rho = exp(loglik-ones(nprior,1)*mean(loglik));
  rho = rho.*q;
  rho = rho./(ones(nprior,1)*sum(rho,1));

  Amp = Tx(rho,Tx(iR1,R2,[2 1],[3 2;4 3]),[1 3],2);
else
  Amp = Tx(iR1,R2,[2 1],[3 2]);
end

[ninp,npat,ntask] = size(xv);

xv = xv.*add_reps(gv,0,ninp);

Z             = tanh(Tx(B,xv) + reshape(D*gv(:)',nhid,npat,ntask));
Z(nhid+1,:,:) = gv;

y_pred = Tx(Amp,Z,1,[2 3]);

if (nargin > 6)
  yv = yv.*gv;
  EV = sumsqr(y_pred - yv)/sum(gv(:));
  EV = 1 - EV/std(yv(find(g)))^2;
else
  EV = 0;
end






