trace_commentor/tests/test_torch.py

81 lines
2.0 KiB
Python

import torch
import torch.nn as nn
from test_utils import *
def test_torch():
@Commentor("<return>")
def target():
x = torch.ones(4, 5)
for i in range(3):
x = x[..., None, :]
a = torch.randn(309, 110, 3)[:100]
f = nn.Linear(3, 128)
b = f(a.reshape(-1, 3)).reshape(-1, 110, 128)
c = torch.concat((a, b), dim=-1)
return c.flatten()
asserteq_or_print(
target(), '''
def target():
x = torch.ones(4, 5)
"""
Tensor((4, 5), f32) : torch.ones(4, 5)
----------
Tensor((4, 5), f32) : x
"""
for i in range(3):
x = x[..., None, :]
"""
Tensor((4, 1, 1, 5), f32) : x
Tensor((4, 1, 1, 1, 5), f32) : x[..., None, :]
----------
Tensor((4, 1, 1, 1, 5), f32) : x
"""
a = torch.randn(309, 110, 3)[:100]
"""
Tensor((309, 110, 3), f32) : torch.randn(309, 110, 3)
Tensor((100, 110, 3), f32) : torch.randn(309, 110, 3)[:100]
----------
Tensor((100, 110, 3), f32) : a
"""
f = nn.Linear(3, 128)
"""
----------
"""
b = f(a.reshape(-1, 3)).reshape(-1, 110, 128)
"""
Tensor((100, 110, 3), f32) : a
Tensor((11000, 3), f32) : a.reshape(-1, 3)
Tensor((11000, 128), f32) : f(a.reshape(-1, 3))
Tensor((100, 110, 128), f32) : f(a.reshape(-1, ... pe(-1, 110, 128)
----------
Tensor((100, 110, 128), f32) : b
"""
c = torch.concat((a, b), dim=-1)
"""
Tensor((100, 110, 3), f32) : a
Tensor((100, 110, 128), f32) : b
Tensor((100, 110, 131), f32) : torch.concat((a, b), dim=-1)
----------
Tensor((100, 110, 131), f32) : c
"""
return c.flatten()
"""
Tensor((100, 110, 131), f32) : c
Tensor((1441000,), f32) : c.flatten()
"""
''')