python - NumPy indexing with varying position -


i have array input_data of shape (a, b, c), , array ind of shape (b,). want loop through b axis , take sum of elements c[b[i]] , c[b[i]+1]. desired output of shape (a, b). have following code works, feel inefficient due index-based looping through b axis. there more efficient method?

import numpy np  input_data = np.random.rand(2, 6, 10) ind = [ 2, 3, 5, 6, 5, 4 ]  out = np.zeros( ( input_data.shape[0], input_data.shape[1] ) )  in range( len(ind) ):     d = input_data[:, i, ind[i]:ind[i]+2]     out[:, i] = np.sum(d, axis = 1) 

edited based on divakar's answer:

import timeit import numpy np  n = 1000  input_data = np.random.rand(10, n, 5000) ind = ( 4999 * np.random.rand(n) ).astype(np.int)  def test_1(): # old loop-based method     out = np.zeros( ( input_data.shape[0], input_data.shape[1] ) )      in range( len(ind) ):         d = input_data[:, i, ind[i]:ind[i]+2]         out[:, i] = np.sum(d, axis = 1)     return out  def test_2():      extent = 2 # comes 2 in "ind[i]:ind[i]+2"      m,n,r = input_data.shape     idx = (np.arange(n)*r + ind)[:,none] + np.arange(extent)     out1 = input_data.reshape(m,-1)[:,idx].reshape(m,n,-1).sum(2)     return out1  print timeit.timeit(stmt = test_1, number = 1000) print timeit.timeit(stmt = test_2, number = 1000)  print np.all( test_1() == test_2(), keepdims = true )  >> 7.70429363482 >> 0.392034666757 >> [[ true]] 

here's vectorized approach using linear indexing broadcasting. merge last 2 axes of input array, calculate linear indices corresponding last 2 axes, perform slicing , reshape 3d shape. finally, summation along last axis desired output. implementation -

extent = 2 # comes 2 in "ind[i]:ind[i]+2"  m,n,r = input_data.shape idx = (np.arange(n)*r + ind)[:,none] + np.arange(extent) out1 = input_data.reshape(m,-1)[:,idx].reshape(m,n,-1).sum(2) 

if extent going 2 stated in question - "... sum of elements c[b[i]] , c[b[i]+1]", -

m,n,r = input_data.shape ind_arr = np.array(ind) axis1_r = np.arange(n) out2 = input_data[:,axis1_r,ind_arr] + input_data[:,axis1_r,ind_arr+1] 

Comments

Popular posts from this blog

javascript - jQuery: Add class depending on URL in the best way -

caching - How to check if a url path exists in the service worker cache -

Redirect to a HTTPS version using .htaccess -