not successfully configured for windows
This commit is contained in:
parent
4c801ddcdf
commit
11a5cc3719
BIN
conda/env.yml
Normal file
BIN
conda/env.yml
Normal file
Binary file not shown.
1
dev_env.sh
Normal file
1
dev_env.sh
Normal file
@ -0,0 +1 @@
|
||||
export PYTHONPATH=.
|
||||
37
tests/test_torch.py
Normal file
37
tests/test_torch.py
Normal file
@ -0,0 +1,37 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from trace_commentor.parser import *
|
||||
|
||||
|
||||
def test_func_no_arg():
|
||||
|
||||
@analyse
|
||||
def target():
|
||||
|
||||
x = torch.ones(4, 5)
|
||||
for i in range(3):
|
||||
x = x[..., None, :]
|
||||
|
||||
a = torch.randn(309, 110, 3)[:100]
|
||||
f = nn.Linear(3, 128)
|
||||
b = f(a.reshape(-1, 3)).reshape(309, 110, 128)
|
||||
c = torch.concat((a, b), dim=-1)
|
||||
|
||||
return c.flatten()
|
||||
|
||||
print()
|
||||
target()
|
||||
|
||||
|
||||
def test_for_loop():
|
||||
|
||||
@analyse
|
||||
def target():
|
||||
a = 1
|
||||
for i in range(3):
|
||||
a += 1
|
||||
print(a)
|
||||
|
||||
print()
|
||||
target()
|
||||
0
trace_commentor/__init__.py
Normal file
0
trace_commentor/__init__.py
Normal file
15
trace_commentor/interpretor/README.md
Normal file
15
trace_commentor/interpretor/README.md
Normal file
@ -0,0 +1,15 @@
|
||||
https://docs.python.org/3/library/ast.html
|
||||
|
||||
- [ ] Literals
|
||||
- [ ] Constant
|
||||
- [ ] Variables
|
||||
- [ ] Expressions
|
||||
- [ ] Subscripting
|
||||
- [ ] Comprehensions
|
||||
- [ ] Statements
|
||||
- [ ] Imports
|
||||
- [ ] Control flow
|
||||
- [ ] Pattern matching
|
||||
- [ ] Type parameters
|
||||
- [ ] Function and class definitions
|
||||
- [ ] Async and await
|
||||
32
trace_commentor/parser.py
Normal file
32
trace_commentor/parser.py
Normal file
@ -0,0 +1,32 @@
|
||||
import inspect
|
||||
import ast
|
||||
import astor
|
||||
|
||||
# from .interpretor import exec_ast
|
||||
|
||||
|
||||
class Parser(object):
|
||||
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def __call__(self, func, **_globals):
|
||||
|
||||
raw_lines, start_lineno = inspect.getsourcelines(func)
|
||||
indent_size = len(raw_lines[0]) - len(raw_lines[0].lstrip())
|
||||
unindented_source = ''.join([l[indent_size:] for l in raw_lines])
|
||||
root = ast.parse(unindented_source).body[0]
|
||||
|
||||
def proxy_func(*args, **kwargs):
|
||||
_locals = kwargs
|
||||
# exec_ast(root, _locals, _globals)
|
||||
import ipdb; ipdb.set_trace()
|
||||
root = ast.parse(unindented_source).body[0]
|
||||
ast.dump(root)
|
||||
...
|
||||
|
||||
return proxy_func
|
||||
|
||||
|
||||
analyse = Parser()
|
||||
Loading…
x
Reference in New Issue
Block a user