mpi_wrapper

This commit is contained in:
Joachim Schöberl
2019-02-11 21:37:00 +01:00
parent 1074593664
commit 9ced2f561f
4 changed files with 67 additions and 61 deletions

View File

@@ -494,20 +494,20 @@ DLL_HEADER void ExportNetgenMeshing(py::module &m)
py::class_<Mesh,shared_ptr<Mesh>>(m, "Mesh")
// .def(py::init<>("create empty mesh"))
.def(py::init( [] (int dim, shared_ptr<PyMPI_Comm> pycomm)
.def(py::init( [] (int dim, NgMPI_Comm comm)
{
auto mesh = make_shared<Mesh>();
mesh->SetCommunicator(pycomm!=nullptr ? pycomm->comm : netgen::ng_comm);
mesh->SetCommunicator(comm);
mesh -> SetDimension(dim);
SetGlobalMesh(mesh); // for visualization
mesh -> SetGeometry (nullptr);
return mesh;
} ),
py::arg("dim")=3, py::arg("comm")=nullptr
py::arg("dim")=3, py::arg("comm")=NgMPI_Comm(ng_comm)
)
.def(NGSPickle<Mesh>())
.def_property_readonly("comm", [](const Mesh & amesh)
{ return make_shared<PyMPI_Comm>(amesh.GetCommunicator()); },
.def_property_readonly("comm", [](const Mesh & amesh) -> NgMPI_Comm
{ return amesh.GetCommunicator(); },
"MPI-communicator the Mesh lives in")
/*
.def("__init__",
@@ -521,8 +521,7 @@ DLL_HEADER void ExportNetgenMeshing(py::module &m)
*/
.def_property_readonly("_timestamp", &Mesh::GetTimeStamp)
.def("Distribute", [](shared_ptr<Mesh> self, shared_ptr<PyMPI_Comm> pycomm) {
MPI_Comm comm = pycomm!=nullptr ? pycomm->comm : self->GetCommunicator();
.def("Distribute", [](shared_ptr<Mesh> self, NgMPI_Comm comm) {
self->SetCommunicator(comm);
if(MyMPI_GetNTasks(comm)==1) return self;
// if(MyMPI_GetNTasks(comm)==2) throw NgException("Sorry, cannot handle communicators with NP=2!");
@@ -530,10 +529,10 @@ DLL_HEADER void ExportNetgenMeshing(py::module &m)
if(MyMPI_GetId(comm)==0) self->Distribute();
else self->SendRecvMesh();
return self;
}, py::arg("comm")=nullptr)
.def("Receive", [](shared_ptr<PyMPI_Comm> pycomm) {
}, py::arg("comm")=NgMPI_Comm(ng_comm))
.def("Receive", [](NgMPI_Comm comm) {
auto mesh = make_shared<Mesh>();
mesh->SetCommunicator(pycomm->comm);
mesh->SetCommunicator(comm);
mesh->SendRecvMesh();
return mesh;
})
@@ -933,57 +932,34 @@ DLL_HEADER void ExportNetgenMeshing(py::module &m)
return old;
}));
py::class_<PyMPI_Comm, shared_ptr<PyMPI_Comm>> (m, "MPI_Comm")
.def_property_readonly ("rank", &PyMPI_Comm::Rank)
.def_property_readonly ("size", &PyMPI_Comm::Size)
// .def_property_readonly ("rank", [](PyMPI_Comm & c) { cout << "rank for " << c.comm << endl; return c.Rank(); })
// .def_property_readonly ("size", [](PyMPI_Comm & c) { cout << "size for " << c.comm << endl; return c.Size(); })
py::class_<NgMPI_Comm> (m, "MPI_Comm")
.def_property_readonly ("rank", &NgMPI_Comm::Rank)
.def_property_readonly ("size", &NgMPI_Comm::Size)
#ifdef PARALLEL
.def("Barrier", [](PyMPI_Comm & c) { MPI_Barrier(c.comm); })
.def("WTime", [](PyMPI_Comm & c) { return MPI_Wtime(); })
.def("Barrier", [](NgMPI_Comm & c) { MPI_Barrier(c); })
.def("WTime", [](NgMPI_Comm & c) { return MPI_Wtime(); })
#else
.def("Barrier", [](PyMPI_Comm & c) { })
.def("WTime", [](PyMPI_Comm & c) { return -1.0; })
.def("Barrier", [](NgMPI_Comm & c) { })
.def("WTime", [](NgMPI_Comm & c) { return -1.0; })
#endif
.def("Sum", [](PyMPI_Comm & c, double x) { return MyMPI_AllReduceNG(x, MPI_SUM, c.comm); })
.def("Min", [](PyMPI_Comm & c, double x) { return MyMPI_AllReduceNG(x, MPI_MIN, c.comm); })
.def("Max", [](PyMPI_Comm & c, double x) { return MyMPI_AllReduceNG(x, MPI_MAX, c.comm); })
.def("Sum", [](PyMPI_Comm & c, int x) { return MyMPI_AllReduceNG(x, MPI_SUM, c.comm); })
.def("Min", [](PyMPI_Comm & c, int x) { return MyMPI_AllReduceNG(x, MPI_MIN, c.comm); })
.def("Max", [](PyMPI_Comm & c, int x) { return MyMPI_AllReduceNG(x, MPI_MAX, c.comm); })
.def("Sum", [](PyMPI_Comm & c, size_t x) { return MyMPI_AllReduceNG(x, MPI_SUM, c.comm); })
.def("Min", [](PyMPI_Comm & c, size_t x) { return MyMPI_AllReduceNG(x, MPI_MIN, c.comm); })
.def("Max", [](PyMPI_Comm & c, size_t x) { return MyMPI_AllReduceNG(x, MPI_MAX, c.comm); })
.def("SubComm", [](PyMPI_Comm & c, std::vector<int> proc_list) {
.def("Sum", [](NgMPI_Comm & c, double x) { return MyMPI_AllReduceNG(x, MPI_SUM, c); })
.def("Min", [](NgMPI_Comm & c, double x) { return MyMPI_AllReduceNG(x, MPI_MIN, c); })
.def("Max", [](NgMPI_Comm & c, double x) { return MyMPI_AllReduceNG(x, MPI_MAX, c); })
.def("Sum", [](NgMPI_Comm & c, int x) { return MyMPI_AllReduceNG(x, MPI_SUM, c); })
.def("Min", [](NgMPI_Comm & c, int x) { return MyMPI_AllReduceNG(x, MPI_MIN, c); })
.def("Max", [](NgMPI_Comm & c, int x) { return MyMPI_AllReduceNG(x, MPI_MAX, c); })
.def("Sum", [](NgMPI_Comm & c, size_t x) { return MyMPI_AllReduceNG(x, MPI_SUM, c); })
.def("Min", [](NgMPI_Comm & c, size_t x) { return MyMPI_AllReduceNG(x, MPI_MIN, c); })
.def("Max", [](NgMPI_Comm & c, size_t x) { return MyMPI_AllReduceNG(x, MPI_MAX, c); })
.def("SubComm", [](NgMPI_Comm & c, std::vector<int> proc_list) {
Array<int> procs(proc_list.size());
for (int i = 0; i < procs.Size(); i++)
procs[i] = proc_list[i];
if (!procs.Contains(c.Rank()))
throw Exception("rank "+ToString(c.Rank())+" not in subcomm");
MPI_Comm subcomm = MyMPI_SubCommunicator(c.comm, procs);
return make_shared<PyMPI_Comm>(subcomm, true);
/*
Array<int> procs;
if (py::extract<py::list> (proc_list).check()) {
py::list pylist = py::extract<py::list> (proc_list)();
procs.SetSize(py::len(pyplist));
for (int i = 0; i < py::len(pylist); i++)
procs[i] = py::extract<int>(pylist[i])();
}
else {
throw Exception("SubComm needs a list!");
}
if(!procs.Size()) {
cout << "warning, tried to construct empty communicator, returning MPI_COMM_NULL" << endl;
return make_shared<PyMPI_Comm>(MPI_COMM_NULL);
}
else if(procs.Size()==2) {
throw Exception("Sorry, NGSolve cannot handle NP=2.");
}
MPI_Comm subcomm = MyMPI_SubCommunicator(c.comm, procs);
return make_shared<PyMPI_Comm>(subcomm, true);
*/
MPI_Comm subcomm = MyMPI_SubCommunicator(c, procs);
return NgMPI_Comm(subcomm, true);
}, py::arg("procs"));
;