Summary
Overview
Work History
Education
Skills
Timeline

Kunjan Patel

Seattle,WA

Summary

I am an Machine Learning Software Engineer with experience in JAX and XLA high-performance computing at Google. I have special experience in image and diffusion models. I also have experience contributing to OSS projects and working in OSS community. I have an deep understanding of accelerator architecture (TPU/GPU), Multivariate calculus behind AI and low level kernel programming in accelerators. I am very passionate about ML perfomance and outside work I am active in the field. I have given talks in Eleuther AI's ML performance reading group, the recordings are on their youtube channel. I am also active on several discord servers such as GPU mode and jax helping answer questions related to XLA, TPUs, Jax, pallas and also learning from the community.

Overview

6
6
years of professional experience

Work History

Machine Learning Software Engineer

Google
10.2021 - Current

I worked on model bring up and optimization of mainly OSS diffusion and some MOE models. My team and my main responsibility is implementing the forward pass, training pipeline in JAX and optimizing its performance on TPUs and GPUs.

  • Worked in partnership Nvidia's MLE's to optimize SDXL and Flux models for H100 and H200 GPUs
  • This involved performing roofline analysis, calculating MFU looking at perfetto and xprof traces to find bottlenecks.
  • For SDXL I focused on optimizing U-net, this involved experimenting with custom pallas kernel for convolutions.
  • Other improvements involved fixing gaps between train steps, fixing MFU calculation and long jax jit compile times due to constant folding.
  • For flux this involved analyzing the xprof to find and fix the shardings to removed unexpected activation all gathers for FSDP parallelism.

Worked on Optimizing WAN 2.1 diffusion model training and inference on latest.

  • Read wan training technical report and began experiments for supporting long context length. Ran experiments with multiple sharding strategies ( tensor parallelism, context parallelism, sequence parallelism, hybrid parallelism).
  • Experimented with oss ring attention in lax and jax, identified performance issues and implemented fused kernel of ring and flash attention in Pallas.

Previously worked on Distributed Inference Team

  • Experimented and developed solution for disaggregated inference
  • Contributor to vllm
  • Worked on smart router for lora and did perfomance analysis of dynamic lora switching, roofline analysis of GMMV and BGMV punica kernels in VLLM.


Software Engineer

Tanium
05.2021 - 10.2021

Worked on Tanium Cloud team , wrote bazel libraries to support tanium cloud product

Software Engineer

Varmour
10.2019 - 04.2021

Joined as new grad and implemented a Kafka pipeline to ingest AWS logs and using a logical model create firewall policies.

Education

Bachelor of Science - Computer Science

University of California, Los Angeles, Los Angeles, CA
05-2019

Skills

  • JAX

  • Pallas

  • Pytorch

  • Golang

  • Multivariable Calculus

  • Distributed Systems

  • High-performance computing

  • Hardware Software co design

Timeline

Machine Learning Software Engineer - Google
10.2021 - Current
Software Engineer - Tanium
05.2021 - 10.2021
Software Engineer - Varmour
10.2019 - 04.2021
University of California, Los Angeles - Bachelor of Science, Computer Science
Kunjan Patel