Inférence par argmax

Idée

Une fois le réseau entraîné, on lui donné une nouvelle image et on regarde laquelle des dix sorties à la valeur la plus grande — c'est la classe prédite.

Pourquoi

Une fois le réseau entraîné, on lui donné une nouvelle image et on regarde laquelle des dix sorties à la valeur la plus grande — c'est la classe prédite. La confiance en chacune des autres classes n'a plus d'importance pour le verdict final.

Outil

Argmax est l'opération $\text{argmax}_i (a_i)$ qui renvoie l'indice du plus grand élément d'une suite finie — comme chercher le maximum d'une fonction réelle sur un ensemble fini, mais on retient l'argument plutôt que la valeur.

Formule

La fonction de prédiction est predict(X, parametres) = torch.argmax(forward(X, parametres), dim=0). Étant donné que forward retourne $A^{(2)} \in \mathbb{R}^{10 \times m}$ (10 probabilités par exemple), expliquer en mots ce que fait argmax(., dim=0) et donner la forme du tenseur retourné.

Piège

Convention features-en-lignes : sortie A2 a shape (10, N) (10 classes, N exemples). argmax(dim=0) retourne l'indice de classe pour chaque exemple — correct. argmax(dim=1) retourne, pour chaque classe, l'exemple le plus 'probable' — n'importe quoi. Sans dim, Pytorch fait l'argmax global (un seul entier sur tout le tenseur). Le notebook Cell 17 utilise dim=0 correctement.