Skip to content

Commit 88350f8

Browse files
authored
fix: custom logits processor (#489)
Signed-off-by: Wallas Santos <[email protected]>
1 parent 440cc7a commit 88350f8

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

tests/e2e/test_logits_processors.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010

1111
def test_custom_logits_processor(model: ModelInfo, backend, monkeypatch,
12-
warmup_shapes, cb):
12+
max_num_seqs, max_model_len, warmup_shapes,
13+
cb):
1314
'''
1415
Simple test to check if custom logits processors are being registered
1516
'''
@@ -41,11 +42,12 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor:
4142

4243
spyre_model = LLM(model=model.name,
4344
revision=model.revision,
44-
max_model_len=128,
45+
max_model_len=max_model_len,
46+
max_num_seqs=max_num_seqs,
4547
logits_processors=[DummyLogitsProcessor])
4648
prompt = "Hello Logits Processors"
4749
params = SamplingParams(max_tokens=5, temperature=0, logprobs=0)
4850

49-
spyre_model.generate(prompt, params)[0]
51+
spyre_model.generate(prompt, params)
5052

5153
assert has_invoked_logits_processor

0 commit comments

Comments
 (0)