This commit is contained in:
Waylon Walker 2022-03-31 20:20:07 -05:00
commit 38355d2442
No known key found for this signature in database
GPG key ID: 66E2BF2B4190EFE4
9083 changed files with 1225834 additions and 0 deletions

View file

@ -0,0 +1,54 @@
"""Encode valid C string literals from Python strings.
If a character is not allowed in C string literals, it is either emitted
as a simple escape sequence (e.g. '\\n'), or an octal escape sequence
with exactly three digits ('\\oXXX'). Question marks are escaped to
prevent trigraphs in the string literal from being interpreted. Note
that '\\?' is an invalid escape sequence in Python.
Consider the string literal "AB\\xCDEF". As one would expect, Python
parses it as ['A', 'B', 0xCD, 'E', 'F']. However, the C standard
specifies that all hexadecimal digits immediately following '\\x' will
be interpreted as part of the escape sequence. Therefore, it is
unexpectedly parsed as ['A', 'B', 0xCDEF].
Emitting ("AB\\xCD" "EF") would avoid this behaviour. However, we opt
for simplicity and use octal escape sequences instead. They do not
suffer from the same issue as they are defined to parse at most three
octal digits.
"""
import string
from typing_extensions import Final
CHAR_MAP: Final = ["\\{:03o}".format(i) for i in range(256)]
# It is safe to use string.printable as it always uses the C locale.
for c in string.printable:
CHAR_MAP[ord(c)] = c
# These assignments must come last because we prioritize simple escape
# sequences over any other representation.
for c in ('\'', '"', '\\', 'a', 'b', 'f', 'n', 'r', 't', 'v'):
escaped = '\\{}'.format(c)
decoded = escaped.encode('ascii').decode('unicode_escape')
CHAR_MAP[ord(decoded)] = escaped
# This escape sequence is invalid in Python.
CHAR_MAP[ord('?')] = r'\?'
def encode_bytes_as_c_string(b: bytes) -> str:
"""Produce contents of a C string literal for a byte string, without quotes."""
escaped = ''.join([CHAR_MAP[i] for i in b])
return escaped
def c_string_initializer(value: bytes) -> str:
"""Create initializer for a C char[]/ char * variable from a string.
For example, if value if b'foo', the result would be '"foo"'.
"""
return '"' + encode_bytes_as_c_string(value) + '"'

View file

@ -0,0 +1,888 @@
"""Utilities for emitting C code."""
from mypy.backports import OrderedDict
from typing import List, Set, Dict, Optional, Callable, Union, Tuple
import sys
from mypyc.common import (
REG_PREFIX, ATTR_PREFIX, STATIC_PREFIX, TYPE_PREFIX, NATIVE_PREFIX,
FAST_ISINSTANCE_MAX_SUBCLASSES, use_vectorcall
)
from mypyc.ir.ops import BasicBlock, Value
from mypyc.ir.rtypes import (
RType, RTuple, RInstance, RUnion, RPrimitive,
is_float_rprimitive, is_bool_rprimitive, is_int_rprimitive, is_short_int_rprimitive,
is_list_rprimitive, is_dict_rprimitive, is_set_rprimitive, is_tuple_rprimitive,
is_none_rprimitive, is_object_rprimitive, object_rprimitive, is_str_rprimitive,
int_rprimitive, is_optional_type, optional_value_type, is_int32_rprimitive,
is_int64_rprimitive, is_bit_rprimitive, is_range_rprimitive, is_bytes_rprimitive
)
from mypyc.ir.func_ir import FuncDecl
from mypyc.ir.class_ir import ClassIR, all_concrete_classes
from mypyc.namegen import NameGenerator, exported_name
from mypyc.sametype import is_same_type
from mypyc.codegen.literals import Literals
class HeaderDeclaration:
"""A representation of a declaration in C.
This is used to generate declarations in header files and
(optionally) definitions in source files.
Attributes:
decl: C source code for the declaration.
defn: Optionally, C source code for a definition.
dependencies: The names of any objects that must be declared prior.
is_type: Whether the declaration is of a C type. (C types will be declared in
external header files and not marked 'extern'.)
needs_export: Whether the declared object needs to be exported to
other modules in the linking table.
"""
def __init__(self,
decl: Union[str, List[str]],
defn: Optional[List[str]] = None,
*,
dependencies: Optional[Set[str]] = None,
is_type: bool = False,
needs_export: bool = False
) -> None:
self.decl = [decl] if isinstance(decl, str) else decl
self.defn = defn
self.dependencies = dependencies or set()
self.is_type = is_type
self.needs_export = needs_export
class EmitterContext:
"""Shared emitter state for a compilation group."""
def __init__(self,
names: NameGenerator,
group_name: Optional[str] = None,
group_map: Optional[Dict[str, Optional[str]]] = None,
) -> None:
"""Setup shared emitter state.
Args:
names: The name generator to use
group_map: Map from module names to group name
group_name: Current group name
"""
self.temp_counter = 0
self.names = names
self.group_name = group_name
self.group_map = group_map or {}
# Groups that this group depends on
self.group_deps: Set[str] = set()
# The map below is used for generating declarations and
# definitions at the top of the C file. The main idea is that they can
# be generated at any time during the emit phase.
# A map of a C identifier to whatever the C identifier declares. Currently this is
# used for declaring structs and the key corresponds to the name of the struct.
# The declaration contains the body of the struct.
self.declarations: Dict[str, HeaderDeclaration] = OrderedDict()
self.literals = Literals()
class ErrorHandler:
"""Describes handling errors in unbox/cast operations."""
class AssignHandler(ErrorHandler):
"""Assign an error value on error."""
class GotoHandler(ErrorHandler):
"""Goto label on error."""
def __init__(self, label: str) -> None:
self.label = label
class ReturnHandler(ErrorHandler):
"""Return a constant value on error."""
def __init__(self, value: str) -> None:
self.value = value
class Emitter:
"""Helper for C code generation."""
def __init__(self,
context: EmitterContext,
value_names: Optional[Dict[Value, str]] = None,
capi_version: Optional[Tuple[int, int]] = None,
) -> None:
self.context = context
self.capi_version = capi_version or sys.version_info[:2]
self.names = context.names
self.value_names = value_names or {}
self.fragments: List[str] = []
self._indent = 0
# Low-level operations
def indent(self) -> None:
self._indent += 4
def dedent(self) -> None:
self._indent -= 4
assert self._indent >= 0
def label(self, label: BasicBlock) -> str:
return 'CPyL%s' % label.label
def reg(self, reg: Value) -> str:
return REG_PREFIX + self.value_names[reg]
def attr(self, name: str) -> str:
return ATTR_PREFIX + name
def emit_line(self, line: str = '') -> None:
if line.startswith('}'):
self.dedent()
self.fragments.append(self._indent * ' ' + line + '\n')
if line.endswith('{'):
self.indent()
def emit_lines(self, *lines: str) -> None:
for line in lines:
self.emit_line(line)
def emit_label(self, label: Union[BasicBlock, str]) -> None:
if isinstance(label, str):
text = label
else:
text = self.label(label)
# Extra semicolon prevents an error when the next line declares a tempvar
self.fragments.append('{}: ;\n'.format(text))
def emit_from_emitter(self, emitter: 'Emitter') -> None:
self.fragments.extend(emitter.fragments)
def emit_printf(self, fmt: str, *args: str) -> None:
fmt = fmt.replace('\n', '\\n')
self.emit_line('printf(%s);' % ', '.join(['"%s"' % fmt] + list(args)))
self.emit_line('fflush(stdout);')
def temp_name(self) -> str:
self.context.temp_counter += 1
return '__tmp%d' % self.context.temp_counter
def new_label(self) -> str:
self.context.temp_counter += 1
return '__LL%d' % self.context.temp_counter
def get_module_group_prefix(self, module_name: str) -> str:
"""Get the group prefix for a module (relative to the current group).
The prefix should be prepended to the object name whenever
accessing an object from this module.
If the module lives is in the current compilation group, there is
no prefix. But if it lives in a different group (and hence a separate
extension module), we need to access objects from it indirectly via an
export table.
For example, for code in group `a` to call a function `bar` in group `b`,
it would need to do `exports_b.CPyDef_bar(...)`, while code that is
also in group `b` can simply do `CPyDef_bar(...)`.
Thus the prefix for a module in group `b` is 'exports_b.' if the current
group is *not* b and just '' if it is.
"""
groups = self.context.group_map
target_group_name = groups.get(module_name)
if target_group_name and target_group_name != self.context.group_name:
self.context.group_deps.add(target_group_name)
return 'exports_{}.'.format(exported_name(target_group_name))
else:
return ''
def get_group_prefix(self, obj: Union[ClassIR, FuncDecl]) -> str:
"""Get the group prefix for an object."""
# See docs above
return self.get_module_group_prefix(obj.module_name)
def static_name(self, id: str, module: Optional[str], prefix: str = STATIC_PREFIX) -> str:
"""Create name of a C static variable.
These are used for literals and imported modules, among other
things.
The caller should ensure that the (id, module) pair cannot
overlap with other calls to this method within a compilation
group.
"""
lib_prefix = '' if not module else self.get_module_group_prefix(module)
# If we are accessing static via the export table, we need to dereference
# the pointer also.
star_maybe = '*' if lib_prefix else ''
suffix = self.names.private_name(module or '', id)
return '{}{}{}{}'.format(star_maybe, lib_prefix, prefix, suffix)
def type_struct_name(self, cl: ClassIR) -> str:
return self.static_name(cl.name, cl.module_name, prefix=TYPE_PREFIX)
def ctype(self, rtype: RType) -> str:
return rtype._ctype
def ctype_spaced(self, rtype: RType) -> str:
"""Adds a space after ctype for non-pointers."""
ctype = self.ctype(rtype)
if ctype[-1] == '*':
return ctype
else:
return ctype + ' '
def c_undefined_value(self, rtype: RType) -> str:
if not rtype.is_unboxed:
return 'NULL'
elif isinstance(rtype, RPrimitive):
return rtype.c_undefined
elif isinstance(rtype, RTuple):
return self.tuple_undefined_value(rtype)
assert False, rtype
def c_error_value(self, rtype: RType) -> str:
return self.c_undefined_value(rtype)
def native_function_name(self, fn: FuncDecl) -> str:
return '{}{}'.format(NATIVE_PREFIX, fn.cname(self.names))
def tuple_c_declaration(self, rtuple: RTuple) -> List[str]:
result = [
'#ifndef MYPYC_DECLARED_{}'.format(rtuple.struct_name),
'#define MYPYC_DECLARED_{}'.format(rtuple.struct_name),
'typedef struct {} {{'.format(rtuple.struct_name),
]
if len(rtuple.types) == 0: # empty tuple
# Empty tuples contain a flag so that they can still indicate
# error values.
result.append('int empty_struct_error_flag;')
else:
i = 0
for typ in rtuple.types:
result.append('{}f{};'.format(self.ctype_spaced(typ), i))
i += 1
result.append('}} {};'.format(rtuple.struct_name))
values = self.tuple_undefined_value_helper(rtuple)
result.append('static {} {} = {{ {} }};'.format(
self.ctype(rtuple), self.tuple_undefined_value(rtuple), ''.join(values)))
result.append('#endif')
result.append('')
return result
def use_vectorcall(self) -> bool:
return use_vectorcall(self.capi_version)
def emit_undefined_attr_check(self, rtype: RType, attr_expr: str,
compare: str,
unlikely: bool = False) -> None:
if isinstance(rtype, RTuple):
check = '({})'.format(self.tuple_undefined_check_cond(
rtype, attr_expr, self.c_undefined_value, compare)
)
else:
check = '({} {} {})'.format(
attr_expr, compare, self.c_undefined_value(rtype)
)
if unlikely:
check = '(unlikely{})'.format(check)
self.emit_line('if {} {{'.format(check))
def tuple_undefined_check_cond(
self, rtuple: RTuple, tuple_expr_in_c: str,
c_type_compare_val: Callable[[RType], str], compare: str) -> str:
if len(rtuple.types) == 0:
# empty tuple
return '{}.empty_struct_error_flag {} {}'.format(
tuple_expr_in_c, compare, c_type_compare_val(int_rprimitive))
item_type = rtuple.types[0]
if isinstance(item_type, RTuple):
return self.tuple_undefined_check_cond(
item_type, tuple_expr_in_c + '.f0', c_type_compare_val, compare)
else:
return '{}.f0 {} {}'.format(
tuple_expr_in_c, compare, c_type_compare_val(item_type))
def tuple_undefined_value(self, rtuple: RTuple) -> str:
return 'tuple_undefined_' + rtuple.unique_id
def tuple_undefined_value_helper(self, rtuple: RTuple) -> List[str]:
res = []
# see tuple_c_declaration()
if len(rtuple.types) == 0:
return [self.c_undefined_value(int_rprimitive)]
for item in rtuple.types:
if not isinstance(item, RTuple):
res.append(self.c_undefined_value(item))
else:
sub_list = self.tuple_undefined_value_helper(item)
res.append('{ ')
res.extend(sub_list)
res.append(' }')
res.append(', ')
return res[:-1]
# Higher-level operations
def declare_tuple_struct(self, tuple_type: RTuple) -> None:
if tuple_type.struct_name not in self.context.declarations:
dependencies = set()
for typ in tuple_type.types:
# XXX other types might eventually need similar behavior
if isinstance(typ, RTuple):
dependencies.add(typ.struct_name)
self.context.declarations[tuple_type.struct_name] = HeaderDeclaration(
self.tuple_c_declaration(tuple_type),
dependencies=dependencies,
is_type=True,
)
def emit_inc_ref(self, dest: str, rtype: RType, *, rare: bool = False) -> None:
"""Increment reference count of C expression `dest`.
For composite unboxed structures (e.g. tuples) recursively
increment reference counts for each component.
If rare is True, optimize for code size and compilation speed.
"""
if is_int_rprimitive(rtype):
if rare:
self.emit_line('CPyTagged_IncRef(%s);' % dest)
else:
self.emit_line('CPyTagged_INCREF(%s);' % dest)
elif isinstance(rtype, RTuple):
for i, item_type in enumerate(rtype.types):
self.emit_inc_ref('{}.f{}'.format(dest, i), item_type)
elif not rtype.is_unboxed:
# Always inline, since this is a simple op
self.emit_line('CPy_INCREF(%s);' % dest)
# Otherwise assume it's an unboxed, pointerless value and do nothing.
def emit_dec_ref(self,
dest: str,
rtype: RType,
*,
is_xdec: bool = False,
rare: bool = False) -> None:
"""Decrement reference count of C expression `dest`.
For composite unboxed structures (e.g. tuples) recursively
decrement reference counts for each component.
If rare is True, optimize for code size and compilation speed.
"""
x = 'X' if is_xdec else ''
if is_int_rprimitive(rtype):
if rare:
self.emit_line('CPyTagged_%sDecRef(%s);' % (x, dest))
else:
# Inlined
self.emit_line('CPyTagged_%sDECREF(%s);' % (x, dest))
elif isinstance(rtype, RTuple):
for i, item_type in enumerate(rtype.types):
self.emit_dec_ref('{}.f{}'.format(dest, i), item_type, is_xdec=is_xdec, rare=rare)
elif not rtype.is_unboxed:
if rare:
self.emit_line('CPy_%sDecRef(%s);' % (x, dest))
else:
# Inlined
self.emit_line('CPy_%sDECREF(%s);' % (x, dest))
# Otherwise assume it's an unboxed, pointerless value and do nothing.
def pretty_name(self, typ: RType) -> str:
value_type = optional_value_type(typ)
if value_type is not None:
return '%s or None' % self.pretty_name(value_type)
return str(typ)
def emit_cast(self,
src: str,
dest: str,
typ: RType,
*,
declare_dest: bool = False,
error: Optional[ErrorHandler] = None,
raise_exception: bool = True,
optional: bool = False,
src_type: Optional[RType] = None,
likely: bool = True) -> None:
"""Emit code for casting a value of given type.
Somewhat strangely, this supports unboxed types but only
operates on boxed versions. This is necessary to properly
handle types such as Optional[int] in compatibility glue.
By default, assign NULL (error value) to dest if the value has
an incompatible type and raise TypeError. These can be customized
using 'error' and 'raise_exception'.
Always copy/steal the reference in 'src'.
Args:
src: Name of source C variable
dest: Name of target C variable
typ: Type of value
declare_dest: If True, also declare the variable 'dest'
error: What happens on error
raise_exception: If True, also raise TypeError on failure
likely: If the cast is likely to succeed (can be False for unions)
"""
error = error or AssignHandler()
if isinstance(error, AssignHandler):
handle_error = '%s = NULL;' % dest
elif isinstance(error, GotoHandler):
handle_error = 'goto %s;' % error.label
else:
assert isinstance(error, ReturnHandler)
handle_error = 'return %s;' % error.value
if raise_exception:
raise_exc = 'CPy_TypeError("{}", {}); '.format(self.pretty_name(typ), src)
err = raise_exc + handle_error
else:
err = handle_error
# Special case casting *from* optional
if src_type and is_optional_type(src_type) and not is_object_rprimitive(typ):
value_type = optional_value_type(src_type)
assert value_type is not None
if is_same_type(value_type, typ):
if declare_dest:
self.emit_line('PyObject *{};'.format(dest))
check = '({} != Py_None)'
if likely:
check = '(likely{})'.format(check)
self.emit_arg_check(src, dest, typ, check.format(src), optional)
self.emit_lines(
' {} = {};'.format(dest, src),
'else {',
err,
'}')
return
# TODO: Verify refcount handling.
if (is_list_rprimitive(typ) or is_dict_rprimitive(typ) or is_set_rprimitive(typ)
or is_str_rprimitive(typ) or is_range_rprimitive(typ) or is_float_rprimitive(typ)
or is_int_rprimitive(typ) or is_bool_rprimitive(typ) or is_bit_rprimitive(typ)):
if declare_dest:
self.emit_line('PyObject *{};'.format(dest))
if is_list_rprimitive(typ):
prefix = 'PyList'
elif is_dict_rprimitive(typ):
prefix = 'PyDict'
elif is_set_rprimitive(typ):
prefix = 'PySet'
elif is_str_rprimitive(typ):
prefix = 'PyUnicode'
elif is_range_rprimitive(typ):
prefix = 'PyRange'
elif is_float_rprimitive(typ):
prefix = 'CPyFloat'
elif is_int_rprimitive(typ):
prefix = 'PyLong'
elif is_bool_rprimitive(typ) or is_bit_rprimitive(typ):
prefix = 'PyBool'
else:
assert False, 'unexpected primitive type'
check = '({}_Check({}))'
if likely:
check = '(likely{})'.format(check)
self.emit_arg_check(src, dest, typ, check.format(prefix, src), optional)
self.emit_lines(
' {} = {};'.format(dest, src),
'else {',
err,
'}')
elif is_bytes_rprimitive(typ):
if declare_dest:
self.emit_line('PyObject *{};'.format(dest))
check = '(PyBytes_Check({}) || PyByteArray_Check({}))'
if likely:
check = '(likely{})'.format(check)
self.emit_arg_check(src, dest, typ, check.format(src, src), optional)
self.emit_lines(
' {} = {};'.format(dest, src),
'else {',
err,
'}')
elif is_tuple_rprimitive(typ):
if declare_dest:
self.emit_line('{} {};'.format(self.ctype(typ), dest))
check = '(PyTuple_Check({}))'
if likely:
check = '(likely{})'.format(check)
self.emit_arg_check(src, dest, typ,
check.format(src), optional)
self.emit_lines(
' {} = {};'.format(dest, src),
'else {',
err,
'}')
elif isinstance(typ, RInstance):
if declare_dest:
self.emit_line('PyObject *{};'.format(dest))
concrete = all_concrete_classes(typ.class_ir)
# If there are too many concrete subclasses or we can't find any
# (meaning the code ought to be dead or we aren't doing global opts),
# fall back to a normal typecheck.
# Otherwise check all the subclasses.
if not concrete or len(concrete) > FAST_ISINSTANCE_MAX_SUBCLASSES + 1:
check = '(PyObject_TypeCheck({}, {}))'.format(
src, self.type_struct_name(typ.class_ir))
else:
full_str = '(Py_TYPE({src}) == {targets[0]})'
for i in range(1, len(concrete)):
full_str += ' || (Py_TYPE({src}) == {targets[%d]})' % i
if len(concrete) > 1:
full_str = '(%s)' % full_str
check = full_str.format(
src=src, targets=[self.type_struct_name(ir) for ir in concrete])
if likely:
check = '(likely{})'.format(check)
self.emit_arg_check(src, dest, typ, check, optional)
self.emit_lines(
' {} = {};'.format(dest, src),
'else {',
err,
'}')
elif is_none_rprimitive(typ):
if declare_dest:
self.emit_line('PyObject *{};'.format(dest))
check = '({} == Py_None)'
if likely:
check = '(likely{})'.format(check)
self.emit_arg_check(src, dest, typ,
check.format(src), optional)
self.emit_lines(
' {} = {};'.format(dest, src),
'else {',
err,
'}')
elif is_object_rprimitive(typ):
if declare_dest:
self.emit_line('PyObject *{};'.format(dest))
self.emit_arg_check(src, dest, typ, '', optional)
self.emit_line('{} = {};'.format(dest, src))
if optional:
self.emit_line('}')
elif isinstance(typ, RUnion):
self.emit_union_cast(src, dest, typ, declare_dest, err, optional, src_type)
elif isinstance(typ, RTuple):
assert not optional
self.emit_tuple_cast(src, dest, typ, declare_dest, err, src_type)
else:
assert False, 'Cast not implemented: %s' % typ
def emit_union_cast(self,
src: str,
dest: str,
typ: RUnion,
declare_dest: bool,
err: str,
optional: bool,
src_type: Optional[RType]) -> None:
"""Emit cast to a union type.
The arguments are similar to emit_cast.
"""
if declare_dest:
self.emit_line('PyObject *{};'.format(dest))
good_label = self.new_label()
if optional:
self.emit_line('if ({} == NULL) {{'.format(src))
self.emit_line('{} = {};'.format(dest, self.c_error_value(typ)))
self.emit_line('goto {};'.format(good_label))
self.emit_line('}')
for item in typ.items:
self.emit_cast(src,
dest,
item,
declare_dest=False,
raise_exception=False,
optional=False,
likely=False)
self.emit_line('if ({} != NULL) goto {};'.format(dest, good_label))
# Handle cast failure.
self.emit_line(err)
self.emit_label(good_label)
def emit_tuple_cast(self, src: str, dest: str, typ: RTuple, declare_dest: bool,
err: str, src_type: Optional[RType]) -> None:
"""Emit cast to a tuple type.
The arguments are similar to emit_cast.
"""
if declare_dest:
self.emit_line('PyObject *{};'.format(dest))
# This reuse of the variable is super dodgy. We don't even
# care about the values except to check whether they are
# invalid.
out_label = self.new_label()
self.emit_lines(
'if (unlikely(!(PyTuple_Check({r}) && PyTuple_GET_SIZE({r}) == {size}))) {{'.format(
r=src, size=len(typ.types)),
'{} = NULL;'.format(dest),
'goto {};'.format(out_label),
'}')
for i, item in enumerate(typ.types):
# Since we did the checks above this should never fail
self.emit_cast('PyTuple_GET_ITEM({}, {})'.format(src, i),
dest,
item,
declare_dest=False,
raise_exception=False,
optional=False)
self.emit_line('if ({} == NULL) goto {};'.format(dest, out_label))
self.emit_line('{} = {};'.format(dest, src))
self.emit_label(out_label)
def emit_arg_check(self, src: str, dest: str, typ: RType, check: str, optional: bool) -> None:
if optional:
self.emit_line('if ({} == NULL) {{'.format(src))
self.emit_line('{} = {};'.format(dest, self.c_error_value(typ)))
if check != '':
self.emit_line('{}if {}'.format('} else ' if optional else '', check))
elif optional:
self.emit_line('else {')
def emit_unbox(self,
src: str,
dest: str,
typ: RType,
*,
declare_dest: bool = False,
error: Optional[ErrorHandler] = None,
raise_exception: bool = True,
optional: bool = False,
borrow: bool = False) -> None:
"""Emit code for unboxing a value of given type (from PyObject *).
By default, assign error value to dest if the value has an
incompatible type and raise TypeError. These can be customized
using 'error' and 'raise_exception'.
Generate a new reference unless 'borrow' is True.
Args:
src: Name of source C variable
dest: Name of target C variable
typ: Type of value
declare_dest: If True, also declare the variable 'dest'
error: What happens on error
raise_exception: If True, also raise TypeError on failure
borrow: If True, create a borrowed reference
"""
error = error or AssignHandler()
# TODO: Verify refcount handling.
if isinstance(error, AssignHandler):
failure = '%s = %s;' % (dest, self.c_error_value(typ))
elif isinstance(error, GotoHandler):
failure = 'goto %s;' % error.label
else:
assert isinstance(error, ReturnHandler)
failure = 'return %s;' % error.value
if raise_exception:
raise_exc = 'CPy_TypeError("{}", {}); '.format(self.pretty_name(typ), src)
failure = raise_exc + failure
if is_int_rprimitive(typ) or is_short_int_rprimitive(typ):
if declare_dest:
self.emit_line('CPyTagged {};'.format(dest))
self.emit_arg_check(src, dest, typ, '(likely(PyLong_Check({})))'.format(src),
optional)
if borrow:
self.emit_line(' {} = CPyTagged_BorrowFromObject({});'.format(dest, src))
else:
self.emit_line(' {} = CPyTagged_FromObject({});'.format(dest, src))
self.emit_line('else {')
self.emit_line(failure)
self.emit_line('}')
elif is_bool_rprimitive(typ) or is_bit_rprimitive(typ):
# Whether we are borrowing or not makes no difference.
if declare_dest:
self.emit_line('char {};'.format(dest))
self.emit_arg_check(src, dest, typ, '(unlikely(!PyBool_Check({}))) {{'.format(src),
optional)
self.emit_line(failure)
self.emit_line('} else')
conversion = '{} == Py_True'.format(src)
self.emit_line(' {} = {};'.format(dest, conversion))
elif is_none_rprimitive(typ):
# Whether we are borrowing or not makes no difference.
if declare_dest:
self.emit_line('char {};'.format(dest))
self.emit_arg_check(src, dest, typ, '(unlikely({} != Py_None)) {{'.format(src),
optional)
self.emit_line(failure)
self.emit_line('} else')
self.emit_line(' {} = 1;'.format(dest))
elif isinstance(typ, RTuple):
self.declare_tuple_struct(typ)
if declare_dest:
self.emit_line('{} {};'.format(self.ctype(typ), dest))
# HACK: The error handling for unboxing tuples is busted
# and instead of fixing it I am just wrapping it in the
# cast code which I think is right. This is not good.
if optional:
self.emit_line('if ({} == NULL) {{'.format(src))
self.emit_line('{} = {};'.format(dest, self.c_error_value(typ)))
self.emit_line('} else {')
cast_temp = self.temp_name()
self.emit_tuple_cast(src, cast_temp, typ, declare_dest=True, err='', src_type=None)
self.emit_line('if (unlikely({} == NULL)) {{'.format(cast_temp))
# self.emit_arg_check(src, dest, typ,
# '(!PyTuple_Check({}) || PyTuple_Size({}) != {}) {{'.format(
# src, src, len(typ.types)), optional)
self.emit_line(failure) # TODO: Decrease refcount?
self.emit_line('} else {')
if not typ.types:
self.emit_line('{}.empty_struct_error_flag = 0;'.format(dest))
for i, item_type in enumerate(typ.types):
temp = self.temp_name()
# emit_tuple_cast above checks the size, so this should not fail
self.emit_line('PyObject *{} = PyTuple_GET_ITEM({}, {});'.format(temp, src, i))
temp2 = self.temp_name()
# Unbox or check the item.
if item_type.is_unboxed:
self.emit_unbox(temp,
temp2,
item_type,
raise_exception=raise_exception,
error=error,
declare_dest=True,
borrow=borrow)
else:
if not borrow:
self.emit_inc_ref(temp, object_rprimitive)
self.emit_cast(temp, temp2, item_type, declare_dest=True)
self.emit_line('{}.f{} = {};'.format(dest, i, temp2))
self.emit_line('}')
if optional:
self.emit_line('}')
else:
assert False, 'Unboxing not implemented: %s' % typ
def emit_box(self, src: str, dest: str, typ: RType, declare_dest: bool = False,
can_borrow: bool = False) -> None:
"""Emit code for boxing a value of given type.
Generate a simple assignment if no boxing is needed.
The source reference count is stolen for the result (no need to decref afterwards).
"""
# TODO: Always generate a new reference (if a reference type)
if declare_dest:
declaration = 'PyObject *'
else:
declaration = ''
if is_int_rprimitive(typ) or is_short_int_rprimitive(typ):
# Steal the existing reference if it exists.
self.emit_line('{}{} = CPyTagged_StealAsObject({});'.format(declaration, dest, src))
elif is_bool_rprimitive(typ) or is_bit_rprimitive(typ):
# N.B: bool is special cased to produce a borrowed value
# after boxing, so we don't need to increment the refcount
# when this comes directly from a Box op.
self.emit_lines('{}{} = {} ? Py_True : Py_False;'.format(declaration, dest, src))
if not can_borrow:
self.emit_inc_ref(dest, object_rprimitive)
elif is_none_rprimitive(typ):
# N.B: None is special cased to produce a borrowed value
# after boxing, so we don't need to increment the refcount
# when this comes directly from a Box op.
self.emit_lines('{}{} = Py_None;'.format(declaration, dest))
if not can_borrow:
self.emit_inc_ref(dest, object_rprimitive)
elif is_int32_rprimitive(typ):
self.emit_line('{}{} = PyLong_FromLong({});'.format(declaration, dest, src))
elif is_int64_rprimitive(typ):
self.emit_line('{}{} = PyLong_FromLongLong({});'.format(declaration, dest, src))
elif isinstance(typ, RTuple):
self.declare_tuple_struct(typ)
self.emit_line('{}{} = PyTuple_New({});'.format(declaration, dest, len(typ.types)))
self.emit_line('if (unlikely({} == NULL))'.format(dest))
self.emit_line(' CPyError_OutOfMemory();')
# TODO: Fail if dest is None
for i in range(0, len(typ.types)):
if not typ.is_unboxed:
self.emit_line('PyTuple_SET_ITEM({}, {}, {}.f{}'.format(dest, i, src, i))
else:
inner_name = self.temp_name()
self.emit_box('{}.f{}'.format(src, i), inner_name, typ.types[i],
declare_dest=True)
self.emit_line('PyTuple_SET_ITEM({}, {}, {});'.format(dest, i, inner_name))
else:
assert not typ.is_unboxed
# Type is boxed -- trivially just assign.
self.emit_line('{}{} = {};'.format(declaration, dest, src))
def emit_error_check(self, value: str, rtype: RType, failure: str) -> None:
"""Emit code for checking a native function return value for uncaught exception."""
if not isinstance(rtype, RTuple):
self.emit_line('if ({} == {}) {{'.format(value, self.c_error_value(rtype)))
else:
if len(rtype.types) == 0:
return # empty tuples can't fail.
else:
cond = self.tuple_undefined_check_cond(rtype, value, self.c_error_value, '==')
self.emit_line('if ({}) {{'.format(cond))
self.emit_lines(failure, '}')
def emit_gc_visit(self, target: str, rtype: RType) -> None:
"""Emit code for GC visiting a C variable reference.
Assume that 'target' represents a C expression that refers to a
struct member, such as 'self->x'.
"""
if not rtype.is_refcounted:
# Not refcounted -> no pointers -> no GC interaction.
return
elif isinstance(rtype, RPrimitive) and rtype.name == 'builtins.int':
self.emit_line('if (CPyTagged_CheckLong({})) {{'.format(target))
self.emit_line('Py_VISIT(CPyTagged_LongAsObject({}));'.format(target))
self.emit_line('}')
elif isinstance(rtype, RTuple):
for i, item_type in enumerate(rtype.types):
self.emit_gc_visit('{}.f{}'.format(target, i), item_type)
elif self.ctype(rtype) == 'PyObject *':
# The simplest case.
self.emit_line('Py_VISIT({});'.format(target))
else:
assert False, 'emit_gc_visit() not implemented for %s' % repr(rtype)
def emit_gc_clear(self, target: str, rtype: RType) -> None:
"""Emit code for clearing a C attribute reference for GC.
Assume that 'target' represents a C expression that refers to a
struct member, such as 'self->x'.
"""
if not rtype.is_refcounted:
# Not refcounted -> no pointers -> no GC interaction.
return
elif isinstance(rtype, RPrimitive) and rtype.name == 'builtins.int':
self.emit_line('if (CPyTagged_CheckLong({})) {{'.format(target))
self.emit_line('CPyTagged __tmp = {};'.format(target))
self.emit_line('{} = {};'.format(target, self.c_undefined_value(rtype)))
self.emit_line('Py_XDECREF(CPyTagged_LongAsObject(__tmp));')
self.emit_line('}')
elif isinstance(rtype, RTuple):
for i, item_type in enumerate(rtype.types):
self.emit_gc_clear('{}.f{}'.format(target, i), item_type)
elif self.ctype(rtype) == 'PyObject *' and self.c_undefined_value(rtype) == 'NULL':
# The simplest case.
self.emit_line('Py_CLEAR({});'.format(target))
else:
assert False, 'emit_gc_clear() not implemented for %s' % repr(rtype)

View file

@ -0,0 +1,949 @@
"""Code generation for native classes and related wrappers."""
from typing import Optional, List, Tuple, Dict, Callable, Mapping, Set
from mypy.backports import OrderedDict
from mypyc.common import PREFIX, NATIVE_PREFIX, REG_PREFIX, use_fastcall
from mypyc.codegen.emit import Emitter, HeaderDeclaration, ReturnHandler
from mypyc.codegen.emitfunc import native_function_header
from mypyc.codegen.emitwrapper import (
generate_dunder_wrapper, generate_hash_wrapper, generate_richcompare_wrapper,
generate_bool_wrapper, generate_get_wrapper, generate_len_wrapper,
generate_set_del_item_wrapper, generate_contains_wrapper, generate_bin_op_wrapper
)
from mypyc.ir.rtypes import RType, RTuple, object_rprimitive
from mypyc.ir.func_ir import FuncIR, FuncDecl, FUNC_STATICMETHOD, FUNC_CLASSMETHOD
from mypyc.ir.class_ir import ClassIR, VTableEntries
from mypyc.sametype import is_same_type
from mypyc.namegen import NameGenerator
def native_slot(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str:
return '{}{}'.format(NATIVE_PREFIX, fn.cname(emitter.names))
def wrapper_slot(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str:
return '{}{}'.format(PREFIX, fn.cname(emitter.names))
# We maintain a table from dunder function names to struct slots they
# correspond to and functions that generate a wrapper (if necessary)
# and return the function name to stick in the slot.
# TODO: Add remaining dunder methods
SlotGenerator = Callable[[ClassIR, FuncIR, Emitter], str]
SlotTable = Mapping[str, Tuple[str, SlotGenerator]]
SLOT_DEFS: SlotTable = {
'__init__': ('tp_init', lambda c, t, e: generate_init_for_class(c, t, e)),
'__call__': ('tp_call', lambda c, t, e: generate_call_wrapper(c, t, e)),
'__str__': ('tp_str', native_slot),
'__repr__': ('tp_repr', native_slot),
'__next__': ('tp_iternext', native_slot),
'__iter__': ('tp_iter', native_slot),
'__hash__': ('tp_hash', generate_hash_wrapper),
'__get__': ('tp_descr_get', generate_get_wrapper),
}
AS_MAPPING_SLOT_DEFS: SlotTable = {
'__getitem__': ('mp_subscript', generate_dunder_wrapper),
'__setitem__': ('mp_ass_subscript', generate_set_del_item_wrapper),
'__delitem__': ('mp_ass_subscript', generate_set_del_item_wrapper),
'__len__': ('mp_length', generate_len_wrapper),
}
AS_SEQUENCE_SLOT_DEFS: SlotTable = {
'__contains__': ('sq_contains', generate_contains_wrapper),
}
AS_NUMBER_SLOT_DEFS: SlotTable = {
'__bool__': ('nb_bool', generate_bool_wrapper),
'__neg__': ('nb_negative', generate_dunder_wrapper),
'__invert__': ('nb_invert', generate_dunder_wrapper),
'__int__': ('nb_int', generate_dunder_wrapper),
'__float__': ('nb_float', generate_dunder_wrapper),
'__add__': ('nb_add', generate_bin_op_wrapper),
'__radd__': ('nb_add', generate_bin_op_wrapper),
'__sub__': ('nb_subtract', generate_bin_op_wrapper),
'__rsub__': ('nb_subtract', generate_bin_op_wrapper),
'__mul__': ('nb_multiply', generate_bin_op_wrapper),
'__rmul__': ('nb_multiply', generate_bin_op_wrapper),
'__mod__': ('nb_remainder', generate_bin_op_wrapper),
'__rmod__': ('nb_remainder', generate_bin_op_wrapper),
'__truediv__': ('nb_true_divide', generate_bin_op_wrapper),
'__rtruediv__': ('nb_true_divide', generate_bin_op_wrapper),
'__floordiv__': ('nb_floor_divide', generate_bin_op_wrapper),
'__rfloordiv__': ('nb_floor_divide', generate_bin_op_wrapper),
'__lshift__': ('nb_lshift', generate_bin_op_wrapper),
'__rlshift__': ('nb_lshift', generate_bin_op_wrapper),
'__rshift__': ('nb_rshift', generate_bin_op_wrapper),
'__rrshift__': ('nb_rshift', generate_bin_op_wrapper),
'__and__': ('nb_and', generate_bin_op_wrapper),
'__rand__': ('nb_and', generate_bin_op_wrapper),
'__or__': ('nb_or', generate_bin_op_wrapper),
'__ror__': ('nb_or', generate_bin_op_wrapper),
'__xor__': ('nb_xor', generate_bin_op_wrapper),
'__rxor__': ('nb_xor', generate_bin_op_wrapper),
'__matmul__': ('nb_matrix_multiply', generate_bin_op_wrapper),
'__rmatmul__': ('nb_matrix_multiply', generate_bin_op_wrapper),
'__iadd__': ('nb_inplace_add', generate_dunder_wrapper),
'__isub__': ('nb_inplace_subtract', generate_dunder_wrapper),
'__imul__': ('nb_inplace_multiply', generate_dunder_wrapper),
'__imod__': ('nb_inplace_remainder', generate_dunder_wrapper),
'__itruediv__': ('nb_inplace_true_divide', generate_dunder_wrapper),
'__ifloordiv__': ('nb_inplace_floor_divide', generate_dunder_wrapper),
'__ilshift__': ('nb_inplace_lshift', generate_dunder_wrapper),
'__irshift__': ('nb_inplace_rshift', generate_dunder_wrapper),
'__iand__': ('nb_inplace_and', generate_dunder_wrapper),
'__ior__': ('nb_inplace_or', generate_dunder_wrapper),
'__ixor__': ('nb_inplace_xor', generate_dunder_wrapper),
'__imatmul__': ('nb_inplace_matrix_multiply', generate_dunder_wrapper),
}
AS_ASYNC_SLOT_DEFS: SlotTable = {
'__await__': ('am_await', native_slot),
'__aiter__': ('am_aiter', native_slot),
'__anext__': ('am_anext', native_slot),
}
SIDE_TABLES = [
('as_mapping', 'PyMappingMethods', AS_MAPPING_SLOT_DEFS),
('as_sequence', 'PySequenceMethods', AS_SEQUENCE_SLOT_DEFS),
('as_number', 'PyNumberMethods', AS_NUMBER_SLOT_DEFS),
('as_async', 'PyAsyncMethods', AS_ASYNC_SLOT_DEFS),
]
# Slots that need to always be filled in because they don't get
# inherited right.
ALWAYS_FILL = {
'__hash__',
}
def generate_call_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str:
if emitter.use_vectorcall():
# Use vectorcall wrapper if supported (PEP 590).
return 'PyVectorcall_Call'
else:
# On older Pythons use the legacy wrapper.
return wrapper_slot(cl, fn, emitter)
def slot_key(attr: str) -> str:
"""Map dunder method name to sort key.
Sort reverse operator methods and __delitem__ after others ('x' > '_').
"""
if (attr.startswith('__r') and attr != '__rshift__') or attr == '__delitem__':
return 'x' + attr
return attr
def generate_slots(cl: ClassIR, table: SlotTable, emitter: Emitter) -> Dict[str, str]:
fields: Dict[str, str] = OrderedDict()
generated: Dict[str, str] = {}
# Sort for determinism on Python 3.5
for name, (slot, generator) in sorted(table.items(), key=lambda x: slot_key(x[0])):
method_cls = cl.get_method_and_class(name)
if method_cls and (method_cls[1] == cl or name in ALWAYS_FILL):
if slot in generated:
# Reuse previously generated wrapper.
fields[slot] = generated[slot]
else:
# Generate new wrapper.
name = generator(cl, method_cls[0], emitter)
fields[slot] = name
generated[slot] = name
return fields
def generate_class_type_decl(cl: ClassIR, c_emitter: Emitter,
external_emitter: Emitter,
emitter: Emitter) -> None:
context = c_emitter.context
name = emitter.type_struct_name(cl)
context.declarations[name] = HeaderDeclaration(
'PyTypeObject *{};'.format(emitter.type_struct_name(cl)),
needs_export=True)
# If this is a non-extension class, all we want is the type object decl.
if not cl.is_ext_class:
return
generate_object_struct(cl, external_emitter)
generate_full = not cl.is_trait and not cl.builtin_base
if generate_full:
context.declarations[emitter.native_function_name(cl.ctor)] = HeaderDeclaration(
'{};'.format(native_function_header(cl.ctor, emitter)),
needs_export=True,
)
def generate_class(cl: ClassIR, module: str, emitter: Emitter) -> None:
"""Generate C code for a class.
This is the main entry point to the module.
"""
name = cl.name
name_prefix = cl.name_prefix(emitter.names)
setup_name = '{}_setup'.format(name_prefix)
new_name = '{}_new'.format(name_prefix)
members_name = '{}_members'.format(name_prefix)
getseters_name = '{}_getseters'.format(name_prefix)
vtable_name = '{}_vtable'.format(name_prefix)
traverse_name = '{}_traverse'.format(name_prefix)
clear_name = '{}_clear'.format(name_prefix)
dealloc_name = '{}_dealloc'.format(name_prefix)
methods_name = '{}_methods'.format(name_prefix)
vtable_setup_name = '{}_trait_vtable_setup'.format(name_prefix)
fields: Dict[str, str] = OrderedDict()
fields['tp_name'] = '"{}"'.format(name)
generate_full = not cl.is_trait and not cl.builtin_base
needs_getseters = cl.needs_getseters or not cl.is_generated
if not cl.builtin_base:
fields['tp_new'] = new_name
if generate_full:
fields['tp_dealloc'] = '(destructor){}_dealloc'.format(name_prefix)
fields['tp_traverse'] = '(traverseproc){}_traverse'.format(name_prefix)
fields['tp_clear'] = '(inquiry){}_clear'.format(name_prefix)
if needs_getseters:
fields['tp_getset'] = getseters_name
fields['tp_methods'] = methods_name
def emit_line() -> None:
emitter.emit_line()
emit_line()
# If the class has a method to initialize default attribute
# values, we need to call it during initialization.
defaults_fn = cl.get_method('__mypyc_defaults_setup')
# If there is a __init__ method, we'll use it in the native constructor.
init_fn = cl.get_method('__init__')
# Fill out slots in the type object from dunder methods.
fields.update(generate_slots(cl, SLOT_DEFS, emitter))
# Fill out dunder methods that live in tables hanging off the side.
for table_name, type, slot_defs in SIDE_TABLES:
slots = generate_slots(cl, slot_defs, emitter)
if slots:
table_struct_name = generate_side_table_for_class(cl, table_name, type, slots, emitter)
fields['tp_{}'.format(table_name)] = '&{}'.format(table_struct_name)
richcompare_name = generate_richcompare_wrapper(cl, emitter)
if richcompare_name:
fields['tp_richcompare'] = richcompare_name
# If the class inherits from python, make space for a __dict__
struct_name = cl.struct_name(emitter.names)
if cl.builtin_base:
base_size = 'sizeof({})'.format(cl.builtin_base)
elif cl.is_trait:
base_size = 'sizeof(PyObject)'
else:
base_size = 'sizeof({})'.format(struct_name)
# Since our types aren't allocated using type() we need to
# populate these fields ourselves if we want them to have correct
# values. PyType_Ready will inherit the offsets from tp_base but
# that isn't what we want.
# XXX: there is no reason for the __weakref__ stuff to be mixed up with __dict__
if cl.has_dict:
# __dict__ lives right after the struct and __weakref__ lives right after that
# TODO: They should get members in the struct instead of doing this nonsense.
weak_offset = '{} + sizeof(PyObject *)'.format(base_size)
emitter.emit_lines(
'PyMemberDef {}[] = {{'.format(members_name),
'{{"__dict__", T_OBJECT_EX, {}, 0, NULL}},'.format(base_size),
'{{"__weakref__", T_OBJECT_EX, {}, 0, NULL}},'.format(weak_offset),
'{0}',
'};',
)
fields['tp_members'] = members_name
fields['tp_basicsize'] = '{} + 2*sizeof(PyObject *)'.format(base_size)
fields['tp_dictoffset'] = base_size
fields['tp_weaklistoffset'] = weak_offset
else:
fields['tp_basicsize'] = base_size
if generate_full:
# Declare setup method that allocates and initializes an object. type is the
# type of the class being initialized, which could be another class if there
# is an interpreted subclass.
emitter.emit_line('static PyObject *{}(PyTypeObject *type);'.format(setup_name))
assert cl.ctor is not None
emitter.emit_line(native_function_header(cl.ctor, emitter) + ';')
emit_line()
generate_new_for_class(cl, new_name, vtable_name, setup_name, emitter)
emit_line()
generate_traverse_for_class(cl, traverse_name, emitter)
emit_line()
generate_clear_for_class(cl, clear_name, emitter)
emit_line()
generate_dealloc_for_class(cl, dealloc_name, clear_name, emitter)
emit_line()
if cl.allow_interpreted_subclasses:
shadow_vtable_name: Optional[str] = generate_vtables(
cl, vtable_setup_name + "_shadow", vtable_name + "_shadow", emitter, shadow=True
)
emit_line()
else:
shadow_vtable_name = None
vtable_name = generate_vtables(cl, vtable_setup_name, vtable_name, emitter, shadow=False)
emit_line()
if needs_getseters:
generate_getseter_declarations(cl, emitter)
emit_line()
generate_getseters_table(cl, getseters_name, emitter)
emit_line()
if cl.is_trait:
generate_new_for_trait(cl, new_name, emitter)
generate_methods_table(cl, methods_name, emitter)
emit_line()
flags = ['Py_TPFLAGS_DEFAULT', 'Py_TPFLAGS_HEAPTYPE', 'Py_TPFLAGS_BASETYPE']
if generate_full:
flags.append('Py_TPFLAGS_HAVE_GC')
if cl.has_method('__call__') and emitter.use_vectorcall():
fields['tp_vectorcall_offset'] = 'offsetof({}, vectorcall)'.format(
cl.struct_name(emitter.names))
flags.append('_Py_TPFLAGS_HAVE_VECTORCALL')
fields['tp_flags'] = ' | '.join(flags)
emitter.emit_line("static PyTypeObject {}_template_ = {{".format(emitter.type_struct_name(cl)))
emitter.emit_line("PyVarObject_HEAD_INIT(NULL, 0)")
for field, value in fields.items():
emitter.emit_line(".{} = {},".format(field, value))
emitter.emit_line("};")
emitter.emit_line("static PyTypeObject *{t}_template = &{t}_template_;".format(
t=emitter.type_struct_name(cl)))
emitter.emit_line()
if generate_full:
generate_setup_for_class(
cl, setup_name, defaults_fn, vtable_name, shadow_vtable_name, emitter)
emitter.emit_line()
generate_constructor_for_class(
cl, cl.ctor, init_fn, setup_name, vtable_name, emitter)
emitter.emit_line()
if needs_getseters:
generate_getseters(cl, emitter)
def getter_name(cl: ClassIR, attribute: str, names: NameGenerator) -> str:
return names.private_name(cl.module_name, '{}_get{}'.format(cl.name, attribute))
def setter_name(cl: ClassIR, attribute: str, names: NameGenerator) -> str:
return names.private_name(cl.module_name, '{}_set{}'.format(cl.name, attribute))
def generate_object_struct(cl: ClassIR, emitter: Emitter) -> None:
seen_attrs: Set[Tuple[str, RType]] = set()
lines: List[str] = []
lines += ['typedef struct {',
'PyObject_HEAD',
'CPyVTableItem *vtable;']
if cl.has_method('__call__') and emitter.use_vectorcall():
lines.append('vectorcallfunc vectorcall;')
for base in reversed(cl.base_mro):
if not base.is_trait:
for attr, rtype in base.attributes.items():
if (attr, rtype) not in seen_attrs:
lines.append('{}{};'.format(emitter.ctype_spaced(rtype),
emitter.attr(attr)))
seen_attrs.add((attr, rtype))
if isinstance(rtype, RTuple):
emitter.declare_tuple_struct(rtype)
lines.append('}} {};'.format(cl.struct_name(emitter.names)))
lines.append('')
emitter.context.declarations[cl.struct_name(emitter.names)] = HeaderDeclaration(
lines,
is_type=True
)
def generate_vtables(base: ClassIR,
vtable_setup_name: str,
vtable_name: str,
emitter: Emitter,
shadow: bool) -> str:
"""Emit the vtables and vtable setup functions for a class.
This includes both the primary vtable and any trait implementation vtables.
The trait vtables go before the main vtable, and have the following layout:
{
CPyType_T1, // pointer to type object
C_T1_trait_vtable, // pointer to array of method pointers
C_T1_offset_table, // pointer to array of attribute offsets
CPyType_T2,
C_T2_trait_vtable,
C_T2_offset_table,
...
}
The method implementations are calculated at the end of IR pass, attribute
offsets are {offsetof(native__C, _x1), offsetof(native__C, _y1), ...}.
To account for both dynamic loading and dynamic class creation,
vtables are populated dynamically at class creation time, so we
emit empty array definitions to store the vtables and a function to
populate them.
If shadow is True, generate "shadow vtables" that point to the
shadow glue methods (which should dispatch via the Python C-API).
Returns the expression to use to refer to the vtable, which might be
different than the name, if there are trait vtables.
"""
def trait_vtable_name(trait: ClassIR) -> str:
return '{}_{}_trait_vtable{}'.format(
base.name_prefix(emitter.names), trait.name_prefix(emitter.names),
'_shadow' if shadow else '')
def trait_offset_table_name(trait: ClassIR) -> str:
return '{}_{}_offset_table'.format(
base.name_prefix(emitter.names), trait.name_prefix(emitter.names)
)
# Emit array definitions with enough space for all the entries
emitter.emit_line('static CPyVTableItem {}[{}];'.format(
vtable_name,
max(1, len(base.vtable_entries) + 3 * len(base.trait_vtables))))
for trait, vtable in base.trait_vtables.items():
# Trait methods entry (vtable index -> method implementation).
emitter.emit_line('static CPyVTableItem {}[{}];'.format(
trait_vtable_name(trait),
max(1, len(vtable))))
# Trait attributes entry (attribute number in trait -> offset in actual struct).
emitter.emit_line('static size_t {}[{}];'.format(
trait_offset_table_name(trait),
max(1, len(trait.attributes)))
)
# Emit vtable setup function
emitter.emit_line('static bool')
emitter.emit_line('{}{}(void)'.format(NATIVE_PREFIX, vtable_setup_name))
emitter.emit_line('{')
if base.allow_interpreted_subclasses and not shadow:
emitter.emit_line('{}{}_shadow();'.format(NATIVE_PREFIX, vtable_setup_name))
subtables = []
for trait, vtable in base.trait_vtables.items():
name = trait_vtable_name(trait)
offset_name = trait_offset_table_name(trait)
generate_vtable(vtable, name, emitter, [], shadow)
generate_offset_table(offset_name, emitter, trait, base)
subtables.append((trait, name, offset_name))
generate_vtable(base.vtable_entries, vtable_name, emitter, subtables, shadow)
emitter.emit_line('return 1;')
emitter.emit_line('}')
return vtable_name if not subtables else "{} + {}".format(vtable_name, len(subtables) * 3)
def generate_offset_table(trait_offset_table_name: str,
emitter: Emitter,
trait: ClassIR,
cl: ClassIR) -> None:
"""Generate attribute offset row of a trait vtable."""
emitter.emit_line('size_t {}_scratch[] = {{'.format(trait_offset_table_name))
for attr in trait.attributes:
emitter.emit_line('offsetof({}, {}),'.format(
cl.struct_name(emitter.names), emitter.attr(attr)
))
if not trait.attributes:
# This is for msvc.
emitter.emit_line('0')
emitter.emit_line('};')
emitter.emit_line('memcpy({name}, {name}_scratch, sizeof({name}));'.format(
name=trait_offset_table_name)
)
def generate_vtable(entries: VTableEntries,
vtable_name: str,
emitter: Emitter,
subtables: List[Tuple[ClassIR, str, str]],
shadow: bool) -> None:
emitter.emit_line('CPyVTableItem {}_scratch[] = {{'.format(vtable_name))
if subtables:
emitter.emit_line('/* Array of trait vtables */')
for trait, table, offset_table in subtables:
emitter.emit_line(
'(CPyVTableItem){}, (CPyVTableItem){}, (CPyVTableItem){},'.format(
emitter.type_struct_name(trait), table, offset_table))
emitter.emit_line('/* Start of real vtable */')
for entry in entries:
method = entry.shadow_method if shadow and entry.shadow_method else entry.method
emitter.emit_line('(CPyVTableItem){}{}{},'.format(
emitter.get_group_prefix(entry.method.decl),
NATIVE_PREFIX,
method.cname(emitter.names)))
# msvc doesn't allow empty arrays; maybe allowing them at all is an extension?
if not entries:
emitter.emit_line('NULL')
emitter.emit_line('};')
emitter.emit_line('memcpy({name}, {name}_scratch, sizeof({name}));'.format(name=vtable_name))
def generate_setup_for_class(cl: ClassIR,
func_name: str,
defaults_fn: Optional[FuncIR],
vtable_name: str,
shadow_vtable_name: Optional[str],
emitter: Emitter) -> None:
"""Generate a native function that allocates an instance of a class."""
emitter.emit_line('static PyObject *')
emitter.emit_line('{}(PyTypeObject *type)'.format(func_name))
emitter.emit_line('{')
emitter.emit_line('{} *self;'.format(cl.struct_name(emitter.names)))
emitter.emit_line('self = ({struct} *)type->tp_alloc(type, 0);'.format(
struct=cl.struct_name(emitter.names)))
emitter.emit_line('if (self == NULL)')
emitter.emit_line(' return NULL;')
if shadow_vtable_name:
emitter.emit_line('if (type != {}) {{'.format(emitter.type_struct_name(cl)))
emitter.emit_line('self->vtable = {};'.format(shadow_vtable_name))
emitter.emit_line('} else {')
emitter.emit_line('self->vtable = {};'.format(vtable_name))
emitter.emit_line('}')
else:
emitter.emit_line('self->vtable = {};'.format(vtable_name))
if cl.has_method('__call__') and emitter.use_vectorcall():
name = cl.method_decl('__call__').cname(emitter.names)
emitter.emit_line('self->vectorcall = {}{};'.format(PREFIX, name))
for base in reversed(cl.base_mro):
for attr, rtype in base.attributes.items():
emitter.emit_line('self->{} = {};'.format(
emitter.attr(attr), emitter.c_undefined_value(rtype)))
# Initialize attributes to default values, if necessary
if defaults_fn is not None:
emitter.emit_lines(
'if ({}{}((PyObject *)self) == 0) {{'.format(
NATIVE_PREFIX, defaults_fn.cname(emitter.names)),
'Py_DECREF(self);',
'return NULL;',
'}')
emitter.emit_line('return (PyObject *)self;')
emitter.emit_line('}')
def generate_constructor_for_class(cl: ClassIR,
fn: FuncDecl,
init_fn: Optional[FuncIR],
setup_name: str,
vtable_name: str,
emitter: Emitter) -> None:
"""Generate a native function that allocates and initializes an instance of a class."""
emitter.emit_line('{}'.format(native_function_header(fn, emitter)))
emitter.emit_line('{')
emitter.emit_line('PyObject *self = {}({});'.format(setup_name, emitter.type_struct_name(cl)))
emitter.emit_line('if (self == NULL)')
emitter.emit_line(' return NULL;')
args = ', '.join(['self'] + [REG_PREFIX + arg.name for arg in fn.sig.args])
if init_fn is not None:
emitter.emit_line('char res = {}{}{}({});'.format(
emitter.get_group_prefix(init_fn.decl),
NATIVE_PREFIX, init_fn.cname(emitter.names), args))
emitter.emit_line('if (res == 2) {')
emitter.emit_line('Py_DECREF(self);')
emitter.emit_line('return NULL;')
emitter.emit_line('}')
# If there is a nontrivial ctor that we didn't define, invoke it via tp_init
elif len(fn.sig.args) > 1:
emitter.emit_line(
'int res = {}->tp_init({});'.format(
emitter.type_struct_name(cl),
args))
emitter.emit_line('if (res < 0) {')
emitter.emit_line('Py_DECREF(self);')
emitter.emit_line('return NULL;')
emitter.emit_line('}')
emitter.emit_line('return self;')
emitter.emit_line('}')
def generate_init_for_class(cl: ClassIR,
init_fn: FuncIR,
emitter: Emitter) -> str:
"""Generate an init function suitable for use as tp_init.
tp_init needs to be a function that returns an int, and our
__init__ methods return a PyObject. Translate NULL to -1,
everything else to 0.
"""
func_name = '{}_init'.format(cl.name_prefix(emitter.names))
emitter.emit_line('static int')
emitter.emit_line(
'{}(PyObject *self, PyObject *args, PyObject *kwds)'.format(func_name))
emitter.emit_line('{')
emitter.emit_line('return {}{}(self, args, kwds) != NULL ? 0 : -1;'.format(
PREFIX, init_fn.cname(emitter.names)))
emitter.emit_line('}')
return func_name
def generate_new_for_class(cl: ClassIR,
func_name: str,
vtable_name: str,
setup_name: str,
emitter: Emitter) -> None:
emitter.emit_line('static PyObject *')
emitter.emit_line(
'{}(PyTypeObject *type, PyObject *args, PyObject *kwds)'.format(func_name))
emitter.emit_line('{')
# TODO: Check and unbox arguments
if not cl.allow_interpreted_subclasses:
emitter.emit_line('if (type != {}) {{'.format(emitter.type_struct_name(cl)))
emitter.emit_line(
'PyErr_SetString(PyExc_TypeError, "interpreted classes cannot inherit from compiled");'
)
emitter.emit_line('return NULL;')
emitter.emit_line('}')
emitter.emit_line('return {}(type);'.format(setup_name))
emitter.emit_line('}')
def generate_new_for_trait(cl: ClassIR,
func_name: str,
emitter: Emitter) -> None:
emitter.emit_line('static PyObject *')
emitter.emit_line(
'{}(PyTypeObject *type, PyObject *args, PyObject *kwds)'.format(func_name))
emitter.emit_line('{')
emitter.emit_line('if (type != {}) {{'.format(emitter.type_struct_name(cl)))
emitter.emit_line(
'PyErr_SetString(PyExc_TypeError, '
'"interpreted classes cannot inherit from compiled traits");'
)
emitter.emit_line('} else {')
emitter.emit_line(
'PyErr_SetString(PyExc_TypeError, "traits may not be directly created");'
)
emitter.emit_line('}')
emitter.emit_line('return NULL;')
emitter.emit_line('}')
def generate_traverse_for_class(cl: ClassIR,
func_name: str,
emitter: Emitter) -> None:
"""Emit function that performs cycle GC traversal of an instance."""
emitter.emit_line('static int')
emitter.emit_line('{}({} *self, visitproc visit, void *arg)'.format(
func_name,
cl.struct_name(emitter.names)))
emitter.emit_line('{')
for base in reversed(cl.base_mro):
for attr, rtype in base.attributes.items():
emitter.emit_gc_visit('self->{}'.format(emitter.attr(attr)), rtype)
if cl.has_dict:
struct_name = cl.struct_name(emitter.names)
# __dict__ lives right after the struct and __weakref__ lives right after that
emitter.emit_gc_visit('*((PyObject **)((char *)self + sizeof({})))'.format(
struct_name), object_rprimitive)
emitter.emit_gc_visit(
'*((PyObject **)((char *)self + sizeof(PyObject *) + sizeof({})))'.format(
struct_name),
object_rprimitive)
emitter.emit_line('return 0;')
emitter.emit_line('}')
def generate_clear_for_class(cl: ClassIR,
func_name: str,
emitter: Emitter) -> None:
emitter.emit_line('static int')
emitter.emit_line('{}({} *self)'.format(func_name, cl.struct_name(emitter.names)))
emitter.emit_line('{')
for base in reversed(cl.base_mro):
for attr, rtype in base.attributes.items():
emitter.emit_gc_clear('self->{}'.format(emitter.attr(attr)), rtype)
if cl.has_dict:
struct_name = cl.struct_name(emitter.names)
# __dict__ lives right after the struct and __weakref__ lives right after that
emitter.emit_gc_clear('*((PyObject **)((char *)self + sizeof({})))'.format(
struct_name), object_rprimitive)
emitter.emit_gc_clear(
'*((PyObject **)((char *)self + sizeof(PyObject *) + sizeof({})))'.format(
struct_name),
object_rprimitive)
emitter.emit_line('return 0;')
emitter.emit_line('}')
def generate_dealloc_for_class(cl: ClassIR,
dealloc_func_name: str,
clear_func_name: str,
emitter: Emitter) -> None:
emitter.emit_line('static void')
emitter.emit_line('{}({} *self)'.format(dealloc_func_name, cl.struct_name(emitter.names)))
emitter.emit_line('{')
emitter.emit_line('PyObject_GC_UnTrack(self);')
# The trashcan is needed to handle deep recursive deallocations
emitter.emit_line('CPy_TRASHCAN_BEGIN(self, {})'.format(dealloc_func_name))
emitter.emit_line('{}(self);'.format(clear_func_name))
emitter.emit_line('Py_TYPE(self)->tp_free((PyObject *)self);')
emitter.emit_line('CPy_TRASHCAN_END(self)')
emitter.emit_line('}')
def generate_methods_table(cl: ClassIR,
name: str,
emitter: Emitter) -> None:
emitter.emit_line('static PyMethodDef {}[] = {{'.format(name))
for fn in cl.methods.values():
if fn.decl.is_prop_setter or fn.decl.is_prop_getter:
continue
emitter.emit_line('{{"{}",'.format(fn.name))
emitter.emit_line(' (PyCFunction){}{},'.format(PREFIX, fn.cname(emitter.names)))
if use_fastcall(emitter.capi_version):
flags = ['METH_FASTCALL']
else:
flags = ['METH_VARARGS']
flags.append('METH_KEYWORDS')
if fn.decl.kind == FUNC_STATICMETHOD:
flags.append('METH_STATIC')
elif fn.decl.kind == FUNC_CLASSMETHOD:
flags.append('METH_CLASS')
emitter.emit_line(' {}, NULL}},'.format(' | '.join(flags)))
# Provide a default __getstate__ and __setstate__
if not cl.has_method('__setstate__') and not cl.has_method('__getstate__'):
emitter.emit_lines(
'{"__setstate__", (PyCFunction)CPyPickle_SetState, METH_O, NULL},',
'{"__getstate__", (PyCFunction)CPyPickle_GetState, METH_NOARGS, NULL},',
)
emitter.emit_line('{NULL} /* Sentinel */')
emitter.emit_line('};')
def generate_side_table_for_class(cl: ClassIR,
name: str,
type: str,
slots: Dict[str, str],
emitter: Emitter) -> Optional[str]:
name = '{}_{}'.format(cl.name_prefix(emitter.names), name)
emitter.emit_line('static {} {} = {{'.format(type, name))
for field, value in slots.items():
emitter.emit_line(".{} = {},".format(field, value))
emitter.emit_line("};")
return name
def generate_getseter_declarations(cl: ClassIR, emitter: Emitter) -> None:
if not cl.is_trait:
for attr in cl.attributes:
emitter.emit_line('static PyObject *')
emitter.emit_line('{}({} *self, void *closure);'.format(
getter_name(cl, attr, emitter.names),
cl.struct_name(emitter.names)))
emitter.emit_line('static int')
emitter.emit_line('{}({} *self, PyObject *value, void *closure);'.format(
setter_name(cl, attr, emitter.names),
cl.struct_name(emitter.names)))
for prop in cl.properties:
# Generate getter declaration
emitter.emit_line('static PyObject *')
emitter.emit_line('{}({} *self, void *closure);'.format(
getter_name(cl, prop, emitter.names),
cl.struct_name(emitter.names)))
# Generate property setter declaration if a setter exists
if cl.properties[prop][1]:
emitter.emit_line('static int')
emitter.emit_line('{}({} *self, PyObject *value, void *closure);'.format(
setter_name(cl, prop, emitter.names),
cl.struct_name(emitter.names)))
def generate_getseters_table(cl: ClassIR,
name: str,
emitter: Emitter) -> None:
emitter.emit_line('static PyGetSetDef {}[] = {{'.format(name))
if not cl.is_trait:
for attr in cl.attributes:
emitter.emit_line('{{"{}",'.format(attr))
emitter.emit_line(' (getter){}, (setter){},'.format(
getter_name(cl, attr, emitter.names), setter_name(cl, attr, emitter.names)))
emitter.emit_line(' NULL, NULL},')
for prop in cl.properties:
emitter.emit_line('{{"{}",'.format(prop))
emitter.emit_line(' (getter){},'.format(getter_name(cl, prop, emitter.names)))
setter = cl.properties[prop][1]
if setter:
emitter.emit_line(' (setter){},'.format(setter_name(cl, prop, emitter.names)))
emitter.emit_line('NULL, NULL},')
else:
emitter.emit_line('NULL, NULL, NULL},')
emitter.emit_line('{NULL} /* Sentinel */')
emitter.emit_line('};')
def generate_getseters(cl: ClassIR, emitter: Emitter) -> None:
if not cl.is_trait:
for i, (attr, rtype) in enumerate(cl.attributes.items()):
generate_getter(cl, attr, rtype, emitter)
emitter.emit_line('')
generate_setter(cl, attr, rtype, emitter)
if i < len(cl.attributes) - 1:
emitter.emit_line('')
for prop, (getter, setter) in cl.properties.items():
rtype = getter.sig.ret_type
emitter.emit_line('')
generate_readonly_getter(cl, prop, rtype, getter, emitter)
if setter:
arg_type = setter.sig.args[1].type
emitter.emit_line('')
generate_property_setter(cl, prop, arg_type, setter, emitter)
def generate_getter(cl: ClassIR,
attr: str,
rtype: RType,
emitter: Emitter) -> None:
attr_field = emitter.attr(attr)
emitter.emit_line('static PyObject *')
emitter.emit_line('{}({} *self, void *closure)'.format(getter_name(cl, attr, emitter.names),
cl.struct_name(emitter.names)))
emitter.emit_line('{')
attr_expr = 'self->{}'.format(attr_field)
emitter.emit_undefined_attr_check(rtype, attr_expr, '==', unlikely=True)
emitter.emit_line('PyErr_SetString(PyExc_AttributeError,')
emitter.emit_line(' "attribute {} of {} undefined");'.format(repr(attr),
repr(cl.name)))
emitter.emit_line('return NULL;')
emitter.emit_line('}')
emitter.emit_inc_ref('self->{}'.format(attr_field), rtype)
emitter.emit_box('self->{}'.format(attr_field), 'retval', rtype, declare_dest=True)
emitter.emit_line('return retval;')
emitter.emit_line('}')
def generate_setter(cl: ClassIR,
attr: str,
rtype: RType,
emitter: Emitter) -> None:
attr_field = emitter.attr(attr)
emitter.emit_line('static int')
emitter.emit_line('{}({} *self, PyObject *value, void *closure)'.format(
setter_name(cl, attr, emitter.names),
cl.struct_name(emitter.names)))
emitter.emit_line('{')
deletable = cl.is_deletable(attr)
if not deletable:
emitter.emit_line('if (value == NULL) {')
emitter.emit_line('PyErr_SetString(PyExc_AttributeError,')
emitter.emit_line(' "{} object attribute {} cannot be deleted");'.format(repr(cl.name),
repr(attr)))
emitter.emit_line('return -1;')
emitter.emit_line('}')
if rtype.is_refcounted:
attr_expr = 'self->{}'.format(attr_field)
emitter.emit_undefined_attr_check(rtype, attr_expr, '!=')
emitter.emit_dec_ref('self->{}'.format(attr_field), rtype)
emitter.emit_line('}')
if deletable:
emitter.emit_line('if (value != NULL) {')
if rtype.is_unboxed:
emitter.emit_unbox('value', 'tmp', rtype, error=ReturnHandler('-1'), declare_dest=True)
elif is_same_type(rtype, object_rprimitive):
emitter.emit_line('PyObject *tmp = value;')
else:
emitter.emit_cast('value', 'tmp', rtype, declare_dest=True)
emitter.emit_lines('if (!tmp)',
' return -1;')
emitter.emit_inc_ref('tmp', rtype)
emitter.emit_line('self->{} = tmp;'.format(attr_field))
if deletable:
emitter.emit_line('} else')
emitter.emit_line(' self->{} = {};'.format(attr_field,
emitter.c_undefined_value(rtype)))
emitter.emit_line('return 0;')
emitter.emit_line('}')
def generate_readonly_getter(cl: ClassIR,
attr: str,
rtype: RType,
func_ir: FuncIR,
emitter: Emitter) -> None:
emitter.emit_line('static PyObject *')
emitter.emit_line('{}({} *self, void *closure)'.format(getter_name(cl, attr, emitter.names),
cl.struct_name(emitter.names)))
emitter.emit_line('{')
if rtype.is_unboxed:
emitter.emit_line('{}retval = {}{}((PyObject *) self);'.format(
emitter.ctype_spaced(rtype), NATIVE_PREFIX, func_ir.cname(emitter.names)))
emitter.emit_box('retval', 'retbox', rtype, declare_dest=True)
emitter.emit_line('return retbox;')
else:
emitter.emit_line('return {}{}((PyObject *) self);'.format(NATIVE_PREFIX,
func_ir.cname(emitter.names)))
emitter.emit_line('}')
def generate_property_setter(cl: ClassIR,
attr: str,
arg_type: RType,
func_ir: FuncIR,
emitter: Emitter) -> None:
emitter.emit_line('static int')
emitter.emit_line('{}({} *self, PyObject *value, void *closure)'.format(
setter_name(cl, attr, emitter.names),
cl.struct_name(emitter.names)))
emitter.emit_line('{')
if arg_type.is_unboxed:
emitter.emit_unbox('value', 'tmp', arg_type, error=ReturnHandler('-1'),
declare_dest=True)
emitter.emit_line('{}{}((PyObject *) self, tmp);'.format(
NATIVE_PREFIX,
func_ir.cname(emitter.names)))
else:
emitter.emit_line('{}{}((PyObject *) self, value);'.format(
NATIVE_PREFIX,
func_ir.cname(emitter.names)))
emitter.emit_line('return 0;')
emitter.emit_line('}')

View file

@ -0,0 +1,618 @@
"""Code generation for native function bodies."""
from typing import Union, Optional
from typing_extensions import Final
from mypyc.common import (
REG_PREFIX, NATIVE_PREFIX, STATIC_PREFIX, TYPE_PREFIX, MODULE_PREFIX,
)
from mypyc.codegen.emit import Emitter
from mypyc.ir.ops import (
OpVisitor, Goto, Branch, Return, Assign, Integer, LoadErrorValue, GetAttr, SetAttr,
LoadStatic, InitStatic, TupleGet, TupleSet, Call, IncRef, DecRef, Box, Cast, Unbox,
BasicBlock, Value, MethodCall, Unreachable, NAMESPACE_STATIC, NAMESPACE_TYPE, NAMESPACE_MODULE,
RaiseStandardError, CallC, LoadGlobal, Truncate, IntOp, LoadMem, GetElementPtr,
LoadAddress, ComparisonOp, SetMem, Register, LoadLiteral, AssignMulti, KeepAlive
)
from mypyc.ir.rtypes import (
RType, RTuple, RArray, is_tagged, is_int32_rprimitive, is_int64_rprimitive, RStruct,
is_pointer_rprimitive, is_int_rprimitive
)
from mypyc.ir.func_ir import FuncIR, FuncDecl, FUNC_STATICMETHOD, FUNC_CLASSMETHOD, all_values
from mypyc.ir.class_ir import ClassIR
from mypyc.ir.pprint import generate_names_for_ir
from mypyc.analysis.blockfreq import frequently_executed_blocks
# Whether to insert debug asserts for all error handling, to quickly
# catch errors propagating without exceptions set.
DEBUG_ERRORS = False
def native_function_type(fn: FuncIR, emitter: Emitter) -> str:
args = ', '.join(emitter.ctype(arg.type) for arg in fn.args) or 'void'
ret = emitter.ctype(fn.ret_type)
return '{} (*)({})'.format(ret, args)
def native_function_header(fn: FuncDecl, emitter: Emitter) -> str:
args = []
for arg in fn.sig.args:
args.append('{}{}{}'.format(emitter.ctype_spaced(arg.type), REG_PREFIX, arg.name))
return '{ret_type}{name}({args})'.format(
ret_type=emitter.ctype_spaced(fn.sig.ret_type),
name=emitter.native_function_name(fn),
args=', '.join(args) or 'void')
def generate_native_function(fn: FuncIR,
emitter: Emitter,
source_path: str,
module_name: str) -> None:
declarations = Emitter(emitter.context)
names = generate_names_for_ir(fn.arg_regs, fn.blocks)
body = Emitter(emitter.context, names)
visitor = FunctionEmitterVisitor(body, declarations, source_path, module_name)
declarations.emit_line('{} {{'.format(native_function_header(fn.decl, emitter)))
body.indent()
for r in all_values(fn.arg_regs, fn.blocks):
if isinstance(r.type, RTuple):
emitter.declare_tuple_struct(r.type)
if isinstance(r.type, RArray):
continue # Special: declared on first assignment
if r in fn.arg_regs:
continue # Skip the arguments
ctype = emitter.ctype_spaced(r.type)
init = ''
declarations.emit_line('{ctype}{prefix}{name}{init};'.format(ctype=ctype,
prefix=REG_PREFIX,
name=names[r],
init=init))
# Before we emit the blocks, give them all labels
blocks = fn.blocks
for i, block in enumerate(blocks):
block.label = i
common = frequently_executed_blocks(fn.blocks[0])
for i in range(len(blocks)):
block = blocks[i]
visitor.rare = block not in common
next_block = None
if i + 1 < len(blocks):
next_block = blocks[i + 1]
body.emit_label(block)
visitor.next_block = next_block
for op in block.ops:
op.accept(visitor)
body.emit_line('}')
emitter.emit_from_emitter(declarations)
emitter.emit_from_emitter(body)
class FunctionEmitterVisitor(OpVisitor[None]):
def __init__(self,
emitter: Emitter,
declarations: Emitter,
source_path: str,
module_name: str) -> None:
self.emitter = emitter
self.names = emitter.names
self.declarations = declarations
self.source_path = source_path
self.module_name = module_name
self.literals = emitter.context.literals
self.rare = False
self.next_block: Optional[BasicBlock] = None
def temp_name(self) -> str:
return self.emitter.temp_name()
def visit_goto(self, op: Goto) -> None:
if op.label is not self.next_block:
self.emit_line('goto %s;' % self.label(op.label))
def visit_branch(self, op: Branch) -> None:
true, false = op.true, op.false
negated = op.negated
negated_rare = False
if true is self.next_block and op.traceback_entry is None:
# Switch true/false since it avoids an else block.
true, false = false, true
negated = not negated
negated_rare = True
neg = '!' if negated else ''
cond = ''
if op.op == Branch.BOOL:
expr_result = self.reg(op.value)
cond = '{}{}'.format(neg, expr_result)
elif op.op == Branch.IS_ERROR:
typ = op.value.type
compare = '!=' if negated else '=='
if isinstance(typ, RTuple):
# TODO: What about empty tuple?
cond = self.emitter.tuple_undefined_check_cond(typ,
self.reg(op.value),
self.c_error_value,
compare)
else:
cond = '{} {} {}'.format(self.reg(op.value),
compare,
self.c_error_value(typ))
else:
assert False, "Invalid branch"
# For error checks, tell the compiler the branch is unlikely
if op.traceback_entry is not None or op.rare:
if not negated_rare:
cond = 'unlikely({})'.format(cond)
else:
cond = 'likely({})'.format(cond)
if false is self.next_block:
if op.traceback_entry is None:
self.emit_line('if ({}) goto {};'.format(cond, self.label(true)))
else:
self.emit_line('if ({}) {{'.format(cond))
self.emit_traceback(op)
self.emit_lines(
'goto %s;' % self.label(true),
'}'
)
else:
self.emit_line('if ({}) {{'.format(cond))
self.emit_traceback(op)
self.emit_lines(
'goto %s;' % self.label(true),
'} else',
' goto %s;' % self.label(false)
)
def visit_return(self, op: Return) -> None:
value_str = self.reg(op.value)
self.emit_line('return %s;' % value_str)
def visit_tuple_set(self, op: TupleSet) -> None:
dest = self.reg(op)
tuple_type = op.tuple_type
self.emitter.declare_tuple_struct(tuple_type)
if len(op.items) == 0: # empty tuple
self.emit_line('{}.empty_struct_error_flag = 0;'.format(dest))
else:
for i, item in enumerate(op.items):
self.emit_line('{}.f{} = {};'.format(dest, i, self.reg(item)))
self.emit_inc_ref(dest, tuple_type)
def visit_assign(self, op: Assign) -> None:
dest = self.reg(op.dest)
src = self.reg(op.src)
# clang whines about self assignment (which we might generate
# for some casts), so don't emit it.
if dest != src:
self.emit_line('%s = %s;' % (dest, src))
def visit_assign_multi(self, op: AssignMulti) -> None:
typ = op.dest.type
assert isinstance(typ, RArray)
dest = self.reg(op.dest)
# RArray values can only be assigned to once, so we can always
# declare them on initialization.
self.emit_line('%s%s[%d] = {%s};' % (
self.emitter.ctype_spaced(typ.item_type),
dest,
len(op.src),
', '.join(self.reg(s) for s in op.src)))
def visit_load_error_value(self, op: LoadErrorValue) -> None:
if isinstance(op.type, RTuple):
values = [self.c_undefined_value(item) for item in op.type.types]
tmp = self.temp_name()
self.emit_line('%s %s = { %s };' % (self.ctype(op.type), tmp, ', '.join(values)))
self.emit_line('%s = %s;' % (self.reg(op), tmp))
else:
self.emit_line('%s = %s;' % (self.reg(op),
self.c_error_value(op.type)))
def visit_load_literal(self, op: LoadLiteral) -> None:
index = self.literals.literal_index(op.value)
s = repr(op.value)
if not any(x in s for x in ('/*', '*/', '\0')):
ann = ' /* %s */' % s
else:
ann = ''
if not is_int_rprimitive(op.type):
self.emit_line('%s = CPyStatics[%d];%s' % (self.reg(op), index, ann))
else:
self.emit_line('%s = (CPyTagged)CPyStatics[%d] | 1;%s' % (
self.reg(op), index, ann))
def get_attr_expr(self, obj: str, op: Union[GetAttr, SetAttr], decl_cl: ClassIR) -> str:
"""Generate attribute accessor for normal (non-property) access.
This either has a form like obj->attr_name for attributes defined in non-trait
classes, and *(obj + attr_offset) for attributes defined by traits. We also
insert all necessary C casts here.
"""
cast = '({} *)'.format(op.class_type.struct_name(self.emitter.names))
if decl_cl.is_trait and op.class_type.class_ir.is_trait:
# For pure trait access find the offset first, offsets
# are ordered by attribute position in the cl.attributes dict.
# TODO: pre-calculate the mapping to make this faster.
trait_attr_index = list(decl_cl.attributes).index(op.attr)
# TODO: reuse these names somehow?
offset = self.emitter.temp_name()
self.declarations.emit_line('size_t {};'.format(offset))
self.emitter.emit_line('{} = {};'.format(
offset,
'CPy_FindAttrOffset({}, {}, {})'.format(
self.emitter.type_struct_name(decl_cl),
'({}{})->vtable'.format(cast, obj),
trait_attr_index,
)
))
attr_cast = '({} *)'.format(self.ctype(op.class_type.attr_type(op.attr)))
return '*{}((char *){} + {})'.format(attr_cast, obj, offset)
else:
# Cast to something non-trait. Note: for this to work, all struct
# members for non-trait classes must obey monotonic linear growth.
if op.class_type.class_ir.is_trait:
assert not decl_cl.is_trait
cast = '({} *)'.format(decl_cl.struct_name(self.emitter.names))
return '({}{})->{}'.format(
cast, obj, self.emitter.attr(op.attr)
)
def visit_get_attr(self, op: GetAttr) -> None:
dest = self.reg(op)
obj = self.reg(op.obj)
rtype = op.class_type
cl = rtype.class_ir
attr_rtype, decl_cl = cl.attr_details(op.attr)
if cl.get_method(op.attr):
# Properties are essentially methods, so use vtable access for them.
version = '_TRAIT' if cl.is_trait else ''
self.emit_line('%s = CPY_GET_ATTR%s(%s, %s, %d, %s, %s); /* %s */' % (
dest,
version,
obj,
self.emitter.type_struct_name(rtype.class_ir),
rtype.getter_index(op.attr),
rtype.struct_name(self.names),
self.ctype(rtype.attr_type(op.attr)),
op.attr))
else:
# Otherwise, use direct or offset struct access.
attr_expr = self.get_attr_expr(obj, op, decl_cl)
self.emitter.emit_line('{} = {};'.format(dest, attr_expr))
self.emitter.emit_undefined_attr_check(
attr_rtype, attr_expr, '==', unlikely=True
)
exc_class = 'PyExc_AttributeError'
self.emitter.emit_line(
'PyErr_SetString({}, "attribute {} of {} undefined");'.format(
exc_class, repr(op.attr), repr(cl.name)))
if attr_rtype.is_refcounted:
self.emitter.emit_line('} else {')
self.emitter.emit_inc_ref(attr_expr, attr_rtype)
self.emitter.emit_line('}')
def visit_set_attr(self, op: SetAttr) -> None:
dest = self.reg(op)
obj = self.reg(op.obj)
src = self.reg(op.src)
rtype = op.class_type
cl = rtype.class_ir
attr_rtype, decl_cl = cl.attr_details(op.attr)
if cl.get_method(op.attr):
# Again, use vtable access for properties...
version = '_TRAIT' if cl.is_trait else ''
self.emit_line('%s = CPY_SET_ATTR%s(%s, %s, %d, %s, %s, %s); /* %s */' % (
dest,
version,
obj,
self.emitter.type_struct_name(rtype.class_ir),
rtype.setter_index(op.attr),
src,
rtype.struct_name(self.names),
self.ctype(rtype.attr_type(op.attr)),
op.attr))
else:
# ...and struct access for normal attributes.
attr_expr = self.get_attr_expr(obj, op, decl_cl)
if attr_rtype.is_refcounted:
self.emitter.emit_undefined_attr_check(attr_rtype, attr_expr, '!=')
self.emitter.emit_dec_ref(attr_expr, attr_rtype)
self.emitter.emit_line('}')
# This steal the reference to src, so we don't need to increment the arg
self.emitter.emit_lines(
'{} = {};'.format(attr_expr, src),
'{} = 1;'.format(dest),
)
PREFIX_MAP: Final = {
NAMESPACE_STATIC: STATIC_PREFIX,
NAMESPACE_TYPE: TYPE_PREFIX,
NAMESPACE_MODULE: MODULE_PREFIX,
}
def visit_load_static(self, op: LoadStatic) -> None:
dest = self.reg(op)
prefix = self.PREFIX_MAP[op.namespace]
name = self.emitter.static_name(op.identifier, op.module_name, prefix)
if op.namespace == NAMESPACE_TYPE:
name = '(PyObject *)%s' % name
ann = ''
if op.ann:
s = repr(op.ann)
if not any(x in s for x in ('/*', '*/', '\0')):
ann = ' /* %s */' % s
self.emit_line('%s = %s;%s' % (dest, name, ann))
def visit_init_static(self, op: InitStatic) -> None:
value = self.reg(op.value)
prefix = self.PREFIX_MAP[op.namespace]
name = self.emitter.static_name(op.identifier, op.module_name, prefix)
if op.namespace == NAMESPACE_TYPE:
value = '(PyTypeObject *)%s' % value
self.emit_line('%s = %s;' % (name, value))
self.emit_inc_ref(name, op.value.type)
def visit_tuple_get(self, op: TupleGet) -> None:
dest = self.reg(op)
src = self.reg(op.src)
self.emit_line('{} = {}.f{};'.format(dest, src, op.index))
self.emit_inc_ref(dest, op.type)
def get_dest_assign(self, dest: Value) -> str:
if not dest.is_void:
return self.reg(dest) + ' = '
else:
return ''
def visit_call(self, op: Call) -> None:
"""Call native function."""
dest = self.get_dest_assign(op)
args = ', '.join(self.reg(arg) for arg in op.args)
lib = self.emitter.get_group_prefix(op.fn)
cname = op.fn.cname(self.names)
self.emit_line('%s%s%s%s(%s);' % (dest, lib, NATIVE_PREFIX, cname, args))
def visit_method_call(self, op: MethodCall) -> None:
"""Call native method."""
dest = self.get_dest_assign(op)
obj = self.reg(op.obj)
rtype = op.receiver_type
class_ir = rtype.class_ir
name = op.method
method = rtype.class_ir.get_method(name)
assert method is not None
# Can we call the method directly, bypassing vtable?
is_direct = class_ir.is_method_final(name)
# The first argument gets omitted for static methods and
# turned into the class for class methods
obj_args = (
[] if method.decl.kind == FUNC_STATICMETHOD else
['(PyObject *)Py_TYPE({})'.format(obj)] if method.decl.kind == FUNC_CLASSMETHOD else
[obj])
args = ', '.join(obj_args + [self.reg(arg) for arg in op.args])
mtype = native_function_type(method, self.emitter)
version = '_TRAIT' if rtype.class_ir.is_trait else ''
if is_direct:
# Directly call method, without going through the vtable.
lib = self.emitter.get_group_prefix(method.decl)
self.emit_line('{}{}{}{}({});'.format(
dest, lib, NATIVE_PREFIX, method.cname(self.names), args))
else:
# Call using vtable.
method_idx = rtype.method_index(name)
self.emit_line('{}CPY_GET_METHOD{}({}, {}, {}, {}, {})({}); /* {} */'.format(
dest, version, obj, self.emitter.type_struct_name(rtype.class_ir),
method_idx, rtype.struct_name(self.names), mtype, args, op.method))
def visit_inc_ref(self, op: IncRef) -> None:
src = self.reg(op.src)
self.emit_inc_ref(src, op.src.type)
def visit_dec_ref(self, op: DecRef) -> None:
src = self.reg(op.src)
self.emit_dec_ref(src, op.src.type, is_xdec=op.is_xdec)
def visit_box(self, op: Box) -> None:
self.emitter.emit_box(self.reg(op.src), self.reg(op), op.src.type, can_borrow=True)
def visit_cast(self, op: Cast) -> None:
self.emitter.emit_cast(self.reg(op.src), self.reg(op), op.type,
src_type=op.src.type)
def visit_unbox(self, op: Unbox) -> None:
self.emitter.emit_unbox(self.reg(op.src), self.reg(op), op.type)
def visit_unreachable(self, op: Unreachable) -> None:
self.emitter.emit_line('CPy_Unreachable();')
def visit_raise_standard_error(self, op: RaiseStandardError) -> None:
# TODO: Better escaping of backspaces and such
if op.value is not None:
if isinstance(op.value, str):
message = op.value.replace('"', '\\"')
self.emitter.emit_line(
'PyErr_SetString(PyExc_{}, "{}");'.format(op.class_name, message))
elif isinstance(op.value, Value):
self.emitter.emit_line(
'PyErr_SetObject(PyExc_{}, {});'.format(op.class_name,
self.emitter.reg(op.value)))
else:
assert False, 'op value type must be either str or Value'
else:
self.emitter.emit_line('PyErr_SetNone(PyExc_{});'.format(op.class_name))
self.emitter.emit_line('{} = 0;'.format(self.reg(op)))
def visit_call_c(self, op: CallC) -> None:
if op.is_void:
dest = ''
else:
dest = self.get_dest_assign(op)
args = ', '.join(self.reg(arg) for arg in op.args)
self.emitter.emit_line("{}{}({});".format(dest, op.function_name, args))
def visit_truncate(self, op: Truncate) -> None:
dest = self.reg(op)
value = self.reg(op.src)
# for C backend the generated code are straight assignments
self.emit_line("{} = {};".format(dest, value))
def visit_load_global(self, op: LoadGlobal) -> None:
dest = self.reg(op)
ann = ''
if op.ann:
s = repr(op.ann)
if not any(x in s for x in ('/*', '*/', '\0')):
ann = ' /* %s */' % s
self.emit_line('%s = %s;%s' % (dest, op.identifier, ann))
def visit_int_op(self, op: IntOp) -> None:
dest = self.reg(op)
lhs = self.reg(op.lhs)
rhs = self.reg(op.rhs)
self.emit_line('%s = %s %s %s;' % (dest, lhs, op.op_str[op.op], rhs))
def visit_comparison_op(self, op: ComparisonOp) -> None:
dest = self.reg(op)
lhs = self.reg(op.lhs)
rhs = self.reg(op.rhs)
lhs_cast = ""
rhs_cast = ""
if op.op in (ComparisonOp.SLT, ComparisonOp.SGT, ComparisonOp.SLE, ComparisonOp.SGE):
# Always signed comparison op
lhs_cast = self.emit_signed_int_cast(op.lhs.type)
rhs_cast = self.emit_signed_int_cast(op.rhs.type)
elif op.op in (ComparisonOp.ULT, ComparisonOp.UGT, ComparisonOp.ULE, ComparisonOp.UGE):
# Always unsigned comparison op
lhs_cast = self.emit_unsigned_int_cast(op.lhs.type)
rhs_cast = self.emit_unsigned_int_cast(op.rhs.type)
elif isinstance(op.lhs, Integer) and op.lhs.value < 0:
# Force signed ==/!= with negative operand
rhs_cast = self.emit_signed_int_cast(op.rhs.type)
elif isinstance(op.rhs, Integer) and op.rhs.value < 0:
# Force signed ==/!= with negative operand
lhs_cast = self.emit_signed_int_cast(op.lhs.type)
self.emit_line('%s = %s%s %s %s%s;' % (dest, lhs_cast, lhs,
op.op_str[op.op], rhs_cast, rhs))
def visit_load_mem(self, op: LoadMem) -> None:
dest = self.reg(op)
src = self.reg(op.src)
# TODO: we shouldn't dereference to type that are pointer type so far
type = self.ctype(op.type)
self.emit_line('%s = *(%s *)%s;' % (dest, type, src))
def visit_set_mem(self, op: SetMem) -> None:
dest = self.reg(op.dest)
src = self.reg(op.src)
dest_type = self.ctype(op.dest_type)
# clang whines about self assignment (which we might generate
# for some casts), so don't emit it.
if dest != src:
self.emit_line('*(%s *)%s = %s;' % (dest_type, dest, src))
def visit_get_element_ptr(self, op: GetElementPtr) -> None:
dest = self.reg(op)
src = self.reg(op.src)
# TODO: support tuple type
assert isinstance(op.src_type, RStruct)
assert op.field in op.src_type.names, "Invalid field name."
self.emit_line('%s = (%s)&((%s *)%s)->%s;' % (dest, op.type._ctype, op.src_type.name,
src, op.field))
def visit_load_address(self, op: LoadAddress) -> None:
typ = op.type
dest = self.reg(op)
src = self.reg(op.src) if isinstance(op.src, Register) else op.src
self.emit_line('%s = (%s)&%s;' % (dest, typ._ctype, src))
def visit_keep_alive(self, op: KeepAlive) -> None:
# This is a no-op.
pass
# Helpers
def label(self, label: BasicBlock) -> str:
return self.emitter.label(label)
def reg(self, reg: Value) -> str:
if isinstance(reg, Integer):
val = reg.value
if val == 0 and is_pointer_rprimitive(reg.type):
return "NULL"
s = str(val)
if val >= (1 << 31):
# Avoid overflowing signed 32-bit int
s += 'ULL'
elif val == -(1 << 63):
# Avoid overflowing C integer literal
s = '(-9223372036854775807LL - 1)'
elif val <= -(1 << 31):
s += 'LL'
return s
else:
return self.emitter.reg(reg)
def ctype(self, rtype: RType) -> str:
return self.emitter.ctype(rtype)
def c_error_value(self, rtype: RType) -> str:
return self.emitter.c_error_value(rtype)
def c_undefined_value(self, rtype: RType) -> str:
return self.emitter.c_undefined_value(rtype)
def emit_line(self, line: str) -> None:
self.emitter.emit_line(line)
def emit_lines(self, *lines: str) -> None:
self.emitter.emit_lines(*lines)
def emit_inc_ref(self, dest: str, rtype: RType) -> None:
self.emitter.emit_inc_ref(dest, rtype, rare=self.rare)
def emit_dec_ref(self, dest: str, rtype: RType, is_xdec: bool) -> None:
self.emitter.emit_dec_ref(dest, rtype, is_xdec=is_xdec, rare=self.rare)
def emit_declaration(self, line: str) -> None:
self.declarations.emit_line(line)
def emit_traceback(self, op: Branch) -> None:
if op.traceback_entry is not None:
globals_static = self.emitter.static_name('globals', self.module_name)
self.emit_line('CPy_AddTraceback("%s", "%s", %d, %s);' % (
self.source_path.replace("\\", "\\\\"),
op.traceback_entry[0],
op.traceback_entry[1],
globals_static))
if DEBUG_ERRORS:
self.emit_line('assert(PyErr_Occurred() != NULL && "failure w/o err!");')
def emit_signed_int_cast(self, type: RType) -> str:
if is_tagged(type):
return '(Py_ssize_t)'
else:
return ''
def emit_unsigned_int_cast(self, type: RType) -> str:
if is_int32_rprimitive(type):
return '(uint32_t)'
elif is_int64_rprimitive(type):
return '(uint64_t)'
else:
return ''

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,822 @@
"""Generate CPython API wrapper functions for native functions.
The wrapper functions are used by the CPython runtime when calling
native functions from interpreted code, and when the called function
can't be determined statically in compiled code. They validate, match,
unbox and type check function arguments, and box return values as
needed. All wrappers accept and return 'PyObject *' (boxed) values.
The wrappers aren't used for most calls between two native functions
or methods in a single compilation unit.
"""
from typing import List, Dict, Optional, Sequence
from mypy.nodes import ArgKind, ARG_POS, ARG_OPT, ARG_NAMED_OPT, ARG_NAMED, ARG_STAR, ARG_STAR2
from mypy.operators import op_methods_to_symbols, reverse_op_methods, reverse_op_method_names
from mypyc.common import PREFIX, NATIVE_PREFIX, DUNDER_PREFIX, use_vectorcall
from mypyc.codegen.emit import Emitter, ErrorHandler, GotoHandler, AssignHandler, ReturnHandler
from mypyc.ir.rtypes import (
RType, RInstance, is_object_rprimitive, is_int_rprimitive, is_bool_rprimitive,
object_rprimitive
)
from mypyc.ir.func_ir import FuncIR, RuntimeArg, FUNC_STATICMETHOD
from mypyc.ir.class_ir import ClassIR
from mypyc.namegen import NameGenerator
# Generic vectorcall wrapper functions (Python 3.7+)
#
# A wrapper function has a signature like this:
#
# PyObject *fn(PyObject *self, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames)
#
# The function takes a self object, pointer to an array of arguments,
# the number of positional arguments, and a tuple of keyword argument
# names (that are stored starting in args[nargs]).
#
# It returns the returned object, or NULL on an exception.
#
# These are more efficient than legacy wrapper functions, since
# usually no tuple or dict objects need to be created for the
# arguments. Vectorcalls also use pre-constructed str objects for
# keyword argument names and other pre-computed information, instead
# of processing the argument format string on each call.
def wrapper_function_header(fn: FuncIR, names: NameGenerator) -> str:
"""Return header of a vectorcall wrapper function.
See comment above for a summary of the arguments.
"""
return (
'PyObject *{prefix}{name}('
'PyObject *self, PyObject *const *args, size_t nargs, PyObject *kwnames)').format(
prefix=PREFIX,
name=fn.cname(names))
def generate_traceback_code(fn: FuncIR,
emitter: Emitter,
source_path: str,
module_name: str) -> str:
# If we hit an error while processing arguments, then we emit a
# traceback frame to make it possible to debug where it happened.
# Unlike traceback frames added for exceptions seen in IR, we do this
# even if there is no `traceback_name`. This is because the error will
# have originated here and so we need it in the traceback.
globals_static = emitter.static_name('globals', module_name)
traceback_code = 'CPy_AddTraceback("%s", "%s", %d, %s);' % (
source_path.replace("\\", "\\\\"),
fn.traceback_name or fn.name,
fn.line,
globals_static)
return traceback_code
def make_arg_groups(args: List[RuntimeArg]) -> Dict[ArgKind, List[RuntimeArg]]:
"""Group arguments by kind."""
return {k: [arg for arg in args if arg.kind == k] for k in ArgKind}
def reorder_arg_groups(groups: Dict[ArgKind, List[RuntimeArg]]) -> List[RuntimeArg]:
"""Reorder argument groups to match their order in a format string."""
return groups[ARG_POS] + groups[ARG_OPT] + groups[ARG_NAMED_OPT] + groups[ARG_NAMED]
def make_static_kwlist(args: List[RuntimeArg]) -> str:
arg_names = ''.join('"{}", '.format(arg.name) for arg in args)
return 'static const char * const kwlist[] = {{{}0}};'.format(arg_names)
def make_format_string(func_name: Optional[str], groups: Dict[ArgKind, List[RuntimeArg]]) -> str:
"""Return a format string that specifies the accepted arguments.
The format string is an extended subset of what is supported by
PyArg_ParseTupleAndKeywords(). Only the type 'O' is used, and we
also support some extensions:
- Required keyword-only arguments are introduced after '@'
- If the function receives *args or **kwargs, we add a '%' prefix
Each group requires the previous groups' delimiters to be present
first.
These are used by both vectorcall and legacy wrapper functions.
"""
format = ''
if groups[ARG_STAR] or groups[ARG_STAR2]:
format += '%'
format += 'O' * len(groups[ARG_POS])
if groups[ARG_OPT] or groups[ARG_NAMED_OPT] or groups[ARG_NAMED]:
format += '|' + 'O' * len(groups[ARG_OPT])
if groups[ARG_NAMED_OPT] or groups[ARG_NAMED]:
format += '$' + 'O' * len(groups[ARG_NAMED_OPT])
if groups[ARG_NAMED]:
format += '@' + 'O' * len(groups[ARG_NAMED])
if func_name is not None:
format += ':{}'.format(func_name)
return format
def generate_wrapper_function(fn: FuncIR,
emitter: Emitter,
source_path: str,
module_name: str) -> None:
"""Generate a CPython-compatible vectorcall wrapper for a native function.
In particular, this handles unboxing the arguments, calling the native function, and
then boxing the return value.
"""
emitter.emit_line('{} {{'.format(wrapper_function_header(fn, emitter.names)))
# If fn is a method, then the first argument is a self param
real_args = list(fn.args)
if fn.class_name and not fn.decl.kind == FUNC_STATICMETHOD:
arg = real_args.pop(0)
emitter.emit_line('PyObject *obj_{} = self;'.format(arg.name))
# Need to order args as: required, optional, kwonly optional, kwonly required
# This is because CPyArg_ParseStackAndKeywords format string requires
# them grouped in that way.
groups = make_arg_groups(real_args)
reordered_args = reorder_arg_groups(groups)
emitter.emit_line(make_static_kwlist(reordered_args))
fmt = make_format_string(fn.name, groups)
# Define the arguments the function accepts (but no types yet)
emitter.emit_line('static CPyArg_Parser parser = {{"{}", kwlist, 0}};'.format(fmt))
for arg in real_args:
emitter.emit_line('PyObject *obj_{}{};'.format(
arg.name, ' = NULL' if arg.optional else ''))
cleanups = ['CPy_DECREF(obj_{});'.format(arg.name)
for arg in groups[ARG_STAR] + groups[ARG_STAR2]]
arg_ptrs: List[str] = []
if groups[ARG_STAR] or groups[ARG_STAR2]:
arg_ptrs += ['&obj_{}'.format(groups[ARG_STAR][0].name) if groups[ARG_STAR] else 'NULL']
arg_ptrs += ['&obj_{}'.format(groups[ARG_STAR2][0].name) if groups[ARG_STAR2] else 'NULL']
arg_ptrs += ['&obj_{}'.format(arg.name) for arg in reordered_args]
if fn.name == '__call__' and use_vectorcall(emitter.capi_version):
nargs = 'PyVectorcall_NARGS(nargs)'
else:
nargs = 'nargs'
parse_fn = 'CPyArg_ParseStackAndKeywords'
# Special case some common signatures
if len(real_args) == 0:
# No args
parse_fn = 'CPyArg_ParseStackAndKeywordsNoArgs'
elif len(real_args) == 1 and len(groups[ARG_POS]) == 1:
# Single positional arg
parse_fn = 'CPyArg_ParseStackAndKeywordsOneArg'
elif len(real_args) == len(groups[ARG_POS]) + len(groups[ARG_OPT]):
# No keyword-only args, *args or **kwargs
parse_fn = 'CPyArg_ParseStackAndKeywordsSimple'
emitter.emit_lines(
'if (!{}(args, {}, kwnames, &parser{})) {{'.format(
parse_fn, nargs, ''.join(', ' + n for n in arg_ptrs)),
'return NULL;',
'}')
traceback_code = generate_traceback_code(fn, emitter, source_path, module_name)
generate_wrapper_core(fn, emitter, groups[ARG_OPT] + groups[ARG_NAMED_OPT],
cleanups=cleanups,
traceback_code=traceback_code)
emitter.emit_line('}')
# Legacy generic wrapper functions
#
# These take a self object, a Python tuple of positional arguments,
# and a dict of keyword arguments. These are a lot slower than
# vectorcall wrappers, especially in calls involving keyword
# arguments.
def legacy_wrapper_function_header(fn: FuncIR, names: NameGenerator) -> str:
return 'PyObject *{prefix}{name}(PyObject *self, PyObject *args, PyObject *kw)'.format(
prefix=PREFIX,
name=fn.cname(names))
def generate_legacy_wrapper_function(fn: FuncIR,
emitter: Emitter,
source_path: str,
module_name: str) -> None:
"""Generates a CPython-compatible legacy wrapper for a native function.
In particular, this handles unboxing the arguments, calling the native function, and
then boxing the return value.
"""
emitter.emit_line('{} {{'.format(legacy_wrapper_function_header(fn, emitter.names)))
# If fn is a method, then the first argument is a self param
real_args = list(fn.args)
if fn.class_name and not fn.decl.kind == FUNC_STATICMETHOD:
arg = real_args.pop(0)
emitter.emit_line('PyObject *obj_{} = self;'.format(arg.name))
# Need to order args as: required, optional, kwonly optional, kwonly required
# This is because CPyArg_ParseTupleAndKeywords format string requires
# them grouped in that way.
groups = make_arg_groups(real_args)
reordered_args = reorder_arg_groups(groups)
emitter.emit_line(make_static_kwlist(reordered_args))
for arg in real_args:
emitter.emit_line('PyObject *obj_{}{};'.format(
arg.name, ' = NULL' if arg.optional else ''))
cleanups = ['CPy_DECREF(obj_{});'.format(arg.name)
for arg in groups[ARG_STAR] + groups[ARG_STAR2]]
arg_ptrs: List[str] = []
if groups[ARG_STAR] or groups[ARG_STAR2]:
arg_ptrs += ['&obj_{}'.format(groups[ARG_STAR][0].name) if groups[ARG_STAR] else 'NULL']
arg_ptrs += ['&obj_{}'.format(groups[ARG_STAR2][0].name) if groups[ARG_STAR2] else 'NULL']
arg_ptrs += ['&obj_{}'.format(arg.name) for arg in reordered_args]
emitter.emit_lines(
'if (!CPyArg_ParseTupleAndKeywords(args, kw, "{}", "{}", kwlist{})) {{'.format(
make_format_string(None, groups), fn.name, ''.join(', ' + n for n in arg_ptrs)),
'return NULL;',
'}')
traceback_code = generate_traceback_code(fn, emitter, source_path, module_name)
generate_wrapper_core(fn, emitter, groups[ARG_OPT] + groups[ARG_NAMED_OPT],
cleanups=cleanups,
traceback_code=traceback_code)
emitter.emit_line('}')
# Specialized wrapper functions
def generate_dunder_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str:
"""Generates a wrapper for native __dunder__ methods to be able to fit into the mapping
protocol slot. This specifically means that the arguments are taken as *PyObjects and returned
as *PyObjects.
"""
gen = WrapperGenerator(cl, emitter)
gen.set_target(fn)
gen.emit_header()
gen.emit_arg_processing()
gen.emit_call()
gen.finish()
return gen.wrapper_name()
def generate_bin_op_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str:
"""Generates a wrapper for a native binary dunder method.
The same wrapper that handles the forward method (e.g. __add__) also handles
the corresponding reverse method (e.g. __radd__), if defined.
Both arguments and the return value are PyObject *.
"""
gen = WrapperGenerator(cl, emitter)
gen.set_target(fn)
gen.arg_names = ['left', 'right']
wrapper_name = gen.wrapper_name()
gen.emit_header()
if fn.name not in reverse_op_methods and fn.name in reverse_op_method_names:
# There's only a reverse operator method.
generate_bin_op_reverse_only_wrapper(emitter, gen)
else:
rmethod = reverse_op_methods[fn.name]
fn_rev = cl.get_method(rmethod)
if fn_rev is None:
# There's only a forward operator method.
generate_bin_op_forward_only_wrapper(fn, emitter, gen)
else:
# There's both a forward and a reverse operator method.
generate_bin_op_both_wrappers(cl, fn, fn_rev, emitter, gen)
return wrapper_name
def generate_bin_op_forward_only_wrapper(fn: FuncIR,
emitter: Emitter,
gen: 'WrapperGenerator') -> None:
gen.emit_arg_processing(error=GotoHandler('typefail'), raise_exception=False)
gen.emit_call(not_implemented_handler='goto typefail;')
gen.emit_error_handling()
emitter.emit_label('typefail')
# If some argument has an incompatible type, treat this the same as
# returning NotImplemented, and try to call the reverse operator method.
#
# Note that in normal Python you'd instead of an explicit
# return of NotImplemented, but it doesn't generally work here
# the body won't be executed at all if there is an argument
# type check failure.
#
# The recommended way is to still use a type check in the
# body. This will only be used in interpreted mode:
#
# def __add__(self, other: int) -> Foo:
# if not isinstance(other, int):
# return NotImplemented
# ...
rmethod = reverse_op_methods[fn.name]
emitter.emit_line('_Py_IDENTIFIER({});'.format(rmethod))
emitter.emit_line(
'return CPy_CallReverseOpMethod(obj_left, obj_right, "{}", &PyId_{});'.format(
op_methods_to_symbols[fn.name],
rmethod))
gen.finish()
def generate_bin_op_reverse_only_wrapper(emitter: Emitter,
gen: 'WrapperGenerator') -> None:
gen.arg_names = ['right', 'left']
gen.emit_arg_processing(error=GotoHandler('typefail'), raise_exception=False)
gen.emit_call()
gen.emit_error_handling()
emitter.emit_label('typefail')
emitter.emit_line('Py_INCREF(Py_NotImplemented);')
emitter.emit_line('return Py_NotImplemented;')
gen.finish()
def generate_bin_op_both_wrappers(cl: ClassIR,
fn: FuncIR,
fn_rev: FuncIR,
emitter: Emitter,
gen: 'WrapperGenerator') -> None:
# There's both a forward and a reverse operator method. First
# check if we should try calling the forward one. If the
# argument type check fails, fall back to the reverse method.
#
# Similar to above, we can't perfectly match Python semantics.
# In regular Python code you'd return NotImplemented if the
# operand has the wrong type, but in compiled code we'll never
# get to execute the type check.
emitter.emit_line('if (PyObject_IsInstance(obj_left, (PyObject *){})) {{'.format(
emitter.type_struct_name(cl)))
gen.emit_arg_processing(error=GotoHandler('typefail'), raise_exception=False)
gen.emit_call(not_implemented_handler='goto typefail;')
gen.emit_error_handling()
emitter.emit_line('}')
emitter.emit_label('typefail')
emitter.emit_line('if (PyObject_IsInstance(obj_right, (PyObject *){})) {{'.format(
emitter.type_struct_name(cl)))
gen.set_target(fn_rev)
gen.arg_names = ['right', 'left']
gen.emit_arg_processing(error=GotoHandler('typefail2'), raise_exception=False)
gen.emit_call()
gen.emit_error_handling()
emitter.emit_line('} else {')
emitter.emit_line('_Py_IDENTIFIER({});'.format(fn_rev.name))
emitter.emit_line(
'return CPy_CallReverseOpMethod(obj_left, obj_right, "{}", &PyId_{});'.format(
op_methods_to_symbols[fn.name],
fn_rev.name))
emitter.emit_line('}')
emitter.emit_label('typefail2')
emitter.emit_line('Py_INCREF(Py_NotImplemented);')
emitter.emit_line('return Py_NotImplemented;')
gen.finish()
RICHCOMPARE_OPS = {
'__lt__': 'Py_LT',
'__gt__': 'Py_GT',
'__le__': 'Py_LE',
'__ge__': 'Py_GE',
'__eq__': 'Py_EQ',
'__ne__': 'Py_NE',
}
def generate_richcompare_wrapper(cl: ClassIR, emitter: Emitter) -> Optional[str]:
"""Generates a wrapper for richcompare dunder methods."""
# Sort for determinism on Python 3.5
matches = sorted([name for name in RICHCOMPARE_OPS if cl.has_method(name)])
if not matches:
return None
name = '{}_RichCompare_{}'.format(DUNDER_PREFIX, cl.name_prefix(emitter.names))
emitter.emit_line(
'static PyObject *{name}(PyObject *obj_lhs, PyObject *obj_rhs, int op) {{'.format(
name=name)
)
emitter.emit_line('switch (op) {')
for func in matches:
emitter.emit_line('case {}: {{'.format(RICHCOMPARE_OPS[func]))
method = cl.get_method(func)
assert method is not None
generate_wrapper_core(method, emitter, arg_names=['lhs', 'rhs'])
emitter.emit_line('}')
emitter.emit_line('}')
emitter.emit_line('Py_INCREF(Py_NotImplemented);')
emitter.emit_line('return Py_NotImplemented;')
emitter.emit_line('}')
return name
def generate_get_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str:
"""Generates a wrapper for native __get__ methods."""
name = '{}{}{}'.format(DUNDER_PREFIX, fn.name, cl.name_prefix(emitter.names))
emitter.emit_line(
'static PyObject *{name}(PyObject *self, PyObject *instance, PyObject *owner) {{'.
format(name=name))
emitter.emit_line('instance = instance ? instance : Py_None;')
emitter.emit_line('return {}{}(self, instance, owner);'.format(
NATIVE_PREFIX,
fn.cname(emitter.names)))
emitter.emit_line('}')
return name
def generate_hash_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str:
"""Generates a wrapper for native __hash__ methods."""
name = '{}{}{}'.format(DUNDER_PREFIX, fn.name, cl.name_prefix(emitter.names))
emitter.emit_line('static Py_ssize_t {name}(PyObject *self) {{'.format(
name=name
))
emitter.emit_line('{}retval = {}{}{}(self);'.format(emitter.ctype_spaced(fn.ret_type),
emitter.get_group_prefix(fn.decl),
NATIVE_PREFIX,
fn.cname(emitter.names)))
emitter.emit_error_check('retval', fn.ret_type, 'return -1;')
if is_int_rprimitive(fn.ret_type):
emitter.emit_line('Py_ssize_t val = CPyTagged_AsSsize_t(retval);')
else:
emitter.emit_line('Py_ssize_t val = PyLong_AsSsize_t(retval);')
emitter.emit_dec_ref('retval', fn.ret_type)
emitter.emit_line('if (PyErr_Occurred()) return -1;')
# We can't return -1 from a hash function..
emitter.emit_line('if (val == -1) return -2;')
emitter.emit_line('return val;')
emitter.emit_line('}')
return name
def generate_len_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str:
"""Generates a wrapper for native __len__ methods."""
name = '{}{}{}'.format(DUNDER_PREFIX, fn.name, cl.name_prefix(emitter.names))
emitter.emit_line('static Py_ssize_t {name}(PyObject *self) {{'.format(
name=name
))
emitter.emit_line('{}retval = {}{}{}(self);'.format(emitter.ctype_spaced(fn.ret_type),
emitter.get_group_prefix(fn.decl),
NATIVE_PREFIX,
fn.cname(emitter.names)))
emitter.emit_error_check('retval', fn.ret_type, 'return -1;')
if is_int_rprimitive(fn.ret_type):
emitter.emit_line('Py_ssize_t val = CPyTagged_AsSsize_t(retval);')
else:
emitter.emit_line('Py_ssize_t val = PyLong_AsSsize_t(retval);')
emitter.emit_dec_ref('retval', fn.ret_type)
emitter.emit_line('if (PyErr_Occurred()) return -1;')
emitter.emit_line('return val;')
emitter.emit_line('}')
return name
def generate_bool_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str:
"""Generates a wrapper for native __bool__ methods."""
name = '{}{}{}'.format(DUNDER_PREFIX, fn.name, cl.name_prefix(emitter.names))
emitter.emit_line('static int {name}(PyObject *self) {{'.format(
name=name
))
emitter.emit_line('{}val = {}{}(self);'.format(emitter.ctype_spaced(fn.ret_type),
NATIVE_PREFIX,
fn.cname(emitter.names)))
emitter.emit_error_check('val', fn.ret_type, 'return -1;')
# This wouldn't be that hard to fix but it seems unimportant and
# getting error handling and unboxing right would be fiddly. (And
# way easier to do in IR!)
assert is_bool_rprimitive(fn.ret_type), "Only bool return supported for __bool__"
emitter.emit_line('return val;')
emitter.emit_line('}')
return name
def generate_del_item_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str:
"""Generates a wrapper for native __delitem__.
This is only called from a combined __delitem__/__setitem__ wrapper.
"""
name = '{}{}{}'.format(DUNDER_PREFIX, '__delitem__', cl.name_prefix(emitter.names))
input_args = ', '.join('PyObject *obj_{}'.format(arg.name) for arg in fn.args)
emitter.emit_line('static int {name}({input_args}) {{'.format(
name=name,
input_args=input_args,
))
generate_set_del_item_wrapper_inner(fn, emitter, fn.args)
return name
def generate_set_del_item_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str:
"""Generates a wrapper for native __setitem__ method (also works for __delitem__).
This is used with the mapping protocol slot. Arguments are taken as *PyObjects and we
return a negative C int on error.
Create a separate wrapper function for __delitem__ as needed and have the
__setitem__ wrapper call it if the value is NULL. Return the name
of the outer (__setitem__) wrapper.
"""
method_cls = cl.get_method_and_class('__delitem__')
del_name = None
if method_cls and method_cls[1] == cl:
# Generate a separate wrapper for __delitem__
del_name = generate_del_item_wrapper(cl, method_cls[0], emitter)
args = fn.args
if fn.name == '__delitem__':
# Add an extra argument for value that we expect to be NULL.
args = list(args) + [RuntimeArg('___value', object_rprimitive, ARG_POS)]
name = '{}{}{}'.format(DUNDER_PREFIX, '__setitem__', cl.name_prefix(emitter.names))
input_args = ', '.join('PyObject *obj_{}'.format(arg.name) for arg in args)
emitter.emit_line('static int {name}({input_args}) {{'.format(
name=name,
input_args=input_args,
))
# First check if this is __delitem__
emitter.emit_line('if (obj_{} == NULL) {{'.format(args[2].name))
if del_name is not None:
# We have a native implementation, so call it
emitter.emit_line('return {}(obj_{}, obj_{});'.format(del_name,
args[0].name,
args[1].name))
else:
# Try to call superclass method instead
emitter.emit_line(
'PyObject *super = CPy_Super(CPyModule_builtins, obj_{});'.format(args[0].name))
emitter.emit_line('if (super == NULL) return -1;')
emitter.emit_line(
'PyObject *result = PyObject_CallMethod(super, "__delitem__", "O", obj_{});'.format(
args[1].name))
emitter.emit_line('Py_DECREF(super);')
emitter.emit_line('Py_XDECREF(result);')
emitter.emit_line('return result == NULL ? -1 : 0;')
emitter.emit_line('}')
method_cls = cl.get_method_and_class('__setitem__')
if method_cls and method_cls[1] == cl:
generate_set_del_item_wrapper_inner(fn, emitter, args)
else:
emitter.emit_line(
'PyObject *super = CPy_Super(CPyModule_builtins, obj_{});'.format(args[0].name))
emitter.emit_line('if (super == NULL) return -1;')
emitter.emit_line('PyObject *result;')
if method_cls is None and cl.builtin_base is None:
msg = "'{}' object does not support item assignment".format(cl.name)
emitter.emit_line(
'PyErr_SetString(PyExc_TypeError, "{}");'.format(msg))
emitter.emit_line('result = NULL;')
else:
# A base class may have __setitem__
emitter.emit_line(
'result = PyObject_CallMethod(super, "__setitem__", "OO", obj_{}, obj_{});'.format(
args[1].name, args[2].name))
emitter.emit_line('Py_DECREF(super);')
emitter.emit_line('Py_XDECREF(result);')
emitter.emit_line('return result == NULL ? -1 : 0;')
emitter.emit_line('}')
return name
def generate_set_del_item_wrapper_inner(fn: FuncIR, emitter: Emitter,
args: Sequence[RuntimeArg]) -> None:
for arg in args:
generate_arg_check(arg.name, arg.type, emitter, GotoHandler('fail'))
native_args = ', '.join('arg_{}'.format(arg.name) for arg in args)
emitter.emit_line('{}val = {}{}({});'.format(emitter.ctype_spaced(fn.ret_type),
NATIVE_PREFIX,
fn.cname(emitter.names),
native_args))
emitter.emit_error_check('val', fn.ret_type, 'goto fail;')
emitter.emit_dec_ref('val', fn.ret_type)
emitter.emit_line('return 0;')
emitter.emit_label('fail')
emitter.emit_line('return -1;')
emitter.emit_line('}')
def generate_contains_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str:
"""Generates a wrapper for a native __contains__ method."""
name = '{}{}{}'.format(DUNDER_PREFIX, fn.name, cl.name_prefix(emitter.names))
emitter.emit_line(
'static int {name}(PyObject *self, PyObject *obj_item) {{'.
format(name=name))
generate_arg_check('item', fn.args[1].type, emitter, ReturnHandler('-1'))
emitter.emit_line('{}val = {}{}(self, arg_item);'.format(emitter.ctype_spaced(fn.ret_type),
NATIVE_PREFIX,
fn.cname(emitter.names)))
emitter.emit_error_check('val', fn.ret_type, 'return -1;')
if is_bool_rprimitive(fn.ret_type):
emitter.emit_line('return val;')
else:
emitter.emit_line('int boolval = PyObject_IsTrue(val);')
emitter.emit_dec_ref('val', fn.ret_type)
emitter.emit_line('return boolval;')
emitter.emit_line('}')
return name
# Helpers
def generate_wrapper_core(fn: FuncIR,
emitter: Emitter,
optional_args: Optional[List[RuntimeArg]] = None,
arg_names: Optional[List[str]] = None,
cleanups: Optional[List[str]] = None,
traceback_code: Optional[str] = None) -> None:
"""Generates the core part of a wrapper function for a native function.
This expects each argument as a PyObject * named obj_{arg} as a precondition.
It converts the PyObject *s to the necessary types, checking and unboxing if necessary,
makes the call, then boxes the result if necessary and returns it.
"""
gen = WrapperGenerator(None, emitter)
gen.set_target(fn)
gen.arg_names = arg_names or [arg.name for arg in fn.args]
gen.cleanups = cleanups or []
gen.optional_args = optional_args or []
gen.traceback_code = traceback_code or ''
error = ReturnHandler('NULL') if not gen.use_goto() else GotoHandler('fail')
gen.emit_arg_processing(error=error)
gen.emit_call()
gen.emit_error_handling()
def generate_arg_check(name: str,
typ: RType,
emitter: Emitter,
error: Optional[ErrorHandler] = None,
*,
optional: bool = False,
raise_exception: bool = True) -> None:
"""Insert a runtime check for argument and unbox if necessary.
The object is named PyObject *obj_{}. This is expected to generate
a value of name arg_{} (unboxed if necessary). For each primitive a runtime
check ensures the correct type.
"""
error = error or AssignHandler()
if typ.is_unboxed:
# Borrow when unboxing to avoid reference count manipulation.
emitter.emit_unbox('obj_{}'.format(name),
'arg_{}'.format(name),
typ,
declare_dest=True,
raise_exception=raise_exception,
error=error,
borrow=True,
optional=optional)
elif is_object_rprimitive(typ):
# Object is trivial since any object is valid
if optional:
emitter.emit_line('PyObject *arg_{};'.format(name))
emitter.emit_line('if (obj_{} == NULL) {{'.format(name))
emitter.emit_line('arg_{} = {};'.format(name, emitter.c_error_value(typ)))
emitter.emit_lines('} else {', 'arg_{} = obj_{}; '.format(name, name), '}')
else:
emitter.emit_line('PyObject *arg_{} = obj_{};'.format(name, name))
else:
emitter.emit_cast('obj_{}'.format(name),
'arg_{}'.format(name),
typ,
declare_dest=True,
raise_exception=raise_exception,
error=error,
optional=optional)
class WrapperGenerator:
"""Helper that simplifies the generation of wrapper functions."""
# TODO: Use this for more wrappers
def __init__(self, cl: Optional[ClassIR], emitter: Emitter) -> None:
self.cl = cl
self.emitter = emitter
self.cleanups: List[str] = []
self.optional_args: List[RuntimeArg] = []
self.traceback_code = ''
def set_target(self, fn: FuncIR) -> None:
"""Set the wrapped function.
It's fine to modify the attributes initialized here later to customize
the wrapper function.
"""
self.target_name = fn.name
self.target_cname = fn.cname(self.emitter.names)
self.arg_names = [arg.name for arg in fn.args]
self.args = fn.args[:]
self.ret_type = fn.ret_type
def wrapper_name(self) -> str:
"""Return the name of the wrapper function."""
return '{}{}{}'.format(DUNDER_PREFIX,
self.target_name,
self.cl.name_prefix(self.emitter.names) if self.cl else '')
def use_goto(self) -> bool:
"""Do we use a goto for error handling (instead of straight return)?"""
return bool(self.cleanups or self.traceback_code)
def emit_header(self) -> None:
"""Emit the function header of the wrapper implementation."""
input_args = ', '.join('PyObject *obj_{}'.format(arg) for arg in self.arg_names)
self.emitter.emit_line('static PyObject *{name}({input_args}) {{'.format(
name=self.wrapper_name(),
input_args=input_args,
))
def emit_arg_processing(self,
error: Optional[ErrorHandler] = None,
raise_exception: bool = True) -> None:
"""Emit validation and unboxing of arguments."""
error = error or self.error()
for arg_name, arg in zip(self.arg_names, self.args):
# Suppress the argument check for *args/**kwargs, since we know it must be right.
typ = arg.type if arg.kind not in (ARG_STAR, ARG_STAR2) else object_rprimitive
generate_arg_check(arg_name,
typ,
self.emitter,
error,
raise_exception=raise_exception,
optional=arg in self.optional_args)
def emit_call(self, not_implemented_handler: str = '') -> None:
"""Emit call to the wrapper function.
If not_implemented_handler is non-empty, use this C code to handle
a NotImplemented return value (if it's possible based on the return type).
"""
native_args = ', '.join('arg_{}'.format(arg) for arg in self.arg_names)
ret_type = self.ret_type
emitter = self.emitter
if ret_type.is_unboxed or self.use_goto():
# TODO: The Py_RETURN macros return the correct PyObject * with reference count
# handling. Are they relevant?
emitter.emit_line('{}retval = {}{}({});'.format(emitter.ctype_spaced(ret_type),
NATIVE_PREFIX,
self.target_cname,
native_args))
emitter.emit_lines(*self.cleanups)
if ret_type.is_unboxed:
emitter.emit_error_check('retval', ret_type, 'return NULL;')
emitter.emit_box('retval', 'retbox', ret_type, declare_dest=True)
emitter.emit_line(
'return {};'.format('retbox' if ret_type.is_unboxed else 'retval'))
else:
if not_implemented_handler and not isinstance(ret_type, RInstance):
# The return value type may overlap with NotImplemented.
emitter.emit_line('PyObject *retbox = {}{}({});'.format(NATIVE_PREFIX,
self.target_cname,
native_args))
emitter.emit_lines('if (retbox == Py_NotImplemented) {',
not_implemented_handler,
'}',
'return retbox;')
else:
emitter.emit_line('return {}{}({});'.format(NATIVE_PREFIX,
self.target_cname,
native_args))
# TODO: Tracebacks?
def error(self) -> ErrorHandler:
"""Figure out how to deal with errors in the wrapper."""
if self.cleanups or self.traceback_code:
# We'll have a label at the end with error handling code.
return GotoHandler('fail')
else:
# Nothing special needs to done to handle errors, so just return.
return ReturnHandler('NULL')
def emit_error_handling(self) -> None:
"""Emit error handling block at the end of the wrapper, if needed."""
emitter = self.emitter
if self.use_goto():
emitter.emit_label('fail')
emitter.emit_lines(*self.cleanups)
if self.traceback_code:
emitter.emit_line(self.traceback_code)
emitter.emit_line('return NULL;')
def finish(self) -> None:
self.emitter.emit_line('}')

View file

@ -0,0 +1,279 @@
from typing import Dict, List, Union, Tuple, Any, cast
from typing_extensions import Final
# Supported Python literal types. All tuple items must have supported
# literal types as well, but we can't represent the type precisely.
LiteralValue = Union[str, bytes, int, bool, float, complex, Tuple[object, ...], None]
# Some literals are singletons and handled specially (None, False and True)
NUM_SINGLETONS: Final = 3
class Literals:
"""Collection of literal values used in a compilation group and related helpers."""
def __init__(self) -> None:
# Each dict maps value to literal index (0, 1, ...)
self.str_literals: Dict[str, int] = {}
self.bytes_literals: Dict[bytes, int] = {}
self.int_literals: Dict[int, int] = {}
self.float_literals: Dict[float, int] = {}
self.complex_literals: Dict[complex, int] = {}
self.tuple_literals: Dict[Tuple[object, ...], int] = {}
def record_literal(self, value: LiteralValue) -> None:
"""Ensure that the literal value is available in generated code."""
if value is None or value is True or value is False:
# These are special cased and always present
return
if isinstance(value, str):
str_literals = self.str_literals
if value not in str_literals:
str_literals[value] = len(str_literals)
elif isinstance(value, bytes):
bytes_literals = self.bytes_literals
if value not in bytes_literals:
bytes_literals[value] = len(bytes_literals)
elif isinstance(value, int):
int_literals = self.int_literals
if value not in int_literals:
int_literals[value] = len(int_literals)
elif isinstance(value, float):
float_literals = self.float_literals
if value not in float_literals:
float_literals[value] = len(float_literals)
elif isinstance(value, complex):
complex_literals = self.complex_literals
if value not in complex_literals:
complex_literals[value] = len(complex_literals)
elif isinstance(value, tuple):
tuple_literals = self.tuple_literals
if value not in tuple_literals:
for item in value:
self.record_literal(cast(Any, item))
tuple_literals[value] = len(tuple_literals)
else:
assert False, 'invalid literal: %r' % value
def literal_index(self, value: LiteralValue) -> int:
"""Return the index to the literals array for given value."""
# The array contains first None and booleans, followed by all str values,
# followed by bytes values, etc.
if value is None:
return 0
elif value is False:
return 1
elif value is True:
return 2
n = NUM_SINGLETONS
if isinstance(value, str):
return n + self.str_literals[value]
n += len(self.str_literals)
if isinstance(value, bytes):
return n + self.bytes_literals[value]
n += len(self.bytes_literals)
if isinstance(value, int):
return n + self.int_literals[value]
n += len(self.int_literals)
if isinstance(value, float):
return n + self.float_literals[value]
n += len(self.float_literals)
if isinstance(value, complex):
return n + self.complex_literals[value]
n += len(self.complex_literals)
if isinstance(value, tuple):
return n + self.tuple_literals[value]
assert False, 'invalid literal: %r' % value
def num_literals(self) -> int:
# The first three are for None, True and False
return (NUM_SINGLETONS + len(self.str_literals) + len(self.bytes_literals) +
len(self.int_literals) + len(self.float_literals) + len(self.complex_literals) +
len(self.tuple_literals))
# The following methods return the C encodings of literal values
# of different types
def encoded_str_values(self) -> List[bytes]:
return _encode_str_values(self.str_literals)
def encoded_int_values(self) -> List[bytes]:
return _encode_int_values(self.int_literals)
def encoded_bytes_values(self) -> List[bytes]:
return _encode_bytes_values(self.bytes_literals)
def encoded_float_values(self) -> List[str]:
return _encode_float_values(self.float_literals)
def encoded_complex_values(self) -> List[str]:
return _encode_complex_values(self.complex_literals)
def encoded_tuple_values(self) -> List[str]:
"""Encode tuple values into a C array.
The format of the result is like this:
<number of tuples>
<length of the first tuple>
<literal index of first item>
...
<literal index of last item>
<length of the second tuple>
...
"""
values = self.tuple_literals
value_by_index = {}
for value, index in values.items():
value_by_index[index] = value
result = []
num = len(values)
result.append(str(num))
for i in range(num):
value = value_by_index[i]
result.append(str(len(value)))
for item in value:
index = self.literal_index(cast(Any, item))
result.append(str(index))
return result
def _encode_str_values(values: Dict[str, int]) -> List[bytes]:
value_by_index = {}
for value, index in values.items():
value_by_index[index] = value
result = []
line: List[bytes] = []
line_len = 0
for i in range(len(values)):
value = value_by_index[i]
c_literal = format_str_literal(value)
c_len = len(c_literal)
if line_len > 0 and line_len + c_len > 70:
result.append(format_int(len(line)) + b''.join(line))
line = []
line_len = 0
line.append(c_literal)
line_len += c_len
if line:
result.append(format_int(len(line)) + b''.join(line))
result.append(b'')
return result
def _encode_bytes_values(values: Dict[bytes, int]) -> List[bytes]:
value_by_index = {}
for value, index in values.items():
value_by_index[index] = value
result = []
line: List[bytes] = []
line_len = 0
for i in range(len(values)):
value = value_by_index[i]
c_init = format_int(len(value))
c_len = len(c_init) + len(value)
if line_len > 0 and line_len + c_len > 70:
result.append(format_int(len(line)) + b''.join(line))
line = []
line_len = 0
line.append(c_init + value)
line_len += c_len
if line:
result.append(format_int(len(line)) + b''.join(line))
result.append(b'')
return result
def format_int(n: int) -> bytes:
"""Format an integer using a variable-length binary encoding."""
if n < 128:
a = [n]
else:
a = []
while n > 0:
a.insert(0, n & 0x7f)
n >>= 7
for i in range(len(a) - 1):
# If the highest bit is set, more 7-bit digits follow
a[i] |= 0x80
return bytes(a)
def format_str_literal(s: str) -> bytes:
utf8 = s.encode('utf-8')
return format_int(len(utf8)) + utf8
def _encode_int_values(values: Dict[int, int]) -> List[bytes]:
"""Encode int values into C strings.
Values are stored in base 10 and separated by 0 bytes.
"""
value_by_index = {}
for value, index in values.items():
value_by_index[index] = value
result = []
line: List[bytes] = []
line_len = 0
for i in range(len(values)):
value = value_by_index[i]
encoded = b'%d' % value
if line_len > 0 and line_len + len(encoded) > 70:
result.append(format_int(len(line)) + b'\0'.join(line))
line = []
line_len = 0
line.append(encoded)
line_len += len(encoded)
if line:
result.append(format_int(len(line)) + b'\0'.join(line))
result.append(b'')
return result
def float_to_c(x: float) -> str:
"""Return C literal representation of a float value."""
s = str(x)
if s == 'inf':
return 'INFINITY'
elif s == '-inf':
return '-INFINITY'
return s
def _encode_float_values(values: Dict[float, int]) -> List[str]:
"""Encode float values into a C array values.
The result contains the number of values followed by individual values.
"""
value_by_index = {}
for value, index in values.items():
value_by_index[index] = value
result = []
num = len(values)
result.append(str(num))
for i in range(num):
value = value_by_index[i]
result.append(float_to_c(value))
return result
def _encode_complex_values(values: Dict[complex, int]) -> List[str]:
"""Encode float values into a C array values.
The result contains the number of values followed by pairs of doubles
representing complex numbers.
"""
value_by_index = {}
for value, index in values.items():
value_by_index[index] = value
result = []
num = len(values)
result.append(str(num))
for i in range(num):
value = value_by_index[i]
result.append(float_to_c(value.real))
result.append(float_to_c(value.imag))
return result