feat(trace_commentor): 添加对ast.ExtSlice和nn.Parameter的支持
- 在`formatters/__init__.py`中添加了`Parameter`函数,用于判断对象是否为`Parameter`类型。 - 在`formatters/desc.py`中的`desc`函数里添加了对`Parameter`对象的描述逻辑,包括数据类型和形状的格式化输出。 - 在`handlers/__init__.py`中将`Index`替换为`ExtSlice`,以支持扩展切片的处理。 - 在`handlers/expressions.py`中添加了`ExtSlice`函数,用于处理扩展切片表达式。
This commit is contained in:
parent
963a65fa71
commit
e0e0216312
36
pyproject.toml
Normal file
36
pyproject.toml
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=65.5.0", "wheel"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "trace_commentor" # pip install
|
||||||
|
version = "2025.7.1"
|
||||||
|
description = "..."
|
||||||
|
readme = "README.md" # 可选, 如果你有 README.md 文件
|
||||||
|
requires-python = ">=3.8"
|
||||||
|
license = {text = "Apache-2.0"} # 或者你的许可证
|
||||||
|
authors = [
|
||||||
|
{name = "Yuyao Huang (Sam)", email = "huangyuyao@outlook.com"},
|
||||||
|
]
|
||||||
|
dependencies = [
|
||||||
|
"astor",
|
||||||
|
"rich",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.setuptools]
|
||||||
|
py-modules = ["trace_commentor"]
|
||||||
|
|
||||||
|
# 包含非 Python 文件
|
||||||
|
# [tool.setuptools.package-data]
|
||||||
|
# "package.path" = ["*.json"]
|
||||||
|
|
||||||
|
# 命令行程序入口
|
||||||
|
# [project.scripts]
|
||||||
|
# script-name = "module.path:function"
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = [
|
||||||
|
"pytest",
|
||||||
|
"pytest-cov",
|
||||||
|
"pytest-mock",
|
||||||
|
]
|
||||||
@ -10,9 +10,14 @@ def Tensor(tensor):
|
|||||||
return "Tensor" in tensor.__class__.__name__
|
return "Tensor" in tensor.__class__.__name__
|
||||||
|
|
||||||
|
|
||||||
|
def Parameter(parameter):
|
||||||
|
return "Parameter" in parameter.__class__.__name__
|
||||||
|
|
||||||
|
|
||||||
LIST = [
|
LIST = [
|
||||||
(callable, silent),
|
(callable, silent),
|
||||||
(Tensor, desc),
|
(Tensor, desc),
|
||||||
|
(Parameter, desc),
|
||||||
(types.ModuleType, silent),
|
(types.ModuleType, silent),
|
||||||
((list, dict), desc),
|
((list, dict), desc),
|
||||||
]
|
]
|
||||||
|
|||||||
@ -24,6 +24,9 @@ def _apply(x, f):
|
|||||||
|
|
||||||
def desc(x):
|
def desc(x):
|
||||||
def _desc(_x):
|
def _desc(_x):
|
||||||
|
if 'Parameter' in _x.__class__.__name__:
|
||||||
|
dtype = str(_x.data.dtype).replace("torch.", "").replace("float", "f").replace("uint", "u").replace("int", "i")
|
||||||
|
return _print(f"Parameter({tuple(_x.data.shape)}, {dtype})")
|
||||||
if "Tensor" in _x.__class__.__name__:
|
if "Tensor" in _x.__class__.__name__:
|
||||||
dtype = str(_x.dtype).replace("torch.", "").replace("float", "f").replace("uint", "u").replace("int", "i")
|
dtype = str(_x.dtype).replace("torch.", "").replace("float", "f").replace("uint", "u").replace("int", "i")
|
||||||
return _print(f"Tensor({tuple(_x.shape)}, {dtype})")
|
return _print(f"Tensor({tuple(_x.shape)}, {dtype})")
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from .definitions import FunctionDef, Return, Lambda
|
from .definitions import FunctionDef, Return, Lambda
|
||||||
from .statements import Pass, Assign, AnnAssign, AugAssign
|
from .statements import Pass, Assign, AnnAssign, AugAssign
|
||||||
from .expressions import Expr, UnaryOp, BinOp, Call, Compare, Attribute, Subscript, Slice, keyword, IfExp, Index
|
from .expressions import Expr, UnaryOp, BinOp, Call, Compare, Attribute, Subscript, Slice, ExtSlice, keyword, IfExp, Index
|
||||||
from .literals import Constant, Tuple, List, Set, Dict, FormattedValue, JoinedStr
|
from .literals import Constant, Tuple, List, Set, Dict, FormattedValue, JoinedStr
|
||||||
from .variables import Name
|
from .variables import Name
|
||||||
from .control_flow import If, For, Continue, Break, With, Raise
|
from .control_flow import If, For, Continue, Break, With, Raise
|
||||||
|
|||||||
@ -55,6 +55,11 @@ def Slice(self, cmtor):
|
|||||||
cmtor.process(self.step)
|
cmtor.process(self.step)
|
||||||
|
|
||||||
|
|
||||||
|
def ExtSlice(self, cmtor):
|
||||||
|
for dim in self.dims:
|
||||||
|
cmtor.process(dim)
|
||||||
|
|
||||||
|
|
||||||
def keyword(self, cmtor):
|
def keyword(self, cmtor):
|
||||||
cmtor.process(self.value)
|
cmtor.process(self.value)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user