// Closest pair of points.  -bds 10/06
#include <iostream>
#include <vector>
#include <math.h>
using namespace std;

// a point is represented as the pair of its x and y coordinates.
typedef pair<double, double> Point;

// utility functions //

// compare points by x-coord or by y-coord, output points. 
bool xLess(const Point& p, const Point& q) { return p.first < q.first; }
bool yLess(const Point& p, const Point& q) { return p.second < q.second; }
ostream& operator<< (ostream& out, const Point& p) 
{ out << "(" << p.first << ", " << p.second << ")"; return out; }

// squaring, distance between points, min in linearly ordered set.
double sqr(double x) { return x*x; }
double dist(const Point& a, const Point& b)
{ return sqrt(sqr(a.first - b.first) + sqr(a.second - b.second)); }
template<typename key>
key& min(key& a, key& b) { return a < b ? a : b; }


/* d <- closestPair(begin, end, pointPair)

precondition: begin and end are random access iterators over a rearrangable 
sequence of Points.  The sequence first element is *begin, and end references 
the first position beyond the sequence (hence it is invalid to dereference end).

postcondition:  The shortest distance between any two points in the sequence 
referenced by [begin..end) is returned and pointPair is set to a pair of points
which have that distance.  If fewer than 2 points are in the sequence, inf is 
returned and pointPair's value is meaningless. 

The algorithm can and does permute the order of the points in the sequence.

The runtime is O(n lg(n)).
*/

template<class iterator>
double closestPair(iterator begin, iterator end, pair<Point, Point>& pointPair)
{
  int n = end - begin; // length of the sequence.

  // Step 0. base cases.  //

  if (n < 2) return 1000000000; // Inf;
  if (n == 2) 
  {
    pointPair.first = begin[0]; 
    pointPair.second = begin[1]; 
    return dist(begin[0], begin[1]);
  }

  // Step 1. // Sort by x coordinate
  sort(begin, end, xLess);

  // Step 2, 3. // recursive calls on left half, right half.
  double d; // d will become the overall minimal distance
  pair<Point, Point> pp; // and pp will be the corresponding points.

  double dL, dR;
  pair<Point, Point> pL, pR;
  dL = closestPair(begin, begin+n/2, pL);
  dR = closestPair(begin+n/2, end, pR);

  // Step 4. // d is min dist within either side.
  d = min(dL, dR);
  pp = ((d == dL) ? pL : pR);

  /* Now d is the minimum distance between two points which are either both on 
  the left side or both on the right side.  It remains to see if there is a 
  closer pair, one of which is on the left and one on the right.  To check for 
  this, we first reduce attention to points within distance d of the dividing 
  line, then use the observation that if any point in this strip has a neighbor 
  closer than d in the strip (and not downwards from the point), then that 
  neighbor is within the next seven points going upwards (i.e. as sorted by 
  y-coordinate).  
  */ 

  // Step 5. Trim to vertical strip of width 2d around the median x coordinate.
  sort(begin, end, xLess);

  double minX = (begin + n/2)->first - d;
  iterator newBegin = begin + n/2;
  while (newBegin >= begin && newBegin->first >= minX ){ --newBegin; } 
  newBegin++; // the first one that is >= minX.

  double maxX = (begin + n/2)->first + d;
  iterator newEnd = begin + n/2;
  while (newEnd < end && newEnd->first <= maxX ){ ++newEnd; } 

  // Step 6. Sort by y coordinate.
  sort(newBegin, newEnd, yLess);

  // Step 7. Check distances of each point with 7 next points up the strip.
  for (iterator i = newBegin; i < newEnd; ++i)
  { iterator k = min(i+8, newEnd); // k is end of the "next 7"
    for(iterator j = i+1; j < k; ++j) // j takes at most 7 values.
    {  double dij = dist(*i, *j);
       if (dij < d) { d = dij; pp.first = *i; pp.second = *j; }
    }
  }
  pointPair = pp;
  return d;
}
/* Runtime analysis.  Let T(n) denote the run time.  Examining step 0,
we see that T(n) < c, for some constant c when n <= 2x

In steps 2,3 we have recursive calls with sequences of half the length 
and each other step is using O(n lg(n)) or less time.

Thus we have T(n) <= 2T(n/2) + d n lg(n), for some constant d.
This has a solution T(n) in O(n lg(n)^2).  It can be improved to O(n lg(n))
*/

// test program //
int main(int ac, char* av[] )
{
	int n = 16;
	if (ac != 2) 
		cerr << "usage: " << av[0] << "n, where n is the number of random points to be in the set." << endl;
	else
		n = atoi(av[1]);
		
	// making n = k^2 points.
	int k = sqrt(n);
	vector<Point> V;
	for (double i = 0; i < k; ++i)
 	  for (double j = 0; j < k; ++j)
	    V.push_back(Point(rand(),rand()));

	// find the closest pair.
	pair<Point, Point> pp;
	double d = closestPair(V.begin(), V.end(), pp); 
	cout << "min dist " << d << " is achieved between points "; 
	cout << pp.first << " and  " << pp.second << " of the " << k*k << " points." << endl;
}

