/*
 * Copyright (c) 2001-2002 The Trustees of Indiana University.  
 *                         All rights reserved.
 * Copyright (c) 1998-2001 University of Notre Dame. 
 *                         All rights reserved.
 * Copyright (c) 1994-1998 The Ohio State University.  
 *                         All rights reserved.
 * 
 * This file is part of the LAM/MPI software package.  For license
 * information, see the LICENSE file in the top level directory of the
 * LAM/MPI source distribution.
 * 
 * $HEADER$
 *
 * $Id: spawn.c,v 1.12 2002/10/09 20:55:48 brbarret Exp $
 *
 * Program to test MPI_Comm_spawn with simple arguments.
 */

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include "mpi.h"
#include "lamtest_error.h"


static void do_parent(char *cmd, int rank, int count);
static void do_target(char *argv0, char *argv1, char *argv2,
		      MPI_Comm parent_inter);
static void all_to_all(MPI_Comm intra);
static void free_inter(MPI_Comm inter, int do_free);


static char *cmd_argv1 = "this is argv 1";
static char *cmd_argv2 = "this is argv 2";
static int tag = 201;


int 
main(int argc, char *argv[])
{
  int rank, size;
  MPI_Comm parent;
 
  MPI_Init(&argc, &argv);
  MPI_Comm_size(MPI_COMM_WORLD, &size);
  MPI_Comm_rank(MPI_COMM_WORLD, &rank);

  /* Check to see if we *were* spawned -- because this is a test, we
     can only assume the existence of this one executable.  Hence, we
     both mpirun it and spawn it. */

  parent = MPI_COMM_NULL;
  MPI_Comm_get_parent(&parent);
  if (parent != MPI_COMM_NULL) {
    char *argv1 = 0;
    char *argv2 = 0;
    if (argc > 1)
      argv1 = argv[1];
    if (argc > 2)
      argv2 = argv[2];
    do_target(argv[0], argv1, argv2, parent);
  } else
    do_parent(argv[0], rank, size);

  /* All done */

  MPI_Finalize();
  return 0;
}


static void
do_parent(char *cmd, int rank, int count)
{
  int *errcode, err;
  int i;
  char *spawn_argv[3];
  MPI_Comm child_inter;
  MPI_Comm intra;
  FILE *fp;
  int found;
  int size;

  /* Ensure we have 3 processes */

  lamtest_check_size(__FILE__, __LINE__, 3, 1);

  spawn_argv[0] = cmd_argv1;
  spawn_argv[1] = cmd_argv2;
  spawn_argv[2] = NULL;

  /* First, see if cmd exists on all ranks */

  fp = fopen(cmd, "r");
  if (fp == NULL)
    found = 0;
  else {
    fclose(fp);
    found = 1;
  }
  MPI_Comm_size(MPI_COMM_WORLD, &size);
  MPI_Allreduce(&found, &count, 1, MPI_INT, MPI_SUM, MPI_COMM_WORLD);
  if (count != size) {
    if (rank == 0)
      lamtest_warning(__FILE__, __LINE__, 
		      "Not all ranks were able to find:\n\t\"%s\"\n"
		      "You probably don't have a uniform filesystem...?\n"
		      "So I'll skip this test, but not call it a failure.\n",
		      cmd);
    return;
  }
  
  /* Now try the spawn if it's found anywhere */

  errcode = malloc(sizeof(int) * count);
  if (errcode == NULL)
    lamtest_error(__FILE__, __LINE__, "Doh!  Rank %d was not able to allocate enough memory.  MPI test aborted!\n", rank);
  memset(errcode, -1, count);
  MPI_Comm_set_errhandler(MPI_COMM_WORLD, MPI_ERRORS_RETURN);
  err = MPI_Comm_spawn(cmd, spawn_argv, count, MPI_INFO_NULL, 0,
		       MPI_COMM_WORLD, &child_inter, errcode); 

  for (i = 0; i < count; i++)
    if (errcode[i] != MPI_SUCCESS)
      lamtest_error(__FILE__, __LINE__, 
		    "ERROR: MPI_Comm_spawn returned errcode[%d] = %d\n", 
		    i, errcode[i]);
  if (err != MPI_SUCCESS)
      lamtest_error(__FILE__, __LINE__, 
		    "ERROR: MPI_Comm_spawn returned errcode = %d\n", err);

  /* Now do a simple ping pong to everyone in the child */

  MPI_Intercomm_merge(child_inter, 0, &intra);
  all_to_all(intra);
  MPI_Comm_free(&intra);

  /* Clean up */

  free_inter(child_inter, 1);
  free(errcode);
}


static void
do_target(char *argv0, char *argv1, char *argv2, MPI_Comm parent)
{
  int rank;
  MPI_Comm intra;

  MPI_Comm_rank(MPI_COMM_WORLD, &rank);

  /* Check that we got the argv that we expected */

  if (strcmp(argv1, cmd_argv1) != 0)
    lamtest_error(__FILE__, __LINE__, "ERROR: Spawn target rank %d got argv[1]=\"%s\" when expecing \"%s\"\n",
		  rank, argv1, cmd_argv1);
  if (strcmp(argv2, cmd_argv2) != 0)
    lamtest_error(__FILE__, __LINE__, "ERROR: Spawn target rank %d got argv[2]=\"%s\" when expecing \"%s\"\n",
		  rank, argv2, cmd_argv2);

  /* Now merge it down to an intra and do a simple all-to-all to
     everyone in the parent */

  MPI_Intercomm_merge(parent, 0, &intra);
  all_to_all(intra);
  MPI_Comm_free(&intra);

  free_inter(parent, 0);
}


static void
all_to_all(MPI_Comm intra)
{
  int size, rank, i;
  int message;

  MPI_Comm_size(intra, &size);
  MPI_Comm_rank(intra, &rank);

  for (i = 0; i < size; i++) {
    message = -1;
    if (i == rank)
      continue;
    else if (i < rank) {
      MPI_Send(&rank, 1, MPI_INT, i, tag, intra);
      MPI_Recv(&message, 1, MPI_INT, i, tag, intra, MPI_STATUS_IGNORE);
    } else {
      MPI_Recv(&message, 1, MPI_INT, i, tag, intra, MPI_STATUS_IGNORE);
      MPI_Send(&rank, 1, MPI_INT, i, tag, intra);
    }

    if (message != i)
      lamtest_error(__FILE__, __LINE__, "ERROR: rank %d got message %d from comm rank %d; expected %d\n", rank, message, i, i);
  }
}


static void
free_inter(MPI_Comm inter, int do_free)
{
  MPI_Comm intra;
  int size;

  MPI_Comm_size(inter, &size);

  MPI_Intercomm_merge(inter, 0, &intra);
  MPI_Comm_size(inter, &size);
  
  MPI_Comm_free(&intra);
  if (do_free)
    MPI_Comm_free(&inter);
}  
