Browse Source

Use mean shift instead of affinity propagation

Arnaud Vergnet 2 years ago
parent
commit
a41d8033c5
4 changed files with 99 additions and 110 deletions
  1. 13
    5
      mydatalib.py
  2. 13
    11
      myplotlib.py
  3. 0
    94
      tp4-affinity.py
  4. 73
    0
      tp4-mean-shift.py

+ 13
- 5
mydatalib.py View File

@@ -31,7 +31,7 @@ def apply_kmeans(data, k: int = 3, init="k-means++"):
31 31
     model = cluster.KMeans(n_clusters=k, init=init)
32 32
     model.fit(data)
33 33
     tps2 = time.time()
34
-    return (model, round((tps2 - tps1)*1000, 2))
34
+    return model, round((tps2 - tps1) * 1000, 2)
35 35
 
36 36
 
37 37
 def apply_agglomerative_clustering(data, k: int = 3, linkage="complete"):
@@ -40,7 +40,7 @@ def apply_agglomerative_clustering(data, k: int = 3, linkage="complete"):
40 40
         n_clusters=k, affinity='euclidean', linkage=linkage)
41 41
     model.fit(data)
42 42
     tps2 = time.time()
43
-    return (model, round((tps2 - tps1)*1000, 2))
43
+    return model, round((tps2 - tps1) * 1000, 2)
44 44
 
45 45
 
46 46
 def apply_DBSCAN(data, eps, min_pts):
@@ -48,7 +48,15 @@ def apply_DBSCAN(data, eps, min_pts):
48 48
     model = cluster.DBSCAN(eps=eps, min_samples=min_pts)
49 49
     model.fit(data)
50 50
     tps2 = time.time()
51
-    return (model, round((tps2 - tps1)*1000, 2))
51
+    return model, round((tps2 - tps1) * 1000, 2)
52
+
53
+
54
+def apply_mean_shift(data, bandwidth: float):
55
+    tps1 = time.time()
56
+    model = cluster.MeanShift(bandwidth=bandwidth)
57
+    model.fit(data)
58
+    tps2 = time.time()
59
+    return model, round((tps2 - tps1) * 1000, 2)
52 60
 
53 61
 
54 62
 def evaluate(data, model):
@@ -56,6 +64,6 @@ def evaluate(data, model):
56 64
         silh = metrics.silhouette_score(data, model.labels_)
57 65
         davies = metrics.davies_bouldin_score(data, model.labels_)
58 66
         calinski = metrics.calinski_harabasz_score(data, model.labels_)
59
-        return (silh, davies, calinski)
67
+        return silh, davies, calinski
60 68
     except ValueError:
61
-        return (None, None, None)
69
+        return None, None, None

+ 13
- 11
myplotlib.py View File

@@ -22,7 +22,7 @@ def print_3d_data(data,
22 22
     f2 = data[:, 2]  # tous les éléments de la troisième colonne
23 23
     fig = plt.figure()
24 24
     ax = fig.gca(projection='3d')  # Affichage en 3D
25
-    if (c is None):
25
+    if c is None:
26 26
         ax.scatter(f0, f1, f2, label='Courbe',
27 27
                    marker='d')
28 28
         plt.title("Données initiales : " + dataset_name)
@@ -35,8 +35,8 @@ def print_3d_data(data,
35 35
     ax.set_ylabel('Y')
36 36
     ax.set_zlabel('Z')
37 37
     plt.tight_layout()
38
-    if (save):
39
-        if (c is None):
38
+    if save:
39
+        if c is None:
40 40
             save_path = "IMG/DATA_VISUALISATION/"
41 41
             if not os.path.exists(save_path):
42 42
                 os.makedirs(save_path)
@@ -54,24 +54,26 @@ def print_3d_data(data,
54 54
 def print_2d_data(data,
55 55
                   dataset_name: str = "",
56 56
                   method_name: str = "",
57
-                  k: int = 0,
57
+                  k = 0,
58 58
                   stop: bool = True,
59 59
                   save: bool = False,
60 60
                   c=None):
61 61
     f0 = data[:, 0]  # tous les élements de la première colonne
62 62
     f1 = data[:, 1]  # tous les éléments de la deuxième colonne
63 63
     plt.figure()
64
+    # utilisation d'une décimale si float
65
+    k_str = str(round(k, 1)) if isinstance(k, float) else str(k)
64 66
     # plt.hist2d(f0, f1)
65
-    if (c is None):
67
+    if c is None:
66 68
         plt.scatter(f0, f1, s=8)
67 69
         plt.title("Données initiales : " + dataset_name)
68 70
     else:
69 71
         plt.scatter(f0, f1, c=c, s=8)
70
-        plt.title("Graphique de " + str(k) + " clusters avec la méthode " +
72
+        plt.title("Graphique de " + k_str + " clusters avec la méthode " +
71 73
                   method_name + " sur le jeu de données " + dataset_name)
72 74
 
73
-    if (save):
74
-        if (c is None):
75
+    if save:
76
+        if c is None:
75 77
             save_path = "IMG/DATA_VISUALISATION/"
76 78
             if not os.path.exists(save_path):
77 79
                 os.makedirs(save_path)
@@ -80,7 +82,7 @@ def print_2d_data(data,
80 82
             save_path = "IMG/" + method_name + "/" + dataset_name + "/CLUSTERS/"
81 83
             if not os.path.exists(save_path):
82 84
                 os.makedirs(save_path)
83
-            plt.savefig(save_path + "k=" + str(k) + ".png")
85
+            plt.savefig(save_path + "k=" + k_str + ".png")
84 86
         plt.close()
85 87
     else:
86 88
         plt.show(block=stop)
@@ -101,7 +103,7 @@ def print_1d_data(x, y,
101 103
               method_name + " sur les données " + dataset_name)
102 104
     plt.xlabel(x_name + " (" + x_unit + ")")
103 105
     plt.ylabel(y_name + " (" + y_unit + ")")
104
-    if (save):
106
+    if save:
105 107
         save_path = "IMG/" + method_name + "/" + dataset_name + "/EVALUATION/"
106 108
         if not os.path.exists(save_path):
107 109
             os.makedirs(save_path)
@@ -126,7 +128,7 @@ def print_dendrogramme(data,
126 128
                    show_leaf_counts=False)
127 129
     plt.title("Dendrogramme du jeu de données " +
128 130
               dataset_name + " avec le linkage " + linkage)
129
-    if (save):
131
+    if save:
130 132
         save_path = "IMG/DENDROGRAMME/" + linkage + "/"
131 133
         if not os.path.exists(save_path):
132 134
             os.makedirs(save_path)

+ 0
- 94
tp4-affinity.py View File

@@ -1,94 +0,0 @@
1
-#!/usr/bin/env python3
2
-# -*- coding: utf-8 -*-
3
-"""
4
-Created on Wed Dec  8 16:07:28 2021
5
-
6
-@author: pfaure
7
-"""
8
-
9
-from sklearn.neighbors import NearestNeighbors
10
-import numpy as np
11
-
12
-from myplotlib import print_1d_data, print_2d_data
13
-from mydatalib import extract_data_2d, scale_data, apply_DBSCAN, evaluate
14
-
15
-path = './artificial/'
16
-dataset_name = "banana"
17
-save = True
18
-
19
-print("-----------------------------------------------------------")
20
-print("     Chargement du dataset : " + dataset_name)
21
-data = extract_data_2d(path + dataset_name)
22
-print_2d_data(data, dataset_name=dataset_name +
23
-              "_brutes", stop=False, save=save)
24
-
25
-print("-----------------------------------------------------------")
26
-print("     Mise à l'échelle")
27
-data_scaled = scale_data(data)
28
-print_2d_data(data_scaled, dataset_name=dataset_name +
29
-              "_scaled", stop=False, save=save)
30
-
31
-print("-----------------------------------------------------------")
32
-print("     Calcul du voisinage")
33
-n = 50
34
-neighbors = NearestNeighbors(n_neighbors=n)
35
-neighbors.fit(data)
36
-distances, indices = neighbors.kneighbors(data)
37
-distances = list(map(lambda x: sum(x[1:n-1])/(len(x)-1), distances))
38
-print(distances)
39
-distances = np.sort(distances, axis=0)
40
-print(distances)
41
-print_1d_data(distances, range(1, len(distances)+1), x_name="distance_moyenne",
42
-              y_name="nombre_de_points", stop=False, save=False)
43
-
44
-
45
-print("-----------------------------------------------------------")
46
-print("     Création clusters : DBSCAN")
47
-params = []
48
-for i in range(1, 20):
49
-    params += [(i/100, 5)]
50
-durations = []
51
-silouettes = []
52
-daviess = []
53
-calinskis = []
54
-clusters = []
55
-noise_points = []
56
-for (distance, min_pts) in params:
57
-    # Application du clustering agglomeratif
58
-    (model, duration) = apply_DBSCAN(data, distance, min_pts)
59
-    cl_pred = model.labels_
60
-    # Affichage des clusters# Affichage des clusters
61
-    print_2d_data(data_scaled, dataset_name=dataset_name,
62
-                  method_name="DBSCAN-Eps=" +
63
-                  str(distance)+"-Minpt="+str(min_pts),
64
-                  k=0, stop=False, save=save, c=cl_pred)
65
-    # Evaluation de la solution de clustering
66
-    (silouette, davies, calinski) = evaluate(data_scaled, model)
67
-    # Enregistrement des valeurs
68
-    durations += [duration]
69
-    silouettes += [silouette]
70
-    daviess += [davies]
71
-    calinskis += [calinski]
72
-    clusters += [len(set(cl_pred)) - (1 if -1 in cl_pred else 0)]
73
-    noise_points += [list(cl_pred).count(-1)]
74
-
75
-# Affichage des résultats
76
-params = [str(i) for i in params]
77
-print_1d_data(params, durations, x_name="(eps,min_pts)",
78
-              y_name="temps_de_calcul", y_unit="ms", dataset_name=dataset_name,
79
-              method_name="DBSCAN", stop=False, save=save)
80
-print_1d_data(params, silouettes, x_name="(eps,min_pts)",
81
-              y_name="coeficient_de_silhouette", dataset_name=dataset_name,
82
-              method_name="DBSCAN", stop=False, save=save)
83
-print_1d_data(params, daviess, x_name="(eps,min_pts)",
84
-              y_name="coeficient_de_Davies", dataset_name=dataset_name,
85
-              method_name="DBSCAN", stop=False, save=save)
86
-print_1d_data(params, calinskis, x_name="(eps,min_pts)",
87
-              y_name="coeficient_de_Calinski", dataset_name=dataset_name,
88
-              method_name="DBSCAN", stop=False, save=save)
89
-print_1d_data(params, clusters, x_name="(eps,min_pts)",
90
-              y_name="nombre_de_clusters", dataset_name=dataset_name,
91
-              method_name="DBSCAN", stop=False, save=save)
92
-print_1d_data(params, noise_points, x_name="(eps,min_pts)",
93
-              y_name="points_de_bruit", dataset_name=dataset_name,
94
-              method_name="DBSCAN", stop=False, save=save)

+ 73
- 0
tp4-mean-shift.py View File

@@ -0,0 +1,73 @@
1
+#!/usr/bin/env python3
2
+# -*- coding: utf-8 -*-
3
+"""
4
+Created on Wed Dec  8 16:07:28 2021
5
+
6
+@author: pfaure
7
+"""
8
+from numpy import arange
9
+from sklearn.neighbors import NearestNeighbors
10
+import numpy as np
11
+
12
+from myplotlib import print_1d_data, print_2d_data
13
+from mydatalib import extract_data_2d, scale_data, apply_mean_shift, evaluate
14
+
15
+path = './artificial/'
16
+dataset_name = "xclara"
17
+method_name = "mean-shift"
18
+save = True
19
+
20
+print("-----------------------------------------------------------")
21
+print("     Chargement du dataset : " + dataset_name)
22
+data = extract_data_2d(path + dataset_name)
23
+print_2d_data(data, dataset_name=dataset_name +
24
+              "_brutes", stop=False, save=save)
25
+
26
+print("-----------------------------------------------------------")
27
+print("     Mise à l'échelle")
28
+data_scaled = scale_data(data)
29
+print_2d_data(data_scaled, dataset_name=dataset_name +
30
+              "_scaled", stop=False, save=save)
31
+
32
+# Application de Affinity Propagation pour plusieurs valeurs de préférence
33
+# et evaluation de la solution
34
+
35
+k_max = 2
36
+
37
+k = []
38
+durations = []
39
+silouettes = []
40
+daviess = []
41
+calinskis = []
42
+for bandwidth in arange(0.1, k_max, 0.1):
43
+    # Application du clustering
44
+    (model, duration) = apply_mean_shift(
45
+        data_scaled, bandwidth=bandwidth)
46
+    # Affichage des clusters
47
+    print_2d_data(data_scaled, dataset_name=dataset_name,
48
+                  method_name=method_name, k=bandwidth,
49
+                  stop=False, save=save, c=model.labels_)
50
+    # Evaluation de la solution de clustering
51
+    (silouette, davies, calinski) = evaluate(data_scaled, model)
52
+    # Enregistrement des valeurs
53
+    k += [bandwidth]
54
+    durations += [duration]
55
+    silouettes += [silouette]
56
+    daviess += [davies]
57
+    calinskis += [calinski]
58
+
59
+# Affichage des résultats
60
+print_1d_data(k, k, x_name="k", y_name="k", dataset_name=dataset_name,
61
+              method_name=method_name, stop=False, save=save)
62
+print_1d_data(k, durations, x_name="k", y_name="temps_de_calcul", y_unit="ms",
63
+              dataset_name=dataset_name,
64
+              method_name=method_name, stop=False, save=save)
65
+print_1d_data(k, silouettes, x_name="k", y_name="coeficient_de_silhouette",
66
+              dataset_name=dataset_name,
67
+              method_name=method_name, stop=False, save=save)
68
+print_1d_data(k, daviess, x_name="k", y_name="coeficient_de_Davies",
69
+              dataset_name=dataset_name,
70
+              method_name=method_name, stop=False, save=save)
71
+print_1d_data(k, calinskis, x_name="k", y_name="coeficient_de_Calinski",
72
+              dataset_name=dataset_name,
73
+              method_name=method_name, stop=False, save=save)

Loading…
Cancel
Save