具体描述见《统计学习方法》第三章。
1 // 2 // main.cpp 3 // kNN 4 // 5 // Created by feng on 15/10/24. 6 // Copyright © 2015年 ttcn. All rights reserved. 7 // 8 9 #include10 #include 11 #include 12 #include 13 using namespace std; 14 15 template 16 struct KdTree { 17 // ctor 18 KdTree():parent(nullptr), leftChild(nullptr), rightChild(nullptr) {} 19 20 // KdTree是否为空 21 bool isEmpty() { return root.empty(); } 22 23 // KdTree是否为叶子节点 24 bool isLeaf() { return !root.empty() && !leftChild && !rightChild;} 25 26 // KdTree是否为根节点 27 bool isRoot() { return !isEmpty() && !parent;} 28 29 // 判断KdTree是否为根节点的左儿子 30 bool isLeft() { return parent->leftChild->root == root; } 31 32 // 判断KdTree是否为根节点的右儿子 33 bool isRight() { return parent->rightChild->root == root; } 34 35 // 存放根节点的数据 36 vector root; 37 38 // 父节点 39 KdTree *parent; 40 41 // 左儿子 42 KdTree *leftChild; 43 44 // 右儿子 45 KdTree *rightChild; 46 }; 47 48 49 /** 50 * 矩阵转置 51 * 52 * @param matrix 原矩阵 53 * 54 * @return 原矩阵的转置矩阵 55 */ 56 template 57 vector > transpose(const vector > &matrix) { 58 size_t rows = matrix.size(); 59 size_t cols = matrix[0].size(); 60 vector > trans(cols, vector (rows, 0)); 61 for (size_t i = 0; i < cols; ++i) { 62 for (size_t j = 0; j < rows; ++j) { 63 trans[i][j] = matrix[j][i]; 64 } 65 } 66 67 return trans; 68 } 69 70 /** 71 * 找中位数 72 * 73 * @param vec 数组 74 * 75 * @return 数组中的中位数 76 */ 77 template 78 T findMiddleValue(vector vec) { 79 sort(vec.begin(), vec.end()); 80 size_t pos = vec.size() / 2; 81 return vec[pos]; 82 } 83 84 /** 85 * 递归构造KdTree 86 * 87 * @param tree KdTree根节点 88 * @param data 数据矩阵 89 * @param depth 当前节点深度 90 * 91 * @return void 92 */ 93 template 94 void buildKdTree(KdTree *tree, vector > &data, size_t depth) { 95 // 输入数据个数 96 size_t samplesNum = data.size(); 97 98 if (samplesNum == 0) { 99 return;100 }101 102 if (samplesNum == 1) {103 tree->root = data[0];104 return;105 }106 107 // 每一个输入数据的维度,属性个数108 size_t k = data[0].size();109 vector > transData = transpose(data);110 111 // 找到当前切分点112 size_t splitAttributeIndex = depth % k;113 vector splitAttributes = transData[splitAttributeIndex];114 T splitValue = findMiddleValue(splitAttributes);115 116 vector > leftSubSet;117 vector > rightSubset;118 119 for (size_t i = 0; i < samplesNum; ++i) {120 if (splitAttributes[i] == splitValue && tree->isEmpty()) {121 tree->root = data[i];122 } else if (splitAttributes[i] < splitValue) {123 leftSubSet.push_back(data[i]);124 } else {125 rightSubset.push_back(data[i]);126 }127 }128 129 tree->leftChild = new KdTree ;130 tree->leftChild->parent = tree;131 tree->rightChild = new KdTree ;132 tree->rightChild->parent = tree;133 buildKdTree(tree->leftChild, leftSubSet, depth + 1);134 buildKdTree(tree->rightChild, rightSubset, depth + 1);135 }136 137 /**138 * 递归打印KdTree139 *140 * @param tree KdTree141 * @param depth 当前深度142 *143 * @return void144 */145 template 146 void printKdTree(const KdTree *tree, size_t depth) {147 for (size_t i = 0; i < depth; ++i) {148 cout << "\t";149 }150 151 for (size_t i = 0; i < tree->root.size(); ++i) {152 cout << tree->root[i] << " ";153 }154 cout << endl;155 156 if (tree->leftChild == nullptr && tree->rightChild == nullptr) {157 return;158 } else {159 if (tree->leftChild) {160 for (int i = 0; i < depth + 1; ++i) {161 cout << "\t";162 }163 cout << "left : ";164 printKdTree(tree->leftChild, depth + 1);165 }166 167 cout << endl;168 169 if (tree->rightChild) {170 for (size_t i = 0; i < depth + 1; ++i) {171 cout << "\t";172 }173 cout << "right : ";174 printKdTree(tree->rightChild, depth + 1);175 }176 cout << endl;177 }178 }179 180 /**181 * 节点之间的欧氏距离182 *183 * @param p1 节点1184 * @param p2 节点2185 *186 * @return 节点之间的欧式距离187 */188 template 189 T calDistance(const vector &p1, const vector &p2) {190 T res = 0;191 for (size_t i = 0; i < p1.size(); ++i) {192 res += pow(p1[i] - p2[i], 2);193 }194 195 return res;196 }197 198 /**199 * 搜索目标节点的最近邻200 *201 * @param tree KdTree202 * @param goal 待分类的节点203 *204 * @return 最近邻节点205 */206 template 207 vector searchNearestNeighbor(KdTree *tree, const vector &goal ) {208 // 节点数属性个数209 size_t k = tree->root.size();210 // 划分的索引211 size_t d = 0;212 KdTree *currentTree = tree;213 vector currentNearest = currentTree->root;214 // 找到目标节点的最叶节点215 while (!currentTree->isLeaf()) {216 size_t index = d % k;217 if (currentTree->rightChild->isEmpty() || goal[index] < currentNearest[index]) {218 currentTree = currentTree->leftChild;219 } else {220 currentTree = currentTree->rightChild;221 }222 223 ++d;224 }225 currentNearest = currentTree->root;226 T currentDistance = calDistance(goal, currentTree->root);227 228 KdTree *searchDistrict;229 if (currentTree->isLeft()) {230 if (!(currentTree->parent->rightChild)) {231 searchDistrict = currentTree;232 } else {233 searchDistrict = currentTree->parent->rightChild;234 }235 } else {236 searchDistrict = currentTree->parent->leftChild;237 }238 239 while (!(searchDistrict->parent)) {240 T districtDistance = abs(goal[(d + 1) % k] - searchDistrict->parent->root[(d + 1) % k]);241 242 if (districtDistance < currentDistance) {243 T parentDistance = calDistance(goal, searchDistrict->parent->root);244 245 if (parentDistance < currentDistance) {246 currentDistance = parentDistance;247 currentTree = searchDistrict->parent;248 currentNearest = currentTree->root;249 }250 251 if (!searchDistrict->isEmpty()) {252 T rootDistance = calDistance(goal, searchDistrict->root);253 if (rootDistance < currentDistance) {254 currentDistance = rootDistance;255 currentTree = searchDistrict;256 currentNearest = currentTree->root;257 }258 }259 260 if (!(searchDistrict->leftChild)) {261 T leftDistance = calDistance(goal, searchDistrict->leftChild->root);262 if (leftDistance < currentDistance) {263 currentDistance = leftDistance;264 currentTree = searchDistrict;265 currentNearest = currentTree->root;266 }267 }268 269 if (!(searchDistrict->rightChild)) {270 T rightDistance = calDistance(goal, searchDistrict->rightChild->root);271 if (rightDistance < currentDistance) {272 currentDistance = rightDistance;273 currentTree = searchDistrict;274 currentNearest = currentTree->root;275 }276 }277 278 }279 280 if (!(searchDistrict->parent->parent)) {281 searchDistrict = searchDistrict->parent->isLeft()? searchDistrict->parent->parent->rightChild : searchDistrict->parent->parent->leftChild;282 } else {283 searchDistrict = searchDistrict->parent;284 }285 ++d;286 }287 288 return currentNearest;289 }290 291 int main(int argc, const char * argv[]) {292 vector > trainDataSet{ { 2,3},{ 5,4},{ 9,6},{ 4,7},{ 8,1},{ 7,2}};293 KdTree *kdTree = new KdTree ;294 buildKdTree(kdTree, trainDataSet, 0);295 printKdTree(kdTree, 0);296 297 vector goal{ 3, 4.5};298 vector nearestNeighbor = searchNearestNeighbor(kdTree, goal);299 300 for (auto i : nearestNeighbor) {301 cout << i << " ";302 }303 cout << endl;304 305 return 0;306 }