108 lines
2.6 KiB
Python
108 lines
2.6 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
from test_utils import *
|
|
|
|
|
|
def test_torch():
|
|
|
|
@Commentor("<return>", _globals=globals())
|
|
def target():
|
|
|
|
x = torch.ones(4, 5)
|
|
for i in range(3):
|
|
x = x[..., None, :]
|
|
|
|
a = torch.randn(309, 110, 3)[:100]
|
|
f = nn.Linear(3, 128)
|
|
b = f(a.reshape(-1, 3)).reshape(-1, 110, 128)
|
|
c = torch.concat((a, b), dim=-1)
|
|
|
|
return c.flatten()
|
|
|
|
asserteq_or_print(
|
|
target(), ''' def target():
|
|
x = torch.ones(4, 5)
|
|
"""
|
|
[4, 5] : torch.ones(4, 5)
|
|
----------
|
|
[4, 5] : x
|
|
"""
|
|
for i in range(3):
|
|
|
|
###### !new iteration! ######
|
|
"""
|
|
0 : __REG__for_loop_iter_once
|
|
----------
|
|
0 : i
|
|
"""
|
|
# x = x[..., None, :]
|
|
"""
|
|
[4, 5] : x
|
|
[4, 1, 5] : x[..., None, :]
|
|
----------
|
|
[4, 1, 5] : x
|
|
"""
|
|
|
|
###### !new iteration! ######
|
|
"""
|
|
1 : __REG__for_loop_iter_once
|
|
----------
|
|
1 : i
|
|
"""
|
|
# x = x[..., None, :]
|
|
"""
|
|
[4, 1, 5] : x
|
|
[4, 1, 1, 5] : x[..., None, :]
|
|
----------
|
|
[4, 1, 1, 5] : x
|
|
"""
|
|
|
|
###### !new iteration! ######
|
|
"""
|
|
2 : __REG__for_loop_iter_once
|
|
----------
|
|
2 : i
|
|
"""
|
|
x = x[..., None, :]
|
|
"""
|
|
[4, 1, 1, 5] : x
|
|
[4, 1, 1, 1, 5] : x[..., None, :]
|
|
----------
|
|
[4, 1, 1, 1, 5] : x
|
|
"""
|
|
a = torch.randn(309, 110, 3)[:100]
|
|
"""
|
|
[309, 110, 3] : torch.randn(309, 110, 3)
|
|
[100, 110, 3] : torch.randn(309, 110, 3)[:100]
|
|
----------
|
|
[100, 110, 3] : a
|
|
"""
|
|
f = nn.Linear(3, 128)
|
|
"""
|
|
----------
|
|
"""
|
|
b = f(a.reshape(-1, 3)).reshape(-1, 110, 128)
|
|
"""
|
|
[100, 110, 3] : a
|
|
[11000, 3] : a.reshape(-1, 3)
|
|
[11000, 128] : f(a.reshape(-1, 3))
|
|
[100, 110, 128] : f(a.reshape(-1, 3)).reshape(-1, 110, 128)
|
|
----------
|
|
[100, 110, 128] : b
|
|
"""
|
|
c = torch.concat((a, b), dim=-1)
|
|
"""
|
|
[100, 110, 3] : a
|
|
[100, 110, 128] : b
|
|
[100, 110, 131] : torch.concat((a, b), dim=-1)
|
|
----------
|
|
[100, 110, 131] : c
|
|
"""
|
|
return c.flatten()
|
|
"""
|
|
[100, 110, 131] : c
|
|
[1441000] : c.flatten()
|
|
"""
|
|
''')
|