Hur man använder NumPy argmax()-funktionen i Python

By rik

I den här genomgången kommer du att lära dig att utnyttja NumPy-funktionen argmax() för att lokalisera positionen (indexet) för det största värdet i numeriska datamängder, så kallade arrayer.

NumPy är ett vitalt bibliotek inom Python för vetenskapliga beräkningar. Det erbjuder N-dimensionella arrayer, vilka är mer effektiva än vanliga Python-listor. En ofta utförd uppgift när man jobbar med NumPy-arrayer är att identifiera det största värdet inom arrayen. Ibland kan det vara nödvändigt att veta var exakt det största värdet återfinns.

Funktionen argmax() hjälper dig att hitta detta index i både en- och flerdimensionella arrayer. Låt oss dyka djupare och se hur det fungerar.

Hur man identifierar positionen (indexet) för det största elementet i en NumPy-array

För att följa med i den här guiden behöver du ha Python och NumPy installerat på din dator. Du kan testa koden genom att starta en Python REPL eller skapa en Jupyter Notebook.

Börja med att importera NumPy och ge det aliaset np, vilket är standard.

import numpy as np

Du kan använda funktionen NumPy max() för att ta reda på det maximala värdet i en array (eventuellt längs en specifik axel).

array_1 = np.array([1,5,7,2,10,9,8,4])
print(np.max(array_1))

# Utdata
10

I detta exempel ger np.max(array_1) resultatet 10, vilket stämmer.

Tänk dig att du vill veta var det maximala värdet finns i arrayen. Du kan använda följande tvåstegsprocess:

  • Lokalisera det största värdet.
  • Hitta indexet för det största värdet.

I array_1 återfinns det maximala värdet 10 på index 4 (med start på index 0). Det första elementet har index 0, det andra index 1, och så vidare.

För att hitta indexet kan du använda funktionen NumPy where(). np.where(villkor) ger en lista med alla index där villkoret är sant.

Du måste då extrahera elementet vid det första indexet från listan. För att hitta var det maximala värdet finns, sätter vi villkoret till array_1==10, eftersom 10 är det maximala värdet i array_1.

print(int(np.where(array_1==10)[0]))

# Utdata
4

Vi har använt np.where() enbart med villkoret, men det är inte det mest rekommenderade sättet att använda funktionen på.

📑 Notera: NumPy where()-funktion:
np.where(villkor, x, y) returnerar:

  • Element från x när villkoret är sant, och
  • Element från y när villkoret är falskt.

Genom att kombinera funktionerna np.max() och np.where() kan vi först identifiera det maximala värdet och sedan dess position.

Men istället för den här tvåstegsprocessen kan du direkt använda NumPy argmax() för att få indexet för det maximala elementet i arrayen.

Syntax för NumPy argmax()-funktionen

Den generella syntaxen för att använda NumPy argmax() är:

np.argmax(array, axis, out)
# Vi har importerat numpy med aliaset np

I syntaxen ovan:

  • array är vilken som helst giltig NumPy-array.
  • axis är en valfri parameter. När du arbetar med flerdimensionella arrayer kan du använda axis-parametern för att hitta indexet för det maximala värdet längs en viss dimension.
  • out är också en valfri parameter. Du kan ge out-parametern en NumPy-array för att spara resultatet av funktionen argmax().

Notera: Från och med NumPy version 1.22.0 finns ytterligare en parameter keepdims. När vi anger axis i funktionsanropet argmax(), reduceras arrayens dimension längs denna axel. Om man sätter keepdims till True garanterar man att den returnerade utdatan har samma form som den ursprungliga arrayen.

Använda NumPy argmax() för att lokalisera indexet för det största elementet

#1. Låt oss använda NumPy argmax() för att hitta indexet för det maximala värdet i array_1.

array_1 = np.array([1,5,7,2,10,9,8,4])
print(np.argmax(array_1))

# Utdata
4

Funktionen argmax() returnerar 4, vilket är korrekt! ✅

#2. Om vi omdefinierar array_1 så att 10 förekommer två gånger, returnerar argmax() endast indexet för den första förekomsten.

array_1 = np.array([1,5,7,2,10,10,8,4])
print(np.argmax(array_1))

# Utdata
4

I fortsättningen kommer vi att använda elementen i array_1 som vi definierade i exempel #1.

Använda NumPy argmax() för att hitta indexet för det största elementet i en 2D-array

Låt oss omforma array_1 till en tvådimensionell array med två rader och fyra kolumner.

array_2 = array_1.reshape(2,4)
print(array_2)

# Utdata
[[ 1  5  7  2]
 [10  9  8  4]]

I en tvådimensionell array indikerar axel 0 raderna, och axel 1 kolumnerna. NumPy-arrayer använder index som börjar på noll. Därför är indexen för raderna och kolumnerna i array_2 enligt följande:

Låt oss nu anropa argmax()array_2.

print(np.argmax(array_2))

# Utdata
4

Även om vi anropar argmax() på den tvådimensionella arrayen returneras fortfarande 4. Det är samma som för den endimensionella arrayen array_1.

Varför händer det? 🤔

Det beror på att vi inte har angett något värde för axis-parametern. Som standard, när axis inte är inställt, returnerar argmax() indexet för det största elementet i arrayen som om den vore utplattad.

Vad är en utplattad array? Om det finns en N-dimensionell array med formen d1 x d2 x … x dN, där d1, d2, upp till dN är storleken på arrayen längs de N dimensionerna, så är den utplattade arrayen en lång endimensionell array av storlek d1 * d2 * … * dN.

För att se hur den utplattade versionen av array_2 ser ut, kan du anropa metoden flatten(), som visas nedan:

array_2.flatten()

# Utdata
array([ 1,  5,  7,  2, 10,  9,  8,  4])

Index för det största elementet längs raderna (axis = 0)

Låt oss fortsätta med att hitta indexet för det största värdet längs raderna (axis = 0).

np.argmax(array_2,axis=0)

# Utdata
array([1, 1, 1, 1])

Detta resultat kan vara lite svårt att förstå, men låt oss reda ut hur det fungerar.

Vi har satt axis-parametern till noll (axis = 0) eftersom vi vill hitta indexet för det största elementet längs raderna. Funktionen argmax() returnerar radnumret där det maximala elementet finns, för varje kolumn.

Låt oss visualisera detta för bättre förståelse.

Från diagrammet ovan och utmatningen av argmax() har vi följande:

  • För den första kolumnen (index 0), är det maximala värdet 10 i den andra raden (index = 1).
  • För den andra kolumnen (index 1), är det maximala värdet 9 i den andra raden (index = 1).
  • För den tredje och fjärde kolumnen (index 2 och 3), är de maximala värdena 8 och 4 i den andra raden (index = 1).

Det är därför utdata blir array([1, 1, 1, 1]) – eftersom det maximala värdet längs raderna finns i den andra raden för alla kolumner.

Index för det största elementet längs kolumnerna (axis = 1)

Låt oss nu använda argmax() för att hitta indexet för det största värdet längs kolumnerna.

Kör följande kodavsnitt och observera resultatet.

np.argmax(array_2,axis=1)
array([2, 0])

Kan du analysera resultatet?

Vi har satt axis = 1 för att beräkna indexet för det maximala värdet längs kolumnerna.

Funktionen argmax() returnerar, för varje rad, kolumnnumret där det maximala värdet finns.

Här är en visuell förklaring:

Från diagrammet ovan och utmatningen av argmax() har vi följande:

  • För den första raden (index 0), är det maximala värdet 7 i den tredje kolumnen (index = 2).
  • För den andra raden (index 1), är det maximala värdet 10 i den första kolumnen (index = 0).

Förhoppningsvis förstår du nu vad resultatet array([2, 0]) betyder.

Använda den valfria out-parametern i NumPy argmax()

Du kan använda den valfria parametern out i NumPy argmax() för att spara resultatet i en NumPy-array.

Låt oss skapa en array med nollor för att lagra resultatet av det tidigare anropet till argmax() – för att hitta indexet för det maximala värdet längs kolumnerna (axis = 1).

out_arr = np.zeros((2,))
print(out_arr)
[0. 0.]

Låt oss nu återgå till exemplet där vi hittar indexet för det maximala elementet längs kolumnerna (axis = 1) och sätta out till out_arr som vi definierade ovan.

np.argmax(array_2,axis=1,out=out_arr)

Vi ser att Python tolken genererar ett TypeError eftersom out_arr som standard initierades till en array av flyttal.

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds)
     56     try:
---> 57         return bound(*args, **kwds)
     58     except TypeError:

TypeError: Cannot cast array data from dtype('float64') to dtype('int64') according to the rule 'safe'

När du sätter parametern out till en array är det viktigt att kontrollera att arrayen har rätt form och datatyp. Eftersom arrayindex alltid är heltal, bör vi sätta parametern dtype till int när vi definierar utdata-arrayen.

out_arr = np.zeros((2,),dtype=int)
print(out_arr)

# Utdata
[0 0]

Vi kan nu anropa argmax()-funktionen med både axis– och out-parametrarna. Denna gång fungerar det utan fel.

np.argmax(array_2,axis=1,out=out_arr)

Utmatningen från argmax() är nu tillgänglig i arrayen out_arr.

print(out_arr)
# Utdata
[2 0]

Slutsats

Jag hoppas att den här guiden hjälpte dig att förstå hur du använder NumPy argmax()-funktionen. Du kan testa kodexemplen i en Jupyter Notebook.

Låt oss repetera vad vi har lärt oss.

  • Funktionen NumPy argmax() returnerar indexet för det maximala elementet i en array. Om det maximala elementet förekommer mer än en gång i en array a, returnerar np.argmax(a) indexet för den första förekomsten av elementet.
  • När du arbetar med flerdimensionella arrayer, kan du använda den valfria axis-parametern för att få indexet för det maximala elementet längs en specifik axel. Till exempel, i en tvådimensionell array: genom att sätta axis = 0 och axis = 1 kan du få indexet för det maximala elementet längs raderna respektive kolumnerna.
  • Om du vill spara det returnerade värdet i en annan array kan du sätta den valfria out-parametern till den önskade arrayen. Utdata-arrayen bör dock ha kompatibel form och datatyp.

Kolla gärna in den djupgående guiden om Python sets.