DBSCAN 的以下功能可能会有所帮助。我已经编写它来迭代超参数 eps 和 min_samples 并包含 min 和 max 集群的可选参数。由于 DBSCAN 是无监督的,因此我没有包含评估参数。
def dbscan_grid_search(X_data, lst, clst_count, eps_space = 0.5,
min_samples_space = 5, min_clust = 0, max_clust = 10):
"""
Performs a hyperparameter grid search for DBSCAN.
Parameters:
* X_data = data used to fit the DBSCAN instance
* lst = a list to store the results of the grid search
* clst_count = a list to store the number of non-whitespace clusters
* eps_space = the range values for the eps parameter
* min_samples_space = the range values for the min_samples parameter
* min_clust = the minimum number of clusters required after each search iteration in order for a result to be appended to the lst
* max_clust = the maximum number of clusters required after each search iteration in order for a result to be appended to the lst
Example:
# Loading Libraries
from sklearn import datasets
from sklearn.preprocessing import StandardScaler
import pandas as pd
# Loading iris dataset
iris = datasets.load_iris()
X = iris.data[:, :]
y = iris.target
# Scaling X data
dbscan_scaler = StandardScaler()
dbscan_scaler.fit(X)
dbscan_X_scaled = dbscan_scaler.transform(X)
# Setting empty lists in global environment
dbscan_clusters = []
cluster_count = []
# Inputting function parameters
dbscan_grid_search(X_data = dbscan_X_scaled,
lst = dbscan_clusters,
clst_count = cluster_count
eps_space = pd.np.arange(0.1, 5, 0.1),
min_samples_space = pd.np.arange(1, 50, 1),
min_clust = 3,
max_clust = 6)
"""
# Importing counter to count the amount of data in each cluster
from collections import Counter
# Starting a tally of total iterations
n_iterations = 0
# Looping over each combination of hyperparameters
for eps_val in eps_space:
for samples_val in min_samples_space:
dbscan_grid = DBSCAN(eps = eps_val,
min_samples = samples_val)
# fit_transform
clusters = dbscan_grid.fit_predict(X = X_data)
# Counting the amount of data in each cluster
cluster_count = Counter(clusters)
# Saving the number of clusters
n_clusters = sum(abs(pd.np.unique(clusters))) - 1
# Increasing the iteration tally with each run of the loop
n_iterations += 1
# Appending the lst each time n_clusters criteria is reached
if n_clusters >= min_clust and n_clusters <= max_clust:
dbscan_clusters.append([eps_val,
samples_val,
n_clusters])
clst_count.append(cluster_count)
# Printing grid search summary information
print(f"""Search Complete. \nYour list is now of length {len(lst)}. """)
print(f"""Hyperparameter combinations checked: {n_iterations}. \n""")