Torch - Query matrix with another matrix -


i have m x n tensor (tensor 1) , k x 2 tensor (tensor 2) , wish extract values of tensor 1 using indices based on tensor 2. example;

tensor1   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17  18  19  20 [torch.doubletensor of size 4x5]  tensor2  2  1  3  5  1  1  4  3 [torch.doubletensor of size 4x2] 

and function yield;

6 15 1 18 

the first solution comes mind loop through indexes , pick correspoding values:

function get_elems_simple(tensor, indices)     local res = torch.tensor(indices:size(1)):typeas(tensor)     local = 0     res:apply(         function ()              = + 1             return tensor[indices[i]:clone():storage()]          end)     return res end 

here tensor[indices[i]:clone():storage()] generic way pick element multi-dimensional tensor. in k-dimensional case analogous tensor[{indices[i][1], ... , indices[i][k]}].

this method works fine if don't have extract lots of values (the bottleneck :apply method not able use many optimization techniques , simd instructions because function executes black box). job can done way more efficiently: method :index need... one-dimensional tensor. multi-dimensional target/index tensors need flattened:

function flatten_indices(sp_indices, shape)     sp_indices = sp_indices - 1     local n_elem, n_dim = sp_indices:size(1), sp_indices:size(2)     local flat_ind = torch.longtensor(n_elem):fill(1)      local mult = 1     d = n_dim, 1, -1         flat_ind:add(sp_indices[{{}, d}] * mult)         mult = mult * shape[d]     end     return flat_ind end  function get_elems_efficient(tensor, sp_indices)     local flat_indices = flatten_indices(sp_indices, tensor:size())      local flat_tensor = tensor:view(-1)     return flat_tensor:index(1, flat_indices) end 

the difference drastic:

n = 500000 k = 100 = torch.rand(n, k) ind = torch.longtensor(n, 2) ind[{{}, 1}]:random(1, n) ind[{{}, 2}]:random(1, k)  elems1 = get_elems_simple(a, ind)      # 4.53 sec elems2 = get_elems_efficient(a, ind)   # 0.05 sec  print(torch.all(elems1:eq(elems2)))    # true 

Comments

Popular posts from this blog

javascript - jQuery: Add class depending on URL in the best way -

caching - How to check if a url path exists in the service worker cache -

Redirect to a HTTPS version using .htaccess -