function model = oasis(data, class_labels, parms) % % MODEL = OASIS(DATA, CLASS_LABELS, PARMS) % % Output: % -- model.W - d x d matrix, the Mahalanobis matrix % -- model.loss_vec - a binary vector: was there an update at each iterations % -- model.parms - parameters (including the default values used) % % Input: % - data - Nxd sparse matrix (each instance being a ROW) % - class_labels - label of each data point (Nx1 integer vector) % - parms: A structure with the following fields: % -- aggress the cutoff point on the size of the correction % (default 0.1) % -- rseed - the random seed for data point selection (default 1) % -- do_sym - whether to symmetrize the matrix every k steps % (default 0) % -- do_psd - whether to PSD the matrix every k steps, including % symmetrizing them (defalut 0) % -- do_save - whether to save the intermediate matrices. Note that % saving is before symmetrizing and/or PSD in case they exist % (default 0) % -- save_path - in case do_save==1 a filename is needed, the % format is save_path/part_k.mat % -- num_of_total_steps number of total steps the algorithm % will run - (default (N^2)/2) % -- num_of_inner_steps number of steps between each save % point (default num_of_total_steps/100) % -- sym_every - an integer multiple of num_of_inner_steps, indicates % the frequency of symmetrizing in case do_sym=1. The % end step will also be symmetrized. (default 1) % -- psd_every - an integer multiple of num_of_inner_steps, % indicates the frequency of projecting on PSD cone in case % do_psd=1. The end step will also be PSD. (default 1) % % (C) Gal Chechik, Uri Shalit 2008 % % Code provided for academic use. For other use, please contact the % authors. % % [N,dim] = size(data); W = eye(dim); [unused_values, inds] = sort(class_labels); %#ok class_labels = class_labels(inds); data = data(inds,:); classes = sort(unique(class_labels)); num_classes = length(classes); % translate class labels to serial numbers 1,2,... %------------------------ for i=1:length(classes) class_labels(class_labels==classes(i)) = i; end class_sizes = zeros(num_classes,1); class_start = zeros(num_classes,1); for k=1:num_classes class_sizes(k) = sum(class_labels==k); class_start(k) = find(class_labels==k, 1, 'first'); end % initializing parms %------------------------ if ~isfield(parms, 'aggres') parms.aggres = 0.1; fprintf('aggressivenesss value set of 0.1') end if ~isfield(parms, 'rseed') parms.rseed = 1; end if ~isfield(parms, 'do_save') parms.do_save = 0; elseif parms.do_save == 1 if ~isfield(parms,'save_path') parms.save_path = fullfile(pwd, 'oasis_saves'); end mkdir(parms.save_path); end if ~isfield(parms, 'num_of_total_steps') parms.num_of_total_steps = ceil(N^2/2); end if ~isfield(parms, 'num_of_inner_steps') %% save_every_n_steps parms.num_of_inner_steps = floor(parms.num_of_total_steps/100); end if ~isfield(parms, 'do_sym') parms.do_sym = 0; end if ~isfield(parms, 'do_psd') parms.do_psd = 0; end if isfield(parms, 'do_sym') && parms.do_sym && ~isfield(parms, 'sym_every') parms.sym_every = 1; end if isfield(parms,'do_psd') && parms.do_psd && ~isfield(parms, 'psd_every') parms.psd_every = 1; end % Init loss_steps = zeros(1,parms.num_of_total_steps); num_batches = ceil(parms.num_of_total_steps/ ... parms.num_of_inner_steps); steps_vec = [repmat(parms.num_of_inner_steps, 1, num_batches-1) ... parms.num_of_total_steps-(num_batches-1)* ... parms.num_of_inner_steps]; data = full(data); for i_batch = 1:num_batches [W,l] = oasis_c(W, data', class_labels, class_start, class_sizes, ... steps_vec(i_batch), parms.aggres, parms.rseed, i_batch); loss_steps((i_batch-1)*parms.num_of_inner_steps+1:i_batch* ... parms.num_of_inner_steps) = l; if parms.do_save temp_filename = fullfile(parms.save_path, ... sprintf('part_%03.0f.mat',i_batch)); save(temp_filename,'W'); end if parms.do_sym if (mod(i_batch, sym_every) == 0) || i_batch == num_batches W = 0.5*(W+W'); end end if parms.do_psd if (mod(i_batch, psd_every) == 0) || i_batch == num_batches [V,D] = eig(0.5*(W+W')); W = V*max(D,0)*V'; clear V D; end end end model.W = W; model.loss_vec = loss_steps; model.parms = parms; return end