import torch import torch.nn as nn from tests.test_utils import * def test(): @Commentor(_globals=globals()) def target(): x = torch.ones(4, 5) for i in range(3): x = x[..., None, :] a = torch.randn(309, 110, 3)[:100] f = nn.Linear(3, 128) b = f(a.reshape(-1, 3)).reshape(-1, 110, 128) c = torch.concat((a, b), dim=-1) return c.flatten() target() test()