/*
 * Copyright (c) 2001-2003 The Trustees of Indiana University.
 *                         All rights reserved.
 * Copyright (c) 1998-2001 University of Notre Dame.
 *                         All rights reserved.
 * Copyright (c) 1994-1998 The Ohio State University.
 *                         All rights reserved.
 *
 * This file is part of the LAM/MPI software package.  For license
 * information, see the LICENSE file in the top level directory of the
 * LAM/MPI source distribution.
 *
 * $HEADER$
 *
 * $Id: comm_join.c,v 1.6 2003/06/27 01:51:06 jsquyres Exp $
 *
 * Program to test MPI_Comm_join 
 */

#include <stdlib.h>
#include <string.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <unistd.h>
#include <netinet/in.h>
#include <netdb.h>
#include <arpa/inet.h>

#include <mpi.h>
#include "lamtest_error.h"

#define SERVER_PORT 3100


int
main(int argc, char** argv) 
{
  struct sockaddr_in client_addr, server_addr;
  struct hostent* h;
  int rank;
  int server_socket, client_socket, client_length, newsd;
  MPI_Comm intercomm;
  char *host_name;
  int host_len, on;
  int send_data = 17, recv_data;

  MPI_Init(&argc, &argv);

  lamtest_check_size(__FILE__, __LINE__, 2, 1);

  MPI_Comm_rank(MPI_COMM_WORLD, &rank);
  
  if (rank == 0) {
    /* server code */
    on = 1;
    server_socket = socket(AF_INET, SOCK_STREAM, 0);
    
    /* reuse the port address */
    setsockopt(server_socket, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on));

    if (server_socket < 0)
      lamtest_error(__FILE__, __LINE__, 
                    "Error: server failed to create a socket\n");
    
    /* bind to a server port */
    server_addr.sin_family = AF_INET;
    server_addr.sin_addr.s_addr = htonl(INADDR_ANY);
    server_addr.sin_port = htons(SERVER_PORT);
    
    if (bind(server_socket, (struct sockaddr*) &server_addr,
	     sizeof(server_addr)) < 0) {
      close(server_socket);
      lamtest_error(__FILE__, __LINE__, 
                    "Error: server failed to bind a socket\n");
    }
    
    listen(server_socket, 3);

    host_len = 255;
    host_name = (char *) malloc(host_len * sizeof(char));
    gethostname(host_name, host_len - 1);
    host_name[host_len - 1] = '\0';
    host_len = strlen(host_name) + 1;
    /* send length of host name to rank 1 */
    MPI_Send(&host_len, 1, MPI_INT, 1, 100, MPI_COMM_WORLD);

    /* send host name to rank 1 */
    MPI_Send(host_name, host_len, MPI_CHAR, 1, 100, MPI_COMM_WORLD);

    client_length = sizeof(client_addr);
    newsd = accept(server_socket, (struct sockaddr*) &client_addr, 
		   &client_length);
    if (newsd < 0) {
      close(server_socket);
      free(host_name);
      lamtest_error(__FILE__, __LINE__, "Error: server failed to accept\n");
    }

    /* join with client */
    MPI_Comm_join(newsd, &intercomm);
    free(host_name);
    close(newsd);    
    close(server_socket);

    /* test the newly generated comm */
    MPI_Comm_rank(intercomm, &rank);
    MPI_Send(&send_data, 1, MPI_INT, 0, 123, intercomm);
    MPI_Recv(&recv_data, 1, MPI_INT, 0, 321, intercomm, MPI_STATUS_IGNORE);
    if (send_data != recv_data)
      lamtest_error(__FILE__, __LINE__, "Error: server check data mismatch\n");
    MPI_Comm_free(&intercomm);
  } else if (rank == 1) {
    /* client code */
    
    client_socket = socket(AF_INET, SOCK_STREAM, 0);

    /* get the host name length from rank 0 */
    MPI_Recv(&host_len, 1, MPI_INT, 0, 100, MPI_COMM_WORLD, MPI_STATUS_IGNORE);

    host_name = (char *) malloc((host_len + 1) * sizeof(char));
    /* get the host name from rank 0 */
    MPI_Recv(host_name, host_len, MPI_CHAR, 0, 100, MPI_COMM_WORLD,
	     MPI_STATUS_IGNORE);
    
    h = gethostbyname(host_name);
    if (h == NULL) {
      close(client_socket);
      free(host_name);
      lamtest_error(__FILE__, __LINE__,
                    "ERROR: Client could not gethostbyname properly\n");
    }

    server_addr.sin_family = h->h_addrtype;
    memcpy((char *) &server_addr.sin_addr.s_addr, h->h_addr_list[0],
	   h->h_length);
    server_addr.sin_port = htons(SERVER_PORT);
    
    newsd = connect(client_socket, (struct sockaddr*) &server_addr, 
		    sizeof(server_addr));
    if (newsd < 0) {
      close(client_socket);
      free(host_name);
      lamtest_error(__FILE__, __LINE__,
                    "ERROR: Client could not connect() properly\n");
    }

    /* join with server */
    MPI_Comm_join(client_socket, &intercomm);
    free(host_name);
    close(client_socket);

    /* test newly generated comm */
    MPI_Comm_rank(intercomm, &rank);
    MPI_Recv(&recv_data, 1, MPI_INT, 0, 123, intercomm, MPI_STATUS_IGNORE);
    MPI_Send(&send_data, 1, MPI_INT, 0, 321, intercomm);
    if (send_data != recv_data)
      lamtest_error(__FILE__, __LINE__, "Error: server check data mismatch\n");
    MPI_Comm_free(&intercomm);
  }

  /* All done */
  
  MPI_Finalize();
  return 0;
}
