mirror of
https://github.com/NGSolve/netgen.git
synced 2026-05-31 01:06:10 +08:00
pybind11 trampoline class for NetgenGeometry
This commit is contained in:
@@ -118,6 +118,7 @@ DLL_HEADER void ExportNetgenMeshing(py::module &m)
|
||||
.def(py::self+Vec<3>())
|
||||
.def(py::self-Vec<3>())
|
||||
.def("__getitem__", [](Point<3>& self, int index) { return self[index]; })
|
||||
.def("__setitem__", [](Point<3>& self, int index, double value) { self[index] = value; })
|
||||
.def("__len__", [](Point<3>& /*unused*/) { return 3; })
|
||||
;
|
||||
|
||||
@@ -764,7 +765,119 @@ DLL_HEADER void ExportNetgenMeshing(py::module &m)
|
||||
|
||||
py::implicitly_convertible< int, PointIndex>();
|
||||
|
||||
py::class_<NetgenGeometry, shared_ptr<NetgenGeometry>> (m, "NetgenGeometry", py::dynamic_attr())
|
||||
py::class_<PointGeomInfo>(m, "PointGeomInfo")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("trignum", &PointGeomInfo::trignum)
|
||||
.def_readwrite("u", &PointGeomInfo::u)
|
||||
.def_readwrite("v", &PointGeomInfo::v)
|
||||
;
|
||||
|
||||
py::class_<EdgePointGeomInfo>(m, "EdgePointGeomInfo")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("edgenr", &EdgePointGeomInfo::edgenr)
|
||||
.def_readwrite("dist", &EdgePointGeomInfo::dist)
|
||||
.def_readwrite("u", &EdgePointGeomInfo::u)
|
||||
.def_readwrite("v", &EdgePointGeomInfo::v)
|
||||
;
|
||||
|
||||
class NetgenGeometryTrampoline : public NetgenGeometry {
|
||||
public:
|
||||
using NetgenGeometry::NetgenGeometry;
|
||||
NetgenGeometryTrampoline() : NetgenGeometry()
|
||||
{
|
||||
static_assert( sizeof(NetgenGeometry)==sizeof(NetgenGeometryTrampoline), "Size of NetgenGeometry and NetgenGeometryTrampoline differ");
|
||||
}
|
||||
|
||||
Vec<3> GetNormal (int surfind, const Point<3> &p, const PointGeomInfo *gi) const override {
|
||||
py::gil_scoped_acquire gil;
|
||||
if (auto overload = pybind11::get_overload(this, "GetNormal"))
|
||||
return py::cast<Vec<3>> (overload(surfind, p, gi));
|
||||
else
|
||||
throw Exception ("GetNormal not implemented");
|
||||
}
|
||||
PointGeomInfo ProjectPoint(int surfind, Point<3> & p) const override {
|
||||
py::gil_scoped_acquire gil;
|
||||
if (auto overload = pybind11::get_overload(this, "ProjectPoint"))
|
||||
return py::cast<PointGeomInfo> (overload(surfind, py::cast(p, py::return_value_policy::reference)));
|
||||
else
|
||||
throw Exception ("ProjectPoint not implemented");
|
||||
}
|
||||
|
||||
void ProjectPointEdge(int surfind, int surfind2, Point<3> & p,
|
||||
EdgePointGeomInfo* gi = nullptr) const override {
|
||||
py::gil_scoped_acquire gil;
|
||||
if (auto overload = pybind11::get_overload(this, "ProjectPointEdge"))
|
||||
overload(surfind, surfind2,
|
||||
py::cast(p, py::return_value_policy::reference),
|
||||
py::cast(gi, py::return_value_policy::reference));
|
||||
throw Exception ("ProjectPointEdge not implemented");
|
||||
}
|
||||
|
||||
bool ProjectPointGI(int surfind, Point<3> & p,
|
||||
PointGeomInfo & gi) const override {
|
||||
py::gil_scoped_acquire gil;
|
||||
if (auto overload = pybind11::get_overload(this, "ProjectPointGI"))
|
||||
return py::cast<bool> (overload(surfind, p, gi));
|
||||
else if (auto overload = pybind11::get_overload(this, "ProjectPoint"))
|
||||
return py::cast<bool> (overload(surfind, py::cast(p, py::return_value_policy::reference)));
|
||||
else
|
||||
throw Exception ("Neither ProjectPointGI nor ProjectPoint implemented");
|
||||
}
|
||||
|
||||
bool CalcPointGeomInfo(int surfind, PointGeomInfo& gi,
|
||||
const Point<3> & p3) const override {
|
||||
py::gil_scoped_acquire gil;
|
||||
if (auto overload = pybind11::get_overload(this, "CalcPointGeomInfo"))
|
||||
return py::cast<bool> (overload(surfind, py::cast(gi, py::return_value_policy::reference), p3));
|
||||
else
|
||||
throw Exception ("CalcPointGeomInfo not implemented");
|
||||
}
|
||||
|
||||
void PointBetweenEdge(const Point<3> & p1, const Point<3> & p2,
|
||||
double secpoint,
|
||||
int surfi1, int surfi2,
|
||||
const EdgePointGeomInfo & ap1,
|
||||
const EdgePointGeomInfo & ap2,
|
||||
Point<3> & newp,
|
||||
EdgePointGeomInfo & newgi) const override {
|
||||
py::gil_scoped_acquire gil;
|
||||
if (auto overload = pybind11::get_overload(this, "PointBetweenEdge"))
|
||||
overload(p1, p2, secpoint, surfi1, surfi2, ap1, ap2,
|
||||
py::cast(newp, py::return_value_policy::reference),
|
||||
py::cast(newgi, py::return_value_policy::reference));
|
||||
else
|
||||
throw Exception ("PointBetweenEdge not implemented");
|
||||
}
|
||||
|
||||
void PointBetween(const Point<3> & p1, const Point<3> & p2,
|
||||
double secpoint,
|
||||
int surfi,
|
||||
const PointGeomInfo & gi1,
|
||||
const PointGeomInfo & gi2,
|
||||
Point<3> & newp,
|
||||
PointGeomInfo & newgi) const override {
|
||||
py::gil_scoped_acquire gil;
|
||||
if (auto overload = pybind11::get_overload(this, "PointBetween"))
|
||||
overload(p1, p2, secpoint, surfi, gi1, gi2,
|
||||
py::cast(newp, py::return_value_policy::reference),
|
||||
py::cast(newgi, py::return_value_policy::reference));
|
||||
else
|
||||
throw Exception ("PointBetween not implemented");
|
||||
}
|
||||
|
||||
Vec<3> GetTangent(const Point<3> & p, int surfi1,
|
||||
int surfi2,
|
||||
const EdgePointGeomInfo & egi) const override {
|
||||
py::gil_scoped_acquire gil;
|
||||
if (auto overload = pybind11::get_overload(this, "GetTangent"))
|
||||
return py::cast<Vec<3>> (overload(p, surfi1, surfi2, egi));
|
||||
throw Exception ("GetTangent not implemented");
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
py::class_<NetgenGeometry, shared_ptr<NetgenGeometry>, NetgenGeometryTrampoline> (m, "NetgenGeometry", py::dynamic_attr())
|
||||
.def(py::init<> ())
|
||||
.def("RestrictH", &NetgenGeometry::RestrictH)
|
||||
;
|
||||
|
||||
|
||||
40
py_tutorials/python_geometry.py
Normal file
40
py_tutorials/python_geometry.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from netgen.occ import *
|
||||
from netgen.meshing import NetgenGeometry, Mesh as NGMesh
|
||||
import numpy as np
|
||||
|
||||
|
||||
class UnitSphereGeometry(NetgenGeometry):
|
||||
def midpoint(self, newp, p1, p2, secpoint=0.5):
|
||||
p1 = np.array([p1[0], p1[1], p1[2]])
|
||||
p2 = np.array([p2[0], p2[1], p2[2]])
|
||||
p = p1 + secpoint * (p2 - p1)
|
||||
self.project(newp, p)
|
||||
|
||||
def project(self, newp, p):
|
||||
pt = np.array([p[0], p[1], p[2]])
|
||||
pt /= np.linalg.norm(pt)
|
||||
newp[0] = pt[0]
|
||||
newp[1] = pt[1]
|
||||
newp[2] = pt[2]
|
||||
|
||||
def PointBetweenEdge(self, p1, p2, secpoint, surfi1, surfi2, ep1, ep2, newp, newgi):
|
||||
self.midpoint(newp, p1, p2, secpoint)
|
||||
|
||||
def PointBetween(self, p1, p2, secpoint, surfi, gi1, gi2, newp, newgi):
|
||||
self.midpoint(newp, p1, p2, secpoint)
|
||||
|
||||
def ProjectPointGI(self, surfind, p, gi):
|
||||
self.project(p, p)
|
||||
return True
|
||||
|
||||
|
||||
m = OCCGeometry(Sphere(Pnt(0, 0, 0), 1)).GenerateMesh()
|
||||
my_geo = UnitSphereGeometry()
|
||||
m.SetGeometry(my_geo)
|
||||
m.Refine()
|
||||
m.Curve(3)
|
||||
|
||||
from ngsolve import *
|
||||
|
||||
mesh = Mesh(m)
|
||||
Draw(mesh)
|
||||
Reference in New Issue
Block a user