博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
[Machine Learning]kNN代码实现(Kd tree)
阅读量:4337 次
发布时间:2019-06-07

本文共 7088 字,大约阅读时间需要 23 分钟。

具体描述见《统计学习方法》第三章。

1 //  2 //  main.cpp  3 //  kNN  4 //  5 //  Created by feng on 15/10/24.  6 //  Copyright © 2015年 ttcn. All rights reserved.  7 //  8   9 #include 
10 #include
11 #include
12 #include
13 using namespace std; 14 15 template
16 struct KdTree { 17 // ctor 18 KdTree():parent(nullptr), leftChild(nullptr), rightChild(nullptr) {} 19 20 // KdTree是否为空 21 bool isEmpty() { return root.empty(); } 22 23 // KdTree是否为叶子节点 24 bool isLeaf() { return !root.empty() && !leftChild && !rightChild;} 25 26 // KdTree是否为根节点 27 bool isRoot() { return !isEmpty() && !parent;} 28 29 // 判断KdTree是否为根节点的左儿子 30 bool isLeft() { return parent->leftChild->root == root; } 31 32 // 判断KdTree是否为根节点的右儿子 33 bool isRight() { return parent->rightChild->root == root; } 34 35 // 存放根节点的数据 36 vector
root; 37 38 // 父节点 39 KdTree
*parent; 40 41 // 左儿子 42 KdTree
*leftChild; 43 44 // 右儿子 45 KdTree
*rightChild; 46 }; 47 48 49 /** 50 * 矩阵转置 51 * 52 * @param matrix 原矩阵 53 * 54 * @return 原矩阵的转置矩阵 55 */ 56 template
57 vector
> transpose(const vector
> &matrix) { 58 size_t rows = matrix.size(); 59 size_t cols = matrix[0].size(); 60 vector
> trans(cols, vector
(rows, 0)); 61 for (size_t i = 0; i < cols; ++i) { 62 for (size_t j = 0; j < rows; ++j) { 63 trans[i][j] = matrix[j][i]; 64 } 65 } 66 67 return trans; 68 } 69 70 /** 71 * 找中位数 72 * 73 * @param vec 数组 74 * 75 * @return 数组中的中位数 76 */ 77 template
78 T findMiddleValue(vector
vec) { 79 sort(vec.begin(), vec.end()); 80 size_t pos = vec.size() / 2; 81 return vec[pos]; 82 } 83 84 /** 85 * 递归构造KdTree 86 * 87 * @param tree KdTree根节点 88 * @param data 数据矩阵 89 * @param depth 当前节点深度 90 * 91 * @return void 92 */ 93 template
94 void buildKdTree(KdTree
*tree, vector
> &data, size_t depth) { 95 // 输入数据个数 96 size_t samplesNum = data.size(); 97 98 if (samplesNum == 0) { 99 return;100 }101 102 if (samplesNum == 1) {103 tree->root = data[0];104 return;105 }106 107 // 每一个输入数据的维度,属性个数108 size_t k = data[0].size();109 vector
> transData = transpose(data);110 111 // 找到当前切分点112 size_t splitAttributeIndex = depth % k;113 vector
splitAttributes = transData[splitAttributeIndex];114 T splitValue = findMiddleValue(splitAttributes);115 116 vector
> leftSubSet;117 vector
> rightSubset;118 119 for (size_t i = 0; i < samplesNum; ++i) {120 if (splitAttributes[i] == splitValue && tree->isEmpty()) {121 tree->root = data[i];122 } else if (splitAttributes[i] < splitValue) {123 leftSubSet.push_back(data[i]);124 } else {125 rightSubset.push_back(data[i]);126 }127 }128 129 tree->leftChild = new KdTree
;130 tree->leftChild->parent = tree;131 tree->rightChild = new KdTree
;132 tree->rightChild->parent = tree;133 buildKdTree(tree->leftChild, leftSubSet, depth + 1);134 buildKdTree(tree->rightChild, rightSubset, depth + 1);135 }136 137 /**138 * 递归打印KdTree139 *140 * @param tree KdTree141 * @param depth 当前深度142 *143 * @return void144 */145 template
146 void printKdTree(const KdTree
*tree, size_t depth) {147 for (size_t i = 0; i < depth; ++i) {148 cout << "\t";149 }150 151 for (size_t i = 0; i < tree->root.size(); ++i) {152 cout << tree->root[i] << " ";153 }154 cout << endl;155 156 if (tree->leftChild == nullptr && tree->rightChild == nullptr) {157 return;158 } else {159 if (tree->leftChild) {160 for (int i = 0; i < depth + 1; ++i) {161 cout << "\t";162 }163 cout << "left : ";164 printKdTree(tree->leftChild, depth + 1);165 }166 167 cout << endl;168 169 if (tree->rightChild) {170 for (size_t i = 0; i < depth + 1; ++i) {171 cout << "\t";172 }173 cout << "right : ";174 printKdTree(tree->rightChild, depth + 1);175 }176 cout << endl;177 }178 }179 180 /**181 * 节点之间的欧氏距离182 *183 * @param p1 节点1184 * @param p2 节点2185 *186 * @return 节点之间的欧式距离187 */188 template
189 T calDistance(const vector
&p1, const vector
&p2) {190 T res = 0;191 for (size_t i = 0; i < p1.size(); ++i) {192 res += pow(p1[i] - p2[i], 2);193 }194 195 return res;196 }197 198 /**199 * 搜索目标节点的最近邻200 *201 * @param tree KdTree202 * @param goal 待分类的节点203 *204 * @return 最近邻节点205 */206 template
207 vector
searchNearestNeighbor(KdTree
*tree, const vector
&goal ) {208 // 节点数属性个数209 size_t k = tree->root.size();210 // 划分的索引211 size_t d = 0;212 KdTree
*currentTree = tree;213 vector
currentNearest = currentTree->root;214 // 找到目标节点的最叶节点215 while (!currentTree->isLeaf()) {216 size_t index = d % k;217 if (currentTree->rightChild->isEmpty() || goal[index] < currentNearest[index]) {218 currentTree = currentTree->leftChild;219 } else {220 currentTree = currentTree->rightChild;221 }222 223 ++d;224 }225 currentNearest = currentTree->root;226 T currentDistance = calDistance(goal, currentTree->root);227 228 KdTree
*searchDistrict;229 if (currentTree->isLeft()) {230 if (!(currentTree->parent->rightChild)) {231 searchDistrict = currentTree;232 } else {233 searchDistrict = currentTree->parent->rightChild;234 }235 } else {236 searchDistrict = currentTree->parent->leftChild;237 }238 239 while (!(searchDistrict->parent)) {240 T districtDistance = abs(goal[(d + 1) % k] - searchDistrict->parent->root[(d + 1) % k]);241 242 if (districtDistance < currentDistance) {243 T parentDistance = calDistance(goal, searchDistrict->parent->root);244 245 if (parentDistance < currentDistance) {246 currentDistance = parentDistance;247 currentTree = searchDistrict->parent;248 currentNearest = currentTree->root;249 }250 251 if (!searchDistrict->isEmpty()) {252 T rootDistance = calDistance(goal, searchDistrict->root);253 if (rootDistance < currentDistance) {254 currentDistance = rootDistance;255 currentTree = searchDistrict;256 currentNearest = currentTree->root;257 }258 }259 260 if (!(searchDistrict->leftChild)) {261 T leftDistance = calDistance(goal, searchDistrict->leftChild->root);262 if (leftDistance < currentDistance) {263 currentDistance = leftDistance;264 currentTree = searchDistrict;265 currentNearest = currentTree->root;266 }267 }268 269 if (!(searchDistrict->rightChild)) {270 T rightDistance = calDistance(goal, searchDistrict->rightChild->root);271 if (rightDistance < currentDistance) {272 currentDistance = rightDistance;273 currentTree = searchDistrict;274 currentNearest = currentTree->root;275 }276 }277 278 }279 280 if (!(searchDistrict->parent->parent)) {281 searchDistrict = searchDistrict->parent->isLeft()? searchDistrict->parent->parent->rightChild : searchDistrict->parent->parent->leftChild;282 } else {283 searchDistrict = searchDistrict->parent;284 }285 ++d;286 }287 288 return currentNearest;289 }290 291 int main(int argc, const char * argv[]) {292 vector
> trainDataSet{ { 2,3},{ 5,4},{ 9,6},{ 4,7},{ 8,1},{ 7,2}};293 KdTree
*kdTree = new KdTree
;294 buildKdTree(kdTree, trainDataSet, 0);295 printKdTree(kdTree, 0);296 297 vector
goal{ 3, 4.5};298 vector
nearestNeighbor = searchNearestNeighbor(kdTree, goal);299 300 for (auto i : nearestNeighbor) {301 cout << i << " ";302 }303 cout << endl;304 305 return 0;306 }

 

转载于:https://www.cnblogs.com/skycore/p/4908873.html

你可能感兴趣的文章
数据结构与算法--拓补排序及无环加权有向图的最短路径
查看>>
【nginx网站性能优化篇(1)】gzip压缩与expire浏览器缓存
查看>>
0041 Java学习笔记-多线程-线程池、ForkJoinPool、ThreadLocal
查看>>
DCOTYPE文档声明
查看>>
【转】在Eclipse中配置tomcat
查看>>
Docker 容器监控平台-Weave Scope
查看>>
Swift中关于集合计算的几种函数记录(intersect、symmetricDifference、union、subtract)...
查看>>
raspberryPi安装卸载程序相关问题之apt-ge转载
查看>>
Maven实践一:HelloWord项目构建部署
查看>>
大数据概述
查看>>
使用ViewPager实现android软件使用向导的功能
查看>>
Scala第三章学习笔记
查看>>
B - The Suspects
查看>>
组合模式
查看>>
Python的可变与不可变数据类型
查看>>
C# 如何获取项目的根目录
查看>>
CLR内存管理之GC的工作原理
查看>>
阿拉伯数字 转 汉字大写
查看>>
cent os 7.2安装oracle 12cr2
查看>>
ssh接收和返回xml
查看>>