K-Means 聚类算法实现:Python 代码详解及可视化
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs
np.random.seed(2)
# 计算两点之间距离
def distance(pt1, pt2):
return np.sum((pt1 - pt2) ** 2)
# 计算当前各个中心点中离给定点pt最近的一个
def classify(pt, centers):
temp = [distance(pt, centers[i]) for i in range(len(centers))]
index = np.argmin(temp)
dist = temp[index]
return (index, dist)
# 将每个点归入到离它最近的中心点类别下
########## Begin ##########
for i in range(M):
category, dist = classify(X[i], centers)
categories[category].append(X[i])
# 计算每个点到离它最近的中心点的平均距离
cost_new = 0
for i in range(K):
for pt in categories[i]:
cost_new += distance(pt, centers[i])
cost_new /= M
# 根据当前的类别划分,重新计算每个类别下新的中心点
calc_centers(categories, centers)
iter = 0
# 循环计算新的中心点,直到误差小于一定的阈值或达到最大循环次数
while abs(cost - cost_new) > epsilon and iter < max_iter:
cost = cost_new
categories = [[] for i in range(K)]
# 将每个点归入到离它最近的中心点类别下
for i in range(M):
category, dist = classify(X[i], centers)
categories[category].append(X[i])
# 计算每个点到离它最近的中心点的平均距离
cost_new = 0
for i in range(K):
for pt in categories[i]:
cost_new += distance(pt, centers[i])
cost_new /= M
# 根据当前的类别划分,重新计算每个类别下新的中心点
calc_centers(categories, centers)
iter += 1
print("经过",iter,"次循环,质心计算完成...")
########## End ##########
M = 100
K = 4
########### Begin ##########
# 围绕K个中心点,生成M个随机二维数据点
X, y = make_blobs(n_samples=M, centers=K, random_state=1)
########## End ##########
cost = 1e10 # 初始误差,设置为很大
epsilon = 1e-8
max_iter = 100
categories = [[] for i in range(K)]
# 随机选择K个中心点作为初始点
init_indecies = np.random.randint(0, M, K)
centers = X[init_indecies]
########## Begin ##########
# 计算分类
for i in range(M):
category, dist = classify(X[i], centers)
categories[category].append(X[i])
# 重新计算各类中心点
calc_centers(categories, centers)
# 再次计算
cost_new = 0
for i in range(K):
for pt in categories[i]:
cost_new += distance(pt, centers[i])
cost_new /= M
########## End ##########
print("经过",iter,"次循环,质心计算完成...")
yPredicts = np.zeros(M)
# 计算每个样本所属的类别
for i in np.arange(M):
category, _ = classify(X[i], centers)
yPredicts[i] = category
# 查看原始分类情况
plt.figure(1)
plt.title("Origin Classification")
plt.scatter(X[:, 0], X[:, 1], c=y, s=30, cmap=plt.cm.Paired)
plt.savefig('/data/workspace/myshixun/step1/picture1/pic1.jpg')
# 查看K-Means分类结果
plt.figure(2)
plt.title("KMeans Classification")
plt.scatter(X[:, 0], X[:, 1], c=yPredicts, s=30, cmap=plt.cm.Paired)
# plt.show()
plt.savefig('/data/workspace/myshixun/step1/picture1/pic2.jpg')
原文地址: https://www.cveoy.top/t/topic/onQy 著作权归作者所有。请勿转载和采集!