diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 1024fb85b..ac8c02513 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -24,6 +24,8 @@ ) import jinja2 +from jinja2 import nodes +from jinja2.ext import Extension from jinja2.sandbox import ImmutableSandboxedEnvironment import numpy as np @@ -191,6 +193,16 @@ def __call__( ) -> ChatFormatterResponse: ... +class _GenerationTagIgnore(Extension): + """Pass-through for HuggingFace's ``{% generation %}`` chat-template tag.""" + + tags = {"generation"} + + def parse(self, parser: jinja2.parser.Parser) -> List[nodes.Node]: + parser.stream.skip(1) # discard the 'generation' tag-name token + return parser.parse_statements(("name:endgeneration",), drop_needle=True) + + class Jinja2ChatFormatter(ChatFormatter): def __init__( self, @@ -213,6 +225,7 @@ def __init__( loader=jinja2.BaseLoader(), trim_blocks=True, lstrip_blocks=True, + extensions=[_GenerationTagIgnore], ).from_string(self.template) @staticmethod diff --git a/tests/test_llama_chat_format.py b/tests/test_llama_chat_format.py index 18c7279cf..16852a472 100644 --- a/tests/test_llama_chat_format.py +++ b/tests/test_llama_chat_format.py @@ -92,3 +92,26 @@ def test_hf_tokenizer_config_str_to_chat_formatter(): ) assert chat_formatter_respoonse.prompt == ("[INST] Hello, world! [/INST]") + + +def test_generation_tag_is_ignored() -> None: + """HuggingFace chat templates use {% generation %}/{% endgeneration %} to + mark training-time loss spans. At inference the tags must be no-ops or + affected GGUFs (SmolLM3 and similar) fail to load with TemplateSyntaxError. + """ + template = ( + "{% for message in messages %}" + "{% generation %}{{ message['role'] }}: {{ message['content'] }}{% endgeneration %}" + "{% endfor %}" + ) + chat_formatter = llama_chat_format.Jinja2ChatFormatter( + template=template, + eos_token="", + bos_token="", + ) + response = chat_formatter( + messages=[ + ChatCompletionRequestUserMessage(role="user", content="hi"), + ] + ) + assert "user: hi" in response.prompt