// Copyright (C) 2015  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_BOTTOM_uP_CLUSTER_Hh_
#define DLIB_BOTTOM_uP_CLUSTER_Hh_

#include <queue>
#include <map>

#include "bottom_up_cluster_abstract.h"
#include "../algs.h"
#include "../matrix.h"
#include "../disjoint_subsets.h"
#include "../graph_utils.h"


namespace dlib
{

// ----------------------------------------------------------------------------------------

    namespace buc_impl
    {
        inline void merge_sets (
            matrix<double>& dists,
            unsigned long dest,
            unsigned long src
        )
        {
            for (long r = 0; r < dists.nr(); ++r)
                dists(dest,r) = dists(r,dest) = std::max(dists(r,dest), dists(r,src));
        }

        struct compare_dist
        {
            bool operator() (
                const sample_pair& a,
                const sample_pair& b
            ) const
            {
                return a.distance() > b.distance();
            }
        };
    }

// ----------------------------------------------------------------------------------------

    template <
        typename EXP
        >
    unsigned long bottom_up_cluster (
        const matrix_exp<EXP>& dists_,
        std::vector<unsigned long>& labels,
        unsigned long min_num_clusters,
        double max_dist = std::numeric_limits<double>::infinity()
    )
    {
        matrix<double> dists = matrix_cast<double>(dists_);
        // make sure requires clause is not broken
        DLIB_CASSERT(dists.nr() == dists.nc() && min_num_clusters > 0, 
            "\t unsigned long bottom_up_cluster()"
            << "\n\t Invalid inputs were given to this function."
            << "\n\t dists.nr(): " << dists.nr() 
            << "\n\t dists.nc(): " << dists.nc() 
            << "\n\t min_num_clusters: " << min_num_clusters 
            );

        using namespace buc_impl;

        labels.resize(dists.nr());
        disjoint_subsets sets;
        sets.set_size(dists.nr());
        if (labels.size() == 0)
            return 0;

        // push all the edges in the graph into a priority queue so the best edges to merge
        // come first.
        std::priority_queue<sample_pair, std::vector<sample_pair>, compare_dist> que;
        for (long r = 0; r < dists.nr(); ++r)
            for (long c = r+1; c < dists.nc(); ++c)
                que.push(sample_pair(r,c,dists(r,c)));

        // Now start merging nodes.
        for (unsigned long iter = min_num_clusters; iter < sets.size(); ++iter)
        {
            // find the next best thing to merge.
            double best_dist = que.top().distance();
            unsigned long a = sets.find_set(que.top().index1());
            unsigned long b = sets.find_set(que.top().index2());
            que.pop();
            // we have been merging and modifying the distances, so make sure this distance
            // is still valid and these guys haven't been merged already.
            while(a == b || best_dist < dists(a,b))
            {
                // Haven't merged it yet, so put it back in with updated distance for
                // reconsideration later.
                if (a != b)
                    que.push(sample_pair(a, b, dists(a, b)));

                best_dist = que.top().distance();
                a = sets.find_set(que.top().index1());
                b = sets.find_set(que.top().index2());
                que.pop();
            }


            // now merge these sets if the best distance is small enough
            if (best_dist > max_dist)
                break;
            unsigned long news = sets.merge_sets(a,b);
            unsigned long olds = (news==a)?b:a;
            merge_sets(dists, news, olds);
        }

        // figure out which cluster each element is in.  Also make sure the labels are
        // contiguous.
        std::map<unsigned long, unsigned long> relabel;
        for (unsigned long r = 0; r < labels.size(); ++r)
        {
            unsigned long l = sets.find_set(r);
            // relabel to make contiguous
            if (relabel.count(l) == 0)
            {
                unsigned long next = relabel.size();
                relabel[l] = next;
            }
            labels[r] = relabel[l];
        }


        return relabel.size();
    }

// ----------------------------------------------------------------------------------------

}

#endif // DLIB_BOTTOM_uP_CLUSTER_Hh_