UnaryOp(); Subscript; Slice; keyword(); fmt_Tensor();

This commit is contained in:
Yuyao Huang 2024-04-22 21:08:41 +08:00
parent c620d0b014
commit 3583f02c5b
11 changed files with 181 additions and 32 deletions

View File

@ -1,14 +1,23 @@
import torch
import torch.nn as nn
from tests.test_utils import * from tests.test_utils import *
from trace_commentor.check import Check
def test(): def test():
# @Commentor() @Commentor(_globals=globals())
@Check() def target():
def target(a, d=1, *b, c, k=1): x = torch.ones(4, 5)
return a + k for i in range(3):
x = x[..., None, :]
print(target(1, 2, 3, 4, c=5, k=2)) a = torch.randn(309, 110, 3)[:100]
f = nn.Linear(3, 128)
b = f(a.reshape(-1, 3)).reshape(-1, 110, 128)
c = torch.concat((a, b), dim=-1)
return c.flatten()
target()
test() test()

View File

@ -13,7 +13,6 @@ def test_constant():
x = 2 x = 2
print(x == 2) print(x == 2)
""" """
<callable> : print
True : x == 2 True : x == 2
None : print(x == 2) None : print(x == 2)
""" """
@ -103,7 +102,6 @@ def test_for():
# odds.append(x) # odds.append(x)
""" """
[] : odds [] : odds
<callable> : odds.append
1 : x 1 : x
None : odds.append(x) None : odds.append(x)
""" """
@ -129,7 +127,6 @@ def test_for():
# odds.append(x) # odds.append(x)
""" """
[1] : odds [1] : odds
<callable> : odds.append
3 : x 3 : x
None : odds.append(x) None : odds.append(x)
""" """
@ -155,7 +152,6 @@ def test_for():
# odds.append(x) # odds.append(x)
""" """
[1, 3] : odds [1, 3] : odds
<callable> : odds.append
5 : x 5 : x
None : odds.append(x) None : odds.append(x)
""" """
@ -181,7 +177,6 @@ def test_for():
# odds.append(x) # odds.append(x)
""" """
[1, 3, 5] : odds [1, 3, 5] : odds
<callable> : odds.append
7 : x 7 : x
None : odds.append(x) None : odds.append(x)
""" """
@ -207,7 +202,6 @@ def test_for():
odds.append(x) odds.append(x)
""" """
[1, 3, 5, 7] : odds [1, 3, 5, 7] : odds
<callable> : odds.append
9 : x 9 : x
None : odds.append(x) None : odds.append(x)
""" """

View File

@ -46,7 +46,6 @@ def test_call_print():
def target(): def target():
print('This line will be printed.') print('This line will be printed.')
""" """
<callable> : print
None : print('This line will be printed.') None : print('This line will be printed.')
""" """
''') ''')

107
tests/test_torch.py Normal file
View File

@ -0,0 +1,107 @@
import torch
import torch.nn as nn
from test_utils import *
def test_torch():
@Commentor("<return>", _globals=globals())
def target():
x = torch.ones(4, 5)
for i in range(3):
x = x[..., None, :]
a = torch.randn(309, 110, 3)[:100]
f = nn.Linear(3, 128)
b = f(a.reshape(-1, 3)).reshape(-1, 110, 128)
c = torch.concat((a, b), dim=-1)
return c.flatten()
asserteq_or_print(
target(), ''' def target():
x = torch.ones(4, 5)
"""
[4, 5] : torch.ones(4, 5)
----------
[4, 5] : x
"""
for i in range(3):
###### !new iteration! ######
"""
0 : __REG__for_loop_iter_once
----------
0 : i
"""
# x = x[..., None, :]
"""
[4, 5] : x
[4, 1, 5] : x[..., None, :]
----------
[4, 1, 5] : x
"""
###### !new iteration! ######
"""
1 : __REG__for_loop_iter_once
----------
1 : i
"""
# x = x[..., None, :]
"""
[4, 1, 5] : x
[4, 1, 1, 5] : x[..., None, :]
----------
[4, 1, 1, 5] : x
"""
###### !new iteration! ######
"""
2 : __REG__for_loop_iter_once
----------
2 : i
"""
x = x[..., None, :]
"""
[4, 1, 1, 5] : x
[4, 1, 1, 1, 5] : x[..., None, :]
----------
[4, 1, 1, 1, 5] : x
"""
a = torch.randn(309, 110, 3)[:100]
"""
[309, 110, 3] : torch.randn(309, 110, 3)
[100, 110, 3] : torch.randn(309, 110, 3)[:100]
----------
[100, 110, 3] : a
"""
f = nn.Linear(3, 128)
"""
----------
"""
b = f(a.reshape(-1, 3)).reshape(-1, 110, 128)
"""
[100, 110, 3] : a
[11000, 3] : a.reshape(-1, 3)
[11000, 128] : f(a.reshape(-1, 3))
[100, 110, 128] : f(a.reshape(-1, 3)).reshape(-1, 110, 128)
----------
[100, 110, 128] : b
"""
c = torch.concat((a, b), dim=-1)
"""
[100, 110, 3] : a
[100, 110, 128] : b
[100, 110, 131] : torch.concat((a, b), dim=-1)
----------
[100, 110, 131] : c
"""
return c.flatten()
"""
[100, 110, 131] : c
[1441000] : c.flatten()
"""
''')

View File

@ -14,7 +14,6 @@ def test_assign():
myint = 7 myint = 7
print(myint) print(myint)
""" """
<callable> : print
7 : myint 7 : myint
None : print(myint) None : print(myint)
""" """

View File

@ -9,14 +9,14 @@ from functools import wraps
from . import handlers from . import handlers
from . import formatters from . import formatters
from . import flags from . import flags
from .utils import sign, to_source, comment_to_file, isinstance_noexcept from .utils import sign, to_source, comment_to_file
class Commentor(object): class Commentor(object):
def __init__(self, output="<stderr>", _globals=dict(), fmt=[], check=True, _exit=True) -> None: def __init__(self, output="<stderr>", _globals=dict(), fmt=[], check=True, _exit=True) -> None:
self._locals = dict() self._locals = dict()
self._globals = dict().update(_globals) self._globals = _globals
self._return = None self._return = None
self._formatters = fmt + formatters.LIST self._formatters = fmt + formatters.LIST
self._lines = [] self._lines = []
@ -93,7 +93,11 @@ class Commentor(object):
def eval(self, node: ast.Expr, format=True): def eval(self, node: ast.Expr, format=True):
src = node if type(node) == str else to_source(node) src = node if type(node) == str else to_source(node)
try:
obj = eval(src, self._globals, self._locals) obj = eval(src, self._globals, self._locals)
except Exception as e:
e.add_note(f"\tduring evaluating `{src}`")
raise e
if not format: if not format:
return obj return obj
@ -108,9 +112,9 @@ class Commentor(object):
def get_formatter(self, obj): def get_formatter(self, obj):
for typ, fmt in self._formatters: for typ, fmt in self._formatters:
if isinstance_noexcept(obj, typ): if isinstance(typ, type) and isinstance(obj, typ):
return fmt return fmt
elif callable(typ) and typ(obj): elif not isinstance(typ, type) and callable(typ) and typ(obj):
return fmt return fmt
else: else:
return repr return repr
@ -138,14 +142,14 @@ class Commentor(object):
return self.__append(sign(line, 2)) return self.__append(sign(line, 2))
def check_support(self): def check_support(self):
unimpl = [] unimpl = set()
for node in ast.walk(self.root): for node in ast.walk(self.root):
if node.__class__ in flags.HANDLER_FREE_NODES: if node.__class__ in flags.HANDLER_FREE_NODES:
continue continue
node_type = node.__class__.__name__ node_type = node.__class__.__name__
handler = getattr(handlers, node_type, None) handler = getattr(handlers, node_type, None)
if handler is None: if handler is None:
unimpl.append(node_type) unimpl.add(node_type)
if unimpl: if unimpl:
print("Unsupported nodes: ", ", ".join(unimpl)) print("Unsupported nodes: ", ", ".join(unimpl))

View File

@ -23,6 +23,7 @@ ASSIGN_SILENT = [
] ]
HANDLER_FREE_NODES = [ HANDLER_FREE_NODES = [
ast.UAdd, ast.USub, ast.Not, ast.Invert,
ast.Add, ast.Sub, ast.Mult, ast.Div, ast.FloorDiv, ast.Mod, ast.LShift, ast.RShift, ast.BitOr, ast.BitXor, ast.BitAnd, ast.MatMult, ast.Add, ast.Sub, ast.Mult, ast.Div, ast.FloorDiv, ast.Mod, ast.LShift, ast.RShift, ast.BitOr, ast.BitXor, ast.BitAnd, ast.MatMult,
ast.Eq, ast.NotEq, ast.Lt, ast.LtE, ast.Gt, ast.GtE, ast.Is, ast.IsNot, ast.In, ast.NotIn, ast.Eq, ast.NotEq, ast.Lt, ast.LtE, ast.Gt, ast.GtE, ast.Is, ast.IsNot, ast.In, ast.NotIn,
ast.Load, ast.Store, ast.Load, ast.Store,

View File

@ -1,7 +1,17 @@
def fmt_callable(fn): import types
return "<callable>"
def silent(_):
return None
def Tensor(tensor):
return "Tensor" in tensor.__class__.__name__
def fmt_Tensor(tensor):
return f"{list(tensor.shape)}"
LIST = [ LIST = [
(callable, fmt_callable), (callable, silent),
(Tensor, fmt_Tensor),
(types.ModuleType, silent),
] ]

View File

@ -1,6 +1,6 @@
from .definitions import FunctionDef, Return from .definitions import FunctionDef, Return
from .statements import Pass, Assign from .statements import Pass, Assign
from .expressions import Expr, BinOp, Call, Compare, Attribute from .expressions import Expr, UnaryOp, BinOp, Call, Compare, Attribute, Subscript, Slice, keyword
from .literals import Constant, Tuple, List from .literals import Constant, Tuple, List
from .variables import Name from .variables import Name
from .control_flow import If, For, Continue, Break from .control_flow import If, For, Continue, Break

View File

@ -1,16 +1,29 @@
from ..utils import *
def Expr(self, cmtor): def Expr(self, cmtor):
cmtor.process(self.value) cmtor.process(self.value)
def UnaryOp(self, cmtor):
if type(self.operand) == ast.Constant:
return
cmtor.process(self.operand)
cmtor.append_comment(cmtor.eval(self))
def BinOp(self, cmtor): def BinOp(self, cmtor):
cmtor.process(self.left) cmtor.process(self.left)
cmtor.process(self.right) cmtor.process(self.right)
cmtor.append_comment(cmtor.eval(self)) cmtor.append_comment(cmtor.eval(self))
def Compare(self, cmtor): def Compare(self, cmtor):
for cmp in self.comparators: for cmp in self.comparators:
cmtor.process(cmp) cmtor.process(cmp)
cmtor.append_comment(cmtor.eval(self)) cmtor.append_comment(cmtor.eval(self))
def Call(self, cmtor): def Call(self, cmtor):
cmtor.process(self.func) cmtor.process(self.func)
for arg in self.args: for arg in self.args:
@ -19,6 +32,26 @@ def Call(self, cmtor):
cmtor.process(kwarg) cmtor.process(kwarg)
cmtor.append_comment(cmtor.eval(self)) cmtor.append_comment(cmtor.eval(self))
def Attribute(self, cmtor): def Attribute(self, cmtor):
cmtor.process(self.value) cmtor.process(self.value)
cmtor.append_comment(cmtor.eval(self)) cmtor.append_comment(cmtor.eval(self))
def Subscript(self, cmtor):
cmtor.process(self.value)
cmtor.process(self.slice)
cmtor.append_comment(cmtor.eval(self))
def Slice(self, cmtor):
if getattr(self, "lower", None) is not None:
cmtor.process(self.lower)
if getattr(self, "upper", None) is not None:
cmtor.process(self.upper)
if getattr(self, "step", None) is not None:
cmtor.process(self.step)
def keyword(self, cmtor):
cmtor.process(self.value)

View File

@ -53,10 +53,3 @@ def comment_to_file(code, file: str) -> bool:
else: else:
raise NotImplementedError(f"Unknown file protocal {file}") raise NotImplementedError(f"Unknown file protocal {file}")
return True return True
def isinstance_noexcept(_obj, _class_or_tuple):
try:
return isinstance(_obj, _class_or_tuple)
except TypeError:
return False