UnaryOp(); Subscript; Slice; keyword(); fmt_Tensor();
This commit is contained in:
parent
c620d0b014
commit
3583f02c5b
@ -1,14 +1,23 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tests.test_utils import *
|
||||
from trace_commentor.check import Check
|
||||
|
||||
def test():
|
||||
|
||||
# @Commentor()
|
||||
@Check()
|
||||
def target(a, d=1, *b, c, k=1):
|
||||
return a + k
|
||||
@Commentor(_globals=globals())
|
||||
def target():
|
||||
x = torch.ones(4, 5)
|
||||
for i in range(3):
|
||||
x = x[..., None, :]
|
||||
|
||||
print(target(1, 2, 3, 4, c=5, k=2))
|
||||
a = torch.randn(309, 110, 3)[:100]
|
||||
f = nn.Linear(3, 128)
|
||||
b = f(a.reshape(-1, 3)).reshape(-1, 110, 128)
|
||||
c = torch.concat((a, b), dim=-1)
|
||||
|
||||
return c.flatten()
|
||||
|
||||
target()
|
||||
|
||||
|
||||
test()
|
||||
|
||||
@ -13,7 +13,6 @@ def test_constant():
|
||||
x = 2
|
||||
print(x == 2)
|
||||
"""
|
||||
<callable> : print
|
||||
True : x == 2
|
||||
None : print(x == 2)
|
||||
"""
|
||||
@ -103,7 +102,6 @@ def test_for():
|
||||
# odds.append(x)
|
||||
"""
|
||||
[] : odds
|
||||
<callable> : odds.append
|
||||
1 : x
|
||||
None : odds.append(x)
|
||||
"""
|
||||
@ -129,7 +127,6 @@ def test_for():
|
||||
# odds.append(x)
|
||||
"""
|
||||
[1] : odds
|
||||
<callable> : odds.append
|
||||
3 : x
|
||||
None : odds.append(x)
|
||||
"""
|
||||
@ -155,7 +152,6 @@ def test_for():
|
||||
# odds.append(x)
|
||||
"""
|
||||
[1, 3] : odds
|
||||
<callable> : odds.append
|
||||
5 : x
|
||||
None : odds.append(x)
|
||||
"""
|
||||
@ -181,7 +177,6 @@ def test_for():
|
||||
# odds.append(x)
|
||||
"""
|
||||
[1, 3, 5] : odds
|
||||
<callable> : odds.append
|
||||
7 : x
|
||||
None : odds.append(x)
|
||||
"""
|
||||
@ -207,7 +202,6 @@ def test_for():
|
||||
odds.append(x)
|
||||
"""
|
||||
[1, 3, 5, 7] : odds
|
||||
<callable> : odds.append
|
||||
9 : x
|
||||
None : odds.append(x)
|
||||
"""
|
||||
|
||||
@ -46,7 +46,6 @@ def test_call_print():
|
||||
def target():
|
||||
print('This line will be printed.')
|
||||
"""
|
||||
<callable> : print
|
||||
None : print('This line will be printed.')
|
||||
"""
|
||||
''')
|
||||
|
||||
107
tests/test_torch.py
Normal file
107
tests/test_torch.py
Normal file
@ -0,0 +1,107 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from test_utils import *
|
||||
|
||||
|
||||
def test_torch():
|
||||
|
||||
@Commentor("<return>", _globals=globals())
|
||||
def target():
|
||||
|
||||
x = torch.ones(4, 5)
|
||||
for i in range(3):
|
||||
x = x[..., None, :]
|
||||
|
||||
a = torch.randn(309, 110, 3)[:100]
|
||||
f = nn.Linear(3, 128)
|
||||
b = f(a.reshape(-1, 3)).reshape(-1, 110, 128)
|
||||
c = torch.concat((a, b), dim=-1)
|
||||
|
||||
return c.flatten()
|
||||
|
||||
asserteq_or_print(
|
||||
target(), ''' def target():
|
||||
x = torch.ones(4, 5)
|
||||
"""
|
||||
[4, 5] : torch.ones(4, 5)
|
||||
----------
|
||||
[4, 5] : x
|
||||
"""
|
||||
for i in range(3):
|
||||
|
||||
###### !new iteration! ######
|
||||
"""
|
||||
0 : __REG__for_loop_iter_once
|
||||
----------
|
||||
0 : i
|
||||
"""
|
||||
# x = x[..., None, :]
|
||||
"""
|
||||
[4, 5] : x
|
||||
[4, 1, 5] : x[..., None, :]
|
||||
----------
|
||||
[4, 1, 5] : x
|
||||
"""
|
||||
|
||||
###### !new iteration! ######
|
||||
"""
|
||||
1 : __REG__for_loop_iter_once
|
||||
----------
|
||||
1 : i
|
||||
"""
|
||||
# x = x[..., None, :]
|
||||
"""
|
||||
[4, 1, 5] : x
|
||||
[4, 1, 1, 5] : x[..., None, :]
|
||||
----------
|
||||
[4, 1, 1, 5] : x
|
||||
"""
|
||||
|
||||
###### !new iteration! ######
|
||||
"""
|
||||
2 : __REG__for_loop_iter_once
|
||||
----------
|
||||
2 : i
|
||||
"""
|
||||
x = x[..., None, :]
|
||||
"""
|
||||
[4, 1, 1, 5] : x
|
||||
[4, 1, 1, 1, 5] : x[..., None, :]
|
||||
----------
|
||||
[4, 1, 1, 1, 5] : x
|
||||
"""
|
||||
a = torch.randn(309, 110, 3)[:100]
|
||||
"""
|
||||
[309, 110, 3] : torch.randn(309, 110, 3)
|
||||
[100, 110, 3] : torch.randn(309, 110, 3)[:100]
|
||||
----------
|
||||
[100, 110, 3] : a
|
||||
"""
|
||||
f = nn.Linear(3, 128)
|
||||
"""
|
||||
----------
|
||||
"""
|
||||
b = f(a.reshape(-1, 3)).reshape(-1, 110, 128)
|
||||
"""
|
||||
[100, 110, 3] : a
|
||||
[11000, 3] : a.reshape(-1, 3)
|
||||
[11000, 128] : f(a.reshape(-1, 3))
|
||||
[100, 110, 128] : f(a.reshape(-1, 3)).reshape(-1, 110, 128)
|
||||
----------
|
||||
[100, 110, 128] : b
|
||||
"""
|
||||
c = torch.concat((a, b), dim=-1)
|
||||
"""
|
||||
[100, 110, 3] : a
|
||||
[100, 110, 128] : b
|
||||
[100, 110, 131] : torch.concat((a, b), dim=-1)
|
||||
----------
|
||||
[100, 110, 131] : c
|
||||
"""
|
||||
return c.flatten()
|
||||
"""
|
||||
[100, 110, 131] : c
|
||||
[1441000] : c.flatten()
|
||||
"""
|
||||
''')
|
||||
@ -14,7 +14,6 @@ def test_assign():
|
||||
myint = 7
|
||||
print(myint)
|
||||
"""
|
||||
<callable> : print
|
||||
7 : myint
|
||||
None : print(myint)
|
||||
"""
|
||||
|
||||
@ -9,14 +9,14 @@ from functools import wraps
|
||||
from . import handlers
|
||||
from . import formatters
|
||||
from . import flags
|
||||
from .utils import sign, to_source, comment_to_file, isinstance_noexcept
|
||||
from .utils import sign, to_source, comment_to_file
|
||||
|
||||
|
||||
class Commentor(object):
|
||||
|
||||
def __init__(self, output="<stderr>", _globals=dict(), fmt=[], check=True, _exit=True) -> None:
|
||||
self._locals = dict()
|
||||
self._globals = dict().update(_globals)
|
||||
self._globals = _globals
|
||||
self._return = None
|
||||
self._formatters = fmt + formatters.LIST
|
||||
self._lines = []
|
||||
@ -93,7 +93,11 @@ class Commentor(object):
|
||||
|
||||
def eval(self, node: ast.Expr, format=True):
|
||||
src = node if type(node) == str else to_source(node)
|
||||
try:
|
||||
obj = eval(src, self._globals, self._locals)
|
||||
except Exception as e:
|
||||
e.add_note(f"\tduring evaluating `{src}`")
|
||||
raise e
|
||||
if not format:
|
||||
return obj
|
||||
|
||||
@ -108,9 +112,9 @@ class Commentor(object):
|
||||
|
||||
def get_formatter(self, obj):
|
||||
for typ, fmt in self._formatters:
|
||||
if isinstance_noexcept(obj, typ):
|
||||
if isinstance(typ, type) and isinstance(obj, typ):
|
||||
return fmt
|
||||
elif callable(typ) and typ(obj):
|
||||
elif not isinstance(typ, type) and callable(typ) and typ(obj):
|
||||
return fmt
|
||||
else:
|
||||
return repr
|
||||
@ -138,14 +142,14 @@ class Commentor(object):
|
||||
return self.__append(sign(line, 2))
|
||||
|
||||
def check_support(self):
|
||||
unimpl = []
|
||||
unimpl = set()
|
||||
for node in ast.walk(self.root):
|
||||
if node.__class__ in flags.HANDLER_FREE_NODES:
|
||||
continue
|
||||
node_type = node.__class__.__name__
|
||||
handler = getattr(handlers, node_type, None)
|
||||
if handler is None:
|
||||
unimpl.append(node_type)
|
||||
unimpl.add(node_type)
|
||||
|
||||
if unimpl:
|
||||
print("Unsupported nodes: ", ", ".join(unimpl))
|
||||
|
||||
@ -23,6 +23,7 @@ ASSIGN_SILENT = [
|
||||
]
|
||||
|
||||
HANDLER_FREE_NODES = [
|
||||
ast.UAdd, ast.USub, ast.Not, ast.Invert,
|
||||
ast.Add, ast.Sub, ast.Mult, ast.Div, ast.FloorDiv, ast.Mod, ast.LShift, ast.RShift, ast.BitOr, ast.BitXor, ast.BitAnd, ast.MatMult,
|
||||
ast.Eq, ast.NotEq, ast.Lt, ast.LtE, ast.Gt, ast.GtE, ast.Is, ast.IsNot, ast.In, ast.NotIn,
|
||||
ast.Load, ast.Store,
|
||||
|
||||
@ -1,7 +1,17 @@
|
||||
def fmt_callable(fn):
|
||||
return "<callable>"
|
||||
import types
|
||||
|
||||
def silent(_):
|
||||
return None
|
||||
|
||||
def Tensor(tensor):
|
||||
return "Tensor" in tensor.__class__.__name__
|
||||
|
||||
def fmt_Tensor(tensor):
|
||||
return f"{list(tensor.shape)}"
|
||||
|
||||
|
||||
LIST = [
|
||||
(callable, fmt_callable),
|
||||
(callable, silent),
|
||||
(Tensor, fmt_Tensor),
|
||||
(types.ModuleType, silent),
|
||||
]
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from .definitions import FunctionDef, Return
|
||||
from .statements import Pass, Assign
|
||||
from .expressions import Expr, BinOp, Call, Compare, Attribute
|
||||
from .expressions import Expr, UnaryOp, BinOp, Call, Compare, Attribute, Subscript, Slice, keyword
|
||||
from .literals import Constant, Tuple, List
|
||||
from .variables import Name
|
||||
from .control_flow import If, For, Continue, Break
|
||||
|
||||
@ -1,16 +1,29 @@
|
||||
from ..utils import *
|
||||
|
||||
|
||||
def Expr(self, cmtor):
|
||||
cmtor.process(self.value)
|
||||
|
||||
|
||||
def UnaryOp(self, cmtor):
|
||||
if type(self.operand) == ast.Constant:
|
||||
return
|
||||
cmtor.process(self.operand)
|
||||
cmtor.append_comment(cmtor.eval(self))
|
||||
|
||||
|
||||
def BinOp(self, cmtor):
|
||||
cmtor.process(self.left)
|
||||
cmtor.process(self.right)
|
||||
cmtor.append_comment(cmtor.eval(self))
|
||||
|
||||
|
||||
def Compare(self, cmtor):
|
||||
for cmp in self.comparators:
|
||||
cmtor.process(cmp)
|
||||
cmtor.append_comment(cmtor.eval(self))
|
||||
|
||||
|
||||
def Call(self, cmtor):
|
||||
cmtor.process(self.func)
|
||||
for arg in self.args:
|
||||
@ -19,6 +32,26 @@ def Call(self, cmtor):
|
||||
cmtor.process(kwarg)
|
||||
cmtor.append_comment(cmtor.eval(self))
|
||||
|
||||
|
||||
def Attribute(self, cmtor):
|
||||
cmtor.process(self.value)
|
||||
cmtor.append_comment(cmtor.eval(self))
|
||||
|
||||
|
||||
def Subscript(self, cmtor):
|
||||
cmtor.process(self.value)
|
||||
cmtor.process(self.slice)
|
||||
cmtor.append_comment(cmtor.eval(self))
|
||||
|
||||
|
||||
def Slice(self, cmtor):
|
||||
if getattr(self, "lower", None) is not None:
|
||||
cmtor.process(self.lower)
|
||||
if getattr(self, "upper", None) is not None:
|
||||
cmtor.process(self.upper)
|
||||
if getattr(self, "step", None) is not None:
|
||||
cmtor.process(self.step)
|
||||
|
||||
|
||||
def keyword(self, cmtor):
|
||||
cmtor.process(self.value)
|
||||
|
||||
@ -53,10 +53,3 @@ def comment_to_file(code, file: str) -> bool:
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown file protocal {file}")
|
||||
return True
|
||||
|
||||
|
||||
def isinstance_noexcept(_obj, _class_or_tuple):
|
||||
try:
|
||||
return isinstance(_obj, _class_or_tuple)
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user