heading · body

Transcript

Horace He Building Machine Learning Systems For A Trillion Trillion Floating Point Operations

read summary →

TITLE: Horace He: Building Machine Learning Systems for a Trillion Trillion Floating Point Operations CHANNEL: Jane Street DATE: 2024-12-09 ---TRANSCRIPT--- So it’s my great pleasure to introduce our speaker tonight Horace He. Horace He graduated from Cornell in 2020 and he works at Meta on the PyTorch team specifically on at the intersection of compilers and machine learning. If you’ve used things like torch.compile which is a thing in PyTorch that makes your model goes like two to four x faster in just one line of code or if you’ve used flex attention which is something that lets researcher design fast kernel for attention without leaving the world of Python is the person responsible for both of those things. You should also check his blog which is pretty awesome and in particular the blog post making deep learning global from first principles. And without further ado I’ll give it to Horace. Thanks for the intro. Today I’m going to give a talk about building machine learning systems for a trillion trillion floating point operations. My name is Horace He and I’m on the PyTorch compilers team at Meta. So you know I think we live in pretty unprecedented times in terms of infrastructure build out which is I think you know nicely reflected in Nvidia’s stock price. I feel like you know basically every month we see like a different headline about like a new nuclear power plant from like Microsoft or you know like a massive 300k GPU data cluster from xAI. You know people like I feel like when Amazon first like built like or like bought a nuclear power center it was like a big news but now practically like you know every cool AI company has their their like own nuclear data center. And you know of course this kind of really massive infrastructure build out has kind of also resulted in pretty ludicrous fund raises from startups I think. Like I remember back in like 2016 like you know if you like made it as a startup it’d be like if you were worth a billion dollars right it’s like a unicorn is like the mark of like really making it as a startup beyond your wildest dreams. But you know in 2024 you know you got to raise a billion dollars just like get started you know it’s like just to play the game you need like a billion dollars you know of which most of it goes to Nvidia. And so I think like it’s kind of crazy to think that you know all of this like you know billions of dollars is really just to do like absolutely insane amount of floating point operations. Here’s like a nice chart from like Epoch AI showing like the growth of compute over time. And um like these floating point operations are really just like big matmuls done like over and over like you know for millions of iterations over like months and nowadays like the leading edge models are currently trained with about like 1e26 floating point operations which is approximately 100 trillion trillion floating point operations so like a trillion like a trillion teraflops worth of floating point operations. And so you know back you know another kind of effect you might have noticed is like back prior to 2016 if you search on like Hacker News like ML you’d often get like a lot of people asking about you know the ML family of languages. But nowadays you know you search ML on a Hacker News and you get like a very different like type of article. And so I think one of the things that kind of is missed here is like you know with all these like billions of dollars and like you know yottaflops of operations it’s kind of easy to forget that you know like these operations needed to actually run somehow on the like machines. And so you know a modern stack might involve you know like call like an LM API and then this calls like PyTorch like Nickel Triton CUDA like NVCC like all these like different layers in the stack that like somebody was involved in writing. And I’m often reminded of this like XKCD comic where you know showing like the state of like modern digital infrastructure. And so you know you kind of you know often have to forget about all the infrastructure that was involved for you to like get where you are. But like really we’re just kind of building like layers on top of layers. And so you know if you work in a systems you know like I do and as like many of you do you kind of I think oftentimes think about your work a little bit like like a benevolent dictator of sorts where you kind of you imagine like what you’re doing here is that like a lot of people build on top of your work. And so if you can kind of you know give like a small amount improvement to you know like millions of people you know your small improvements can lead to like you know significant impacts on the world. And so for example I kind of imagine like Guido with Python he just like you know kind of sits on top of his cloud and you know eventually like gives us like no GIL or like faster CPython or like the walrus operator. I feel like this is kind of oftentimes like how I imagine like infrastructure work did and it’s kind of a lot of why I got into infra work in the first place. And so on the other hand I feel like if you work in ML things can sometimes feel a little bit different. I recently came across this post on threads where this person said you know my main gripe with like working on top of like LM APIs is that you’re not really like no one is like engineering anything you’re just like chancing a like prayer to this like man-made demon deity to like do your bidding. And this is very similar to this kind of like Shoggoth metaphor that has become like very popular in deep learning circles where the idea here is that like we really like put all these like matmuls and like computation into producing this like very weird alien intelligence and then we like kind of like RLHF it you know and like provide it in this like nice convenient interface to people with like you know ChatGPT or something like that. But you know I think if you think about it I think we’re kind of working with Shoggoths like all the way down. Like like even like if you’re working like in systems and you’re not calling like ML models you still have this kind of like you know massive amount of infrastructure that you presumably don’t really understand written by people you probably never talked to. And like the end result you know they try to expose it as like some kind of a simple interface where you can just like import torch and then you know hopefully run your code across like you know 100k GPUs. And so I think as a result there are some I think interesting ways in which I think ML models feel kind of different from regular systems. One of those ways is that like ML models are extremely simple and as a result we have like very high expectations from the performance of like the the code. I think kind of like trace all the way this all the way back to this very nice article called the bitter lesson which I’d really recommend reading if you guys haven’t come across it before. Where the main observation here was that like clever ideas in machine learning have basically throughout like its 50 year history always lost out to just simple ideas that like scaled really well with Moore’s law. And I’m not really like joking here when I say that like machine learning logic is exceedingly simple. There’s like this cool project from someone called like Andrej Karpathy called Llama 2.c and basically here he’s like implemented like Llama 2 in about like 973 lines of C with like no other dependencies like you know every single loop every single like matmul is just implemented from scratch. So with 973 lines you can’t like you know do it very fast but it does run and I think it does kind of indicate just like how fundamentally simple the models that like we’re spending all this compute end up being. And so the end result like although like the problems themselves are extremely simple and like very like easy to optimize in some sense the expectations are very high for how well we can optimize these matmuls. So one example here is that like the predominant metric for measuring your like models performance in deep learning is called model flop utilization. And so this is basically the like percentage of the theoretical max flops that your GPU is able to do. And so if you kind of think about this this is actually like a very absurd metric to hit. Like if on a CPU like you measured any of your code by this metric like you can only hit 100% if like at every single time every single core of your CPU is always issuing max width SIMD instructions. That’s like the only way you can hit 100% utilization. And so you know if you take a look at any of the code that you guys presumably write like almost no CPU code is like anywhere near like you know 100% flop it’s probably like way under like 1% most of the time. On the other hand in like machine learning for like large scale training we’re typically often hitting around like 50% of like the peak flops. And this is I think like this is kind of like indicative that like even though the overall problems like very simple the kind of corresponding difficulty just goes into like making your like models hit like this like very high perf barrier. Another kind of interesting observation in like machine learning is that that has kind of exacerbated this a bit is that the field has consolidated significantly over the last five to 10 years. So one of them is that like you know like maybe 10 years ago you had like a lot more architectures a lot more variants and like different things that people were trying. But nowadays people really just like transformers are like the dominant architecture for everything like you have transformers for vision you have transformers for language you have transformers for like you know audio it’s kind of like all transformers. And the other way like things have kind of changed is that instead of like you know many different people training soda models you oftentimes just have a few companies training soda models. And so we’ve kind of gone from a bit of like a monopoly where like you know previously there was just like you know one person providing the infra and like many people using the infra to in some ways it feels a little bit more like a monopsony where you have like many people trying to provide infra and then you know only one person is actually training the the job at the end of the day. And so as a kind of result, uh, I I think that, uh, like there’s kind of two general ways to think about like getting performance from like your systems. And so one of the ways generally is basically like optimization. It’s like you have a compiler, you like make the compiler faster, um, you know, you improve like the performance of everybody using your compiler. Um, and so the other way though is kind of, uh, like programming models. And so like this is kind of analogous I suppose to like a system whose responsibility is to like chop down a tree and then, you know, you’re just like optimizing how fast you can chop down the tree. And the other alternative is like you’re providing tools for people to cut down the trees, uh, themselves. And so to kind of like talk a little bit about, uh, programming models, um, I I think it’s, uh, like illustrative to kind of talk about how like ML frameworks have kind of evolved in terms of what programming model they’ve, uh, exposed to users. So originally, um, like I think like 2010, like 2011, 2012, uh, the first like, uh, ML framework that kind of got a lot of popularity was this framework called Caffe. Uh, and the way you like express neural networks in Caffe is like a very declarative, uh, nature. And by that I mean you like edited a Protobuf file, I think. Uh, where like the Protobuf like specified all the things you needed to care care about your neural network. Um, and so as you can imagine, you know, programming Protobufs is like not very fun. You know, there’s a lot of things you might want to do, uh, that you can’t do in Protobufs. Uh, and so a like natural next thing that people did, uh, was kind of these kind of graph builder type, uh, APIs. And so this is kind of, you know, how TensorFlow 1 kind of looked like, where the idea is that, oh, you know, like programming like, you know, programming Protobufs or like no human should need to write Protobufs by hand. And so I I ideally, you know, you should just like write like a DSL of sorts that allows you to generate the Protobufs, uh, like from this, uh, DSL. However, like even this DSL still kind of has like a lot of, uh, confusion in it. Like it is not super clear that like given this DSL, like how code actually executes, uh, on your GPU. And so kind of finally, uh, like around 2016 or 2017, uh, PyTorch started to become like really successful. And kind of the like, uh, the feature that like was most, uh, emblematic of PyTorch uh, was basically what was called imperative or eager execution. And so what this means is that like, uh, um, or like, yeah, so so PyTorch, you know, was very successful. And I think it’s like worth talking about like why eager execution was, uh, so successful. And I think that the main reason it was successful just comes down to like what the programming model of execution looked like. Where in a imperative slash like eager execution, it’s basically like, you know, you call a function, uh, the GPU runs a function, and then your function finishes. And, you know, that’s it. Like, you know, like that’s basically all you do. You call like torch.matmul, and this is basically the sequential operations that happens. Um, but, you know, with kind of like a graph mode, uh, type approach or, you know, where kind of this compiler interjects, you first like define the function, the function gets converted into some like intermediate IR, uh, you know, a bunch of who knows what happens to to your function, and then eventually like the function eventually executes, uh, on the GPU. Um, and the first like the top one was like a very simple execution model for people to understand. Um, and I think another thing that’s kind of interesting to notice about this is that, uh, this kind of also just like the top, uh, half also kind of describes how Python, uh, executes. And I think this is a kind of, uh, pretty illustrative to me of like why Python has been so successful in machine learning. It’s basically that like uh, like a a funny statement that you can make about Python is that, uh, like if you tried to like train a model, uh, today, like, you know, you were like I gave you like a day to like train a model, uh, it would run faster if you ran it in Python compared to doing in C++. Uh, and you might argue that this is like, uh, unfair comparison, um, because, you know, like, you know, all the infra, like, you know, all the frameworks that people built, uh, are in Python. Uh, but I think the reason why so many of these frameworks have been built in Python is that, uh, Python is like an exceedingly like, uh, simple language. And so it’s also like a very growable language. Uh, and and what I mean by like growable is that like it’s very easy for people to build their own infrastructure on top of Python uh, without really needing to fight with like anything the Python language does, uh, because the Python language itself does basically nothing. Like like it it like, you know, it doesn’t do any optimizations for you. It just like, you know, takes your function and and runs it. And so I I think that PyTorch, uh, basically historically is at like a very similar, uh, point in the design space, where PyTorch is like execution model is like so very simple. And although this like doesn’t really give you a lot in terms of like it doesn’t like automatically do a lot of things for you, it does mean that it’s very easy for people to like build their own, uh, infrastructure and frameworks, uh, on top of PyTorch. I I think another important, uh, detail to realize, especially about like PyTorch when it first came out, is that like this kind of unoptimized execution didn’t actually even, uh, sacrifice any performance, uh, at all. Like, you know, when people kind of benchmarked PyTorch versus like, you know, TensorFlow or Caffe, uh, PyTorch oftentimes wasn’t even, uh, slower than those frameworks. And I I think there’s like, uh, two main reasons for this. So the first reason is that, uh, back in the day, like, you know, about like 90% plus of your time was spent in matmuls. And so there’s basically nothing else you need to optimize. Uh, and so matmuls here are like matrix multiplications. And so they’re often provided by these like vendor libraries like cuBLAS or like cuDNN. Uh, and but like, you know, they’re provided by Nvidia and they’re like very hand optimized. And so, you know, if 90% of your time is spent in matmuls, then like what else can you even do to like optimize the performance of your neural network? Um, and I think another kind of, uh, interesting, uh, like um, important piece here for like why PyTorch’s performance was like quite good, uh, is that it had this like async execution model. Uh, where basically, uh, the idea here is that like you kind of have like a parallel work queue on your GPU. And so what the, uh, CPU does is that it’s only responsible for scheduling work on your work queue. Uh, and then, you know, the GPU executes work, uh, from the work queue. Um, and I think of it generally as like, uh, this gif is what usually comes to mind. And basically you can imagine that like the the dog is like, uh, Gromit is like Python. You know, they’re like, you know, trying to put down the train track, uh, in front of the train, which is the GPU. And so, you know, as long as like Gromit is able to put down the train tracks faster than the train actually rolls along the train tracks, uh, you can actually kind of view Python as like having zero overhead. Like it it doesn’t provide any extra cost, uh, compared to, you know, if Python was like in a more efficient language like, uh, C++. Um, and so in this way like, you know, eager execution not only had like a much easier to understand program model for users, it was also like basically just as fast, um, as like non, uh, eager execution. Um, unfortunately, you know, good things never last. And then in 2017, like, uh, Nvidia introduced what are called like Tensor Cores. Um, and if you guys are unfamiliar with Tensor Cores, uh, they’re basically like, uh, hardware units on the GPUs that only do matrix multiplications. Like uh, and I I don’t mean this like, uh, figuratively in the sense that like people often say that GPUs are like well suited for matmuls. Uh, I mean this like very literally in that like there’s actually an assembly instruction that just does like a a mini matmul. And this is how you like interact with the Tensor Cores. And so if you look at this like plot, uh, of like the amount of like matmul flops versus non-matmul flops, uh, you can really see like when, uh, Nvidia realized that like deep learning was a big deal. Uh, because like all of a sudden, you know, you kind of had this like massive like 10x gap. And so this is like a log scale. Um, and you had this kind of like massive like 10x gap between how fast, uh, matmuls were on the GPU and how fast like literally anything else you wanted to run on the on the GPU, uh, was. Um, and so the end result is, you know, previously we said that like, you know, matmuls took like 90% of the time. And so if the if like, you know, Nvidia has sped up matmuls by like 10x, uh, but then every everything else like stays the same amount of speed, then all of a sudden, you know, like you need to spend a you’re spending a lot more of your time doing like non-matmul operations. And so as a result, uh, we’ve kind of gotten like ML compilers, uh, due to this, I think largely, uh, due to this change. And so one of the I think important details about ML compilers like, you know, in terms of like how they differ from the frameworks that came before, uh, is that ML frameworks still keep like the eager programming model. In that like the code that you write, like logically, the the programming model exposed to users, is that you’re still just writing Python code and executes like line by line. Uh, the only difference now is that instead of actually executing line by line, we kind of captured into like a graph, uh, in some manner. And and so torch.compile I think actually kind of does this in a pretty, uh, interesting way. Uh, and and that torch, uh, compile actually like, uh, intercepts at like the Python bytecode interpreter level, uh, where Python kind of exposes these APIs where you can kind of, uh, insert your like own frame interpreter. And so this looks very much just like a traditional like, you know, JIT for like any kind of other VM, except this JIT is kind of, you know, only uh, meant for like, uh, PyTorch programs. Um, and so if you kind of look at like, you know, how, uh, like things have evolved over time, originally you kind of had like frameworks like TensorFlow or Caffe before that, uh, where both the the user the user like programming model that users wrote was like a graph builder type abstraction. Uh, but then the execution programming model was also a graph execution type abstraction. Uh, and then like after that you kind of had PyTorch, you know, 1, uh, type like, uh, stuff, uh, where the user programming model now switched to like an eager style execution, uh, but the execution program and execution program model was also eager, uh, where like, you know, each operator is executed one at a time. Uh, but kind of now finally, you know, modern ML frameworks like pretty much all all ML frameworks nowadays, uh, use like an imperative, eager uh, programming model, uh but almost all ML frameworks now also have some way to like capture this uh program model into a graph of some kind so they can perform optimizations. And so this is like, you know, Jax.jit, this is like TinyGrad, this is like MLX. They kind of all have their like different approaches for capturing the graph um uh to optimize. Um And so I I think next I I kind of want to talk about, you know, we’ve kind of discussed how we’ve gotten to uh ML compilers uh in the first place. And so I I think next I want to talk about like what uh ML compilers are actually doing for you uh and what kind of optimizations they’re performing. And and so generally speaking, the way I think about uh deep learning performance uh or like performance on GPUs in general is that there’s basically three things you can be uh spending your time on. Uh the first one is compute. So this is times uh on your GPU computing like actual floating-point operations. Uh the next one is a memory, which is uh you know, time spent transferring your tensors within a GPU. So this is like, you know, across various memory subsystems uh in your GPU. And so finally like overhead, which is like everything else, like, you know, it’s like time your GPU is spending idle and so on. And so first we’re going to talk about uh compute. And so I think to a somewhat uh approximation, you can say that all runtime on your GPU is either compute or it’s a shuffling data. Um in that like, you know, me- uh data movement is like not like a real operation, right? It’s like a no It’s like a no-op from like the theoretical point of view. All it’s doing is it’s moving data from one place where it’s convenient to another place where it’s convenient. And so basically a floating-point operation is like the only real thing a GPU can do. Um but you can actually I I think I simplify this even more and say that uh in reality actually nowadays uh like all runtime is either like matmuls or essentially uh shuffling data. And so this is like because if you look at the actual like uh flop chart on like an A100 X um like GPU, uh you can see here that like the FP32 flops is like you only have 67 teraflops of FP32 compute, uh but you actually have like a thousand teraflops of uh TF32 compute, uh which is basically like matrix multiplication uh compute. And so what does essentially you can So this is like what this means is that if you’re not doing matrix multiplications on your uh GPU, you’re really only like uh getting like 7% of like your peak uh flop utilization. And so like, you know, by like the metric that we I mentioned before, like model flop utilization, even if your GPU was fully occupied doing stuff that uh wasn’t a matmul, you could only ever get like 7% uh flop utilization, uh which is like much lower than, you know, our theoretical uh peak. Um yeah. Um I I I do I do have a brief interlude about like I think an interesting case where like these kind of abstractions uh do break down even more. Um and so I I do have like a kind of a fun question, which is like do the matrix contents affect your matrix multiplication performance? And so I I think uh you know, if you kind of like, you know, are familiar with like, you know, general performance, there are like a lot of things that were a lot of ways where like data can impact your performance. But in this case, matmuls actually avoid a lot of them. So for example, they have identical memory access patterns regardless of like what data is in your tensor. Uh there’s like no control flow uh in the matmul. And the GPUs also don’t have like denormals, uh so like that’s like not not a possibility as well. Um so if you like, you know, So in in this case, we’re like taking uh three tensors. We’re initializing them with like all zeros, uh like a tensor initialized from the Gaussian distribution, um and then a tensor initialized from like the uniform distribution, which is like from zero to one. Um and so funnily enough, if you benchmark this, you actually find that there is a performance difference uh depending on the actual data that is within your tensors. And so uh there’s this tweet uh from uh somebody you guys might know uh a long time ago uh that really I think uh like when I first saw this, I actually was very much reminded of this tweet, where, you know, I I thought I knew how like a GPU works. Uh but then I I I was like very confused about what could possibly be causing these like performance differences um like, you know, between the different uh data. And so the actual cause here is something called uh like leakage power or or dynamic power, where I think, you know, most of you are probably familiar that, you know, when you like when a CPU or GPU is under load, it uses more like power, and at some point it can like throttle, uh you know, it’s like using the max amount of power it can use or or the max amount of like a heat uh you’re you’re it’s allowed. Uh but the actual thing is that like, you know, this power doesn’t just like come from nowhere. Uh it actually largely comes from what’s called like dynamic or uh switching power. And what this means is like every time a transistor like on the GPU uh switches from like zero to one or like, you know, high to low or low to high, uh it like loses a a little bit of this power. And so the actual like power usage on your GPU is kind of like a like sum across the total amount of like switching uh that goes on in your GPU. And so this is why like if you’re multiplying with all zeros, you can imagine that like your GPU ends up not like a lot of transistors don’t end up switching at all. Uh and so like it doesn’t actually consume that much power and it’s much less of a throttled. And so if you actually uh like look at this, um you you can actually get like very different performance for like all these different kind of fun distributions, like uh whether it’s like, you know, the the normal distribution or whether it’s like a checkerboard type like pattern or, you know, it’s like sparse or ternary. Uh and the reason why is just that like it’s it’s like this kind of abstract thing where these different patterns lead to like more or less transistor uh flips, and which leads to like more or less of a power throttling, uh which leads to like more or less of a performance. Um and so I remember actually one time uh somebody told me a funny story where like they uh they were like training the machine learning model and benchmarking performance, um and then they like at some point uh their model would NaN, and then they’d be like, “Wow, my performance just got way better.” Uh and like they they So I like wrote an article about this, and they like messaged me, and they’re like, “Oh, you know, that was very illustrative uh because you know, I was really confused why my performance would be getting better.” Uh but but that’s why, you know, like if all your tensors are NaN, uh your transistors also don’t need to do a lot of flipping, and so, you know, you’ll measure like a better perf. Um yeah, so so that’s kind of compute. And so the next thing that your uh GPU can spend your can be spending a lot of your time on is like memory, uh which is essentially the time spent transferring your tensors uh within your GPU. And and so I think one thing to observe from this kind of empirical plot uh from a paper on like a data movement is that uh although like uh so this paper kind of breaks down the operations on like uh I think a BERT type model into what it calls like tensor contractions, i.e., matrix multiplications, uh and then like, you know, normalization operations and element-wise operations. And so you can see that although uh matrix matrix multiplications are responsible for like 99.8% of your flops, uh they’re only responsible for 61% of your runtime. And so, you know, where like why are we spending like, you know, 40% of our runtime doing like operations that cumulatively only take 0.2% of our flops? Um the kind of uh key thing here is what’s uh called like a memory bandwidth cost. Uh where the way I typically uh think about this is that uh like even like uh and so here I’m talking like all of your data already lives on the GPU. Like, you know, it’s like you know, it’s like, you know, occupying your like um uh GPU’s VRAM. Um but the thing is that like your GPU’s uh VRAM uh is not where it like the compute units are located. And in order to actually do operations on a GPU, you need to move your data from like the VRAM uh to like where the compute uh units are located, uh which is like your SRAM or like compute units. And and so usually I kind of think this of like as like a factory, uh where you have like a factory with like that not that much space. And then you have a warehouse located um like much further away. Uh and so you have like a lot more space in your warehouse. Uh but now in order to do any operations like on those uh like uh supplies, you need to move them from your warehouse uh to your factory uh and then back. And so this cost of like moving data around is uh called the memory bandwidth cost. Um and so this is actually like responsible for like a lot of uh like what your GPU is spending its time doing. Uh where if you imagine like let’s say we do like, you know, three operations like on a GPU, so maybe you’re doing like an add and then like uh you know, ReLU and then like a sign operation. Um Like you can imagine that what actually happens when you do these operations is like first uh the GPU sends the data from like the memory to the compute units, and then, you know, turns it from a square to a triangle, and then it sends it all the way back. Uh and then, you know, it sends the triangle from like the memory units to the compute units again, uh where it does like another operation, and then sends it all the the way back. And then finally, you know, you guys get the idea. It’s like sending the circle from the memory units to the compute units again, uh and then it’s like sending it all the way back. And so by default, whenever you run any operations like in PyTorch, like you let’s say you ran like add and then multiply uh and then cosine, uh this is exactly what would be happening uh on your GPU. And so you might think that this is a very uh like dumb thing to do. Um and and you you would be correct. Uh and that, you know, why are we like sending our triangle back from like the factory to the warehouse uh just to send the data from like the warehouse uh back to the factory again? And so a a very like common operation uh for uh GPU compilers to do, and I’d say what I would actually call like the most important optimization in a deep learning compiler by far, uh is called operator fusion. And so what an operator fusion does is that instead of like, you know, sending the data back and forth uh so much, uh we like do a single GPU kernel uh where you send the data once uh to the factory units, you do all of the operations, uh and then you send the data back. Um also notably this is also like like an issue like the this optimization is not really something you can do in eager mode, right? Because in eager mode I was you know mentioning like the execute model is very simple where you like run an operation and then it executes the operation. And now if you want to do this optimization that program model is like no longer sufficient. And so there’s actually like a lot of different ways to minimize memory movement. Uh although at the end of the day like operator fusion is like the most important thing you can do for like a ML compiler. And there’s also like a lot of decisions that go into operator fusion that like kind of enable it to be more or less effective. Uh one of the kind of examples here is kind of these like re-computation versus reuse trade-offs. Uh if you guys are kind of familiar with maybe like register allocation type settings, you kind of often have a similar issue where like if you have a register you can choose to either like store it in global memory and then load from it later or you can just choose to like recompute that value from like values that are already in the registers. And so you kind of have a similar idea here where you oftentimes have cases where uh by doing some amount of re-computation you can significantly reduce your memory you can significantly reduce your number of memory accesses. And so in this way like the re-computation can not only like reduce your like peak memory usage, you can also often like improve your actual runtime performance. And so I I think one of the things to mention actually about like why so this optimization actually ends up being like quite important for deep learning performance like re-computation versus reuse. And I think the reason why is that like the shape of machine learning programs actually I think looks quite unusual compared to like your typical program that you might have. Uh where like it’s generally like kind of a bit of an axiom in like programs that usually most intermediates that you have are very short-lived. So like you know your your program generally consists of a lot of very short-lived intermediates that you know are created and then very shortly destroyed. Um but in machine learning this is actually not the case. Uh because in machine learning like the typical like model that you’ll execute will first like run the model forwards like you know layer zero to layer one to layer two to layer three to layer four. And then it needs to run what’s called the backwards pass which will like run the layers in reverse. So then it’ll go from like layer four to layer three to layer two to layer one and to layer zero. And in between the forwards pass and the backwards pass you need to save what are called like intermediates or activations. And so these are like a lot of like you have a lot of them. And they’re like to often like times like you know largely responsible for like running into like out of memory type errors. And so I think this is actually like kind of a pretty unusual like program structure in machine learning that’s caused by like you know back propagation um and gradient descent. Um and so finally you know the last thing you can be spending your time on is overhead where you can imagine that you know if poor Gromit is not able to put down the train tracks faster than the train can like go on the train tracks then sometimes the train is going to be just stuck waiting for him to put down like the next train track. And so here I have like a a profile trace where you can see that like the bottom line which is a GPU trace is largely idle and is mostly just idle waiting for the CPU to like schedule the next operation. And so there are a lot of like ways to actually address this nowadays. One of the most powerful is called a CUDA graphs which is like NVIDIA provided API. But you guys also have like other approaches like in a compiler for example like code generating like a lower overhead wrapper or something like this. So you know I’ve talked about like ML compilers like and you like you know what they can do to your program and how they can be useful. But I think kind of like an interesting question that you often see like you know I talked a lot about how like you know we have like you know super massive infra build out and the programs are super simple and we we’ve seen like a lot of consolidation in terms of like what the architecture looks like. And so I think like a reasonable question is like if you only have like one architecture and you’re spending like billions of dollars to train it why do you even need a compiler? You know like why can’t you just like you know assign some group of people to like optimize it by hand instead of like leveraging a compiler? Um and uh I’m going to say something like kind of or sorry the other thing I’ll say about this actually is that like in practice a lot of times people do not use compilers like for kind of this reason. And so this section is going to be talking a little bit about like why I think that’s the case and what are kind of some of the challenges when it comes to like using compilers in this setting. And yeah so disclaimer you know I I do really like compilers. I’m going to say some kind of mean things about compilers in a bit. And but you know as like to to establish my credibility I work on a team called Fighters Compilers. And so there are like to be clear like a lot of reasons why compilers can be very useful. In particular this kind of like notion of leverage like being able to do the optimization once in the compiler and then having everybody be able to like take advantage of it. Um and also you know compilers are also like very fun to work on. Um that being said uh I’m going to introduce you you know my my new exciting library Horses Exciting Library abbreviated HEL. And so it has a couple of cool features that you might be interested in. Uh so the first feature it has is that it doesn’t always work. Um and you know to address that it also has no documentation about why it or when it will work except by reading my library’s implementation. Um and you know exchange for that when you update the library it may totally change what code works and what code doesn’t work i.e. like no backwards compatibility or like guarantees on your like you know on whether your code works. And so are you interested in using my library? Uh I I’d guess you know most people would probably say no. And so I think one thing to note here that like if work means like has a desired performance and is applying the desired optimizations this is kind of largely describing how compiler optimizations work. And that compiler optimizations don’t always work. Um they are often like there’s no real documentation on when a compiler optimization will trigger. Um and you know when you update your compiler it may completely change when it does or does not apply these optimizations. And so there’s like very influential to me article called like about this compiler called ISPC and they have this note here called auto vectorization is not a programming model. Uh what they note here is that like uh you know the problem with an auto vectorizer which is kind of like a compiler or so the overall like framing of the article is that he wrote this compiler called like ISPC which you can think of as like CUDA for Intel SIMD instructions. Um and he kind of you know is constantly like you know as part of the article is constantly trying to fight against like the Intel compiler team which is like where he worked. Uh where the Intel compiler team wanted to kind of leverage auto vectorization to get vectorization done instead of introducing like a new program model in ISPC. And so he kind of I think alludes to this like what the problem with the auto vectorizer which is that the problem with the auto vectorizer is that as long as vectorization can fail and it will then if you’re programmer that actually cares about what code the compiler generates for your program you need to deeply understand the compiler uh you need to deeply understand the auto vectorizer. Then when it fails to vectorize code you want to be vectorized you need to either poke it in the right way or change your program in the right way so that it works for you again. And so this is like a very horrible way to program. And then you know if most of you are if any of you here are like very like into using SIMD instructions you probably also do not trust the auto vectorizer at all and you’re mostly just like you know writing intrinsics. And so um you know with a proper program model ideally the user is able to like learn what does and does not work without you know needing to be tied to the implementation. And then you know one compiler implements it and then the user can like learn to reliably rely on this optimization without needing to understand the compiler and only needing to understand the program model. And so one I guess way you can phrase this is that like a compiler optimization that always works is just part of the program model. Like for example when you’re writing like you know SIMD instructions you know that a SIMD instruction will like a SIMD intrinsic will always get mapped to SIMD instruction. And so that’s like part of your program model and not really an optimization that the compiler is doing. Um so I think you know to like when I see like the people kind of complaining about like Shogoths like and working with them I kind of am often reminded times of a compiler. We can imagine that like a compiler is often times just like a large piece of code that you know has been worked on by a lot of like very smart people and often times has a lot of like tricky details and implementation details in it. And so ideally when you’re doing a compiler like only the program model ends up being exposed to the user. So like the actual compiler implementation ends up being completely hidden and so the user only needs to deal with the like the nice program model which is usually just like the language that the compiler is compiling from. Um unfortunately you know when compilers fail and like you know they don’t apply the optimization that you want them to apply the kind of entire thing becomes exposed. And so you know this kind of applies anytime that you’re like struggling with it you’re wrestling with the compiler and you’re trying to understand why the compiler like did or did not like inline my code or things like this. Um and so to give some examples of like cases where like very kind of nuanced details here, I can lead the compilers like having like can lead the compiler struggling a lot. One of them here is like numerics in machine learning. When numerics can be like a kind of a like in general floating point like arithmetic is like a very cursed thing to deal with generally speaking and it’s just gotten even worse with the fact that like Nvidia and not just Nvidia like everybody in the industry keeps on pushing our data types lower to lower and lower bits. Uh where like on the V100 they kind of introduced like 16 point 16 bit operations as kind of the default and on A100 they introduced 8-bit operations and on the B100 you know they’re now pushing for like 4-bit operations like 4-bit floats and operations on a 4-bit floating point numbers. I think it’s reasonable to question like how is this even a floating point number at this point? But this is kind of what we’re typically dealing with. And so as a result of like this kind of like low precision and the fact that numerics end up being so subtle, you often times have like very annoying numerical problems to deal with and I think a good example for me that was like very frustrating was this kind of a NaN in like a flash attention implementation. Uh where the underlying cause here is something called an FMA or like a fused multiply accumulate. It can actually be like disastrously bad when it comes to your numerics. Uh so for folks who are unfamiliar with FMA, it’s basically like it takes like A, B, and C and it does like a similar operation that does A * B + C. And so one of the ways that FMA differs from normal operations is that in addition to be to being faster, it’s also usually computed in what people call like infinite precision internally. And what that means is that like the result of your FMA is like um um like the closest possible representation to like the true value of your FMA. And so this is different from like if you just did this operation separately where typically after your like multiply operation you would have like a rounding term where it becomes a a bit different. And so in this particular case we’re computing exponent of like AI * B - a maximum scale where max scale is like a like a a a single scalar value that’s a result of taking the maximum across AI * B. And so the main idea here is that exponent is like a very numerically unstable operation and so we want to make sure that like we’re we’re not taking any very large exponents. And so we’re like we’re always we’re subtracting out the maximum value so that all the exponents like the maximum size exponent they can get here is a zero. And you know the compiler in its wisdom uh hopefully rewrites this as like FMA of AI B negative max scale where you know it’s smart here it actually realizes that you know it can rewrite the plus into like a plus and minus. Uh but then you know if you actually look at the numerics that are going on here, a max scale is like a rounding of AI * B and we actually end up computing this uh quantity instead of like exponent of AI * B minus the rounded value of AI * B which can be disastrously off. And so the end result here is that we FMA’d with sorry we NaN’d with FMA is turned on and you don’t NaN with FMA is turned off even though FMAs are like theoretically a precision improving optimization. Uh the underlying cause here to summarize is basically that our numerical properties relied on two separate computations of the same value to be exactly identical. And in this case with FMAs can like you can apply them to only one branch but not the other branch and this leads to like NaNs that in this case. Uh another thing that algebraic another thing that compilers often times struggle with are what are called algebraic rewrites. So if you guys are familiar with this optimization called flash attention which kind of like fuses attention operators together, it actually relies on this like kind of subtle rewrite which people often times call like online softmax which is typically softmax is implemented with a couple of like um global synchronizers I suppose. But you can actually rewrite softmax in a way that removes one of those global synchronizations. But this requires like some amount of math. It’s not very hard math but it like compilers are generally very bad at math and so it’s kind of difficult for them to actually come up with this rewrite by themselves. And so I think flash attention is like a very good example of like an a case where you need to think about the programming model you expose for flash attention. Where you know before you with this like fused attention kernel came about, people just typically wrote attention like this. You know you did like one matmul and then did a softmax and then you did another matmul. And you know this was like not super efficient but it was as efficient as you could do so people were happy with it. Uh but unfortunately with flash attention, we now want to fuse like these three operations into a single kernel. And so like the difficulty is like you know what API do we expose for flash attention? Um and so one way you might like tackle this is a pattern matching. So you can like you know use a compiler you know try to find the sequence of matmul softmax and matmul and then you know pattern match it into a like single flash attention operator. And so this is like a reasonable reasonable option but the issue here is that it becomes very frustrating to debug. Like for example the user might like you know change how they write softmax. Like they might reimplement softmax with their like own softmax implementation instead of using your like vendor provided softmax implementation and then all of a sudden your pattern like no longer applies and you know they’re sad because like your the memory suddenly blows up and their code is like 3x slower. And so this is like is very frustrating for users to deal with. And so instead you know what PyTorch did in this case is that we just said okay like this these three operators now need to be fused into one operator. We’re just going to directly provide you that like single operator as a like you know a single API that you can call. And so this is one way typically that you can deal with like programming models is you can just kind of introduce a slightly higher level thing that like does one very specific thing for the users. Um but this is also still kind of frustrating. And and you might have some questions about like whether this is actually good enough to do because you know when you consolidate multiple more primitive APIs into a single more monolithic API, you often times run into issues where the monolithic API is now no longer able to represent all the things that users want to do. And so with attention we do see that people kept on coming out with like new attention variants like you know sliding window attention, Alibi, page attention, neighborhood attention all this kind of stuff. Like you know you look on Twitter and like a new four like attention papers come out every single week. And as a result like the kind of fused attention kernels that people have like keep on accumulating like new quirks. Like you know flash attention like this one’s from like the flash attention repo and now has like a dropout, a softmax scale, a causal one, a window size, a soft cap, Alibi slopes and so on. And you know some of these were like added in like the past couple months. And you know they just like keep on adding them. And so and once you’ve added like a quirk to an API or like to a program model, you can’t remove the quirk anymore because now that’s like you know breaking what users are relying on. And even worse like even though user like we’ve people have been kind of aggressive about adding new quirks, it still doesn’t end up being enough. Um And you have like all sorts of users who are like constantly complaining that you know like nobody has implemented flash attention for their you know pet attention variant. Like there’s no flash attention for prefix LM. This internal bottom one was from like a blog post from somebody who used to be at Google who was like complaining about the ecosystem outside of Google. Um and so basically the point here is that like a single monolithic operator is not actually always sufficient. And so we’re in a situation where compilers like we don’t really want to rely on a compiler uh to generate from scratch but it’s also painful to do modifications by hand. And so this might kind of seem like a like no win situation where you kind of must choose like one of them. You must choose either like unpredictability and you know do it as a compiler optimization or you must choose like a single monolithic API that’s like difficult for users to like modify. Um but you know this is kind of where you kind of can be clever and come up with like a new program model that wasn’t either of the program models that users had before. So like one thing to notice here is that this kind of custom kernel that’s difficult for a compiler to generate from scratch can actually be decomposed into like a handwritten slash like complicated flash attention kernel and a bunch of like trivial modifications from users that can actually be like very mechanically generated. Um and so here we have like an API that we’ve introduced recently called flash attention. And I really like flash attention and I think we’ve seen like very positive reception from the community including from I think users who like traditionally don’t actually like use compilers that much. And so one of the reasons that flash attention is like so liked even though it like it relies on torch compile to work is that like it’s guaranteed to always result in a single fused attention kernel and it’s always guaranteed to have like the same memory properties as a fused attention kernel. And so this means that like the user now has a programming model that they can rely upon where they can kind of program against this program model and I try a bunch of different variants that all like fit within this program model and the user does not need to understand how the like actual API is implemented under the hood. Um so to give you guys like a a bit of a like uh peak at how this API looks like, if you guys are familiar with like sliding window attention or causal causal attention, uh here you can kind of implement causal attention uh by checking whether like your query position is like greater than equal to your KB position, and then also checking whether the distance between them is like less than your sliding window size. And then you can just like and these masks together uh to get your like sliding window uh causal attention. And so the way we kind of think about uh like So this is I think a good example of where compilers uh can provide a lot of value uh for users even in these like super like uh handcrafted scenarios, uh where like previously prior to this API this attention API, uh you kind of have this situation where you had a bunch of these like you had this like big cube of like uh you know like masking or like uh positional biases, or whether it supported training or whether it was supported inference. And you know a bunch of these dots were like filled in with users who had manually implemented attention kernels. Uh but now with Flex Attention, uh every single one of these like dots are like now filled in. And so it’s now like consistently users can rely rely upon fused attention kernels uh regardless of like the attention variant uh that they’re using. And to give some example of like some stuff that users have been doing with a Flex Attention, uh you have on the left like a bit of a fun attention mask uh from like some guy on Twitter, uh where basically in this case he had like a bunch of like molecular graph. Uh like uh molecular graphs of like different sizes, and he was able to convert this into like an attention mask uh that Flex Attention supported. Um, and so this is like a a very weird mask that you know I definitely did not think about uh when developing Flex Attention. But when you develop like I think a good program model for users, uh users are often times able to do things that like you never considered doing uh when you develop the abstraction. Um, another kind of analogy that that I sometimes think about when it comes to like you know optimizations versus program models, um is there’s kind of like a famous uh I think quote from Grothendieck uh about like you know tackling math problems, uh where he said that you know when you tackle like a math problem, you know you imagine a math problem as like a walnut. And it’s like you know when you’re trying to open the walnut, you can either tackle it by like just hitting the walnut a bunch and like opening it up, or you can kind of you know like soak the walnut in water and you know kind of you know like uh establish like new like mental models for how to think about the problem until eventually like the walnut just opens by itself after being soaked in water. Um, and so I think about program models like uh very similarly. Where like you if you think about like auto vectorization from like the Intel compiler folks’ perspective, uh you know he had this kind of fun anecdote where like he said that the Intel people kept on asking like what happens when the CUDA compiler fails to vectorize. And he was like you know absolutely baffled about you know what happens like you like you felt like this was just a thing he needed to understand to like you know understand you know how the Intel people needed to change their compiler uh to be more competitive. But the kind of like misunderstanding here is that like uh it’s almost a nonsensical question to even ask like when does a CUDA compiler fail to like make your code parallel? Because if you’ve written your code in CUDA and in the CUDA programming model, it it like must be parallel. Like it’s kind of like a axiomatic part of the program model. Like there’s no way you can write a CUDA program that uh does not execute in or like does not execute across multiple cores. Like you can write CUDA pro you can write CUDA programs that are like incorrect or deadlock, uh but they’re still guaranteed to always run across multiple cores. And this is because the parallelism is inherent to the program model in CUDA, while the parallelism is not inherent to the program model uh of auto vectorization. Um I think uh uh Yeah, I I I think kind of ML compilers and kind of you know how how they apply to like you know just generally optimizing GPUs, I think like is already quite difficult. Uh but you have a lot more issues when it comes to like making GPUs uh run at scale. Um, and so when you come to like you know getting GPUs to run at scale, I do think there are a couple of like kind of interesting differences between uh distributed ML programs uh versus like traditional distributed uh systems. And so one of the differences is that traditional distributed systems often times are just trying to scale a QPS. Uh and so when you like trying to scale QPS, you have a lot of like very small interchangeable queries. And like the performance per hardware is like not like that critical. And that often times you’re willing to like double your amount of hardware used in order to get like fault tolerance or like you know like higher uptime or things like this. On the other hand in ML systems, uh we just basically have like a single query. Like you know you’re just like training your ML job. Uh we have frequent global synchronizations across all of our hardware. Um, and uh performance is like extremely important. You know so much so that like there’s no way we could tolerate like a 2x a loss in performance uh for basically uh any reason. And so um one way to kind of think about how you can parallelize programs, and this is actually like not specific to ML. This is kind of a general like a way to think about parallelizing programs, is that you can think that there’s basically a like you have like this cube of computation to do. Uh where one of the dimensions is like the um like the batch dimension. So it’s like different tasks uh that you can perform. You also have like the within task dimension. So like the different operations within a single task. And then finally you have like the time dimension, uh which is what your like what your uh like GPU is doing at any given point in time. And so you can think of data parallelism as basically just splitting along like the batch dimension. Uh you can think about like task parallelism as splitting across the task dimension. And then you can think about pipeline parallelism as kind of like splitting across the time dimension. Um Yeah. And so for data parallelism, uh like the basic idea I think is like pretty simple, which is just that we have like a bunch of different tasks. We have a bunch of parallelism. And so each GPU just handles a different task. And so because of that, uh we just have like we just put one task on each GPU and and then run it. And this seems like very nice and very trivial. Uh but the issue is that like we have a massive synchronization of parameters after every step. Uh because after every step that we do, um like the gradients need to be like synchronized across all of your hardware. Um, and you can’t just like uh avoid synchronizing your your parameters uh because you also have this kind of like uh um mathy constraint uh from like the ML training side, uh which is that you simply can’t train with like too large of a batch size, or or your model just uh like will not converge properly. Um, there’s also kind of another detail here where like you can’t just like naively replicate your parameters. Uh people often times use uh like what’s called a fully sharded data parallel. Um, the the second kind of parallelism that you have is like task parallelism, which is commonly known as tensor parallelism. And in this case you basically just have two GPUs that uh split the same task. And so the main problem that you run into here that’s kind of specific to ML, is that like task parallelism because you’re like running the same task, uh does not often times have like uh obvious ways to overlap your communication with your computation. Uh but nevertheless, because the performance is so important, uh we still really want to overlap the communication. And so there ends up being like a lot of kind of uh involved things that people do to try to improve the overlap. Uh one of them here is called like async tensor parallelism, uh where the general idea here is that you have like a communication op followed by like a computation operation. And so usually you while you’re doing the communication, uh the GPU can’t do any computation. And so this is like you know wasted uh idle time on a GPU. Uh but the ob- observation here about async tensor parallelism, is you can actually kind of like mini pipeline uh your like communication and your matmul. Where even within like a single tensor operation, uh there’s still often times like a a batch dimension that you can parallelize along. And so by like doing this kind of like micro pipelining of your like uh communication and computation, you can actually uh still enable overlap uh even though like we’re doing uh tensor parallelism. Uh finally, the kind of a last kind of a parallelism that you have is a pipeline parallelism, uh where the general idea here is that you assign like the first part of the task to the first GPU, and the second part of the task to the second GPU. Uh this kind of differs from uh tensor parallelism or like task parallelism, in that in this case uh you do not run on the same task with both GPUs at the same time. Uh and so in this way you can think of the task is actually being sharded across like the time dimension. And so there are a couple like uh additional wrinkles here about pipeline parallelism uh that are kind of unique to ML I think. Uh so one of them is that you know once again the frequent massive synchronizations uh prevent us from like filling up the pipeline. Uh but the second issue is that like a back propagation like the you know the forwards and backwards pattern that I mentioned before, uh also adds like a lot of very fun wrinkles uh to your pipeline parallelism schedule. And then in this case uh this is kind of like a a pipeline like you know that people have designed, where the blue boxes are like the forward pass, uh the cyan boxes are like the backwards pass, and then the green boxes are like uh the backwards pass uh split up even more. Um, and and in this case like the the main uh kind of nuance here, is that you can see at like there’s various points here uh where your pipeline can actually choose to either run a forwards uh micro batch, or like a backwards micro batch. And so this kind of like uh choice is not something you need to deal with uh in like a traditional pipeline parallelism setting. And it ends up leading to like a a lot of different um like optimizations that people do with pipeline parallelism. Um, and so kind of like putting this all together, uh this is a diagram from the Llama 3 paper where where showed like what’s kind of being done uh to train Llama 3. And so in this case you can see that we we’re combining like, you know, tensor parallelism and task parallelism with data parallelism uh with pipeline parallelism. And there’s also a fourth one thrown in here called context parallelism. Uh you can think of that as just like another form of a task uh task or tensor parallelism. Um yeah. And and so I think again here like compilers I think have historically struggled quite a bit when it comes to keeping up with the distributed optimizations that people do. And the main reason is that like compilers are dumb and humans are smart. And what I mean by this is that for any given set of parallelism schemes or like any different any given set of parallelism configs, it often is pretty feasible to come up to create like a automatic model for the compiler to determine how to automatically like parallelize your your program. Uh but the issue here is that actually that like most of the innovation in parallelism uh doesn’t come from like searching within your existing search space. It usually comes from like expanding your search space uh along a new dimension. And so this is something that compilers are not like super great at and it’s been one of the struggles when people kind of try to like, you know, automate parallelism uh with with uh with compilers or like general systems. Um and yeah. For example, like one of the additional like kind of wrinkles that have kind of shown up uh like, you know, at certain scales uh is a fault tolerance uh where, you know, when you’re running like, you know, on like, you know, 10 or 20,000 GPUs uh like what we’re doing is like, you know, globally full synchronizations across all of our GPUs. Uh we have this issue where our GPUs can fail for like all sorts of reasons. Some of them might be like user related, some of them might be like hardware related, some of them might be networking related. Uh but basically like a single failure takes down your entire run and this is like, you know, quite problematic. Uh and so this table is from like a paper that uh Meta published about like training Llama 3 um and fault tolerance. And um like, you know, I I I I think like one kind of example of like how this ends up being like interesting issue is that when you’re training on like 16,000 GPU hours, uh you you’re only getting like a error once every 1.8 hours. But you know, now if you like scale it up to 131,000 GPUs, uh you now get a failure like every like 15 minutes or so. And and so like this turns something that might not be that problematic at 16,000 scale or smaller uh into something that’s like very problematic uh at a 131,000 GPU scale or even higher. Uh where you can imagine that, you know, now you have a situation where you might not be able to make even a single step uh before a single GPU in your entire fleet uh fails. Um so to kind of conclude, you know, I I think ML has basically become the single most important computational workload in the world uh over the last decade. Uh may maybe I’m a bit biased uh but I I I think by like the amount of like flops and investment I I I think you can make a strong case for for it. Um but the kind of characteristics I think of the workload uh both from a social POV like, you know, the massive research required as well as a technical POV, I think kind of often does mean that a lot of the traditional approaches to building systems or compilers uh don’t directly apply. Like you can’t just like, you know, build a compiler and then, you know, have people who just optimize a compiler for like 5 years uh without thinking about, you know, how like how the workloads are evolving on top of them uh affect the compiler. And so I I think to me like the kind of most interesting question about building systems here isn’t really about like building the right optimizations for systems. It’s instead about like coming up with the right programming models uh for expressing like your systems. And kind of coming up with the right programming models to enable people to kind of like do their own optimizations and kind of like build the next uh you know, 100,000 GPU model. Thanks for coming to my talk. Uh hope you guys liked it.