Compare commits

...

2 Commits

Author SHA1 Message Date
Yuyao Huang (Sam)
0bb3135cb7 add support to "Starred" expressions 2025-07-16 21:31:32 +08:00
Yuyao Huang (Sam)
e0e0216312 feat(trace_commentor): 添加对ast.ExtSlice和nn.Parameter的支持
- 在`formatters/__init__.py`中添加了`Parameter`函数,用于判断对象是否为`Parameter`类型。
- 在`formatters/desc.py`中的`desc`函数里添加了对`Parameter`对象的描述逻辑,包括数据类型和形状的格式化输出。
- 在`handlers/__init__.py`中将`Index`替换为`ExtSlice`,以支持扩展切片的处理。
- 在`handlers/expressions.py`中添加了`ExtSlice`函数,用于处理扩展切片表达式。
2025-07-04 12:57:12 +08:00
6 changed files with 56 additions and 1 deletions

2
.gitignore vendored
View File

@ -2,3 +2,5 @@ __pycache__
*.log *.log
*.log.py *.log.py
.vscode-upload.json .vscode-upload.json
dist
*.egg-info

36
pyproject.toml Normal file
View File

@ -0,0 +1,36 @@
[build-system]
requires = ["setuptools>=65.5.0", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "trace_commentor" # pip install
version = "2025.7.1"
description = "..."
readme = "README.md" # 可选, 如果你有 README.md 文件
requires-python = ">=3.8"
license = {text = "Apache-2.0"} # 或者你的许可证
authors = [
{name = "Yuyao Huang (Sam)", email = "huangyuyao@outlook.com"},
]
dependencies = [
"astor",
"rich",
]
[tool.setuptools]
py-modules = ["trace_commentor"]
# 包含非 Python 文件
# [tool.setuptools.package-data]
# "package.path" = ["*.json"]
# 命令行程序入口
# [project.scripts]
# script-name = "module.path:function"
[project.optional-dependencies]
dev = [
"pytest",
"pytest-cov",
"pytest-mock",
]

View File

@ -10,9 +10,14 @@ def Tensor(tensor):
return "Tensor" in tensor.__class__.__name__ return "Tensor" in tensor.__class__.__name__
def Parameter(parameter):
return "Parameter" in parameter.__class__.__name__
LIST = [ LIST = [
(callable, silent), (callable, silent),
(Tensor, desc), (Tensor, desc),
(Parameter, desc),
(types.ModuleType, silent), (types.ModuleType, silent),
((list, dict), desc), ((list, dict), desc),
] ]

View File

@ -24,6 +24,9 @@ def _apply(x, f):
def desc(x): def desc(x):
def _desc(_x): def _desc(_x):
if 'Parameter' in _x.__class__.__name__:
dtype = str(_x.data.dtype).replace("torch.", "").replace("float", "f").replace("uint", "u").replace("int", "i")
return _print(f"Parameter({tuple(_x.data.shape)}, {dtype})")
if "Tensor" in _x.__class__.__name__: if "Tensor" in _x.__class__.__name__:
dtype = str(_x.dtype).replace("torch.", "").replace("float", "f").replace("uint", "u").replace("int", "i") dtype = str(_x.dtype).replace("torch.", "").replace("float", "f").replace("uint", "u").replace("int", "i")
return _print(f"Tensor({tuple(_x.shape)}, {dtype})") return _print(f"Tensor({tuple(_x.shape)}, {dtype})")

View File

@ -1,6 +1,6 @@
from .definitions import FunctionDef, Return, Lambda from .definitions import FunctionDef, Return, Lambda
from .statements import Pass, Assign, AnnAssign, AugAssign from .statements import Pass, Assign, AnnAssign, AugAssign
from .expressions import Expr, UnaryOp, BinOp, Call, Compare, Attribute, Subscript, Slice, keyword, IfExp, Index from .expressions import Expr, UnaryOp, BinOp, Call, Compare, Attribute, Subscript, Slice, ExtSlice, Starred, keyword, IfExp, Index
from .literals import Constant, Tuple, List, Set, Dict, FormattedValue, JoinedStr from .literals import Constant, Tuple, List, Set, Dict, FormattedValue, JoinedStr
from .variables import Name from .variables import Name
from .control_flow import If, For, Continue, Break, With, Raise from .control_flow import If, For, Continue, Break, With, Raise

View File

@ -55,6 +55,15 @@ def Slice(self, cmtor):
cmtor.process(self.step) cmtor.process(self.step)
def ExtSlice(self, cmtor):
for dim in self.dims:
cmtor.process(dim)
def Starred(self, cmtor):
cmtor.process(self.value)
def keyword(self, cmtor): def keyword(self, cmtor):
cmtor.process(self.value) cmtor.process(self.value)