● The Gradient 📅 14/10/2023 à 17:30

Neural algorithmic reasoning

Géopolitique 👤 Petar Veličković
Illustration
In this article, we will talk about classical computation: the kind of computation typically found in an undergraduate Computer Science course on Algorithms and Data Structures [1]. Think shortest path-finding, sorting, clever ways to break problems down into simpler problems, incredible ways to organise data for efficient retrieval and updates. Of course, given The Gradient’s focus on Artificial Intelligence, we will not stop there; we will also investigate how to capture such computation with deep neural networks.Why capture classical computation?Maybe it’s worth starting by clarifying where my vested interest in classical computation comes from. Competitive programming—the art of solving problems by rapidly writing programs that need to terminate in a given amount of time, and within certain memory constraints—was a highly popular activity in my secondary school. For me, it was truly the gateway into Computer Science, and I trust the story is similar for many machine learning practitioners and researchers today. I have been able to win several medals at international programming competitions, such as the Northwestern Europe Regionals of the ACM-ICPC, the top-tier Computer Science competition for university students. Hopefully, my successes in competitive programming also give me some credentials to write about this topic.While this should make clear why I care about classical computation, why should we all care? To arrive at this answer, let us ponder some of the key properties that classical algorithms have:They are provably correct, and we can often have strong guarantees about the resources (time or memory) required for the computation to terminate.They offer strong generalisation: while algorithms are often devised by observing several small-scale example inputs, once implemented, they will work without fault on inputs that are significantly larger, or distributionally different than such examples.By design, they are interpretable and compositional: their (pseudo)code representation makes it much easier to reason about what the computation is actually doing, and one can easily recompose various computations together through subroutines to achieve different capabilities.Looking at all of these properties taken together, they seem to be exactly the issues that plague modern deep neural networks the most: you can rarely guarantee their accuracy, they often collapse on out-of-distribution inputs, and they are very notorious as black boxes, with compounding errors that can hinder compositionality. However, it is exactly those skills that are important for making AI instructive and useful to humans! For example, to have an AI system that reliably and instructively teaches a concept to a human, the quality of its output should not depend on minor details of the input, and it should be able to generalise that concept to novel situations. Arguably, these skills are also a missing key step on the road to building generally-intelligent agents. Therefore, if we are able to make any strides towards capturing traits of classical computation in deep neural networks, this is likely to be a very fruitful pursuit.First impressions: Algorithmic alignmentMy journey into the field started with an interesting, but seemingly inconsequential question: Can neural networks learn to execute classical algorithms?This can be seen as a good way to benchmark to what extent are certain neural networks capable of behaving algorithmically; arguably, if a system can produce the outputs of a certain computation, it has “captured” it. Further, learning to execute provides a uniquely well-built environment for evaluating machine learning models:An infinite data source—as we can generate arbitrary amounts of inputs;They require complex data manipulation—making it a challenging task for deep learning;They have a clearly specified target function—simplifying interpretability analyses.When we started studying this area in 2019, we really did not think more of it than a very neat benchmark—but it was certainly becoming a very lively area. Concurrently with our efforts, a team from MIT tried tackling a more ambitious, but strongly related problem:What makes a neural network better (or worse) at fitting certain (algorithmic) tasks?The landmark paper, “What Can Neural Networks Reason About?” [2] established a mathematical foundation for why an architecture might be better for a task (in terms of sample complexity: the number of training examples needed to reduce validation loss below epsilon). The authors’ main theorem states that better algorithmic alignment leads to better generalisation. Rigorously defining algorithmic alignment is out of scope of this text, but it can be very intuitively visualised:Here, we can see a visual decomposition of how a graph neural network (GNN) [3] aligns with the classical Bellman-Ford [4] algorithm for shortest path-finding. Specifically, Bellman-Ford maintains its current estimate of how far away each node is
← Retour