function [Z1,Z2,Apinv,Bpinv] = online_step_loreta1(A,B,Apinv,Bpinv,x1,x2,y,t) % [Z1,Z2,Apinv,Bpinv] = online_step_loreta1(A,B,Apinv,Bpinv,x1,x2,y,t) % % This function implements Loreta-1 from the paper: % "Online Learning in the Manifold of Low-rank Matrices", Uri Shalit, % Daphna Weinshall and Gal Chechik, NIPS 2010 % % % % Do retraction step given the gradient: Z1*Z2' = R_{AB}(t*y*A*B') % The gradient is represented as y*x1*x2' (a matrix) % % Inputs: % A,B - The low-rank factors of the current model. A is (n X k), B is (m X k) % Apinv, Bpinv - The pseudo-inverses of A and B % x1, x2 - The factors of the rank-1 gradient matrix. x1 is (n X 1), x2 is (m X 1) % t - The step size % y - The sign of the step, either -1 or 1 % % Outputs: % Z1, Z2 - The low-rank factors of the model after the retractions step % Apinv, Bpinv - The pseudo-inverses of Z1, Z2 respectively x2temp = t*y*x2; temp1 = x2temp'*(Bpinv'*(Apinv*x1)); %this is a scalar if x1, x2 are rank-1 UUTx1 = A*(Apinv*x1); cA = UUTx1*(-(1/2)+(3/8)*(temp1) ) + x1*(1-(1/2)*temp1) ; %nx1 dA = Bpinv*x2temp; %kx1 Z1 = A + cA*dA'; x2VVT = (x2temp'*B)*Bpinv; cB = Apinv*x1; %kx1 dB = ((-1/2)+(3/8)*temp1)*x2VVT' + (1-(1/2)*temp1)*x2temp; %nx1 Z2 = B + dB*cB'; Apinv = rank1_pinv_update(A,Apinv,cA,dA); Bpinv = rank1_pinv_update(B,Bpinv,dB,cB); end function [Apinv_new] = rank1_pinv_update(A,Apinv,c,d) % % A is nxk, c is nx1 and d is kx1 % Apinv_new is the pseudoinverse of A+c*d' % ref.: Carl D Mayer, "Generalized inversion of modified matrices" beta = 1+d'*(Apinv*c); v = Apinv*c; n = Apinv'*d; w = c-A*(Apinv*c); %m = d-A'*(Apinv'*d); we deal only with full column rank matrices, therefore norm_m %should be always zero norm_w = w'*w; %norm_m = m'*m; norm_m = 0; %we deal only with full column rank matrices, therefore norm_m should be always zero norm_v = v'*v; norm_n = n'*n; if abs(beta)>eps && norm_meps && abs(beta)eps && norm_m>eps G = -v*w'/norm_w-m*n'/norm(m)+beta*m*w'/(norm_w*norm_m); elseif norm_weps && abs(beta)eps G = m*(v'*Apinv)/beta - (beta/(norm_v*norm_m+beta^2))* (norm_v*m/beta+v)*( (norm_m/beta)*(Apinv'*v)+n)'; elseif norm_w