Assign();Name();Call()
This commit is contained in:
parent
5017ce2736
commit
b86e2e5613
BIN
conda/env.yml
BIN
conda/env.yml
Binary file not shown.
@ -3,7 +3,7 @@ from test_utils import *
|
|||||||
|
|
||||||
def test_function_def():
|
def test_function_def():
|
||||||
|
|
||||||
@Commentor()
|
@Commentor("<return>")
|
||||||
def target():
|
def target():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from test_utils import asserteq_or_print
|
|||||||
|
|
||||||
def test_binop():
|
def test_binop():
|
||||||
|
|
||||||
@Commentor()
|
@Commentor("<return>")
|
||||||
def target():
|
def target():
|
||||||
1 + 1
|
1 + 1
|
||||||
|
|
||||||
@ -20,7 +20,7 @@ def test_binop():
|
|||||||
|
|
||||||
def test_binop_cascade():
|
def test_binop_cascade():
|
||||||
|
|
||||||
@Commentor()
|
@Commentor("<return>")
|
||||||
def target():
|
def target():
|
||||||
1 + 1 + 1
|
1 + 1 + 1
|
||||||
|
|
||||||
@ -33,3 +33,19 @@ def test_binop_cascade():
|
|||||||
3 : 1 + 1 + 1
|
3 : 1 + 1 + 1
|
||||||
"""
|
"""
|
||||||
''')
|
''')
|
||||||
|
|
||||||
|
|
||||||
|
def test_call_print():
|
||||||
|
|
||||||
|
@Commentor("<return>")
|
||||||
|
def target():
|
||||||
|
print("This line will be printed.")
|
||||||
|
|
||||||
|
asserteq_or_print(
|
||||||
|
target(), '''
|
||||||
|
def target():
|
||||||
|
print('This line will be printed.')
|
||||||
|
"""
|
||||||
|
None : print('This line will be printed.')
|
||||||
|
"""
|
||||||
|
''')
|
||||||
|
|||||||
@ -3,12 +3,11 @@ from test_utils import *
|
|||||||
|
|
||||||
def test_constant():
|
def test_constant():
|
||||||
|
|
||||||
@Commentor()
|
@Commentor("<return>")
|
||||||
def target():
|
def target():
|
||||||
1
|
1
|
||||||
|
|
||||||
asserteq_or_print(target(),
|
asserteq_or_print(target(), '''
|
||||||
'''
|
|
||||||
def target():
|
def target():
|
||||||
1
|
1
|
||||||
''')
|
''')
|
||||||
|
|||||||
@ -1,37 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from trace_commentor.parser import *
|
|
||||||
|
|
||||||
|
|
||||||
def test_func_no_arg():
|
|
||||||
|
|
||||||
@analyse
|
|
||||||
def target():
|
|
||||||
|
|
||||||
x = torch.ones(4, 5)
|
|
||||||
for i in range(3):
|
|
||||||
x = x[..., None, :]
|
|
||||||
|
|
||||||
a = torch.randn(309, 110, 3)[:100]
|
|
||||||
f = nn.Linear(3, 128)
|
|
||||||
b = f(a.reshape(-1, 3)).reshape(309, 110, 128)
|
|
||||||
c = torch.concat((a, b), dim=-1)
|
|
||||||
|
|
||||||
return c.flatten()
|
|
||||||
|
|
||||||
print()
|
|
||||||
target()
|
|
||||||
|
|
||||||
|
|
||||||
def test_for_loop():
|
|
||||||
|
|
||||||
@analyse
|
|
||||||
def target():
|
|
||||||
a = 1
|
|
||||||
for i in range(3):
|
|
||||||
a += 1
|
|
||||||
print(a)
|
|
||||||
|
|
||||||
print()
|
|
||||||
target()
|
|
||||||
20
tests/test_variables.py
Normal file
20
tests/test_variables.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
from test_utils import *
|
||||||
|
|
||||||
|
|
||||||
|
def test_assign():
|
||||||
|
|
||||||
|
@Commentor("<return>")
|
||||||
|
def target():
|
||||||
|
myint = 7
|
||||||
|
print(myint)
|
||||||
|
|
||||||
|
asserteq_or_print(
|
||||||
|
target(), '''
|
||||||
|
def target():
|
||||||
|
myint = 7
|
||||||
|
print(myint)
|
||||||
|
"""
|
||||||
|
7 : myint
|
||||||
|
None : print(myint)
|
||||||
|
"""
|
||||||
|
''')
|
||||||
@ -1,5 +1,7 @@
|
|||||||
import inspect
|
|
||||||
import ast
|
import ast
|
||||||
|
import inspect
|
||||||
|
import sys
|
||||||
|
import rich
|
||||||
|
|
||||||
from inspect import getfullargspec
|
from inspect import getfullargspec
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
@ -12,13 +14,15 @@ from .utils import sign, to_source
|
|||||||
|
|
||||||
class Commentor(object):
|
class Commentor(object):
|
||||||
|
|
||||||
def __init__(self, _formatters=[]) -> None:
|
def __init__(self, output="<stderr>", fmt=[]) -> None:
|
||||||
self._locals = dict()
|
self._locals = dict()
|
||||||
self._globals = dict()
|
self._globals = dict()
|
||||||
self._formatters = _formatters + formatters.LIST
|
self._return = None
|
||||||
|
self._formatters = fmt + formatters.LIST
|
||||||
self._lines = []
|
self._lines = []
|
||||||
self.indent = 0
|
self.indent = 0
|
||||||
self.state = flags.SOURCE
|
self.state = flags.SOURCE
|
||||||
|
self.file = output
|
||||||
|
|
||||||
def __call__(self, func):
|
def __call__(self, func):
|
||||||
|
|
||||||
@ -28,14 +32,30 @@ class Commentor(object):
|
|||||||
self.root = ast.parse(unindented_source).body[0]
|
self.root = ast.parse(unindented_source).body[0]
|
||||||
|
|
||||||
if flags.DEBUG:
|
if flags.DEBUG:
|
||||||
with open("test.log", "wt") as f:
|
with open("debug.log", "wt") as f:
|
||||||
print(ast.dump(self.root, indent=4), file=f)
|
print(ast.dump(self.root, indent=4), file=f)
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def proxy_func(*args, **kwargs):
|
def proxy_func(*args, **kwargs):
|
||||||
|
# input {
|
||||||
self._locals = kwargs
|
self._locals = kwargs
|
||||||
|
# }
|
||||||
|
|
||||||
self.process(self.root)
|
self.process(self.root)
|
||||||
return "\n".join(self._lines)
|
|
||||||
|
# output {
|
||||||
|
code = "\n".join(self._lines)
|
||||||
|
if self.file == "<return>":
|
||||||
|
return code
|
||||||
|
elif self.file == "<stderr>":
|
||||||
|
rich.print(code, file=sys.stderr)
|
||||||
|
elif self.file == "<stdout>":
|
||||||
|
rich.print(code, file=sys.stdout)
|
||||||
|
else:
|
||||||
|
with open(self.file, "wt") as f:
|
||||||
|
rich.print(code, file=f)
|
||||||
|
return self._return
|
||||||
|
# }
|
||||||
|
|
||||||
return proxy_func
|
return proxy_func
|
||||||
|
|
||||||
@ -51,6 +71,10 @@ class Commentor(object):
|
|||||||
obj = eval(src, self._globals, self._locals)
|
obj = eval(src, self._globals, self._locals)
|
||||||
fmt = self.get_formatter(obj)
|
fmt = self.get_formatter(obj)
|
||||||
return f"{fmt(obj)} : {src}"
|
return f"{fmt(obj)} : {src}"
|
||||||
|
|
||||||
|
def exec(self, node: ast.stmt):
|
||||||
|
src = to_source(node)
|
||||||
|
exec(src, self._globals, self._locals)
|
||||||
|
|
||||||
def get_formatter(self, obj):
|
def get_formatter(self, obj):
|
||||||
for typ, fmt in self._formatters:
|
for typ, fmt in self._formatters:
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
from .definitions import FunctionDef
|
from .definitions import FunctionDef
|
||||||
from .statements import Pass
|
from .statements import Pass, Assign
|
||||||
from .expressions import Expr, BinOp
|
from .expressions import Expr, BinOp, Call
|
||||||
from .literals import Constant
|
from .literals import Constant
|
||||||
|
from .variables import Name
|
||||||
|
|||||||
@ -5,3 +5,10 @@ def BinOp(self, cmtor):
|
|||||||
cmtor.process(self.left)
|
cmtor.process(self.left)
|
||||||
cmtor.process(self.right)
|
cmtor.process(self.right)
|
||||||
cmtor.append_comment(cmtor.eval(self))
|
cmtor.append_comment(cmtor.eval(self))
|
||||||
|
|
||||||
|
def Call(self, cmtor):
|
||||||
|
for arg in self.args:
|
||||||
|
cmtor.process(arg)
|
||||||
|
for kwarg in self.keywords:
|
||||||
|
cmtor.process(kwarg)
|
||||||
|
cmtor.append_comment(cmtor.eval(self))
|
||||||
|
|||||||
@ -1,2 +1,6 @@
|
|||||||
def Pass(self, cmtor):
|
def Pass(self, cmtor):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def Assign(self, cmtor):
|
||||||
|
cmtor.process(self.value)
|
||||||
|
cmtor.exec(self)
|
||||||
|
|||||||
2
trace_commentor/handlers/variables.py
Normal file
2
trace_commentor/handlers/variables.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
def Name(self, cmtor):
|
||||||
|
cmtor.append_comment(cmtor.eval(self))
|
||||||
Loading…
x
Reference in New Issue
Block a user