forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
_sources.py
105 lines (86 loc) · 3.85 KB
/
_sources.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import ast
import functools
import inspect
from textwrap import dedent
from typing import Any, Optional, Tuple, List, NamedTuple
from torch._C import ErrorReport
from torch._C._jit_tree_views import SourceRangeFactory
def get_source_lines_and_file(
obj: Any,
error_msg: Optional[str] = None,
) -> Tuple[List[str], int, Optional[str]]:
"""
Wrapper around inspect.getsourcelines and inspect.getsourcefile.
Returns: (sourcelines, file_lino, filename)
"""
filename = None # in case getsourcefile throws
try:
filename = inspect.getsourcefile(obj)
sourcelines, file_lineno = inspect.getsourcelines(obj)
except OSError as e:
msg = (f"Can't get source for {obj}. TorchScript requires source access in "
"order to carry out compilation, make sure original .py files are "
"available.")
if error_msg:
msg += '\n' + error_msg
raise OSError(msg) from e
return sourcelines, file_lineno, filename
def normalize_source_lines(sourcelines: List[str]) -> List[str]:
"""
This helper function accepts a list of source lines. It finds the
indentation level of the function definition (`def`), then it indents
all lines in the function body to a point at or greater than that
level. This allows for comments and continued string literals that
are at a lower indentation than the rest of the code.
Args:
sourcelines: function source code, separated into lines by
the '\n' character
Returns:
A list of source lines that have been correctly aligned
"""
def remove_prefix(text, prefix):
return text[text.startswith(prefix) and len(prefix):]
# Find the line and line number containing the function definition
for i, l in enumerate(sourcelines):
if l.lstrip().startswith("def"):
idx = i
break
fn_def = sourcelines[idx]
# Get a string representing the amount of leading whitespace
whitespace = fn_def.split("def")[0]
# Add this leading whitespace to all lines before and after the `def`
aligned_prefix = [whitespace + remove_prefix(s, whitespace) for s in sourcelines[:idx]]
aligned_suffix = [whitespace + remove_prefix(s, whitespace) for s in sourcelines[idx + 1:]]
# Put it together again
aligned_prefix.append(fn_def)
return aligned_prefix + aligned_suffix
# Thin wrapper around SourceRangeFactory to store extra metadata
# about the function-to-be-compiled.
class SourceContext(SourceRangeFactory):
def __init__(self, source, filename, file_lineno, leading_whitespace_len, uses_true_division=True, funcname=None):
super(SourceContext, self).__init__(source, filename, file_lineno, leading_whitespace_len)
self.uses_true_division = uses_true_division
self.filename = filename
self.funcname = funcname
@functools.lru_cache(maxsize=None)
def make_source_context(*args):
return SourceContext(*args)
def fake_range():
return SourceContext('', None, 0, 0).make_raw_range(0, 1)
class ParsedDef(NamedTuple):
ast: ast.Module
ctx: SourceContext
source: str
filename: Optional[str]
file_lineno: int
def parse_def(fn):
sourcelines, file_lineno, filename = get_source_lines_and_file(fn, ErrorReport.call_stack())
sourcelines = normalize_source_lines(sourcelines)
source = ''.join(sourcelines)
dedent_src = dedent(source)
py_ast = ast.parse(dedent_src)
if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
raise RuntimeError(f"Expected a single top-level function: {filename}:{file_lineno}")
leading_whitespace_len = len(source.split('\n', 1)[0]) - len(dedent_src.split('\n', 1)[0])
ctx = make_source_context(source, filename, file_lineno, leading_whitespace_len, True, fn.__name__)
return ParsedDef(py_ast, ctx, source, filename, file_lineno)