diff --git a/experiment.py b/experiment.py index 990cccd..f9ca756 100644 --- a/experiment.py +++ b/experiment.py @@ -1,14 +1,23 @@ +import torch +import torch.nn as nn from tests.test_utils import * -from trace_commentor.check import Check def test(): - # @Commentor() - @Check() - def target(a, d=1, *b, c, k=1): - return a + k + @Commentor(_globals=globals()) + def target(): + x = torch.ones(4, 5) + for i in range(3): + x = x[..., None, :] - print(target(1, 2, 3, 4, c=5, k=2)) + a = torch.randn(309, 110, 3)[:100] + f = nn.Linear(3, 128) + b = f(a.reshape(-1, 3)).reshape(-1, 110, 128) + c = torch.concat((a, b), dim=-1) + + return c.flatten() + + target() test() diff --git a/tests/test_control_flow.py b/tests/test_control_flow.py index 7013ca1..4ce60f7 100644 --- a/tests/test_control_flow.py +++ b/tests/test_control_flow.py @@ -13,7 +13,6 @@ def test_constant(): x = 2 print(x == 2) """ - : print True : x == 2 None : print(x == 2) """ @@ -103,7 +102,6 @@ def test_for(): # odds.append(x) """ [] : odds - : odds.append 1 : x None : odds.append(x) """ @@ -129,7 +127,6 @@ def test_for(): # odds.append(x) """ [1] : odds - : odds.append 3 : x None : odds.append(x) """ @@ -155,7 +152,6 @@ def test_for(): # odds.append(x) """ [1, 3] : odds - : odds.append 5 : x None : odds.append(x) """ @@ -181,7 +177,6 @@ def test_for(): # odds.append(x) """ [1, 3, 5] : odds - : odds.append 7 : x None : odds.append(x) """ @@ -207,7 +202,6 @@ def test_for(): odds.append(x) """ [1, 3, 5, 7] : odds - : odds.append 9 : x None : odds.append(x) """ diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 3588c22..7ef6de8 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -46,7 +46,6 @@ def test_call_print(): def target(): print('This line will be printed.') """ - : print None : print('This line will be printed.') """ ''') diff --git a/tests/test_torch.py b/tests/test_torch.py new file mode 100644 index 0000000..0fd971d --- /dev/null +++ b/tests/test_torch.py @@ -0,0 +1,107 @@ +import torch +import torch.nn as nn + +from test_utils import * + + +def test_torch(): + + @Commentor("", _globals=globals()) + def target(): + + x = torch.ones(4, 5) + for i in range(3): + x = x[..., None, :] + + a = torch.randn(309, 110, 3)[:100] + f = nn.Linear(3, 128) + b = f(a.reshape(-1, 3)).reshape(-1, 110, 128) + c = torch.concat((a, b), dim=-1) + + return c.flatten() + + asserteq_or_print( + target(), ''' def target(): + x = torch.ones(4, 5) + """ + [4, 5] : torch.ones(4, 5) + ---------- + [4, 5] : x + """ + for i in range(3): + + ###### !new iteration! ###### + """ + 0 : __REG__for_loop_iter_once + ---------- + 0 : i + """ + # x = x[..., None, :] + """ + [4, 5] : x + [4, 1, 5] : x[..., None, :] + ---------- + [4, 1, 5] : x + """ + + ###### !new iteration! ###### + """ + 1 : __REG__for_loop_iter_once + ---------- + 1 : i + """ + # x = x[..., None, :] + """ + [4, 1, 5] : x + [4, 1, 1, 5] : x[..., None, :] + ---------- + [4, 1, 1, 5] : x + """ + + ###### !new iteration! ###### + """ + 2 : __REG__for_loop_iter_once + ---------- + 2 : i + """ + x = x[..., None, :] + """ + [4, 1, 1, 5] : x + [4, 1, 1, 1, 5] : x[..., None, :] + ---------- + [4, 1, 1, 1, 5] : x + """ + a = torch.randn(309, 110, 3)[:100] + """ + [309, 110, 3] : torch.randn(309, 110, 3) + [100, 110, 3] : torch.randn(309, 110, 3)[:100] + ---------- + [100, 110, 3] : a + """ + f = nn.Linear(3, 128) + """ + ---------- + """ + b = f(a.reshape(-1, 3)).reshape(-1, 110, 128) + """ + [100, 110, 3] : a + [11000, 3] : a.reshape(-1, 3) + [11000, 128] : f(a.reshape(-1, 3)) + [100, 110, 128] : f(a.reshape(-1, 3)).reshape(-1, 110, 128) + ---------- + [100, 110, 128] : b + """ + c = torch.concat((a, b), dim=-1) + """ + [100, 110, 3] : a + [100, 110, 128] : b + [100, 110, 131] : torch.concat((a, b), dim=-1) + ---------- + [100, 110, 131] : c + """ + return c.flatten() + """ + [100, 110, 131] : c + [1441000] : c.flatten() + """ +''') diff --git a/tests/test_variables.py b/tests/test_variables.py index 0f82584..422f2f5 100644 --- a/tests/test_variables.py +++ b/tests/test_variables.py @@ -14,7 +14,6 @@ def test_assign(): myint = 7 print(myint) """ - : print 7 : myint None : print(myint) """ diff --git a/trace_commentor/commentor.py b/trace_commentor/commentor.py index 8bb036b..7601dc2 100644 --- a/trace_commentor/commentor.py +++ b/trace_commentor/commentor.py @@ -9,14 +9,14 @@ from functools import wraps from . import handlers from . import formatters from . import flags -from .utils import sign, to_source, comment_to_file, isinstance_noexcept +from .utils import sign, to_source, comment_to_file class Commentor(object): def __init__(self, output="", _globals=dict(), fmt=[], check=True, _exit=True) -> None: self._locals = dict() - self._globals = dict().update(_globals) + self._globals = _globals self._return = None self._formatters = fmt + formatters.LIST self._lines = [] @@ -93,7 +93,11 @@ class Commentor(object): def eval(self, node: ast.Expr, format=True): src = node if type(node) == str else to_source(node) - obj = eval(src, self._globals, self._locals) + try: + obj = eval(src, self._globals, self._locals) + except Exception as e: + e.add_note(f"\tduring evaluating `{src}`") + raise e if not format: return obj @@ -108,9 +112,9 @@ class Commentor(object): def get_formatter(self, obj): for typ, fmt in self._formatters: - if isinstance_noexcept(obj, typ): + if isinstance(typ, type) and isinstance(obj, typ): return fmt - elif callable(typ) and typ(obj): + elif not isinstance(typ, type) and callable(typ) and typ(obj): return fmt else: return repr @@ -138,14 +142,14 @@ class Commentor(object): return self.__append(sign(line, 2)) def check_support(self): - unimpl = [] + unimpl = set() for node in ast.walk(self.root): if node.__class__ in flags.HANDLER_FREE_NODES: continue node_type = node.__class__.__name__ handler = getattr(handlers, node_type, None) if handler is None: - unimpl.append(node_type) + unimpl.add(node_type) if unimpl: print("Unsupported nodes: ", ", ".join(unimpl)) diff --git a/trace_commentor/flags.py b/trace_commentor/flags.py index be4769e..ee42a8f 100644 --- a/trace_commentor/flags.py +++ b/trace_commentor/flags.py @@ -23,6 +23,7 @@ ASSIGN_SILENT = [ ] HANDLER_FREE_NODES = [ + ast.UAdd, ast.USub, ast.Not, ast.Invert, ast.Add, ast.Sub, ast.Mult, ast.Div, ast.FloorDiv, ast.Mod, ast.LShift, ast.RShift, ast.BitOr, ast.BitXor, ast.BitAnd, ast.MatMult, ast.Eq, ast.NotEq, ast.Lt, ast.LtE, ast.Gt, ast.GtE, ast.Is, ast.IsNot, ast.In, ast.NotIn, ast.Load, ast.Store, diff --git a/trace_commentor/formatters/__init__.py b/trace_commentor/formatters/__init__.py index b28efd2..7661074 100644 --- a/trace_commentor/formatters/__init__.py +++ b/trace_commentor/formatters/__init__.py @@ -1,7 +1,17 @@ -def fmt_callable(fn): - return "" +import types + +def silent(_): + return None + +def Tensor(tensor): + return "Tensor" in tensor.__class__.__name__ + +def fmt_Tensor(tensor): + return f"{list(tensor.shape)}" LIST = [ - (callable, fmt_callable), + (callable, silent), + (Tensor, fmt_Tensor), + (types.ModuleType, silent), ] diff --git a/trace_commentor/handlers/__init__.py b/trace_commentor/handlers/__init__.py index df80451..7b97318 100644 --- a/trace_commentor/handlers/__init__.py +++ b/trace_commentor/handlers/__init__.py @@ -1,6 +1,6 @@ from .definitions import FunctionDef, Return from .statements import Pass, Assign -from .expressions import Expr, BinOp, Call, Compare, Attribute +from .expressions import Expr, UnaryOp, BinOp, Call, Compare, Attribute, Subscript, Slice, keyword from .literals import Constant, Tuple, List from .variables import Name from .control_flow import If, For, Continue, Break diff --git a/trace_commentor/handlers/expressions.py b/trace_commentor/handlers/expressions.py index 6bedc5e..864142f 100644 --- a/trace_commentor/handlers/expressions.py +++ b/trace_commentor/handlers/expressions.py @@ -1,16 +1,29 @@ +from ..utils import * + + def Expr(self, cmtor): cmtor.process(self.value) + +def UnaryOp(self, cmtor): + if type(self.operand) == ast.Constant: + return + cmtor.process(self.operand) + cmtor.append_comment(cmtor.eval(self)) + + def BinOp(self, cmtor): cmtor.process(self.left) cmtor.process(self.right) cmtor.append_comment(cmtor.eval(self)) + def Compare(self, cmtor): for cmp in self.comparators: cmtor.process(cmp) cmtor.append_comment(cmtor.eval(self)) + def Call(self, cmtor): cmtor.process(self.func) for arg in self.args: @@ -19,6 +32,26 @@ def Call(self, cmtor): cmtor.process(kwarg) cmtor.append_comment(cmtor.eval(self)) + def Attribute(self, cmtor): cmtor.process(self.value) cmtor.append_comment(cmtor.eval(self)) + + +def Subscript(self, cmtor): + cmtor.process(self.value) + cmtor.process(self.slice) + cmtor.append_comment(cmtor.eval(self)) + + +def Slice(self, cmtor): + if getattr(self, "lower", None) is not None: + cmtor.process(self.lower) + if getattr(self, "upper", None) is not None: + cmtor.process(self.upper) + if getattr(self, "step", None) is not None: + cmtor.process(self.step) + + +def keyword(self, cmtor): + cmtor.process(self.value) diff --git a/trace_commentor/utils.py b/trace_commentor/utils.py index c41e448..3b62ad7 100644 --- a/trace_commentor/utils.py +++ b/trace_commentor/utils.py @@ -53,10 +53,3 @@ def comment_to_file(code, file: str) -> bool: else: raise NotImplementedError(f"Unknown file protocal {file}") return True - - -def isinstance_noexcept(_obj, _class_or_tuple): - try: - return isinstance(_obj, _class_or_tuple) - except TypeError: - return False