Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions lotus/models/cross_encoder_reranker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from sentence_transformers import CrossEncoder

from lotus.models.reranker import Reranker
from lotus.types import RerankerOutput
Expand All @@ -20,7 +19,19 @@ def __init__(
max_batch_size: int = 64,
):
self.max_batch_size: int = max_batch_size
self.model = CrossEncoder(model, device=device) # type: ignore # CrossEncoder has wrong type stubs
self._model_name = model
self._device = device
self._model = None # Initialize model as None for lazy loading

@property
def model(self):
"""Lazy load the model when it's first accessed."""
if self._model is None:
# Only import CrossEncoder when needed
from sentence_transformers import CrossEncoder

self._model = CrossEncoder(self._model_name, device=self._device) # type: ignore # CrossEncoder has wrong type stubs
return self._model

def __call__(self, query: str, docs: list[str], K: int) -> RerankerOutput:
results = self.model.rank(query, docs, top_k=K, batch_size=self.max_batch_size, show_progress_bar=False)
Expand Down
Loading