kiln_ai.adapters.chat

 1from .chat_formatter import (
 2    BasicChatMessage,
 3    ChatCompletionMessageIncludingLiteLLM,
 4    ChatFormatter,
 5    ChatMessage,
 6    ChatStrategy,
 7    MultiturnFormatter,
 8    ToolCallMessage,
 9    ToolResponseMessage,
10    get_chat_formatter,
11)
12from .chat_utils import build_tool_call_messages
13
14__all__ = [
15    "BasicChatMessage",
16    "ChatCompletionMessageIncludingLiteLLM",
17    "ChatFormatter",
18    "ChatMessage",
19    "ChatStrategy",
20    "MultiturnFormatter",
21    "ToolCallMessage",
22    "ToolResponseMessage",
23    "build_tool_call_messages",
24    "get_chat_formatter",
25]
@dataclass
class BasicChatMessage:
26@dataclass
27class BasicChatMessage:
28    role: Literal["system", "assistant", "user"]
29    content: Optional[str]
BasicChatMessage(role: Literal['system', 'assistant', 'user'], content: Optional[str])
role: Literal['system', 'assistant', 'user']
content: Optional[str]
ChatCompletionMessageIncludingLiteLLM = typing.Union[openai.types.chat.chat_completion_developer_message_param.ChatCompletionDeveloperMessageParam, openai.types.chat.chat_completion_system_message_param.ChatCompletionSystemMessageParam, openai.types.chat.chat_completion_user_message_param.ChatCompletionUserMessageParam, kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper, kiln_ai.utils.open_ai_types.ChatCompletionToolMessageParamWrapper, openai.types.chat.chat_completion_function_message_param.ChatCompletionFunctionMessageParam, litellm.types.utils.Message]
class ChatFormatter(abc.ABC):
 88class ChatFormatter(ABC):
 89    def __init__(
 90        self,
 91        system_message: str,
 92        user_input: InputType,
 93        thinking_instructions: str | None = None,
 94    ) -> None:
 95        self.system_message = system_message
 96        self.user_input = user_input
 97        self.thinking_instructions = thinking_instructions
 98        self._messages: List[ChatMessage] = []
 99        self._state = "start"
100        self._intermediate_outputs: Dict[str, str] = {}
101
102    @property
103    def messages(self) -> List[ChatMessage]:
104        return list(self._messages)
105
106    def append_messages(self, messages: Sequence[ChatMessage]) -> None:
107        """Append messages to the internal messages list."""
108        self._messages.extend(messages)
109
110    def message_dicts(self) -> List[dict]:
111        result = []
112        for m in self._messages:
113            msg_dict = {"role": m.role, "content": m.content}
114            if isinstance(m, ToolCallMessage):
115                msg_dict["tool_calls"] = m.tool_calls
116            elif isinstance(m, ToolResponseMessage):
117                msg_dict["tool_call_id"] = m.tool_call_id
118            result.append(msg_dict)
119        return result
120
121    def intermediate_outputs(self) -> Dict[str, str]:
122        """Get the intermediate outputs from the chat formatter."""
123        return self._intermediate_outputs
124
125    def initial_messages(self) -> list[ChatCompletionMessageIncludingLiteLLM]:
126        """Messages to seed the conversation. Empty for fresh runs; prior trace for continuation."""
127        return []
128
129    @abstractmethod
130    def next_turn(self, previous_output: str | None = None) -> Optional[ChatTurn]:
131        """Advance the conversation and return the next messages if any."""
132        raise NotImplementedError

Helper class that provides a standard way to create an ABC using inheritance.

system_message
user_input
thinking_instructions
messages: List[Union[BasicChatMessage, ToolCallMessage, ToolResponseMessage]]
102    @property
103    def messages(self) -> List[ChatMessage]:
104        return list(self._messages)
def append_messages( self, messages: Sequence[Union[BasicChatMessage, ToolCallMessage, ToolResponseMessage]]) -> None:
106    def append_messages(self, messages: Sequence[ChatMessage]) -> None:
107        """Append messages to the internal messages list."""
108        self._messages.extend(messages)

Append messages to the internal messages list.

def message_dicts(self) -> List[dict]:
110    def message_dicts(self) -> List[dict]:
111        result = []
112        for m in self._messages:
113            msg_dict = {"role": m.role, "content": m.content}
114            if isinstance(m, ToolCallMessage):
115                msg_dict["tool_calls"] = m.tool_calls
116            elif isinstance(m, ToolResponseMessage):
117                msg_dict["tool_call_id"] = m.tool_call_id
118            result.append(msg_dict)
119        return result
def intermediate_outputs(self) -> Dict[str, str]:
121    def intermediate_outputs(self) -> Dict[str, str]:
122        """Get the intermediate outputs from the chat formatter."""
123        return self._intermediate_outputs

Get the intermediate outputs from the chat formatter.

def initial_messages( self) -> list[typing.Union[openai.types.chat.chat_completion_developer_message_param.ChatCompletionDeveloperMessageParam, openai.types.chat.chat_completion_system_message_param.ChatCompletionSystemMessageParam, openai.types.chat.chat_completion_user_message_param.ChatCompletionUserMessageParam, kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper, kiln_ai.utils.open_ai_types.ChatCompletionToolMessageParamWrapper, openai.types.chat.chat_completion_function_message_param.ChatCompletionFunctionMessageParam, litellm.types.utils.Message]]:
125    def initial_messages(self) -> list[ChatCompletionMessageIncludingLiteLLM]:
126        """Messages to seed the conversation. Empty for fresh runs; prior trace for continuation."""
127        return []

Messages to seed the conversation. Empty for fresh runs; prior trace for continuation.

@abstractmethod
def next_turn( self, previous_output: str | None = None) -> Optional[kiln_ai.adapters.chat.chat_formatter.ChatTurn]:
129    @abstractmethod
130    def next_turn(self, previous_output: str | None = None) -> Optional[ChatTurn]:
131        """Advance the conversation and return the next messages if any."""
132        raise NotImplementedError

Advance the conversation and return the next messages if any.

ChatMessage = typing.Union[BasicChatMessage, ToolCallMessage, ToolResponseMessage]
class ChatStrategy(builtins.str, enum.Enum):
65class ChatStrategy(str, Enum):
66    """Strategy for how a chat is structured."""
67
68    # Single turn, immediately return the answer
69    single_turn = "final_only"
70    # Two turn, first turn is the thinking, second turn is the answer. Legacy format - used for old fine tunes but not new trains.
71    two_message_cot_legacy = "final_and_intermediate"
72    # Two turn, first turn is the thinking, second turn is the answer. New format - used for new trains.
73    two_message_cot = "two_message_cot"
74    # Single turn, with both the thinking and the answer in the same message, using R1-style thinking format in <think> tags
75    single_turn_r1_thinking = "final_and_intermediate_r1_compatible"

Strategy for how a chat is structured.

single_turn = <ChatStrategy.single_turn: 'final_only'>
two_message_cot_legacy = <ChatStrategy.two_message_cot_legacy: 'final_and_intermediate'>
two_message_cot = <ChatStrategy.two_message_cot: 'two_message_cot'>
single_turn_r1_thinking = <ChatStrategy.single_turn_r1_thinking: 'final_and_intermediate_r1_compatible'>
class MultiturnFormatter(kiln_ai.adapters.chat.ChatFormatter):
275class MultiturnFormatter(ChatFormatter):
276    """
277    Formatter for continuing a multi-turn conversation with prior trace.
278    Takes prior_trace (existing conversation) and appends the new user message.
279    Produces a single turn: the new user message. Tool calls and multi-turn
280    model responses are handled by _run_model_turn's internal loop.
281
282    When user_input is a dict or list with tool_call_id keys, the input is
283    treated as tool call results (role "tool") rather than a user message.
284    This supports resuming after a return_on_tool_call interrupt.
285    """
286
287    def __init__(
288        self,
289        prior_trace: list[ChatCompletionMessageParam],
290        user_input: InputType,
291    ) -> None:
292        super().__init__(
293            system_message="",
294            user_input=user_input,
295            thinking_instructions=None,
296        )
297        self._prior_trace = prior_trace
298
299    def initial_messages(self) -> list[ChatCompletionMessageIncludingLiteLLM]:
300        """Messages to seed the conversation (prior trace)."""
301        return list(self._prior_trace)
302
303    @property
304    def _is_tool_result(self) -> bool:
305        """Return True if user_input looks like one or more tool call results."""
306        input = self.user_input
307        if isinstance(input, dict):
308            return "tool_call_id" in input
309        if isinstance(input, list):
310            return bool(input) and all(
311                isinstance(item, dict) and "tool_call_id" in item for item in input
312            )
313        return False
314
315    def next_turn(self, previous_output: str | None = None) -> Optional[ChatTurn]:
316        if self._state == "start":
317            self._state = "awaiting_final"
318            if self._is_tool_result:
319                if isinstance(self.user_input, dict):
320                    raw_items: list[dict] = [self.user_input]
321                else:
322                    raw_items = list(self.user_input)  # type: ignore[arg-type]
323                msgs: list[ChatMessage] = [
324                    ToolResponseMessage(
325                        role="tool",
326                        content=str(item.get("content", "")),
327                        tool_call_id=item["tool_call_id"],
328                        is_error=item.get("is_error"),
329                        error_message=item.get("error_message"),
330                        kiln_task_tool_data=item.get("kiln_task_tool_data"),
331                    )
332                    for item in raw_items
333                ]
334                self._messages.extend(msgs)
335                return ChatTurn(messages=msgs, final_call=True)
336            else:
337                # prior trace is already in the messages list and contains system and so on, we only need
338                # to append the latest new user message
339                user_msg = BasicChatMessage(
340                    "user", format_user_message(self.user_input)
341                )
342                self._messages.append(user_msg)
343                return ChatTurn(messages=[user_msg], final_call=True)
344
345        if self._state == "awaiting_final":
346            if previous_output is None:
347                raise ValueError("previous_output required for final step")
348            self._messages.append(BasicChatMessage("assistant", previous_output))
349            self._state = "done"
350            return None
351
352        return None

Formatter for continuing a multi-turn conversation with prior trace. Takes prior_trace (existing conversation) and appends the new user message. Produces a single turn: the new user message. Tool calls and multi-turn model responses are handled by _run_model_turn's internal loop.

When user_input is a dict or list with tool_call_id keys, the input is treated as tool call results (role "tool") rather than a user message. This supports resuming after a return_on_tool_call interrupt.

MultiturnFormatter( prior_trace: list[typing.Union[openai.types.chat.chat_completion_developer_message_param.ChatCompletionDeveloperMessageParam, openai.types.chat.chat_completion_system_message_param.ChatCompletionSystemMessageParam, openai.types.chat.chat_completion_user_message_param.ChatCompletionUserMessageParam, kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper, kiln_ai.utils.open_ai_types.ChatCompletionToolMessageParamWrapper, openai.types.chat.chat_completion_function_message_param.ChatCompletionFunctionMessageParam]], user_input: Union[Dict[str, Any], List[Any], str])
287    def __init__(
288        self,
289        prior_trace: list[ChatCompletionMessageParam],
290        user_input: InputType,
291    ) -> None:
292        super().__init__(
293            system_message="",
294            user_input=user_input,
295            thinking_instructions=None,
296        )
297        self._prior_trace = prior_trace
def initial_messages( self) -> list[typing.Union[openai.types.chat.chat_completion_developer_message_param.ChatCompletionDeveloperMessageParam, openai.types.chat.chat_completion_system_message_param.ChatCompletionSystemMessageParam, openai.types.chat.chat_completion_user_message_param.ChatCompletionUserMessageParam, kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper, kiln_ai.utils.open_ai_types.ChatCompletionToolMessageParamWrapper, openai.types.chat.chat_completion_function_message_param.ChatCompletionFunctionMessageParam, litellm.types.utils.Message]]:
299    def initial_messages(self) -> list[ChatCompletionMessageIncludingLiteLLM]:
300        """Messages to seed the conversation (prior trace)."""
301        return list(self._prior_trace)

Messages to seed the conversation (prior trace).

def next_turn( self, previous_output: str | None = None) -> Optional[kiln_ai.adapters.chat.chat_formatter.ChatTurn]:
315    def next_turn(self, previous_output: str | None = None) -> Optional[ChatTurn]:
316        if self._state == "start":
317            self._state = "awaiting_final"
318            if self._is_tool_result:
319                if isinstance(self.user_input, dict):
320                    raw_items: list[dict] = [self.user_input]
321                else:
322                    raw_items = list(self.user_input)  # type: ignore[arg-type]
323                msgs: list[ChatMessage] = [
324                    ToolResponseMessage(
325                        role="tool",
326                        content=str(item.get("content", "")),
327                        tool_call_id=item["tool_call_id"],
328                        is_error=item.get("is_error"),
329                        error_message=item.get("error_message"),
330                        kiln_task_tool_data=item.get("kiln_task_tool_data"),
331                    )
332                    for item in raw_items
333                ]
334                self._messages.extend(msgs)
335                return ChatTurn(messages=msgs, final_call=True)
336            else:
337                # prior trace is already in the messages list and contains system and so on, we only need
338                # to append the latest new user message
339                user_msg = BasicChatMessage(
340                    "user", format_user_message(self.user_input)
341                )
342                self._messages.append(user_msg)
343                return ChatTurn(messages=[user_msg], final_call=True)
344
345        if self._state == "awaiting_final":
346            if previous_output is None:
347                raise ValueError("previous_output required for final step")
348            self._messages.append(BasicChatMessage("assistant", previous_output))
349            self._state = "done"
350            return None
351
352        return None

Advance the conversation and return the next messages if any.

@dataclass
class ToolCallMessage:
32@dataclass
33class ToolCallMessage:
34    """Assistant message with tool calls for chat formatting"""
35
36    role: Literal["assistant"]
37    tool_calls: List[ChatCompletionMessageToolCallParam]
38    content: Optional[str] = None

Assistant message with tool calls for chat formatting

ToolCallMessage( role: Literal['assistant'], tool_calls: List[openai.types.chat.chat_completion_message_function_tool_call_param.ChatCompletionMessageFunctionToolCallParam], content: Optional[str] = None)
role: Literal['assistant']
tool_calls: List[openai.types.chat.chat_completion_message_function_tool_call_param.ChatCompletionMessageFunctionToolCallParam]
content: Optional[str] = None
@dataclass
class ToolResponseMessage:
41@dataclass
42class ToolResponseMessage:
43    """Tool response message for chat formatting"""
44
45    role: Literal["tool"]
46    content: str
47    tool_call_id: str
48    is_error: Optional[bool] = None
49    error_message: Optional[str] = None
50    kiln_task_tool_data: Optional[str] = None

Tool response message for chat formatting

ToolResponseMessage( role: Literal['tool'], content: str, tool_call_id: str, is_error: Optional[bool] = None, error_message: Optional[str] = None, kiln_task_tool_data: Optional[str] = None)
role: Literal['tool']
content: str
tool_call_id: str
is_error: Optional[bool] = None
error_message: Optional[str] = None
kiln_task_tool_data: Optional[str] = None
def build_tool_call_messages( trace: list[typing.Union[openai.types.chat.chat_completion_developer_message_param.ChatCompletionDeveloperMessageParam, openai.types.chat.chat_completion_system_message_param.ChatCompletionSystemMessageParam, openai.types.chat.chat_completion_user_message_param.ChatCompletionUserMessageParam, kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper, kiln_ai.utils.open_ai_types.ChatCompletionToolMessageParamWrapper, openai.types.chat.chat_completion_function_message_param.ChatCompletionFunctionMessageParam]] | None) -> list[typing.Union[ToolCallMessage, ToolResponseMessage]]:
11def build_tool_call_messages(
12    trace: list[ChatCompletionMessageParam] | None,
13) -> list[Union[ToolCallMessage, ToolResponseMessage]]:
14    """
15    Extract tool call and tool response messages from a trace. It's based off the OpenAI schema.
16
17    Args:
18        trace: The trace of the task run in OpenAI format
19
20    Returns:
21        List of ToolCallMessage and ToolResponseMessage objects extracted from the trace
22    """
23    if trace is None:
24        return []
25
26    tool_messages: list[Union[ToolCallMessage, ToolResponseMessage]] = []
27
28    for message in trace:
29        role = message.get("role")
30
31        if role == "assistant" and "tool_calls" in message:
32            tool_calls = message.get("tool_calls")
33            if tool_calls:
34                content = message.get("content")
35                tool_messages.append(
36                    ToolCallMessage(
37                        role="assistant",
38                        tool_calls=tool_calls,
39                        content=extract_text_from_content(content),
40                    )
41                )
42        elif role == "tool":
43            content = message.get("content")
44            tool_call_id = message.get("tool_call_id")
45
46            if tool_call_id is None:
47                raise ValueError("Tool call ID is required for tool response messages")
48            if content is None:
49                raise ValueError("Content is required for tool response messages")
50
51            if not isinstance(content, str):
52                content = str(content)
53
54            tool_messages.append(
55                ToolResponseMessage(
56                    role="tool",
57                    content=content,
58                    tool_call_id=tool_call_id,
59                )
60            )
61
62    return tool_messages

Extract tool call and tool response messages from a trace. It's based off the OpenAI schema.

Args: trace: The trace of the task run in OpenAI format

Returns: List of ToolCallMessage and ToolResponseMessage objects extracted from the trace

def get_chat_formatter( strategy: ChatStrategy, system_message: str, user_input: Union[Dict[str, Any], List[Any], str], thinking_instructions: str | None = None) -> ChatFormatter:
355def get_chat_formatter(
356    strategy: ChatStrategy,
357    system_message: str,
358    user_input: InputType,
359    thinking_instructions: str | None = None,
360) -> ChatFormatter:
361    match strategy:
362        case ChatStrategy.single_turn:
363            return SingleTurnFormatter(system_message, user_input)
364        case ChatStrategy.two_message_cot_legacy:
365            return TwoMessageCotLegacyFormatter(
366                system_message, user_input, thinking_instructions
367            )
368        case ChatStrategy.two_message_cot:
369            return TwoMessageCotFormatter(
370                system_message, user_input, thinking_instructions
371            )
372        case ChatStrategy.single_turn_r1_thinking:
373            return SingleTurnR1ThinkingFormatter(system_message, user_input)
374        case _:
375            raise_exhaustive_enum_error(strategy)