pybind11 trampoline class for NetgenGeometry

This commit is contained in:
Hochsteger, Matthias
2026-03-31 10:55:10 +02:00
parent e8d4788b6d
commit 7207519cad
2 changed files with 154 additions and 1 deletions

View File

@@ -118,6 +118,7 @@ DLL_HEADER void ExportNetgenMeshing(py::module &m)
.def(py::self+Vec<3>())
.def(py::self-Vec<3>())
.def("__getitem__", [](Point<3>& self, int index) { return self[index]; })
.def("__setitem__", [](Point<3>& self, int index, double value) { self[index] = value; })
.def("__len__", [](Point<3>& /*unused*/) { return 3; })
;
@@ -764,7 +765,119 @@ DLL_HEADER void ExportNetgenMeshing(py::module &m)
py::implicitly_convertible< int, PointIndex>();
py::class_<NetgenGeometry, shared_ptr<NetgenGeometry>> (m, "NetgenGeometry", py::dynamic_attr())
py::class_<PointGeomInfo>(m, "PointGeomInfo")
.def(py::init<>())
.def_readwrite("trignum", &PointGeomInfo::trignum)
.def_readwrite("u", &PointGeomInfo::u)
.def_readwrite("v", &PointGeomInfo::v)
;
py::class_<EdgePointGeomInfo>(m, "EdgePointGeomInfo")
.def(py::init<>())
.def_readwrite("edgenr", &EdgePointGeomInfo::edgenr)
.def_readwrite("dist", &EdgePointGeomInfo::dist)
.def_readwrite("u", &EdgePointGeomInfo::u)
.def_readwrite("v", &EdgePointGeomInfo::v)
;
class NetgenGeometryTrampoline : public NetgenGeometry {
public:
using NetgenGeometry::NetgenGeometry;
NetgenGeometryTrampoline() : NetgenGeometry()
{
static_assert( sizeof(NetgenGeometry)==sizeof(NetgenGeometryTrampoline), "Size of NetgenGeometry and NetgenGeometryTrampoline differ");
}
Vec<3> GetNormal (int surfind, const Point<3> &p, const PointGeomInfo *gi) const override {
py::gil_scoped_acquire gil;
if (auto overload = pybind11::get_overload(this, "GetNormal"))
return py::cast<Vec<3>> (overload(surfind, p, gi));
else
throw Exception ("GetNormal not implemented");
}
PointGeomInfo ProjectPoint(int surfind, Point<3> & p) const override {
py::gil_scoped_acquire gil;
if (auto overload = pybind11::get_overload(this, "ProjectPoint"))
return py::cast<PointGeomInfo> (overload(surfind, py::cast(p, py::return_value_policy::reference)));
else
throw Exception ("ProjectPoint not implemented");
}
void ProjectPointEdge(int surfind, int surfind2, Point<3> & p,
EdgePointGeomInfo* gi = nullptr) const override {
py::gil_scoped_acquire gil;
if (auto overload = pybind11::get_overload(this, "ProjectPointEdge"))
overload(surfind, surfind2,
py::cast(p, py::return_value_policy::reference),
py::cast(gi, py::return_value_policy::reference));
throw Exception ("ProjectPointEdge not implemented");
}
bool ProjectPointGI(int surfind, Point<3> & p,
PointGeomInfo & gi) const override {
py::gil_scoped_acquire gil;
if (auto overload = pybind11::get_overload(this, "ProjectPointGI"))
return py::cast<bool> (overload(surfind, p, gi));
else if (auto overload = pybind11::get_overload(this, "ProjectPoint"))
return py::cast<bool> (overload(surfind, py::cast(p, py::return_value_policy::reference)));
else
throw Exception ("Neither ProjectPointGI nor ProjectPoint implemented");
}
bool CalcPointGeomInfo(int surfind, PointGeomInfo& gi,
const Point<3> & p3) const override {
py::gil_scoped_acquire gil;
if (auto overload = pybind11::get_overload(this, "CalcPointGeomInfo"))
return py::cast<bool> (overload(surfind, py::cast(gi, py::return_value_policy::reference), p3));
else
throw Exception ("CalcPointGeomInfo not implemented");
}
void PointBetweenEdge(const Point<3> & p1, const Point<3> & p2,
double secpoint,
int surfi1, int surfi2,
const EdgePointGeomInfo & ap1,
const EdgePointGeomInfo & ap2,
Point<3> & newp,
EdgePointGeomInfo & newgi) const override {
py::gil_scoped_acquire gil;
if (auto overload = pybind11::get_overload(this, "PointBetweenEdge"))
overload(p1, p2, secpoint, surfi1, surfi2, ap1, ap2,
py::cast(newp, py::return_value_policy::reference),
py::cast(newgi, py::return_value_policy::reference));
else
throw Exception ("PointBetweenEdge not implemented");
}
void PointBetween(const Point<3> & p1, const Point<3> & p2,
double secpoint,
int surfi,
const PointGeomInfo & gi1,
const PointGeomInfo & gi2,
Point<3> & newp,
PointGeomInfo & newgi) const override {
py::gil_scoped_acquire gil;
if (auto overload = pybind11::get_overload(this, "PointBetween"))
overload(p1, p2, secpoint, surfi, gi1, gi2,
py::cast(newp, py::return_value_policy::reference),
py::cast(newgi, py::return_value_policy::reference));
else
throw Exception ("PointBetween not implemented");
}
Vec<3> GetTangent(const Point<3> & p, int surfi1,
int surfi2,
const EdgePointGeomInfo & egi) const override {
py::gil_scoped_acquire gil;
if (auto overload = pybind11::get_overload(this, "GetTangent"))
return py::cast<Vec<3>> (overload(p, surfi1, surfi2, egi));
throw Exception ("GetTangent not implemented");
}
};
py::class_<NetgenGeometry, shared_ptr<NetgenGeometry>, NetgenGeometryTrampoline> (m, "NetgenGeometry", py::dynamic_attr())
.def(py::init<> ())
.def("RestrictH", &NetgenGeometry::RestrictH)
;

View File

@@ -0,0 +1,40 @@
from netgen.occ import *
from netgen.meshing import NetgenGeometry, Mesh as NGMesh
import numpy as np
class UnitSphereGeometry(NetgenGeometry):
def midpoint(self, newp, p1, p2, secpoint=0.5):
p1 = np.array([p1[0], p1[1], p1[2]])
p2 = np.array([p2[0], p2[1], p2[2]])
p = p1 + secpoint * (p2 - p1)
self.project(newp, p)
def project(self, newp, p):
pt = np.array([p[0], p[1], p[2]])
pt /= np.linalg.norm(pt)
newp[0] = pt[0]
newp[1] = pt[1]
newp[2] = pt[2]
def PointBetweenEdge(self, p1, p2, secpoint, surfi1, surfi2, ep1, ep2, newp, newgi):
self.midpoint(newp, p1, p2, secpoint)
def PointBetween(self, p1, p2, secpoint, surfi, gi1, gi2, newp, newgi):
self.midpoint(newp, p1, p2, secpoint)
def ProjectPointGI(self, surfind, p, gi):
self.project(p, p)
return True
m = OCCGeometry(Sphere(Pnt(0, 0, 0), 1)).GenerateMesh()
my_geo = UnitSphereGeometry()
m.SetGeometry(my_geo)
m.Refine()
m.Curve(3)
from ngsolve import *
mesh = Mesh(m)
Draw(mesh)