% Dr. Durant Week 3 Lab Example, Pruning example based on: % https://www.mathworks.com/help/deeplearning/ug/parameter-pruning-and-quantization-of-image-classification-network.html % 2022-03-24, CSC4651/5651 format compact % skip fewer lines to fit more output in console % Adapted to work with CSC4651/5651 week 2 lab network load hill_valley_alexSuccess.mat % get successful network in memory as trainedNetwork_1 (default name) % following MATLAB example -- see the example to better understand what this code is doing... lgraph = layerGraph(trainedNetwork_1.Layers); lgraph = removeLayers(lgraph,["prob","classoutput"]); % You'll get an error if your layer names are different dlnet = dlnetwork(lgraph); analyzeNetwork(dlnet) % In MATLAB the semicolon (;) suppresses console output. Here I omit it to display important values numTotalParams = sum(cellfun(@numel,dlnet.Learnables.Value)) numNonZeroPerParam = cellfun(@(w)nnz(extractdata(w)),dlnet.Learnables.Value) initialSparsity = 1-(sum(numNonZeroPerParam)/numTotalParams) %% Load data used in week 2: This is for the .zip file. Change the name to use noisy data, etc. imds = imageDatastore("hillvalley_plots", "IncludeSubfolders",true,"LabelSource","foldernames"); [imdsTrain,imdsVal] = splitEachLabel(imds,0.7); % Duplicate split from week 2 lab trueLabels = imdsVal.Labels; classes = categories(trueLabels) %% Set up pruning sparsities to try if true % Original tutorial code -- 10 steps from 0.0 to 0.9 numIterations = 10; targetSparsity = 0.9; iterationScheme = linspace(0.0,targetSparsity,numIterations); else % Focus on the range of interest for your model, which might span just a few percentage points numIterations = 20; iterationScheme = linspace(0.96, 0.98, numIterations); end pruningMaskMagnitude = cell(1,numIterations); pruningMaskMagnitude{1} = dlupdate(@(p)true(size(p)), dlnet.Learnables); miniBatchSize = 16; imdsVal.ReadSize imdsVal.ReadSize = miniBatchSize; % deepNetworkDesigner % TODO: Clean this up -- resize when source is first created imdsValOrig = imdsVal; imdsVal = augmentedImageSource([227 227 3], imdsValOrig) mbqValidation = minibatchqueue(imdsVal,1,... 'MiniBatchSize',miniBatchSize,... 'MiniBatchFormat','SSCB',... 'MiniBatchFcn',@preprocessMiniBatch,... 'OutputEnvironment','auto'); %% Initialize Plot figure plot(100*iterationScheme([1,end]),100*accuracyOriginalNet*[1 1],'*-b','LineWidth',2,"Color","b") ylim([0 100]) xlim(100*iterationScheme([1,end])) xlabel("Sparsity (%)") ylabel("Accuracy (%)") legend("Original Accuracy","Location","southwest") title("Pruning Accuracy") grid on %% Magnitude Pruning pruningMaskMagnitude = cell(1,numIterations); pruningMaskMagnitude{1} = dlupdate(@(p)true(size(p)), dlnet.Learnables); lineAccuracyPruningMagnitude = animatedline('Color','g','Marker','o','LineWidth',1.5); legend("Original Accuracy","Magnitude Pruning Accuracy","Location","southwest") % Compute magnitude scores scoresMagnitude = calculateMagnitudeScore(dlnet); for idx = 1:numel(iterationScheme) prunedNetMagnitude = dlnet; % Update the pruning mask pruningMaskMagnitude{idx} = calculateMask(scoresMagnitude,iterationScheme(idx)); % Check the number of zero entries in the pruning mask numPrunedParams = sum(cellfun(@(m)nnz(~extractdata(m)),pruningMaskMagnitude{idx}.Value)); sparsity = numPrunedParams/numTotalParams % no ; to display value % Apply pruning mask to network parameters prunedNetMagnitude.Learnables = dlupdate(@(W,M)W.*M, prunedNetMagnitude.Learnables, pruningMaskMagnitude{idx}); % Compute validation accuracy on pruned network accuracyMagnitude = evaluateAccuracy(prunedNetMagnitude,mbqValidation,classes,trueLabels); % Display the pruning progress addpoints(lineAccuracyPruningMagnitude,100*sparsity,100*accuracyMagnitude) drawnow end