From 0281d68cb4adf89254bb260f7a20de54d93d3163 Mon Sep 17 00:00:00 2001 From: Yuyao Huang Date: Mon, 22 Apr 2024 15:52:10 +0800 Subject: [PATCH] function arguments parse --- DEVNOTE.md | 9 ++++ experiment.py | 11 ++--- tests/test_definitions.py | 26 ++++++++++ trace_commentor/commentor.py | 65 ++++++++++++------------- trace_commentor/handlers/definitions.py | 7 ++- trace_commentor/utils.py | 29 +++++++++++ 6 files changed, 105 insertions(+), 42 deletions(-) create mode 100644 DEVNOTE.md diff --git a/DEVNOTE.md b/DEVNOTE.md new file mode 100644 index 0000000..1002b41 --- /dev/null +++ b/DEVNOTE.md @@ -0,0 +1,9 @@ +## How to add a new node support to TraceCommentor + +- Read about the target node information (https://docs.python.org/3/library/ast.html). +- Add a test case. +- Write the implementation and pass all tests. + - echo the original code. (automatically handled by the parent scope node if not in `flags.APPEND_SOURCE_BY_THEMSELVES`) + - expression should be evaluated and add comment to each recursively by correct order. (use `cmtor.eval(...)`) + - statements should be executed. (use `cmtor.exec(...)`) + - use `pytest -s -q -k` to test new test case first, use `DEBUG` and `PRINT` environment variable to help debugging. diff --git a/experiment.py b/experiment.py index d5eb021..25660d0 100644 --- a/experiment.py +++ b/experiment.py @@ -4,15 +4,10 @@ from tests.test_utils import * def test(): @Commentor() - def target(): - # return only odd numbers - 1,3,5,7,9 - for x in range(10): - # Check if x is even - if x % 2 == 0: - continue - print(x) + def target(a, d=1, *b, c, k=1): + return a + k - print(target()) + print(target(1, 2, 3, 4, c=5, k=2)) test() diff --git a/tests/test_definitions.py b/tests/test_definitions.py index 7d11a7d..a78551d 100644 --- a/tests/test_definitions.py +++ b/tests/test_definitions.py @@ -34,3 +34,29 @@ def test_return(): 2 : a + 1 """ ''') + + + + +def test_args(): + + @Commentor("") + def target(a, d=1, *b, c, k=1): + return a + k + + asserteq_or_print(target(1, 2, 3, 4, c=5, k=2), ''' + def target(a, d=1, *b, c, k=1): + """ + 1 : a + 2 : d + (3, 4) : b + 5 : c + 2 : k + """ + return a + k + """ + 1 : a + 2 : k + 3 : a + k + """ +''') diff --git a/trace_commentor/commentor.py b/trace_commentor/commentor.py index 32efac0..9dce81c 100644 --- a/trace_commentor/commentor.py +++ b/trace_commentor/commentor.py @@ -5,20 +5,18 @@ import rich from inspect import getfullargspec from functools import wraps -from io import IOBase from . import handlers from . import formatters from . import flags -from .utils import sign, to_source -from rich.syntax import Syntax +from .utils import sign, to_source, comment_to_file class Commentor(object): - def __init__(self, output="", fmt=[]) -> None: + def __init__(self, output="", _globals=dict(), fmt=[]) -> None: self._locals = dict() - self._globals = dict() + self._globals = dict().update(_globals) self._return = None self._formatters = fmt + formatters.LIST self._lines = [] @@ -34,6 +32,7 @@ class Commentor(object): self.indent = len(raw_lines[0]) - len(raw_lines[0].lstrip()) unindented_source = ''.join([l[self.indent:] for l in raw_lines]) self.root = ast.parse(unindented_source).body[0] + pt = getfullargspec(func) if flags.DEBUG: with open("debug.log", "wt") as f: @@ -42,35 +41,39 @@ class Commentor(object): @wraps(func) def proxy_func(*args, **kwargs): # input { - self._locals = kwargs + self._locals = dict() + + # args specified + for target, value in zip(pt.args, args): + self._locals[target] = value + + # args defaults + if pt.defaults is not None: + for target, value in zip(pt.args[-len(pt.defaults):], pt.defaults): + self._locals.setdefault(target, value) + + # varargs + if pt.varargs is not None: + self._locals[pt.varargs] = args[len(pt.args):] + + # kwargs specified + self._locals.update(kwargs) + + # kwargs default + if pt.kwonlydefaults is not None: + for target, value in pt.kwonlydefaults.items(): + self._locals.setdefault(target, value) + # } self.process(self.root) # output { - code = "\n".join(self._lines) - if self.file == "": - return code - elif self.file == "": - if sys.stderr.isatty(): - syntax = Syntax(code, "python") - rich.print(syntax, file=sys.stderr) - else: - rich.print(code, file=sys.stderr) - elif self.file == "": - if sys.stdout.isatty(): - syntax = Syntax(code, "python") - rich.print(syntax, file=sys.stdout) - else: - rich.print(code, file=sys.stdout) - elif isinstance(self.file, IOBase): - rich.print(code, file=self.file) - elif type(self.file) == str: - with open(self.file, "wt") as f: - rich.print(code, file=f) + comments = "\n".join(self._lines) + if comment_to_file(comments, file=self.file): + return self._return else: - raise NotImplementedError(f"Unknown file protocal {self.file}") - return self._return + return comments # } return proxy_func @@ -83,7 +86,7 @@ class Commentor(object): return handler(node, self, *args, **kwargs) def eval(self, node: ast.Expr, format=True): - src = to_source(node) + src = node if type(node) == str else to_source(node) obj = eval(src, self._globals, self._locals) if not format: return obj @@ -125,7 +128,3 @@ class Commentor(object): self.__append('"""') if line is not None: return self.__append(sign(line, 2)) - - def typeset(self): - return "\n".join( - (c[1] * " " + l for l, c in zip(self._lines, self._lines_category))) diff --git a/trace_commentor/handlers/definitions.py b/trace_commentor/handlers/definitions.py index bcbd0aa..88034db 100644 --- a/trace_commentor/handlers/definitions.py +++ b/trace_commentor/handlers/definitions.py @@ -2,9 +2,14 @@ from .. import flags from ..utils import to_source def FunctionDef(self, cmtor): - cmtor.append_source(f"def {self.name}():") + cmtor.append_source(f"def {self.name}({to_source(self.args)}):") cmtor.indent += flags.INDENT + if self is cmtor.root: + for arg in cmtor._locals: + cmtor.append_comment(cmtor.eval(arg)) + cmtor.append_source() + for stmt in self.body: if type(stmt) not in flags.APPEND_SOURCE_BY_THEMSELVES: diff --git a/trace_commentor/utils.py b/trace_commentor/utils.py index 6d31aa5..dd42b9b 100644 --- a/trace_commentor/utils.py +++ b/trace_commentor/utils.py @@ -2,6 +2,10 @@ import ast import astor import inspect import os +import sys +import rich +from rich.syntax import Syntax +from io import IOBase from . import flags @@ -20,3 +24,28 @@ def to_source(node): if type(node) != ast.Tuple and len(src) > 2 and src[0] == "(" and src[-1] == ")": src = src[1:-1] return src + + +def comment_to_file(code, file: str) -> bool: + if file == "": + return False + elif file == "": + if sys.stderr.isatty(): + syntax = Syntax(code, "python") + rich.print(syntax, file=sys.stderr) + else: + rich.print(code, file=sys.stderr) + elif file == "": + if sys.stdout.isatty(): + syntax = Syntax(code, "python") + rich.print(syntax, file=sys.stdout) + else: + rich.print(code, file=sys.stdout) + elif isinstance(file, IOBase): + rich.print(code, file=file) + elif type(file) == str: + with open(file, "wt") as f: + rich.print(code, file=f) + else: + raise NotImplementedError(f"Unknown file protocal {file}") + return True