kiln_ai.adapters.chat
1from .chat_formatter import ( 2 BasicChatMessage, 3 ChatCompletionMessageIncludingLiteLLM, 4 ChatFormatter, 5 ChatMessage, 6 ChatStrategy, 7 MultiturnFormatter, 8 ToolCallMessage, 9 ToolResponseMessage, 10 get_chat_formatter, 11) 12from .chat_utils import build_tool_call_messages 13 14__all__ = [ 15 "BasicChatMessage", 16 "ChatCompletionMessageIncludingLiteLLM", 17 "ChatFormatter", 18 "ChatMessage", 19 "ChatStrategy", 20 "MultiturnFormatter", 21 "ToolCallMessage", 22 "ToolResponseMessage", 23 "build_tool_call_messages", 24 "get_chat_formatter", 25]
88class ChatFormatter(ABC): 89 def __init__( 90 self, 91 system_message: str, 92 user_input: InputType, 93 thinking_instructions: str | None = None, 94 ) -> None: 95 self.system_message = system_message 96 self.user_input = user_input 97 self.thinking_instructions = thinking_instructions 98 self._messages: List[ChatMessage] = [] 99 self._state = "start" 100 self._intermediate_outputs: Dict[str, str] = {} 101 102 @property 103 def messages(self) -> List[ChatMessage]: 104 return list(self._messages) 105 106 def append_messages(self, messages: Sequence[ChatMessage]) -> None: 107 """Append messages to the internal messages list.""" 108 self._messages.extend(messages) 109 110 def message_dicts(self) -> List[dict]: 111 result = [] 112 for m in self._messages: 113 msg_dict = {"role": m.role, "content": m.content} 114 if isinstance(m, ToolCallMessage): 115 msg_dict["tool_calls"] = m.tool_calls 116 elif isinstance(m, ToolResponseMessage): 117 msg_dict["tool_call_id"] = m.tool_call_id 118 result.append(msg_dict) 119 return result 120 121 def intermediate_outputs(self) -> Dict[str, str]: 122 """Get the intermediate outputs from the chat formatter.""" 123 return self._intermediate_outputs 124 125 def initial_messages(self) -> list[ChatCompletionMessageIncludingLiteLLM]: 126 """Messages to seed the conversation. Empty for fresh runs; prior trace for continuation.""" 127 return [] 128 129 @abstractmethod 130 def next_turn(self, previous_output: str | None = None) -> Optional[ChatTurn]: 131 """Advance the conversation and return the next messages if any.""" 132 raise NotImplementedError
Helper class that provides a standard way to create an ABC using inheritance.
106 def append_messages(self, messages: Sequence[ChatMessage]) -> None: 107 """Append messages to the internal messages list.""" 108 self._messages.extend(messages)
Append messages to the internal messages list.
110 def message_dicts(self) -> List[dict]: 111 result = [] 112 for m in self._messages: 113 msg_dict = {"role": m.role, "content": m.content} 114 if isinstance(m, ToolCallMessage): 115 msg_dict["tool_calls"] = m.tool_calls 116 elif isinstance(m, ToolResponseMessage): 117 msg_dict["tool_call_id"] = m.tool_call_id 118 result.append(msg_dict) 119 return result
121 def intermediate_outputs(self) -> Dict[str, str]: 122 """Get the intermediate outputs from the chat formatter.""" 123 return self._intermediate_outputs
Get the intermediate outputs from the chat formatter.
125 def initial_messages(self) -> list[ChatCompletionMessageIncludingLiteLLM]: 126 """Messages to seed the conversation. Empty for fresh runs; prior trace for continuation.""" 127 return []
Messages to seed the conversation. Empty for fresh runs; prior trace for continuation.
129 @abstractmethod 130 def next_turn(self, previous_output: str | None = None) -> Optional[ChatTurn]: 131 """Advance the conversation and return the next messages if any.""" 132 raise NotImplementedError
Advance the conversation and return the next messages if any.
65class ChatStrategy(str, Enum): 66 """Strategy for how a chat is structured.""" 67 68 # Single turn, immediately return the answer 69 single_turn = "final_only" 70 # Two turn, first turn is the thinking, second turn is the answer. Legacy format - used for old fine tunes but not new trains. 71 two_message_cot_legacy = "final_and_intermediate" 72 # Two turn, first turn is the thinking, second turn is the answer. New format - used for new trains. 73 two_message_cot = "two_message_cot" 74 # Single turn, with both the thinking and the answer in the same message, using R1-style thinking format in <think> tags 75 single_turn_r1_thinking = "final_and_intermediate_r1_compatible"
Strategy for how a chat is structured.
275class MultiturnFormatter(ChatFormatter): 276 """ 277 Formatter for continuing a multi-turn conversation with prior trace. 278 Takes prior_trace (existing conversation) and appends the new user message. 279 Produces a single turn: the new user message. Tool calls and multi-turn 280 model responses are handled by _run_model_turn's internal loop. 281 282 When user_input is a dict or list with tool_call_id keys, the input is 283 treated as tool call results (role "tool") rather than a user message. 284 This supports resuming after a return_on_tool_call interrupt. 285 """ 286 287 def __init__( 288 self, 289 prior_trace: list[ChatCompletionMessageParam], 290 user_input: InputType, 291 ) -> None: 292 super().__init__( 293 system_message="", 294 user_input=user_input, 295 thinking_instructions=None, 296 ) 297 self._prior_trace = prior_trace 298 299 def initial_messages(self) -> list[ChatCompletionMessageIncludingLiteLLM]: 300 """Messages to seed the conversation (prior trace).""" 301 return list(self._prior_trace) 302 303 @property 304 def _is_tool_result(self) -> bool: 305 """Return True if user_input looks like one or more tool call results.""" 306 input = self.user_input 307 if isinstance(input, dict): 308 return "tool_call_id" in input 309 if isinstance(input, list): 310 return bool(input) and all( 311 isinstance(item, dict) and "tool_call_id" in item for item in input 312 ) 313 return False 314 315 def next_turn(self, previous_output: str | None = None) -> Optional[ChatTurn]: 316 if self._state == "start": 317 self._state = "awaiting_final" 318 if self._is_tool_result: 319 if isinstance(self.user_input, dict): 320 raw_items: list[dict] = [self.user_input] 321 else: 322 raw_items = list(self.user_input) # type: ignore[arg-type] 323 msgs: list[ChatMessage] = [ 324 ToolResponseMessage( 325 role="tool", 326 content=str(item.get("content", "")), 327 tool_call_id=item["tool_call_id"], 328 is_error=item.get("is_error"), 329 error_message=item.get("error_message"), 330 kiln_task_tool_data=item.get("kiln_task_tool_data"), 331 ) 332 for item in raw_items 333 ] 334 self._messages.extend(msgs) 335 return ChatTurn(messages=msgs, final_call=True) 336 else: 337 # prior trace is already in the messages list and contains system and so on, we only need 338 # to append the latest new user message 339 user_msg = BasicChatMessage( 340 "user", format_user_message(self.user_input) 341 ) 342 self._messages.append(user_msg) 343 return ChatTurn(messages=[user_msg], final_call=True) 344 345 if self._state == "awaiting_final": 346 if previous_output is None: 347 raise ValueError("previous_output required for final step") 348 self._messages.append(BasicChatMessage("assistant", previous_output)) 349 self._state = "done" 350 return None 351 352 return None
Formatter for continuing a multi-turn conversation with prior trace. Takes prior_trace (existing conversation) and appends the new user message. Produces a single turn: the new user message. Tool calls and multi-turn model responses are handled by _run_model_turn's internal loop.
When user_input is a dict or list with tool_call_id keys, the input is treated as tool call results (role "tool") rather than a user message. This supports resuming after a return_on_tool_call interrupt.
299 def initial_messages(self) -> list[ChatCompletionMessageIncludingLiteLLM]: 300 """Messages to seed the conversation (prior trace).""" 301 return list(self._prior_trace)
Messages to seed the conversation (prior trace).
315 def next_turn(self, previous_output: str | None = None) -> Optional[ChatTurn]: 316 if self._state == "start": 317 self._state = "awaiting_final" 318 if self._is_tool_result: 319 if isinstance(self.user_input, dict): 320 raw_items: list[dict] = [self.user_input] 321 else: 322 raw_items = list(self.user_input) # type: ignore[arg-type] 323 msgs: list[ChatMessage] = [ 324 ToolResponseMessage( 325 role="tool", 326 content=str(item.get("content", "")), 327 tool_call_id=item["tool_call_id"], 328 is_error=item.get("is_error"), 329 error_message=item.get("error_message"), 330 kiln_task_tool_data=item.get("kiln_task_tool_data"), 331 ) 332 for item in raw_items 333 ] 334 self._messages.extend(msgs) 335 return ChatTurn(messages=msgs, final_call=True) 336 else: 337 # prior trace is already in the messages list and contains system and so on, we only need 338 # to append the latest new user message 339 user_msg = BasicChatMessage( 340 "user", format_user_message(self.user_input) 341 ) 342 self._messages.append(user_msg) 343 return ChatTurn(messages=[user_msg], final_call=True) 344 345 if self._state == "awaiting_final": 346 if previous_output is None: 347 raise ValueError("previous_output required for final step") 348 self._messages.append(BasicChatMessage("assistant", previous_output)) 349 self._state = "done" 350 return None 351 352 return None
Advance the conversation and return the next messages if any.
32@dataclass 33class ToolCallMessage: 34 """Assistant message with tool calls for chat formatting""" 35 36 role: Literal["assistant"] 37 tool_calls: List[ChatCompletionMessageToolCallParam] 38 content: Optional[str] = None
Assistant message with tool calls for chat formatting
41@dataclass 42class ToolResponseMessage: 43 """Tool response message for chat formatting""" 44 45 role: Literal["tool"] 46 content: str 47 tool_call_id: str 48 is_error: Optional[bool] = None 49 error_message: Optional[str] = None 50 kiln_task_tool_data: Optional[str] = None
Tool response message for chat formatting
11def build_tool_call_messages( 12 trace: list[ChatCompletionMessageParam] | None, 13) -> list[Union[ToolCallMessage, ToolResponseMessage]]: 14 """ 15 Extract tool call and tool response messages from a trace. It's based off the OpenAI schema. 16 17 Args: 18 trace: The trace of the task run in OpenAI format 19 20 Returns: 21 List of ToolCallMessage and ToolResponseMessage objects extracted from the trace 22 """ 23 if trace is None: 24 return [] 25 26 tool_messages: list[Union[ToolCallMessage, ToolResponseMessage]] = [] 27 28 for message in trace: 29 role = message.get("role") 30 31 if role == "assistant" and "tool_calls" in message: 32 tool_calls = message.get("tool_calls") 33 if tool_calls: 34 content = message.get("content") 35 tool_messages.append( 36 ToolCallMessage( 37 role="assistant", 38 tool_calls=tool_calls, 39 content=extract_text_from_content(content), 40 ) 41 ) 42 elif role == "tool": 43 content = message.get("content") 44 tool_call_id = message.get("tool_call_id") 45 46 if tool_call_id is None: 47 raise ValueError("Tool call ID is required for tool response messages") 48 if content is None: 49 raise ValueError("Content is required for tool response messages") 50 51 if not isinstance(content, str): 52 content = str(content) 53 54 tool_messages.append( 55 ToolResponseMessage( 56 role="tool", 57 content=content, 58 tool_call_id=tool_call_id, 59 ) 60 ) 61 62 return tool_messages
Extract tool call and tool response messages from a trace. It's based off the OpenAI schema.
Args: trace: The trace of the task run in OpenAI format
Returns: List of ToolCallMessage and ToolResponseMessage objects extracted from the trace
355def get_chat_formatter( 356 strategy: ChatStrategy, 357 system_message: str, 358 user_input: InputType, 359 thinking_instructions: str | None = None, 360) -> ChatFormatter: 361 match strategy: 362 case ChatStrategy.single_turn: 363 return SingleTurnFormatter(system_message, user_input) 364 case ChatStrategy.two_message_cot_legacy: 365 return TwoMessageCotLegacyFormatter( 366 system_message, user_input, thinking_instructions 367 ) 368 case ChatStrategy.two_message_cot: 369 return TwoMessageCotFormatter( 370 system_message, user_input, thinking_instructions 371 ) 372 case ChatStrategy.single_turn_r1_thinking: 373 return SingleTurnR1ThinkingFormatter(system_message, user_input) 374 case _: 375 raise_exhaustive_enum_error(strategy)