How do you scale transformer context lengths over multiple machines?
Context length is such an important aspect in today’s AI race. All major players actively advertise this too. Given how matrix math works, how do people run inference for a transformer when the context length is so long that you can’t fit it on one gpu / one machine ?
Not sure how they do it specifically for LLMs, but you can do what is called model or tensor parallelism where you can split a layer over multiple GPUs or even nodes.
If you look under the hood it's the same distributed matrix multiplication stuff with MPI, as far as I know.
I think Deepspeed has bespoke transformer kernels which handle this stuff specifically.
I'm a few months off on the latest, but one way used to be to start summarizing history if the context did start becoming huge: summarize the earliest n-k messages, keep the last k.
Quantizing it down to 8 bits seems to be one solution. TensorRT-LLM does this (and I think requires an H100)? exLlama also does this on much lesser hardware.