#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <sys/time.h>

// assumes that sizeof(double) is a power of two
#define ALIGN_UP(n) (((n) + (sizeof(double) - 1)) & ~(sizeof(double) - 1))

#define LOOP 50

double** dmalloc(int x, int y)
{
  size_t prefix;
  size_t size;
  double** result;
  double* row;
  int i;
  
  prefix = ALIGN_UP(x * sizeof(double*));
  size = prefix + x * y * sizeof(double);
  printf("X=%d X=%d prefix_size=%u total_size=%u\n", x, y, prefix, size);
  result = malloc(size); assert(result != 0);
  row = (double*)((char*)result + prefix);
  for(i = 0; i < x; i++) { result[i] = row; row += y; }
  return result;
}
    
double access_a(double** a, int x, int y) { return a[x][y]; }
double access_b(double* b, int x, int Y, int y) { return b[x * Y + y]; }

void print_delta(const char* prefix, const struct timeval* start,
  const struct timeval* end)
{
  int delta = (end->tv_sec - start->tv_sec) * 1000000
  + end->tv_usec - start->tv_usec;
  printf("%-8s: %d:%d seconds\n",
    prefix, delta / 1000000, delta % 1000000);
}

void test_a(double** a, int X, int Y)
{
  int x, y, i;
  struct timeval tv_start;
  struct timeval tv_end;

#if LOOP != 1
  /* dummy loop to get initialize cache lines */
  for(x = 0; x < X; x++)
    for(y = 0; y < Y; y++)
      a[x][y] = 3.14;
#endif
  gettimeofday(&tv_start, 0);
  for(i = 0; i < LOOP; i++)
    for(x = 0; x < X; x++)
      for(y = 0; y < Y; y++)
	a[x][y] *= 1.5;
  gettimeofday(&tv_end, 0);
  print_delta("double**", &tv_start, &tv_end);
}

void test_b(double* b, int X, int Y)
{
  int x, y, i;
  struct timeval tv_start;
  struct timeval tv_end;

#if LOOP != 1
  /* dummy loop to get initialize cache lines */
  for(x = 0; x < X; x++)
    for(y = 0; y < Y; y++)
      b[x * Y + y] = 3.14;
#endif
  gettimeofday(&tv_start, 0);
  for(i = 0; i < LOOP; i++)
    for(x = 0; x < X; x++)
      for(y = 0; y < Y; y++)
	b[x * Y + y] *= 1.5;
  gettimeofday(&tv_end, 0);
  print_delta("double*", &tv_start, &tv_end);
}

int main(int argc, char **argv)
{
  enum { X = 1234, Y = 5678 };
  double** a = dmalloc(X, Y);
  double* b = malloc(X * Y * sizeof(*b));

  test_a(a, X, Y);
  test_b(b, X, Y);
  printf("a[27][42]     =%.3f\n", access_a(a, 27, 42));
  printf("b[27 * Y + 42]=%.3f\n", access_b(b, 27, Y, 42));
  return 0;
}
