#include <mpi.h>

/****************************
Assignment part 1:

star_Reduce as written here works correctly only when type == MPI_DOUBLE
and binop == MPI_SUM.  Rewrite it so that it works correctly
when type is MPI_DOUBLE or MPI_INT and when binop is MPI_SUM, MPI_PROD, or
MPI_MAX.
****************************/
int star_Reduce(void* local, void* result, int size, MPI_Datatype type, MPI_Op binop, int root, MPI_Comm comm)
{
	int r,p; 
	MPI_Status status;
	MPI_Comm_size(comm, &p);
	MPI_Comm_rank(comm, &r);
	if (r != root)
	    MPI_Send(local, size, type, root, 0, comm);
	else
	{
	    double* other_local_d = new double[size];
	    double* result_d = static_cast<double*>(result);
	    double* local_d = static_cast<double*>(local);
	    for (int i = 0; i < size; ++i) result_d[i] = local_d[i];
	    for (int rr = 0; rr < p; ++rr) if (rr != root)
	    {  
		MPI_Recv(other_local_d, size, type, rr/*MPI_ANY_SOURCE*/, 0, comm, &status);
	        for (int i = 0; i < size; ++i) result_d[i] += other_local_d[i];
	    }
	}
}

/****************************
Assignment part 2:

tree_Reduce as written here works correctly only when type == MPI_DOUBLE,
binop == MPI_SUM, and root == 0.  Rewrite it so that it works correctly
when type is MPI_DOUBLE or MPI_INT and when binop is MPI_SUM, MPI_PROD, or
MPI_MAX, and when root is any rank in [0..p).
****************************/
int tree_Reduce(void* local, void* result, int size, MPI_Datatype type, MPI_Op binop, int root, MPI_Comm comm)
{
	int r,p; 
	MPI_Status status;
	MPI_Comm_size(comm, &p);
	MPI_Comm_rank(comm, &r);
	double* local_result = new double[size];
	double* local_d = static_cast<double*>(local);
	for(int i = 0; i < size; ++i) local_result[i] = local_d[i];

	double* other_result = new double[size];

        // k = highest power of 2 which is less than p.
        int k = 1;  for (; k < p; k *= 2); k /= 2;

	for (; k > 0; k /= 2)
	{   if (r < k && r+k < p)
            {
	        MPI_Recv(other_result, size, type, r+k, 0, comm, &status);
	        for(int i = 0; i < size; ++i) local_result[i] += other_result[i];
            }
	    else if (k <= r && r < 2*k)
	        MPI_Send(local_result, size, type, r-k, 0, comm);
	    else 
		;// fugeddaboudit
	}
	if (r == 0) 
	{
	    double* result_d = static_cast<double*>(result);
	    for(int i = 0; i < size; ++i) result_d[i] = local_result[i];
	}
}
