diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..d235d3f --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,36 @@ +[build-system] +requires = ["setuptools>=65.5.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "trace_commentor" # pip install +version = "2025.7.1" +description = "..." +readme = "README.md" # 可选, 如果你有 README.md 文件 +requires-python = ">=3.8" +license = {text = "Apache-2.0"} # 或者你的许可证 +authors = [ + {name = "Yuyao Huang (Sam)", email = "huangyuyao@outlook.com"}, +] +dependencies = [ + "astor", + "rich", +] + +[tool.setuptools] +py-modules = ["trace_commentor"] + +# 包含非 Python 文件 +# [tool.setuptools.package-data] +# "package.path" = ["*.json"] + +# 命令行程序入口 +# [project.scripts] +# script-name = "module.path:function" + +[project.optional-dependencies] +dev = [ + "pytest", + "pytest-cov", + "pytest-mock", +] diff --git a/trace_commentor/formatters/__init__.py b/trace_commentor/formatters/__init__.py index 2fe6c06..40b9fb6 100644 --- a/trace_commentor/formatters/__init__.py +++ b/trace_commentor/formatters/__init__.py @@ -10,9 +10,14 @@ def Tensor(tensor): return "Tensor" in tensor.__class__.__name__ +def Parameter(parameter): + return "Parameter" in parameter.__class__.__name__ + + LIST = [ (callable, silent), (Tensor, desc), + (Parameter, desc), (types.ModuleType, silent), ((list, dict), desc), ] diff --git a/trace_commentor/formatters/desc.py b/trace_commentor/formatters/desc.py index 374b95c..8ebde39 100644 --- a/trace_commentor/formatters/desc.py +++ b/trace_commentor/formatters/desc.py @@ -24,6 +24,9 @@ def _apply(x, f): def desc(x): def _desc(_x): + if 'Parameter' in _x.__class__.__name__: + dtype = str(_x.data.dtype).replace("torch.", "").replace("float", "f").replace("uint", "u").replace("int", "i") + return _print(f"Parameter({tuple(_x.data.shape)}, {dtype})") if "Tensor" in _x.__class__.__name__: dtype = str(_x.dtype).replace("torch.", "").replace("float", "f").replace("uint", "u").replace("int", "i") return _print(f"Tensor({tuple(_x.shape)}, {dtype})") diff --git a/trace_commentor/handlers/__init__.py b/trace_commentor/handlers/__init__.py index ebe60e4..da9a0f7 100644 --- a/trace_commentor/handlers/__init__.py +++ b/trace_commentor/handlers/__init__.py @@ -1,6 +1,6 @@ from .definitions import FunctionDef, Return, Lambda from .statements import Pass, Assign, AnnAssign, AugAssign -from .expressions import Expr, UnaryOp, BinOp, Call, Compare, Attribute, Subscript, Slice, keyword, IfExp, Index +from .expressions import Expr, UnaryOp, BinOp, Call, Compare, Attribute, Subscript, Slice, ExtSlice, keyword, IfExp, Index from .literals import Constant, Tuple, List, Set, Dict, FormattedValue, JoinedStr from .variables import Name from .control_flow import If, For, Continue, Break, With, Raise diff --git a/trace_commentor/handlers/expressions.py b/trace_commentor/handlers/expressions.py index 76f2fb2..f2e5927 100644 --- a/trace_commentor/handlers/expressions.py +++ b/trace_commentor/handlers/expressions.py @@ -55,6 +55,11 @@ def Slice(self, cmtor): cmtor.process(self.step) +def ExtSlice(self, cmtor): + for dim in self.dims: + cmtor.process(dim) + + def keyword(self, cmtor): cmtor.process(self.value)