论文标题
DHP:可通过超网络修剪的可微分修剪
DHP: Differentiable Meta Pruning via HyperNetworks
论文作者
论文摘要
网络修剪一直是神经网络加速和减轻模型存储/传输负担的推动力。随着自动和神经体系结构搜索(NAS)的出现,修剪已成为自动机制和基于搜索的体系结构优化的话题。然而,当前的自动设计依赖于增强学习或进化算法。由于这些算法的非差异性,修剪算法需要长时间的搜索阶段才能达到收敛。 为了解决这个问题,本文通过Hypernetworks引入了一种可区分的修剪方法,以用于自动网络修剪。专门设计的Hypernetworks将潜在向量作为输入并生成骨干网络的权重参数。潜在向量控制骨干网络中卷积层的输出通道,并充当修剪层的手柄。通过对潜在向量进行$ \ ell_1 $ sparsity正则化并利用近端求解器,可以获得稀疏的潜在向量。通过超网络传递稀疏的潜在向量,可以去除生成的权重参数的相应切片,从而实现网络修剪的效果。所有层的潜在向量都被整理在一起,从而产生了自动层配置。在各种网络上进行了广泛的实验,以进行图像分类,单图像超分辨率和降解。实验结果验证了所提出的方法。
Network pruning has been the driving force for the acceleration of neural networks and the alleviation of model storage/transmission burden. With the advent of AutoML and neural architecture search (NAS), pruning has become topical with automatic mechanism and searching based architecture optimization. Yet, current automatic designs rely on either reinforcement learning or evolutionary algorithm. Due to the non-differentiability of those algorithms, the pruning algorithm needs a long searching stage before reaching the convergence. To circumvent this problem, this paper introduces a differentiable pruning method via hypernetworks for automatic network pruning. The specifically designed hypernetworks take latent vectors as input and generate the weight parameters of the backbone network. The latent vectors control the output channels of the convolutional layers in the backbone network and act as a handle for the pruning of the layers. By enforcing $\ell_1$ sparsity regularization to the latent vectors and utilizing proximal gradient solver, sparse latent vectors can be obtained. Passing the sparsified latent vectors through the hypernetworks, the corresponding slices of the generated weight parameters can be removed, achieving the effect of network pruning. The latent vectors of all the layers are pruned together, resulting in an automatic layer configuration. Extensive experiments are conducted on various networks for image classification, single image super-resolution, and denoising. And the experimental results validate the proposed method.