Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Decrease the size of the generated stencils and the runtime JIT code. Patch by Diego Russo.
80 changes: 54 additions & 26 deletions Python/jit.c
Original file line number Diff line number Diff line change
Expand Up @@ -130,18 +130,20 @@ mark_executable(unsigned char *memory, size_t size)

// JIT compiler stuff: /////////////////////////////////////////////////////////

#define SYMBOL_MASK_WORDS 4
#define GOT_SLOT_SIZE sizeof(uintptr_t)
#define SYMBOL_MASK_WORDS 8

typedef uint32_t symbol_mask[SYMBOL_MASK_WORDS];

typedef struct {
unsigned char *mem;
symbol_mask mask;
size_t size;
} trampoline_state;
} symbol_state;

typedef struct {
trampoline_state trampolines;
symbol_state trampolines;
symbol_state got_symbols;
uintptr_t instruction_starts[UOP_MAX_TRACE_LENGTH];
} jit_state;

Expand Down Expand Up @@ -205,6 +207,33 @@ set_bits(uint32_t *loc, uint8_t loc_start, uint64_t value, uint8_t value_start,
// - x86_64-unknown-linux-gnu:
// - https://github.com/llvm/llvm-project/blob/main/lld/ELF/Arch/X86_64.cpp


// Get the symbol slot memory location for a given symbol ordinal.
static unsigned char *
get_symbol_slot(int ordinal, symbol_state *state, int size)
{
const uint32_t symbol_mask = 1U << (ordinal % 32);
const uint32_t state_mask = state->mask[ordinal / 32];
assert(symbol_mask & state_mask);

// Count the number of set bits in the symbol mask lower than ordinal
size_t index = _Py_popcount32(state_mask & (symbol_mask - 1));
for (int i = 0; i < ordinal / 32; i++) {
index += _Py_popcount32(state->mask[i]);
}

unsigned char *slot = state->mem + index * size;
assert((size_t)(index + 1) * size <= state->size);
return slot;
}

// Return the address of the GOT slot for the requested symbol ordinal.
static uintptr_t
got_symbol_address(int ordinal, jit_state *state)
{
return (uintptr_t)get_symbol_slot(ordinal, &state->got_symbols, GOT_SLOT_SIZE);
}

// Many of these patches are "relaxing", meaning that they can rewrite the
// code they're patching to be more efficient (like turning a 64-bit memory
// load into a 32-bit immediate load). These patches have an "x" in their name.
Expand Down Expand Up @@ -447,6 +476,7 @@ patch_x86_64_32rx(unsigned char *location, uint64_t value)
patch_32r(location, value);
}

void patch_got_symbol(jit_state *state, int ordinal);
void patch_aarch64_trampoline(unsigned char *location, int ordinal, jit_state *state);
void patch_x86_64_trampoline(unsigned char *location, int ordinal, jit_state *state);

Expand All @@ -465,23 +495,13 @@ void patch_x86_64_trampoline(unsigned char *location, int ordinal, jit_state *st
#define DATA_ALIGN 1
#endif

// Get the trampoline memory location for a given symbol ordinal.
static unsigned char *
get_trampoline_slot(int ordinal, jit_state *state)
// Populate the GOT entry for the given symbol ordinal with its resolved address.
void
patch_got_symbol(jit_state *state, int ordinal)
{
const uint32_t symbol_mask = 1 << (ordinal % 32);
const uint32_t trampoline_mask = state->trampolines.mask[ordinal / 32];
assert(symbol_mask & trampoline_mask);

// Count the number of set bits in the trampoline mask lower than ordinal
int index = _Py_popcount32(trampoline_mask & (symbol_mask - 1));
for (int i = 0; i < ordinal / 32; i++) {
index += _Py_popcount32(state->trampolines.mask[i]);
}

unsigned char *trampoline = state->trampolines.mem + index * TRAMPOLINE_SIZE;
assert((size_t)(index + 1) * TRAMPOLINE_SIZE <= state->trampolines.size);
return trampoline;
uint64_t value = (uintptr_t)symbols_map[ordinal];
unsigned char *location = (unsigned char *)get_symbol_slot(ordinal, &state->got_symbols, GOT_SLOT_SIZE);
patch_64(location, value);
}

// Generate and patch AArch64 trampolines. The symbols to jump to are stored
Expand All @@ -501,8 +521,7 @@ patch_aarch64_trampoline(unsigned char *location, int ordinal, jit_state *state)
}

// Out of range - need a trampoline
uint32_t *p = (uint32_t *)get_trampoline_slot(ordinal, state);

uint32_t *p = (uint32_t *)get_symbol_slot(ordinal, &state->trampolines, TRAMPOLINE_SIZE);

/* Generate the trampoline
0: 58000048 ldr x8, 8
Expand Down Expand Up @@ -532,7 +551,7 @@ patch_x86_64_trampoline(unsigned char *location, int ordinal, jit_state *state)
}

// Out of range - need a trampoline
unsigned char *trampoline = get_trampoline_slot(ordinal, state);
unsigned char *trampoline = get_symbol_slot(ordinal, &state->trampolines, TRAMPOLINE_SIZE);

/* Generate the trampoline (14 bytes, padded to 16):
0: ff 25 00 00 00 00 jmp *(%rip)
Expand Down Expand Up @@ -574,21 +593,26 @@ _PyJIT_Compile(_PyExecutorObject *executor, const _PyUOpInstruction trace[], siz
code_size += group->code_size;
data_size += group->data_size;
combine_symbol_mask(group->trampoline_mask, state.trampolines.mask);
combine_symbol_mask(group->got_mask, state.got_symbols.mask);
}
group = &stencil_groups[_FATAL_ERROR];
code_size += group->code_size;
data_size += group->data_size;
combine_symbol_mask(group->trampoline_mask, state.trampolines.mask);
combine_symbol_mask(group->got_mask, state.got_symbols.mask);
// Calculate the size of the trampolines required by the whole trace
for (size_t i = 0; i < Py_ARRAY_LENGTH(state.trampolines.mask); i++) {
state.trampolines.size += _Py_popcount32(state.trampolines.mask[i]) * TRAMPOLINE_SIZE;
}
for (size_t i = 0; i < Py_ARRAY_LENGTH(state.got_symbols.mask); i++) {
state.got_symbols.size += _Py_popcount32(state.got_symbols.mask[i]) * GOT_SLOT_SIZE;
}
// Round up to the nearest page:
size_t page_size = get_page_size();
assert((page_size & (page_size - 1)) == 0);
size_t code_padding = DATA_ALIGN - ((code_size + state.trampolines.size) & (DATA_ALIGN - 1));
size_t padding = page_size - ((code_size + state.trampolines.size + code_padding + data_size) & (page_size - 1));
size_t total_size = code_size + state.trampolines.size + code_padding + data_size + padding;
size_t padding = page_size - ((code_size + state.trampolines.size + code_padding + data_size + state.got_symbols.size) & (page_size - 1));
size_t total_size = code_size + state.trampolines.size + code_padding + data_size + state.got_symbols.size + padding;
unsigned char *memory = jit_alloc(total_size);
if (memory == NULL) {
return -1;
Expand All @@ -598,6 +622,7 @@ _PyJIT_Compile(_PyExecutorObject *executor, const _PyUOpInstruction trace[], siz
OPT_STAT_ADD(jit_code_size, code_size);
OPT_STAT_ADD(jit_trampoline_size, state.trampolines.size);
OPT_STAT_ADD(jit_data_size, data_size);
OPT_STAT_ADD(jit_got_size, state.got_symbols.size);
OPT_STAT_ADD(jit_padding_size, padding);
OPT_HIST(total_size, trace_total_memory_hist);
// Update the offsets of each instruction:
Expand All @@ -608,6 +633,7 @@ _PyJIT_Compile(_PyExecutorObject *executor, const _PyUOpInstruction trace[], siz
unsigned char *code = memory;
state.trampolines.mem = memory + code_size;
unsigned char *data = memory + code_size + state.trampolines.size + code_padding;
state.got_symbols.mem = data + data_size;
assert(trace[0].opcode == _START_EXECUTOR || trace[0].opcode == _COLD_EXIT || trace[0].opcode == _COLD_DYNAMIC_EXIT);
for (size_t i = 0; i < length; i++) {
const _PyUOpInstruction *instruction = &trace[i];
Expand Down Expand Up @@ -649,19 +675,21 @@ compile_trampoline(void)
code_size += group->code_size;
data_size += group->data_size;
combine_symbol_mask(group->trampoline_mask, state.trampolines.mask);
combine_symbol_mask(group->got_mask, state.got_symbols.mask);
// Round up to the nearest page:
size_t page_size = get_page_size();
assert((page_size & (page_size - 1)) == 0);
size_t code_padding = DATA_ALIGN - ((code_size + state.trampolines.size) & (DATA_ALIGN - 1));
size_t padding = page_size - ((code_size + state.trampolines.size + code_padding + data_size) & (page_size - 1));
size_t total_size = code_size + state.trampolines.size + code_padding + data_size + padding;
size_t padding = page_size - ((code_size + state.trampolines.size + code_padding + data_size + state.got_symbols.size) & (page_size - 1));
size_t total_size = code_size + state.trampolines.size + code_padding + data_size + state.got_symbols.size + padding;
unsigned char *memory = jit_alloc(total_size);
if (memory == NULL) {
return NULL;
}
unsigned char *code = memory;
state.trampolines.mem = memory + code_size;
unsigned char *data = memory + code_size + state.trampolines.size + code_padding;
state.got_symbols.mem = data + data_size;
// Compile the shim, which handles converting between the native
// calling convention and the calling convention used by jitted code
// (which may be different for efficiency reasons).
Expand Down
121 changes: 97 additions & 24 deletions Tools/jit/_stencils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ class HoleValue(enum.Enum):
HoleValue.CODE: "(uintptr_t)code",
HoleValue.DATA: "(uintptr_t)data",
HoleValue.EXECUTOR: "(uintptr_t)executor",
HoleValue.GOT: "",
# These should all have been turned into DATA values by process_relocations:
# HoleValue.GOT: "",
HoleValue.OPARG: "instruction->oparg",
HoleValue.OPERAND0: "instruction->operand0",
HoleValue.OPERAND0_HI: "(instruction->operand0 >> 32)",
Expand All @@ -115,6 +115,24 @@ class HoleValue(enum.Enum):
HoleValue.ZERO: "",
}

_AARCH64_GOT_RELOCATIONS = {
"R_AARCH64_ADR_GOT_PAGE",
"R_AARCH64_LD64_GOT_LO12_NC",
"ARM64_RELOC_GOT_LOAD_PAGE21",
"ARM64_RELOC_GOT_LOAD_PAGEOFF12",
"IMAGE_REL_ARM64_PAGEBASE_REL21",
"IMAGE_REL_ARM64_PAGEOFFSET_12L",
"IMAGE_REL_ARM64_PAGEOFFSET_12A",
}

_X86_GOT_RELOCATIONS = {
"R_X86_64_GOTPCRELX",
"R_X86_64_REX_GOTPCRELX",
"X86_64_RELOC_GOT",
"X86_64_RELOC_GOT_LOAD",
"IMAGE_REL_AMD64_REL32",
}


@dataclasses.dataclass
class Hole:
Expand All @@ -133,6 +151,8 @@ class Hole:
# ...plus this addend:
addend: int
need_state: bool = False
custom_location: str = ""
custom_value: str = ""
func: str = dataclasses.field(init=False)
# Convenience method:
replace = dataclasses.replace
Expand Down Expand Up @@ -170,16 +190,22 @@ def fold(self, other: typing.Self, body: bytearray) -> typing.Self | None:

def as_c(self, where: str) -> str:
"""Dump this hole as a call to a patch_* function."""
location = f"{where} + {self.offset:#x}"
value = _HOLE_EXPRS[self.value]
if self.symbol:
if value:
value += " + "
value += f"(uintptr_t)&{self.symbol}"
if _signed(self.addend) or not value:
if value:
value += " + "
value += f"{_signed(self.addend):#x}"
if self.custom_location:
location = self.custom_location
else:
location = f"{where} + {self.offset:#x}"
if self.custom_value:
value = self.custom_value
else:
value = _HOLE_EXPRS[self.value]
if self.symbol:
if value:
value += " + "
value += f"(uintptr_t)&{self.symbol}"
if _signed(self.addend) or not value:
if value:
value += " + "
value += f"{_signed(self.addend):#x}"
if self.need_state:
return f"{self.func}({location}, {value}, state);"
return f"{self.func}({location}, {value});"
Expand Down Expand Up @@ -219,8 +245,11 @@ class StencilGroup:
symbols: dict[int | str, tuple[HoleValue, int]] = dataclasses.field(
default_factory=dict, init=False
)
_got: dict[str, int] = dataclasses.field(default_factory=dict, init=False)
_jit_symbol_table: dict[str, int] = dataclasses.field(
default_factory=dict, init=False
)
_trampolines: set[int] = dataclasses.field(default_factory=set, init=False)
_got_entries: set[int] = dataclasses.field(default_factory=set, init=False)

def convert_labels_to_relocations(self) -> None:
for name, hole_plus in self.symbols.items():
Expand Down Expand Up @@ -270,13 +299,39 @@ def process_relocations(self, known_symbols: dict[str, int]) -> None:
self._trampolines.add(ordinal)
hole.addend = ordinal
hole.symbol = None
elif (
hole.kind in _AARCH64_GOT_RELOCATIONS | _X86_GOT_RELOCATIONS
and hole.symbol
and "_JIT_" not in hole.symbol
and hole.value is HoleValue.GOT
):
if hole.symbol in known_symbols:
ordinal = known_symbols[hole.symbol]
else:
ordinal = len(known_symbols)
known_symbols[hole.symbol] = ordinal
self._got_entries.add(ordinal)
self.data.pad(8)
for stencil in [self.code, self.data]:
for hole in stencil.holes:
if hole.value is HoleValue.GOT:
assert hole.symbol is not None
hole.value = HoleValue.DATA
hole.addend += self._global_offset_table_lookup(hole.symbol)
if "_JIT_" in hole.symbol:
# Relocations for local symbols
hole.value = HoleValue.DATA
hole.addend += self._jit_symbol_table_lookup(hole.symbol)
else:
_ordinal = known_symbols[hole.symbol]
_custom_value = f"got_symbol_address({_ordinal:#x}, state)"
if hole.kind in _X86_GOT_RELOCATIONS:
# When patching on x86, subtract the addend -4
# that is used to compute the 32 bit RIP relative
# displacement to the GOT entry
_custom_value = (
f"got_symbol_address({_ordinal:#x}, state) - 4"
)
hole.addend = _ordinal
hole.custom_value = _custom_value
hole.symbol = None
elif hole.symbol in self.symbols:
hole.value, addend = self.symbols[hole.symbol]
Expand All @@ -289,16 +344,19 @@ def process_relocations(self, known_symbols: dict[str, int]) -> None:
raise ValueError(
f"Add PyAPI_FUNC(...) or PyAPI_DATA(...) to declaration of {hole.symbol}!"
)
self._emit_jit_symbol_table()
self._emit_global_offset_table()
self.code.holes.sort(key=lambda hole: hole.offset)
self.data.holes.sort(key=lambda hole: hole.offset)

def _global_offset_table_lookup(self, symbol: str) -> int:
return len(self.data.body) + self._got.setdefault(symbol, 8 * len(self._got))
def _jit_symbol_table_lookup(self, symbol: str) -> int:
return len(self.data.body) + self._jit_symbol_table.setdefault(
symbol, 8 * len(self._jit_symbol_table)
)

def _emit_global_offset_table(self) -> None:
def _emit_jit_symbol_table(self) -> None:
got = len(self.data.body)
for s, offset in self._got.items():
for s, offset in self._jit_symbol_table.items():
if s in self.symbols:
value, addend = self.symbols[s]
symbol = None
Expand All @@ -322,20 +380,35 @@ def _emit_global_offset_table(self) -> None:
)
self.data.body.extend([0] * 8)

def _get_trampoline_mask(self) -> str:
def _emit_global_offset_table(self) -> None:
for hole in self.code.holes:
if hole.value is HoleValue.GOT:
_got_hole = Hole(0, "R_X86_64_64", hole.value, None, hole.addend)
_got_hole.func = "patch_got_symbol"
_got_hole.custom_location = "state"
if _got_hole not in self.data.holes:
self.data.holes.append(_got_hole)

def _get_symbol_mask(self, ordinals: set[int]) -> str:
bitmask: int = 0
trampoline_mask: list[str] = []
for ordinal in self._trampolines:
symbol_mask: list[str] = []
for ordinal in ordinals:
bitmask |= 1 << ordinal
while bitmask:
word = bitmask & ((1 << 32) - 1)
trampoline_mask.append(f"{word:#04x}")
symbol_mask.append(f"{word:#04x}")
bitmask >>= 32
return "{" + (", ".join(trampoline_mask) or "0") + "}"
return "{" + (", ".join(symbol_mask) or "0") + "}"

def _get_trampoline_mask(self) -> str:
return self._get_symbol_mask(self._trampolines)

def _get_got_mask(self) -> str:
return self._get_symbol_mask(self._got_entries)

def as_c(self, opname: str) -> str:
"""Dump this hole as a StencilGroup initializer."""
return f"{{emit_{opname}, {len(self.code.body)}, {len(self.data.body)}, {self._get_trampoline_mask()}}}"
return f"{{emit_{opname}, {len(self.code.body)}, {len(self.data.body)}, {self._get_trampoline_mask()}, {self._get_got_mask()}}}"


def symbol_to_value(symbol: str) -> tuple[HoleValue, str | None]:
Expand Down
1 change: 1 addition & 0 deletions Tools/jit/_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def _dump_footer(
yield " size_t code_size;"
yield " size_t data_size;"
yield " symbol_mask trampoline_mask;"
yield " symbol_mask got_mask;"
yield "} StencilGroup;"
yield ""
yield f"static const StencilGroup trampoline = {groups['trampoline'].as_c('trampoline')};"
Expand Down
Loading