近似误差、估计误差(知乎解释)
hei,wei,tag
1.5,40,thin
1.5,50,fat
1.5,60,fat
1.6,40,thin
1.6,50,thin
1.6,60,fat
1.6,70,fat
1.7,50,thin
1.7,60,thin
1.7,70,fat
1.7,80,fat
1.8,60,thin
1.8,70,thin
1.8,80,fat
1.8,90,fat
1.9,80,thin
1.9,90,fat
# -*- coding: utf-8 -*-
"""
Created on Fri Dec 8 17:21:14 2017
@author: jasonhaven
"""
import os
import numpy as np
import pandas as pd
import operator
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
def read_from_csv(file):
'''
file:文件绝对地址
功能:读入csv文件并解析出数据集和标签集
'''
pwd=os.getcwd()
os.chdir(os.path.dirname(file))
df=pd.read_csv('data.csv')
os.chdir(pwd)
datas=df.iloc[:,:2].astype(np.float).values
labels = df.iloc[0:,-1:].astype(np.str).values#加载类别标签部分
return datas,labels
def classify(instance,datas,labels,k):
'''
instance:新的实例特征向量
datas:训练数据集
labels:标签集
k:选择相邻的k个实例
'''
num=datas.shape[0]
#tile(A, reps)返回一个shape=reps的矩阵,矩阵的每个元素是A
diffMat = np.tile(instance, (num, 1)) - datas
#diffMat就是输入样本与每个训练样本的差值
square_diffMat = diffMat**2
#然后对其每个x和y的差值进行平方运算。
square_distances=square_diffMat.sum(axis=1)
#开平方根求出距离
distances=square_distances**0.5
#argsort函数返回的是关键字(数组值)从小到大的索引值
sorted_distances = distances.argsort()
class_count = {}
# 投票过程,就是统计前k个最近的样本所属类别包含的样本个数
for i in range(k):
# sortedDistIndicies[i]是第i个最相近的样本下标
voteIlabel = str(labels[sorted_distances[i]])
# 然后将票数增1
class_count[voteIlabel] = class_count.get(voteIlabel, 0) + 1
# 把分类结果进行排序,然后返回得票数最多的分类结果
sorted_class_count = sorted(class_count.items(), key=operator.itemgetter(1), reverse=True)
return sorted_class_count[0][0]
def draw(datas,labels):
plt.figure('KNN')
plt.title('KNN')
plt.xlabel('height')
plt.ylabel('weight')
green_patch=mpatches.Patch(color='green', label='thin')
red_patch=mpatches.Patch(color='red', label='fat')
handles=[red_patch,green_patch]
plt.legend(handles=handles)
for i,x in enumerate(datas):
if labels[i]=='thin':
plt.scatter(x[0],x[1],s=100,marker='.',c='g')
else:
plt.scatter(x[0],x[1],s=100,marker='x',c='r')
plt.show()
if __name__=='__main__':
#获取数据集
file='./data.csv'#data.csv : 身高,体重,标签
datas,labels=read_from_csv(file)
labels=list(labels)
#新实例
instance=[1.7,60]
k=2
#分类
label=classify(instance,datas,labels,k)
draw(datas,labels)
print("knn classify : %s's label is %s"%(str(instance),label))
扫码关注腾讯云开发者
领取腾讯云代金券
Copyright © 2013 - 2025 Tencent Cloud. All Rights Reserved. 腾讯云 版权所有
深圳市腾讯计算机系统有限公司 ICP备案/许可证号:粤B2-20090059 深公网安备号 44030502008569
腾讯云计算(北京)有限责任公司 京ICP证150476号 | 京ICP备11018762号 | 京公网安备号11010802020287
Copyright © 2013 - 2025 Tencent Cloud.
All Rights Reserved. 腾讯云 版权所有