import numpy as np
from sklearn.datasets import make_blobs # 用于生成测试数据
import matplotlib.pyplot as plt
class SimpleKMeans:
def __init__(self, n_clusters=3, max_iter=100, tol=1e-4):
self.n_clusters = n_clusters
self.max_iter = max_iter
self.tol = tol
self.centroids = None
self.labels = None
self.inertia = None
def _initialize_centroids(self, X):
"""随机选择初始质心"""
idx = np.random.choice(X.shape[0], size=self.n_clusters, replace=False)
self.centroids = X[idx]
def _assign_clusters(self, X):
"""为每个样本分配最近的质心"""
distances = np.sqrt(((X - self.centroids[:, np.newaxis])**2).sum(axis=2))
self.labels = np.argmin(distances, axis=0)
def _update_centroids(self, X):
"""根据当前标签更新质心位置"""
new_centroids = np.array([X[self.labels == i].mean(axis=0) for i in range(self.n_clusters)])
return new_centroids
def fit(self, X):
"""执行 KMeans 聚类"""
self._initialize_centroids(X)
for iteration in range(self.max_iter):
old_centroids = self.centroids.copy()
self._assign_clusters(X)
self.centroids = self._update_centroids(X)
# 检查收敛条件
if np.all(np.abs(self.centroids - old_centroids) < self.tol):
print(f"Converged after {iteration + 1} iterations.")
break
# 计算最终惯性
self._compute_inertia(X)
def _compute_inertia(self, X):
"""计算并存储惯性值"""
inertia = 0
for i in range(self.n_clusters):
cluster_points = X[self.labels == i]
if len(cluster_points) > 0:
centroid = self.centroids[i]
inertia += np.sum((cluster_points - centroid)**2)
self.inertia = inertia
def predict(self, X):
"""预测新样本所属的簇"""
distances = np.sqrt(((X - self.centroids[:, np.newaxis])**2).sum(axis=2))
return np.argmin(distances, axis=0)
# 测试代码
def main():
# 生成一些模拟数据
X, y_true = make_blobs(n_samples=300, centers=4, cluster_std=0.60, random_state=0)
# 创建并训练模型
kmeans = SimpleKMeans(n_clusters=4)
kmeans.fit(X)
# 可视化结果
plt.scatter(X[:, 0], X[:, 1], c=kmeans.labels, s=50, cmap='viridis')
plt.scatter(kmeans.centroids[:, 0], kmeans.centroids[:, 1], c='red', marker='x', s=200, linewidths=3)
plt.title('KMeans Clustering Results')
# plt.show()
plt.savefig("./kmeans_result.png")
print(f"Final Inertia: {kmeans.inertia}")
if __name__ == "__main__":
# 生成(3,2)的整数数据集,并打印出来
x = np.random.random((300, 2))
y = np.random.random((4, 1, 2))
z = x-y
k = ((x-y)**2).sum(axis=2)
print(x.shape,y.shape,z.shape,k.shape)