UnaryOp(); Subscript; Slice; keyword(); fmt_Tensor();
This commit is contained in:
parent
c620d0b014
commit
3583f02c5b
@ -1,14 +1,23 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
from tests.test_utils import *
|
from tests.test_utils import *
|
||||||
from trace_commentor.check import Check
|
|
||||||
|
|
||||||
def test():
|
def test():
|
||||||
|
|
||||||
# @Commentor()
|
@Commentor(_globals=globals())
|
||||||
@Check()
|
def target():
|
||||||
def target(a, d=1, *b, c, k=1):
|
x = torch.ones(4, 5)
|
||||||
return a + k
|
for i in range(3):
|
||||||
|
x = x[..., None, :]
|
||||||
|
|
||||||
print(target(1, 2, 3, 4, c=5, k=2))
|
a = torch.randn(309, 110, 3)[:100]
|
||||||
|
f = nn.Linear(3, 128)
|
||||||
|
b = f(a.reshape(-1, 3)).reshape(-1, 110, 128)
|
||||||
|
c = torch.concat((a, b), dim=-1)
|
||||||
|
|
||||||
|
return c.flatten()
|
||||||
|
|
||||||
|
target()
|
||||||
|
|
||||||
|
|
||||||
test()
|
test()
|
||||||
|
|||||||
@ -13,7 +13,6 @@ def test_constant():
|
|||||||
x = 2
|
x = 2
|
||||||
print(x == 2)
|
print(x == 2)
|
||||||
"""
|
"""
|
||||||
<callable> : print
|
|
||||||
True : x == 2
|
True : x == 2
|
||||||
None : print(x == 2)
|
None : print(x == 2)
|
||||||
"""
|
"""
|
||||||
@ -103,7 +102,6 @@ def test_for():
|
|||||||
# odds.append(x)
|
# odds.append(x)
|
||||||
"""
|
"""
|
||||||
[] : odds
|
[] : odds
|
||||||
<callable> : odds.append
|
|
||||||
1 : x
|
1 : x
|
||||||
None : odds.append(x)
|
None : odds.append(x)
|
||||||
"""
|
"""
|
||||||
@ -129,7 +127,6 @@ def test_for():
|
|||||||
# odds.append(x)
|
# odds.append(x)
|
||||||
"""
|
"""
|
||||||
[1] : odds
|
[1] : odds
|
||||||
<callable> : odds.append
|
|
||||||
3 : x
|
3 : x
|
||||||
None : odds.append(x)
|
None : odds.append(x)
|
||||||
"""
|
"""
|
||||||
@ -155,7 +152,6 @@ def test_for():
|
|||||||
# odds.append(x)
|
# odds.append(x)
|
||||||
"""
|
"""
|
||||||
[1, 3] : odds
|
[1, 3] : odds
|
||||||
<callable> : odds.append
|
|
||||||
5 : x
|
5 : x
|
||||||
None : odds.append(x)
|
None : odds.append(x)
|
||||||
"""
|
"""
|
||||||
@ -181,7 +177,6 @@ def test_for():
|
|||||||
# odds.append(x)
|
# odds.append(x)
|
||||||
"""
|
"""
|
||||||
[1, 3, 5] : odds
|
[1, 3, 5] : odds
|
||||||
<callable> : odds.append
|
|
||||||
7 : x
|
7 : x
|
||||||
None : odds.append(x)
|
None : odds.append(x)
|
||||||
"""
|
"""
|
||||||
@ -207,7 +202,6 @@ def test_for():
|
|||||||
odds.append(x)
|
odds.append(x)
|
||||||
"""
|
"""
|
||||||
[1, 3, 5, 7] : odds
|
[1, 3, 5, 7] : odds
|
||||||
<callable> : odds.append
|
|
||||||
9 : x
|
9 : x
|
||||||
None : odds.append(x)
|
None : odds.append(x)
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -46,7 +46,6 @@ def test_call_print():
|
|||||||
def target():
|
def target():
|
||||||
print('This line will be printed.')
|
print('This line will be printed.')
|
||||||
"""
|
"""
|
||||||
<callable> : print
|
|
||||||
None : print('This line will be printed.')
|
None : print('This line will be printed.')
|
||||||
"""
|
"""
|
||||||
''')
|
''')
|
||||||
|
|||||||
107
tests/test_torch.py
Normal file
107
tests/test_torch.py
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from test_utils import *
|
||||||
|
|
||||||
|
|
||||||
|
def test_torch():
|
||||||
|
|
||||||
|
@Commentor("<return>", _globals=globals())
|
||||||
|
def target():
|
||||||
|
|
||||||
|
x = torch.ones(4, 5)
|
||||||
|
for i in range(3):
|
||||||
|
x = x[..., None, :]
|
||||||
|
|
||||||
|
a = torch.randn(309, 110, 3)[:100]
|
||||||
|
f = nn.Linear(3, 128)
|
||||||
|
b = f(a.reshape(-1, 3)).reshape(-1, 110, 128)
|
||||||
|
c = torch.concat((a, b), dim=-1)
|
||||||
|
|
||||||
|
return c.flatten()
|
||||||
|
|
||||||
|
asserteq_or_print(
|
||||||
|
target(), ''' def target():
|
||||||
|
x = torch.ones(4, 5)
|
||||||
|
"""
|
||||||
|
[4, 5] : torch.ones(4, 5)
|
||||||
|
----------
|
||||||
|
[4, 5] : x
|
||||||
|
"""
|
||||||
|
for i in range(3):
|
||||||
|
|
||||||
|
###### !new iteration! ######
|
||||||
|
"""
|
||||||
|
0 : __REG__for_loop_iter_once
|
||||||
|
----------
|
||||||
|
0 : i
|
||||||
|
"""
|
||||||
|
# x = x[..., None, :]
|
||||||
|
"""
|
||||||
|
[4, 5] : x
|
||||||
|
[4, 1, 5] : x[..., None, :]
|
||||||
|
----------
|
||||||
|
[4, 1, 5] : x
|
||||||
|
"""
|
||||||
|
|
||||||
|
###### !new iteration! ######
|
||||||
|
"""
|
||||||
|
1 : __REG__for_loop_iter_once
|
||||||
|
----------
|
||||||
|
1 : i
|
||||||
|
"""
|
||||||
|
# x = x[..., None, :]
|
||||||
|
"""
|
||||||
|
[4, 1, 5] : x
|
||||||
|
[4, 1, 1, 5] : x[..., None, :]
|
||||||
|
----------
|
||||||
|
[4, 1, 1, 5] : x
|
||||||
|
"""
|
||||||
|
|
||||||
|
###### !new iteration! ######
|
||||||
|
"""
|
||||||
|
2 : __REG__for_loop_iter_once
|
||||||
|
----------
|
||||||
|
2 : i
|
||||||
|
"""
|
||||||
|
x = x[..., None, :]
|
||||||
|
"""
|
||||||
|
[4, 1, 1, 5] : x
|
||||||
|
[4, 1, 1, 1, 5] : x[..., None, :]
|
||||||
|
----------
|
||||||
|
[4, 1, 1, 1, 5] : x
|
||||||
|
"""
|
||||||
|
a = torch.randn(309, 110, 3)[:100]
|
||||||
|
"""
|
||||||
|
[309, 110, 3] : torch.randn(309, 110, 3)
|
||||||
|
[100, 110, 3] : torch.randn(309, 110, 3)[:100]
|
||||||
|
----------
|
||||||
|
[100, 110, 3] : a
|
||||||
|
"""
|
||||||
|
f = nn.Linear(3, 128)
|
||||||
|
"""
|
||||||
|
----------
|
||||||
|
"""
|
||||||
|
b = f(a.reshape(-1, 3)).reshape(-1, 110, 128)
|
||||||
|
"""
|
||||||
|
[100, 110, 3] : a
|
||||||
|
[11000, 3] : a.reshape(-1, 3)
|
||||||
|
[11000, 128] : f(a.reshape(-1, 3))
|
||||||
|
[100, 110, 128] : f(a.reshape(-1, 3)).reshape(-1, 110, 128)
|
||||||
|
----------
|
||||||
|
[100, 110, 128] : b
|
||||||
|
"""
|
||||||
|
c = torch.concat((a, b), dim=-1)
|
||||||
|
"""
|
||||||
|
[100, 110, 3] : a
|
||||||
|
[100, 110, 128] : b
|
||||||
|
[100, 110, 131] : torch.concat((a, b), dim=-1)
|
||||||
|
----------
|
||||||
|
[100, 110, 131] : c
|
||||||
|
"""
|
||||||
|
return c.flatten()
|
||||||
|
"""
|
||||||
|
[100, 110, 131] : c
|
||||||
|
[1441000] : c.flatten()
|
||||||
|
"""
|
||||||
|
''')
|
||||||
@ -14,7 +14,6 @@ def test_assign():
|
|||||||
myint = 7
|
myint = 7
|
||||||
print(myint)
|
print(myint)
|
||||||
"""
|
"""
|
||||||
<callable> : print
|
|
||||||
7 : myint
|
7 : myint
|
||||||
None : print(myint)
|
None : print(myint)
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -9,14 +9,14 @@ from functools import wraps
|
|||||||
from . import handlers
|
from . import handlers
|
||||||
from . import formatters
|
from . import formatters
|
||||||
from . import flags
|
from . import flags
|
||||||
from .utils import sign, to_source, comment_to_file, isinstance_noexcept
|
from .utils import sign, to_source, comment_to_file
|
||||||
|
|
||||||
|
|
||||||
class Commentor(object):
|
class Commentor(object):
|
||||||
|
|
||||||
def __init__(self, output="<stderr>", _globals=dict(), fmt=[], check=True, _exit=True) -> None:
|
def __init__(self, output="<stderr>", _globals=dict(), fmt=[], check=True, _exit=True) -> None:
|
||||||
self._locals = dict()
|
self._locals = dict()
|
||||||
self._globals = dict().update(_globals)
|
self._globals = _globals
|
||||||
self._return = None
|
self._return = None
|
||||||
self._formatters = fmt + formatters.LIST
|
self._formatters = fmt + formatters.LIST
|
||||||
self._lines = []
|
self._lines = []
|
||||||
@ -93,7 +93,11 @@ class Commentor(object):
|
|||||||
|
|
||||||
def eval(self, node: ast.Expr, format=True):
|
def eval(self, node: ast.Expr, format=True):
|
||||||
src = node if type(node) == str else to_source(node)
|
src = node if type(node) == str else to_source(node)
|
||||||
obj = eval(src, self._globals, self._locals)
|
try:
|
||||||
|
obj = eval(src, self._globals, self._locals)
|
||||||
|
except Exception as e:
|
||||||
|
e.add_note(f"\tduring evaluating `{src}`")
|
||||||
|
raise e
|
||||||
if not format:
|
if not format:
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
@ -108,9 +112,9 @@ class Commentor(object):
|
|||||||
|
|
||||||
def get_formatter(self, obj):
|
def get_formatter(self, obj):
|
||||||
for typ, fmt in self._formatters:
|
for typ, fmt in self._formatters:
|
||||||
if isinstance_noexcept(obj, typ):
|
if isinstance(typ, type) and isinstance(obj, typ):
|
||||||
return fmt
|
return fmt
|
||||||
elif callable(typ) and typ(obj):
|
elif not isinstance(typ, type) and callable(typ) and typ(obj):
|
||||||
return fmt
|
return fmt
|
||||||
else:
|
else:
|
||||||
return repr
|
return repr
|
||||||
@ -138,14 +142,14 @@ class Commentor(object):
|
|||||||
return self.__append(sign(line, 2))
|
return self.__append(sign(line, 2))
|
||||||
|
|
||||||
def check_support(self):
|
def check_support(self):
|
||||||
unimpl = []
|
unimpl = set()
|
||||||
for node in ast.walk(self.root):
|
for node in ast.walk(self.root):
|
||||||
if node.__class__ in flags.HANDLER_FREE_NODES:
|
if node.__class__ in flags.HANDLER_FREE_NODES:
|
||||||
continue
|
continue
|
||||||
node_type = node.__class__.__name__
|
node_type = node.__class__.__name__
|
||||||
handler = getattr(handlers, node_type, None)
|
handler = getattr(handlers, node_type, None)
|
||||||
if handler is None:
|
if handler is None:
|
||||||
unimpl.append(node_type)
|
unimpl.add(node_type)
|
||||||
|
|
||||||
if unimpl:
|
if unimpl:
|
||||||
print("Unsupported nodes: ", ", ".join(unimpl))
|
print("Unsupported nodes: ", ", ".join(unimpl))
|
||||||
|
|||||||
@ -23,6 +23,7 @@ ASSIGN_SILENT = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
HANDLER_FREE_NODES = [
|
HANDLER_FREE_NODES = [
|
||||||
|
ast.UAdd, ast.USub, ast.Not, ast.Invert,
|
||||||
ast.Add, ast.Sub, ast.Mult, ast.Div, ast.FloorDiv, ast.Mod, ast.LShift, ast.RShift, ast.BitOr, ast.BitXor, ast.BitAnd, ast.MatMult,
|
ast.Add, ast.Sub, ast.Mult, ast.Div, ast.FloorDiv, ast.Mod, ast.LShift, ast.RShift, ast.BitOr, ast.BitXor, ast.BitAnd, ast.MatMult,
|
||||||
ast.Eq, ast.NotEq, ast.Lt, ast.LtE, ast.Gt, ast.GtE, ast.Is, ast.IsNot, ast.In, ast.NotIn,
|
ast.Eq, ast.NotEq, ast.Lt, ast.LtE, ast.Gt, ast.GtE, ast.Is, ast.IsNot, ast.In, ast.NotIn,
|
||||||
ast.Load, ast.Store,
|
ast.Load, ast.Store,
|
||||||
|
|||||||
@ -1,7 +1,17 @@
|
|||||||
def fmt_callable(fn):
|
import types
|
||||||
return "<callable>"
|
|
||||||
|
def silent(_):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def Tensor(tensor):
|
||||||
|
return "Tensor" in tensor.__class__.__name__
|
||||||
|
|
||||||
|
def fmt_Tensor(tensor):
|
||||||
|
return f"{list(tensor.shape)}"
|
||||||
|
|
||||||
|
|
||||||
LIST = [
|
LIST = [
|
||||||
(callable, fmt_callable),
|
(callable, silent),
|
||||||
|
(Tensor, fmt_Tensor),
|
||||||
|
(types.ModuleType, silent),
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from .definitions import FunctionDef, Return
|
from .definitions import FunctionDef, Return
|
||||||
from .statements import Pass, Assign
|
from .statements import Pass, Assign
|
||||||
from .expressions import Expr, BinOp, Call, Compare, Attribute
|
from .expressions import Expr, UnaryOp, BinOp, Call, Compare, Attribute, Subscript, Slice, keyword
|
||||||
from .literals import Constant, Tuple, List
|
from .literals import Constant, Tuple, List
|
||||||
from .variables import Name
|
from .variables import Name
|
||||||
from .control_flow import If, For, Continue, Break
|
from .control_flow import If, For, Continue, Break
|
||||||
|
|||||||
@ -1,16 +1,29 @@
|
|||||||
|
from ..utils import *
|
||||||
|
|
||||||
|
|
||||||
def Expr(self, cmtor):
|
def Expr(self, cmtor):
|
||||||
cmtor.process(self.value)
|
cmtor.process(self.value)
|
||||||
|
|
||||||
|
|
||||||
|
def UnaryOp(self, cmtor):
|
||||||
|
if type(self.operand) == ast.Constant:
|
||||||
|
return
|
||||||
|
cmtor.process(self.operand)
|
||||||
|
cmtor.append_comment(cmtor.eval(self))
|
||||||
|
|
||||||
|
|
||||||
def BinOp(self, cmtor):
|
def BinOp(self, cmtor):
|
||||||
cmtor.process(self.left)
|
cmtor.process(self.left)
|
||||||
cmtor.process(self.right)
|
cmtor.process(self.right)
|
||||||
cmtor.append_comment(cmtor.eval(self))
|
cmtor.append_comment(cmtor.eval(self))
|
||||||
|
|
||||||
|
|
||||||
def Compare(self, cmtor):
|
def Compare(self, cmtor):
|
||||||
for cmp in self.comparators:
|
for cmp in self.comparators:
|
||||||
cmtor.process(cmp)
|
cmtor.process(cmp)
|
||||||
cmtor.append_comment(cmtor.eval(self))
|
cmtor.append_comment(cmtor.eval(self))
|
||||||
|
|
||||||
|
|
||||||
def Call(self, cmtor):
|
def Call(self, cmtor):
|
||||||
cmtor.process(self.func)
|
cmtor.process(self.func)
|
||||||
for arg in self.args:
|
for arg in self.args:
|
||||||
@ -19,6 +32,26 @@ def Call(self, cmtor):
|
|||||||
cmtor.process(kwarg)
|
cmtor.process(kwarg)
|
||||||
cmtor.append_comment(cmtor.eval(self))
|
cmtor.append_comment(cmtor.eval(self))
|
||||||
|
|
||||||
|
|
||||||
def Attribute(self, cmtor):
|
def Attribute(self, cmtor):
|
||||||
cmtor.process(self.value)
|
cmtor.process(self.value)
|
||||||
cmtor.append_comment(cmtor.eval(self))
|
cmtor.append_comment(cmtor.eval(self))
|
||||||
|
|
||||||
|
|
||||||
|
def Subscript(self, cmtor):
|
||||||
|
cmtor.process(self.value)
|
||||||
|
cmtor.process(self.slice)
|
||||||
|
cmtor.append_comment(cmtor.eval(self))
|
||||||
|
|
||||||
|
|
||||||
|
def Slice(self, cmtor):
|
||||||
|
if getattr(self, "lower", None) is not None:
|
||||||
|
cmtor.process(self.lower)
|
||||||
|
if getattr(self, "upper", None) is not None:
|
||||||
|
cmtor.process(self.upper)
|
||||||
|
if getattr(self, "step", None) is not None:
|
||||||
|
cmtor.process(self.step)
|
||||||
|
|
||||||
|
|
||||||
|
def keyword(self, cmtor):
|
||||||
|
cmtor.process(self.value)
|
||||||
|
|||||||
@ -53,10 +53,3 @@ def comment_to_file(code, file: str) -> bool:
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Unknown file protocal {file}")
|
raise NotImplementedError(f"Unknown file protocal {file}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def isinstance_noexcept(_obj, _class_or_tuple):
|
|
||||||
try:
|
|
||||||
return isinstance(_obj, _class_or_tuple)
|
|
||||||
except TypeError:
|
|
||||||
return False
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user