论文标题

提高应用于动态约束的优化问题的梯度下降算法的效率

Improving the Efficiency of Gradient Descent Algorithms Applied to Optimization Problems with Dynamical Constraints

论文作者

Matei, Ion, Zhenirovskyy, Maksym, de Kleer, Johan, Maxwell, John

论文摘要

我们介绍了两个块坐标下降算法,以解决使用普通微分方程(ODE)作为动态约束的优化问题。该算法无需实施直接或伴随的灵敏度分析方法来评估损失功能梯度。它们是由原始问题重新制定,作为与平等限制的等效优化问题。该算法自然遵循旨在根据ode求解器恢复梯度定位算法的步骤,该算法明确地说明了ODE解决方案的灵敏度。在我们的第一个提出的算法中,我们避免通过将ODE求解器集成为隐式约束的序列来明确解决ODE。在我们的第二个算法中,我们使用ODE求解器来重置ODE解决方案,但没有直接使用伴随灵敏度分析方法。这两种算法都接受小批量实施,并从基于GPU的并行化中显示出显着的效率。当应用于学习Cucker-Smale模型的参数时,我们演示了算法的性能。将算法与基于具有敏感性分析能力的ODE求解器的梯度下降算法进行比较,使用Pytorch和JAX实现,具有敏感性分析能力。实验结果表明,所提出的算法至少比Pytorch实现快4倍,并且比JAX实现快至少16倍。对于大版本的Cucker-Smale模型,JAX实现的速度比基于灵敏度分析的实现快数千倍。此外,我们的算法在培训和测试数据上都会产生更准确的结果。对于实施实时参数估计(例如诊断算法)的算法,计算效率的提高至关重要。

We introduce two block coordinate descent algorithms for solving optimization problems with ordinary differential equations (ODEs) as dynamical constraints. The algorithms do not need to implement direct or adjoint sensitivity analysis methods to evaluate loss function gradients. They results from reformulation of the original problem as an equivalent optimization problem with equality constraints. The algorithms naturally follow from steps aimed at recovering the gradient-decent algorithm based on ODE solvers that explicitly account for sensitivity of the ODE solution. In our first proposed algorithm we avoid explicitly solving the ODE by integrating the ODE solver as a sequence of implicit constraints. In our second algorithm, we use an ODE solver to reset the ODE solution, but no direct are adjoint sensitivity analysis methods are used. Both algorithm accepts mini-batch implementations and show significant efficiency benefits from GPU-based parallelization. We demonstrate the performance of the algorithms when applied to learning the parameters of the Cucker-Smale model. The algorithms are compared with gradient descent algorithms based on ODE solvers endowed with sensitivity analysis capabilities, for various number of state size, using Pytorch and Jax implementations. The experimental results demonstrate that the proposed algorithms are at least 4x faster than the Pytorch implementations, and at least 16x faster than Jax implementations. For large versions of the Cucker-Smale model, the Jax implementation is thousands of times faster than the sensitivity analysis-based implementation. In addition, our algorithms generate more accurate results both on training and test data. Such gains in computational efficiency is paramount for algorithms that implement real time parameter estimations, such as diagnosis algorithms.

扫码加入交流群

加入微信交流群

微信交流群二维码

扫码加入学术交流群,获取更多资源