UnaryOp(); Subscript; Slice; keyword(); fmt_Tensor();

This commit is contained in:
Yuyao Huang 2024-04-22 21:08:41 +08:00
parent c620d0b014
commit 3583f02c5b
11 changed files with 181 additions and 32 deletions

View File

@ -1,14 +1,23 @@
import torch
import torch.nn as nn
from tests.test_utils import *
from trace_commentor.check import Check
def test():
# @Commentor()
@Check()
def target(a, d=1, *b, c, k=1):
return a + k
@Commentor(_globals=globals())
def target():
x = torch.ones(4, 5)
for i in range(3):
x = x[..., None, :]
print(target(1, 2, 3, 4, c=5, k=2))
a = torch.randn(309, 110, 3)[:100]
f = nn.Linear(3, 128)
b = f(a.reshape(-1, 3)).reshape(-1, 110, 128)
c = torch.concat((a, b), dim=-1)
return c.flatten()
target()
test()

View File

@ -13,7 +13,6 @@ def test_constant():
x = 2
print(x == 2)
"""
<callable> : print
True : x == 2
None : print(x == 2)
"""
@ -103,7 +102,6 @@ def test_for():
# odds.append(x)
"""
[] : odds
<callable> : odds.append
1 : x
None : odds.append(x)
"""
@ -129,7 +127,6 @@ def test_for():
# odds.append(x)
"""
[1] : odds
<callable> : odds.append
3 : x
None : odds.append(x)
"""
@ -155,7 +152,6 @@ def test_for():
# odds.append(x)
"""
[1, 3] : odds
<callable> : odds.append
5 : x
None : odds.append(x)
"""
@ -181,7 +177,6 @@ def test_for():
# odds.append(x)
"""
[1, 3, 5] : odds
<callable> : odds.append
7 : x
None : odds.append(x)
"""
@ -207,7 +202,6 @@ def test_for():
odds.append(x)
"""
[1, 3, 5, 7] : odds
<callable> : odds.append
9 : x
None : odds.append(x)
"""

View File

@ -46,7 +46,6 @@ def test_call_print():
def target():
print('This line will be printed.')
"""
<callable> : print
None : print('This line will be printed.')
"""
''')

107
tests/test_torch.py Normal file
View File

@ -0,0 +1,107 @@
import torch
import torch.nn as nn
from test_utils import *
def test_torch():
@Commentor("<return>", _globals=globals())
def target():
x = torch.ones(4, 5)
for i in range(3):
x = x[..., None, :]
a = torch.randn(309, 110, 3)[:100]
f = nn.Linear(3, 128)
b = f(a.reshape(-1, 3)).reshape(-1, 110, 128)
c = torch.concat((a, b), dim=-1)
return c.flatten()
asserteq_or_print(
target(), ''' def target():
x = torch.ones(4, 5)
"""
[4, 5] : torch.ones(4, 5)
----------
[4, 5] : x
"""
for i in range(3):
###### !new iteration! ######
"""
0 : __REG__for_loop_iter_once
----------
0 : i
"""
# x = x[..., None, :]
"""
[4, 5] : x
[4, 1, 5] : x[..., None, :]
----------
[4, 1, 5] : x
"""
###### !new iteration! ######
"""
1 : __REG__for_loop_iter_once
----------
1 : i
"""
# x = x[..., None, :]
"""
[4, 1, 5] : x
[4, 1, 1, 5] : x[..., None, :]
----------
[4, 1, 1, 5] : x
"""
###### !new iteration! ######
"""
2 : __REG__for_loop_iter_once
----------
2 : i
"""
x = x[..., None, :]
"""
[4, 1, 1, 5] : x
[4, 1, 1, 1, 5] : x[..., None, :]
----------
[4, 1, 1, 1, 5] : x
"""
a = torch.randn(309, 110, 3)[:100]
"""
[309, 110, 3] : torch.randn(309, 110, 3)
[100, 110, 3] : torch.randn(309, 110, 3)[:100]
----------
[100, 110, 3] : a
"""
f = nn.Linear(3, 128)
"""
----------
"""
b = f(a.reshape(-1, 3)).reshape(-1, 110, 128)
"""
[100, 110, 3] : a
[11000, 3] : a.reshape(-1, 3)
[11000, 128] : f(a.reshape(-1, 3))
[100, 110, 128] : f(a.reshape(-1, 3)).reshape(-1, 110, 128)
----------
[100, 110, 128] : b
"""
c = torch.concat((a, b), dim=-1)
"""
[100, 110, 3] : a
[100, 110, 128] : b
[100, 110, 131] : torch.concat((a, b), dim=-1)
----------
[100, 110, 131] : c
"""
return c.flatten()
"""
[100, 110, 131] : c
[1441000] : c.flatten()
"""
''')

View File

@ -14,7 +14,6 @@ def test_assign():
myint = 7
print(myint)
"""
<callable> : print
7 : myint
None : print(myint)
"""

View File

@ -9,14 +9,14 @@ from functools import wraps
from . import handlers
from . import formatters
from . import flags
from .utils import sign, to_source, comment_to_file, isinstance_noexcept
from .utils import sign, to_source, comment_to_file
class Commentor(object):
def __init__(self, output="<stderr>", _globals=dict(), fmt=[], check=True, _exit=True) -> None:
self._locals = dict()
self._globals = dict().update(_globals)
self._globals = _globals
self._return = None
self._formatters = fmt + formatters.LIST
self._lines = []
@ -93,7 +93,11 @@ class Commentor(object):
def eval(self, node: ast.Expr, format=True):
src = node if type(node) == str else to_source(node)
obj = eval(src, self._globals, self._locals)
try:
obj = eval(src, self._globals, self._locals)
except Exception as e:
e.add_note(f"\tduring evaluating `{src}`")
raise e
if not format:
return obj
@ -108,9 +112,9 @@ class Commentor(object):
def get_formatter(self, obj):
for typ, fmt in self._formatters:
if isinstance_noexcept(obj, typ):
if isinstance(typ, type) and isinstance(obj, typ):
return fmt
elif callable(typ) and typ(obj):
elif not isinstance(typ, type) and callable(typ) and typ(obj):
return fmt
else:
return repr
@ -138,14 +142,14 @@ class Commentor(object):
return self.__append(sign(line, 2))
def check_support(self):
unimpl = []
unimpl = set()
for node in ast.walk(self.root):
if node.__class__ in flags.HANDLER_FREE_NODES:
continue
node_type = node.__class__.__name__
handler = getattr(handlers, node_type, None)
if handler is None:
unimpl.append(node_type)
unimpl.add(node_type)
if unimpl:
print("Unsupported nodes: ", ", ".join(unimpl))

View File

@ -23,6 +23,7 @@ ASSIGN_SILENT = [
]
HANDLER_FREE_NODES = [
ast.UAdd, ast.USub, ast.Not, ast.Invert,
ast.Add, ast.Sub, ast.Mult, ast.Div, ast.FloorDiv, ast.Mod, ast.LShift, ast.RShift, ast.BitOr, ast.BitXor, ast.BitAnd, ast.MatMult,
ast.Eq, ast.NotEq, ast.Lt, ast.LtE, ast.Gt, ast.GtE, ast.Is, ast.IsNot, ast.In, ast.NotIn,
ast.Load, ast.Store,

View File

@ -1,7 +1,17 @@
def fmt_callable(fn):
return "<callable>"
import types
def silent(_):
return None
def Tensor(tensor):
return "Tensor" in tensor.__class__.__name__
def fmt_Tensor(tensor):
return f"{list(tensor.shape)}"
LIST = [
(callable, fmt_callable),
(callable, silent),
(Tensor, fmt_Tensor),
(types.ModuleType, silent),
]

View File

@ -1,6 +1,6 @@
from .definitions import FunctionDef, Return
from .statements import Pass, Assign
from .expressions import Expr, BinOp, Call, Compare, Attribute
from .expressions import Expr, UnaryOp, BinOp, Call, Compare, Attribute, Subscript, Slice, keyword
from .literals import Constant, Tuple, List
from .variables import Name
from .control_flow import If, For, Continue, Break

View File

@ -1,16 +1,29 @@
from ..utils import *
def Expr(self, cmtor):
cmtor.process(self.value)
def UnaryOp(self, cmtor):
if type(self.operand) == ast.Constant:
return
cmtor.process(self.operand)
cmtor.append_comment(cmtor.eval(self))
def BinOp(self, cmtor):
cmtor.process(self.left)
cmtor.process(self.right)
cmtor.append_comment(cmtor.eval(self))
def Compare(self, cmtor):
for cmp in self.comparators:
cmtor.process(cmp)
cmtor.append_comment(cmtor.eval(self))
def Call(self, cmtor):
cmtor.process(self.func)
for arg in self.args:
@ -19,6 +32,26 @@ def Call(self, cmtor):
cmtor.process(kwarg)
cmtor.append_comment(cmtor.eval(self))
def Attribute(self, cmtor):
cmtor.process(self.value)
cmtor.append_comment(cmtor.eval(self))
def Subscript(self, cmtor):
cmtor.process(self.value)
cmtor.process(self.slice)
cmtor.append_comment(cmtor.eval(self))
def Slice(self, cmtor):
if getattr(self, "lower", None) is not None:
cmtor.process(self.lower)
if getattr(self, "upper", None) is not None:
cmtor.process(self.upper)
if getattr(self, "step", None) is not None:
cmtor.process(self.step)
def keyword(self, cmtor):
cmtor.process(self.value)

View File

@ -53,10 +53,3 @@ def comment_to_file(code, file: str) -> bool:
else:
raise NotImplementedError(f"Unknown file protocal {file}")
return True
def isinstance_noexcept(_obj, _class_or_tuple):
try:
return isinstance(_obj, _class_or_tuple)
except TypeError:
return False