kiln_ai.adapters.chat

 1from .chat_formatter import (
 2    BasicChatMessage,
 3    ChatFormatter,
 4    ChatMessage,
 5    ChatStrategy,
 6    ToolCallMessage,
 7    ToolResponseMessage,
 8    get_chat_formatter,
 9)
10from .chat_utils import build_tool_call_messages
11
12__all__ = [
13    "BasicChatMessage",
14    "ChatFormatter",
15    "ChatMessage",
16    "ChatStrategy",
17    "ToolCallMessage",
18    "ToolResponseMessage",
19    "build_tool_call_messages",
20    "get_chat_formatter",
21]
@dataclass
class BasicChatMessage:
16@dataclass
17class BasicChatMessage:
18    role: Literal["system", "assistant", "user"]
19    content: Optional[str]
BasicChatMessage(role: Literal['system', 'assistant', 'user'], content: Optional[str])
role: Literal['system', 'assistant', 'user']
content: Optional[str]
class ChatFormatter(abc.ABC):
57class ChatFormatter(ABC):
58    def __init__(
59        self,
60        system_message: str,
61        user_input: InputType,
62        thinking_instructions: str | None = None,
63    ) -> None:
64        self.system_message = system_message
65        self.user_input = user_input
66        self.thinking_instructions = thinking_instructions
67        self._messages: List[ChatMessage] = []
68        self._state = "start"
69        self._intermediate_outputs: Dict[str, str] = {}
70
71    @property
72    def messages(self) -> List[ChatMessage]:
73        return list(self._messages)
74
75    def append_messages(self, messages: Sequence[ChatMessage]) -> None:
76        """Append messages to the internal messages list."""
77        self._messages.extend(messages)
78
79    def message_dicts(self) -> List[dict]:
80        result = []
81        for m in self._messages:
82            msg_dict = {"role": m.role, "content": m.content}
83            if isinstance(m, ToolCallMessage):
84                msg_dict["tool_calls"] = m.tool_calls
85            elif isinstance(m, ToolResponseMessage):
86                msg_dict["tool_call_id"] = m.tool_call_id
87            result.append(msg_dict)
88        return result
89
90    def intermediate_outputs(self) -> Dict[str, str]:
91        """Get the intermediate outputs from the chat formatter."""
92        return self._intermediate_outputs
93
94    @abstractmethod
95    def next_turn(self, previous_output: str | None = None) -> Optional[ChatTurn]:
96        """Advance the conversation and return the next messages if any."""
97        raise NotImplementedError

Helper class that provides a standard way to create an ABC using inheritance.

system_message
user_input
thinking_instructions
messages: List[Union[BasicChatMessage, ToolCallMessage, ToolResponseMessage]]
71    @property
72    def messages(self) -> List[ChatMessage]:
73        return list(self._messages)
def append_messages( self, messages: Sequence[Union[BasicChatMessage, ToolCallMessage, ToolResponseMessage]]) -> None:
75    def append_messages(self, messages: Sequence[ChatMessage]) -> None:
76        """Append messages to the internal messages list."""
77        self._messages.extend(messages)

Append messages to the internal messages list.

def message_dicts(self) -> List[dict]:
79    def message_dicts(self) -> List[dict]:
80        result = []
81        for m in self._messages:
82            msg_dict = {"role": m.role, "content": m.content}
83            if isinstance(m, ToolCallMessage):
84                msg_dict["tool_calls"] = m.tool_calls
85            elif isinstance(m, ToolResponseMessage):
86                msg_dict["tool_call_id"] = m.tool_call_id
87            result.append(msg_dict)
88        return result
def intermediate_outputs(self) -> Dict[str, str]:
90    def intermediate_outputs(self) -> Dict[str, str]:
91        """Get the intermediate outputs from the chat formatter."""
92        return self._intermediate_outputs

Get the intermediate outputs from the chat formatter.

@abstractmethod
def next_turn( self, previous_output: str | None = None) -> Optional[kiln_ai.adapters.chat.chat_formatter.ChatTurn]:
94    @abstractmethod
95    def next_turn(self, previous_output: str | None = None) -> Optional[ChatTurn]:
96        """Advance the conversation and return the next messages if any."""
97        raise NotImplementedError

Advance the conversation and return the next messages if any.

ChatMessage = typing.Union[BasicChatMessage, ToolCallMessage, ToolResponseMessage]
class ChatStrategy(builtins.str, enum.Enum):
65class ChatStrategy(str, Enum):
66    """Strategy for how a chat is structured."""
67
68    # Single turn, immediately return the answer
69    single_turn = "final_only"
70    # Two turn, first turn is the thinking, second turn is the answer. Legacy format - used for old fine tunes but not new trains.
71    two_message_cot_legacy = "final_and_intermediate"
72    # Two turn, first turn is the thinking, second turn is the answer. New format - used for new trains.
73    two_message_cot = "two_message_cot"
74    # Single turn, with both the thinking and the answer in the same message, using R1-style thinking format in <think> tags
75    single_turn_r1_thinking = "final_and_intermediate_r1_compatible"

Strategy for how a chat is structured.

single_turn = <ChatStrategy.single_turn: 'final_only'>
two_message_cot_legacy = <ChatStrategy.two_message_cot_legacy: 'final_and_intermediate'>
two_message_cot = <ChatStrategy.two_message_cot: 'two_message_cot'>
single_turn_r1_thinking = <ChatStrategy.single_turn_r1_thinking: 'final_and_intermediate_r1_compatible'>
@dataclass
class ToolCallMessage:
22@dataclass
23class ToolCallMessage:
24    """Assistant message with tool calls for chat formatting"""
25
26    role: Literal["assistant"]
27    tool_calls: List[ChatCompletionMessageToolCallParam]
28    content: Optional[str] = None

Assistant message with tool calls for chat formatting

ToolCallMessage( role: Literal['assistant'], tool_calls: List[openai.types.chat.chat_completion_message_function_tool_call_param.ChatCompletionMessageFunctionToolCallParam], content: Optional[str] = None)
role: Literal['assistant']
tool_calls: List[openai.types.chat.chat_completion_message_function_tool_call_param.ChatCompletionMessageFunctionToolCallParam]
content: Optional[str] = None
@dataclass
class ToolResponseMessage:
31@dataclass
32class ToolResponseMessage:
33    """Tool response message for chat formatting"""
34
35    role: Literal["tool"]
36    content: str
37    tool_call_id: str

Tool response message for chat formatting

ToolResponseMessage(role: Literal['tool'], content: str, tool_call_id: str)
role: Literal['tool']
content: str
tool_call_id: str
def build_tool_call_messages( trace: list[typing.Union[openai.types.chat.chat_completion_developer_message_param.ChatCompletionDeveloperMessageParam, openai.types.chat.chat_completion_system_message_param.ChatCompletionSystemMessageParam, openai.types.chat.chat_completion_user_message_param.ChatCompletionUserMessageParam, kiln_ai.utils.open_ai_types.ChatCompletionAssistantMessageParamWrapper, kiln_ai.utils.open_ai_types.ChatCompletionToolMessageParamWrapper, openai.types.chat.chat_completion_function_message_param.ChatCompletionFunctionMessageParam]] | None) -> list[typing.Union[ToolCallMessage, ToolResponseMessage]]:
11def build_tool_call_messages(
12    trace: list[ChatCompletionMessageParam] | None,
13) -> list[Union[ToolCallMessage, ToolResponseMessage]]:
14    """
15    Extract tool call and tool response messages from a trace. It's based off the OpenAI schema.
16
17    Args:
18        trace: The trace of the task run in OpenAI format
19
20    Returns:
21        List of ToolCallMessage and ToolResponseMessage objects extracted from the trace
22    """
23    if trace is None:
24        return []
25
26    tool_messages: list[Union[ToolCallMessage, ToolResponseMessage]] = []
27
28    for message in trace:
29        role = message.get("role")
30
31        if role == "assistant" and "tool_calls" in message:
32            tool_calls = message.get("tool_calls")
33            if tool_calls:
34                content = message.get("content")
35                tool_messages.append(
36                    ToolCallMessage(
37                        role="assistant",
38                        tool_calls=tool_calls,
39                        content=extract_text_from_content(content),
40                    )
41                )
42        elif role == "tool":
43            content = message.get("content")
44            tool_call_id = message.get("tool_call_id")
45
46            if tool_call_id is None:
47                raise ValueError("Tool call ID is required for tool response messages")
48            if content is None:
49                raise ValueError("Content is required for tool response messages")
50
51            if not isinstance(content, str):
52                content = str(content)
53
54            tool_messages.append(
55                ToolResponseMessage(
56                    role="tool",
57                    content=content,
58                    tool_call_id=tool_call_id,
59                )
60            )
61
62    return tool_messages

Extract tool call and tool response messages from a trace. It's based off the OpenAI schema.

Args: trace: The trace of the task run in OpenAI format

Returns: List of ToolCallMessage and ToolResponseMessage objects extracted from the trace

def get_chat_formatter( strategy: ChatStrategy, system_message: str, user_input: Union[Dict[str, Any], List[Any], str], thinking_instructions: str | None = None) -> ChatFormatter:
240def get_chat_formatter(
241    strategy: ChatStrategy,
242    system_message: str,
243    user_input: InputType,
244    thinking_instructions: str | None = None,
245) -> ChatFormatter:
246    match strategy:
247        case ChatStrategy.single_turn:
248            return SingleTurnFormatter(system_message, user_input)
249        case ChatStrategy.two_message_cot_legacy:
250            return TwoMessageCotLegacyFormatter(
251                system_message, user_input, thinking_instructions
252            )
253        case ChatStrategy.two_message_cot:
254            return TwoMessageCotFormatter(
255                system_message, user_input, thinking_instructions
256            )
257        case ChatStrategy.single_turn_r1_thinking:
258            return SingleTurnR1ThinkingFormatter(system_message, user_input)
259        case _:
260            raise_exhaustive_enum_error(strategy)