Fleet  0.0.9
Inference in the LOT
MPI.h
Go to the documentation of this file.
1 #pragma once
2 
3 // Install MPI via
4 // sudo apt install openmpi-bin openmpi-common openmpi-doc libopenmpi-dev
5 // Maybe see: https://www.slothparadise.com/use-mpi-without-nfs/
6 
7 
8 #include <string>
9 #include <mpi.h>
10 
11 const size_t MPI_HEAD_RANK = 0; // rank of the head node
12 
13 int mpi_rank() {
14  int r;
15  MPI_Comm_rank(MPI_COMM_WORLD, &r);
16  return r;
17 }
18 
19 int mpi_size() {
20  int s;
21  MPI_Comm_size(MPI_COMM_WORLD, &s);
22  return s;
23 }
24 
25 bool is_mpi_head() {
26  return mpi_rank() == MPI_HEAD_RANK;
27 }
28 
34 template<typename T>
35 void mpi_return(T& x) {
36  assert(mpi_rank() != MPI_HEAD_RANK && "*** Head rank cannot call mpi_return"); // the head can't return
37 
38  std::string v = x.serialize(); // convert to string
39 // CERR "MPI SENDING " TAB mpi_rank() TAB v ENDL;
40 
41  MPI_Send(v.data(), v.size(), MPI_CHAR, MPI_HEAD_RANK, 0, MPI_COMM_WORLD);
42 }
43 
51 template<typename T>
52 std::vector<T> mpi_gather() {
53 
54  assert(mpi_rank() == MPI_HEAD_RANK && "*** Cannot call mpi_gather unless you are head rank");
55 
56  int s = mpi_size();
57  std::vector<T> out;
58 
59  for(int r=0;r<s;r++) {
60  if(r != MPI_HEAD_RANK) {
61  // get the status with the message size
62  MPI_Status status;
63  auto v = MPI_Probe(r, 0, MPI_COMM_WORLD, &status);
64  assert(v == MPI_SUCCESS);
65 
66  // When probe returns, the status object has the size and other
67  // attributes of the incoming message. Get the message size
68  int sz; v = MPI_Get_count(&status, MPI_CHAR, &sz);
69  assert(v == MPI_SUCCESS);
70  assert(sz != MPI_UNDEFINED);
71  assert(sz >= 0);
72 
73  // Allocate a buffer to hold the incoming numbers
74  char* buf = new char[sz+1];
75 
76  // Now receive the message with the allocated buffer
77  v = MPI_Recv(buf, sz, MPI_CHAR, r, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
78  assert(v == MPI_SUCCESS);
79  buf[sz] = '\0'; // make null terminated
80 
81  std::string ret = buf;
82  out.push_back(T::deserialize(ret)); // note must give size -- not null terminated
83  delete[] buf;
84 
85 // CERR "MPI RECEIVING" TAB r TAB sz TAB out.rbegin()->size() TAB std::string(buf, sz) ENDL;
86  }
87  }
88 
89  return out;
90 }
91 
92 
93 
99 template<typename T>
100 std::vector<T> mpi_map(const size_t n) {
101  enum WORKER_TAGS { GIVE_ME, TAKE_THIS, DONE };
102 
103  if(mpi_rank() == MPI_HEAD_RANK) {
104 
105  // store all of the return values
106  std::vector<T> ret;
107 
108  for(size_t i=0;i<n;) {
109 
110  MPI_Status status;
111  auto v = MPI_Probe(MPI_ANY_SOURCE, MPI_ANY_TAG, MPI_COMM_WORLD, &status);
112 
113  if(status.MPI_TAG == TAKE_THIS) { // if they are sending data
114 
115  // Allocate a buffer to hold the incoming numbers
116  char* buf = new char[sz+1];
117 
118  // Now receive the message with the allocated buffer
119  v = MPI_Recv(buf, sz, MPI_CHAR, r, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
120  assert(v == MPI_SUCCESS);
121  buf[sz] = '\0'; // make null terminated
122 
123  std::string ret = buf;
124  out.push_back(T::deserialize(ret)); // note must give size -- not null terminated
125  delete[] buf;
126  }
127  else if(status.MPI_TAG == GIVE_ME) {
128  int[1] buf;
129  buf[0] = i;
130  MPI_send(buf, 1, MPI_INT, status.MPI_SOURCE, TAKE_THIS, MPI_COMM_WORLD);
131  i++;
132  }
133  else {
134  assert(false && "*** Bad tag");
135  }
136  }
137  }
138  else {
139  // It's a worker
140  while(true) {
141 
142  MPI_Status status;
143  auto v = MPI_Probe(MPI_ANY_SOURCE, MPI_ANY_TAG, MPI_COMM_WORLD, &status);
144 
145  }
146  }
147 }
148 
void mpi_return(T &x)
Return my results via MPI.
Definition: MPI.h:35
int mpi_size()
Definition: MPI.h:19
const size_t MPI_HEAD_RANK
Definition: MPI.h:11
int mpi_rank()
Definition: MPI.h:13
bool is_mpi_head()
Definition: MPI.h:25
std::vector< T > mpi_gather()
Reads all the MPI returns from mpi_return. NOTE that the output does not come with any order guarante...
Definition: MPI.h:52
std::vector< T > mpi_map(const size_t n)
Send an int n to each request. Each requestor can use this to index into some data which they are ass...
Definition: MPI.h:100