kiln_ai.adapters.model_adapters.litellm_adapter
1import asyncio 2import copy 3import json 4import logging 5import time 6from dataclasses import dataclass 7from typing import Any, Dict, List, Tuple 8 9import litellm 10from litellm.types.utils import ( 11 ChatCompletionMessageToolCall, 12 ChoiceLogprobs, 13 Choices, 14 ModelResponse, 15) 16from litellm.types.utils import Message as LiteLLMMessage 17from litellm.types.utils import Usage as LiteLlmUsage 18from openai.types.chat.chat_completion_message_tool_call_param import ( 19 ChatCompletionMessageToolCallParam, 20) 21 22import kiln_ai.datamodel as datamodel 23from kiln_ai.adapters.chat import ChatCompletionMessageIncludingLiteLLM 24from kiln_ai.adapters.chat.chat_formatter import chat_message_to_dict 25from kiln_ai.adapters.ml_model_list import ( 26 KilnModelProvider, 27 ModelProviderName, 28 StructuredOutputMode, 29) 30from kiln_ai.adapters.model_adapters.adapter_stream import AdapterStream 31from kiln_ai.adapters.model_adapters.base_adapter import ( 32 AdapterConfig, 33 BaseAdapter, 34 MessageUsage, 35 RunOutput, 36 Usage, 37) 38from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig 39from kiln_ai.datamodel.datamodel_enums import InputType 40from kiln_ai.datamodel.json_schema import ( 41 close_object_schemas, 42 validate_schema_with_value_error, 43) 44from kiln_ai.datamodel.run_config import ( 45 KilnAgentRunConfigProperties, 46 as_kiln_agent_run_config, 47) 48from kiln_ai.tools.base_tool import ( 49 KilnToolInterface, 50 ToolCallContext, 51 ToolCallDefinition, 52) 53from kiln_ai.tools.kiln_task_tool import KilnTaskToolResult 54from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error 55from kiln_ai.utils.litellm import get_litellm_provider_info 56from kiln_ai.utils.open_ai_types import ( 57 ChatCompletionAssistantMessageParamWrapper, 58 ChatCompletionMessageParam, 59 ChatCompletionToolMessageParamWrapper, 60 sanitize_messages_for_provider, 61) 62 63MAX_CALLS_PER_TURN = 10 64MAX_TOOL_CALLS_PER_TURN = 30 65 66logger = logging.getLogger(__name__) 67 68 69def _validate_unmanaged_tools(tools: list[KilnToolInterface]) -> None: 70 for i, tool in enumerate(tools): 71 if not isinstance(tool, KilnToolInterface): 72 raise TypeError( 73 f"unmanaged_tools[{i}] must be a KilnToolInterface instance, got {type(tool).__name__}" 74 ) 75 76 77@dataclass 78class ModelTurnResult: 79 assistant_message: str 80 all_messages: list[ChatCompletionMessageIncludingLiteLLM] 81 model_response: ModelResponse | None 82 model_choice: Choices | None 83 usage: Usage 84 interrupted_by_tool_calls: list[ChatCompletionMessageToolCall] | None = None 85 message_latency: dict[int, int] | None = None 86 message_usage: dict[int, MessageUsage] | None = None 87 88 89class LiteLlmAdapter(BaseAdapter): 90 def __init__( 91 self, 92 config: LiteLlmConfig, 93 kiln_task: datamodel.Task, 94 base_adapter_config: AdapterConfig | None = None, 95 ): 96 if not isinstance(config.run_config_properties, KilnAgentRunConfigProperties): 97 raise ValueError("LiteLlmAdapter requires KilnAgentRunConfigProperties") 98 self.config = config 99 self._additional_body_options = config.additional_body_options 100 self._api_base = config.base_url 101 self._headers = config.default_headers 102 self._litellm_model_id: str | None = None 103 self._cached_available_tools: list[KilnToolInterface] | None = None 104 105 super().__init__( 106 task=kiln_task, 107 run_config=config.run_config_properties, 108 config=base_adapter_config, 109 ) 110 111 unmanaged_tools = self.base_adapter_config.unmanaged_tools 112 if unmanaged_tools: 113 _validate_unmanaged_tools(unmanaged_tools) 114 115 async def _run_model_turn( 116 self, 117 provider: KilnModelProvider, 118 messages: list[ChatCompletionMessageIncludingLiteLLM], 119 top_logprobs: int | None, 120 skip_response_format: bool, 121 ) -> ModelTurnResult: 122 """ 123 Call the model for a single top level turn: from user message to agent message. 124 125 It may make handle iterations of tool calls between the user/agent message if needed. 126 127 `messages` is the caller-owned list and is mutated in place (append/extend 128 only — never rebound) so the partial trace survives any exception 129 escaping this method. 130 """ 131 132 usage = Usage() 133 tool_calls_count = 0 134 # Per-LLM-call latency / usage, keyed by index in the messages list. 135 # Kept separate because we don't own the LiteLLM message objects. 136 message_latency: dict[int, int] = {} 137 message_usage: dict[int, MessageUsage] = {} 138 139 while tool_calls_count < MAX_TOOL_CALLS_PER_TURN: 140 # Build completion kwargs for tool calls 141 completion_kwargs = await self.build_completion_kwargs( 142 provider, 143 # Pass a copy, as acompletion mutates objects and breaks types. 144 copy.deepcopy(messages), 145 top_logprobs, 146 skip_response_format, 147 ) 148 149 # Make the completion call (timed) 150 start = time.monotonic() 151 model_response, response_choice = await self.acompletion_checking_response( 152 **completion_kwargs 153 ) 154 call_latency_ms = int((time.monotonic() - start) * 1000) 155 156 # count the usage 157 call_usage = self.usage_from_response(model_response) 158 usage += call_usage 159 usage.total_llm_latency_ms = ( 160 usage.total_llm_latency_ms or 0 161 ) + call_latency_ms 162 163 # Extract content and tool calls 164 if not hasattr(response_choice, "message"): 165 raise ValueError("Response choice has no message") 166 content = response_choice.message.content 167 tool_calls = response_choice.message.tool_calls 168 if not content and not tool_calls: 169 raise ValueError( 170 "Model returned an assistant message, but no content or tool calls. This is not supported." 171 ) 172 173 # Add message to messages, so it can be used in the next turn 174 messages.append(response_choice.message) 175 message_latency[len(messages) - 1] = call_latency_ms 176 message_usage[len(messages) - 1] = call_usage 177 178 # Process tool calls if any 179 if tool_calls and len(tool_calls) > 0: 180 # check if we should return control to caller 181 if self.base_adapter_config.return_on_tool_call: 182 # filter out task_response tool (task_response tools are internal) 183 standard_tool_calls = [ 184 tc for tc in tool_calls if tc.function.name != "task_response" 185 ] 186 has_task_response = any( 187 tc.function.name == "task_response" for tc in tool_calls 188 ) 189 if standard_tool_calls and not has_task_response: 190 return ModelTurnResult( 191 # we don't have any content, we are waiting for toolcall output to come back from client 192 assistant_message="", 193 all_messages=messages, 194 model_response=model_response, 195 model_choice=response_choice, 196 usage=usage, 197 interrupted_by_tool_calls=standard_tool_calls, 198 message_latency=message_latency, 199 message_usage=message_usage, 200 ) 201 202 # otherwise: process tool calls internally until final output 203 ( 204 assistant_message_from_toolcall, 205 tool_call_messages, 206 ) = await self.process_tool_calls(tool_calls) 207 208 # Add tool call results to messages 209 messages.extend(tool_call_messages) 210 211 # If task_response tool was called, we're done 212 if assistant_message_from_toolcall is not None: 213 return ModelTurnResult( 214 assistant_message=assistant_message_from_toolcall, 215 all_messages=messages, 216 model_response=model_response, 217 model_choice=response_choice, 218 usage=usage, 219 message_latency=message_latency, 220 message_usage=message_usage, 221 ) 222 223 # If there were tool calls, increment counter and continue 224 if tool_call_messages: 225 tool_calls_count += 1 226 continue 227 228 # If no tool calls, return the content as final output 229 if content: 230 return ModelTurnResult( 231 assistant_message=content, 232 all_messages=messages, 233 model_response=model_response, 234 model_choice=response_choice, 235 usage=usage, 236 message_latency=message_latency, 237 message_usage=message_usage, 238 ) 239 240 # If we get here with no content and no tool calls, break 241 raise RuntimeError( 242 "Model returned neither content nor tool calls. It must return at least one of these." 243 ) 244 245 raise RuntimeError( 246 f"Too many tool calls ({tool_calls_count}). Stopping iteration to avoid using too many tokens." 247 ) 248 249 async def _run( 250 self, 251 input: InputType, 252 trace_ref: list[ChatCompletionMessageParam], 253 prior_trace: list[ChatCompletionMessageParam] | None = None, 254 ) -> tuple[RunOutput, Usage | None]: 255 usage = Usage() 256 257 provider = self.model_provider() 258 if not provider.model_id: 259 raise ValueError("Model ID is required for OpenAI compatible models") 260 261 # build_chat_formatter returns MultiturnFormatter when prior_trace is set, else prompt-based formatter 262 chat_formatter = self.build_chat_formatter(input, prior_trace) 263 # `trace_ref` is typed as `ChatCompletionMessageParam` for the 264 # caller's API surface, but internally we store LiteLLM `Message` 265 # objects transiently (converted by `all_messages_to_trace` before 266 # export). We widen the type alias once here so the internal 267 # mutations don't each need their own suppression. 268 messages_internal: list[ChatCompletionMessageIncludingLiteLLM] = trace_ref # type: ignore[assignment] 269 messages_internal.extend(copy.deepcopy(chat_formatter.initial_messages())) 270 271 prior_output: str | None = None 272 final_choice: Choices | None = None 273 turns = 0 274 message_latency: dict[int, int] = {} 275 message_usage: dict[int, MessageUsage] = {} 276 277 # Same loop for both fresh runs and prior_trace continuation. 278 # _run_model_turn has its own internal loop for tool calls (model calls tool -> we run it -> model continues). 279 while True: 280 turns += 1 281 if turns > MAX_CALLS_PER_TURN: 282 raise RuntimeError( 283 f"Too many turns ({turns}). Stopping iteration to avoid using too many tokens." 284 ) 285 286 turn = chat_formatter.next_turn(prior_output) 287 if turn is None: 288 # No next turn, we're done 289 break 290 291 # Add messages from the turn to chat history 292 for message in turn.messages: 293 if message.content is None: 294 raise ValueError("Empty message content isn't allowed") 295 messages_internal.append(chat_message_to_dict(message)) # type: ignore 296 297 skip_response_format = not turn.final_call 298 turn_result = await self._run_model_turn( 299 provider, 300 messages_internal, 301 self.base_adapter_config.top_logprobs if turn.final_call else None, 302 skip_response_format, 303 ) 304 305 usage += turn_result.usage 306 if turn_result.message_latency: 307 message_latency.update(turn_result.message_latency) 308 if turn_result.message_usage: 309 message_usage.update(turn_result.message_usage) 310 311 prior_output = turn_result.assistant_message 312 # Update messages_internal in place from turn_result so trace_ref 313 # stays in sync (required by the base adapter error-wrapping contract). 314 messages_internal[:] = turn_result.all_messages 315 final_choice = turn_result.model_choice 316 317 # Check if we were interrupted by tool calls 318 if turn_result.interrupted_by_tool_calls: 319 trace = self.all_messages_to_trace( 320 messages_internal, message_latency, message_usage 321 ) 322 intermediate_outputs = chat_formatter.intermediate_outputs() 323 output = RunOutput( 324 output=prior_output or "", 325 intermediate_outputs=intermediate_outputs, 326 output_logprobs=None, 327 trace=trace, 328 ) 329 return output, usage 330 331 if not prior_output: 332 raise RuntimeError("No assistant message/output returned from model") 333 334 logprobs = self._extract_and_validate_logprobs(final_choice) 335 336 # Save COT/reasoning if it exists. May be a message, or may be parsed by LiteLLM (or openrouter, or anyone upstream) 337 intermediate_outputs = chat_formatter.intermediate_outputs() 338 self._extract_reasoning_to_intermediate_outputs( 339 final_choice, intermediate_outputs 340 ) 341 342 if not isinstance(prior_output, str): 343 raise RuntimeError(f"assistant message is not a string: {prior_output}") 344 345 trace = self.all_messages_to_trace( 346 messages_internal, message_latency, message_usage 347 ) 348 output = RunOutput( 349 output=prior_output, 350 intermediate_outputs=intermediate_outputs, 351 output_logprobs=logprobs, 352 trace=trace, 353 ) 354 355 return output, usage 356 357 def _create_run_stream( 358 self, 359 input: InputType, 360 prior_trace: list[ChatCompletionMessageParam] | None = None, 361 ) -> AdapterStream: 362 provider = self.model_provider() 363 if not provider.model_id: 364 raise ValueError("Model ID is required for OpenAI compatible models") 365 366 chat_formatter = self.build_chat_formatter(input, prior_trace) 367 initial_messages: list[ChatCompletionMessageIncludingLiteLLM] = copy.deepcopy( 368 chat_formatter.initial_messages() 369 ) 370 371 return AdapterStream( 372 adapter=self, 373 provider=provider, 374 chat_formatter=chat_formatter, 375 initial_messages=initial_messages, 376 top_logprobs=self.base_adapter_config.top_logprobs, 377 ) 378 379 def _extract_and_validate_logprobs( 380 self, final_choice: Choices | None 381 ) -> ChoiceLogprobs | None: 382 """ 383 Extract logprobs from the final choice and validate they exist if required. 384 """ 385 logprobs = None 386 if ( 387 final_choice is not None 388 and hasattr(final_choice, "logprobs") 389 and isinstance(final_choice.logprobs, ChoiceLogprobs) 390 ): 391 logprobs = final_choice.logprobs 392 393 # Check logprobs worked, if required 394 if self.base_adapter_config.top_logprobs is not None and logprobs is None: 395 raise RuntimeError("Logprobs were required, but no logprobs were returned.") 396 397 return logprobs 398 399 def _extract_reasoning_to_intermediate_outputs( 400 self, final_choice: Choices | None, intermediate_outputs: Dict[str, Any] 401 ) -> None: 402 """Extract reasoning content from model choice and add to intermediate outputs if present.""" 403 if ( 404 final_choice is not None 405 and hasattr(final_choice, "message") 406 and hasattr(final_choice.message, "reasoning_content") 407 ): 408 reasoning_content = final_choice.message.reasoning_content 409 if reasoning_content is not None: 410 stripped_reasoning_content = reasoning_content.strip() 411 if len(stripped_reasoning_content) > 0: 412 intermediate_outputs["reasoning"] = stripped_reasoning_content 413 414 async def acompletion_checking_response( 415 self, **kwargs: Any 416 ) -> Tuple[ModelResponse, Choices]: 417 response = await litellm.acompletion(**kwargs) 418 419 if ( 420 not isinstance(response, ModelResponse) 421 or not response.choices 422 or len(response.choices) == 0 423 or not isinstance(response.choices[0], Choices) 424 ): 425 raise RuntimeError( 426 f"Expected ModelResponse with Choices, got {type(response)}." 427 ) 428 return response, response.choices[0] 429 430 def adapter_name(self) -> str: 431 return "kiln_openai_compatible_adapter" 432 433 async def response_format_options(self) -> dict[str, Any]: 434 # Unstructured if task isn't structured 435 if not self.has_structured_output(): 436 return {} 437 438 run_config = as_kiln_agent_run_config(self.run_config) 439 structured_output_mode: StructuredOutputMode = run_config.structured_output_mode 440 441 match structured_output_mode: 442 case StructuredOutputMode.json_mode: 443 return {"response_format": {"type": "json_object"}} 444 case StructuredOutputMode.json_schema: 445 return self.json_schema_response_format() 446 case StructuredOutputMode.function_calling_weak: 447 return self.tool_call_params(strict=False) 448 case StructuredOutputMode.function_calling: 449 return self.tool_call_params(strict=True) 450 case StructuredOutputMode.json_instructions: 451 # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below). 452 return {} 453 case StructuredOutputMode.json_custom_instructions: 454 # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above). 455 return {} 456 case StructuredOutputMode.json_instruction_and_object: 457 # We set response_format to json_object and also set json instructions in the prompt 458 return {"response_format": {"type": "json_object"}} 459 case StructuredOutputMode.default: 460 provider_name = run_config.model_provider_name 461 if provider_name == ModelProviderName.ollama: 462 # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs 463 return self.json_schema_response_format() 464 elif provider_name == ModelProviderName.docker_model_runner: 465 # Docker Model Runner uses OpenAI-compatible API with JSON schema support 466 return self.json_schema_response_format() 467 else: 468 # Default to function calling -- it's older than the other modes. Higher compatibility. 469 # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI. 470 strict = provider_name == ModelProviderName.openai 471 return self.tool_call_params(strict=strict) 472 case StructuredOutputMode.unknown: 473 # See above, but this case should never happen. 474 raise ValueError("Structured output mode is unknown.") 475 case _: 476 raise_exhaustive_enum_error(structured_output_mode) # type: ignore[arg-type] 477 478 def json_schema_response_format(self) -> dict[str, Any]: 479 output_schema = self.task.output_schema() 480 if output_schema is None: 481 raise ValueError( 482 "Invalid output schema for this task. Cannot use JSON schema response format." 483 ) 484 output_schema = close_object_schemas(output_schema, strict=True) 485 return { 486 "response_format": { 487 "type": "json_schema", 488 "json_schema": { 489 "name": "task_response", 490 "schema": output_schema, 491 }, 492 } 493 } 494 495 def tool_call_params(self, strict: bool) -> dict[str, Any]: 496 # Add additional_properties: false to the schema (OpenAI requires this for some models) 497 output_schema = self.task.output_schema() 498 if not isinstance(output_schema, dict): 499 raise ValueError( 500 "Invalid output schema for this task. Can not use tool calls." 501 ) 502 output_schema = close_object_schemas(output_schema, strict=strict) 503 504 function_params = { 505 "name": "task_response", 506 "parameters": output_schema, 507 } 508 # This should be on, but we allow setting function_calling_weak for APIs that don't support it. 509 if strict: 510 function_params["strict"] = True 511 512 return { 513 "tools": [ 514 { 515 "type": "function", 516 "function": function_params, 517 } 518 ], 519 "tool_choice": { 520 "type": "function", 521 "function": {"name": "task_response"}, 522 }, 523 } 524 525 def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]: 526 # Don't love having this logic here. But it's worth the usability improvement 527 # so better to keep it than exclude it. Should figure out how I want to isolate 528 # this sort of logic so it's config driven and can be overridden 529 extra_body: dict[str, Any] = {} 530 provider_options = {} 531 532 run_config = as_kiln_agent_run_config(self.run_config) 533 # For legacy config 'thinking_level' is not set, default to provider's default 534 if "thinking_level" in run_config.model_fields_set: 535 thinking_level = run_config.thinking_level 536 else: 537 thinking_level = provider.default_thinking_level 538 539 # Skip if provider doesn't support thinking levels (stale configs may still have one set) 540 if ( 541 thinking_level is not None 542 and provider.available_thinking_levels is not None 543 ): 544 # Anthropic models in OpenRouter uses reasoning object. See https://openrouter.ai/docs/use-cases/reasoning-tokens 545 if ( 546 provider.name == ModelProviderName.openrouter 547 and provider.openrouter_reasoning_object 548 ): 549 extra_body["reasoning"] = {"effort": thinking_level} 550 elif ( 551 provider.name == ModelProviderName.anthropic 552 and thinking_level == "none" 553 ): 554 # Anthropic's native API has no reasoning_effort="none"; passing it makes 555 # litellm map thinking to None and then crash. Omitting reasoning_effort 556 # disables extended thinking, which also frees temperature from the 557 # temperature=1 requirement that applies whenever thinking is enabled. 558 pass 559 else: 560 extra_body["reasoning_effort"] = thinking_level 561 # Opus 4.7/4.8 default thinking display to "omitted", returning empty 562 # thinking text. Request the summary so reasoning is surfaced. litellm 563 # still maps reasoning_effort to output_config.effort; this only adds the 564 # display to the adaptive thinking object. 565 if ( 566 provider.name == ModelProviderName.anthropic 567 and provider.anthropic_summarized_thinking 568 ): 569 extra_body["thinking"] = { 570 "type": "adaptive", 571 "display": "summarized", 572 } 573 574 if provider.require_openrouter_reasoning: 575 # https://openrouter.ai/docs/use-cases/reasoning-tokens 576 extra_body["reasoning"] = { 577 "exclude": False, 578 } 579 580 if provider.gemini_reasoning_enabled: 581 extra_body["reasoning"] = { 582 "enabled": True, 583 } 584 585 if provider.name == ModelProviderName.openrouter: 586 # Ask OpenRouter to include usage in the response (cost) 587 extra_body["usage"] = {"include": True} 588 589 # Set a default provider order for more deterministic routing. 590 # OpenRouter will ignore providers that don't support the model. 591 # Special cases below (like R1) can override this order. 592 # allow_fallbacks is true by default, but we can override it here. 593 provider_options["order"] = [ 594 "fireworks", 595 "parasail", 596 "together", 597 "deepinfra", 598 "novita", 599 "groq", 600 "amazon-bedrock", 601 "azure", 602 "nebius", 603 ] 604 605 if provider.anthropic_extended_thinking and "thinking" not in extra_body: 606 extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000} 607 608 if provider.r1_openrouter_options: 609 # Require providers that support the reasoning parameter 610 provider_options["require_parameters"] = True 611 # Prefer R1 providers with reasonable perf/quants 612 provider_options["order"] = ["fireworks", "together"] 613 # R1 providers with unreasonable quants 614 provider_options["ignore"] = ["deepinfra"] 615 616 # Only set of this request is to get logprobs. 617 if ( 618 provider.logprobs_openrouter_options 619 and self.base_adapter_config.top_logprobs is not None 620 ): 621 # Don't let OpenRouter choose a provider that doesn't support logprobs. 622 provider_options["require_parameters"] = True 623 # DeepInfra silently fails to return logprobs consistently. 624 provider_options["ignore"] = ["deepinfra"] 625 626 if provider.openrouter_skip_required_parameters: 627 # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params. 628 provider_options["require_parameters"] = False 629 630 # Siliconflow uses a bool flag for thinking, for some models 631 if provider.siliconflow_enable_thinking is not None: 632 extra_body["enable_thinking"] = provider.siliconflow_enable_thinking 633 634 if len(provider_options) > 0: 635 extra_body["provider"] = provider_options 636 637 return extra_body 638 639 def litellm_model_id(self) -> str: 640 # The model ID is an interesting combination of format and url endpoint. 641 # It specifics the provider URL/host, but this is overridden if you manually set an api url 642 if self._litellm_model_id: 643 return self._litellm_model_id 644 645 litellm_provider_info = get_litellm_provider_info(self.model_provider()) 646 if litellm_provider_info.is_custom and self._api_base is None: 647 raise ValueError( 648 "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)" 649 ) 650 651 self._litellm_model_id = litellm_provider_info.litellm_model_id 652 return self._litellm_model_id 653 654 def _allowed_openai_params_for_completion_kwargs( 655 self, completion_kwargs: dict[str, Any] 656 ) -> list[str]: 657 """ 658 LiteLLM drops params it thinks are not supported by the model when drop_params is True. Sometimes it is wrong 659 and we know it is supported, so we whitelist them here and pass that as an allowed_openai_params parameter. 660 """ 661 # callers could have set allowed_openai_params in the additional_body_options, so we need to check for that 662 explicit_allowed_params: Any | list = completion_kwargs.get( 663 "allowed_openai_params", [] 664 ) 665 if not isinstance(explicit_allowed_params, list): 666 raise ValueError( 667 f"Unexpected allowed_openai_params format: {explicit_allowed_params} - expected list, got {type(explicit_allowed_params)}" 668 ) 669 explicit_allowed_params_validated = [ 670 param for param in explicit_allowed_params if isinstance(param, str) 671 ] 672 invalid_count = len(explicit_allowed_params) - len( 673 explicit_allowed_params_validated 674 ) 675 if invalid_count > 0: 676 raise ValueError( 677 f"Unexpected allowed_openai_params format: {explicit_allowed_params} - {invalid_count} items are not strings" 678 ) 679 680 # these are our own logic 681 automatic_allowed_params: list[str] = [] 682 if "tools" in completion_kwargs: 683 automatic_allowed_params.append("tools") 684 if "tool_choice" in completion_kwargs: 685 automatic_allowed_params.append("tool_choice") 686 687 return list(set(explicit_allowed_params_validated + automatic_allowed_params)) 688 689 async def build_completion_kwargs( 690 self, 691 provider: KilnModelProvider, 692 messages: list[ChatCompletionMessageIncludingLiteLLM], 693 top_logprobs: int | None, 694 skip_response_format: bool = False, 695 ) -> dict[str, Any]: 696 run_config = as_kiln_agent_run_config(self.run_config) 697 extra_body = self.build_extra_body(provider) 698 699 # Merge all parameters into a single kwargs dict for litellm 700 completion_kwargs = { 701 "model": self.litellm_model_id(), 702 "messages": messages, 703 "api_base": self._api_base, 704 "headers": self._headers, 705 "temperature": run_config.temperature, 706 "top_p": run_config.top_p, 707 # This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc. 708 # Not all models and providers support all openai params (for example, o3 doesn't support top_p) 709 # Better to ignore them than to fail the model call. 710 # https://docs.litellm.ai/docs/completion/input 711 "drop_params": True, 712 **extra_body, 713 **self._additional_body_options, 714 } 715 716 if self.base_adapter_config.automatic_prompt_caching: 717 # Mark the last message for cache control. Litellm's AnthropicCacheControlHook 718 # handles provider-specific injection. Providers auto-cache matching prefixes, 719 # so marking the last message is sufficient for multi-turn conversations. 720 completion_kwargs["cache_control_injection_points"] = [ 721 {"location": "message", "index": -1} 722 ] 723 724 tool_calls = await self.litellm_tools() 725 has_tools = len(tool_calls) > 0 726 if has_tools: 727 completion_kwargs["tools"] = tool_calls 728 completion_kwargs["tool_choice"] = "auto" 729 730 # Special condition for Claude Opus 4.1 and Sonnet 4.5, where we can only specify top_p or temp, not both. 731 # Remove default values (1.0) prioritizing anything the user customized, then error with helpful message if they are both custom. 732 if provider.temp_top_p_exclusive: 733 if "top_p" in completion_kwargs and completion_kwargs["top_p"] == 1.0: 734 del completion_kwargs["top_p"] 735 if ( 736 "temperature" in completion_kwargs 737 and completion_kwargs["temperature"] == 1.0 738 ): 739 del completion_kwargs["temperature"] 740 if "top_p" in completion_kwargs and "temperature" in completion_kwargs: 741 raise ValueError( 742 "top_p and temperature can not both have custom values for this model. This is a restriction from the model provider. Please set only one of them to a custom value (not 1.0)." 743 ) 744 745 if not skip_response_format: 746 # Response format: json_schema, json_instructions, json_mode, function_calling, etc 747 response_format_options = await self.response_format_options() 748 749 # Check for a conflict between tools and response format using tools 750 # We could reconsider this. Model could be able to choose between a final answer or a tool call on any turn. However, good models for tools tend to also support json_schea, so do we need to support both? If we do, merge them, and consider auto vs forced when merging (only forced for final, auto for merged). 751 if has_tools and "tools" in response_format_options: 752 raise ValueError( 753 "Function calling/tools can't be used as the JSON response format if you're also using tools. Please select a different structured output mode." 754 ) 755 756 completion_kwargs.update(response_format_options) 757 758 if top_logprobs is not None: 759 completion_kwargs["logprobs"] = True 760 completion_kwargs["top_logprobs"] = top_logprobs 761 762 # any params listed in this list will be passed to the model regardless of LiteLLM's own validation 763 allowed_openai_params = self._allowed_openai_params_for_completion_kwargs( 764 completion_kwargs 765 ) 766 if len(allowed_openai_params) > 0: 767 completion_kwargs["allowed_openai_params"] = allowed_openai_params 768 769 completion_kwargs["messages"] = sanitize_messages_for_provider(messages) 770 771 return completion_kwargs 772 773 def usage_from_response(self, response: ModelResponse) -> MessageUsage: 774 litellm_usage = response.get("usage", None) 775 776 # LiteLLM isn't consistent in how it returns the cost. 777 cost = response._hidden_params.get("response_cost", None) 778 if cost is None and litellm_usage: 779 cost = litellm_usage.get("cost", None) 780 781 usage = MessageUsage() 782 783 if not litellm_usage and not cost: 784 return usage 785 786 if litellm_usage and isinstance(litellm_usage, LiteLlmUsage): 787 usage.input_tokens = litellm_usage.get("prompt_tokens", None) 788 usage.output_tokens = litellm_usage.get("completion_tokens", None) 789 usage.total_tokens = litellm_usage.get("total_tokens", None) 790 prompt_details = litellm_usage.get("prompt_tokens_details", None) 791 if prompt_details and hasattr(prompt_details, "cached_tokens"): 792 usage.cached_tokens = prompt_details.cached_tokens 793 elif prompt_details: 794 logger.warning( 795 f"prompt_tokens_details has unexpected type {type(prompt_details)}, cached_tokens not extracted" 796 ) 797 else: 798 logger.warning( 799 f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}" 800 ) 801 802 if isinstance(cost, float): 803 usage.cost = cost 804 elif cost is not None: 805 # None is allowed, but no other types are expected 806 logger.warning( 807 f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}" 808 ) 809 810 return usage 811 812 async def cached_available_tools(self) -> list[KilnToolInterface]: 813 if self._cached_available_tools is None: 814 self._cached_available_tools = await self.available_tools() 815 return self._cached_available_tools 816 817 async def _tools_for_execution(self) -> list[KilnToolInterface]: 818 """Registry-resolved tools plus :attr:`AdapterConfig.unmanaged_tools` (same order as ``litellm_tools``).""" 819 registry = await self.cached_available_tools() 820 unmanaged = self.base_adapter_config.unmanaged_tools or [] 821 return registry + unmanaged 822 823 async def litellm_tools(self) -> list[ToolCallDefinition]: 824 available_tools = await self.cached_available_tools() 825 826 registry_defs = [await tool.toolcall_definition() for tool in available_tools] 827 unmanaged = self.base_adapter_config.unmanaged_tools 828 unmanaged_defs = ( 829 [await t.toolcall_definition() for t in unmanaged] if unmanaged else [] 830 ) 831 832 merged = registry_defs + unmanaged_defs 833 seen_names: set[str] = set() 834 for d in merged: 835 name = d["function"]["name"] 836 if name in seen_names: 837 raise ValueError( 838 f"Duplicate tool name {name!r}: unmanaged and registry tools must have unique names." 839 ) 840 seen_names.add(name) 841 842 return merged 843 844 async def process_tool_calls( 845 self, tool_calls: list[ChatCompletionMessageToolCall] | None 846 ) -> tuple[str | None, list[ChatCompletionToolMessageParamWrapper]]: 847 if tool_calls is None: 848 return None, [] 849 850 assistant_output_from_toolcall: str | None = None 851 tool_call_response_messages: list[ChatCompletionToolMessageParamWrapper] = [] 852 tool_run_coroutines = [] 853 854 for tool_call in tool_calls: 855 # Kiln "task_response" tool is used for returning structured output via tool calls. 856 # Load the output from the tool call. Also 857 if tool_call.function.name == "task_response": 858 assistant_output_from_toolcall = tool_call.function.arguments 859 continue 860 861 # Process normal tool calls (not the "task_response" tool) 862 tool_name = tool_call.function.name 863 tool = None 864 for tool_option in await self._tools_for_execution(): 865 if await tool_option.name() == tool_name: 866 tool = tool_option 867 break 868 if not tool: 869 raise RuntimeError( 870 f"A tool named '{tool_name}' was invoked by a model, but was not available." 871 ) 872 873 # Parse the arguments and validate them against the tool's schema 874 try: 875 parsed_args = json.loads(tool_call.function.arguments) 876 except json.JSONDecodeError: 877 raise RuntimeError( 878 f"Failed to parse arguments for tool '{tool_name}' (should be JSON): {tool_call.function.arguments}" 879 ) 880 try: 881 tool_call_definition = await tool.toolcall_definition() 882 json_schema = json.dumps(tool_call_definition["function"]["parameters"]) 883 validate_schema_with_value_error(parsed_args, json_schema) 884 except Exception as e: 885 raise RuntimeError( 886 f"Failed to validate arguments for tool '{tool_name}'. The arguments didn't match the tool's schema. The arguments were: {parsed_args}\n The error was: {e}" 887 ) from e 888 889 # Create context with the calling task's allow_saving setting 890 context = ToolCallContext( 891 allow_saving=self.base_adapter_config.allow_saving 892 ) 893 894 async def run_tool_and_format( 895 t=tool, c=context, args=parsed_args, tc_id=tool_call.id 896 ): 897 result = await t.run(c, **args) 898 return ChatCompletionToolMessageParamWrapper( 899 role="tool", 900 tool_call_id=tc_id, 901 content=result.output, 902 kiln_task_tool_data=result.kiln_task_tool_data 903 if isinstance(result, KilnTaskToolResult) 904 else None, 905 is_error=result.is_error if result.is_error else None, 906 error_message=result.error_message 907 if result.error_message 908 else None, 909 ) 910 911 tool_run_coroutines.append(run_tool_and_format()) 912 913 if tool_run_coroutines: 914 tool_call_response_messages = await asyncio.gather(*tool_run_coroutines) 915 916 if ( 917 assistant_output_from_toolcall is not None 918 and len(tool_call_response_messages) > 0 919 ): 920 raise RuntimeError( 921 "Model asked for impossible combination: task_response tool call and other tool calls were both provided in the same turn. This is not supported as it means the model asked us to both return task_response results (ending the turn) and run new tools calls to send back to the model. If the model makes this mistake often, try a difference structured data model like JSON schema, where this is impossible." 922 ) 923 924 return assistant_output_from_toolcall, tool_call_response_messages 925 926 def litellm_message_to_trace_message( 927 self, 928 raw_message: LiteLLMMessage, 929 latency_ms: int | None = None, 930 usage: MessageUsage | None = None, 931 ) -> ChatCompletionAssistantMessageParamWrapper: 932 """ 933 Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper 934 """ 935 message: ChatCompletionAssistantMessageParamWrapper = { 936 "role": "assistant", 937 } 938 if raw_message.role != "assistant": 939 raise ValueError( 940 "Model returned a message with a role other than assistant. This is not supported." 941 ) 942 943 if hasattr(raw_message, "content"): 944 message["content"] = raw_message.content 945 if hasattr(raw_message, "reasoning_content"): 946 message["reasoning_content"] = raw_message.reasoning_content 947 if hasattr(raw_message, "tool_calls"): 948 # Convert ChatCompletionMessageToolCall to ChatCompletionMessageToolCallParam 949 open_ai_tool_calls: List[ChatCompletionMessageToolCallParam] = [] 950 for litellm_tool_call in raw_message.tool_calls or []: 951 # Optional in the SDK for streaming responses, but should never be None at this point. 952 if litellm_tool_call.function.name is None: 953 raise ValueError( 954 "The model requested a tool call, without providing a function name (required)." 955 ) 956 open_ai_tool_calls.append( 957 ChatCompletionMessageToolCallParam( 958 id=litellm_tool_call.id, 959 type="function", 960 function={ 961 "name": litellm_tool_call.function.name, 962 "arguments": litellm_tool_call.function.arguments, 963 }, 964 ) 965 ) 966 if len(open_ai_tool_calls) > 0: 967 message["tool_calls"] = open_ai_tool_calls 968 969 if latency_ms is not None: 970 message["latency_ms"] = latency_ms 971 972 if usage is not None: 973 message["usage"] = usage 974 975 if not message.get("content") and not message.get("tool_calls"): 976 raise ValueError( 977 "Model returned an assistant message, but no content or tool calls. This is not supported." 978 ) 979 980 return message 981 982 def all_messages_to_trace( 983 self, 984 messages: list[ChatCompletionMessageIncludingLiteLLM], 985 message_latency: dict[int, int] | None = None, 986 message_usage: dict[int, MessageUsage] | None = None, 987 ) -> list[ChatCompletionMessageParam]: 988 """ 989 Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types. 990 991 Non-LiteLLM dict messages pass through unchanged. Any per-message 992 ``usage``/``latency_ms`` already attached to those dicts (e.g. from a 993 seeded prior trace) is preserved. 994 """ 995 trace: list[ChatCompletionMessageParam] = [] 996 for i, message in enumerate(messages): 997 if isinstance(message, LiteLLMMessage): 998 latency_ms = message_latency.get(i) if message_latency else None 999 usage = message_usage.get(i) if message_usage else None 1000 trace.append( 1001 self.litellm_message_to_trace_message(message, latency_ms, usage) 1002 ) 1003 else: 1004 trace.append(message) 1005 return trace 1006 1007 def _messages_to_trace( 1008 self, 1009 messages: list[ChatCompletionMessageParam], 1010 ) -> list[ChatCompletionMessageParam]: 1011 """Override: the `messages` list may transiently contain LiteLLM 1012 Message objects. Normalize to API-safe shapes for export on error. 1013 """ 1014 return self.all_messages_to_trace(messages) # type: ignore[arg-type]
78@dataclass 79class ModelTurnResult: 80 assistant_message: str 81 all_messages: list[ChatCompletionMessageIncludingLiteLLM] 82 model_response: ModelResponse | None 83 model_choice: Choices | None 84 usage: Usage 85 interrupted_by_tool_calls: list[ChatCompletionMessageToolCall] | None = None 86 message_latency: dict[int, int] | None = None 87 message_usage: dict[int, MessageUsage] | None = None
90class LiteLlmAdapter(BaseAdapter): 91 def __init__( 92 self, 93 config: LiteLlmConfig, 94 kiln_task: datamodel.Task, 95 base_adapter_config: AdapterConfig | None = None, 96 ): 97 if not isinstance(config.run_config_properties, KilnAgentRunConfigProperties): 98 raise ValueError("LiteLlmAdapter requires KilnAgentRunConfigProperties") 99 self.config = config 100 self._additional_body_options = config.additional_body_options 101 self._api_base = config.base_url 102 self._headers = config.default_headers 103 self._litellm_model_id: str | None = None 104 self._cached_available_tools: list[KilnToolInterface] | None = None 105 106 super().__init__( 107 task=kiln_task, 108 run_config=config.run_config_properties, 109 config=base_adapter_config, 110 ) 111 112 unmanaged_tools = self.base_adapter_config.unmanaged_tools 113 if unmanaged_tools: 114 _validate_unmanaged_tools(unmanaged_tools) 115 116 async def _run_model_turn( 117 self, 118 provider: KilnModelProvider, 119 messages: list[ChatCompletionMessageIncludingLiteLLM], 120 top_logprobs: int | None, 121 skip_response_format: bool, 122 ) -> ModelTurnResult: 123 """ 124 Call the model for a single top level turn: from user message to agent message. 125 126 It may make handle iterations of tool calls between the user/agent message if needed. 127 128 `messages` is the caller-owned list and is mutated in place (append/extend 129 only — never rebound) so the partial trace survives any exception 130 escaping this method. 131 """ 132 133 usage = Usage() 134 tool_calls_count = 0 135 # Per-LLM-call latency / usage, keyed by index in the messages list. 136 # Kept separate because we don't own the LiteLLM message objects. 137 message_latency: dict[int, int] = {} 138 message_usage: dict[int, MessageUsage] = {} 139 140 while tool_calls_count < MAX_TOOL_CALLS_PER_TURN: 141 # Build completion kwargs for tool calls 142 completion_kwargs = await self.build_completion_kwargs( 143 provider, 144 # Pass a copy, as acompletion mutates objects and breaks types. 145 copy.deepcopy(messages), 146 top_logprobs, 147 skip_response_format, 148 ) 149 150 # Make the completion call (timed) 151 start = time.monotonic() 152 model_response, response_choice = await self.acompletion_checking_response( 153 **completion_kwargs 154 ) 155 call_latency_ms = int((time.monotonic() - start) * 1000) 156 157 # count the usage 158 call_usage = self.usage_from_response(model_response) 159 usage += call_usage 160 usage.total_llm_latency_ms = ( 161 usage.total_llm_latency_ms or 0 162 ) + call_latency_ms 163 164 # Extract content and tool calls 165 if not hasattr(response_choice, "message"): 166 raise ValueError("Response choice has no message") 167 content = response_choice.message.content 168 tool_calls = response_choice.message.tool_calls 169 if not content and not tool_calls: 170 raise ValueError( 171 "Model returned an assistant message, but no content or tool calls. This is not supported." 172 ) 173 174 # Add message to messages, so it can be used in the next turn 175 messages.append(response_choice.message) 176 message_latency[len(messages) - 1] = call_latency_ms 177 message_usage[len(messages) - 1] = call_usage 178 179 # Process tool calls if any 180 if tool_calls and len(tool_calls) > 0: 181 # check if we should return control to caller 182 if self.base_adapter_config.return_on_tool_call: 183 # filter out task_response tool (task_response tools are internal) 184 standard_tool_calls = [ 185 tc for tc in tool_calls if tc.function.name != "task_response" 186 ] 187 has_task_response = any( 188 tc.function.name == "task_response" for tc in tool_calls 189 ) 190 if standard_tool_calls and not has_task_response: 191 return ModelTurnResult( 192 # we don't have any content, we are waiting for toolcall output to come back from client 193 assistant_message="", 194 all_messages=messages, 195 model_response=model_response, 196 model_choice=response_choice, 197 usage=usage, 198 interrupted_by_tool_calls=standard_tool_calls, 199 message_latency=message_latency, 200 message_usage=message_usage, 201 ) 202 203 # otherwise: process tool calls internally until final output 204 ( 205 assistant_message_from_toolcall, 206 tool_call_messages, 207 ) = await self.process_tool_calls(tool_calls) 208 209 # Add tool call results to messages 210 messages.extend(tool_call_messages) 211 212 # If task_response tool was called, we're done 213 if assistant_message_from_toolcall is not None: 214 return ModelTurnResult( 215 assistant_message=assistant_message_from_toolcall, 216 all_messages=messages, 217 model_response=model_response, 218 model_choice=response_choice, 219 usage=usage, 220 message_latency=message_latency, 221 message_usage=message_usage, 222 ) 223 224 # If there were tool calls, increment counter and continue 225 if tool_call_messages: 226 tool_calls_count += 1 227 continue 228 229 # If no tool calls, return the content as final output 230 if content: 231 return ModelTurnResult( 232 assistant_message=content, 233 all_messages=messages, 234 model_response=model_response, 235 model_choice=response_choice, 236 usage=usage, 237 message_latency=message_latency, 238 message_usage=message_usage, 239 ) 240 241 # If we get here with no content and no tool calls, break 242 raise RuntimeError( 243 "Model returned neither content nor tool calls. It must return at least one of these." 244 ) 245 246 raise RuntimeError( 247 f"Too many tool calls ({tool_calls_count}). Stopping iteration to avoid using too many tokens." 248 ) 249 250 async def _run( 251 self, 252 input: InputType, 253 trace_ref: list[ChatCompletionMessageParam], 254 prior_trace: list[ChatCompletionMessageParam] | None = None, 255 ) -> tuple[RunOutput, Usage | None]: 256 usage = Usage() 257 258 provider = self.model_provider() 259 if not provider.model_id: 260 raise ValueError("Model ID is required for OpenAI compatible models") 261 262 # build_chat_formatter returns MultiturnFormatter when prior_trace is set, else prompt-based formatter 263 chat_formatter = self.build_chat_formatter(input, prior_trace) 264 # `trace_ref` is typed as `ChatCompletionMessageParam` for the 265 # caller's API surface, but internally we store LiteLLM `Message` 266 # objects transiently (converted by `all_messages_to_trace` before 267 # export). We widen the type alias once here so the internal 268 # mutations don't each need their own suppression. 269 messages_internal: list[ChatCompletionMessageIncludingLiteLLM] = trace_ref # type: ignore[assignment] 270 messages_internal.extend(copy.deepcopy(chat_formatter.initial_messages())) 271 272 prior_output: str | None = None 273 final_choice: Choices | None = None 274 turns = 0 275 message_latency: dict[int, int] = {} 276 message_usage: dict[int, MessageUsage] = {} 277 278 # Same loop for both fresh runs and prior_trace continuation. 279 # _run_model_turn has its own internal loop for tool calls (model calls tool -> we run it -> model continues). 280 while True: 281 turns += 1 282 if turns > MAX_CALLS_PER_TURN: 283 raise RuntimeError( 284 f"Too many turns ({turns}). Stopping iteration to avoid using too many tokens." 285 ) 286 287 turn = chat_formatter.next_turn(prior_output) 288 if turn is None: 289 # No next turn, we're done 290 break 291 292 # Add messages from the turn to chat history 293 for message in turn.messages: 294 if message.content is None: 295 raise ValueError("Empty message content isn't allowed") 296 messages_internal.append(chat_message_to_dict(message)) # type: ignore 297 298 skip_response_format = not turn.final_call 299 turn_result = await self._run_model_turn( 300 provider, 301 messages_internal, 302 self.base_adapter_config.top_logprobs if turn.final_call else None, 303 skip_response_format, 304 ) 305 306 usage += turn_result.usage 307 if turn_result.message_latency: 308 message_latency.update(turn_result.message_latency) 309 if turn_result.message_usage: 310 message_usage.update(turn_result.message_usage) 311 312 prior_output = turn_result.assistant_message 313 # Update messages_internal in place from turn_result so trace_ref 314 # stays in sync (required by the base adapter error-wrapping contract). 315 messages_internal[:] = turn_result.all_messages 316 final_choice = turn_result.model_choice 317 318 # Check if we were interrupted by tool calls 319 if turn_result.interrupted_by_tool_calls: 320 trace = self.all_messages_to_trace( 321 messages_internal, message_latency, message_usage 322 ) 323 intermediate_outputs = chat_formatter.intermediate_outputs() 324 output = RunOutput( 325 output=prior_output or "", 326 intermediate_outputs=intermediate_outputs, 327 output_logprobs=None, 328 trace=trace, 329 ) 330 return output, usage 331 332 if not prior_output: 333 raise RuntimeError("No assistant message/output returned from model") 334 335 logprobs = self._extract_and_validate_logprobs(final_choice) 336 337 # Save COT/reasoning if it exists. May be a message, or may be parsed by LiteLLM (or openrouter, or anyone upstream) 338 intermediate_outputs = chat_formatter.intermediate_outputs() 339 self._extract_reasoning_to_intermediate_outputs( 340 final_choice, intermediate_outputs 341 ) 342 343 if not isinstance(prior_output, str): 344 raise RuntimeError(f"assistant message is not a string: {prior_output}") 345 346 trace = self.all_messages_to_trace( 347 messages_internal, message_latency, message_usage 348 ) 349 output = RunOutput( 350 output=prior_output, 351 intermediate_outputs=intermediate_outputs, 352 output_logprobs=logprobs, 353 trace=trace, 354 ) 355 356 return output, usage 357 358 def _create_run_stream( 359 self, 360 input: InputType, 361 prior_trace: list[ChatCompletionMessageParam] | None = None, 362 ) -> AdapterStream: 363 provider = self.model_provider() 364 if not provider.model_id: 365 raise ValueError("Model ID is required for OpenAI compatible models") 366 367 chat_formatter = self.build_chat_formatter(input, prior_trace) 368 initial_messages: list[ChatCompletionMessageIncludingLiteLLM] = copy.deepcopy( 369 chat_formatter.initial_messages() 370 ) 371 372 return AdapterStream( 373 adapter=self, 374 provider=provider, 375 chat_formatter=chat_formatter, 376 initial_messages=initial_messages, 377 top_logprobs=self.base_adapter_config.top_logprobs, 378 ) 379 380 def _extract_and_validate_logprobs( 381 self, final_choice: Choices | None 382 ) -> ChoiceLogprobs | None: 383 """ 384 Extract logprobs from the final choice and validate they exist if required. 385 """ 386 logprobs = None 387 if ( 388 final_choice is not None 389 and hasattr(final_choice, "logprobs") 390 and isinstance(final_choice.logprobs, ChoiceLogprobs) 391 ): 392 logprobs = final_choice.logprobs 393 394 # Check logprobs worked, if required 395 if self.base_adapter_config.top_logprobs is not None and logprobs is None: 396 raise RuntimeError("Logprobs were required, but no logprobs were returned.") 397 398 return logprobs 399 400 def _extract_reasoning_to_intermediate_outputs( 401 self, final_choice: Choices | None, intermediate_outputs: Dict[str, Any] 402 ) -> None: 403 """Extract reasoning content from model choice and add to intermediate outputs if present.""" 404 if ( 405 final_choice is not None 406 and hasattr(final_choice, "message") 407 and hasattr(final_choice.message, "reasoning_content") 408 ): 409 reasoning_content = final_choice.message.reasoning_content 410 if reasoning_content is not None: 411 stripped_reasoning_content = reasoning_content.strip() 412 if len(stripped_reasoning_content) > 0: 413 intermediate_outputs["reasoning"] = stripped_reasoning_content 414 415 async def acompletion_checking_response( 416 self, **kwargs: Any 417 ) -> Tuple[ModelResponse, Choices]: 418 response = await litellm.acompletion(**kwargs) 419 420 if ( 421 not isinstance(response, ModelResponse) 422 or not response.choices 423 or len(response.choices) == 0 424 or not isinstance(response.choices[0], Choices) 425 ): 426 raise RuntimeError( 427 f"Expected ModelResponse with Choices, got {type(response)}." 428 ) 429 return response, response.choices[0] 430 431 def adapter_name(self) -> str: 432 return "kiln_openai_compatible_adapter" 433 434 async def response_format_options(self) -> dict[str, Any]: 435 # Unstructured if task isn't structured 436 if not self.has_structured_output(): 437 return {} 438 439 run_config = as_kiln_agent_run_config(self.run_config) 440 structured_output_mode: StructuredOutputMode = run_config.structured_output_mode 441 442 match structured_output_mode: 443 case StructuredOutputMode.json_mode: 444 return {"response_format": {"type": "json_object"}} 445 case StructuredOutputMode.json_schema: 446 return self.json_schema_response_format() 447 case StructuredOutputMode.function_calling_weak: 448 return self.tool_call_params(strict=False) 449 case StructuredOutputMode.function_calling: 450 return self.tool_call_params(strict=True) 451 case StructuredOutputMode.json_instructions: 452 # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below). 453 return {} 454 case StructuredOutputMode.json_custom_instructions: 455 # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above). 456 return {} 457 case StructuredOutputMode.json_instruction_and_object: 458 # We set response_format to json_object and also set json instructions in the prompt 459 return {"response_format": {"type": "json_object"}} 460 case StructuredOutputMode.default: 461 provider_name = run_config.model_provider_name 462 if provider_name == ModelProviderName.ollama: 463 # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs 464 return self.json_schema_response_format() 465 elif provider_name == ModelProviderName.docker_model_runner: 466 # Docker Model Runner uses OpenAI-compatible API with JSON schema support 467 return self.json_schema_response_format() 468 else: 469 # Default to function calling -- it's older than the other modes. Higher compatibility. 470 # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI. 471 strict = provider_name == ModelProviderName.openai 472 return self.tool_call_params(strict=strict) 473 case StructuredOutputMode.unknown: 474 # See above, but this case should never happen. 475 raise ValueError("Structured output mode is unknown.") 476 case _: 477 raise_exhaustive_enum_error(structured_output_mode) # type: ignore[arg-type] 478 479 def json_schema_response_format(self) -> dict[str, Any]: 480 output_schema = self.task.output_schema() 481 if output_schema is None: 482 raise ValueError( 483 "Invalid output schema for this task. Cannot use JSON schema response format." 484 ) 485 output_schema = close_object_schemas(output_schema, strict=True) 486 return { 487 "response_format": { 488 "type": "json_schema", 489 "json_schema": { 490 "name": "task_response", 491 "schema": output_schema, 492 }, 493 } 494 } 495 496 def tool_call_params(self, strict: bool) -> dict[str, Any]: 497 # Add additional_properties: false to the schema (OpenAI requires this for some models) 498 output_schema = self.task.output_schema() 499 if not isinstance(output_schema, dict): 500 raise ValueError( 501 "Invalid output schema for this task. Can not use tool calls." 502 ) 503 output_schema = close_object_schemas(output_schema, strict=strict) 504 505 function_params = { 506 "name": "task_response", 507 "parameters": output_schema, 508 } 509 # This should be on, but we allow setting function_calling_weak for APIs that don't support it. 510 if strict: 511 function_params["strict"] = True 512 513 return { 514 "tools": [ 515 { 516 "type": "function", 517 "function": function_params, 518 } 519 ], 520 "tool_choice": { 521 "type": "function", 522 "function": {"name": "task_response"}, 523 }, 524 } 525 526 def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]: 527 # Don't love having this logic here. But it's worth the usability improvement 528 # so better to keep it than exclude it. Should figure out how I want to isolate 529 # this sort of logic so it's config driven and can be overridden 530 extra_body: dict[str, Any] = {} 531 provider_options = {} 532 533 run_config = as_kiln_agent_run_config(self.run_config) 534 # For legacy config 'thinking_level' is not set, default to provider's default 535 if "thinking_level" in run_config.model_fields_set: 536 thinking_level = run_config.thinking_level 537 else: 538 thinking_level = provider.default_thinking_level 539 540 # Skip if provider doesn't support thinking levels (stale configs may still have one set) 541 if ( 542 thinking_level is not None 543 and provider.available_thinking_levels is not None 544 ): 545 # Anthropic models in OpenRouter uses reasoning object. See https://openrouter.ai/docs/use-cases/reasoning-tokens 546 if ( 547 provider.name == ModelProviderName.openrouter 548 and provider.openrouter_reasoning_object 549 ): 550 extra_body["reasoning"] = {"effort": thinking_level} 551 elif ( 552 provider.name == ModelProviderName.anthropic 553 and thinking_level == "none" 554 ): 555 # Anthropic's native API has no reasoning_effort="none"; passing it makes 556 # litellm map thinking to None and then crash. Omitting reasoning_effort 557 # disables extended thinking, which also frees temperature from the 558 # temperature=1 requirement that applies whenever thinking is enabled. 559 pass 560 else: 561 extra_body["reasoning_effort"] = thinking_level 562 # Opus 4.7/4.8 default thinking display to "omitted", returning empty 563 # thinking text. Request the summary so reasoning is surfaced. litellm 564 # still maps reasoning_effort to output_config.effort; this only adds the 565 # display to the adaptive thinking object. 566 if ( 567 provider.name == ModelProviderName.anthropic 568 and provider.anthropic_summarized_thinking 569 ): 570 extra_body["thinking"] = { 571 "type": "adaptive", 572 "display": "summarized", 573 } 574 575 if provider.require_openrouter_reasoning: 576 # https://openrouter.ai/docs/use-cases/reasoning-tokens 577 extra_body["reasoning"] = { 578 "exclude": False, 579 } 580 581 if provider.gemini_reasoning_enabled: 582 extra_body["reasoning"] = { 583 "enabled": True, 584 } 585 586 if provider.name == ModelProviderName.openrouter: 587 # Ask OpenRouter to include usage in the response (cost) 588 extra_body["usage"] = {"include": True} 589 590 # Set a default provider order for more deterministic routing. 591 # OpenRouter will ignore providers that don't support the model. 592 # Special cases below (like R1) can override this order. 593 # allow_fallbacks is true by default, but we can override it here. 594 provider_options["order"] = [ 595 "fireworks", 596 "parasail", 597 "together", 598 "deepinfra", 599 "novita", 600 "groq", 601 "amazon-bedrock", 602 "azure", 603 "nebius", 604 ] 605 606 if provider.anthropic_extended_thinking and "thinking" not in extra_body: 607 extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000} 608 609 if provider.r1_openrouter_options: 610 # Require providers that support the reasoning parameter 611 provider_options["require_parameters"] = True 612 # Prefer R1 providers with reasonable perf/quants 613 provider_options["order"] = ["fireworks", "together"] 614 # R1 providers with unreasonable quants 615 provider_options["ignore"] = ["deepinfra"] 616 617 # Only set of this request is to get logprobs. 618 if ( 619 provider.logprobs_openrouter_options 620 and self.base_adapter_config.top_logprobs is not None 621 ): 622 # Don't let OpenRouter choose a provider that doesn't support logprobs. 623 provider_options["require_parameters"] = True 624 # DeepInfra silently fails to return logprobs consistently. 625 provider_options["ignore"] = ["deepinfra"] 626 627 if provider.openrouter_skip_required_parameters: 628 # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params. 629 provider_options["require_parameters"] = False 630 631 # Siliconflow uses a bool flag for thinking, for some models 632 if provider.siliconflow_enable_thinking is not None: 633 extra_body["enable_thinking"] = provider.siliconflow_enable_thinking 634 635 if len(provider_options) > 0: 636 extra_body["provider"] = provider_options 637 638 return extra_body 639 640 def litellm_model_id(self) -> str: 641 # The model ID is an interesting combination of format and url endpoint. 642 # It specifics the provider URL/host, but this is overridden if you manually set an api url 643 if self._litellm_model_id: 644 return self._litellm_model_id 645 646 litellm_provider_info = get_litellm_provider_info(self.model_provider()) 647 if litellm_provider_info.is_custom and self._api_base is None: 648 raise ValueError( 649 "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)" 650 ) 651 652 self._litellm_model_id = litellm_provider_info.litellm_model_id 653 return self._litellm_model_id 654 655 def _allowed_openai_params_for_completion_kwargs( 656 self, completion_kwargs: dict[str, Any] 657 ) -> list[str]: 658 """ 659 LiteLLM drops params it thinks are not supported by the model when drop_params is True. Sometimes it is wrong 660 and we know it is supported, so we whitelist them here and pass that as an allowed_openai_params parameter. 661 """ 662 # callers could have set allowed_openai_params in the additional_body_options, so we need to check for that 663 explicit_allowed_params: Any | list = completion_kwargs.get( 664 "allowed_openai_params", [] 665 ) 666 if not isinstance(explicit_allowed_params, list): 667 raise ValueError( 668 f"Unexpected allowed_openai_params format: {explicit_allowed_params} - expected list, got {type(explicit_allowed_params)}" 669 ) 670 explicit_allowed_params_validated = [ 671 param for param in explicit_allowed_params if isinstance(param, str) 672 ] 673 invalid_count = len(explicit_allowed_params) - len( 674 explicit_allowed_params_validated 675 ) 676 if invalid_count > 0: 677 raise ValueError( 678 f"Unexpected allowed_openai_params format: {explicit_allowed_params} - {invalid_count} items are not strings" 679 ) 680 681 # these are our own logic 682 automatic_allowed_params: list[str] = [] 683 if "tools" in completion_kwargs: 684 automatic_allowed_params.append("tools") 685 if "tool_choice" in completion_kwargs: 686 automatic_allowed_params.append("tool_choice") 687 688 return list(set(explicit_allowed_params_validated + automatic_allowed_params)) 689 690 async def build_completion_kwargs( 691 self, 692 provider: KilnModelProvider, 693 messages: list[ChatCompletionMessageIncludingLiteLLM], 694 top_logprobs: int | None, 695 skip_response_format: bool = False, 696 ) -> dict[str, Any]: 697 run_config = as_kiln_agent_run_config(self.run_config) 698 extra_body = self.build_extra_body(provider) 699 700 # Merge all parameters into a single kwargs dict for litellm 701 completion_kwargs = { 702 "model": self.litellm_model_id(), 703 "messages": messages, 704 "api_base": self._api_base, 705 "headers": self._headers, 706 "temperature": run_config.temperature, 707 "top_p": run_config.top_p, 708 # This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc. 709 # Not all models and providers support all openai params (for example, o3 doesn't support top_p) 710 # Better to ignore them than to fail the model call. 711 # https://docs.litellm.ai/docs/completion/input 712 "drop_params": True, 713 **extra_body, 714 **self._additional_body_options, 715 } 716 717 if self.base_adapter_config.automatic_prompt_caching: 718 # Mark the last message for cache control. Litellm's AnthropicCacheControlHook 719 # handles provider-specific injection. Providers auto-cache matching prefixes, 720 # so marking the last message is sufficient for multi-turn conversations. 721 completion_kwargs["cache_control_injection_points"] = [ 722 {"location": "message", "index": -1} 723 ] 724 725 tool_calls = await self.litellm_tools() 726 has_tools = len(tool_calls) > 0 727 if has_tools: 728 completion_kwargs["tools"] = tool_calls 729 completion_kwargs["tool_choice"] = "auto" 730 731 # Special condition for Claude Opus 4.1 and Sonnet 4.5, where we can only specify top_p or temp, not both. 732 # Remove default values (1.0) prioritizing anything the user customized, then error with helpful message if they are both custom. 733 if provider.temp_top_p_exclusive: 734 if "top_p" in completion_kwargs and completion_kwargs["top_p"] == 1.0: 735 del completion_kwargs["top_p"] 736 if ( 737 "temperature" in completion_kwargs 738 and completion_kwargs["temperature"] == 1.0 739 ): 740 del completion_kwargs["temperature"] 741 if "top_p" in completion_kwargs and "temperature" in completion_kwargs: 742 raise ValueError( 743 "top_p and temperature can not both have custom values for this model. This is a restriction from the model provider. Please set only one of them to a custom value (not 1.0)." 744 ) 745 746 if not skip_response_format: 747 # Response format: json_schema, json_instructions, json_mode, function_calling, etc 748 response_format_options = await self.response_format_options() 749 750 # Check for a conflict between tools and response format using tools 751 # We could reconsider this. Model could be able to choose between a final answer or a tool call on any turn. However, good models for tools tend to also support json_schea, so do we need to support both? If we do, merge them, and consider auto vs forced when merging (only forced for final, auto for merged). 752 if has_tools and "tools" in response_format_options: 753 raise ValueError( 754 "Function calling/tools can't be used as the JSON response format if you're also using tools. Please select a different structured output mode." 755 ) 756 757 completion_kwargs.update(response_format_options) 758 759 if top_logprobs is not None: 760 completion_kwargs["logprobs"] = True 761 completion_kwargs["top_logprobs"] = top_logprobs 762 763 # any params listed in this list will be passed to the model regardless of LiteLLM's own validation 764 allowed_openai_params = self._allowed_openai_params_for_completion_kwargs( 765 completion_kwargs 766 ) 767 if len(allowed_openai_params) > 0: 768 completion_kwargs["allowed_openai_params"] = allowed_openai_params 769 770 completion_kwargs["messages"] = sanitize_messages_for_provider(messages) 771 772 return completion_kwargs 773 774 def usage_from_response(self, response: ModelResponse) -> MessageUsage: 775 litellm_usage = response.get("usage", None) 776 777 # LiteLLM isn't consistent in how it returns the cost. 778 cost = response._hidden_params.get("response_cost", None) 779 if cost is None and litellm_usage: 780 cost = litellm_usage.get("cost", None) 781 782 usage = MessageUsage() 783 784 if not litellm_usage and not cost: 785 return usage 786 787 if litellm_usage and isinstance(litellm_usage, LiteLlmUsage): 788 usage.input_tokens = litellm_usage.get("prompt_tokens", None) 789 usage.output_tokens = litellm_usage.get("completion_tokens", None) 790 usage.total_tokens = litellm_usage.get("total_tokens", None) 791 prompt_details = litellm_usage.get("prompt_tokens_details", None) 792 if prompt_details and hasattr(prompt_details, "cached_tokens"): 793 usage.cached_tokens = prompt_details.cached_tokens 794 elif prompt_details: 795 logger.warning( 796 f"prompt_tokens_details has unexpected type {type(prompt_details)}, cached_tokens not extracted" 797 ) 798 else: 799 logger.warning( 800 f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}" 801 ) 802 803 if isinstance(cost, float): 804 usage.cost = cost 805 elif cost is not None: 806 # None is allowed, but no other types are expected 807 logger.warning( 808 f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}" 809 ) 810 811 return usage 812 813 async def cached_available_tools(self) -> list[KilnToolInterface]: 814 if self._cached_available_tools is None: 815 self._cached_available_tools = await self.available_tools() 816 return self._cached_available_tools 817 818 async def _tools_for_execution(self) -> list[KilnToolInterface]: 819 """Registry-resolved tools plus :attr:`AdapterConfig.unmanaged_tools` (same order as ``litellm_tools``).""" 820 registry = await self.cached_available_tools() 821 unmanaged = self.base_adapter_config.unmanaged_tools or [] 822 return registry + unmanaged 823 824 async def litellm_tools(self) -> list[ToolCallDefinition]: 825 available_tools = await self.cached_available_tools() 826 827 registry_defs = [await tool.toolcall_definition() for tool in available_tools] 828 unmanaged = self.base_adapter_config.unmanaged_tools 829 unmanaged_defs = ( 830 [await t.toolcall_definition() for t in unmanaged] if unmanaged else [] 831 ) 832 833 merged = registry_defs + unmanaged_defs 834 seen_names: set[str] = set() 835 for d in merged: 836 name = d["function"]["name"] 837 if name in seen_names: 838 raise ValueError( 839 f"Duplicate tool name {name!r}: unmanaged and registry tools must have unique names." 840 ) 841 seen_names.add(name) 842 843 return merged 844 845 async def process_tool_calls( 846 self, tool_calls: list[ChatCompletionMessageToolCall] | None 847 ) -> tuple[str | None, list[ChatCompletionToolMessageParamWrapper]]: 848 if tool_calls is None: 849 return None, [] 850 851 assistant_output_from_toolcall: str | None = None 852 tool_call_response_messages: list[ChatCompletionToolMessageParamWrapper] = [] 853 tool_run_coroutines = [] 854 855 for tool_call in tool_calls: 856 # Kiln "task_response" tool is used for returning structured output via tool calls. 857 # Load the output from the tool call. Also 858 if tool_call.function.name == "task_response": 859 assistant_output_from_toolcall = tool_call.function.arguments 860 continue 861 862 # Process normal tool calls (not the "task_response" tool) 863 tool_name = tool_call.function.name 864 tool = None 865 for tool_option in await self._tools_for_execution(): 866 if await tool_option.name() == tool_name: 867 tool = tool_option 868 break 869 if not tool: 870 raise RuntimeError( 871 f"A tool named '{tool_name}' was invoked by a model, but was not available." 872 ) 873 874 # Parse the arguments and validate them against the tool's schema 875 try: 876 parsed_args = json.loads(tool_call.function.arguments) 877 except json.JSONDecodeError: 878 raise RuntimeError( 879 f"Failed to parse arguments for tool '{tool_name}' (should be JSON): {tool_call.function.arguments}" 880 ) 881 try: 882 tool_call_definition = await tool.toolcall_definition() 883 json_schema = json.dumps(tool_call_definition["function"]["parameters"]) 884 validate_schema_with_value_error(parsed_args, json_schema) 885 except Exception as e: 886 raise RuntimeError( 887 f"Failed to validate arguments for tool '{tool_name}'. The arguments didn't match the tool's schema. The arguments were: {parsed_args}\n The error was: {e}" 888 ) from e 889 890 # Create context with the calling task's allow_saving setting 891 context = ToolCallContext( 892 allow_saving=self.base_adapter_config.allow_saving 893 ) 894 895 async def run_tool_and_format( 896 t=tool, c=context, args=parsed_args, tc_id=tool_call.id 897 ): 898 result = await t.run(c, **args) 899 return ChatCompletionToolMessageParamWrapper( 900 role="tool", 901 tool_call_id=tc_id, 902 content=result.output, 903 kiln_task_tool_data=result.kiln_task_tool_data 904 if isinstance(result, KilnTaskToolResult) 905 else None, 906 is_error=result.is_error if result.is_error else None, 907 error_message=result.error_message 908 if result.error_message 909 else None, 910 ) 911 912 tool_run_coroutines.append(run_tool_and_format()) 913 914 if tool_run_coroutines: 915 tool_call_response_messages = await asyncio.gather(*tool_run_coroutines) 916 917 if ( 918 assistant_output_from_toolcall is not None 919 and len(tool_call_response_messages) > 0 920 ): 921 raise RuntimeError( 922 "Model asked for impossible combination: task_response tool call and other tool calls were both provided in the same turn. This is not supported as it means the model asked us to both return task_response results (ending the turn) and run new tools calls to send back to the model. If the model makes this mistake often, try a difference structured data model like JSON schema, where this is impossible." 923 ) 924 925 return assistant_output_from_toolcall, tool_call_response_messages 926 927 def litellm_message_to_trace_message( 928 self, 929 raw_message: LiteLLMMessage, 930 latency_ms: int | None = None, 931 usage: MessageUsage | None = None, 932 ) -> ChatCompletionAssistantMessageParamWrapper: 933 """ 934 Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper 935 """ 936 message: ChatCompletionAssistantMessageParamWrapper = { 937 "role": "assistant", 938 } 939 if raw_message.role != "assistant": 940 raise ValueError( 941 "Model returned a message with a role other than assistant. This is not supported." 942 ) 943 944 if hasattr(raw_message, "content"): 945 message["content"] = raw_message.content 946 if hasattr(raw_message, "reasoning_content"): 947 message["reasoning_content"] = raw_message.reasoning_content 948 if hasattr(raw_message, "tool_calls"): 949 # Convert ChatCompletionMessageToolCall to ChatCompletionMessageToolCallParam 950 open_ai_tool_calls: List[ChatCompletionMessageToolCallParam] = [] 951 for litellm_tool_call in raw_message.tool_calls or []: 952 # Optional in the SDK for streaming responses, but should never be None at this point. 953 if litellm_tool_call.function.name is None: 954 raise ValueError( 955 "The model requested a tool call, without providing a function name (required)." 956 ) 957 open_ai_tool_calls.append( 958 ChatCompletionMessageToolCallParam( 959 id=litellm_tool_call.id, 960 type="function", 961 function={ 962 "name": litellm_tool_call.function.name, 963 "arguments": litellm_tool_call.function.arguments, 964 }, 965 ) 966 ) 967 if len(open_ai_tool_calls) > 0: 968 message["tool_calls"] = open_ai_tool_calls 969 970 if latency_ms is not None: 971 message["latency_ms"] = latency_ms 972 973 if usage is not None: 974 message["usage"] = usage 975 976 if not message.get("content") and not message.get("tool_calls"): 977 raise ValueError( 978 "Model returned an assistant message, but no content or tool calls. This is not supported." 979 ) 980 981 return message 982 983 def all_messages_to_trace( 984 self, 985 messages: list[ChatCompletionMessageIncludingLiteLLM], 986 message_latency: dict[int, int] | None = None, 987 message_usage: dict[int, MessageUsage] | None = None, 988 ) -> list[ChatCompletionMessageParam]: 989 """ 990 Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types. 991 992 Non-LiteLLM dict messages pass through unchanged. Any per-message 993 ``usage``/``latency_ms`` already attached to those dicts (e.g. from a 994 seeded prior trace) is preserved. 995 """ 996 trace: list[ChatCompletionMessageParam] = [] 997 for i, message in enumerate(messages): 998 if isinstance(message, LiteLLMMessage): 999 latency_ms = message_latency.get(i) if message_latency else None 1000 usage = message_usage.get(i) if message_usage else None 1001 trace.append( 1002 self.litellm_message_to_trace_message(message, latency_ms, usage) 1003 ) 1004 else: 1005 trace.append(message) 1006 return trace 1007 1008 def _messages_to_trace( 1009 self, 1010 messages: list[ChatCompletionMessageParam], 1011 ) -> list[ChatCompletionMessageParam]: 1012 """Override: the `messages` list may transiently contain LiteLLM 1013 Message objects. Normalize to API-safe shapes for export on error. 1014 """ 1015 return self.all_messages_to_trace(messages) # type: ignore[arg-type]
Base class for AI model adapters that handle task execution.
This abstract class provides the foundation for implementing model-specific adapters that can process tasks with structured or unstructured inputs/outputs. It handles input/output validation, prompt building, and run tracking.
Prompt building is handled internally by the adapter, which uses a prompt builder based on the run config. To override the prompt building behavior, pass a custom prompt builder to the adapter config.
91 def __init__( 92 self, 93 config: LiteLlmConfig, 94 kiln_task: datamodel.Task, 95 base_adapter_config: AdapterConfig | None = None, 96 ): 97 if not isinstance(config.run_config_properties, KilnAgentRunConfigProperties): 98 raise ValueError("LiteLlmAdapter requires KilnAgentRunConfigProperties") 99 self.config = config 100 self._additional_body_options = config.additional_body_options 101 self._api_base = config.base_url 102 self._headers = config.default_headers 103 self._litellm_model_id: str | None = None 104 self._cached_available_tools: list[KilnToolInterface] | None = None 105 106 super().__init__( 107 task=kiln_task, 108 run_config=config.run_config_properties, 109 config=base_adapter_config, 110 ) 111 112 unmanaged_tools = self.base_adapter_config.unmanaged_tools 113 if unmanaged_tools: 114 _validate_unmanaged_tools(unmanaged_tools)
415 async def acompletion_checking_response( 416 self, **kwargs: Any 417 ) -> Tuple[ModelResponse, Choices]: 418 response = await litellm.acompletion(**kwargs) 419 420 if ( 421 not isinstance(response, ModelResponse) 422 or not response.choices 423 or len(response.choices) == 0 424 or not isinstance(response.choices[0], Choices) 425 ): 426 raise RuntimeError( 427 f"Expected ModelResponse with Choices, got {type(response)}." 428 ) 429 return response, response.choices[0]
434 async def response_format_options(self) -> dict[str, Any]: 435 # Unstructured if task isn't structured 436 if not self.has_structured_output(): 437 return {} 438 439 run_config = as_kiln_agent_run_config(self.run_config) 440 structured_output_mode: StructuredOutputMode = run_config.structured_output_mode 441 442 match structured_output_mode: 443 case StructuredOutputMode.json_mode: 444 return {"response_format": {"type": "json_object"}} 445 case StructuredOutputMode.json_schema: 446 return self.json_schema_response_format() 447 case StructuredOutputMode.function_calling_weak: 448 return self.tool_call_params(strict=False) 449 case StructuredOutputMode.function_calling: 450 return self.tool_call_params(strict=True) 451 case StructuredOutputMode.json_instructions: 452 # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below). 453 return {} 454 case StructuredOutputMode.json_custom_instructions: 455 # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above). 456 return {} 457 case StructuredOutputMode.json_instruction_and_object: 458 # We set response_format to json_object and also set json instructions in the prompt 459 return {"response_format": {"type": "json_object"}} 460 case StructuredOutputMode.default: 461 provider_name = run_config.model_provider_name 462 if provider_name == ModelProviderName.ollama: 463 # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs 464 return self.json_schema_response_format() 465 elif provider_name == ModelProviderName.docker_model_runner: 466 # Docker Model Runner uses OpenAI-compatible API with JSON schema support 467 return self.json_schema_response_format() 468 else: 469 # Default to function calling -- it's older than the other modes. Higher compatibility. 470 # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI. 471 strict = provider_name == ModelProviderName.openai 472 return self.tool_call_params(strict=strict) 473 case StructuredOutputMode.unknown: 474 # See above, but this case should never happen. 475 raise ValueError("Structured output mode is unknown.") 476 case _: 477 raise_exhaustive_enum_error(structured_output_mode) # type: ignore[arg-type]
479 def json_schema_response_format(self) -> dict[str, Any]: 480 output_schema = self.task.output_schema() 481 if output_schema is None: 482 raise ValueError( 483 "Invalid output schema for this task. Cannot use JSON schema response format." 484 ) 485 output_schema = close_object_schemas(output_schema, strict=True) 486 return { 487 "response_format": { 488 "type": "json_schema", 489 "json_schema": { 490 "name": "task_response", 491 "schema": output_schema, 492 }, 493 } 494 }
496 def tool_call_params(self, strict: bool) -> dict[str, Any]: 497 # Add additional_properties: false to the schema (OpenAI requires this for some models) 498 output_schema = self.task.output_schema() 499 if not isinstance(output_schema, dict): 500 raise ValueError( 501 "Invalid output schema for this task. Can not use tool calls." 502 ) 503 output_schema = close_object_schemas(output_schema, strict=strict) 504 505 function_params = { 506 "name": "task_response", 507 "parameters": output_schema, 508 } 509 # This should be on, but we allow setting function_calling_weak for APIs that don't support it. 510 if strict: 511 function_params["strict"] = True 512 513 return { 514 "tools": [ 515 { 516 "type": "function", 517 "function": function_params, 518 } 519 ], 520 "tool_choice": { 521 "type": "function", 522 "function": {"name": "task_response"}, 523 }, 524 }
526 def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]: 527 # Don't love having this logic here. But it's worth the usability improvement 528 # so better to keep it than exclude it. Should figure out how I want to isolate 529 # this sort of logic so it's config driven and can be overridden 530 extra_body: dict[str, Any] = {} 531 provider_options = {} 532 533 run_config = as_kiln_agent_run_config(self.run_config) 534 # For legacy config 'thinking_level' is not set, default to provider's default 535 if "thinking_level" in run_config.model_fields_set: 536 thinking_level = run_config.thinking_level 537 else: 538 thinking_level = provider.default_thinking_level 539 540 # Skip if provider doesn't support thinking levels (stale configs may still have one set) 541 if ( 542 thinking_level is not None 543 and provider.available_thinking_levels is not None 544 ): 545 # Anthropic models in OpenRouter uses reasoning object. See https://openrouter.ai/docs/use-cases/reasoning-tokens 546 if ( 547 provider.name == ModelProviderName.openrouter 548 and provider.openrouter_reasoning_object 549 ): 550 extra_body["reasoning"] = {"effort": thinking_level} 551 elif ( 552 provider.name == ModelProviderName.anthropic 553 and thinking_level == "none" 554 ): 555 # Anthropic's native API has no reasoning_effort="none"; passing it makes 556 # litellm map thinking to None and then crash. Omitting reasoning_effort 557 # disables extended thinking, which also frees temperature from the 558 # temperature=1 requirement that applies whenever thinking is enabled. 559 pass 560 else: 561 extra_body["reasoning_effort"] = thinking_level 562 # Opus 4.7/4.8 default thinking display to "omitted", returning empty 563 # thinking text. Request the summary so reasoning is surfaced. litellm 564 # still maps reasoning_effort to output_config.effort; this only adds the 565 # display to the adaptive thinking object. 566 if ( 567 provider.name == ModelProviderName.anthropic 568 and provider.anthropic_summarized_thinking 569 ): 570 extra_body["thinking"] = { 571 "type": "adaptive", 572 "display": "summarized", 573 } 574 575 if provider.require_openrouter_reasoning: 576 # https://openrouter.ai/docs/use-cases/reasoning-tokens 577 extra_body["reasoning"] = { 578 "exclude": False, 579 } 580 581 if provider.gemini_reasoning_enabled: 582 extra_body["reasoning"] = { 583 "enabled": True, 584 } 585 586 if provider.name == ModelProviderName.openrouter: 587 # Ask OpenRouter to include usage in the response (cost) 588 extra_body["usage"] = {"include": True} 589 590 # Set a default provider order for more deterministic routing. 591 # OpenRouter will ignore providers that don't support the model. 592 # Special cases below (like R1) can override this order. 593 # allow_fallbacks is true by default, but we can override it here. 594 provider_options["order"] = [ 595 "fireworks", 596 "parasail", 597 "together", 598 "deepinfra", 599 "novita", 600 "groq", 601 "amazon-bedrock", 602 "azure", 603 "nebius", 604 ] 605 606 if provider.anthropic_extended_thinking and "thinking" not in extra_body: 607 extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000} 608 609 if provider.r1_openrouter_options: 610 # Require providers that support the reasoning parameter 611 provider_options["require_parameters"] = True 612 # Prefer R1 providers with reasonable perf/quants 613 provider_options["order"] = ["fireworks", "together"] 614 # R1 providers with unreasonable quants 615 provider_options["ignore"] = ["deepinfra"] 616 617 # Only set of this request is to get logprobs. 618 if ( 619 provider.logprobs_openrouter_options 620 and self.base_adapter_config.top_logprobs is not None 621 ): 622 # Don't let OpenRouter choose a provider that doesn't support logprobs. 623 provider_options["require_parameters"] = True 624 # DeepInfra silently fails to return logprobs consistently. 625 provider_options["ignore"] = ["deepinfra"] 626 627 if provider.openrouter_skip_required_parameters: 628 # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params. 629 provider_options["require_parameters"] = False 630 631 # Siliconflow uses a bool flag for thinking, for some models 632 if provider.siliconflow_enable_thinking is not None: 633 extra_body["enable_thinking"] = provider.siliconflow_enable_thinking 634 635 if len(provider_options) > 0: 636 extra_body["provider"] = provider_options 637 638 return extra_body
640 def litellm_model_id(self) -> str: 641 # The model ID is an interesting combination of format and url endpoint. 642 # It specifics the provider URL/host, but this is overridden if you manually set an api url 643 if self._litellm_model_id: 644 return self._litellm_model_id 645 646 litellm_provider_info = get_litellm_provider_info(self.model_provider()) 647 if litellm_provider_info.is_custom and self._api_base is None: 648 raise ValueError( 649 "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)" 650 ) 651 652 self._litellm_model_id = litellm_provider_info.litellm_model_id 653 return self._litellm_model_id
690 async def build_completion_kwargs( 691 self, 692 provider: KilnModelProvider, 693 messages: list[ChatCompletionMessageIncludingLiteLLM], 694 top_logprobs: int | None, 695 skip_response_format: bool = False, 696 ) -> dict[str, Any]: 697 run_config = as_kiln_agent_run_config(self.run_config) 698 extra_body = self.build_extra_body(provider) 699 700 # Merge all parameters into a single kwargs dict for litellm 701 completion_kwargs = { 702 "model": self.litellm_model_id(), 703 "messages": messages, 704 "api_base": self._api_base, 705 "headers": self._headers, 706 "temperature": run_config.temperature, 707 "top_p": run_config.top_p, 708 # This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc. 709 # Not all models and providers support all openai params (for example, o3 doesn't support top_p) 710 # Better to ignore them than to fail the model call. 711 # https://docs.litellm.ai/docs/completion/input 712 "drop_params": True, 713 **extra_body, 714 **self._additional_body_options, 715 } 716 717 if self.base_adapter_config.automatic_prompt_caching: 718 # Mark the last message for cache control. Litellm's AnthropicCacheControlHook 719 # handles provider-specific injection. Providers auto-cache matching prefixes, 720 # so marking the last message is sufficient for multi-turn conversations. 721 completion_kwargs["cache_control_injection_points"] = [ 722 {"location": "message", "index": -1} 723 ] 724 725 tool_calls = await self.litellm_tools() 726 has_tools = len(tool_calls) > 0 727 if has_tools: 728 completion_kwargs["tools"] = tool_calls 729 completion_kwargs["tool_choice"] = "auto" 730 731 # Special condition for Claude Opus 4.1 and Sonnet 4.5, where we can only specify top_p or temp, not both. 732 # Remove default values (1.0) prioritizing anything the user customized, then error with helpful message if they are both custom. 733 if provider.temp_top_p_exclusive: 734 if "top_p" in completion_kwargs and completion_kwargs["top_p"] == 1.0: 735 del completion_kwargs["top_p"] 736 if ( 737 "temperature" in completion_kwargs 738 and completion_kwargs["temperature"] == 1.0 739 ): 740 del completion_kwargs["temperature"] 741 if "top_p" in completion_kwargs and "temperature" in completion_kwargs: 742 raise ValueError( 743 "top_p and temperature can not both have custom values for this model. This is a restriction from the model provider. Please set only one of them to a custom value (not 1.0)." 744 ) 745 746 if not skip_response_format: 747 # Response format: json_schema, json_instructions, json_mode, function_calling, etc 748 response_format_options = await self.response_format_options() 749 750 # Check for a conflict between tools and response format using tools 751 # We could reconsider this. Model could be able to choose between a final answer or a tool call on any turn. However, good models for tools tend to also support json_schea, so do we need to support both? If we do, merge them, and consider auto vs forced when merging (only forced for final, auto for merged). 752 if has_tools and "tools" in response_format_options: 753 raise ValueError( 754 "Function calling/tools can't be used as the JSON response format if you're also using tools. Please select a different structured output mode." 755 ) 756 757 completion_kwargs.update(response_format_options) 758 759 if top_logprobs is not None: 760 completion_kwargs["logprobs"] = True 761 completion_kwargs["top_logprobs"] = top_logprobs 762 763 # any params listed in this list will be passed to the model regardless of LiteLLM's own validation 764 allowed_openai_params = self._allowed_openai_params_for_completion_kwargs( 765 completion_kwargs 766 ) 767 if len(allowed_openai_params) > 0: 768 completion_kwargs["allowed_openai_params"] = allowed_openai_params 769 770 completion_kwargs["messages"] = sanitize_messages_for_provider(messages) 771 772 return completion_kwargs
774 def usage_from_response(self, response: ModelResponse) -> MessageUsage: 775 litellm_usage = response.get("usage", None) 776 777 # LiteLLM isn't consistent in how it returns the cost. 778 cost = response._hidden_params.get("response_cost", None) 779 if cost is None and litellm_usage: 780 cost = litellm_usage.get("cost", None) 781 782 usage = MessageUsage() 783 784 if not litellm_usage and not cost: 785 return usage 786 787 if litellm_usage and isinstance(litellm_usage, LiteLlmUsage): 788 usage.input_tokens = litellm_usage.get("prompt_tokens", None) 789 usage.output_tokens = litellm_usage.get("completion_tokens", None) 790 usage.total_tokens = litellm_usage.get("total_tokens", None) 791 prompt_details = litellm_usage.get("prompt_tokens_details", None) 792 if prompt_details and hasattr(prompt_details, "cached_tokens"): 793 usage.cached_tokens = prompt_details.cached_tokens 794 elif prompt_details: 795 logger.warning( 796 f"prompt_tokens_details has unexpected type {type(prompt_details)}, cached_tokens not extracted" 797 ) 798 else: 799 logger.warning( 800 f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}" 801 ) 802 803 if isinstance(cost, float): 804 usage.cost = cost 805 elif cost is not None: 806 # None is allowed, but no other types are expected 807 logger.warning( 808 f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}" 809 ) 810 811 return usage
824 async def litellm_tools(self) -> list[ToolCallDefinition]: 825 available_tools = await self.cached_available_tools() 826 827 registry_defs = [await tool.toolcall_definition() for tool in available_tools] 828 unmanaged = self.base_adapter_config.unmanaged_tools 829 unmanaged_defs = ( 830 [await t.toolcall_definition() for t in unmanaged] if unmanaged else [] 831 ) 832 833 merged = registry_defs + unmanaged_defs 834 seen_names: set[str] = set() 835 for d in merged: 836 name = d["function"]["name"] 837 if name in seen_names: 838 raise ValueError( 839 f"Duplicate tool name {name!r}: unmanaged and registry tools must have unique names." 840 ) 841 seen_names.add(name) 842 843 return merged
845 async def process_tool_calls( 846 self, tool_calls: list[ChatCompletionMessageToolCall] | None 847 ) -> tuple[str | None, list[ChatCompletionToolMessageParamWrapper]]: 848 if tool_calls is None: 849 return None, [] 850 851 assistant_output_from_toolcall: str | None = None 852 tool_call_response_messages: list[ChatCompletionToolMessageParamWrapper] = [] 853 tool_run_coroutines = [] 854 855 for tool_call in tool_calls: 856 # Kiln "task_response" tool is used for returning structured output via tool calls. 857 # Load the output from the tool call. Also 858 if tool_call.function.name == "task_response": 859 assistant_output_from_toolcall = tool_call.function.arguments 860 continue 861 862 # Process normal tool calls (not the "task_response" tool) 863 tool_name = tool_call.function.name 864 tool = None 865 for tool_option in await self._tools_for_execution(): 866 if await tool_option.name() == tool_name: 867 tool = tool_option 868 break 869 if not tool: 870 raise RuntimeError( 871 f"A tool named '{tool_name}' was invoked by a model, but was not available." 872 ) 873 874 # Parse the arguments and validate them against the tool's schema 875 try: 876 parsed_args = json.loads(tool_call.function.arguments) 877 except json.JSONDecodeError: 878 raise RuntimeError( 879 f"Failed to parse arguments for tool '{tool_name}' (should be JSON): {tool_call.function.arguments}" 880 ) 881 try: 882 tool_call_definition = await tool.toolcall_definition() 883 json_schema = json.dumps(tool_call_definition["function"]["parameters"]) 884 validate_schema_with_value_error(parsed_args, json_schema) 885 except Exception as e: 886 raise RuntimeError( 887 f"Failed to validate arguments for tool '{tool_name}'. The arguments didn't match the tool's schema. The arguments were: {parsed_args}\n The error was: {e}" 888 ) from e 889 890 # Create context with the calling task's allow_saving setting 891 context = ToolCallContext( 892 allow_saving=self.base_adapter_config.allow_saving 893 ) 894 895 async def run_tool_and_format( 896 t=tool, c=context, args=parsed_args, tc_id=tool_call.id 897 ): 898 result = await t.run(c, **args) 899 return ChatCompletionToolMessageParamWrapper( 900 role="tool", 901 tool_call_id=tc_id, 902 content=result.output, 903 kiln_task_tool_data=result.kiln_task_tool_data 904 if isinstance(result, KilnTaskToolResult) 905 else None, 906 is_error=result.is_error if result.is_error else None, 907 error_message=result.error_message 908 if result.error_message 909 else None, 910 ) 911 912 tool_run_coroutines.append(run_tool_and_format()) 913 914 if tool_run_coroutines: 915 tool_call_response_messages = await asyncio.gather(*tool_run_coroutines) 916 917 if ( 918 assistant_output_from_toolcall is not None 919 and len(tool_call_response_messages) > 0 920 ): 921 raise RuntimeError( 922 "Model asked for impossible combination: task_response tool call and other tool calls were both provided in the same turn. This is not supported as it means the model asked us to both return task_response results (ending the turn) and run new tools calls to send back to the model. If the model makes this mistake often, try a difference structured data model like JSON schema, where this is impossible." 923 ) 924 925 return assistant_output_from_toolcall, tool_call_response_messages
927 def litellm_message_to_trace_message( 928 self, 929 raw_message: LiteLLMMessage, 930 latency_ms: int | None = None, 931 usage: MessageUsage | None = None, 932 ) -> ChatCompletionAssistantMessageParamWrapper: 933 """ 934 Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper 935 """ 936 message: ChatCompletionAssistantMessageParamWrapper = { 937 "role": "assistant", 938 } 939 if raw_message.role != "assistant": 940 raise ValueError( 941 "Model returned a message with a role other than assistant. This is not supported." 942 ) 943 944 if hasattr(raw_message, "content"): 945 message["content"] = raw_message.content 946 if hasattr(raw_message, "reasoning_content"): 947 message["reasoning_content"] = raw_message.reasoning_content 948 if hasattr(raw_message, "tool_calls"): 949 # Convert ChatCompletionMessageToolCall to ChatCompletionMessageToolCallParam 950 open_ai_tool_calls: List[ChatCompletionMessageToolCallParam] = [] 951 for litellm_tool_call in raw_message.tool_calls or []: 952 # Optional in the SDK for streaming responses, but should never be None at this point. 953 if litellm_tool_call.function.name is None: 954 raise ValueError( 955 "The model requested a tool call, without providing a function name (required)." 956 ) 957 open_ai_tool_calls.append( 958 ChatCompletionMessageToolCallParam( 959 id=litellm_tool_call.id, 960 type="function", 961 function={ 962 "name": litellm_tool_call.function.name, 963 "arguments": litellm_tool_call.function.arguments, 964 }, 965 ) 966 ) 967 if len(open_ai_tool_calls) > 0: 968 message["tool_calls"] = open_ai_tool_calls 969 970 if latency_ms is not None: 971 message["latency_ms"] = latency_ms 972 973 if usage is not None: 974 message["usage"] = usage 975 976 if not message.get("content") and not message.get("tool_calls"): 977 raise ValueError( 978 "Model returned an assistant message, but no content or tool calls. This is not supported." 979 ) 980 981 return message
Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper
983 def all_messages_to_trace( 984 self, 985 messages: list[ChatCompletionMessageIncludingLiteLLM], 986 message_latency: dict[int, int] | None = None, 987 message_usage: dict[int, MessageUsage] | None = None, 988 ) -> list[ChatCompletionMessageParam]: 989 """ 990 Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types. 991 992 Non-LiteLLM dict messages pass through unchanged. Any per-message 993 ``usage``/``latency_ms`` already attached to those dicts (e.g. from a 994 seeded prior trace) is preserved. 995 """ 996 trace: list[ChatCompletionMessageParam] = [] 997 for i, message in enumerate(messages): 998 if isinstance(message, LiteLLMMessage): 999 latency_ms = message_latency.get(i) if message_latency else None 1000 usage = message_usage.get(i) if message_usage else None 1001 trace.append( 1002 self.litellm_message_to_trace_message(message, latency_ms, usage) 1003 ) 1004 else: 1005 trace.append(message) 1006 return trace
Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types.
Non-LiteLLM dict messages pass through unchanged. Any per-message
usage/latency_ms already attached to those dicts (e.g. from a
seeded prior trace) is preserved.
Inherited Members
- kiln_ai.adapters.model_adapters.base_adapter.BaseAdapter
- task
- run_config
- base_adapter_config
- output_schema
- input_schema
- model_provider
- invoke
- invoke_returning_run_output
- invoke_openai_stream
- invoke_ai_sdk_stream
- has_structured_output
- build_prompt
- build_chat_formatter
- generate_run
- update_run_config_unknown_structured_output_mode
- available_tools