Skip to content

Commit c21de7c

Browse files
committed
GH-142305: JIT: Deduplicating GOT symbols in the trace
1 parent 4085ff7 commit c21de7c

File tree

3 files changed

+151
-50
lines changed

3 files changed

+151
-50
lines changed

Python/jit.c

Lines changed: 54 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -130,18 +130,20 @@ mark_executable(unsigned char *memory, size_t size)
130130

131131
// JIT compiler stuff: /////////////////////////////////////////////////////////
132132

133-
#define SYMBOL_MASK_WORDS 4
133+
#define GOT_SLOT_SIZE sizeof(uintptr_t)
134+
#define SYMBOL_MASK_WORDS 8
134135

135136
typedef uint32_t symbol_mask[SYMBOL_MASK_WORDS];
136137

137138
typedef struct {
138139
unsigned char *mem;
139140
symbol_mask mask;
140141
size_t size;
141-
} trampoline_state;
142+
} symbol_state;
142143

143144
typedef struct {
144-
trampoline_state trampolines;
145+
symbol_state trampolines;
146+
symbol_state got_symbols;
145147
uintptr_t instruction_starts[UOP_MAX_TRACE_LENGTH];
146148
} jit_state;
147149

@@ -205,6 +207,33 @@ set_bits(uint32_t *loc, uint8_t loc_start, uint64_t value, uint8_t value_start,
205207
// - x86_64-unknown-linux-gnu:
206208
// - https://github.com/llvm/llvm-project/blob/main/lld/ELF/Arch/X86_64.cpp
207209

210+
211+
// Get the symbol slot memory location for a given symbol ordinal.
212+
static unsigned char *
213+
get_symbol_slot(int ordinal, symbol_state *state, int size)
214+
{
215+
const uint32_t symbol_mask = 1U << (ordinal % 32);
216+
const uint32_t state_mask = state->mask[ordinal / 32];
217+
assert(symbol_mask & state_mask);
218+
219+
// Count the number of set bits in the symbol mask lower than ordinal
220+
size_t index = _Py_popcount32(state_mask & (symbol_mask - 1));
221+
for (int i = 0; i < ordinal / 32; i++) {
222+
index += _Py_popcount32(state->mask[i]);
223+
}
224+
225+
unsigned char *slot = state->mem + index * size;
226+
assert((size_t)(index + 1) * size <= state->size);
227+
return slot;
228+
}
229+
230+
// Return the address of the GOT slot for the requested symbol ordinal.
231+
static uintptr_t
232+
got_symbol_address(int ordinal, jit_state *state)
233+
{
234+
return (uintptr_t)get_symbol_slot(ordinal, &state->got_symbols, GOT_SLOT_SIZE);
235+
}
236+
208237
// Many of these patches are "relaxing", meaning that they can rewrite the
209238
// code they're patching to be more efficient (like turning a 64-bit memory
210239
// load into a 32-bit immediate load). These patches have an "x" in their name.
@@ -447,6 +476,7 @@ patch_x86_64_32rx(unsigned char *location, uint64_t value)
447476
patch_32r(location, value);
448477
}
449478

479+
void patch_got_symbol(jit_state *state, int ordinal);
450480
void patch_aarch64_trampoline(unsigned char *location, int ordinal, jit_state *state);
451481
void patch_x86_64_trampoline(unsigned char *location, int ordinal, jit_state *state);
452482

@@ -465,23 +495,13 @@ void patch_x86_64_trampoline(unsigned char *location, int ordinal, jit_state *st
465495
#define DATA_ALIGN 1
466496
#endif
467497

468-
// Get the trampoline memory location for a given symbol ordinal.
469-
static unsigned char *
470-
get_trampoline_slot(int ordinal, jit_state *state)
498+
// Populate the GOT entry for the given symbol ordinal with its resolved address.
499+
void
500+
patch_got_symbol(jit_state *state, int ordinal)
471501
{
472-
const uint32_t symbol_mask = 1 << (ordinal % 32);
473-
const uint32_t trampoline_mask = state->trampolines.mask[ordinal / 32];
474-
assert(symbol_mask & trampoline_mask);
475-
476-
// Count the number of set bits in the trampoline mask lower than ordinal
477-
int index = _Py_popcount32(trampoline_mask & (symbol_mask - 1));
478-
for (int i = 0; i < ordinal / 32; i++) {
479-
index += _Py_popcount32(state->trampolines.mask[i]);
480-
}
481-
482-
unsigned char *trampoline = state->trampolines.mem + index * TRAMPOLINE_SIZE;
483-
assert((size_t)(index + 1) * TRAMPOLINE_SIZE <= state->trampolines.size);
484-
return trampoline;
502+
uint64_t value = (uintptr_t)symbols_map[ordinal];
503+
unsigned char *location = (unsigned char *)get_symbol_slot(ordinal, &state->got_symbols, GOT_SLOT_SIZE);
504+
patch_64(location, value);
485505
}
486506

487507
// Generate and patch AArch64 trampolines. The symbols to jump to are stored
@@ -501,8 +521,7 @@ patch_aarch64_trampoline(unsigned char *location, int ordinal, jit_state *state)
501521
}
502522

503523
// Out of range - need a trampoline
504-
uint32_t *p = (uint32_t *)get_trampoline_slot(ordinal, state);
505-
524+
uint32_t *p = (uint32_t *)get_symbol_slot(ordinal, &state->trampolines, TRAMPOLINE_SIZE);
506525

507526
/* Generate the trampoline
508527
0: 58000048 ldr x8, 8
@@ -532,7 +551,7 @@ patch_x86_64_trampoline(unsigned char *location, int ordinal, jit_state *state)
532551
}
533552

534553
// Out of range - need a trampoline
535-
unsigned char *trampoline = get_trampoline_slot(ordinal, state);
554+
unsigned char *trampoline = get_symbol_slot(ordinal, &state->trampolines, TRAMPOLINE_SIZE);
536555

537556
/* Generate the trampoline (14 bytes, padded to 16):
538557
0: ff 25 00 00 00 00 jmp *(%rip)
@@ -574,21 +593,26 @@ _PyJIT_Compile(_PyExecutorObject *executor, const _PyUOpInstruction trace[], siz
574593
code_size += group->code_size;
575594
data_size += group->data_size;
576595
combine_symbol_mask(group->trampoline_mask, state.trampolines.mask);
596+
combine_symbol_mask(group->got_mask, state.got_symbols.mask);
577597
}
578598
group = &stencil_groups[_FATAL_ERROR];
579599
code_size += group->code_size;
580600
data_size += group->data_size;
581601
combine_symbol_mask(group->trampoline_mask, state.trampolines.mask);
602+
combine_symbol_mask(group->got_mask, state.got_symbols.mask);
582603
// Calculate the size of the trampolines required by the whole trace
583604
for (size_t i = 0; i < Py_ARRAY_LENGTH(state.trampolines.mask); i++) {
584605
state.trampolines.size += _Py_popcount32(state.trampolines.mask[i]) * TRAMPOLINE_SIZE;
585606
}
607+
for (size_t i = 0; i < Py_ARRAY_LENGTH(state.got_symbols.mask); i++) {
608+
state.got_symbols.size += _Py_popcount32(state.got_symbols.mask[i]) * GOT_SLOT_SIZE;
609+
}
586610
// Round up to the nearest page:
587611
size_t page_size = get_page_size();
588612
assert((page_size & (page_size - 1)) == 0);
589613
size_t code_padding = DATA_ALIGN - ((code_size + state.trampolines.size) & (DATA_ALIGN - 1));
590-
size_t padding = page_size - ((code_size + state.trampolines.size + code_padding + data_size) & (page_size - 1));
591-
size_t total_size = code_size + state.trampolines.size + code_padding + data_size + padding;
614+
size_t padding = page_size - ((code_size + state.trampolines.size + code_padding + data_size + state.got_symbols.size) & (page_size - 1));
615+
size_t total_size = code_size + state.trampolines.size + code_padding + data_size + state.got_symbols.size + padding;
592616
unsigned char *memory = jit_alloc(total_size);
593617
if (memory == NULL) {
594618
return -1;
@@ -598,6 +622,7 @@ _PyJIT_Compile(_PyExecutorObject *executor, const _PyUOpInstruction trace[], siz
598622
OPT_STAT_ADD(jit_code_size, code_size);
599623
OPT_STAT_ADD(jit_trampoline_size, state.trampolines.size);
600624
OPT_STAT_ADD(jit_data_size, data_size);
625+
OPT_STAT_ADD(jit_got_size, state.got_symbols.size);
601626
OPT_STAT_ADD(jit_padding_size, padding);
602627
OPT_HIST(total_size, trace_total_memory_hist);
603628
// Update the offsets of each instruction:
@@ -608,6 +633,7 @@ _PyJIT_Compile(_PyExecutorObject *executor, const _PyUOpInstruction trace[], siz
608633
unsigned char *code = memory;
609634
state.trampolines.mem = memory + code_size;
610635
unsigned char *data = memory + code_size + state.trampolines.size + code_padding;
636+
state.got_symbols.mem = data + data_size;
611637
assert(trace[0].opcode == _START_EXECUTOR || trace[0].opcode == _COLD_EXIT || trace[0].opcode == _COLD_DYNAMIC_EXIT);
612638
for (size_t i = 0; i < length; i++) {
613639
const _PyUOpInstruction *instruction = &trace[i];
@@ -649,19 +675,21 @@ compile_trampoline(void)
649675
code_size += group->code_size;
650676
data_size += group->data_size;
651677
combine_symbol_mask(group->trampoline_mask, state.trampolines.mask);
678+
combine_symbol_mask(group->got_mask, state.got_symbols.mask);
652679
// Round up to the nearest page:
653680
size_t page_size = get_page_size();
654681
assert((page_size & (page_size - 1)) == 0);
655682
size_t code_padding = DATA_ALIGN - ((code_size + state.trampolines.size) & (DATA_ALIGN - 1));
656-
size_t padding = page_size - ((code_size + state.trampolines.size + code_padding + data_size) & (page_size - 1));
657-
size_t total_size = code_size + state.trampolines.size + code_padding + data_size + padding;
683+
size_t padding = page_size - ((code_size + state.trampolines.size + code_padding + data_size + state.got_symbols.size) & (page_size - 1));
684+
size_t total_size = code_size + state.trampolines.size + code_padding + data_size + state.got_symbols.size + padding;
658685
unsigned char *memory = jit_alloc(total_size);
659686
if (memory == NULL) {
660687
return NULL;
661688
}
662689
unsigned char *code = memory;
663690
state.trampolines.mem = memory + code_size;
664691
unsigned char *data = memory + code_size + state.trampolines.size + code_padding;
692+
state.got_symbols.mem = data + data_size;
665693
// Compile the shim, which handles converting between the native
666694
// calling convention and the calling convention used by jitted code
667695
// (which may be different for efficiency reasons).

Tools/jit/_stencils.py

Lines changed: 96 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ class HoleValue(enum.Enum):
100100
HoleValue.CODE: "(uintptr_t)code",
101101
HoleValue.DATA: "(uintptr_t)data",
102102
HoleValue.EXECUTOR: "(uintptr_t)executor",
103+
HoleValue.GOT: "",
103104
# These should all have been turned into DATA values by process_relocations:
104-
# HoleValue.GOT: "",
105105
HoleValue.OPARG: "instruction->oparg",
106106
HoleValue.OPERAND0: "instruction->operand0",
107107
HoleValue.OPERAND0_HI: "(instruction->operand0 >> 32)",
@@ -115,6 +115,23 @@ class HoleValue(enum.Enum):
115115
HoleValue.ZERO: "",
116116
}
117117

118+
_AARCH64_GOT_RELOCATIONS = {
119+
"R_AARCH64_ADR_GOT_PAGE",
120+
"R_AARCH64_LD64_GOT_LO12_NC",
121+
"ARM64_RELOC_GOT_LOAD_PAGE21",
122+
"ARM64_RELOC_GOT_LOAD_PAGEOFF12",
123+
"IMAGE_REL_ARM64_PAGEBASE_REL21",
124+
"IMAGE_REL_ARM64_PAGEOFFSET_12L",
125+
"IMAGE_REL_ARM64_PAGEOFFSET_12A",
126+
}
127+
128+
_X86_GOT_RELOCATIONS = {
129+
"R_X86_64_GOTPCRELX",
130+
"R_X86_64_REX_GOTPCRELX",
131+
"X86_64_RELOC_GOT",
132+
"X86_64_RELOC_GOT_LOAD",
133+
}
134+
118135

119136
@dataclasses.dataclass
120137
class Hole:
@@ -133,6 +150,8 @@ class Hole:
133150
# ...plus this addend:
134151
addend: int
135152
need_state: bool = False
153+
custom_location: str = ""
154+
custom_value: str = ""
136155
func: str = dataclasses.field(init=False)
137156
# Convenience method:
138157
replace = dataclasses.replace
@@ -170,16 +189,22 @@ def fold(self, other: typing.Self, body: bytearray) -> typing.Self | None:
170189

171190
def as_c(self, where: str) -> str:
172191
"""Dump this hole as a call to a patch_* function."""
173-
location = f"{where} + {self.offset:#x}"
174-
value = _HOLE_EXPRS[self.value]
175-
if self.symbol:
176-
if value:
177-
value += " + "
178-
value += f"(uintptr_t)&{self.symbol}"
179-
if _signed(self.addend) or not value:
180-
if value:
181-
value += " + "
182-
value += f"{_signed(self.addend):#x}"
192+
if self.custom_location:
193+
location = self.custom_location
194+
else:
195+
location = f"{where} + {self.offset:#x}"
196+
if self.custom_value:
197+
value = self.custom_value
198+
else:
199+
value = _HOLE_EXPRS[self.value]
200+
if self.symbol:
201+
if value:
202+
value += " + "
203+
value += f"(uintptr_t)&{self.symbol}"
204+
if _signed(self.addend) or not value:
205+
if value:
206+
value += " + "
207+
value += f"{_signed(self.addend):#x}"
183208
if self.need_state:
184209
return f"{self.func}({location}, {value}, state);"
185210
return f"{self.func}({location}, {value});"
@@ -219,8 +244,11 @@ class StencilGroup:
219244
symbols: dict[int | str, tuple[HoleValue, int]] = dataclasses.field(
220245
default_factory=dict, init=False
221246
)
222-
_got: dict[str, int] = dataclasses.field(default_factory=dict, init=False)
247+
_jit_symbol_table: dict[str, int] = dataclasses.field(
248+
default_factory=dict, init=False
249+
)
223250
_trampolines: set[int] = dataclasses.field(default_factory=set, init=False)
251+
_got_entries: set[int] = dataclasses.field(default_factory=set, init=False)
224252

225253
def convert_labels_to_relocations(self) -> None:
226254
for name, hole_plus in self.symbols.items():
@@ -270,13 +298,39 @@ def process_relocations(self, known_symbols: dict[str, int]) -> None:
270298
self._trampolines.add(ordinal)
271299
hole.addend = ordinal
272300
hole.symbol = None
301+
elif (
302+
hole.kind in _AARCH64_GOT_RELOCATIONS | _X86_GOT_RELOCATIONS
303+
and hole.symbol
304+
and "_JIT_" not in hole.symbol
305+
and hole.value is HoleValue.GOT
306+
):
307+
if hole.symbol in known_symbols:
308+
ordinal = known_symbols[hole.symbol]
309+
else:
310+
ordinal = len(known_symbols)
311+
known_symbols[hole.symbol] = ordinal
312+
self._got_entries.add(ordinal)
273313
self.data.pad(8)
274314
for stencil in [self.code, self.data]:
275315
for hole in stencil.holes:
276316
if hole.value is HoleValue.GOT:
277317
assert hole.symbol is not None
278-
hole.value = HoleValue.DATA
279-
hole.addend += self._global_offset_table_lookup(hole.symbol)
318+
if "_JIT_" in hole.symbol:
319+
# Relocations for local symbols
320+
hole.value = HoleValue.DATA
321+
hole.addend += self._jit_symbol_table_lookup(hole.symbol)
322+
else:
323+
_ordinal = known_symbols[hole.symbol]
324+
_custom_value = f"got_symbol_address({_ordinal:#x}, state)"
325+
if hole.kind in _X86_GOT_RELOCATIONS:
326+
# When patching on x86, subtract the addend -4
327+
# that is used to compute the 32 bit RIP relative
328+
# displacement to the GOT entry
329+
_custom_value = (
330+
f"got_symbol_address({_ordinal:#x}, state) - 4"
331+
)
332+
hole.addend = _ordinal
333+
hole.custom_value = _custom_value
280334
hole.symbol = None
281335
elif hole.symbol in self.symbols:
282336
hole.value, addend = self.symbols[hole.symbol]
@@ -289,16 +343,19 @@ def process_relocations(self, known_symbols: dict[str, int]) -> None:
289343
raise ValueError(
290344
f"Add PyAPI_FUNC(...) or PyAPI_DATA(...) to declaration of {hole.symbol}!"
291345
)
346+
self._emit_jit_symbol_table()
292347
self._emit_global_offset_table()
293348
self.code.holes.sort(key=lambda hole: hole.offset)
294349
self.data.holes.sort(key=lambda hole: hole.offset)
295350

296-
def _global_offset_table_lookup(self, symbol: str) -> int:
297-
return len(self.data.body) + self._got.setdefault(symbol, 8 * len(self._got))
351+
def _jit_symbol_table_lookup(self, symbol: str) -> int:
352+
return len(self.data.body) + self._jit_symbol_table.setdefault(
353+
symbol, 8 * len(self._jit_symbol_table)
354+
)
298355

299-
def _emit_global_offset_table(self) -> None:
356+
def _emit_jit_symbol_table(self) -> None:
300357
got = len(self.data.body)
301-
for s, offset in self._got.items():
358+
for s, offset in self._jit_symbol_table.items():
302359
if s in self.symbols:
303360
value, addend = self.symbols[s]
304361
symbol = None
@@ -322,20 +379,35 @@ def _emit_global_offset_table(self) -> None:
322379
)
323380
self.data.body.extend([0] * 8)
324381

325-
def _get_trampoline_mask(self) -> str:
382+
def _emit_global_offset_table(self) -> None:
383+
for hole in self.code.holes:
384+
if hole.value is HoleValue.GOT:
385+
_got_hole = Hole(0, "R_X86_64_64", hole.value, None, hole.addend)
386+
_got_hole.func = "patch_got_symbol"
387+
_got_hole.custom_location = "state"
388+
if _got_hole not in self.data.holes:
389+
self.data.holes.append(_got_hole)
390+
391+
def _get_symbol_mask(self, ordinals: set[int]) -> str:
326392
bitmask: int = 0
327-
trampoline_mask: list[str] = []
328-
for ordinal in self._trampolines:
393+
symbol_mask: list[str] = []
394+
for ordinal in ordinals:
329395
bitmask |= 1 << ordinal
330396
while bitmask:
331397
word = bitmask & ((1 << 32) - 1)
332-
trampoline_mask.append(f"{word:#04x}")
398+
symbol_mask.append(f"{word:#04x}")
333399
bitmask >>= 32
334-
return "{" + (", ".join(trampoline_mask) or "0") + "}"
400+
return "{" + (", ".join(symbol_mask) or "0") + "}"
401+
402+
def _get_trampoline_mask(self) -> str:
403+
return self._get_symbol_mask(self._trampolines)
404+
405+
def _get_got_mask(self) -> str:
406+
return self._get_symbol_mask(self._got_entries)
335407

336408
def as_c(self, opname: str) -> str:
337409
"""Dump this hole as a StencilGroup initializer."""
338-
return f"{{emit_{opname}, {len(self.code.body)}, {len(self.data.body)}, {self._get_trampoline_mask()}}}"
410+
return f"{{emit_{opname}, {len(self.code.body)}, {len(self.data.body)}, {self._get_trampoline_mask()}, {self._get_got_mask()}}}"
339411

340412

341413
def symbol_to_value(symbol: str) -> tuple[HoleValue, str | None]:

Tools/jit/_writer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def _dump_footer(
2020
yield " size_t code_size;"
2121
yield " size_t data_size;"
2222
yield " symbol_mask trampoline_mask;"
23+
yield " symbol_mask got_mask;"
2324
yield "} StencilGroup;"
2425
yield ""
2526
yield f"static const StencilGroup trampoline = {groups['trampoline'].as_c('trampoline')};"

0 commit comments

Comments
 (0)