小编典典

如何恢复平坦的Numpy数组的原始索引?

python

我有一个多维numpy数组,试图将其粘贴到熊猫数据框中。我想展平数组,并创建一个反映预展平数组索引的熊猫索引。

请注意,我使用3D来缩小示例,但我想将其推广到至少4D

A = np.random.rand(2,3,4)
array([[[ 0.43793885,  0.40078139,  0.48078691,  0.05334248],
    [ 0.76331509,  0.82514441,  0.86169078,  0.86496111],
    [ 0.75572665,  0.80860943,  0.79995337,  0.63123724]],

   [[ 0.20648946,  0.57042315,  0.71777265,  0.34155005],
    [ 0.30843717,  0.39381407,  0.12623462,  0.93481552],
    [ 0.3267771 ,  0.64097038,  0.30405215,  0.57726629]]])

df = pd.DataFrame(A.flatten())

我正在尝试生成像这样的x / y / z列:

           A  z  y  x
0   0.437939  0  0  0
1   0.400781  0  0  1
2   0.480787  0  0  2
3   0.053342  0  0  3
4   0.763315  0  1  0
5   0.825144  0  1  1
6   0.861691  0  1  2
7   0.864961  0  1  3
...
21  0.640970  1  2  1
22  0.304052  1  2  2
23  0.577266  1  2  3

我尝试使用进行设置,np.meshgrid但在某处出错:

dimnames = ['z', 'y', 'x']
ranges   = [ np.arange(x) for x in A.shape ]
ix       = [ x.flatten()  for x in np.meshgrid(*ranges) ]
for name, col in zip(dimnames, ix):
    df[name] = col
df = df.set_index(dimnames).squeeze()

这个结果看起来有些明智,但是索引是错误的:

df
z  y  x
0  0  0    0.437939
      1    0.400781
      2    0.480787
      3    0.053342
1  0  0    0.763315
      1    0.825144
      2    0.861691
      3    0.864961
0  1  0    0.755727
      1    0.808609
      2    0.799953
      3    0.631237
1  1  0    0.206489
      1    0.570423
      2    0.717773
      3    0.341550
0  2  0    0.308437
      1    0.393814
      2    0.126235
      3    0.934816
1  2  0    0.326777
      1    0.640970
      2    0.304052
      3    0.577266

print A[0,1,0]
0.76331508999999997

print print df.loc[0,1,0]
0.75572665000000006

如何创建索引列以反映的形状A


阅读 245

收藏
2021-01-20

共1个答案

小编典典

您可以使用pd.MultiIndex.from_product:

import numpy as np
import pandas as pd
import string

def using_multiindex(A, columns):
    shape = A.shape
    index = pd.MultiIndex.from_product([range(s)for s in shape], names=columns)
    df = pd.DataFrame({'A': A.flatten()}, index=index).reset_index()
    return df

A = np.array([[[ 0.43793885,  0.40078139,  0.48078691,  0.05334248],
    [ 0.76331509,  0.82514441,  0.86169078,  0.86496111],
    [ 0.75572665,  0.80860943,  0.79995337,  0.63123724]],

   [[ 0.20648946,  0.57042315,  0.71777265,  0.34155005],
    [ 0.30843717,  0.39381407,  0.12623462,  0.93481552],
    [ 0.3267771 ,  0.64097038,  0.30405215,  0.57726629]]])

df = using_multiindex(A, list('ZYX'))

产量

    Z  Y  X         A
0   0  0  0  0.437939
1   0  0  1  0.400781
2   0  0  2  0.480787
3   0  0  3  0.053342
...
21  1  2  1  0.640970
22  1  2  2  0.304052
23  1  2  3  0.577266

或者,如果性能是重中之重,请考虑使用senderle的cartesian_product。(请参见下面的代码。)

这是形状为(100,100,100)的A的基准:

In [321]: %timeit  using_cartesian_product(A, columns)
100 loops, best of 3: 13.8 ms per loop

In [318]: %timeit using_multiindex(A, columns)
10 loops, best of 3: 35.6 ms per loop

In [320]: %timeit indices_merged_arr_generic(A, columns)
10 loops, best of 3: 29.1 ms per loop

In [319]: %timeit using_product(A)
1 loop, best of 3: 461 ms per loop

这是我用于基准测试的设置:

import numpy as np
import pandas as pd
import functools
import itertools as IT
import string
product = IT.product

def cartesian_product_broadcasted(*arrays):
    """
    http://stackoverflow.com/a/11146645/190597 (senderle)
    """
    broadcastable = np.ix_(*arrays)
    broadcasted = np.broadcast_arrays(*broadcastable)
    dtype = np.result_type(*arrays)
    rows, cols = functools.reduce(np.multiply, broadcasted[0].shape), len(broadcasted)
    out = np.empty(rows * cols, dtype=dtype)
    start, end = 0, rows
    for a in broadcasted:
        out[start:end] = a.reshape(-1)
        start, end = end, end + rows
    return out.reshape(cols, rows).T

def using_cartesian_product(A, columns):
    shape = A.shape
    coords = cartesian_product_broadcasted(*[np.arange(s, dtype='int') for s in shape])
    df = pd.DataFrame(coords, columns=columns)
    df['A'] = A.flatten()
    return df

def using_multiindex(A, columns):
    shape = A.shape
    index = pd.MultiIndex.from_product([range(s)for s in shape], names=columns)
    df = pd.DataFrame({'A': A.flatten()}, index=index).reset_index()
    return df
2021-01-20