1 /**
2 This file contains functions for performing hierarchical clustering, and
3 can be used for drawing heatmaps and, eventually, dendrograms.
4 
5 Bugs:  Not very efficient, though it probably doesn't need to be because
6        the use case is visualizations, and all the information has to fit
7        reasonably on the visualization.  Therefore, N will always be fairly
8        small.
9 
10 Copyright (C) 2011 David Simcha
11 
12 License:
13 
14 Boost Software License - Version 1.0 - August 17th, 2003
15 
16 Permission is hereby granted, free of charge, to any person or organization
17 obtaining a copy of the software and accompanying documentation covered by
18 this license (the "Software") to use, reproduce, display, distribute,
19 execute, and transmit the Software, and to prepare derivative works of the
20 Software, and to permit third-parties to whom the Software is furnished to
21 do so, all subject to the following:
22 
23 The copyright notices in the Software and this entire statement, including
24 the above license grant, this restriction and the following disclaimer,
25 must be included in all copies of the Software, in whole or in part, and
26 all derivative works of the Software, unless such copies or derivative
27 works are solely in the form of machine-executable object code generated by
28 a source language processor.
29 
30 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
31 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
32 FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
33 SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
34 FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
35 ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
36 DEALINGS IN THE SOFTWARE.
37  */
38 module plot2kill.hierarchical;
39 
40 import plot2kill.util, std.typetuple;
41 
42 /// Used for mean linkage.
43 double mean(double[] stuff) {
44     return reduce!"a + b"(0.0, stuff) / stuff.length;
45 }
46 
47 /// Euclidean distance.
48 double euclidean(R1, R2)(R1 a, R2 b)
49 if(allSatisfy!(isInputRange, R1, R2) && is(ElementType!R1 : double) &&
50 is(ElementType!R2 : double)) {
51 
52     return sqrt(
53         reduce!"a + (b[0] - b[1]) ^^ 2"(0.0, zip(a, b))
54     );
55 }
56 
57 /**
58 A tree for defining hierarchical clusters.
59 */
60 struct Cluster {
61     private this
62     (Cluster* left, Cluster* right, double dist, size_t index) {
63         this.left = left;
64         this.right = right;
65         this.distance = dist;
66         this.index = index;
67     }
68 
69     ///
70     Cluster* left;
71 
72     ///
73     Cluster* right;
74 
75     ///
76     double distance;
77 
78     /**
79     The index of the data w.r.t. matrix, if this is a leaf node, or size_t.max
80     if this is not a leaf node.
81     */
82     size_t index = size_t.max;
83 
84     /// The name of the sample, populated only for leaf nodes.
85     string name;
86 
87     /// True if this cluster doesn't have children.
88     bool isLeaf() @property pure nothrow const {
89         if(left is null) assert(right is null);
90         return left is null;
91     }
92 
93     // Tracks distances to other clusters that have already been computed.
94     // Is always null once hierarchicalCluster returns, because it's no
95     // longer needed.
96     private double[Cluster*] distCache;
97 
98     private double calculateDistance
99     (alias linkage)(ref Cluster rhs, double[][] distances) {
100         if(&rhs in distCache) {
101             return distCache[&rhs];
102         }
103 
104         auto app = appender!(double[])();
105 
106         void addDists(ref Cluster a, ref Cluster b) {
107             if(a.isLeaf) {
108                 if(b.isLeaf) {
109                     auto index1 = max(a.index, b.index);
110                     auto index2 = min(a.index, b.index);
111                     assert(index1 != index2);
112                     app.put(distances[index1][index2]);
113                 } else {
114                     addDists(a, *(b.left));
115                     addDists(a, *(b.right));
116                 }
117             } else {
118                 addDists(*(a.left), b);
119                 addDists(*(a.right), b);
120             }
121         }
122 
123         addDists(this, rhs);
124         auto ret = linkage(app.data);
125         distCache[&rhs] = ret;
126         return ret;
127     }
128 
129     /// Iterate over the leaf nodes.
130     int opApply(int delegate(ref Cluster) dg) {
131         int res;
132 
133         if(isLeaf) {
134             res = dg(this);
135             return res;
136         }
137 
138         assert(left);
139         assert(right);
140 
141         res = left.opApply(dg);
142         if(res) return res;
143 
144         res = right.opApply(dg);
145         return res;
146     }
147     
148     /// The number of leaf nodes in this cluster.
149     int nLeafNodes() const pure nothrow @property {
150         if(isLeaf) return 1;
151         return left.nLeafNodes + right.nLeafNodes;
152     }
153 }
154 
155 /**
156 Cluster by rows or columns?
157 */
158 enum ClusterBy {
159     ///
160     rows,
161 
162     ///
163     columns
164 }
165 
166 /**
167 Perform hierarchical clustering.  matrix must be rectangular and represents
168 the data matrix.  distance is the distance metric, which must be a function
169 that accepts two equal-length input ranges of doubles.  linkage is the linkage
170 function, which must accept a double[] that represents all possible pairwise
171 distances between two clusters and return a summary of these distances.
172 
173 clusterBy indicates whether the rows or the columns of the matrix should
174 be clustered.
175 
176 names is an optional string array of names, one per sample.  If provided,
177 this information will be placed in the Cluster objects, allowing samples
178 to be tracked by name.
179 */
180 Cluster* hierarchicalCluster(alias distance = euclidean, alias linkage = mean)(
181     double[][] matrix, ClusterBy clusterBy, string[] names = null
182 ) {
183     enforce(matrix.length > 0, "Cannot cluster zero elements.");
184     foreach(i; 1..matrix.length) {
185         enforce(matrix[i].length == matrix[0].length,
186             "matrix must be rectangular for hierarchicalCluster.");
187     }
188 
189     Cluster*[] clusters;
190 
191     // Make distance matrix.
192     double[][] distances;
193     if(clusterBy == ClusterBy.rows) {
194         clusters = new Cluster*[matrix.length];
195         distances = new double[][matrix.length];
196 
197         enforce(names.length == 0 || names.length == matrix.length,
198             "names.length must be equal to matrix.length for " ~
199             "hierarchical clustering by row.");
200 
201         foreach(i; 0..clusters.length) {
202             distances[i] = new double[i];
203 
204             foreach(j; 0..i) {
205                 distances[i][j] = distance(matrix[i], matrix[j]);
206             }
207         }
208     } else {
209         assert(clusterBy == ClusterBy.columns);
210         clusters = new Cluster*[matrix[0].length];
211         distances = new double[][matrix[0].length];
212 
213         enforce(names.length == 0 || names.length == matrix[0].length,
214             "names.length must be equal to matrix[0].length for " ~
215             "hierarchical clustering by row.");
216 
217         foreach(i; 0..clusters.length) {
218             distances[i] = new double[i];
219 
220             foreach(j; 0..i) {
221                 distances[i][j] = distance(
222                     transversal(matrix, i),
223                     transversal(matrix, j)
224                 );
225             }
226         }
227     }
228 
229     foreach(i; 0..clusters.length) {
230         clusters[i] = new Cluster(null, null, double.nan, i);
231         if(names.length) clusters[i].name = names[i];
232     }
233 
234     while(clusters.length > 1) {
235         // Find min dist pair.
236         size_t minPair1, minPair2;
237         double minDist = double.infinity;
238 
239         foreach(i; 0..clusters.length) foreach(j; i + 1..clusters.length) {
240             immutable dist =
241                 clusters[i].calculateDistance!(linkage)(*clusters[j], distances);
242 
243             if(dist < minDist) {
244                 minPair1 = i;
245                 minPair2 = j;
246                 minDist = dist;
247             }
248         }
249 
250         // Clean up excess distCache stuff, let it get GC'd.
251         clusters[minPair1].distCache = null;
252         clusters[minPair2].distCache = null;
253 
254         foreach(cluster; clusters) {
255             if(clusters[minPair1] in cluster.distCache) {
256                 cluster.distCache.remove(clusters[minPair1]);
257             }
258 
259             if(clusters[minPair2] in cluster.distCache) {
260                 cluster.distCache.remove(clusters[minPair2]);
261             }
262         }
263 
264         clusters[minPair1] = new Cluster(
265             clusters[minPair1], clusters[minPair2], minDist, size_t.max);
266 
267         clusters = clusters.remove(minPair2);
268     }
269 
270     distances[] = null;  // Make sure it gets gc'd.
271     clusters[0].distCache = null;
272     return clusters[0];
273 }