JAX Ecosystem

Friday, May 20, 2022

JAX 라이브러리 관련 자료들을 읽고 정리합니다.

Summary

  • JAX: PyTorch, Tensor, MXNet보다 빠르고, NumPy 같은 array based 계산과 여러 파이썬 수학 기능들을 편리하게, 병렬 처리 가능하도록 사용할 수 있도록 구현한 라이브러리. Google Research가 개발하였음.
  1. Differentiation: Python 함수가 주어지면 grad 기능을 통해 자동으로 미분식을 알아내고,
  2. Vectorisation: vamp, pmap등의 기능을 통해 병렬화가 가능하며,
  3. JIT-compilation: CPU, GPU, TPU 등, 사용 가능한 accelerators를 쉽게 조절 할 수 있도록, 그리고 빠른 컴퓨팅 속도를 위해 just-in-time compile (C++ 처럼 컴파일 후 실행하는게 아닌, 실제 실행하는 시점에 컴파일)할 수 있도록 XLA 를 사용
  • Haiku: JAX의 함수형 패러다임을 활용하면서도 OOP 기반 neural network 모델을 편리하게 사용할 수 있도록 구현한 라이브러리
  • 이 외에도 JAX 관련 라이브러리로는 Optax (gradient 변환, optimizer 관련), RLax (RL 관련), Chex (실험 테스팅 관련), Jraph (GNN 관련) 등이 있음

Comments

  • 정리해보자면, 기존엔 cpu만 사용할 수 있었던 numpy계산을 gpu 위에서도 가능하도록 구현했고(물론 cupy가 있기는 했지만), torch나 tensorflow라는 프레임워크 위에서 한정된 실험을 하는 것이 아니라 단순 파이썬 코드 혹은 numpy 코드를 가지고도 ML 실험이 가능하도록 구현된 라이브러리라고 할 수 있음
  • 파이썬, Numpy 기본 기능들만 가지고도 ML 실험이 매우 빠르게 가능하도록 한 것이 의미가 있다고 생각하고, JAX를 사용해보지는 않았지만 최근에 Google, DeepMind 코드에서 JAX와 Haiku가 매우매우 자주 보이는 데에는 이유가 있을 것이라는 생각이 듦

References

Data Augmentation for Deep Learning

PaLM: Pathway Language Model