KNN算法(K近邻算法)

目录

含义

代码

含义

knn(k近邻)算法就像它的名字一样,使用邻居而且是近的邻居来确定某个样本的类别,就好像物以类聚这种意思。

为了理解的更直观,观看下面的图片,模拟knn算法把中间的星星归类成正方形或则是圆形。

当k=3时,可以发现就是小圈的情况 找到最近的三个图形,两个正方形,一个圆形,那么此时星星就被划分到正方形的分类

当k=7的,找到最近的7个图形,有三个正方形,四个圆形,那么此时星星就被划分到圆形的类别

图例1-1

代码

from keras.datasets import mnist
import numpy as np
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 归一化处理
x_train = x_train.astype(np.float32) / 255.0
x_test = x_test.astype(np.float32) / 255.0

print(x_train.shape, y_train.shape)

def distance(a, b):
    return np.sqrt(np.sum(np.square(a - b)))

class KNN:
    def __init__(self, k, label_num):
        self.k = k
        self.label_num = label_num # 类别的数量

    def fit(self, x_train, y_train):
        # 在类中保存训练数据
        self.x_train = x_train
        self.y_train = y_train

    def get_knn_indices(self, x):
        # 获取距离目标样本点最近的K个样本点的标签
        # 计算已知样本的距离
        dis = list(map(lambda a: distance(a, x), self.x_train))
        # 按距离从小到大排序,并得到对应的下标
        knn_indices = np.argsort(dis)
        # 取最近的K个
        knn_indices = knn_indices[:self.k]
        return knn_indices

    def get_label(self, x):
        # 对KNN方法的具体实现,观察K个近邻并使用np.argmax获取其中数量最多的类别
        knn_indices = self.get_knn_indices(x)
        # 类别计数
        label_statistic = np.zeros(shape=[self.label_num])
        for index in knn_indices:
            label = int(self.y_train[index])
            label_statistic[label] += 1
        # 返回数量最多的类别
        return np.argmax(label_statistic)

    def predict(self, x_test):
        # 预测样本 test_x 的类别
        predicted_test_labels = np.zeros(shape=[len(x_test)], dtype=int)
        for i, x in enumerate(x_test):
            predicted_test_labels[i] = self.get_label(x)
        return predicted_test_labels

knn=KNN(5, 10)
knn.fit(x_train, y_train)

sum=0.0
for i,v in enumerate(knn.predict(x_test[:100])):
    if(v==y_test[i]):
        sum+=1
print(sum/100)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值