Vad är Google JAX? Allt du behöver veta

Google JAX eller Just After Execution är ett ramverk utvecklat av Google för att påskynda maskininlärningsuppgifter.

Du kan betrakta det som ett bibliotek för Python, som hjälper till med snabbare exekvering av uppgifter, vetenskaplig beräkning, funktionstransformationer, djupinlärning, neurala nätverk och mycket mer.

Om Google JAX

Det mest grundläggande beräkningspaketet i Python är NumPy-paketet som har alla funktioner som aggregering, vektoroperationer, linjär algebra, n-dimensionell array och matrismanipulationer och många andra avancerade funktioner.

Tänk om vi ytterligare kunde påskynda beräkningarna som utförs med NumPy – särskilt för stora datamängder?

Har vi något som skulle kunna fungera lika bra på olika typer av processorer som en GPU eller TPU, utan några kodändringar?

Vad sägs om om systemet kunde utföra komponerbara funktionstransformationer automatiskt och mer effektivt?

Google JAX är ett bibliotek (eller ramverk, som Wikipedia säger) som gör just det och kanske mycket mer. Den byggdes för att optimera prestanda och effektivt utföra maskininlärning (ML) och djupinlärningsuppgifter. Google JAX tillhandahåller följande transformationsfunktioner som gör den unik från andra ML-bibliotek och hjälper till med avancerad vetenskaplig beräkning för djupinlärning och neurala nätverk:

  • Automatisk differentiering
  • Auto vektorisering
  • Automatisk parallellisering
  • Just-in-time (JIT) sammanställning

Google JAX unika funktioner

Alla transformationer använder XLA (Accelerated Linear Algebra) för högre prestanda och minnesoptimering. XLA är en domänspecifik optimerande kompilatormotor som utför linjär algebra och accelererar TensorFlow-modeller. Att använda XLA ovanpå din Python-kod kräver inga betydande kodändringar!

Låt oss utforska i detalj var och en av dessa funktioner.

Funktioner i Google JAX

Google JAX kommer med viktiga komponerbara transformationsfunktioner för att förbättra prestanda och utföra djupinlärningsuppgifter mer effektivt. Till exempel, automatisk differentiering för att få gradienten för en funktion och hitta derivator av valfri ordning. På samma sätt, automatisk parallellisering och JIT för att utföra flera uppgifter parallellt. Dessa transformationer är nyckeln till applikationer som robotik, spel och till och med forskning.

En komponerbar transformationsfunktion är en ren funktion som omvandlar en uppsättning data till en annan form. De kallas komponerbara eftersom de är fristående (dvs. dessa funktioner har inga beroenden med resten av programmet) och är tillståndslösa (dvs samma indata kommer alltid att resultera i samma utdata).

Y(x) = T: (f(x))

I ovanstående ekvation är f(x) den ursprungliga funktionen på vilken en transformation tillämpas. Y(x) är den resulterande funktionen efter att transformationen har tillämpats.

Till exempel, om du har en funktion som heter ’total_bill_amt’ och du vill ha resultatet som en funktionstransform, kan du helt enkelt använda den transformation du vill, låt oss säga gradient (gradient):

grad_total_bill = grad(total_bill_amt)

Genom att transformera numeriska funktioner med funktioner som grad() kan vi enkelt få deras högre ordningsderivator, som vi kan använda i stor utsträckning i optimeringsalgoritmer för djupinlärning som gradientnedstigning, vilket gör algoritmerna snabbare och effektivare. På liknande sätt, genom att använda jit(), kan vi kompilera Python-program just-in-time (lata).

#1. Automatisk differentiering

Python använder autograd-funktionen för att automatiskt skilja mellan NumPy och inbyggd Python-kod. JAX använder en modifierad version av autograd (dvs. grad) och kombinerar XLA (Accelerated Linear Algebra) för att utföra automatisk differentiering och hitta derivator av valfri ordning för GPU (Graphic Processing Units) och TPU (Tensor Processing Units).]

Snabbnotering om TPU, GPU och CPU: CPU eller Central Processing Unit hanterar alla operationer på datorn. GPU är en extra processor som förbättrar datorkraften och kör avancerade operationer. TPU är en kraftfull enhet speciellt utvecklad för komplexa och tunga arbetsbelastningar som AI och algoritmer för djupinlärning.

På samma sätt som autograd-funktionen, som kan differentiera genom loopar, rekursioner, förgreningar och så vidare, använder JAX funktionen grad() för gradienter i omvänt läge (backpropagation). Vi kan också differentiera en funktion till vilken ordning som helst med grad:

grad(grad(grad(sin θ))) (1.0)

Automatisk differentiering av högre ordning

Som vi nämnde tidigare är grad ganska användbar för att hitta de partiella derivatorna av en funktion. Vi kan använda en partiell derivata för att beräkna gradientnedgången för en kostnadsfunktion med avseende på de neurala nätverksparametrarna vid djupinlärning för att minimera förluster.

Beräknar partiell derivata

Anta att en funktion har flera variabler, x, y och z. Att hitta derivatan av en variabel genom att hålla de andra variablerna konstanta kallas en partiell derivata. Anta att vi har en funktion,

f(x,y,z) = x + 2y + z2

Exempel för att visa partiell derivata

Den partiella derivatan av x kommer att vara ∂f/∂x, vilket berättar hur en funktion ändras för en variabel när andra är konstanta. Om vi ​​utför detta manuellt måste vi skriva ett program för att differentiera, tillämpa det för varje variabel och sedan beräkna gradientnedgången. Detta skulle bli en komplex och tidskrävande affär för flera variabler.

Automatisk differentiering bryter ner funktionen i en uppsättning elementära operationer, som +, -, *, / eller sin, cos, tan, exp, etc., och tillämpar sedan kedjeregeln för att beräkna derivatan. Vi kan göra detta i både framåt- och backläge.

Detta är inte det! Alla dessa beräkningar sker så snabbt (tänk på en miljon beräkningar liknande ovanstående och den tid det kan ta!). XLA tar hand om hastigheten och prestandan.

#2. Accelererad linjär algebra

Låt oss ta den föregående ekvationen. Utan XLA kommer beräkningen att ta tre (eller fler) kärnor, där varje kärna kommer att utföra en mindre uppgift. Till exempel,

Kärna k1 –> x * 2y (multiplikation)

k2 –> x * 2y + z (tillägg)

k3 –> Reduktion

Om samma uppgift utförs av XLA:n, tar en enda kärna hand om alla mellanliggande operationer genom att fusionera dem. De mellanliggande resultaten av elementära operationer streamas istället för att lagra dem i minnet, vilket sparar minne och ökar hastigheten.

#3. Just-in-time sammanställning

JAX använder internt XLA-kompilatorn för att öka exekveringshastigheten. XLA kan öka hastigheten på CPU, GPU och TPU. Allt detta är möjligt med JIT-kodexekveringen. För att använda detta kan vi använda jit via import:

from jax import jit
def my_function(x):
	…………some lines of code
my_function_jit = jit(my_function)

Ett annat sätt är att dekorera jit över funktionsdefinitionen:

@jit
def my_function(x):
	…………some lines of code

Denna kod är mycket snabbare eftersom transformationen kommer att returnera den kompilerade versionen av koden till den som ringer i stället för att använda Python-tolken. Detta är särskilt användbart för vektorindata, som arrayer och matriser.

Detsamma gäller för alla befintliga python-funktioner. Till exempel funktioner från NumPy-paketet. I det här fallet bör vi importera jax.numpy som jnp snarare än NumPy:

import jax
import jax.numpy as jnp

x = jnp.array([[1,2,3,4], [5,6,7,8]])

När du väl har gjort detta ersätter JAX-arrayobjektet som kallas DeviceArray standardmatrisen NumPy. DeviceArray är lat – värdena hålls i acceleratorn tills de behövs. Detta innebär också att JAX-programmet inte väntar på att resultaten ska återgå till det anropande (Python) programmet, och följaktligen efter en asynkron sändning.

#4. Automatisk vektorisering (vmap)

I en typisk maskininlärningsvärld har vi datauppsättningar med en miljon eller fler datapunkter. Troligtvis skulle vi utföra några beräkningar eller manipulationer på var och en av eller de flesta av dessa datapunkter – vilket är en mycket tid- och minneskrävande uppgift! Om du till exempel vill hitta kvadraten för var och en av datapunkterna i datamängden är det första du skulle tänka på att skapa en loop och ta kvadraten en efter en – argh!

Om vi ​​skapar dessa punkter som vektorer, kan vi göra alla rutor på en gång genom att utföra vektor- eller matrismanipulationer på datapunkterna med vår favorit NumPy. Och om ditt program kunde göra detta automatiskt – kan du begära något mer? Det är precis vad JAX gör! Den kan automatiskt vektorisera alla dina datapunkter så att du enkelt kan utföra alla operationer på dem – vilket gör dina algoritmer mycket snabbare och effektivare.

JAX använder vmap-funktionen för autovektorisering. Tänk på följande array:

x = jnp.array([1,2,3,4,5,6,7,8,9,10])
y = jnp.square(x)

Genom att göra bara ovanstående kommer kvadratmetoden att köras för varje punkt i arrayen. Men om du gör följande:

vmap(jnp.square(x))

Metodkvadraten kommer endast att köras en gång eftersom datapunkterna nu vektoriseras automatiskt med vmap-metoden innan funktionen exekveras, och looping pressas ner till den elementära operationsnivån – vilket resulterar i en matrismultiplikation snarare än skalär multiplikation, vilket ger bättre prestanda .

#5. SPMD-programmering (pmap)

SPMD – eller Single Program Multiple Data-programmering är viktigt i djupinlärningssammanhang – du använder ofta samma funktioner på olika uppsättningar data som finns på flera GPU:er eller TPU:er. JAX har en funktion som heter pump, som möjliggör parallell programmering på flera GPU:er eller vilken accelerator som helst. Precis som JIT kommer program som använder pmap att kompileras av XLA och exekveras samtidigt över systemen. Denna automatiska parallellisering fungerar för både framåt- och bakåtberäkningar.

Hur fungerar pmap

Vi kan också tillämpa flera transformationer på en gång i valfri ordning på vilken funktion som helst som:

pmap(vmap(jit(grad (f(x)))))

Flera komponerbara transformationer

Begränsningar för Google JAX

Google JAX-utvecklare har väl tänkt på att påskynda algoritmer för djupinlärning samtidigt som de har introducerat alla dessa fantastiska transformationer. De vetenskapliga beräkningsfunktionerna och paketen är på linje med NumPy, så du behöver inte oroa dig för inlärningskurvan. JAX har dock följande begränsningar:

  • Google JAX är fortfarande i de tidiga utvecklingsstadierna, och även om dess huvudsakliga syfte är prestandaoptimering, ger det inte mycket nytta för CPU-beräkningar. NumPy verkar prestera bättre, och att använda JAX kanske bara ökar omkostnaderna.
  • JAX är fortfarande i sin forskning eller tidiga stadier och behöver mer finjustering för att nå infrastrukturstandarderna för ramverk som TensorFlow, som är mer etablerade och har mer fördefinierade modeller, projekt med öppen källkod och läromedel.
  • Från och med nu har JAX inte stöd för Windows operativsystem – du skulle behöva en virtuell maskin för att få det att fungera.
  • JAX fungerar bara på rena funktioner – de som inte har några biverkningar. För funktioner med biverkningar kanske JAX inte är ett bra alternativ.

Hur man installerar JAX i din Python-miljö

Om du har python-installation på ditt system och vill köra JAX på din lokala dator (CPU), använd följande kommandon:

pip install --upgrade pip
pip install --upgrade "jax[cpu]"

Om du vill köra Google JAX på en GPU eller TPU, följ instruktionerna på GitHub JAX sida. För att ställa in Python, besök python officiella nedladdningar sida.

Slutsats

Google JAX är utmärkt för att skriva effektiva algoritmer för djupinlärning, robotik och forskning. Trots begränsningarna används den flitigt med andra ramverk som Haiku, Lin och många fler. Du kommer att kunna uppskatta vad JAX gör när du kör program och se tidsskillnaderna i exekvering av kod med och utan JAX. Du kan börja med att läsa officiell Google JAX-dokumentationvilket är ganska omfattande.