Skip to content

Commit

Permalink
Lower backtrack_limit to fail earlier for invalid input
Browse files Browse the repository at this point in the history
  • Loading branch information
Lőrinc committed Feb 12, 2024
1 parent 21c5688 commit 019de85
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
16 changes: 15 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::num::NonZeroU64;
use std::thread;

use fancy_regex::Regex;
use fancy_regex::RegexBuilder;
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::pyclass;
Expand Down Expand Up @@ -417,7 +418,7 @@ impl CoreBPE {
special_tokens_encoder: HashMap<String, Rank>,
pattern: &str,
) -> PyResult<Self> {
let regex = Regex::new(pattern)
let regex = RegexBuilder::new(pattern).backtrack_limit(100_000).build()
.map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string()))?;

let special_regex = {
Expand Down Expand Up @@ -572,6 +573,7 @@ fn _tiktoken(_py: Python, m: &PyModule) -> PyResult<()> {

#[cfg(test)]
mod tests {
use fancy_regex::RegexBuilder;
use rustc_hash::FxHashMap as HashMap;

use crate::{byte_pair_split, Rank};
Expand All @@ -596,4 +598,16 @@ mod tests {
let res = byte_pair_split(b"abab", &ranks);
assert_eq!(res, vec![b"ab", b"ab"]);
}

#[test]
fn test_effect_of_backtrack_limit() {
let regex = RegexBuilder::new(r"(a|b|ab)*(?=c)")
.backtrack_limit(10)
.build()
.expect("Failed to build regex")
.clone();

let input = "ab".repeat(100) + "c";
assert!(regex.is_match(&input).is_err(), "Should throw");
}
}
3 changes: 1 addition & 2 deletions tests/test_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,13 @@
from .test_helpers import ENCODING_FACTORIES, MAX_EXAMPLES


@pytest.mark.skip(reason="Takes a really long time to finish, but was added to reproduce a crash.")
@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES)
def test_extremely_big_encoding(make_enc: Callable[[], tiktoken.Encoding]):
enc = make_enc()
for c in ["^", "0", "a", "'s"]: # TODO " ", "\n" are still failing
print(f"Validating `{c}`")

big_value = c * 1_000_000
big_value = c * 100_000
assert big_value == enc.decode(enc.encode(big_value))

big_value = " " + big_value
Expand Down

0 comments on commit 019de85

Please sign in to comment.