trace_commentor/tests/test_torch.py
2024-04-19 03:14:12 +08:00

38 lines
610 B
Python

import torch
import torch.nn as nn
from trace_commentor.parser import *
def test_func_no_arg():
@analyse
def target():
x = torch.ones(4, 5)
for i in range(3):
x = x[..., None, :]
a = torch.randn(309, 110, 3)[:100]
f = nn.Linear(3, 128)
b = f(a.reshape(-1, 3)).reshape(309, 110, 128)
c = torch.concat((a, b), dim=-1)
return c.flatten()
print()
target()
def test_for_loop():
@analyse
def target():
a = 1
for i in range(3):
a += 1
print(a)
print()
target()