Inférence par argmax

Formule

La fonction de prédiction est predict(X, parametres) = torch.argmax(forward(X, parametres), dim=0). Étant donné que forward retourne $A^{(2)} \in \mathbb{R}^{10 \times m}$ (10 probabilités par exemple), expliquer en mots ce que fait argmax(., dim=0) et donner la forme du tenseur retourné.