import torch import torch.nn as nn from test_utils import * def test_torch(): @Commentor("", _globals=globals()) def target(): x = torch.ones(4, 5) for i in range(3): x = x[..., None, :] a = torch.randn(309, 110, 3)[:100] f = nn.Linear(3, 128) b = f(a.reshape(-1, 3)).reshape(-1, 110, 128) c = torch.concat((a, b), dim=-1) return c.flatten() asserteq_or_print( target(), ''' def target(): x = torch.ones(4, 5) """ [4, 5] : torch.ones(4, 5) ---------- [4, 5] : x """ for i in range(3): ###### !new iteration! ###### """ 0 : __REG__for_loop_iter_once ---------- 0 : i """ # x = x[..., None, :] """ [4, 5] : x [4, 1, 5] : x[..., None, :] ---------- [4, 1, 5] : x """ ###### !new iteration! ###### """ 1 : __REG__for_loop_iter_once ---------- 1 : i """ # x = x[..., None, :] """ [4, 1, 5] : x [4, 1, 1, 5] : x[..., None, :] ---------- [4, 1, 1, 5] : x """ ###### !new iteration! ###### """ 2 : __REG__for_loop_iter_once ---------- 2 : i """ x = x[..., None, :] """ [4, 1, 1, 5] : x [4, 1, 1, 1, 5] : x[..., None, :] ---------- [4, 1, 1, 1, 5] : x """ a = torch.randn(309, 110, 3)[:100] """ [309, 110, 3] : torch.randn(309, 110, 3) [100, 110, 3] : torch.randn(309, 110, 3)[:100] ---------- [100, 110, 3] : a """ f = nn.Linear(3, 128) """ ---------- """ b = f(a.reshape(-1, 3)).reshape(-1, 110, 128) """ [100, 110, 3] : a [11000, 3] : a.reshape(-1, 3) [11000, 128] : f(a.reshape(-1, 3)) [100, 110, 128] : f(a.reshape(-1, 3)).reshape(-1, 110, 128) ---------- [100, 110, 128] : b """ c = torch.concat((a, b), dim=-1) """ [100, 110, 3] : a [100, 110, 128] : b [100, 110, 131] : torch.concat((a, b), dim=-1) ---------- [100, 110, 131] : c """ return c.flatten() """ [100, 110, 131] : c [1441000] : c.flatten() """ ''')