trace_commentor/tests/test_torch.py
2024-04-23 18:11:15 +08:00

108 lines
2.6 KiB
Python

import torch
import torch.nn as nn
from test_utils import *
def test_torch():
@Commentor("<return>")
def target():
x = torch.ones(4, 5)
for i in range(3):
x = x[..., None, :]
a = torch.randn(309, 110, 3)[:100]
f = nn.Linear(3, 128)
b = f(a.reshape(-1, 3)).reshape(-1, 110, 128)
c = torch.concat((a, b), dim=-1)
return c.flatten()
asserteq_or_print(
target(), ''' def target():
x = torch.ones(4, 5)
"""
[4, 5] : torch.ones(4, 5)
----------
[4, 5] : x
"""
for i in range(3):
###### !new iteration! ######
"""
0 : __REG__for_loop_iter_once
----------
0 : i
"""
# x = x[..., None, :]
"""
[4, 5] : x
[4, 1, 5] : x[..., None, :]
----------
[4, 1, 5] : x
"""
###### !new iteration! ######
"""
1 : __REG__for_loop_iter_once
----------
1 : i
"""
# x = x[..., None, :]
"""
[4, 1, 5] : x
[4, 1, 1, 5] : x[..., None, :]
----------
[4, 1, 1, 5] : x
"""
###### !new iteration! ######
"""
2 : __REG__for_loop_iter_once
----------
2 : i
"""
x = x[..., None, :]
"""
[4, 1, 1, 5] : x
[4, 1, 1, 1, 5] : x[..., None, :]
----------
[4, 1, 1, 1, 5] : x
"""
a = torch.randn(309, 110, 3)[:100]
"""
[309, 110, 3] : torch.randn(309, 110, 3)
[100, 110, 3] : torch.randn(309, 110, 3)[:100]
----------
[100, 110, 3] : a
"""
f = nn.Linear(3, 128)
"""
----------
"""
b = f(a.reshape(-1, 3)).reshape(-1, 110, 128)
"""
[100, 110, 3] : a
[11000, 3] : a.reshape(-1, 3)
[11000, 128] : f(a.reshape(-1, 3))
[100, 110, 128] : f(a.reshape(-1, 3)).reshape(-1, 110, 128)
----------
[100, 110, 128] : b
"""
c = torch.concat((a, b), dim=-1)
"""
[100, 110, 3] : a
[100, 110, 128] : b
[100, 110, 131] : torch.concat((a, b), dim=-1)
----------
[100, 110, 131] : c
"""
return c.flatten()
"""
[100, 110, 131] : c
[1441000] : c.flatten()
"""
''')