我正在尝试开发一个穷人的NBody模拟,我想通过只考虑最近的邻居来加快交互的速度。似乎最简单的方法是将所有的点坐标绑定到一个粗略的网格中,然后循环遍历包含在最近的相邻网格点中的点。
最初,我打算做一个quadTree来查询最近的邻居,但是数据在一次又一次迭代中是相当动态的,我希望处理具有超过一百万个坐标点的模拟,所以对我来说,基于网格的解决方案似乎是一个更好的起点。
我目前有以下可用的代码:
import numpy as np
cimport numpy as np
from cython.view cimport array as cvarray
cimport cython
from libc.math cimport sinh, cosh, sin, cos, acos, exp, sqrt, fabs, M_PI, floor
from collections import defaultdict
DTYPE = np.float64
ctypedef np.float64_t DTYPE_t
cdef DTYPE_t pi = 3.141592653589793
# Define Grid Structure
@cython.cdivision(True)
@cython.boundscheck(False) # turn off bounds-checking for entire function
@cython.wraparound(False) # turn off negative index wrapping for entire function
def gridStructure(DTYPE_t[:,:,:] X, Py_ssize_t ref, int xSize, int ySize):
Grid = defaultdict(int)
cdef Py_ssize_t ii,jj
cdef Py_ssize_t N = X.shape[1]
cdef DTYPE_t[:] maxes = np.max(X[ref,:,:],axis=0)
cdef DTYPE_t[:] mines = np.min(X[ref,:,:],axis=0)
cdef Py_ssize_t pointX, pointY, index
for ii in range(0,N):
pointX = int(floor((X[ref,ii,0]-mines[0])/maxes[0]*xSize))
pointY = int(floor((X[ref,ii,1]-mines[1])/maxes[1]*ySize))
index = pointX + xSize*pointY
if index in Grid.keys():
Grid[index].append(ii)
else:
Grid[index] = [ii]
return Grid
我使用python的默认dict结构来存储键(一个整数网格索引)和一组指向内存视图中坐标的整数。可以安全地假设网格是稀疏的;坐标通常具有异构分布。我认为我的结构足够简单,我应该使用基于C的结构而不是dict,这样我就可以摆脱python包装器。
最终,这个函数不会被用户直接调用,所以在这一层我不需要任何python交互。
目前,代码可以在大约一秒钟内将1000000个点绑定到256x256网格,对于潜在的实时穷人的模拟,有可能更快地执行此操作吗?任何关于适当的数据结构以加速此函数的建议都将不胜感激。我基本上想返回一些数据结构,在这里我可以调用一个基于整数的键,并返回一组与该键相关的点。谢谢!
发布于 2019-02-14 21:45:58
好吧,我有一个更好的解决方案。我在这里找到了这篇文章:Efficient (and well explained) implementation of a Quadtree for 2D collision detection
他在第二个答案的开头谈到了如何在网格上做逻辑。由于网格的大小是固定的,sim中的元素数量也是固定的,所以如果我只是预先分配内存,速度会提高20倍。下面是新的网格函数:
import numpy as np
cimport numpy as np
cimport cython
from libc.math cimport floor
DTYPE = np.float64
ctypedef np.float64_t DTYPE_t
@cython.cdivision(True)
@cython.boundscheck(False)
@cython.wraparound(False)
def gridStructure(DTYPE_t[:,:,:] X, np.int_t ref, np.int_t xSize, np.int_t[:] grid, np.int_t[:] eList):
cdef Py_ssize_t N = X.shape[1]
cdef np.int_t[:] lastEl = np.zeros(xSize*xSize, dtype=np.int)
cdef DTYPE_t[:] maxes = np.max(X[ref,:,:],axis=0)
cdef DTYPE_t[:] mines = np.min(X[ref,:,:],axis=0)
cdef DTYPE_t widthX = max(maxes[0]-mines[0],maxes[1]-mines[1])
cdef DTYPE_t ratio = xSize/widthX
cdef np.int_t index,pointX,pointY
cdef np.int_t ii,zz
for ii in range(0,N):
pointX = int(floor((X[ref,ii,0]-mines[0])*ratio*0.99))
pointY = int(floor((X[ref,ii,1]-mines[1])*ratio*0.99))
index = int(pointX + xSize*pointY)
if grid[index] == -1:
grid[index] = ii
lastEl[index] = ii
else:
zz = lastEl[index]
eList[zz] = ii
lastEl[index] = ii
return 0
256x256网格中的1,000,000个点在50毫秒内。我现在可以接受这一点。
代码假设网格和元素列表(eList)的元素在输入时使用-1填充。
编辑:我删除了原始答案,因为我之前发布的解决方案是错误的。我不得不添加一个额外的内存缓冲区来存储每个索引的最后一个指针位置。
https://stackoverflow.com/questions/54642652
复制