Long Context로 인한 Large KV Cache의 문제점과 해결 방안: Part I-KV cache의 메모리 요구량

daewoo kim
9 min readFeb 4, 2024

Auto-regressive 모델이란 이전 단계의 출력들을 이용하여 다음 단계의 출력을 예측하는 모델이다. GPT는 auto-regressive 모델로 이전에 생성된 토큰를 기반으로 다음 토큰을 생성한다. GPT는 이전 토큰 생성 시 발생된 중간값인 activations(e.g. KV cache)를 캐싱하여 이전 토큰 값을 재계산하기 위한 GPU의 FLOPS를 절감하는 대신, KV cache을 위한 추가적인 메모리 공간이 필요하다. LLM의 context window size가 증가할수록 KV cache의 크기 또한 선형적으로 증가하므로 context window size는 메모리 용량에 제한을 받는다. 본 포스트는 LLM이 long context를 지원할 경우 KV cache 메모리 요구량이 급격하게 증가하면서 발생하는 메모리 용량 증가 문제에 대해서 소개한다.

LLM의 Context Window Size의 증가

필자가 예전 포스트(LLM의 Context Window Size가 크다고 좋은 것일까?)에서 설명했던 것과 같이 LLM의 context window size는 LLM의 활용성을 높이는데 중요한 역할을 담당한다. 이로 인해 gpt-3.5-turbo(ChatGPT)가 최초 출시한 2022년 말엔 context window size가 4K tokens을 지원하였지만 2023년 11월에 발표된 gpt-4-turbo는 128K tokens를 지원하기 시작하였다. 즉, GPT의 context window size는 무려 1년만에 32배 증가하였다. 그러나 LLM가 long context를 지원할수록 GPU의 FLOPS와 메모리 사용량을 급격하게 증가시켜 production 수준에서 리소스 부족 문제를 발생시킨다.

주요 LLM별 context window size 비교(2024/01 현재. by author)

KV cache란?

KV Cache란 토큰 생성 시 계산되는 Key/Value 텐서를 GPU 메모리에 저장한 후 재사용하는 것으로 이전 토큰의 Key/Value 텐서를 재계산되는 것을 막아 연산량을 줄이는 방법이다.

KV cache가 연산량을 줄이는 방법 (ref: https://medium.com/@joaolages/kv-caching-explained-276520203249)

KV Caching은 compute & memory trade-off의 대표적인 예로 컴퓨팅 양을 줄이는 대신 생성된 Key 텐서와 Value 텐서를 버리지 않고 저장해야 하기 때문에 메모리 사용량이 증가한다. KV Caching에 필요한 메모리 양은 context window size와 batch size에 의해 결정된다.

KV Cache의 메모리 요구량 계산

KV cache(MHA)의 메모리 요구량은 다음과 같이 계산식으로 계산할 수 있다.

KV cache(MHA)의 메모리 요구량 (by author)

LLaMA2의 모델 specification을 이용하여 MHA(Multi-Head Attention)의 KV Cache의 메모리 요구량을 계산하면 다음과 같이 batch size와 sequence length(=context window length)의 식으로 구성됨을 확인할 수 있다.

(Note: LLAMA2–70B은 원래 GQA(Grouped-Query Attention)를 사용하였다. 본 포스트에서는 KV cache가 요구하는 메모리 양이 매우 큼을 확인하기 위해 GQA로 최적화하기 전인 MHA를 사용한 경우를 고려하였다.)

LLaMA2의 모델 버전별 KV cache 메모리 요구량(FP16기준. by author)

아래 그림은 LLaMA2-70B가 MHA를 사용하였을 때 sequence length와 batch size별 KV cache의 메모리 요구량의 변화를 나타낸 것이다. LLaMA2는 기본적으로 sequence length=4K를 지원하므로 batch 당 KV cache의 메모리 요구량은 1.25GB인 것을 알 수 있다. 만일 LLaMA2의 sequence length를 128K로 증가시킨다면 KV cache를 위해 batch당 무려 40GB가 필요하다!! 특히 long context 조건에서KV cache의 메모리 요구량은 batch size에 비례하여 급격하게 증가함을 알 수 있다. 예를들어 sequence length=128K & batch size=32인 경우 KV cache는 1TB이 넘는 것을 알 수 있다.

LLaMA2–70B의 sequence length별 KV Cache 용량 (by author)

LLM의 sequence length는 지속적으로 커지고 있어 LLM serving 시 long context를 처리하는 것은 매우 중요한 이슈가 되고 있다. 아래 그림은 single node(A100–80GB x8. 총 GPU 메모리 용량: 640GB )를 기준으로 LLaMA2–70B 모델 serving에 필요한 메모리 용량(weight + KV cache)을 나타낸 것이다. Sequence length가 4K일 경우, single node 수준에서 batch size=256을 처리할 수 있는 반면(weight + KV cache:460GB < 총 GPU 메모리 용량: 640GB), sequence length가 128K일 경우, singe node 수준에서 겨우 batch size=8을 처리할 수 밖에 없다. 이와 같은 결과를 통해 다음과 같은 사실을 알 수 있다.

  1. Long Context와 Large batch size인 조건에서 KV cache가 weight보다 훨씬 더 많은메모리를 소비한다.
  2. KV cache의 메모리 소비량이 매우 커질 경우, 추론 시 GPU 메모리 용량이 bottleneck으로 작용할 수 있음을 의미한다.
LLaMA2–70B의 Weight + KV cache의 메모리 용량 (by author)

LLM serving 시 sequence length와 batch size결정하기

LLM의 weight는 고정된 값인 반면 KV cache는 sequence length와 batch size에 따라 변화한다. GPU 메모리는 대부분 weight와 KV cache로 채워지며 GPU 메모리 용량에 따라 지원하는 sequence length와 batch size가 결정된다. 따라서 토큰 당 KV cache의 용량을 알 수 있다면 GPU 메모리 용량에 따른 지원 가능한 sequence length과 batch size를 계산할 수 있다. LLaMA2–13B(MHA)을 single GPU(A100-80GB)에서 서비스한다면 지원 가능한 sequence length와 batch size는 다음과 같다.

Single GPU에서 serving할 수 있는 최대 sequence length와 최대 batch size (LLaMA2–13B기준. FP16) (by author)

(1) 최대 sequence length

  • batch size=1일 경우, 0.82MB*1*seq_len=54GB이므로 최대 sequence length = (approx.) 65854이다.

(2) 최대 batch size

  • sequence length=4K일 경우, 0.82MB*batch_size*4096=54GB이므로 batch size = (approx.) 16이다.

만일 A100 GPU(80GB)가 아닌 A100 GPU(40GB)를 사용한다면 메모리 용량의 제약으로 최대 sequence length와 최대 batch size는 모두 1/4 가량 줄어드는 것을 확인할 수 있다.

GPU 메모리 용량(40GB vs 80GB)에 따른 LLaMA2–13B의 max. sequence length와 max. batch size의 변화 (by author)

결론

LLM이 long context를 지원할 경우, KV cache가 급격하게 커지면서 GPU 메모리 용량에 따라 추론 시 LLM의 최대 sequence length와 최대 batch size가 결정되는 것을 확인하였다. Long context를 처리하거나 생성을 해야 할 때 GPU 메모리 부족 문제는 batch 처리를 어렵게 만들어 하드웨어 효율성을 낮추는 문제를 초래한다. 이러한 관점에서 제한된 GPU 메모리 용량를 효율적으로 사용하기 위해 다음과 같은 여러가지 방법을 사용할 수 있다.

  • 모델 weight의 memory footprint를 줄이는 방법 (e.g. quantization)
  • KV cache의 memory footprint를 줄이는 방법(e.g. GQA, MQA)
  • Model Parallelism을 사용하여 모델을 여러 GPU로 분할 처리하는 방법 (e.g. tenosr parallelism 등)

다음 포스트에서는 LLM inference 최적화를 위해 KV cache의 memory footprint를 줄이는 방법에 대해서 알아보도록 할 예정이다.

레퍼런스

[1] Mastering LLM Techniques: Inference Optimization

[2] EfficientML.ai Lecture 12 — Transformer and LLM Part-1

[3] EfficientML.ai Lecture 13 — Transformer and LLM Part-2

[4] Efficient Large Language Model Inference

[5] Transformers KV Caching Explained

--

--

daewoo kim

AI developer & Author | Working@semiconductor-industry. I write and share about what I learn.