fin TP3 SVM

This commit is contained in:
Titouan Labourdette 2021-11-26 20:29:02 +01:00
parent 3471776951
commit 13bac40fa4
6 changed files with 1460 additions and 752 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 8,
"id": "530f620c",
"metadata": {},
"outputs": [],
@ -16,12 +16,13 @@
"from matplotlib import pyplot as plt\n",
"from sklearn.model_selection import KFold\n",
"import time\n",
"import statistics"
"import statistics\n",
"from sklearn import metrics"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 3,
"id": "68b6a517",
"metadata": {},
"outputs": [],
@ -863,11 +864,46 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 11,
"id": "98107e41",
"metadata": {},
"outputs": [],
"source": []
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Matrice de confusion K-NN :\n",
" [[51 0 0 0 0 1 0 0 0 0]\n",
" [ 0 56 0 0 0 0 0 0 0 0]\n",
" [ 3 1 45 1 0 0 1 1 0 0]\n",
" [ 0 1 1 35 0 1 0 1 1 1]\n",
" [ 0 3 0 0 48 0 0 0 0 2]\n",
" [ 0 1 0 1 0 38 0 0 0 0]\n",
" [ 0 0 0 0 0 2 44 0 0 0]\n",
" [ 0 2 0 0 3 0 0 47 0 0]\n",
" [ 2 0 0 0 0 3 1 0 42 2]\n",
" [ 0 0 0 0 4 1 0 1 2 50]]\n"
]
}
],
"source": [
"### Create vector of 5000 random indexes\n",
"rand_indexes = np.random.randint(70000, size=5000)\n",
"### Load data with the previous vector\n",
"data = mnist.data[rand_indexes]\n",
"# print(\"Dataset : \", data)\n",
"target = mnist.target[rand_indexes]\n",
"\n",
"# Split the dataset\n",
"xtrain, xtest, ytrain, ytest = model_selection.train_test_split(data, target,train_size=0.9)\n",
"\n",
"# Training on xtrain,ytrain\n",
"clf = neighbors.KNeighborsClassifier(n_neighbors=3,p=2,n_jobs=1)\n",
"clf.fit(xtrain, ytrain)\n",
"# Predicting on xtest\n",
"pred = clf.predict(xtest)\n",
"print(\"Matrice de confusion K-NN :\\n\", metrics.confusion_matrix(ytest, pred))"
]
},
{
"cell_type": "code",
@ -894,7 +930,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
"version": "3.8.8"
}
},
"nbformat": 4,

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long