24 lines
449 B
Python
24 lines
449 B
Python
import torch
|
|
import torch.nn as nn
|
|
from tests.test_utils import *
|
|
|
|
def test():
|
|
|
|
@Commentor(_globals=globals())
|
|
def target():
|
|
x = torch.ones(4, 5)
|
|
for i in range(3):
|
|
x = x[..., None, :]
|
|
|
|
a = torch.randn(309, 110, 3)[:100]
|
|
f = nn.Linear(3, 128)
|
|
b = f(a.reshape(-1, 3)).reshape(-1, 110, 128)
|
|
c = torch.concat((a, b), dim=-1)
|
|
|
|
return c.flatten()
|
|
|
|
target()
|
|
|
|
|
|
test()
|