机器学习算法:k近邻

发布时间:2020-06-17 05:28:07 作者:BoyTNT
来源:网络 阅读:1426

前言:

最近在研究机器学习,过程中的心得体会会记录到blog里,文章与代码均为原创。会不定期龟速更新,注意这不是正式的教程,因为本人也是初学者,但是估计C#版本的代码能帮到一些刚入门的同学去理解复杂的公式。


------------------------ 我是分割线 ------------------------


k近邻(k-Nearest Neighbor,KNN)算法,应该是机器学习里最基础的算法,其核心思想是:给定一个未知分类的样本,如果与它最相似的k个已知样本中的多数属于某一个分类,那么这个未知样本也属于这个分类


所谓相似,是指两个样本之间的欧氏距离小,其计算公式为:

机器学习算法:k近邻

其中Xi为样本X的第i个特征。


k近邻算法的优点在于实现简单,缺点在于时间和空间复杂度高。


上C#版代码,这里取k=1,即只根据最相近的一个点确定分类:

首先是DataVector,包含N维数据和分类标签,用于表示一个样本。

using System;

namespace MachineLearning
{
    /// <summary>
    /// 数据向量
    /// </summary>
    /// <typeparam name="T"></typeparam>
    public class DataVector<T>
    {
        /// <summary>
        /// N维数据
        /// </summary>
        public T[] Data { get; private set; }
        /// <summary>
        /// 分类标签
        /// </summary>
        public string Label { get; set; }

        /// <summary>
        /// 构造
        /// </summary>
        /// <param name="dimension">数据维度</param>
        public DataVector(int dimension)
        {
            Data = new T[dimension];
        }
        
        public int Dimension
        {
            get { return this.Data.Length; }
        }
    }
}


然后是核心算法:

using System;
using System.Collections.Generic;

namespace MachineLearning
{
    /// <summary>
    /// k近邻法
    /// </summary>
    public class NearestNeighbour
    {
        private int m_K;
        private List<DataVector<double>> m_TrainingSet;
        
        public NearestNeighbour(int k = 1)
        {
            m_K = k;
        }
        
        /// <summary>
        /// 训练
        /// </summary>
        /// <param name="trainingSet"></param>
        public void Train(List<DataVector<double>> trainingSet)
        {
            m_TrainingSet = trainingSet;
        }

        /// <summary>
        /// 分类
        /// </summary>
        /// <param name="vector"></param>
        /// <returns></returns>
        public string Classify(DataVector<double> vector)
        {
            //K=1时可简化处理提高效率
            if(m_K == 1)
            {
                double minDist = double.PositiveInfinity;
                int targetIndex = -1;
                for(int i = 0;i < m_TrainingSet.Count;i++)
                {
                    //计算距离
                    double distance = ComputeDistance(vector, m_TrainingSet[i], minDist);

                    //找最小值
                    if(distance < minDist)
                    {
                        minDist = distance;
                        targetIndex = i;
                    }
                }
            
                return m_TrainingSet[targetIndex].Label;
            }
            else
            {
                var dict = new SortedDictionary<double, string>();
                
                for(int i = 0;i < m_TrainingSet.Count;i++)
                {
                    //计算距离并记录
                    double distance = ComputeDistance(vector, m_TrainingSet[i]);
                    dict[distance] = m_TrainingSet[i].Label;
                }
                
                //找最多的Label
                var labels = new List<string>();
                int count = 0;
                foreach(var label in dict.Values)
                {
                    labels.Add(label);
                    if(++count > m_K - 1)
                        break;
                }

                return GetMajorLabel(labels);
            }
        }
    
        /// <summary>
        /// 计算距离
        /// </summary>
        /// <param name="v1"></param>
        /// <param name="v2"></param>
        /// <param name="minValue"></param>
        /// <returns></returns>
        private double ComputeDistance(DataVector<double> v1, DataVector<double> v2, double minValue = double.PositiveInfinity)
        {
            double distance = 0.0;
            minValue = minValue * minValue;
            for(int i = 0;i < v1.Data.Length;++i)
            {
                double diff = v1.Data[i] - v2.Data[i];
                distance += diff * diff;
            
                //如果当前累加的距离已经大于给定的最小值,不用继续计算了
                if(distance > minValue)
                    return double.PositiveInfinity;
            }

            return Math.Sqrt(distance);
        }
        
        /// <summary>
        /// 取多数
        /// </summary>
        /// <param name="dataSet"></param>
        /// <returns></returns>
        private string GetMajorLabel(List<string> labels)
        {
            var dict = new Dictionary<string, int>();
            foreach(var item in labels)
            {
                if(!dict.ContainsKey(item))
                    dict[item] = 0;
                dict[item]++;
            }

            string label = string.Empty;
            int count = -1;
            foreach(var key in dict.Keys)
            {
                if(dict[key] > count)
                {
                    label = key;
                    count = dict[key];
                }
            }
            
            return label;
        }
    }
}


需要注意的是,计算距离时,数量级大的维度会对距离影响大,因此大多数情况下,不能直接计算,要对原始数据做归一化,并根据重要性进行加权。归一化可以使用公式:value = (old-min)/(max-min),其中old是原始值,max是所有数据的最大值,min是所有数据的最小值。这样计算得到的value将落在0至1的区间上。


这个算法太简单,暂时不上测试代码了,有时间再补吧。


推荐阅读:
  1. 【学习笔记】K近邻归类算法
  2. 机器学习算法

免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:is@yisu.com进行举报,并提供相关证据,一经查实,将立刻删除涉嫌侵权内容。

机器学习 机器学习算法 k近邻

上一篇:如何巧用SSH突破防火墙

下一篇:W-6-2 配置SQL Server数据库

相关阅读

您好,登录后才能下订单哦!

密码登录
登录注册
其他方式登录
点击 登录注册 即表示同意《亿速云用户服务条款》