Home>

### python 3x - i want to get a matrix with the values ​​of a certain column removed for each row of the matrix with pytorch (speedu

As an example, consider the following 5x5 matrix.

``````import torch
x = torch.rand (25) .reshape (5, 5)
x
>>>
tensor ([[0.9263, 0.6601, 0.3334, 0.6175, 0.6035],
[0.2583, 0.6105, 0.3113, 0.2965, 0.9429],
[0.0350, 0.9206, 0.8667, 0.8958, 0.7814],
[0.8921, 0.8116, 0.7271, 0.1324, 0.8097],
[0.7166, 0.7780, 0.5185, 0.4530, 0.4059])``````

Also assume that i am given a column index for each row.

``````drop_indices = torch.tensor ([1, 4, 0, 3, 2])
drop_indices
>>>
tensor ([1, 4, 0, 3, 2])``````

At this time`i line`Against`drop_indices [i] column`I want to get a 5x4 matrix that looks like the value of is erased.
You want to remove 0.6601 in the first column for the first row and 0.9429 in the fourth column for the second row.
In other words, I want to get the following matrix as a result.

``````tensor ([[0.9263, 0.3334, 0.6175, 0.6035],
[0.2583, 0.6105, 0.3113, 0.2965],
[0.9206, 0.8667, 0.8958, 0.7814],
[0.8921, 0.8116, 0.7271, 0.8097],
[0.7166, 0.7780, 0.4530, 0.4059]))``````

You can turn it with for and do what I want with concat, but in reality, I want it to move at a certain high speed with a matrix of about 1000 rows and 5000 columns. (It is a part of the learning process of deep learning)

In what direction should we go? Thank you very much.

I solved it myself as follows. Thank you very much.

``````>>>import torch
>>>x = torch.rand (25) .reshape (5, 5)
>>>drop_indices = torch.tensor ([1, 4, 0, 3, 2])
>>>drop_elements = x [torch.tensor (range (5)), drop_indices] .reshape (5, 1)
>>>x_2 = x [~ mask] .reshape (5, 4)
>>>x
tensor ([[0.3465, 0.5554, 0.9142, 0.1674, 0.1353],
[0.9194, 0.2897, 0.7397, 0.5700, 0.9786],
[0.5404, 0.5266, 0.4050, 0.4092, 0.1816],
[0.9258, 0.0706, 0.9894, 0.8694, 0.3407],
[0.8124, 0.0562, 0.1115, 0.4929, 0.2795])
>>>drop_indices
tensor ([1, 4, 0, 3, 2])
>>>x_2
tensor ([[0.3465, 0.9142, 0.1674, 0.1353],
[0.9194, 0.2897, 0.7397, 0.5700],
[0.5266, 0.4050, 0.4092, 0.1816],
[0.9258, 0.0706, 0.9894, 0.3407],
[0.8124, 0.0562, 0.4929, 0.2795]))``````