Home>

As an example, consider the following 5x5 matrix.

import torch
x = torch.rand (25) .reshape (5, 5)
x
>>>
tensor ([[0.9263, 0.6601, 0.3334, 0.6175, 0.6035],
        [0.2583, 0.6105, 0.3113, 0.2965, 0.9429],
        [0.0350, 0.9206, 0.8667, 0.8958, 0.7814],
        [0.8921, 0.8116, 0.7271, 0.1324, 0.8097],
        [0.7166, 0.7780, 0.5185, 0.4530, 0.4059])


Also assume that i am given a column index for each row.

drop_indices = torch.tensor ([1, 4, 0, 3, 2])
drop_indices
>>>
tensor ([1, 4, 0, 3, 2])


At this timei lineAgainstdrop_indices [i] columnI want to get a 5x4 matrix that looks like the value of is erased.
You want to remove 0.6601 in the first column for the first row and 0.9429 in the fourth column for the second row.
In other words, I want to get the following matrix as a result.

tensor ([[0.9263, 0.3334, 0.6175, 0.6035],
        [0.2583, 0.6105, 0.3113, 0.2965],
        [0.9206, 0.8667, 0.8958, 0.7814],
        [0.8921, 0.8116, 0.7271, 0.8097],
        [0.7166, 0.7780, 0.4530, 0.4059]))

You can turn it with for and do what I want with concat, but in reality, I want it to move at a certain high speed with a matrix of about 1000 rows and 5000 columns. (It is a part of the learning process of deep learning)

In what direction should we go? Thank you very much.

  • Answer # 1

    I solved it myself as follows. Thank you very much.

    >>>import torch
    >>>x = torch.rand (25) .reshape (5, 5)
    >>>drop_indices = torch.tensor ([1, 4, 0, 3, 2])
    >>>drop_elements = x [torch.tensor (range (5)), drop_indices] .reshape (5, 1)
    >>>mask = x == drop_elements
    >>>x_2 = x [~ mask] .reshape (5, 4)
    >>>x
    tensor ([[0.3465, 0.5554, 0.9142, 0.1674, 0.1353],
            [0.9194, 0.2897, 0.7397, 0.5700, 0.9786],
            [0.5404, 0.5266, 0.4050, 0.4092, 0.1816],
            [0.9258, 0.0706, 0.9894, 0.8694, 0.3407],
            [0.8124, 0.0562, 0.1115, 0.4929, 0.2795])
    >>>drop_indices
    tensor ([1, 4, 0, 3, 2])
    >>>x_2
    tensor ([[0.3465, 0.9142, 0.1674, 0.1353],
            [0.9194, 0.2897, 0.7397, 0.5700],
            [0.5266, 0.4050, 0.4092, 0.1816],
            [0.9258, 0.0706, 0.9894, 0.3407],
            [0.8124, 0.0562, 0.4929, 0.2795]))