This commit is contained in:
Yuyao Huang 2024-04-19 18:51:21 +08:00
parent b86e2e5613
commit 6e8c8e1998
6 changed files with 62 additions and 5 deletions

15
experiment.py Normal file
View File

@ -0,0 +1,15 @@
from tests.test_utils import *
@Commentor(fmt=[
(type(None), lambda o: None),
])
def function():
mystring = 'hello'
print(mystring)
mystring = "hello"
print(mystring)
return 1
print(function())

View File

@ -11,3 +11,26 @@ def test_function_def():
def target(): def target():
pass pass
''') ''')
def test_return():
with closing(StringIO()) as f:
@Commentor(f)
def target():
a = 1
return a + 1
assert target() == 2
asserteq_or_print(
f.getvalue(), '''
def target():
a = 1
return a + 1
"""
1 : a
2 : a + 1
"""
''')

View File

@ -1,7 +1,11 @@
from io import StringIO
from contextlib import closing
from trace_commentor import flags, Commentor from trace_commentor import flags, Commentor
def asserteq_or_print(value, ground_truth): def asserteq_or_print(value, ground_truth):
if flags.DEBUG or flags.PRINT: if flags.DEBUG or flags.PRINT:
print(value) print(value)
else: else:
assert value == ground_truth.strip("\n"), "\n".join(["\n\n<<<<<<<< VALUE", value, "========================", ground_truth.strip("\n"), ">>>>>>>> GROUND\n"]) value = value.strip("\n").rstrip(" ")
ground_truth = ground_truth.strip("\n").rstrip(" ")
assert value == ground_truth, "\n".join(["\n\n<<<<<<<< VALUE", value, "========================", ground_truth, ">>>>>>>> GROUND\n"])

View File

@ -5,6 +5,7 @@ import rich
from inspect import getfullargspec from inspect import getfullargspec
from functools import wraps from functools import wraps
from io import IOBase
from . import handlers from . import handlers
from . import formatters from . import formatters
@ -51,9 +52,13 @@ class Commentor(object):
rich.print(code, file=sys.stderr) rich.print(code, file=sys.stderr)
elif self.file == "<stdout>": elif self.file == "<stdout>":
rich.print(code, file=sys.stdout) rich.print(code, file=sys.stdout)
else: elif isinstance(self.file, IOBase):
rich.print(code, file=self.file)
elif type(self.file) == str:
with open(self.file, "wt") as f: with open(self.file, "wt") as f:
rich.print(code, file=f) rich.print(code, file=f)
else:
raise NotImplementedError(f"Unknown file protocal {self.file}")
return self._return return self._return
# } # }
@ -66,11 +71,16 @@ class Commentor(object):
raise NotImplementedError(f"Unknown how to handle {node_type} node.") raise NotImplementedError(f"Unknown how to handle {node_type} node.")
return handler(node, self) return handler(node, self)
def eval(self, node: ast.Expr): def eval(self, node: ast.Expr, format=True):
src = to_source(node) src = to_source(node)
obj = eval(src, self._globals, self._locals) obj = eval(src, self._globals, self._locals)
if not format:
return obj
fmt = self.get_formatter(obj) fmt = self.get_formatter(obj)
return f"{fmt(obj)} : {src}" fmt_obj = fmt(obj)
if fmt_obj is not None:
return f"{fmt(obj)} : {src}"
def exec(self, node: ast.stmt): def exec(self, node: ast.stmt):
src = to_source(node) src = to_source(node)

View File

@ -1,4 +1,4 @@
from .definitions import FunctionDef from .definitions import FunctionDef, Return
from .statements import Pass, Assign from .statements import Pass, Assign
from .expressions import Expr, BinOp, Call from .expressions import Expr, BinOp, Call
from .literals import Constant from .literals import Constant

View File

@ -15,3 +15,8 @@ def FunctionDef(self, cmtor):
cmtor.append_source() cmtor.append_source()
cmtor.indent -= flags.INDENT cmtor.indent -= flags.INDENT
def Return(self, cmtor):
cmtor.process(self.value)
cmtor._return = cmtor.eval(self.value, format=False)