1 /** 2 This file contains functions for performing hierarchical clustering, and 3 can be used for drawing heatmaps and, eventually, dendrograms. 4 5 Bugs: Not very efficient, though it probably doesn't need to be because 6 the use case is visualizations, and all the information has to fit 7 reasonably on the visualization. Therefore, N will always be fairly 8 small. 9 10 Copyright (C) 2011 David Simcha 11 12 License: 13 14 Boost Software License - Version 1.0 - August 17th, 2003 15 16 Permission is hereby granted, free of charge, to any person or organization 17 obtaining a copy of the software and accompanying documentation covered by 18 this license (the "Software") to use, reproduce, display, distribute, 19 execute, and transmit the Software, and to prepare derivative works of the 20 Software, and to permit third-parties to whom the Software is furnished to 21 do so, all subject to the following: 22 23 The copyright notices in the Software and this entire statement, including 24 the above license grant, this restriction and the following disclaimer, 25 must be included in all copies of the Software, in whole or in part, and 26 all derivative works of the Software, unless such copies or derivative 27 works are solely in the form of machine-executable object code generated by 28 a source language processor. 29 30 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 31 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 32 FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT 33 SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE 34 FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, 35 ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 36 DEALINGS IN THE SOFTWARE. 37 */ 38 module plot2kill.hierarchical; 39 40 import plot2kill.util, std.typetuple; 41 42 /// Used for mean linkage. 43 double mean(double[] stuff) { 44 return reduce!"a + b"(0.0, stuff) / stuff.length; 45 } 46 47 /// Euclidean distance. 48 double euclidean(R1, R2)(R1 a, R2 b) 49 if(allSatisfy!(isInputRange, R1, R2) && is(ElementType!R1 : double) && 50 is(ElementType!R2 : double)) { 51 52 return sqrt( 53 reduce!"a + (b[0] - b[1]) ^^ 2"(0.0, zip(a, b)) 54 ); 55 } 56 57 /** 58 A tree for defining hierarchical clusters. 59 */ 60 struct Cluster { 61 private this 62 (Cluster* left, Cluster* right, double dist, size_t index) { 63 this.left = left; 64 this.right = right; 65 this.distance = dist; 66 this.index = index; 67 } 68 69 /// 70 Cluster* left; 71 72 /// 73 Cluster* right; 74 75 /// 76 double distance; 77 78 /** 79 The index of the data w.r.t. matrix, if this is a leaf node, or size_t.max 80 if this is not a leaf node. 81 */ 82 size_t index = size_t.max; 83 84 /// The name of the sample, populated only for leaf nodes. 85 string name; 86 87 /// True if this cluster doesn't have children. 88 bool isLeaf() @property pure nothrow const { 89 if(left is null) assert(right is null); 90 return left is null; 91 } 92 93 // Tracks distances to other clusters that have already been computed. 94 // Is always null once hierarchicalCluster returns, because it's no 95 // longer needed. 96 private double[Cluster*] distCache; 97 98 private double calculateDistance 99 (alias linkage)(ref Cluster rhs, double[][] distances) { 100 if(&rhs in distCache) { 101 return distCache[&rhs]; 102 } 103 104 auto app = appender!(double[])(); 105 106 void addDists(ref Cluster a, ref Cluster b) { 107 if(a.isLeaf) { 108 if(b.isLeaf) { 109 auto index1 = max(a.index, b.index); 110 auto index2 = min(a.index, b.index); 111 assert(index1 != index2); 112 app.put(distances[index1][index2]); 113 } else { 114 addDists(a, *(b.left)); 115 addDists(a, *(b.right)); 116 } 117 } else { 118 addDists(*(a.left), b); 119 addDists(*(a.right), b); 120 } 121 } 122 123 addDists(this, rhs); 124 auto ret = linkage(app.data); 125 distCache[&rhs] = ret; 126 return ret; 127 } 128 129 /// Iterate over the leaf nodes. 130 int opApply(int delegate(ref Cluster) dg) { 131 int res; 132 133 if(isLeaf) { 134 res = dg(this); 135 return res; 136 } 137 138 assert(left); 139 assert(right); 140 141 res = left.opApply(dg); 142 if(res) return res; 143 144 res = right.opApply(dg); 145 return res; 146 } 147 148 /// The number of leaf nodes in this cluster. 149 int nLeafNodes() const pure nothrow @property { 150 if(isLeaf) return 1; 151 return left.nLeafNodes + right.nLeafNodes; 152 } 153 } 154 155 /** 156 Cluster by rows or columns? 157 */ 158 enum ClusterBy { 159 /// 160 rows, 161 162 /// 163 columns 164 } 165 166 /** 167 Perform hierarchical clustering. matrix must be rectangular and represents 168 the data matrix. distance is the distance metric, which must be a function 169 that accepts two equal-length input ranges of doubles. linkage is the linkage 170 function, which must accept a double[] that represents all possible pairwise 171 distances between two clusters and return a summary of these distances. 172 173 clusterBy indicates whether the rows or the columns of the matrix should 174 be clustered. 175 176 names is an optional string array of names, one per sample. If provided, 177 this information will be placed in the Cluster objects, allowing samples 178 to be tracked by name. 179 */ 180 Cluster* hierarchicalCluster(alias distance = euclidean, alias linkage = mean)( 181 double[][] matrix, ClusterBy clusterBy, string[] names = null 182 ) { 183 enforce(matrix.length > 0, "Cannot cluster zero elements."); 184 foreach(i; 1..matrix.length) { 185 enforce(matrix[i].length == matrix[0].length, 186 "matrix must be rectangular for hierarchicalCluster."); 187 } 188 189 Cluster*[] clusters; 190 191 // Make distance matrix. 192 double[][] distances; 193 if(clusterBy == ClusterBy.rows) { 194 clusters = new Cluster*[matrix.length]; 195 distances = new double[][matrix.length]; 196 197 enforce(names.length == 0 || names.length == matrix.length, 198 "names.length must be equal to matrix.length for " ~ 199 "hierarchical clustering by row."); 200 201 foreach(i; 0..clusters.length) { 202 distances[i] = new double[i]; 203 204 foreach(j; 0..i) { 205 distances[i][j] = distance(matrix[i], matrix[j]); 206 } 207 } 208 } else { 209 assert(clusterBy == ClusterBy.columns); 210 clusters = new Cluster*[matrix[0].length]; 211 distances = new double[][matrix[0].length]; 212 213 enforce(names.length == 0 || names.length == matrix[0].length, 214 "names.length must be equal to matrix[0].length for " ~ 215 "hierarchical clustering by row."); 216 217 foreach(i; 0..clusters.length) { 218 distances[i] = new double[i]; 219 220 foreach(j; 0..i) { 221 distances[i][j] = distance( 222 transversal(matrix, i), 223 transversal(matrix, j) 224 ); 225 } 226 } 227 } 228 229 foreach(i; 0..clusters.length) { 230 clusters[i] = new Cluster(null, null, double.nan, i); 231 if(names.length) clusters[i].name = names[i]; 232 } 233 234 while(clusters.length > 1) { 235 // Find min dist pair. 236 size_t minPair1, minPair2; 237 double minDist = double.infinity; 238 239 foreach(i; 0..clusters.length) foreach(j; i + 1..clusters.length) { 240 immutable dist = 241 clusters[i].calculateDistance!(linkage)(*clusters[j], distances); 242 243 if(dist < minDist) { 244 minPair1 = i; 245 minPair2 = j; 246 minDist = dist; 247 } 248 } 249 250 // Clean up excess distCache stuff, let it get GC'd. 251 clusters[minPair1].distCache = null; 252 clusters[minPair2].distCache = null; 253 254 foreach(cluster; clusters) { 255 if(clusters[minPair1] in cluster.distCache) { 256 cluster.distCache.remove(clusters[minPair1]); 257 } 258 259 if(clusters[minPair2] in cluster.distCache) { 260 cluster.distCache.remove(clusters[minPair2]); 261 } 262 } 263 264 clusters[minPair1] = new Cluster( 265 clusters[minPair1], clusters[minPair2], minDist, size_t.max); 266 267 clusters = clusters.remove(minPair2); 268 } 269 270 distances[] = null; // Make sure it gets gc'd. 271 clusters[0].distCache = null; 272 return clusters[0]; 273 }