kiln_ai.adapters.model_adapters.litellm_adapter
1import asyncio 2import copy 3import json 4import logging 5import time 6from dataclasses import dataclass 7from typing import Any, Dict, List, Tuple 8 9import litellm 10from litellm.types.utils import ( 11 ChatCompletionMessageToolCall, 12 ChoiceLogprobs, 13 Choices, 14 ModelResponse, 15) 16from litellm.types.utils import Message as LiteLLMMessage 17from litellm.types.utils import Usage as LiteLlmUsage 18from openai.types.chat.chat_completion_message_tool_call_param import ( 19 ChatCompletionMessageToolCallParam, 20) 21 22import kiln_ai.datamodel as datamodel 23from kiln_ai.adapters.chat import ChatCompletionMessageIncludingLiteLLM 24from kiln_ai.adapters.chat.chat_formatter import chat_message_to_dict 25from kiln_ai.adapters.ml_model_list import ( 26 KilnModelProvider, 27 ModelProviderName, 28 StructuredOutputMode, 29) 30from kiln_ai.adapters.model_adapters.adapter_stream import AdapterStream 31from kiln_ai.adapters.model_adapters.base_adapter import ( 32 AdapterConfig, 33 BaseAdapter, 34 MessageUsage, 35 RunOutput, 36 Usage, 37) 38from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig 39from kiln_ai.datamodel.datamodel_enums import InputType 40from kiln_ai.datamodel.json_schema import ( 41 close_object_schemas, 42 validate_schema_with_value_error, 43) 44from kiln_ai.datamodel.run_config import ( 45 KilnAgentRunConfigProperties, 46 as_kiln_agent_run_config, 47) 48from kiln_ai.tools.base_tool import ( 49 KilnToolInterface, 50 ToolCallContext, 51 ToolCallDefinition, 52) 53from kiln_ai.tools.kiln_task_tool import KilnTaskToolResult 54from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error 55from kiln_ai.utils.litellm import get_litellm_provider_info 56from kiln_ai.utils.open_ai_types import ( 57 ChatCompletionAssistantMessageParamWrapper, 58 ChatCompletionMessageParam, 59 ChatCompletionToolMessageParamWrapper, 60 sanitize_messages_for_provider, 61) 62 63MAX_CALLS_PER_TURN = 10 64MAX_TOOL_CALLS_PER_TURN = 30 65 66logger = logging.getLogger(__name__) 67 68 69def _validate_unmanaged_tools(tools: list[KilnToolInterface]) -> None: 70 for i, tool in enumerate(tools): 71 if not isinstance(tool, KilnToolInterface): 72 raise TypeError( 73 f"unmanaged_tools[{i}] must be a KilnToolInterface instance, got {type(tool).__name__}" 74 ) 75 76 77@dataclass 78class ModelTurnResult: 79 assistant_message: str 80 all_messages: list[ChatCompletionMessageIncludingLiteLLM] 81 model_response: ModelResponse | None 82 model_choice: Choices | None 83 usage: Usage 84 interrupted_by_tool_calls: list[ChatCompletionMessageToolCall] | None = None 85 message_latency: dict[int, int] | None = None 86 message_usage: dict[int, MessageUsage] | None = None 87 88 89class LiteLlmAdapter(BaseAdapter): 90 def __init__( 91 self, 92 config: LiteLlmConfig, 93 kiln_task: datamodel.Task, 94 base_adapter_config: AdapterConfig | None = None, 95 ): 96 if not isinstance(config.run_config_properties, KilnAgentRunConfigProperties): 97 raise ValueError("LiteLlmAdapter requires KilnAgentRunConfigProperties") 98 self.config = config 99 self._additional_body_options = config.additional_body_options 100 self._api_base = config.base_url 101 self._headers = config.default_headers 102 self._litellm_model_id: str | None = None 103 self._cached_available_tools: list[KilnToolInterface] | None = None 104 105 super().__init__( 106 task=kiln_task, 107 run_config=config.run_config_properties, 108 config=base_adapter_config, 109 ) 110 111 unmanaged_tools = self.base_adapter_config.unmanaged_tools 112 if unmanaged_tools: 113 _validate_unmanaged_tools(unmanaged_tools) 114 115 async def _run_model_turn( 116 self, 117 provider: KilnModelProvider, 118 prior_messages: list[ChatCompletionMessageIncludingLiteLLM], 119 top_logprobs: int | None, 120 skip_response_format: bool, 121 ) -> ModelTurnResult: 122 """ 123 Call the model for a single top level turn: from user message to agent message. 124 125 It may make handle iterations of tool calls between the user/agent message if needed. 126 """ 127 128 usage = Usage() 129 messages = list(prior_messages) 130 tool_calls_count = 0 131 # Per-LLM-call latency / usage, keyed by index in the messages list. 132 # Kept separate because we don't own the LiteLLM message objects. 133 message_latency: dict[int, int] = {} 134 message_usage: dict[int, MessageUsage] = {} 135 136 while tool_calls_count < MAX_TOOL_CALLS_PER_TURN: 137 # Build completion kwargs for tool calls 138 completion_kwargs = await self.build_completion_kwargs( 139 provider, 140 # Pass a copy, as acompletion mutates objects and breaks types. 141 copy.deepcopy(messages), 142 top_logprobs, 143 skip_response_format, 144 ) 145 146 # Make the completion call (timed) 147 start = time.monotonic() 148 model_response, response_choice = await self.acompletion_checking_response( 149 **completion_kwargs 150 ) 151 call_latency_ms = int((time.monotonic() - start) * 1000) 152 153 # count the usage 154 call_usage = self.usage_from_response(model_response) 155 usage += call_usage 156 usage.total_llm_latency_ms = ( 157 usage.total_llm_latency_ms or 0 158 ) + call_latency_ms 159 160 # Extract content and tool calls 161 if not hasattr(response_choice, "message"): 162 raise ValueError("Response choice has no message") 163 content = response_choice.message.content 164 tool_calls = response_choice.message.tool_calls 165 if not content and not tool_calls: 166 raise ValueError( 167 "Model returned an assistant message, but no content or tool calls. This is not supported." 168 ) 169 170 # Add message to messages, so it can be used in the next turn 171 messages.append(response_choice.message) 172 message_latency[len(messages) - 1] = call_latency_ms 173 message_usage[len(messages) - 1] = call_usage 174 175 # Process tool calls if any 176 if tool_calls and len(tool_calls) > 0: 177 # check if we should return control to caller 178 if self.base_adapter_config.return_on_tool_call: 179 # filter out task_response tool (task_response tools are internal) 180 standard_tool_calls = [ 181 tc for tc in tool_calls if tc.function.name != "task_response" 182 ] 183 has_task_response = any( 184 tc.function.name == "task_response" for tc in tool_calls 185 ) 186 if standard_tool_calls and not has_task_response: 187 return ModelTurnResult( 188 # we don't have any content, we are waiting for toolcall output to come back from client 189 assistant_message="", 190 all_messages=messages, 191 model_response=model_response, 192 model_choice=response_choice, 193 usage=usage, 194 interrupted_by_tool_calls=standard_tool_calls, 195 message_latency=message_latency, 196 message_usage=message_usage, 197 ) 198 199 # otherwise: process tool calls internally until final output 200 ( 201 assistant_message_from_toolcall, 202 tool_call_messages, 203 ) = await self.process_tool_calls(tool_calls) 204 205 # Add tool call results to messages 206 messages.extend(tool_call_messages) 207 208 # If task_response tool was called, we're done 209 if assistant_message_from_toolcall is not None: 210 return ModelTurnResult( 211 assistant_message=assistant_message_from_toolcall, 212 all_messages=messages, 213 model_response=model_response, 214 model_choice=response_choice, 215 usage=usage, 216 message_latency=message_latency, 217 message_usage=message_usage, 218 ) 219 220 # If there were tool calls, increment counter and continue 221 if tool_call_messages: 222 tool_calls_count += 1 223 continue 224 225 # If no tool calls, return the content as final output 226 if content: 227 return ModelTurnResult( 228 assistant_message=content, 229 all_messages=messages, 230 model_response=model_response, 231 model_choice=response_choice, 232 usage=usage, 233 message_latency=message_latency, 234 message_usage=message_usage, 235 ) 236 237 # If we get here with no content and no tool calls, break 238 raise RuntimeError( 239 "Model returned neither content nor tool calls. It must return at least one of these." 240 ) 241 242 raise RuntimeError( 243 f"Too many tool calls ({tool_calls_count}). Stopping iteration to avoid using too many tokens." 244 ) 245 246 async def _run( 247 self, 248 input: InputType, 249 prior_trace: list[ChatCompletionMessageParam] | None = None, 250 ) -> tuple[RunOutput, Usage | None]: 251 usage = Usage() 252 253 provider = self.model_provider() 254 if not provider.model_id: 255 raise ValueError("Model ID is required for OpenAI compatible models") 256 257 # build_chat_formatter returns MultiturnFormatter when prior_trace is set, else prompt-based formatter 258 chat_formatter = self.build_chat_formatter(input, prior_trace) 259 messages: list[ChatCompletionMessageIncludingLiteLLM] = copy.deepcopy( 260 chat_formatter.initial_messages() 261 ) 262 263 prior_output: str | None = None 264 final_choice: Choices | None = None 265 turns = 0 266 message_latency: dict[int, int] = {} 267 message_usage: dict[int, MessageUsage] = {} 268 269 # Same loop for both fresh runs and prior_trace continuation. 270 # _run_model_turn has its own internal loop for tool calls (model calls tool -> we run it -> model continues). 271 while True: 272 turns += 1 273 if turns > MAX_CALLS_PER_TURN: 274 raise RuntimeError( 275 f"Too many turns ({turns}). Stopping iteration to avoid using too many tokens." 276 ) 277 278 turn = chat_formatter.next_turn(prior_output) 279 if turn is None: 280 # No next turn, we're done 281 break 282 283 # Add messages from the turn to chat history 284 for message in turn.messages: 285 if message.content is None: 286 raise ValueError("Empty message content isn't allowed") 287 messages.append(chat_message_to_dict(message)) # type: ignore 288 289 skip_response_format = not turn.final_call 290 turn_result = await self._run_model_turn( 291 provider, 292 messages, 293 self.base_adapter_config.top_logprobs if turn.final_call else None, 294 skip_response_format, 295 ) 296 297 usage += turn_result.usage 298 if turn_result.message_latency: 299 message_latency.update(turn_result.message_latency) 300 if turn_result.message_usage: 301 message_usage.update(turn_result.message_usage) 302 303 prior_output = turn_result.assistant_message 304 messages = turn_result.all_messages 305 final_choice = turn_result.model_choice 306 307 # Check if we were interrupted by tool calls 308 if turn_result.interrupted_by_tool_calls: 309 trace = self.all_messages_to_trace( 310 messages, message_latency, message_usage 311 ) 312 intermediate_outputs = chat_formatter.intermediate_outputs() 313 output = RunOutput( 314 output=prior_output or "", 315 intermediate_outputs=intermediate_outputs, 316 output_logprobs=None, 317 trace=trace, 318 ) 319 return output, usage 320 321 if not prior_output: 322 raise RuntimeError("No assistant message/output returned from model") 323 324 logprobs = self._extract_and_validate_logprobs(final_choice) 325 326 # Save COT/reasoning if it exists. May be a message, or may be parsed by LiteLLM (or openrouter, or anyone upstream) 327 intermediate_outputs = chat_formatter.intermediate_outputs() 328 self._extract_reasoning_to_intermediate_outputs( 329 final_choice, intermediate_outputs 330 ) 331 332 if not isinstance(prior_output, str): 333 raise RuntimeError(f"assistant message is not a string: {prior_output}") 334 335 trace = self.all_messages_to_trace(messages, message_latency, message_usage) 336 output = RunOutput( 337 output=prior_output, 338 intermediate_outputs=intermediate_outputs, 339 output_logprobs=logprobs, 340 trace=trace, 341 ) 342 343 return output, usage 344 345 def _create_run_stream( 346 self, 347 input: InputType, 348 prior_trace: list[ChatCompletionMessageParam] | None = None, 349 ) -> AdapterStream: 350 provider = self.model_provider() 351 if not provider.model_id: 352 raise ValueError("Model ID is required for OpenAI compatible models") 353 354 chat_formatter = self.build_chat_formatter(input, prior_trace) 355 initial_messages: list[ChatCompletionMessageIncludingLiteLLM] = copy.deepcopy( 356 chat_formatter.initial_messages() 357 ) 358 359 return AdapterStream( 360 adapter=self, 361 provider=provider, 362 chat_formatter=chat_formatter, 363 initial_messages=initial_messages, 364 top_logprobs=self.base_adapter_config.top_logprobs, 365 ) 366 367 def _extract_and_validate_logprobs( 368 self, final_choice: Choices | None 369 ) -> ChoiceLogprobs | None: 370 """ 371 Extract logprobs from the final choice and validate they exist if required. 372 """ 373 logprobs = None 374 if ( 375 final_choice is not None 376 and hasattr(final_choice, "logprobs") 377 and isinstance(final_choice.logprobs, ChoiceLogprobs) 378 ): 379 logprobs = final_choice.logprobs 380 381 # Check logprobs worked, if required 382 if self.base_adapter_config.top_logprobs is not None and logprobs is None: 383 raise RuntimeError("Logprobs were required, but no logprobs were returned.") 384 385 return logprobs 386 387 def _extract_reasoning_to_intermediate_outputs( 388 self, final_choice: Choices | None, intermediate_outputs: Dict[str, Any] 389 ) -> None: 390 """Extract reasoning content from model choice and add to intermediate outputs if present.""" 391 if ( 392 final_choice is not None 393 and hasattr(final_choice, "message") 394 and hasattr(final_choice.message, "reasoning_content") 395 ): 396 reasoning_content = final_choice.message.reasoning_content 397 if reasoning_content is not None: 398 stripped_reasoning_content = reasoning_content.strip() 399 if len(stripped_reasoning_content) > 0: 400 intermediate_outputs["reasoning"] = stripped_reasoning_content 401 402 async def acompletion_checking_response( 403 self, **kwargs: Any 404 ) -> Tuple[ModelResponse, Choices]: 405 response = await litellm.acompletion(**kwargs) 406 407 if ( 408 not isinstance(response, ModelResponse) 409 or not response.choices 410 or len(response.choices) == 0 411 or not isinstance(response.choices[0], Choices) 412 ): 413 raise RuntimeError( 414 f"Expected ModelResponse with Choices, got {type(response)}." 415 ) 416 return response, response.choices[0] 417 418 def adapter_name(self) -> str: 419 return "kiln_openai_compatible_adapter" 420 421 async def response_format_options(self) -> dict[str, Any]: 422 # Unstructured if task isn't structured 423 if not self.has_structured_output(): 424 return {} 425 426 run_config = as_kiln_agent_run_config(self.run_config) 427 structured_output_mode: StructuredOutputMode = run_config.structured_output_mode 428 429 match structured_output_mode: 430 case StructuredOutputMode.json_mode: 431 return {"response_format": {"type": "json_object"}} 432 case StructuredOutputMode.json_schema: 433 return self.json_schema_response_format() 434 case StructuredOutputMode.function_calling_weak: 435 return self.tool_call_params(strict=False) 436 case StructuredOutputMode.function_calling: 437 return self.tool_call_params(strict=True) 438 case StructuredOutputMode.json_instructions: 439 # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below). 440 return {} 441 case StructuredOutputMode.json_custom_instructions: 442 # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above). 443 return {} 444 case StructuredOutputMode.json_instruction_and_object: 445 # We set response_format to json_object and also set json instructions in the prompt 446 return {"response_format": {"type": "json_object"}} 447 case StructuredOutputMode.default: 448 provider_name = run_config.model_provider_name 449 if provider_name == ModelProviderName.ollama: 450 # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs 451 return self.json_schema_response_format() 452 elif provider_name == ModelProviderName.docker_model_runner: 453 # Docker Model Runner uses OpenAI-compatible API with JSON schema support 454 return self.json_schema_response_format() 455 else: 456 # Default to function calling -- it's older than the other modes. Higher compatibility. 457 # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI. 458 strict = provider_name == ModelProviderName.openai 459 return self.tool_call_params(strict=strict) 460 case StructuredOutputMode.unknown: 461 # See above, but this case should never happen. 462 raise ValueError("Structured output mode is unknown.") 463 case _: 464 raise_exhaustive_enum_error(structured_output_mode) # type: ignore[arg-type] 465 466 def json_schema_response_format(self) -> dict[str, Any]: 467 output_schema = self.task.output_schema() 468 if output_schema is None: 469 raise ValueError( 470 "Invalid output schema for this task. Cannot use JSON schema response format." 471 ) 472 output_schema = close_object_schemas(output_schema, strict=True) 473 return { 474 "response_format": { 475 "type": "json_schema", 476 "json_schema": { 477 "name": "task_response", 478 "schema": output_schema, 479 }, 480 } 481 } 482 483 def tool_call_params(self, strict: bool) -> dict[str, Any]: 484 # Add additional_properties: false to the schema (OpenAI requires this for some models) 485 output_schema = self.task.output_schema() 486 if not isinstance(output_schema, dict): 487 raise ValueError( 488 "Invalid output schema for this task. Can not use tool calls." 489 ) 490 output_schema = close_object_schemas(output_schema, strict=strict) 491 492 function_params = { 493 "name": "task_response", 494 "parameters": output_schema, 495 } 496 # This should be on, but we allow setting function_calling_weak for APIs that don't support it. 497 if strict: 498 function_params["strict"] = True 499 500 return { 501 "tools": [ 502 { 503 "type": "function", 504 "function": function_params, 505 } 506 ], 507 "tool_choice": { 508 "type": "function", 509 "function": {"name": "task_response"}, 510 }, 511 } 512 513 def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]: 514 # Don't love having this logic here. But it's worth the usability improvement 515 # so better to keep it than exclude it. Should figure out how I want to isolate 516 # this sort of logic so it's config driven and can be overridden 517 extra_body: dict[str, Any] = {} 518 provider_options = {} 519 520 run_config = as_kiln_agent_run_config(self.run_config) 521 # For legacy config 'thinking_level' is not set, default to provider's default 522 if "thinking_level" in run_config.model_fields_set: 523 thinking_level = run_config.thinking_level 524 else: 525 thinking_level = provider.default_thinking_level 526 527 # Set the reasoning_effort 528 if thinking_level is not None: 529 # Anthropic models in OpenRouter uses reasoning object. See https://openrouter.ai/docs/use-cases/reasoning-tokens 530 if ( 531 provider.name == ModelProviderName.openrouter 532 and provider.openrouter_reasoning_object 533 ): 534 extra_body["reasoning"] = {"effort": thinking_level} 535 else: 536 extra_body["reasoning_effort"] = thinking_level 537 538 if provider.require_openrouter_reasoning: 539 # https://openrouter.ai/docs/use-cases/reasoning-tokens 540 extra_body["reasoning"] = { 541 "exclude": False, 542 } 543 544 if provider.gemini_reasoning_enabled: 545 extra_body["reasoning"] = { 546 "enabled": True, 547 } 548 549 if provider.name == ModelProviderName.openrouter: 550 # Ask OpenRouter to include usage in the response (cost) 551 extra_body["usage"] = {"include": True} 552 553 # Set a default provider order for more deterministic routing. 554 # OpenRouter will ignore providers that don't support the model. 555 # Special cases below (like R1) can override this order. 556 # allow_fallbacks is true by default, but we can override it here. 557 provider_options["order"] = [ 558 "fireworks", 559 "parasail", 560 "together", 561 "deepinfra", 562 "novita", 563 "groq", 564 "amazon-bedrock", 565 "azure", 566 "nebius", 567 ] 568 569 if provider.anthropic_extended_thinking: 570 extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000} 571 572 if provider.r1_openrouter_options: 573 # Require providers that support the reasoning parameter 574 provider_options["require_parameters"] = True 575 # Prefer R1 providers with reasonable perf/quants 576 provider_options["order"] = ["fireworks", "together"] 577 # R1 providers with unreasonable quants 578 provider_options["ignore"] = ["deepinfra"] 579 580 # Only set of this request is to get logprobs. 581 if ( 582 provider.logprobs_openrouter_options 583 and self.base_adapter_config.top_logprobs is not None 584 ): 585 # Don't let OpenRouter choose a provider that doesn't support logprobs. 586 provider_options["require_parameters"] = True 587 # DeepInfra silently fails to return logprobs consistently. 588 provider_options["ignore"] = ["deepinfra"] 589 590 if provider.openrouter_skip_required_parameters: 591 # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params. 592 provider_options["require_parameters"] = False 593 594 # Siliconflow uses a bool flag for thinking, for some models 595 if provider.siliconflow_enable_thinking is not None: 596 extra_body["enable_thinking"] = provider.siliconflow_enable_thinking 597 598 if len(provider_options) > 0: 599 extra_body["provider"] = provider_options 600 601 return extra_body 602 603 def litellm_model_id(self) -> str: 604 # The model ID is an interesting combination of format and url endpoint. 605 # It specifics the provider URL/host, but this is overridden if you manually set an api url 606 if self._litellm_model_id: 607 return self._litellm_model_id 608 609 litellm_provider_info = get_litellm_provider_info(self.model_provider()) 610 if litellm_provider_info.is_custom and self._api_base is None: 611 raise ValueError( 612 "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)" 613 ) 614 615 self._litellm_model_id = litellm_provider_info.litellm_model_id 616 return self._litellm_model_id 617 618 def _allowed_openai_params_for_completion_kwargs( 619 self, completion_kwargs: dict[str, Any] 620 ) -> list[str]: 621 """ 622 LiteLLM drops params it thinks are not supported by the model when drop_params is True. Sometimes it is wrong 623 and we know it is supported, so we whitelist them here and pass that as an allowed_openai_params parameter. 624 """ 625 # callers could have set allowed_openai_params in the additional_body_options, so we need to check for that 626 explicit_allowed_params: Any | list = completion_kwargs.get( 627 "allowed_openai_params", [] 628 ) 629 if not isinstance(explicit_allowed_params, list): 630 raise ValueError( 631 f"Unexpected allowed_openai_params format: {explicit_allowed_params} - expected list, got {type(explicit_allowed_params)}" 632 ) 633 explicit_allowed_params_validated = [ 634 param for param in explicit_allowed_params if isinstance(param, str) 635 ] 636 invalid_count = len(explicit_allowed_params) - len( 637 explicit_allowed_params_validated 638 ) 639 if invalid_count > 0: 640 raise ValueError( 641 f"Unexpected allowed_openai_params format: {explicit_allowed_params} - {invalid_count} items are not strings" 642 ) 643 644 # these are our own logic 645 automatic_allowed_params: list[str] = [] 646 if "tools" in completion_kwargs: 647 automatic_allowed_params.append("tools") 648 if "tool_choice" in completion_kwargs: 649 automatic_allowed_params.append("tool_choice") 650 651 return list(set(explicit_allowed_params_validated + automatic_allowed_params)) 652 653 async def build_completion_kwargs( 654 self, 655 provider: KilnModelProvider, 656 messages: list[ChatCompletionMessageIncludingLiteLLM], 657 top_logprobs: int | None, 658 skip_response_format: bool = False, 659 ) -> dict[str, Any]: 660 run_config = as_kiln_agent_run_config(self.run_config) 661 extra_body = self.build_extra_body(provider) 662 663 # Merge all parameters into a single kwargs dict for litellm 664 completion_kwargs = { 665 "model": self.litellm_model_id(), 666 "messages": messages, 667 "api_base": self._api_base, 668 "headers": self._headers, 669 "temperature": run_config.temperature, 670 "top_p": run_config.top_p, 671 # This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc. 672 # Not all models and providers support all openai params (for example, o3 doesn't support top_p) 673 # Better to ignore them than to fail the model call. 674 # https://docs.litellm.ai/docs/completion/input 675 "drop_params": True, 676 **extra_body, 677 **self._additional_body_options, 678 } 679 680 if self.base_adapter_config.automatic_prompt_caching: 681 # Mark the last message for cache control. Litellm's AnthropicCacheControlHook 682 # handles provider-specific injection. Providers auto-cache matching prefixes, 683 # so marking the last message is sufficient for multi-turn conversations. 684 completion_kwargs["cache_control_injection_points"] = [ 685 {"location": "message", "index": -1} 686 ] 687 688 tool_calls = await self.litellm_tools() 689 has_tools = len(tool_calls) > 0 690 if has_tools: 691 completion_kwargs["tools"] = tool_calls 692 completion_kwargs["tool_choice"] = "auto" 693 694 # Special condition for Claude Opus 4.1 and Sonnet 4.5, where we can only specify top_p or temp, not both. 695 # Remove default values (1.0) prioritizing anything the user customized, then error with helpful message if they are both custom. 696 if provider.temp_top_p_exclusive: 697 if "top_p" in completion_kwargs and completion_kwargs["top_p"] == 1.0: 698 del completion_kwargs["top_p"] 699 if ( 700 "temperature" in completion_kwargs 701 and completion_kwargs["temperature"] == 1.0 702 ): 703 del completion_kwargs["temperature"] 704 if "top_p" in completion_kwargs and "temperature" in completion_kwargs: 705 raise ValueError( 706 "top_p and temperature can not both have custom values for this model. This is a restriction from the model provider. Please set only one of them to a custom value (not 1.0)." 707 ) 708 709 if not skip_response_format: 710 # Response format: json_schema, json_instructions, json_mode, function_calling, etc 711 response_format_options = await self.response_format_options() 712 713 # Check for a conflict between tools and response format using tools 714 # We could reconsider this. Model could be able to choose between a final answer or a tool call on any turn. However, good models for tools tend to also support json_schea, so do we need to support both? If we do, merge them, and consider auto vs forced when merging (only forced for final, auto for merged). 715 if has_tools and "tools" in response_format_options: 716 raise ValueError( 717 "Function calling/tools can't be used as the JSON response format if you're also using tools. Please select a different structured output mode." 718 ) 719 720 completion_kwargs.update(response_format_options) 721 722 if top_logprobs is not None: 723 completion_kwargs["logprobs"] = True 724 completion_kwargs["top_logprobs"] = top_logprobs 725 726 # any params listed in this list will be passed to the model regardless of LiteLLM's own validation 727 allowed_openai_params = self._allowed_openai_params_for_completion_kwargs( 728 completion_kwargs 729 ) 730 if len(allowed_openai_params) > 0: 731 completion_kwargs["allowed_openai_params"] = allowed_openai_params 732 733 completion_kwargs["messages"] = sanitize_messages_for_provider(messages) 734 735 return completion_kwargs 736 737 def usage_from_response(self, response: ModelResponse) -> MessageUsage: 738 litellm_usage = response.get("usage", None) 739 740 # LiteLLM isn't consistent in how it returns the cost. 741 cost = response._hidden_params.get("response_cost", None) 742 if cost is None and litellm_usage: 743 cost = litellm_usage.get("cost", None) 744 745 usage = MessageUsage() 746 747 if not litellm_usage and not cost: 748 return usage 749 750 if litellm_usage and isinstance(litellm_usage, LiteLlmUsage): 751 usage.input_tokens = litellm_usage.get("prompt_tokens", None) 752 usage.output_tokens = litellm_usage.get("completion_tokens", None) 753 usage.total_tokens = litellm_usage.get("total_tokens", None) 754 prompt_details = litellm_usage.get("prompt_tokens_details", None) 755 if prompt_details and hasattr(prompt_details, "cached_tokens"): 756 usage.cached_tokens = prompt_details.cached_tokens 757 elif prompt_details: 758 logger.warning( 759 f"prompt_tokens_details has unexpected type {type(prompt_details)}, cached_tokens not extracted" 760 ) 761 else: 762 logger.warning( 763 f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}" 764 ) 765 766 if isinstance(cost, float): 767 usage.cost = cost 768 elif cost is not None: 769 # None is allowed, but no other types are expected 770 logger.warning( 771 f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}" 772 ) 773 774 return usage 775 776 async def cached_available_tools(self) -> list[KilnToolInterface]: 777 if self._cached_available_tools is None: 778 self._cached_available_tools = await self.available_tools() 779 return self._cached_available_tools 780 781 async def _tools_for_execution(self) -> list[KilnToolInterface]: 782 """Registry-resolved tools plus :attr:`AdapterConfig.unmanaged_tools` (same order as ``litellm_tools``).""" 783 registry = await self.cached_available_tools() 784 unmanaged = self.base_adapter_config.unmanaged_tools or [] 785 return registry + unmanaged 786 787 async def litellm_tools(self) -> list[ToolCallDefinition]: 788 available_tools = await self.cached_available_tools() 789 790 registry_defs = [await tool.toolcall_definition() for tool in available_tools] 791 unmanaged = self.base_adapter_config.unmanaged_tools 792 unmanaged_defs = ( 793 [await t.toolcall_definition() for t in unmanaged] if unmanaged else [] 794 ) 795 796 merged = registry_defs + unmanaged_defs 797 seen_names: set[str] = set() 798 for d in merged: 799 name = d["function"]["name"] 800 if name in seen_names: 801 raise ValueError( 802 f"Duplicate tool name {name!r}: unmanaged and registry tools must have unique names." 803 ) 804 seen_names.add(name) 805 806 return merged 807 808 async def process_tool_calls( 809 self, tool_calls: list[ChatCompletionMessageToolCall] | None 810 ) -> tuple[str | None, list[ChatCompletionToolMessageParamWrapper]]: 811 if tool_calls is None: 812 return None, [] 813 814 assistant_output_from_toolcall: str | None = None 815 tool_call_response_messages: list[ChatCompletionToolMessageParamWrapper] = [] 816 tool_run_coroutines = [] 817 818 for tool_call in tool_calls: 819 # Kiln "task_response" tool is used for returning structured output via tool calls. 820 # Load the output from the tool call. Also 821 if tool_call.function.name == "task_response": 822 assistant_output_from_toolcall = tool_call.function.arguments 823 continue 824 825 # Process normal tool calls (not the "task_response" tool) 826 tool_name = tool_call.function.name 827 tool = None 828 for tool_option in await self._tools_for_execution(): 829 if await tool_option.name() == tool_name: 830 tool = tool_option 831 break 832 if not tool: 833 raise RuntimeError( 834 f"A tool named '{tool_name}' was invoked by a model, but was not available." 835 ) 836 837 # Parse the arguments and validate them against the tool's schema 838 try: 839 parsed_args = json.loads(tool_call.function.arguments) 840 except json.JSONDecodeError: 841 raise RuntimeError( 842 f"Failed to parse arguments for tool '{tool_name}' (should be JSON): {tool_call.function.arguments}" 843 ) 844 try: 845 tool_call_definition = await tool.toolcall_definition() 846 json_schema = json.dumps(tool_call_definition["function"]["parameters"]) 847 validate_schema_with_value_error(parsed_args, json_schema) 848 except Exception as e: 849 raise RuntimeError( 850 f"Failed to validate arguments for tool '{tool_name}'. The arguments didn't match the tool's schema. The arguments were: {parsed_args}\n The error was: {e}" 851 ) from e 852 853 # Create context with the calling task's allow_saving setting 854 context = ToolCallContext( 855 allow_saving=self.base_adapter_config.allow_saving 856 ) 857 858 async def run_tool_and_format( 859 t=tool, c=context, args=parsed_args, tc_id=tool_call.id 860 ): 861 result = await t.run(c, **args) 862 return ChatCompletionToolMessageParamWrapper( 863 role="tool", 864 tool_call_id=tc_id, 865 content=result.output, 866 kiln_task_tool_data=result.kiln_task_tool_data 867 if isinstance(result, KilnTaskToolResult) 868 else None, 869 is_error=result.is_error if result.is_error else None, 870 error_message=result.error_message 871 if result.error_message 872 else None, 873 ) 874 875 tool_run_coroutines.append(run_tool_and_format()) 876 877 if tool_run_coroutines: 878 tool_call_response_messages = await asyncio.gather(*tool_run_coroutines) 879 880 if ( 881 assistant_output_from_toolcall is not None 882 and len(tool_call_response_messages) > 0 883 ): 884 raise RuntimeError( 885 "Model asked for impossible combination: task_response tool call and other tool calls were both provided in the same turn. This is not supported as it means the model asked us to both return task_response results (ending the turn) and run new tools calls to send back to the model. If the model makes this mistake often, try a difference structured data model like JSON schema, where this is impossible." 886 ) 887 888 return assistant_output_from_toolcall, tool_call_response_messages 889 890 def litellm_message_to_trace_message( 891 self, 892 raw_message: LiteLLMMessage, 893 latency_ms: int | None = None, 894 usage: MessageUsage | None = None, 895 ) -> ChatCompletionAssistantMessageParamWrapper: 896 """ 897 Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper 898 """ 899 message: ChatCompletionAssistantMessageParamWrapper = { 900 "role": "assistant", 901 } 902 if raw_message.role != "assistant": 903 raise ValueError( 904 "Model returned a message with a role other than assistant. This is not supported." 905 ) 906 907 if hasattr(raw_message, "content"): 908 message["content"] = raw_message.content 909 if hasattr(raw_message, "reasoning_content"): 910 message["reasoning_content"] = raw_message.reasoning_content 911 if hasattr(raw_message, "tool_calls"): 912 # Convert ChatCompletionMessageToolCall to ChatCompletionMessageToolCallParam 913 open_ai_tool_calls: List[ChatCompletionMessageToolCallParam] = [] 914 for litellm_tool_call in raw_message.tool_calls or []: 915 # Optional in the SDK for streaming responses, but should never be None at this point. 916 if litellm_tool_call.function.name is None: 917 raise ValueError( 918 "The model requested a tool call, without providing a function name (required)." 919 ) 920 open_ai_tool_calls.append( 921 ChatCompletionMessageToolCallParam( 922 id=litellm_tool_call.id, 923 type="function", 924 function={ 925 "name": litellm_tool_call.function.name, 926 "arguments": litellm_tool_call.function.arguments, 927 }, 928 ) 929 ) 930 if len(open_ai_tool_calls) > 0: 931 message["tool_calls"] = open_ai_tool_calls 932 933 if latency_ms is not None: 934 message["latency_ms"] = latency_ms 935 936 if usage is not None: 937 message["usage"] = usage 938 939 if not message.get("content") and not message.get("tool_calls"): 940 raise ValueError( 941 "Model returned an assistant message, but no content or tool calls. This is not supported." 942 ) 943 944 return message 945 946 def all_messages_to_trace( 947 self, 948 messages: list[ChatCompletionMessageIncludingLiteLLM], 949 message_latency: dict[int, int] | None = None, 950 message_usage: dict[int, MessageUsage] | None = None, 951 ) -> list[ChatCompletionMessageParam]: 952 """ 953 Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types. 954 955 Non-LiteLLM dict messages pass through unchanged. Any per-message 956 ``usage``/``latency_ms`` already attached to those dicts (e.g. from a 957 seeded prior trace) is preserved. 958 """ 959 trace: list[ChatCompletionMessageParam] = [] 960 for i, message in enumerate(messages): 961 if isinstance(message, LiteLLMMessage): 962 latency_ms = message_latency.get(i) if message_latency else None 963 usage = message_usage.get(i) if message_usage else None 964 trace.append( 965 self.litellm_message_to_trace_message(message, latency_ms, usage) 966 ) 967 else: 968 trace.append(message) 969 return trace
78@dataclass 79class ModelTurnResult: 80 assistant_message: str 81 all_messages: list[ChatCompletionMessageIncludingLiteLLM] 82 model_response: ModelResponse | None 83 model_choice: Choices | None 84 usage: Usage 85 interrupted_by_tool_calls: list[ChatCompletionMessageToolCall] | None = None 86 message_latency: dict[int, int] | None = None 87 message_usage: dict[int, MessageUsage] | None = None
90class LiteLlmAdapter(BaseAdapter): 91 def __init__( 92 self, 93 config: LiteLlmConfig, 94 kiln_task: datamodel.Task, 95 base_adapter_config: AdapterConfig | None = None, 96 ): 97 if not isinstance(config.run_config_properties, KilnAgentRunConfigProperties): 98 raise ValueError("LiteLlmAdapter requires KilnAgentRunConfigProperties") 99 self.config = config 100 self._additional_body_options = config.additional_body_options 101 self._api_base = config.base_url 102 self._headers = config.default_headers 103 self._litellm_model_id: str | None = None 104 self._cached_available_tools: list[KilnToolInterface] | None = None 105 106 super().__init__( 107 task=kiln_task, 108 run_config=config.run_config_properties, 109 config=base_adapter_config, 110 ) 111 112 unmanaged_tools = self.base_adapter_config.unmanaged_tools 113 if unmanaged_tools: 114 _validate_unmanaged_tools(unmanaged_tools) 115 116 async def _run_model_turn( 117 self, 118 provider: KilnModelProvider, 119 prior_messages: list[ChatCompletionMessageIncludingLiteLLM], 120 top_logprobs: int | None, 121 skip_response_format: bool, 122 ) -> ModelTurnResult: 123 """ 124 Call the model for a single top level turn: from user message to agent message. 125 126 It may make handle iterations of tool calls between the user/agent message if needed. 127 """ 128 129 usage = Usage() 130 messages = list(prior_messages) 131 tool_calls_count = 0 132 # Per-LLM-call latency / usage, keyed by index in the messages list. 133 # Kept separate because we don't own the LiteLLM message objects. 134 message_latency: dict[int, int] = {} 135 message_usage: dict[int, MessageUsage] = {} 136 137 while tool_calls_count < MAX_TOOL_CALLS_PER_TURN: 138 # Build completion kwargs for tool calls 139 completion_kwargs = await self.build_completion_kwargs( 140 provider, 141 # Pass a copy, as acompletion mutates objects and breaks types. 142 copy.deepcopy(messages), 143 top_logprobs, 144 skip_response_format, 145 ) 146 147 # Make the completion call (timed) 148 start = time.monotonic() 149 model_response, response_choice = await self.acompletion_checking_response( 150 **completion_kwargs 151 ) 152 call_latency_ms = int((time.monotonic() - start) * 1000) 153 154 # count the usage 155 call_usage = self.usage_from_response(model_response) 156 usage += call_usage 157 usage.total_llm_latency_ms = ( 158 usage.total_llm_latency_ms or 0 159 ) + call_latency_ms 160 161 # Extract content and tool calls 162 if not hasattr(response_choice, "message"): 163 raise ValueError("Response choice has no message") 164 content = response_choice.message.content 165 tool_calls = response_choice.message.tool_calls 166 if not content and not tool_calls: 167 raise ValueError( 168 "Model returned an assistant message, but no content or tool calls. This is not supported." 169 ) 170 171 # Add message to messages, so it can be used in the next turn 172 messages.append(response_choice.message) 173 message_latency[len(messages) - 1] = call_latency_ms 174 message_usage[len(messages) - 1] = call_usage 175 176 # Process tool calls if any 177 if tool_calls and len(tool_calls) > 0: 178 # check if we should return control to caller 179 if self.base_adapter_config.return_on_tool_call: 180 # filter out task_response tool (task_response tools are internal) 181 standard_tool_calls = [ 182 tc for tc in tool_calls if tc.function.name != "task_response" 183 ] 184 has_task_response = any( 185 tc.function.name == "task_response" for tc in tool_calls 186 ) 187 if standard_tool_calls and not has_task_response: 188 return ModelTurnResult( 189 # we don't have any content, we are waiting for toolcall output to come back from client 190 assistant_message="", 191 all_messages=messages, 192 model_response=model_response, 193 model_choice=response_choice, 194 usage=usage, 195 interrupted_by_tool_calls=standard_tool_calls, 196 message_latency=message_latency, 197 message_usage=message_usage, 198 ) 199 200 # otherwise: process tool calls internally until final output 201 ( 202 assistant_message_from_toolcall, 203 tool_call_messages, 204 ) = await self.process_tool_calls(tool_calls) 205 206 # Add tool call results to messages 207 messages.extend(tool_call_messages) 208 209 # If task_response tool was called, we're done 210 if assistant_message_from_toolcall is not None: 211 return ModelTurnResult( 212 assistant_message=assistant_message_from_toolcall, 213 all_messages=messages, 214 model_response=model_response, 215 model_choice=response_choice, 216 usage=usage, 217 message_latency=message_latency, 218 message_usage=message_usage, 219 ) 220 221 # If there were tool calls, increment counter and continue 222 if tool_call_messages: 223 tool_calls_count += 1 224 continue 225 226 # If no tool calls, return the content as final output 227 if content: 228 return ModelTurnResult( 229 assistant_message=content, 230 all_messages=messages, 231 model_response=model_response, 232 model_choice=response_choice, 233 usage=usage, 234 message_latency=message_latency, 235 message_usage=message_usage, 236 ) 237 238 # If we get here with no content and no tool calls, break 239 raise RuntimeError( 240 "Model returned neither content nor tool calls. It must return at least one of these." 241 ) 242 243 raise RuntimeError( 244 f"Too many tool calls ({tool_calls_count}). Stopping iteration to avoid using too many tokens." 245 ) 246 247 async def _run( 248 self, 249 input: InputType, 250 prior_trace: list[ChatCompletionMessageParam] | None = None, 251 ) -> tuple[RunOutput, Usage | None]: 252 usage = Usage() 253 254 provider = self.model_provider() 255 if not provider.model_id: 256 raise ValueError("Model ID is required for OpenAI compatible models") 257 258 # build_chat_formatter returns MultiturnFormatter when prior_trace is set, else prompt-based formatter 259 chat_formatter = self.build_chat_formatter(input, prior_trace) 260 messages: list[ChatCompletionMessageIncludingLiteLLM] = copy.deepcopy( 261 chat_formatter.initial_messages() 262 ) 263 264 prior_output: str | None = None 265 final_choice: Choices | None = None 266 turns = 0 267 message_latency: dict[int, int] = {} 268 message_usage: dict[int, MessageUsage] = {} 269 270 # Same loop for both fresh runs and prior_trace continuation. 271 # _run_model_turn has its own internal loop for tool calls (model calls tool -> we run it -> model continues). 272 while True: 273 turns += 1 274 if turns > MAX_CALLS_PER_TURN: 275 raise RuntimeError( 276 f"Too many turns ({turns}). Stopping iteration to avoid using too many tokens." 277 ) 278 279 turn = chat_formatter.next_turn(prior_output) 280 if turn is None: 281 # No next turn, we're done 282 break 283 284 # Add messages from the turn to chat history 285 for message in turn.messages: 286 if message.content is None: 287 raise ValueError("Empty message content isn't allowed") 288 messages.append(chat_message_to_dict(message)) # type: ignore 289 290 skip_response_format = not turn.final_call 291 turn_result = await self._run_model_turn( 292 provider, 293 messages, 294 self.base_adapter_config.top_logprobs if turn.final_call else None, 295 skip_response_format, 296 ) 297 298 usage += turn_result.usage 299 if turn_result.message_latency: 300 message_latency.update(turn_result.message_latency) 301 if turn_result.message_usage: 302 message_usage.update(turn_result.message_usage) 303 304 prior_output = turn_result.assistant_message 305 messages = turn_result.all_messages 306 final_choice = turn_result.model_choice 307 308 # Check if we were interrupted by tool calls 309 if turn_result.interrupted_by_tool_calls: 310 trace = self.all_messages_to_trace( 311 messages, message_latency, message_usage 312 ) 313 intermediate_outputs = chat_formatter.intermediate_outputs() 314 output = RunOutput( 315 output=prior_output or "", 316 intermediate_outputs=intermediate_outputs, 317 output_logprobs=None, 318 trace=trace, 319 ) 320 return output, usage 321 322 if not prior_output: 323 raise RuntimeError("No assistant message/output returned from model") 324 325 logprobs = self._extract_and_validate_logprobs(final_choice) 326 327 # Save COT/reasoning if it exists. May be a message, or may be parsed by LiteLLM (or openrouter, or anyone upstream) 328 intermediate_outputs = chat_formatter.intermediate_outputs() 329 self._extract_reasoning_to_intermediate_outputs( 330 final_choice, intermediate_outputs 331 ) 332 333 if not isinstance(prior_output, str): 334 raise RuntimeError(f"assistant message is not a string: {prior_output}") 335 336 trace = self.all_messages_to_trace(messages, message_latency, message_usage) 337 output = RunOutput( 338 output=prior_output, 339 intermediate_outputs=intermediate_outputs, 340 output_logprobs=logprobs, 341 trace=trace, 342 ) 343 344 return output, usage 345 346 def _create_run_stream( 347 self, 348 input: InputType, 349 prior_trace: list[ChatCompletionMessageParam] | None = None, 350 ) -> AdapterStream: 351 provider = self.model_provider() 352 if not provider.model_id: 353 raise ValueError("Model ID is required for OpenAI compatible models") 354 355 chat_formatter = self.build_chat_formatter(input, prior_trace) 356 initial_messages: list[ChatCompletionMessageIncludingLiteLLM] = copy.deepcopy( 357 chat_formatter.initial_messages() 358 ) 359 360 return AdapterStream( 361 adapter=self, 362 provider=provider, 363 chat_formatter=chat_formatter, 364 initial_messages=initial_messages, 365 top_logprobs=self.base_adapter_config.top_logprobs, 366 ) 367 368 def _extract_and_validate_logprobs( 369 self, final_choice: Choices | None 370 ) -> ChoiceLogprobs | None: 371 """ 372 Extract logprobs from the final choice and validate they exist if required. 373 """ 374 logprobs = None 375 if ( 376 final_choice is not None 377 and hasattr(final_choice, "logprobs") 378 and isinstance(final_choice.logprobs, ChoiceLogprobs) 379 ): 380 logprobs = final_choice.logprobs 381 382 # Check logprobs worked, if required 383 if self.base_adapter_config.top_logprobs is not None and logprobs is None: 384 raise RuntimeError("Logprobs were required, but no logprobs were returned.") 385 386 return logprobs 387 388 def _extract_reasoning_to_intermediate_outputs( 389 self, final_choice: Choices | None, intermediate_outputs: Dict[str, Any] 390 ) -> None: 391 """Extract reasoning content from model choice and add to intermediate outputs if present.""" 392 if ( 393 final_choice is not None 394 and hasattr(final_choice, "message") 395 and hasattr(final_choice.message, "reasoning_content") 396 ): 397 reasoning_content = final_choice.message.reasoning_content 398 if reasoning_content is not None: 399 stripped_reasoning_content = reasoning_content.strip() 400 if len(stripped_reasoning_content) > 0: 401 intermediate_outputs["reasoning"] = stripped_reasoning_content 402 403 async def acompletion_checking_response( 404 self, **kwargs: Any 405 ) -> Tuple[ModelResponse, Choices]: 406 response = await litellm.acompletion(**kwargs) 407 408 if ( 409 not isinstance(response, ModelResponse) 410 or not response.choices 411 or len(response.choices) == 0 412 or not isinstance(response.choices[0], Choices) 413 ): 414 raise RuntimeError( 415 f"Expected ModelResponse with Choices, got {type(response)}." 416 ) 417 return response, response.choices[0] 418 419 def adapter_name(self) -> str: 420 return "kiln_openai_compatible_adapter" 421 422 async def response_format_options(self) -> dict[str, Any]: 423 # Unstructured if task isn't structured 424 if not self.has_structured_output(): 425 return {} 426 427 run_config = as_kiln_agent_run_config(self.run_config) 428 structured_output_mode: StructuredOutputMode = run_config.structured_output_mode 429 430 match structured_output_mode: 431 case StructuredOutputMode.json_mode: 432 return {"response_format": {"type": "json_object"}} 433 case StructuredOutputMode.json_schema: 434 return self.json_schema_response_format() 435 case StructuredOutputMode.function_calling_weak: 436 return self.tool_call_params(strict=False) 437 case StructuredOutputMode.function_calling: 438 return self.tool_call_params(strict=True) 439 case StructuredOutputMode.json_instructions: 440 # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below). 441 return {} 442 case StructuredOutputMode.json_custom_instructions: 443 # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above). 444 return {} 445 case StructuredOutputMode.json_instruction_and_object: 446 # We set response_format to json_object and also set json instructions in the prompt 447 return {"response_format": {"type": "json_object"}} 448 case StructuredOutputMode.default: 449 provider_name = run_config.model_provider_name 450 if provider_name == ModelProviderName.ollama: 451 # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs 452 return self.json_schema_response_format() 453 elif provider_name == ModelProviderName.docker_model_runner: 454 # Docker Model Runner uses OpenAI-compatible API with JSON schema support 455 return self.json_schema_response_format() 456 else: 457 # Default to function calling -- it's older than the other modes. Higher compatibility. 458 # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI. 459 strict = provider_name == ModelProviderName.openai 460 return self.tool_call_params(strict=strict) 461 case StructuredOutputMode.unknown: 462 # See above, but this case should never happen. 463 raise ValueError("Structured output mode is unknown.") 464 case _: 465 raise_exhaustive_enum_error(structured_output_mode) # type: ignore[arg-type] 466 467 def json_schema_response_format(self) -> dict[str, Any]: 468 output_schema = self.task.output_schema() 469 if output_schema is None: 470 raise ValueError( 471 "Invalid output schema for this task. Cannot use JSON schema response format." 472 ) 473 output_schema = close_object_schemas(output_schema, strict=True) 474 return { 475 "response_format": { 476 "type": "json_schema", 477 "json_schema": { 478 "name": "task_response", 479 "schema": output_schema, 480 }, 481 } 482 } 483 484 def tool_call_params(self, strict: bool) -> dict[str, Any]: 485 # Add additional_properties: false to the schema (OpenAI requires this for some models) 486 output_schema = self.task.output_schema() 487 if not isinstance(output_schema, dict): 488 raise ValueError( 489 "Invalid output schema for this task. Can not use tool calls." 490 ) 491 output_schema = close_object_schemas(output_schema, strict=strict) 492 493 function_params = { 494 "name": "task_response", 495 "parameters": output_schema, 496 } 497 # This should be on, but we allow setting function_calling_weak for APIs that don't support it. 498 if strict: 499 function_params["strict"] = True 500 501 return { 502 "tools": [ 503 { 504 "type": "function", 505 "function": function_params, 506 } 507 ], 508 "tool_choice": { 509 "type": "function", 510 "function": {"name": "task_response"}, 511 }, 512 } 513 514 def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]: 515 # Don't love having this logic here. But it's worth the usability improvement 516 # so better to keep it than exclude it. Should figure out how I want to isolate 517 # this sort of logic so it's config driven and can be overridden 518 extra_body: dict[str, Any] = {} 519 provider_options = {} 520 521 run_config = as_kiln_agent_run_config(self.run_config) 522 # For legacy config 'thinking_level' is not set, default to provider's default 523 if "thinking_level" in run_config.model_fields_set: 524 thinking_level = run_config.thinking_level 525 else: 526 thinking_level = provider.default_thinking_level 527 528 # Set the reasoning_effort 529 if thinking_level is not None: 530 # Anthropic models in OpenRouter uses reasoning object. See https://openrouter.ai/docs/use-cases/reasoning-tokens 531 if ( 532 provider.name == ModelProviderName.openrouter 533 and provider.openrouter_reasoning_object 534 ): 535 extra_body["reasoning"] = {"effort": thinking_level} 536 else: 537 extra_body["reasoning_effort"] = thinking_level 538 539 if provider.require_openrouter_reasoning: 540 # https://openrouter.ai/docs/use-cases/reasoning-tokens 541 extra_body["reasoning"] = { 542 "exclude": False, 543 } 544 545 if provider.gemini_reasoning_enabled: 546 extra_body["reasoning"] = { 547 "enabled": True, 548 } 549 550 if provider.name == ModelProviderName.openrouter: 551 # Ask OpenRouter to include usage in the response (cost) 552 extra_body["usage"] = {"include": True} 553 554 # Set a default provider order for more deterministic routing. 555 # OpenRouter will ignore providers that don't support the model. 556 # Special cases below (like R1) can override this order. 557 # allow_fallbacks is true by default, but we can override it here. 558 provider_options["order"] = [ 559 "fireworks", 560 "parasail", 561 "together", 562 "deepinfra", 563 "novita", 564 "groq", 565 "amazon-bedrock", 566 "azure", 567 "nebius", 568 ] 569 570 if provider.anthropic_extended_thinking: 571 extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000} 572 573 if provider.r1_openrouter_options: 574 # Require providers that support the reasoning parameter 575 provider_options["require_parameters"] = True 576 # Prefer R1 providers with reasonable perf/quants 577 provider_options["order"] = ["fireworks", "together"] 578 # R1 providers with unreasonable quants 579 provider_options["ignore"] = ["deepinfra"] 580 581 # Only set of this request is to get logprobs. 582 if ( 583 provider.logprobs_openrouter_options 584 and self.base_adapter_config.top_logprobs is not None 585 ): 586 # Don't let OpenRouter choose a provider that doesn't support logprobs. 587 provider_options["require_parameters"] = True 588 # DeepInfra silently fails to return logprobs consistently. 589 provider_options["ignore"] = ["deepinfra"] 590 591 if provider.openrouter_skip_required_parameters: 592 # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params. 593 provider_options["require_parameters"] = False 594 595 # Siliconflow uses a bool flag for thinking, for some models 596 if provider.siliconflow_enable_thinking is not None: 597 extra_body["enable_thinking"] = provider.siliconflow_enable_thinking 598 599 if len(provider_options) > 0: 600 extra_body["provider"] = provider_options 601 602 return extra_body 603 604 def litellm_model_id(self) -> str: 605 # The model ID is an interesting combination of format and url endpoint. 606 # It specifics the provider URL/host, but this is overridden if you manually set an api url 607 if self._litellm_model_id: 608 return self._litellm_model_id 609 610 litellm_provider_info = get_litellm_provider_info(self.model_provider()) 611 if litellm_provider_info.is_custom and self._api_base is None: 612 raise ValueError( 613 "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)" 614 ) 615 616 self._litellm_model_id = litellm_provider_info.litellm_model_id 617 return self._litellm_model_id 618 619 def _allowed_openai_params_for_completion_kwargs( 620 self, completion_kwargs: dict[str, Any] 621 ) -> list[str]: 622 """ 623 LiteLLM drops params it thinks are not supported by the model when drop_params is True. Sometimes it is wrong 624 and we know it is supported, so we whitelist them here and pass that as an allowed_openai_params parameter. 625 """ 626 # callers could have set allowed_openai_params in the additional_body_options, so we need to check for that 627 explicit_allowed_params: Any | list = completion_kwargs.get( 628 "allowed_openai_params", [] 629 ) 630 if not isinstance(explicit_allowed_params, list): 631 raise ValueError( 632 f"Unexpected allowed_openai_params format: {explicit_allowed_params} - expected list, got {type(explicit_allowed_params)}" 633 ) 634 explicit_allowed_params_validated = [ 635 param for param in explicit_allowed_params if isinstance(param, str) 636 ] 637 invalid_count = len(explicit_allowed_params) - len( 638 explicit_allowed_params_validated 639 ) 640 if invalid_count > 0: 641 raise ValueError( 642 f"Unexpected allowed_openai_params format: {explicit_allowed_params} - {invalid_count} items are not strings" 643 ) 644 645 # these are our own logic 646 automatic_allowed_params: list[str] = [] 647 if "tools" in completion_kwargs: 648 automatic_allowed_params.append("tools") 649 if "tool_choice" in completion_kwargs: 650 automatic_allowed_params.append("tool_choice") 651 652 return list(set(explicit_allowed_params_validated + automatic_allowed_params)) 653 654 async def build_completion_kwargs( 655 self, 656 provider: KilnModelProvider, 657 messages: list[ChatCompletionMessageIncludingLiteLLM], 658 top_logprobs: int | None, 659 skip_response_format: bool = False, 660 ) -> dict[str, Any]: 661 run_config = as_kiln_agent_run_config(self.run_config) 662 extra_body = self.build_extra_body(provider) 663 664 # Merge all parameters into a single kwargs dict for litellm 665 completion_kwargs = { 666 "model": self.litellm_model_id(), 667 "messages": messages, 668 "api_base": self._api_base, 669 "headers": self._headers, 670 "temperature": run_config.temperature, 671 "top_p": run_config.top_p, 672 # This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc. 673 # Not all models and providers support all openai params (for example, o3 doesn't support top_p) 674 # Better to ignore them than to fail the model call. 675 # https://docs.litellm.ai/docs/completion/input 676 "drop_params": True, 677 **extra_body, 678 **self._additional_body_options, 679 } 680 681 if self.base_adapter_config.automatic_prompt_caching: 682 # Mark the last message for cache control. Litellm's AnthropicCacheControlHook 683 # handles provider-specific injection. Providers auto-cache matching prefixes, 684 # so marking the last message is sufficient for multi-turn conversations. 685 completion_kwargs["cache_control_injection_points"] = [ 686 {"location": "message", "index": -1} 687 ] 688 689 tool_calls = await self.litellm_tools() 690 has_tools = len(tool_calls) > 0 691 if has_tools: 692 completion_kwargs["tools"] = tool_calls 693 completion_kwargs["tool_choice"] = "auto" 694 695 # Special condition for Claude Opus 4.1 and Sonnet 4.5, where we can only specify top_p or temp, not both. 696 # Remove default values (1.0) prioritizing anything the user customized, then error with helpful message if they are both custom. 697 if provider.temp_top_p_exclusive: 698 if "top_p" in completion_kwargs and completion_kwargs["top_p"] == 1.0: 699 del completion_kwargs["top_p"] 700 if ( 701 "temperature" in completion_kwargs 702 and completion_kwargs["temperature"] == 1.0 703 ): 704 del completion_kwargs["temperature"] 705 if "top_p" in completion_kwargs and "temperature" in completion_kwargs: 706 raise ValueError( 707 "top_p and temperature can not both have custom values for this model. This is a restriction from the model provider. Please set only one of them to a custom value (not 1.0)." 708 ) 709 710 if not skip_response_format: 711 # Response format: json_schema, json_instructions, json_mode, function_calling, etc 712 response_format_options = await self.response_format_options() 713 714 # Check for a conflict between tools and response format using tools 715 # We could reconsider this. Model could be able to choose between a final answer or a tool call on any turn. However, good models for tools tend to also support json_schea, so do we need to support both? If we do, merge them, and consider auto vs forced when merging (only forced for final, auto for merged). 716 if has_tools and "tools" in response_format_options: 717 raise ValueError( 718 "Function calling/tools can't be used as the JSON response format if you're also using tools. Please select a different structured output mode." 719 ) 720 721 completion_kwargs.update(response_format_options) 722 723 if top_logprobs is not None: 724 completion_kwargs["logprobs"] = True 725 completion_kwargs["top_logprobs"] = top_logprobs 726 727 # any params listed in this list will be passed to the model regardless of LiteLLM's own validation 728 allowed_openai_params = self._allowed_openai_params_for_completion_kwargs( 729 completion_kwargs 730 ) 731 if len(allowed_openai_params) > 0: 732 completion_kwargs["allowed_openai_params"] = allowed_openai_params 733 734 completion_kwargs["messages"] = sanitize_messages_for_provider(messages) 735 736 return completion_kwargs 737 738 def usage_from_response(self, response: ModelResponse) -> MessageUsage: 739 litellm_usage = response.get("usage", None) 740 741 # LiteLLM isn't consistent in how it returns the cost. 742 cost = response._hidden_params.get("response_cost", None) 743 if cost is None and litellm_usage: 744 cost = litellm_usage.get("cost", None) 745 746 usage = MessageUsage() 747 748 if not litellm_usage and not cost: 749 return usage 750 751 if litellm_usage and isinstance(litellm_usage, LiteLlmUsage): 752 usage.input_tokens = litellm_usage.get("prompt_tokens", None) 753 usage.output_tokens = litellm_usage.get("completion_tokens", None) 754 usage.total_tokens = litellm_usage.get("total_tokens", None) 755 prompt_details = litellm_usage.get("prompt_tokens_details", None) 756 if prompt_details and hasattr(prompt_details, "cached_tokens"): 757 usage.cached_tokens = prompt_details.cached_tokens 758 elif prompt_details: 759 logger.warning( 760 f"prompt_tokens_details has unexpected type {type(prompt_details)}, cached_tokens not extracted" 761 ) 762 else: 763 logger.warning( 764 f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}" 765 ) 766 767 if isinstance(cost, float): 768 usage.cost = cost 769 elif cost is not None: 770 # None is allowed, but no other types are expected 771 logger.warning( 772 f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}" 773 ) 774 775 return usage 776 777 async def cached_available_tools(self) -> list[KilnToolInterface]: 778 if self._cached_available_tools is None: 779 self._cached_available_tools = await self.available_tools() 780 return self._cached_available_tools 781 782 async def _tools_for_execution(self) -> list[KilnToolInterface]: 783 """Registry-resolved tools plus :attr:`AdapterConfig.unmanaged_tools` (same order as ``litellm_tools``).""" 784 registry = await self.cached_available_tools() 785 unmanaged = self.base_adapter_config.unmanaged_tools or [] 786 return registry + unmanaged 787 788 async def litellm_tools(self) -> list[ToolCallDefinition]: 789 available_tools = await self.cached_available_tools() 790 791 registry_defs = [await tool.toolcall_definition() for tool in available_tools] 792 unmanaged = self.base_adapter_config.unmanaged_tools 793 unmanaged_defs = ( 794 [await t.toolcall_definition() for t in unmanaged] if unmanaged else [] 795 ) 796 797 merged = registry_defs + unmanaged_defs 798 seen_names: set[str] = set() 799 for d in merged: 800 name = d["function"]["name"] 801 if name in seen_names: 802 raise ValueError( 803 f"Duplicate tool name {name!r}: unmanaged and registry tools must have unique names." 804 ) 805 seen_names.add(name) 806 807 return merged 808 809 async def process_tool_calls( 810 self, tool_calls: list[ChatCompletionMessageToolCall] | None 811 ) -> tuple[str | None, list[ChatCompletionToolMessageParamWrapper]]: 812 if tool_calls is None: 813 return None, [] 814 815 assistant_output_from_toolcall: str | None = None 816 tool_call_response_messages: list[ChatCompletionToolMessageParamWrapper] = [] 817 tool_run_coroutines = [] 818 819 for tool_call in tool_calls: 820 # Kiln "task_response" tool is used for returning structured output via tool calls. 821 # Load the output from the tool call. Also 822 if tool_call.function.name == "task_response": 823 assistant_output_from_toolcall = tool_call.function.arguments 824 continue 825 826 # Process normal tool calls (not the "task_response" tool) 827 tool_name = tool_call.function.name 828 tool = None 829 for tool_option in await self._tools_for_execution(): 830 if await tool_option.name() == tool_name: 831 tool = tool_option 832 break 833 if not tool: 834 raise RuntimeError( 835 f"A tool named '{tool_name}' was invoked by a model, but was not available." 836 ) 837 838 # Parse the arguments and validate them against the tool's schema 839 try: 840 parsed_args = json.loads(tool_call.function.arguments) 841 except json.JSONDecodeError: 842 raise RuntimeError( 843 f"Failed to parse arguments for tool '{tool_name}' (should be JSON): {tool_call.function.arguments}" 844 ) 845 try: 846 tool_call_definition = await tool.toolcall_definition() 847 json_schema = json.dumps(tool_call_definition["function"]["parameters"]) 848 validate_schema_with_value_error(parsed_args, json_schema) 849 except Exception as e: 850 raise RuntimeError( 851 f"Failed to validate arguments for tool '{tool_name}'. The arguments didn't match the tool's schema. The arguments were: {parsed_args}\n The error was: {e}" 852 ) from e 853 854 # Create context with the calling task's allow_saving setting 855 context = ToolCallContext( 856 allow_saving=self.base_adapter_config.allow_saving 857 ) 858 859 async def run_tool_and_format( 860 t=tool, c=context, args=parsed_args, tc_id=tool_call.id 861 ): 862 result = await t.run(c, **args) 863 return ChatCompletionToolMessageParamWrapper( 864 role="tool", 865 tool_call_id=tc_id, 866 content=result.output, 867 kiln_task_tool_data=result.kiln_task_tool_data 868 if isinstance(result, KilnTaskToolResult) 869 else None, 870 is_error=result.is_error if result.is_error else None, 871 error_message=result.error_message 872 if result.error_message 873 else None, 874 ) 875 876 tool_run_coroutines.append(run_tool_and_format()) 877 878 if tool_run_coroutines: 879 tool_call_response_messages = await asyncio.gather(*tool_run_coroutines) 880 881 if ( 882 assistant_output_from_toolcall is not None 883 and len(tool_call_response_messages) > 0 884 ): 885 raise RuntimeError( 886 "Model asked for impossible combination: task_response tool call and other tool calls were both provided in the same turn. This is not supported as it means the model asked us to both return task_response results (ending the turn) and run new tools calls to send back to the model. If the model makes this mistake often, try a difference structured data model like JSON schema, where this is impossible." 887 ) 888 889 return assistant_output_from_toolcall, tool_call_response_messages 890 891 def litellm_message_to_trace_message( 892 self, 893 raw_message: LiteLLMMessage, 894 latency_ms: int | None = None, 895 usage: MessageUsage | None = None, 896 ) -> ChatCompletionAssistantMessageParamWrapper: 897 """ 898 Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper 899 """ 900 message: ChatCompletionAssistantMessageParamWrapper = { 901 "role": "assistant", 902 } 903 if raw_message.role != "assistant": 904 raise ValueError( 905 "Model returned a message with a role other than assistant. This is not supported." 906 ) 907 908 if hasattr(raw_message, "content"): 909 message["content"] = raw_message.content 910 if hasattr(raw_message, "reasoning_content"): 911 message["reasoning_content"] = raw_message.reasoning_content 912 if hasattr(raw_message, "tool_calls"): 913 # Convert ChatCompletionMessageToolCall to ChatCompletionMessageToolCallParam 914 open_ai_tool_calls: List[ChatCompletionMessageToolCallParam] = [] 915 for litellm_tool_call in raw_message.tool_calls or []: 916 # Optional in the SDK for streaming responses, but should never be None at this point. 917 if litellm_tool_call.function.name is None: 918 raise ValueError( 919 "The model requested a tool call, without providing a function name (required)." 920 ) 921 open_ai_tool_calls.append( 922 ChatCompletionMessageToolCallParam( 923 id=litellm_tool_call.id, 924 type="function", 925 function={ 926 "name": litellm_tool_call.function.name, 927 "arguments": litellm_tool_call.function.arguments, 928 }, 929 ) 930 ) 931 if len(open_ai_tool_calls) > 0: 932 message["tool_calls"] = open_ai_tool_calls 933 934 if latency_ms is not None: 935 message["latency_ms"] = latency_ms 936 937 if usage is not None: 938 message["usage"] = usage 939 940 if not message.get("content") and not message.get("tool_calls"): 941 raise ValueError( 942 "Model returned an assistant message, but no content or tool calls. This is not supported." 943 ) 944 945 return message 946 947 def all_messages_to_trace( 948 self, 949 messages: list[ChatCompletionMessageIncludingLiteLLM], 950 message_latency: dict[int, int] | None = None, 951 message_usage: dict[int, MessageUsage] | None = None, 952 ) -> list[ChatCompletionMessageParam]: 953 """ 954 Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types. 955 956 Non-LiteLLM dict messages pass through unchanged. Any per-message 957 ``usage``/``latency_ms`` already attached to those dicts (e.g. from a 958 seeded prior trace) is preserved. 959 """ 960 trace: list[ChatCompletionMessageParam] = [] 961 for i, message in enumerate(messages): 962 if isinstance(message, LiteLLMMessage): 963 latency_ms = message_latency.get(i) if message_latency else None 964 usage = message_usage.get(i) if message_usage else None 965 trace.append( 966 self.litellm_message_to_trace_message(message, latency_ms, usage) 967 ) 968 else: 969 trace.append(message) 970 return trace
Base class for AI model adapters that handle task execution.
This abstract class provides the foundation for implementing model-specific adapters that can process tasks with structured or unstructured inputs/outputs. It handles input/output validation, prompt building, and run tracking.
Prompt building is handled internally by the adapter, which uses a prompt builder based on the run config. To override the prompt building behavior, pass a custom prompt builder to the adapter config.
91 def __init__( 92 self, 93 config: LiteLlmConfig, 94 kiln_task: datamodel.Task, 95 base_adapter_config: AdapterConfig | None = None, 96 ): 97 if not isinstance(config.run_config_properties, KilnAgentRunConfigProperties): 98 raise ValueError("LiteLlmAdapter requires KilnAgentRunConfigProperties") 99 self.config = config 100 self._additional_body_options = config.additional_body_options 101 self._api_base = config.base_url 102 self._headers = config.default_headers 103 self._litellm_model_id: str | None = None 104 self._cached_available_tools: list[KilnToolInterface] | None = None 105 106 super().__init__( 107 task=kiln_task, 108 run_config=config.run_config_properties, 109 config=base_adapter_config, 110 ) 111 112 unmanaged_tools = self.base_adapter_config.unmanaged_tools 113 if unmanaged_tools: 114 _validate_unmanaged_tools(unmanaged_tools)
403 async def acompletion_checking_response( 404 self, **kwargs: Any 405 ) -> Tuple[ModelResponse, Choices]: 406 response = await litellm.acompletion(**kwargs) 407 408 if ( 409 not isinstance(response, ModelResponse) 410 or not response.choices 411 or len(response.choices) == 0 412 or not isinstance(response.choices[0], Choices) 413 ): 414 raise RuntimeError( 415 f"Expected ModelResponse with Choices, got {type(response)}." 416 ) 417 return response, response.choices[0]
422 async def response_format_options(self) -> dict[str, Any]: 423 # Unstructured if task isn't structured 424 if not self.has_structured_output(): 425 return {} 426 427 run_config = as_kiln_agent_run_config(self.run_config) 428 structured_output_mode: StructuredOutputMode = run_config.structured_output_mode 429 430 match structured_output_mode: 431 case StructuredOutputMode.json_mode: 432 return {"response_format": {"type": "json_object"}} 433 case StructuredOutputMode.json_schema: 434 return self.json_schema_response_format() 435 case StructuredOutputMode.function_calling_weak: 436 return self.tool_call_params(strict=False) 437 case StructuredOutputMode.function_calling: 438 return self.tool_call_params(strict=True) 439 case StructuredOutputMode.json_instructions: 440 # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below). 441 return {} 442 case StructuredOutputMode.json_custom_instructions: 443 # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above). 444 return {} 445 case StructuredOutputMode.json_instruction_and_object: 446 # We set response_format to json_object and also set json instructions in the prompt 447 return {"response_format": {"type": "json_object"}} 448 case StructuredOutputMode.default: 449 provider_name = run_config.model_provider_name 450 if provider_name == ModelProviderName.ollama: 451 # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs 452 return self.json_schema_response_format() 453 elif provider_name == ModelProviderName.docker_model_runner: 454 # Docker Model Runner uses OpenAI-compatible API with JSON schema support 455 return self.json_schema_response_format() 456 else: 457 # Default to function calling -- it's older than the other modes. Higher compatibility. 458 # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI. 459 strict = provider_name == ModelProviderName.openai 460 return self.tool_call_params(strict=strict) 461 case StructuredOutputMode.unknown: 462 # See above, but this case should never happen. 463 raise ValueError("Structured output mode is unknown.") 464 case _: 465 raise_exhaustive_enum_error(structured_output_mode) # type: ignore[arg-type]
467 def json_schema_response_format(self) -> dict[str, Any]: 468 output_schema = self.task.output_schema() 469 if output_schema is None: 470 raise ValueError( 471 "Invalid output schema for this task. Cannot use JSON schema response format." 472 ) 473 output_schema = close_object_schemas(output_schema, strict=True) 474 return { 475 "response_format": { 476 "type": "json_schema", 477 "json_schema": { 478 "name": "task_response", 479 "schema": output_schema, 480 }, 481 } 482 }
484 def tool_call_params(self, strict: bool) -> dict[str, Any]: 485 # Add additional_properties: false to the schema (OpenAI requires this for some models) 486 output_schema = self.task.output_schema() 487 if not isinstance(output_schema, dict): 488 raise ValueError( 489 "Invalid output schema for this task. Can not use tool calls." 490 ) 491 output_schema = close_object_schemas(output_schema, strict=strict) 492 493 function_params = { 494 "name": "task_response", 495 "parameters": output_schema, 496 } 497 # This should be on, but we allow setting function_calling_weak for APIs that don't support it. 498 if strict: 499 function_params["strict"] = True 500 501 return { 502 "tools": [ 503 { 504 "type": "function", 505 "function": function_params, 506 } 507 ], 508 "tool_choice": { 509 "type": "function", 510 "function": {"name": "task_response"}, 511 }, 512 }
514 def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]: 515 # Don't love having this logic here. But it's worth the usability improvement 516 # so better to keep it than exclude it. Should figure out how I want to isolate 517 # this sort of logic so it's config driven and can be overridden 518 extra_body: dict[str, Any] = {} 519 provider_options = {} 520 521 run_config = as_kiln_agent_run_config(self.run_config) 522 # For legacy config 'thinking_level' is not set, default to provider's default 523 if "thinking_level" in run_config.model_fields_set: 524 thinking_level = run_config.thinking_level 525 else: 526 thinking_level = provider.default_thinking_level 527 528 # Set the reasoning_effort 529 if thinking_level is not None: 530 # Anthropic models in OpenRouter uses reasoning object. See https://openrouter.ai/docs/use-cases/reasoning-tokens 531 if ( 532 provider.name == ModelProviderName.openrouter 533 and provider.openrouter_reasoning_object 534 ): 535 extra_body["reasoning"] = {"effort": thinking_level} 536 else: 537 extra_body["reasoning_effort"] = thinking_level 538 539 if provider.require_openrouter_reasoning: 540 # https://openrouter.ai/docs/use-cases/reasoning-tokens 541 extra_body["reasoning"] = { 542 "exclude": False, 543 } 544 545 if provider.gemini_reasoning_enabled: 546 extra_body["reasoning"] = { 547 "enabled": True, 548 } 549 550 if provider.name == ModelProviderName.openrouter: 551 # Ask OpenRouter to include usage in the response (cost) 552 extra_body["usage"] = {"include": True} 553 554 # Set a default provider order for more deterministic routing. 555 # OpenRouter will ignore providers that don't support the model. 556 # Special cases below (like R1) can override this order. 557 # allow_fallbacks is true by default, but we can override it here. 558 provider_options["order"] = [ 559 "fireworks", 560 "parasail", 561 "together", 562 "deepinfra", 563 "novita", 564 "groq", 565 "amazon-bedrock", 566 "azure", 567 "nebius", 568 ] 569 570 if provider.anthropic_extended_thinking: 571 extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000} 572 573 if provider.r1_openrouter_options: 574 # Require providers that support the reasoning parameter 575 provider_options["require_parameters"] = True 576 # Prefer R1 providers with reasonable perf/quants 577 provider_options["order"] = ["fireworks", "together"] 578 # R1 providers with unreasonable quants 579 provider_options["ignore"] = ["deepinfra"] 580 581 # Only set of this request is to get logprobs. 582 if ( 583 provider.logprobs_openrouter_options 584 and self.base_adapter_config.top_logprobs is not None 585 ): 586 # Don't let OpenRouter choose a provider that doesn't support logprobs. 587 provider_options["require_parameters"] = True 588 # DeepInfra silently fails to return logprobs consistently. 589 provider_options["ignore"] = ["deepinfra"] 590 591 if provider.openrouter_skip_required_parameters: 592 # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params. 593 provider_options["require_parameters"] = False 594 595 # Siliconflow uses a bool flag for thinking, for some models 596 if provider.siliconflow_enable_thinking is not None: 597 extra_body["enable_thinking"] = provider.siliconflow_enable_thinking 598 599 if len(provider_options) > 0: 600 extra_body["provider"] = provider_options 601 602 return extra_body
604 def litellm_model_id(self) -> str: 605 # The model ID is an interesting combination of format and url endpoint. 606 # It specifics the provider URL/host, but this is overridden if you manually set an api url 607 if self._litellm_model_id: 608 return self._litellm_model_id 609 610 litellm_provider_info = get_litellm_provider_info(self.model_provider()) 611 if litellm_provider_info.is_custom and self._api_base is None: 612 raise ValueError( 613 "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)" 614 ) 615 616 self._litellm_model_id = litellm_provider_info.litellm_model_id 617 return self._litellm_model_id
654 async def build_completion_kwargs( 655 self, 656 provider: KilnModelProvider, 657 messages: list[ChatCompletionMessageIncludingLiteLLM], 658 top_logprobs: int | None, 659 skip_response_format: bool = False, 660 ) -> dict[str, Any]: 661 run_config = as_kiln_agent_run_config(self.run_config) 662 extra_body = self.build_extra_body(provider) 663 664 # Merge all parameters into a single kwargs dict for litellm 665 completion_kwargs = { 666 "model": self.litellm_model_id(), 667 "messages": messages, 668 "api_base": self._api_base, 669 "headers": self._headers, 670 "temperature": run_config.temperature, 671 "top_p": run_config.top_p, 672 # This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc. 673 # Not all models and providers support all openai params (for example, o3 doesn't support top_p) 674 # Better to ignore them than to fail the model call. 675 # https://docs.litellm.ai/docs/completion/input 676 "drop_params": True, 677 **extra_body, 678 **self._additional_body_options, 679 } 680 681 if self.base_adapter_config.automatic_prompt_caching: 682 # Mark the last message for cache control. Litellm's AnthropicCacheControlHook 683 # handles provider-specific injection. Providers auto-cache matching prefixes, 684 # so marking the last message is sufficient for multi-turn conversations. 685 completion_kwargs["cache_control_injection_points"] = [ 686 {"location": "message", "index": -1} 687 ] 688 689 tool_calls = await self.litellm_tools() 690 has_tools = len(tool_calls) > 0 691 if has_tools: 692 completion_kwargs["tools"] = tool_calls 693 completion_kwargs["tool_choice"] = "auto" 694 695 # Special condition for Claude Opus 4.1 and Sonnet 4.5, where we can only specify top_p or temp, not both. 696 # Remove default values (1.0) prioritizing anything the user customized, then error with helpful message if they are both custom. 697 if provider.temp_top_p_exclusive: 698 if "top_p" in completion_kwargs and completion_kwargs["top_p"] == 1.0: 699 del completion_kwargs["top_p"] 700 if ( 701 "temperature" in completion_kwargs 702 and completion_kwargs["temperature"] == 1.0 703 ): 704 del completion_kwargs["temperature"] 705 if "top_p" in completion_kwargs and "temperature" in completion_kwargs: 706 raise ValueError( 707 "top_p and temperature can not both have custom values for this model. This is a restriction from the model provider. Please set only one of them to a custom value (not 1.0)." 708 ) 709 710 if not skip_response_format: 711 # Response format: json_schema, json_instructions, json_mode, function_calling, etc 712 response_format_options = await self.response_format_options() 713 714 # Check for a conflict between tools and response format using tools 715 # We could reconsider this. Model could be able to choose between a final answer or a tool call on any turn. However, good models for tools tend to also support json_schea, so do we need to support both? If we do, merge them, and consider auto vs forced when merging (only forced for final, auto for merged). 716 if has_tools and "tools" in response_format_options: 717 raise ValueError( 718 "Function calling/tools can't be used as the JSON response format if you're also using tools. Please select a different structured output mode." 719 ) 720 721 completion_kwargs.update(response_format_options) 722 723 if top_logprobs is not None: 724 completion_kwargs["logprobs"] = True 725 completion_kwargs["top_logprobs"] = top_logprobs 726 727 # any params listed in this list will be passed to the model regardless of LiteLLM's own validation 728 allowed_openai_params = self._allowed_openai_params_for_completion_kwargs( 729 completion_kwargs 730 ) 731 if len(allowed_openai_params) > 0: 732 completion_kwargs["allowed_openai_params"] = allowed_openai_params 733 734 completion_kwargs["messages"] = sanitize_messages_for_provider(messages) 735 736 return completion_kwargs
738 def usage_from_response(self, response: ModelResponse) -> MessageUsage: 739 litellm_usage = response.get("usage", None) 740 741 # LiteLLM isn't consistent in how it returns the cost. 742 cost = response._hidden_params.get("response_cost", None) 743 if cost is None and litellm_usage: 744 cost = litellm_usage.get("cost", None) 745 746 usage = MessageUsage() 747 748 if not litellm_usage and not cost: 749 return usage 750 751 if litellm_usage and isinstance(litellm_usage, LiteLlmUsage): 752 usage.input_tokens = litellm_usage.get("prompt_tokens", None) 753 usage.output_tokens = litellm_usage.get("completion_tokens", None) 754 usage.total_tokens = litellm_usage.get("total_tokens", None) 755 prompt_details = litellm_usage.get("prompt_tokens_details", None) 756 if prompt_details and hasattr(prompt_details, "cached_tokens"): 757 usage.cached_tokens = prompt_details.cached_tokens 758 elif prompt_details: 759 logger.warning( 760 f"prompt_tokens_details has unexpected type {type(prompt_details)}, cached_tokens not extracted" 761 ) 762 else: 763 logger.warning( 764 f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}" 765 ) 766 767 if isinstance(cost, float): 768 usage.cost = cost 769 elif cost is not None: 770 # None is allowed, but no other types are expected 771 logger.warning( 772 f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}" 773 ) 774 775 return usage
788 async def litellm_tools(self) -> list[ToolCallDefinition]: 789 available_tools = await self.cached_available_tools() 790 791 registry_defs = [await tool.toolcall_definition() for tool in available_tools] 792 unmanaged = self.base_adapter_config.unmanaged_tools 793 unmanaged_defs = ( 794 [await t.toolcall_definition() for t in unmanaged] if unmanaged else [] 795 ) 796 797 merged = registry_defs + unmanaged_defs 798 seen_names: set[str] = set() 799 for d in merged: 800 name = d["function"]["name"] 801 if name in seen_names: 802 raise ValueError( 803 f"Duplicate tool name {name!r}: unmanaged and registry tools must have unique names." 804 ) 805 seen_names.add(name) 806 807 return merged
809 async def process_tool_calls( 810 self, tool_calls: list[ChatCompletionMessageToolCall] | None 811 ) -> tuple[str | None, list[ChatCompletionToolMessageParamWrapper]]: 812 if tool_calls is None: 813 return None, [] 814 815 assistant_output_from_toolcall: str | None = None 816 tool_call_response_messages: list[ChatCompletionToolMessageParamWrapper] = [] 817 tool_run_coroutines = [] 818 819 for tool_call in tool_calls: 820 # Kiln "task_response" tool is used for returning structured output via tool calls. 821 # Load the output from the tool call. Also 822 if tool_call.function.name == "task_response": 823 assistant_output_from_toolcall = tool_call.function.arguments 824 continue 825 826 # Process normal tool calls (not the "task_response" tool) 827 tool_name = tool_call.function.name 828 tool = None 829 for tool_option in await self._tools_for_execution(): 830 if await tool_option.name() == tool_name: 831 tool = tool_option 832 break 833 if not tool: 834 raise RuntimeError( 835 f"A tool named '{tool_name}' was invoked by a model, but was not available." 836 ) 837 838 # Parse the arguments and validate them against the tool's schema 839 try: 840 parsed_args = json.loads(tool_call.function.arguments) 841 except json.JSONDecodeError: 842 raise RuntimeError( 843 f"Failed to parse arguments for tool '{tool_name}' (should be JSON): {tool_call.function.arguments}" 844 ) 845 try: 846 tool_call_definition = await tool.toolcall_definition() 847 json_schema = json.dumps(tool_call_definition["function"]["parameters"]) 848 validate_schema_with_value_error(parsed_args, json_schema) 849 except Exception as e: 850 raise RuntimeError( 851 f"Failed to validate arguments for tool '{tool_name}'. The arguments didn't match the tool's schema. The arguments were: {parsed_args}\n The error was: {e}" 852 ) from e 853 854 # Create context with the calling task's allow_saving setting 855 context = ToolCallContext( 856 allow_saving=self.base_adapter_config.allow_saving 857 ) 858 859 async def run_tool_and_format( 860 t=tool, c=context, args=parsed_args, tc_id=tool_call.id 861 ): 862 result = await t.run(c, **args) 863 return ChatCompletionToolMessageParamWrapper( 864 role="tool", 865 tool_call_id=tc_id, 866 content=result.output, 867 kiln_task_tool_data=result.kiln_task_tool_data 868 if isinstance(result, KilnTaskToolResult) 869 else None, 870 is_error=result.is_error if result.is_error else None, 871 error_message=result.error_message 872 if result.error_message 873 else None, 874 ) 875 876 tool_run_coroutines.append(run_tool_and_format()) 877 878 if tool_run_coroutines: 879 tool_call_response_messages = await asyncio.gather(*tool_run_coroutines) 880 881 if ( 882 assistant_output_from_toolcall is not None 883 and len(tool_call_response_messages) > 0 884 ): 885 raise RuntimeError( 886 "Model asked for impossible combination: task_response tool call and other tool calls were both provided in the same turn. This is not supported as it means the model asked us to both return task_response results (ending the turn) and run new tools calls to send back to the model. If the model makes this mistake often, try a difference structured data model like JSON schema, where this is impossible." 887 ) 888 889 return assistant_output_from_toolcall, tool_call_response_messages
891 def litellm_message_to_trace_message( 892 self, 893 raw_message: LiteLLMMessage, 894 latency_ms: int | None = None, 895 usage: MessageUsage | None = None, 896 ) -> ChatCompletionAssistantMessageParamWrapper: 897 """ 898 Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper 899 """ 900 message: ChatCompletionAssistantMessageParamWrapper = { 901 "role": "assistant", 902 } 903 if raw_message.role != "assistant": 904 raise ValueError( 905 "Model returned a message with a role other than assistant. This is not supported." 906 ) 907 908 if hasattr(raw_message, "content"): 909 message["content"] = raw_message.content 910 if hasattr(raw_message, "reasoning_content"): 911 message["reasoning_content"] = raw_message.reasoning_content 912 if hasattr(raw_message, "tool_calls"): 913 # Convert ChatCompletionMessageToolCall to ChatCompletionMessageToolCallParam 914 open_ai_tool_calls: List[ChatCompletionMessageToolCallParam] = [] 915 for litellm_tool_call in raw_message.tool_calls or []: 916 # Optional in the SDK for streaming responses, but should never be None at this point. 917 if litellm_tool_call.function.name is None: 918 raise ValueError( 919 "The model requested a tool call, without providing a function name (required)." 920 ) 921 open_ai_tool_calls.append( 922 ChatCompletionMessageToolCallParam( 923 id=litellm_tool_call.id, 924 type="function", 925 function={ 926 "name": litellm_tool_call.function.name, 927 "arguments": litellm_tool_call.function.arguments, 928 }, 929 ) 930 ) 931 if len(open_ai_tool_calls) > 0: 932 message["tool_calls"] = open_ai_tool_calls 933 934 if latency_ms is not None: 935 message["latency_ms"] = latency_ms 936 937 if usage is not None: 938 message["usage"] = usage 939 940 if not message.get("content") and not message.get("tool_calls"): 941 raise ValueError( 942 "Model returned an assistant message, but no content or tool calls. This is not supported." 943 ) 944 945 return message
Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper
947 def all_messages_to_trace( 948 self, 949 messages: list[ChatCompletionMessageIncludingLiteLLM], 950 message_latency: dict[int, int] | None = None, 951 message_usage: dict[int, MessageUsage] | None = None, 952 ) -> list[ChatCompletionMessageParam]: 953 """ 954 Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types. 955 956 Non-LiteLLM dict messages pass through unchanged. Any per-message 957 ``usage``/``latency_ms`` already attached to those dicts (e.g. from a 958 seeded prior trace) is preserved. 959 """ 960 trace: list[ChatCompletionMessageParam] = [] 961 for i, message in enumerate(messages): 962 if isinstance(message, LiteLLMMessage): 963 latency_ms = message_latency.get(i) if message_latency else None 964 usage = message_usage.get(i) if message_usage else None 965 trace.append( 966 self.litellm_message_to_trace_message(message, latency_ms, usage) 967 ) 968 else: 969 trace.append(message) 970 return trace
Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types.
Non-LiteLLM dict messages pass through unchanged. Any per-message
usage/latency_ms already attached to those dicts (e.g. from a
seeded prior trace) is preserved.
Inherited Members
- kiln_ai.adapters.model_adapters.base_adapter.BaseAdapter
- task
- run_config
- base_adapter_config
- output_schema
- input_schema
- model_provider
- invoke
- invoke_returning_run_output
- invoke_openai_stream
- invoke_ai_sdk_stream
- has_structured_output
- build_prompt
- build_chat_formatter
- generate_run
- update_run_config_unknown_structured_output_mode
- available_tools