diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 11fe169cf..5899c3687 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -887,12 +887,9 @@ def generate( # Check for kv cache prefix match if reset and self.n_tokens > 0: - longest_prefix = 0 - for a, b in zip(self._input_ids, tokens[:-1]): - if a == b: - longest_prefix += 1 - else: - break + longest_prefix = Llama.longest_token_prefix( + self._input_ids, tokens[:-1] + ) if longest_prefix > 0: if self._ctx.kv_cache_seq_rm(-1, longest_prefix, -1): reset = False @@ -1313,10 +1310,10 @@ def logit_bias_processor( try: cache_item = self.cache[prompt_tokens] cache_prefix_len = Llama.longest_token_prefix( - cache_item.input_ids.tolist(), prompt_tokens + cache_item.input_ids, prompt_tokens ) eval_prefix_len = Llama.longest_token_prefix( - self._input_ids.tolist(), prompt_tokens + self._input_ids, prompt_tokens ) if cache_prefix_len > eval_prefix_len: self.load_state(cache_item) @@ -2251,14 +2248,16 @@ def logits_to_logprobs( return subtract_maxs - out @staticmethod - def longest_token_prefix(a: Sequence[int], b: Sequence[int]): - longest_prefix = 0 - for _a, _b in zip(a, b): - if _a == _b: - longest_prefix += 1 - else: - break - return longest_prefix + def longest_token_prefix( + a: Union[Sequence[int], npt.NDArray[np.intc]], + b: Union[Sequence[int], npt.NDArray[np.intc]], + ) -> int: + n = min(len(a), len(b)) + if n == 0: + return 0 + eq = np.asarray(a[:n]) == np.asarray(b[:n]) + mismatch = np.argmin(eq) + return int(n) if eq[mismatch] else int(mismatch) @classmethod def from_pretrained(