trace_commentor/experiment.py

24 lines
449 B
Python

import torch
import torch.nn as nn
from tests.test_utils import *
def test():
@Commentor(_globals=globals())
def target():
x = torch.ones(4, 5)
for i in range(3):
x = x[..., None, :]
a = torch.randn(309, 110, 3)[:100]
f = nn.Linear(3, 128)
b = f(a.reshape(-1, 3)).reshape(-1, 110, 128)
c = torch.concat((a, b), dim=-1)
return c.flatten()
target()
test()