#include<stdio.h>
#include<math.h>
#include<mpi.h>

#define nmax 128
#define pmax 10

void mat_gen(matrA, n, np, myid)
  int n, np, myid; 
  double matrA[nmax*nmax][5];
/* Generira matricata */
{ int i, i1, j, dim;

  i1=n/np;
  dim=n*i1;

  for(i=0; i<dim; i++)
  {  for(j=0; j<5; j++)
     { matrA[i][0]=-1.;
       matrA[i][1]=-1.;
       matrA[i][2]= 4.;
       matrA[i][3]=-1.;
       matrA[i][4]=-1.;
     }
  }
  
  for(i=0;i<i1;i++)
  {  matrA[i*n][1]=0.;
     matrA[(i+1)*n-1][3]=0.;
  }
  
  if(myid==0)
  {  for(i=0; i<n; i++)
     {  matrA[i][0]=0. ;
     }
  } 
  if(myid==np-1)
  {  for(i=dim-1; i>dim-1-n; i--)
     {  matrA[i][4]=0. ;
     }
  }
}

double multuv(u,v,n,np,myid,otm)
  int n, np, myid, otm;
  double u[(nmax+2)*nmax],v[(nmax+2)*nmax];
/* Presmyata skalarnoto proizvedenie na u i v s np procesora.
   Vseki ot procesorite sadarja chast ot u i ot v  i presmyata 
   saotvetnata suma, sled koeto ya izprashta na proces 0.  
   Krainiat rezultat se izprashta vav vsichki procesori. */
{ int i, istart, iend;
  
  double res=0.;
  double prod=0.;
  double vprod[pmax];
  MPI_Comm com=MPI_COMM_WORLD;

  istart=otm*n;
  iend=(n/np+otm)*n;

  for(i=istart; i<iend; i++)
  {  prod=prod+u[i]*v[i];
  } 
  MPI_Gather(&prod,1,MPI_DOUBLE,vprod,1,MPI_DOUBLE,0, com);

  if (myid==0) 
  {  for(i=0; i<np; i++)
     {  res=res+vprod[i];
     }
  }  
  MPI_Bcast(&res, 1, MPI_DOUBLE, 0, com);

  return  res;
}  

void multAv(matrA, vv, rv, n,np,myid,otm)
  int n, np, myid, otm; 
  double matrA[nmax*nmax][5], vv[(nmax+2)*nmax], rv[(nmax+2)*nmax];
{ int dim,rem, i, j, k, istart, iend;
  int etypes ,etyper, btypes, btyper;
  
  MPI_Comm com=MPI_COMM_WORLD;
  MPI_Status status;

  etypes=1;
  etyper=2;
  btypes=3;
  btyper=4;

  dim=n*n/np;
  rem=myid%2;

  if(rem==0)
  {  if(myid!=np-1)
     {  MPI_Send(&(vv[dim]), n, MPI_DOUBLE, myid+1,etypes,com);
        MPI_Recv(&(vv[dim+n]), n, MPI_DOUBLE, myid+1, btyper, com, &status);
     }
     if(myid!=0)
     {  MPI_Send(&(vv[n]), n, MPI_DOUBLE, myid-1,btypes,com);
        MPI_Recv(vv, n, MPI_DOUBLE, myid-1, etyper, com, &status);
     }
  }
  else
  {  MPI_Recv(vv, n, MPI_DOUBLE, myid-1, etypes, com, &status);
     MPI_Send(&(vv[n]), n, MPI_DOUBLE, myid-1,btyper,com);
     if(myid!=np-1)
     {  MPI_Recv(&(vv[dim+n]), n, MPI_DOUBLE, myid+1, btypes, com, &status);
        MPI_Send(&(vv[dim]), n, MPI_DOUBLE, myid+1,etyper,com);
     }
  }

  istart=otm*n;
  iend=istart+dim;
  if(myid==0)
  {  for(i=0; i<n; i++)
     {  j=istart+i;
        k=i+n;
        rv[j] = matrA[i][2]*vv[k] + matrA[i][3]*vv[k+1] + matrA[i][1]*vv[k-1]
            + matrA[i][4]*vv[k+n];
     }
     for(i=n; i<dim; i++)
     {  j=istart+i;
        k=i+n;
        rv[j] = matrA[i][2]*vv[k] + matrA[i][3]*vv[k+1] + matrA[i][1]*vv[k-1]
            + matrA[i][4]*vv[k+n]+matrA[i][0]*vv[k-n];
     }
  }
  else if(myid==np-1)
  {  for(i=0; i<dim-n; i++)
     {  j=istart+i;
        k=i+n;
        rv[j] = matrA[i][2]*vv[k] + matrA[i][3]*vv[k+1] + matrA[i][1]*vv[k-1]
            + matrA[i][4]*vv[k+n]+matrA[i][0]*vv[k-n];
     }
     for(i=dim-n; i<dim; i++)
     {  j=istart+i;
        k=i+n;
        rv[j] = matrA[i][2]*vv[k] + matrA[i][3]*vv[k+1] + matrA[i][1]*vv[k-1]
            + matrA[i][0]*vv[k-n];
     }
  }
  else
  {  for(i=0; i<dim; i++)
     {  j=istart+i;
        k=i+n;
        rv[j] = matrA[i][2]*vv[k] + matrA[i][3]*vv[k+1] + matrA[i][1]*vv[k-1]
            + matrA[i][4]*vv[k+n] + matrA[i][0]*vv[k-n];
     }
  }
}


int main(argc, argv)
  int argc;
  char **argv;
{ int myid, np, n, otm;
  double eps;
  double t1,t2;
  double matrA[nmax*nmax][5], xx[(nmax+2)*nmax], vv[(nmax+2)*nmax];
  int i,j, dim;

  MPI_Comm com=MPI_COMM_WORLD;

  MPI_Init(&argc, &argv);
  MPI_Comm_size(com, &np);
  MPI_Comm_rank(com, &myid);

  if(myid==0)
  {  printf("Enter the dimension of the vectors: \n");
     scanf("%d",&n);
/*     printf("Enter eps: \n");
     scanf("%lf",&eps);*/
  }
  MPI_Bcast(&n, 1, MPI_INT, 0, com);
  MPI_Bcast(&eps, 1, MPI_DOUBLE, 0, com);

  dim=n*n/np;
  otm=1;

  for(i=0;i<(dim+2*n);i++)
  {  xx[i]=1. ;
  }
  mat_gen(matrA, n, np, myid);
  
  t1=MPI_Wtime();
/*  iter=st_de(matrA,b, xx, n,np,myid, eps);*/
  multAv(matrA, xx, vv, n, np, myid, otm);
  t2=MPI_Wtime();

  printf("myid= %d ; TIME=%g \n", myid, t2-t1);

/*  for(j=0; j<np; j++)
  {  if(myid==j)
     {   for(i=0; i<((n+2)*n);i++)
         {   printf("myid= %d ; xx[%d] = %g, vv=%g \n", myid, i, xx[i],vv[i]);
         }
     }
  }
*/
  MPI_Finalize();
  return 0;
}
