Assign();Name();Call()
This commit is contained in:
parent
5017ce2736
commit
b86e2e5613
BIN
conda/env.yml
BIN
conda/env.yml
Binary file not shown.
@ -3,7 +3,7 @@ from test_utils import *
|
||||
|
||||
def test_function_def():
|
||||
|
||||
@Commentor()
|
||||
@Commentor("<return>")
|
||||
def target():
|
||||
pass
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ from test_utils import asserteq_or_print
|
||||
|
||||
def test_binop():
|
||||
|
||||
@Commentor()
|
||||
@Commentor("<return>")
|
||||
def target():
|
||||
1 + 1
|
||||
|
||||
@ -20,7 +20,7 @@ def test_binop():
|
||||
|
||||
def test_binop_cascade():
|
||||
|
||||
@Commentor()
|
||||
@Commentor("<return>")
|
||||
def target():
|
||||
1 + 1 + 1
|
||||
|
||||
@ -33,3 +33,19 @@ def test_binop_cascade():
|
||||
3 : 1 + 1 + 1
|
||||
"""
|
||||
''')
|
||||
|
||||
|
||||
def test_call_print():
|
||||
|
||||
@Commentor("<return>")
|
||||
def target():
|
||||
print("This line will be printed.")
|
||||
|
||||
asserteq_or_print(
|
||||
target(), '''
|
||||
def target():
|
||||
print('This line will be printed.')
|
||||
"""
|
||||
None : print('This line will be printed.')
|
||||
"""
|
||||
''')
|
||||
|
||||
@ -3,12 +3,11 @@ from test_utils import *
|
||||
|
||||
def test_constant():
|
||||
|
||||
@Commentor()
|
||||
@Commentor("<return>")
|
||||
def target():
|
||||
1
|
||||
|
||||
asserteq_or_print(target(),
|
||||
'''
|
||||
|
||||
asserteq_or_print(target(), '''
|
||||
def target():
|
||||
1
|
||||
''')
|
||||
|
||||
@ -1,37 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from trace_commentor.parser import *
|
||||
|
||||
|
||||
def test_func_no_arg():
|
||||
|
||||
@analyse
|
||||
def target():
|
||||
|
||||
x = torch.ones(4, 5)
|
||||
for i in range(3):
|
||||
x = x[..., None, :]
|
||||
|
||||
a = torch.randn(309, 110, 3)[:100]
|
||||
f = nn.Linear(3, 128)
|
||||
b = f(a.reshape(-1, 3)).reshape(309, 110, 128)
|
||||
c = torch.concat((a, b), dim=-1)
|
||||
|
||||
return c.flatten()
|
||||
|
||||
print()
|
||||
target()
|
||||
|
||||
|
||||
def test_for_loop():
|
||||
|
||||
@analyse
|
||||
def target():
|
||||
a = 1
|
||||
for i in range(3):
|
||||
a += 1
|
||||
print(a)
|
||||
|
||||
print()
|
||||
target()
|
||||
20
tests/test_variables.py
Normal file
20
tests/test_variables.py
Normal file
@ -0,0 +1,20 @@
|
||||
from test_utils import *
|
||||
|
||||
|
||||
def test_assign():
|
||||
|
||||
@Commentor("<return>")
|
||||
def target():
|
||||
myint = 7
|
||||
print(myint)
|
||||
|
||||
asserteq_or_print(
|
||||
target(), '''
|
||||
def target():
|
||||
myint = 7
|
||||
print(myint)
|
||||
"""
|
||||
7 : myint
|
||||
None : print(myint)
|
||||
"""
|
||||
''')
|
||||
@ -1,5 +1,7 @@
|
||||
import inspect
|
||||
import ast
|
||||
import inspect
|
||||
import sys
|
||||
import rich
|
||||
|
||||
from inspect import getfullargspec
|
||||
from functools import wraps
|
||||
@ -12,13 +14,15 @@ from .utils import sign, to_source
|
||||
|
||||
class Commentor(object):
|
||||
|
||||
def __init__(self, _formatters=[]) -> None:
|
||||
def __init__(self, output="<stderr>", fmt=[]) -> None:
|
||||
self._locals = dict()
|
||||
self._globals = dict()
|
||||
self._formatters = _formatters + formatters.LIST
|
||||
self._return = None
|
||||
self._formatters = fmt + formatters.LIST
|
||||
self._lines = []
|
||||
self.indent = 0
|
||||
self.state = flags.SOURCE
|
||||
self.file = output
|
||||
|
||||
def __call__(self, func):
|
||||
|
||||
@ -28,14 +32,30 @@ class Commentor(object):
|
||||
self.root = ast.parse(unindented_source).body[0]
|
||||
|
||||
if flags.DEBUG:
|
||||
with open("test.log", "wt") as f:
|
||||
with open("debug.log", "wt") as f:
|
||||
print(ast.dump(self.root, indent=4), file=f)
|
||||
|
||||
@wraps(func)
|
||||
def proxy_func(*args, **kwargs):
|
||||
# input {
|
||||
self._locals = kwargs
|
||||
# }
|
||||
|
||||
self.process(self.root)
|
||||
return "\n".join(self._lines)
|
||||
|
||||
# output {
|
||||
code = "\n".join(self._lines)
|
||||
if self.file == "<return>":
|
||||
return code
|
||||
elif self.file == "<stderr>":
|
||||
rich.print(code, file=sys.stderr)
|
||||
elif self.file == "<stdout>":
|
||||
rich.print(code, file=sys.stdout)
|
||||
else:
|
||||
with open(self.file, "wt") as f:
|
||||
rich.print(code, file=f)
|
||||
return self._return
|
||||
# }
|
||||
|
||||
return proxy_func
|
||||
|
||||
@ -51,6 +71,10 @@ class Commentor(object):
|
||||
obj = eval(src, self._globals, self._locals)
|
||||
fmt = self.get_formatter(obj)
|
||||
return f"{fmt(obj)} : {src}"
|
||||
|
||||
def exec(self, node: ast.stmt):
|
||||
src = to_source(node)
|
||||
exec(src, self._globals, self._locals)
|
||||
|
||||
def get_formatter(self, obj):
|
||||
for typ, fmt in self._formatters:
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from .definitions import FunctionDef
|
||||
from .statements import Pass
|
||||
from .expressions import Expr, BinOp
|
||||
from .statements import Pass, Assign
|
||||
from .expressions import Expr, BinOp, Call
|
||||
from .literals import Constant
|
||||
from .variables import Name
|
||||
|
||||
@ -5,3 +5,10 @@ def BinOp(self, cmtor):
|
||||
cmtor.process(self.left)
|
||||
cmtor.process(self.right)
|
||||
cmtor.append_comment(cmtor.eval(self))
|
||||
|
||||
def Call(self, cmtor):
|
||||
for arg in self.args:
|
||||
cmtor.process(arg)
|
||||
for kwarg in self.keywords:
|
||||
cmtor.process(kwarg)
|
||||
cmtor.append_comment(cmtor.eval(self))
|
||||
|
||||
@ -1,2 +1,6 @@
|
||||
def Pass(self, cmtor):
|
||||
pass
|
||||
|
||||
def Assign(self, cmtor):
|
||||
cmtor.process(self.value)
|
||||
cmtor.exec(self)
|
||||
|
||||
2
trace_commentor/handlers/variables.py
Normal file
2
trace_commentor/handlers/variables.py
Normal file
@ -0,0 +1,2 @@
|
||||
def Name(self, cmtor):
|
||||
cmtor.append_comment(cmtor.eval(self))
|
||||
Loading…
x
Reference in New Issue
Block a user