kiln_ai.adapters.model_adapters.litellm_adapter
1import asyncio 2import copy 3import json 4import logging 5from dataclasses import dataclass 6from typing import Any, Dict, List, Tuple 7 8import litellm 9from litellm.types.utils import ( 10 ChatCompletionMessageToolCall, 11 ChoiceLogprobs, 12 Choices, 13 ModelResponse, 14) 15from litellm.types.utils import Message as LiteLLMMessage 16from litellm.types.utils import Usage as LiteLlmUsage 17from openai.types.chat.chat_completion_message_tool_call_param import ( 18 ChatCompletionMessageToolCallParam, 19) 20 21import kiln_ai.datamodel as datamodel 22from kiln_ai.adapters.chat import ChatCompletionMessageIncludingLiteLLM 23from kiln_ai.adapters.chat.chat_formatter import ToolResponseMessage 24from kiln_ai.adapters.ml_model_list import ( 25 KilnModelProvider, 26 ModelProviderName, 27 StructuredOutputMode, 28) 29from kiln_ai.adapters.model_adapters.adapter_stream import AdapterStream 30from kiln_ai.adapters.model_adapters.base_adapter import ( 31 AdapterConfig, 32 BaseAdapter, 33 RunOutput, 34 Usage, 35) 36from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig 37from kiln_ai.datamodel.datamodel_enums import InputType 38from kiln_ai.datamodel.json_schema import ( 39 close_object_schemas, 40 validate_schema_with_value_error, 41) 42from kiln_ai.datamodel.run_config import ( 43 KilnAgentRunConfigProperties, 44 as_kiln_agent_run_config, 45) 46from kiln_ai.tools.base_tool import ( 47 KilnToolInterface, 48 ToolCallContext, 49 ToolCallDefinition, 50) 51from kiln_ai.tools.kiln_task_tool import KilnTaskToolResult 52from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error 53from kiln_ai.utils.litellm import get_litellm_provider_info 54from kiln_ai.utils.open_ai_types import ( 55 ChatCompletionAssistantMessageParamWrapper, 56 ChatCompletionMessageParam, 57 ChatCompletionToolMessageParamWrapper, 58) 59 60MAX_CALLS_PER_TURN = 10 61MAX_TOOL_CALLS_PER_TURN = 30 62 63logger = logging.getLogger(__name__) 64 65 66def _validate_unmanaged_tools(tools: list[KilnToolInterface]) -> None: 67 for i, tool in enumerate(tools): 68 if not isinstance(tool, KilnToolInterface): 69 raise TypeError( 70 f"unmanaged_tools[{i}] must be a KilnToolInterface instance, got {type(tool).__name__}" 71 ) 72 73 74@dataclass 75class ModelTurnResult: 76 assistant_message: str 77 all_messages: list[ChatCompletionMessageIncludingLiteLLM] 78 model_response: ModelResponse | None 79 model_choice: Choices | None 80 usage: Usage 81 interrupted_by_tool_calls: list[ChatCompletionMessageToolCall] | None = None 82 83 84class LiteLlmAdapter(BaseAdapter): 85 def __init__( 86 self, 87 config: LiteLlmConfig, 88 kiln_task: datamodel.Task, 89 base_adapter_config: AdapterConfig | None = None, 90 ): 91 if not isinstance(config.run_config_properties, KilnAgentRunConfigProperties): 92 raise ValueError("LiteLlmAdapter requires KilnAgentRunConfigProperties") 93 self.config = config 94 self._additional_body_options = config.additional_body_options 95 self._api_base = config.base_url 96 self._headers = config.default_headers 97 self._litellm_model_id: str | None = None 98 self._cached_available_tools: list[KilnToolInterface] | None = None 99 100 super().__init__( 101 task=kiln_task, 102 run_config=config.run_config_properties, 103 config=base_adapter_config, 104 ) 105 106 unmanaged_tools = self.base_adapter_config.unmanaged_tools 107 if unmanaged_tools: 108 _validate_unmanaged_tools(unmanaged_tools) 109 110 async def _run_model_turn( 111 self, 112 provider: KilnModelProvider, 113 prior_messages: list[ChatCompletionMessageIncludingLiteLLM], 114 top_logprobs: int | None, 115 skip_response_format: bool, 116 ) -> ModelTurnResult: 117 """ 118 Call the model for a single top level turn: from user message to agent message. 119 120 It may make handle iterations of tool calls between the user/agent message if needed. 121 """ 122 123 usage = Usage() 124 messages = list(prior_messages) 125 tool_calls_count = 0 126 127 while tool_calls_count < MAX_TOOL_CALLS_PER_TURN: 128 # Build completion kwargs for tool calls 129 completion_kwargs = await self.build_completion_kwargs( 130 provider, 131 # Pass a copy, as acompletion mutates objects and breaks types. 132 copy.deepcopy(messages), 133 top_logprobs, 134 skip_response_format, 135 ) 136 137 # Make the completion call 138 model_response, response_choice = await self.acompletion_checking_response( 139 **completion_kwargs 140 ) 141 142 # count the usage 143 usage += self.usage_from_response(model_response) 144 145 # Extract content and tool calls 146 if not hasattr(response_choice, "message"): 147 raise ValueError("Response choice has no message") 148 content = response_choice.message.content 149 tool_calls = response_choice.message.tool_calls 150 if not content and not tool_calls: 151 raise ValueError( 152 "Model returned an assistant message, but no content or tool calls. This is not supported." 153 ) 154 155 # Add message to messages, so it can be used in the next turn 156 messages.append(response_choice.message) 157 158 # Process tool calls if any 159 if tool_calls and len(tool_calls) > 0: 160 # check if we should return control to caller 161 if self.base_adapter_config.return_on_tool_call: 162 # filter out task_response tool (task_response tools are internal) 163 standard_tool_calls = [ 164 tc for tc in tool_calls if tc.function.name != "task_response" 165 ] 166 has_task_response = any( 167 tc.function.name == "task_response" for tc in tool_calls 168 ) 169 if standard_tool_calls and not has_task_response: 170 return ModelTurnResult( 171 # we don't have any content, we are waiting for toolcall output to come back from client 172 assistant_message="", 173 all_messages=messages, 174 model_response=model_response, 175 model_choice=response_choice, 176 usage=usage, 177 interrupted_by_tool_calls=standard_tool_calls, 178 ) 179 180 # otherwise: process tool calls internally until final output 181 ( 182 assistant_message_from_toolcall, 183 tool_call_messages, 184 ) = await self.process_tool_calls(tool_calls) 185 186 # Add tool call results to messages 187 messages.extend(tool_call_messages) 188 189 # If task_response tool was called, we're done 190 if assistant_message_from_toolcall is not None: 191 return ModelTurnResult( 192 assistant_message=assistant_message_from_toolcall, 193 all_messages=messages, 194 model_response=model_response, 195 model_choice=response_choice, 196 usage=usage, 197 ) 198 199 # If there were tool calls, increment counter and continue 200 if tool_call_messages: 201 tool_calls_count += 1 202 continue 203 204 # If no tool calls, return the content as final output 205 if content: 206 return ModelTurnResult( 207 assistant_message=content, 208 all_messages=messages, 209 model_response=model_response, 210 model_choice=response_choice, 211 usage=usage, 212 ) 213 214 # If we get here with no content and no tool calls, break 215 raise RuntimeError( 216 "Model returned neither content nor tool calls. It must return at least one of these." 217 ) 218 219 raise RuntimeError( 220 f"Too many tool calls ({tool_calls_count}). Stopping iteration to avoid using too many tokens." 221 ) 222 223 async def _run( 224 self, 225 input: InputType, 226 prior_trace: list[ChatCompletionMessageParam] | None = None, 227 ) -> tuple[RunOutput, Usage | None]: 228 usage = Usage() 229 230 provider = self.model_provider() 231 if not provider.model_id: 232 raise ValueError("Model ID is required for OpenAI compatible models") 233 234 # build_chat_formatter returns MultiturnFormatter when prior_trace is set, else prompt-based formatter 235 chat_formatter = self.build_chat_formatter(input, prior_trace) 236 messages: list[ChatCompletionMessageIncludingLiteLLM] = copy.deepcopy( 237 chat_formatter.initial_messages() 238 ) 239 240 prior_output: str | None = None 241 final_choice: Choices | None = None 242 turns = 0 243 244 # Same loop for both fresh runs and prior_trace continuation. 245 # _run_model_turn has its own internal loop for tool calls (model calls tool -> we run it -> model continues). 246 while True: 247 turns += 1 248 if turns > MAX_CALLS_PER_TURN: 249 raise RuntimeError( 250 f"Too many turns ({turns}). Stopping iteration to avoid using too many tokens." 251 ) 252 253 turn = chat_formatter.next_turn(prior_output) 254 if turn is None: 255 # No next turn, we're done 256 break 257 258 # Add messages from the turn to chat history 259 for message in turn.messages: 260 if message.content is None: 261 raise ValueError("Empty message content isn't allowed") 262 # pyright incorrectly warns about this, but it's valid so we can ignore. It can't handle the multi-value role. 263 msg_dict: dict = {"role": message.role, "content": message.content} 264 if isinstance(message, ToolResponseMessage): 265 msg_dict["tool_call_id"] = message.tool_call_id 266 messages.append(msg_dict) # type: ignore 267 268 skip_response_format = not turn.final_call 269 turn_result = await self._run_model_turn( 270 provider, 271 messages, 272 self.base_adapter_config.top_logprobs if turn.final_call else None, 273 skip_response_format, 274 ) 275 276 usage += turn_result.usage 277 278 prior_output = turn_result.assistant_message 279 messages = turn_result.all_messages 280 final_choice = turn_result.model_choice 281 282 # Check if we were interrupted by tool calls 283 if turn_result.interrupted_by_tool_calls: 284 trace = self.all_messages_to_trace(messages) 285 intermediate_outputs = chat_formatter.intermediate_outputs() 286 output = RunOutput( 287 output=prior_output or "", 288 intermediate_outputs=intermediate_outputs, 289 output_logprobs=None, 290 trace=trace, 291 ) 292 return output, usage 293 294 if not prior_output: 295 raise RuntimeError("No assistant message/output returned from model") 296 297 logprobs = self._extract_and_validate_logprobs(final_choice) 298 299 # Save COT/reasoning if it exists. May be a message, or may be parsed by LiteLLM (or openrouter, or anyone upstream) 300 intermediate_outputs = chat_formatter.intermediate_outputs() 301 self._extract_reasoning_to_intermediate_outputs( 302 final_choice, intermediate_outputs 303 ) 304 305 if not isinstance(prior_output, str): 306 raise RuntimeError(f"assistant message is not a string: {prior_output}") 307 308 trace = self.all_messages_to_trace(messages) 309 output = RunOutput( 310 output=prior_output, 311 intermediate_outputs=intermediate_outputs, 312 output_logprobs=logprobs, 313 trace=trace, 314 ) 315 316 return output, usage 317 318 def _create_run_stream( 319 self, 320 input: InputType, 321 prior_trace: list[ChatCompletionMessageParam] | None = None, 322 ) -> AdapterStream: 323 provider = self.model_provider() 324 if not provider.model_id: 325 raise ValueError("Model ID is required for OpenAI compatible models") 326 327 chat_formatter = self.build_chat_formatter(input, prior_trace) 328 initial_messages: list[ChatCompletionMessageIncludingLiteLLM] = copy.deepcopy( 329 chat_formatter.initial_messages() 330 ) 331 332 return AdapterStream( 333 adapter=self, 334 provider=provider, 335 chat_formatter=chat_formatter, 336 initial_messages=initial_messages, 337 top_logprobs=self.base_adapter_config.top_logprobs, 338 ) 339 340 def _extract_and_validate_logprobs( 341 self, final_choice: Choices | None 342 ) -> ChoiceLogprobs | None: 343 """ 344 Extract logprobs from the final choice and validate they exist if required. 345 """ 346 logprobs = None 347 if ( 348 final_choice is not None 349 and hasattr(final_choice, "logprobs") 350 and isinstance(final_choice.logprobs, ChoiceLogprobs) 351 ): 352 logprobs = final_choice.logprobs 353 354 # Check logprobs worked, if required 355 if self.base_adapter_config.top_logprobs is not None and logprobs is None: 356 raise RuntimeError("Logprobs were required, but no logprobs were returned.") 357 358 return logprobs 359 360 def _extract_reasoning_to_intermediate_outputs( 361 self, final_choice: Choices | None, intermediate_outputs: Dict[str, Any] 362 ) -> None: 363 """Extract reasoning content from model choice and add to intermediate outputs if present.""" 364 if ( 365 final_choice is not None 366 and hasattr(final_choice, "message") 367 and hasattr(final_choice.message, "reasoning_content") 368 ): 369 reasoning_content = final_choice.message.reasoning_content 370 if reasoning_content is not None: 371 stripped_reasoning_content = reasoning_content.strip() 372 if len(stripped_reasoning_content) > 0: 373 intermediate_outputs["reasoning"] = stripped_reasoning_content 374 375 async def acompletion_checking_response( 376 self, **kwargs: Any 377 ) -> Tuple[ModelResponse, Choices]: 378 response = await litellm.acompletion(**kwargs) 379 380 if ( 381 not isinstance(response, ModelResponse) 382 or not response.choices 383 or len(response.choices) == 0 384 or not isinstance(response.choices[0], Choices) 385 ): 386 raise RuntimeError( 387 f"Expected ModelResponse with Choices, got {type(response)}." 388 ) 389 return response, response.choices[0] 390 391 def adapter_name(self) -> str: 392 return "kiln_openai_compatible_adapter" 393 394 async def response_format_options(self) -> dict[str, Any]: 395 # Unstructured if task isn't structured 396 if not self.has_structured_output(): 397 return {} 398 399 run_config = as_kiln_agent_run_config(self.run_config) 400 structured_output_mode: StructuredOutputMode = run_config.structured_output_mode 401 402 match structured_output_mode: 403 case StructuredOutputMode.json_mode: 404 return {"response_format": {"type": "json_object"}} 405 case StructuredOutputMode.json_schema: 406 return self.json_schema_response_format() 407 case StructuredOutputMode.function_calling_weak: 408 return self.tool_call_params(strict=False) 409 case StructuredOutputMode.function_calling: 410 return self.tool_call_params(strict=True) 411 case StructuredOutputMode.json_instructions: 412 # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below). 413 return {} 414 case StructuredOutputMode.json_custom_instructions: 415 # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above). 416 return {} 417 case StructuredOutputMode.json_instruction_and_object: 418 # We set response_format to json_object and also set json instructions in the prompt 419 return {"response_format": {"type": "json_object"}} 420 case StructuredOutputMode.default: 421 provider_name = run_config.model_provider_name 422 if provider_name == ModelProviderName.ollama: 423 # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs 424 return self.json_schema_response_format() 425 elif provider_name == ModelProviderName.docker_model_runner: 426 # Docker Model Runner uses OpenAI-compatible API with JSON schema support 427 return self.json_schema_response_format() 428 else: 429 # Default to function calling -- it's older than the other modes. Higher compatibility. 430 # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI. 431 strict = provider_name == ModelProviderName.openai 432 return self.tool_call_params(strict=strict) 433 case StructuredOutputMode.unknown: 434 # See above, but this case should never happen. 435 raise ValueError("Structured output mode is unknown.") 436 case _: 437 raise_exhaustive_enum_error(structured_output_mode) # type: ignore[arg-type] 438 439 def json_schema_response_format(self) -> dict[str, Any]: 440 output_schema = self.task.output_schema() 441 if output_schema is None: 442 raise ValueError( 443 "Invalid output schema for this task. Cannot use JSON schema response format." 444 ) 445 output_schema = close_object_schemas(output_schema, strict=True) 446 return { 447 "response_format": { 448 "type": "json_schema", 449 "json_schema": { 450 "name": "task_response", 451 "schema": output_schema, 452 }, 453 } 454 } 455 456 def tool_call_params(self, strict: bool) -> dict[str, Any]: 457 # Add additional_properties: false to the schema (OpenAI requires this for some models) 458 output_schema = self.task.output_schema() 459 if not isinstance(output_schema, dict): 460 raise ValueError( 461 "Invalid output schema for this task. Can not use tool calls." 462 ) 463 output_schema = close_object_schemas(output_schema, strict=strict) 464 465 function_params = { 466 "name": "task_response", 467 "parameters": output_schema, 468 } 469 # This should be on, but we allow setting function_calling_weak for APIs that don't support it. 470 if strict: 471 function_params["strict"] = True 472 473 return { 474 "tools": [ 475 { 476 "type": "function", 477 "function": function_params, 478 } 479 ], 480 "tool_choice": { 481 "type": "function", 482 "function": {"name": "task_response"}, 483 }, 484 } 485 486 def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]: 487 # Don't love having this logic here. But it's worth the usability improvement 488 # so better to keep it than exclude it. Should figure out how I want to isolate 489 # this sort of logic so it's config driven and can be overridden 490 extra_body: dict[str, Any] = {} 491 provider_options = {} 492 493 run_config = as_kiln_agent_run_config(self.run_config) 494 # For legacy config 'thinking_level' is not set, default to provider's default 495 if "thinking_level" in run_config.model_fields_set: 496 thinking_level = run_config.thinking_level 497 else: 498 thinking_level = provider.default_thinking_level 499 500 # Set the reasoning_effort 501 if thinking_level is not None: 502 # Anthropic models in OpenRouter uses reasoning object. See https://openrouter.ai/docs/use-cases/reasoning-tokens 503 if ( 504 provider.name == ModelProviderName.openrouter 505 and provider.openrouter_reasoning_object 506 ): 507 extra_body["reasoning"] = {"effort": thinking_level} 508 else: 509 extra_body["reasoning_effort"] = thinking_level 510 511 if provider.require_openrouter_reasoning: 512 # https://openrouter.ai/docs/use-cases/reasoning-tokens 513 extra_body["reasoning"] = { 514 "exclude": False, 515 } 516 517 if provider.gemini_reasoning_enabled: 518 extra_body["reasoning"] = { 519 "enabled": True, 520 } 521 522 if provider.name == ModelProviderName.openrouter: 523 # Ask OpenRouter to include usage in the response (cost) 524 extra_body["usage"] = {"include": True} 525 526 # Set a default provider order for more deterministic routing. 527 # OpenRouter will ignore providers that don't support the model. 528 # Special cases below (like R1) can override this order. 529 # allow_fallbacks is true by default, but we can override it here. 530 provider_options["order"] = [ 531 "fireworks", 532 "parasail", 533 "together", 534 "deepinfra", 535 "novita", 536 "groq", 537 "amazon-bedrock", 538 "azure", 539 "nebius", 540 ] 541 542 if provider.anthropic_extended_thinking: 543 extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000} 544 545 if provider.r1_openrouter_options: 546 # Require providers that support the reasoning parameter 547 provider_options["require_parameters"] = True 548 # Prefer R1 providers with reasonable perf/quants 549 provider_options["order"] = ["fireworks", "together"] 550 # R1 providers with unreasonable quants 551 provider_options["ignore"] = ["deepinfra"] 552 553 # Only set of this request is to get logprobs. 554 if ( 555 provider.logprobs_openrouter_options 556 and self.base_adapter_config.top_logprobs is not None 557 ): 558 # Don't let OpenRouter choose a provider that doesn't support logprobs. 559 provider_options["require_parameters"] = True 560 # DeepInfra silently fails to return logprobs consistently. 561 provider_options["ignore"] = ["deepinfra"] 562 563 if provider.openrouter_skip_required_parameters: 564 # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params. 565 provider_options["require_parameters"] = False 566 567 # Siliconflow uses a bool flag for thinking, for some models 568 if provider.siliconflow_enable_thinking is not None: 569 extra_body["enable_thinking"] = provider.siliconflow_enable_thinking 570 571 if len(provider_options) > 0: 572 extra_body["provider"] = provider_options 573 574 return extra_body 575 576 def litellm_model_id(self) -> str: 577 # The model ID is an interesting combination of format and url endpoint. 578 # It specifics the provider URL/host, but this is overridden if you manually set an api url 579 if self._litellm_model_id: 580 return self._litellm_model_id 581 582 litellm_provider_info = get_litellm_provider_info(self.model_provider()) 583 if litellm_provider_info.is_custom and self._api_base is None: 584 raise ValueError( 585 "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)" 586 ) 587 588 self._litellm_model_id = litellm_provider_info.litellm_model_id 589 return self._litellm_model_id 590 591 def _allowed_openai_params_for_completion_kwargs( 592 self, completion_kwargs: dict[str, Any] 593 ) -> list[str]: 594 """ 595 LiteLLM drops params it thinks are not supported by the model when drop_params is True. Sometimes it is wrong 596 and we know it is supported, so we whitelist them here and pass that as an allowed_openai_params parameter. 597 """ 598 # callers could have set allowed_openai_params in the additional_body_options, so we need to check for that 599 explicit_allowed_params: Any | list = completion_kwargs.get( 600 "allowed_openai_params", [] 601 ) 602 if not isinstance(explicit_allowed_params, list): 603 raise ValueError( 604 f"Unexpected allowed_openai_params format: {explicit_allowed_params} - expected list, got {type(explicit_allowed_params)}" 605 ) 606 explicit_allowed_params_validated = [ 607 param for param in explicit_allowed_params if isinstance(param, str) 608 ] 609 invalid_count = len(explicit_allowed_params) - len( 610 explicit_allowed_params_validated 611 ) 612 if invalid_count > 0: 613 raise ValueError( 614 f"Unexpected allowed_openai_params format: {explicit_allowed_params} - {invalid_count} items are not strings" 615 ) 616 617 # these are our own logic 618 automatic_allowed_params: list[str] = [] 619 if "tools" in completion_kwargs: 620 automatic_allowed_params.append("tools") 621 if "tool_choice" in completion_kwargs: 622 automatic_allowed_params.append("tool_choice") 623 624 return list(set(explicit_allowed_params_validated + automatic_allowed_params)) 625 626 async def build_completion_kwargs( 627 self, 628 provider: KilnModelProvider, 629 messages: list[ChatCompletionMessageIncludingLiteLLM], 630 top_logprobs: int | None, 631 skip_response_format: bool = False, 632 ) -> dict[str, Any]: 633 run_config = as_kiln_agent_run_config(self.run_config) 634 extra_body = self.build_extra_body(provider) 635 636 # Merge all parameters into a single kwargs dict for litellm 637 completion_kwargs = { 638 "model": self.litellm_model_id(), 639 "messages": messages, 640 "api_base": self._api_base, 641 "headers": self._headers, 642 "temperature": run_config.temperature, 643 "top_p": run_config.top_p, 644 # This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc. 645 # Not all models and providers support all openai params (for example, o3 doesn't support top_p) 646 # Better to ignore them than to fail the model call. 647 # https://docs.litellm.ai/docs/completion/input 648 "drop_params": True, 649 **extra_body, 650 **self._additional_body_options, 651 } 652 653 if self.base_adapter_config.automatic_prompt_caching: 654 # Mark the last message for cache control. Litellm's AnthropicCacheControlHook 655 # handles provider-specific injection. Providers auto-cache matching prefixes, 656 # so marking the last message is sufficient for multi-turn conversations. 657 completion_kwargs["cache_control_injection_points"] = [ 658 {"location": "message", "index": -1} 659 ] 660 661 tool_calls = await self.litellm_tools() 662 has_tools = len(tool_calls) > 0 663 if has_tools: 664 completion_kwargs["tools"] = tool_calls 665 completion_kwargs["tool_choice"] = "auto" 666 667 # Special condition for Claude Opus 4.1 and Sonnet 4.5, where we can only specify top_p or temp, not both. 668 # Remove default values (1.0) prioritizing anything the user customized, then error with helpful message if they are both custom. 669 if provider.temp_top_p_exclusive: 670 if "top_p" in completion_kwargs and completion_kwargs["top_p"] == 1.0: 671 del completion_kwargs["top_p"] 672 if ( 673 "temperature" in completion_kwargs 674 and completion_kwargs["temperature"] == 1.0 675 ): 676 del completion_kwargs["temperature"] 677 if "top_p" in completion_kwargs and "temperature" in completion_kwargs: 678 raise ValueError( 679 "top_p and temperature can not both have custom values for this model. This is a restriction from the model provider. Please set only one of them to a custom value (not 1.0)." 680 ) 681 682 if not skip_response_format: 683 # Response format: json_schema, json_instructions, json_mode, function_calling, etc 684 response_format_options = await self.response_format_options() 685 686 # Check for a conflict between tools and response format using tools 687 # We could reconsider this. Model could be able to choose between a final answer or a tool call on any turn. However, good models for tools tend to also support json_schea, so do we need to support both? If we do, merge them, and consider auto vs forced when merging (only forced for final, auto for merged). 688 if has_tools and "tools" in response_format_options: 689 raise ValueError( 690 "Function calling/tools can't be used as the JSON response format if you're also using tools. Please select a different structured output mode." 691 ) 692 693 completion_kwargs.update(response_format_options) 694 695 if top_logprobs is not None: 696 completion_kwargs["logprobs"] = True 697 completion_kwargs["top_logprobs"] = top_logprobs 698 699 # any params listed in this list will be passed to the model regardless of LiteLLM's own validation 700 allowed_openai_params = self._allowed_openai_params_for_completion_kwargs( 701 completion_kwargs 702 ) 703 if len(allowed_openai_params) > 0: 704 completion_kwargs["allowed_openai_params"] = allowed_openai_params 705 706 return completion_kwargs 707 708 def usage_from_response(self, response: ModelResponse) -> Usage: 709 litellm_usage = response.get("usage", None) 710 711 # LiteLLM isn't consistent in how it returns the cost. 712 cost = response._hidden_params.get("response_cost", None) 713 if cost is None and litellm_usage: 714 cost = litellm_usage.get("cost", None) 715 716 usage = Usage() 717 718 if not litellm_usage and not cost: 719 return usage 720 721 if litellm_usage and isinstance(litellm_usage, LiteLlmUsage): 722 usage.input_tokens = litellm_usage.get("prompt_tokens", None) 723 usage.output_tokens = litellm_usage.get("completion_tokens", None) 724 usage.total_tokens = litellm_usage.get("total_tokens", None) 725 prompt_details = litellm_usage.get("prompt_tokens_details", None) 726 if prompt_details and hasattr(prompt_details, "cached_tokens"): 727 usage.cached_tokens = prompt_details.cached_tokens 728 elif prompt_details: 729 logger.warning( 730 f"prompt_tokens_details has unexpected type {type(prompt_details)}, cached_tokens not extracted" 731 ) 732 else: 733 logger.warning( 734 f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}" 735 ) 736 737 if isinstance(cost, float): 738 usage.cost = cost 739 elif cost is not None: 740 # None is allowed, but no other types are expected 741 logger.warning( 742 f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}" 743 ) 744 745 return usage 746 747 async def cached_available_tools(self) -> list[KilnToolInterface]: 748 if self._cached_available_tools is None: 749 self._cached_available_tools = await self.available_tools() 750 return self._cached_available_tools 751 752 async def _tools_for_execution(self) -> list[KilnToolInterface]: 753 """Registry-resolved tools plus :attr:`AdapterConfig.unmanaged_tools` (same order as ``litellm_tools``).""" 754 registry = await self.cached_available_tools() 755 unmanaged = self.base_adapter_config.unmanaged_tools or [] 756 return registry + unmanaged 757 758 async def litellm_tools(self) -> list[ToolCallDefinition]: 759 available_tools = await self.cached_available_tools() 760 761 registry_defs = [await tool.toolcall_definition() for tool in available_tools] 762 unmanaged = self.base_adapter_config.unmanaged_tools 763 unmanaged_defs = ( 764 [await t.toolcall_definition() for t in unmanaged] if unmanaged else [] 765 ) 766 767 merged = registry_defs + unmanaged_defs 768 seen_names: set[str] = set() 769 for d in merged: 770 name = d["function"]["name"] 771 if name in seen_names: 772 raise ValueError( 773 f"Duplicate tool name {name!r}: unmanaged and registry tools must have unique names." 774 ) 775 seen_names.add(name) 776 777 return merged 778 779 async def process_tool_calls( 780 self, tool_calls: list[ChatCompletionMessageToolCall] | None 781 ) -> tuple[str | None, list[ChatCompletionToolMessageParamWrapper]]: 782 if tool_calls is None: 783 return None, [] 784 785 assistant_output_from_toolcall: str | None = None 786 tool_call_response_messages: list[ChatCompletionToolMessageParamWrapper] = [] 787 tool_run_coroutines = [] 788 789 for tool_call in tool_calls: 790 # Kiln "task_response" tool is used for returning structured output via tool calls. 791 # Load the output from the tool call. Also 792 if tool_call.function.name == "task_response": 793 assistant_output_from_toolcall = tool_call.function.arguments 794 continue 795 796 # Process normal tool calls (not the "task_response" tool) 797 tool_name = tool_call.function.name 798 tool = None 799 for tool_option in await self._tools_for_execution(): 800 if await tool_option.name() == tool_name: 801 tool = tool_option 802 break 803 if not tool: 804 raise RuntimeError( 805 f"A tool named '{tool_name}' was invoked by a model, but was not available." 806 ) 807 808 # Parse the arguments and validate them against the tool's schema 809 try: 810 parsed_args = json.loads(tool_call.function.arguments) 811 except json.JSONDecodeError: 812 raise RuntimeError( 813 f"Failed to parse arguments for tool '{tool_name}' (should be JSON): {tool_call.function.arguments}" 814 ) 815 try: 816 tool_call_definition = await tool.toolcall_definition() 817 json_schema = json.dumps(tool_call_definition["function"]["parameters"]) 818 validate_schema_with_value_error(parsed_args, json_schema) 819 except Exception as e: 820 raise RuntimeError( 821 f"Failed to validate arguments for tool '{tool_name}'. The arguments didn't match the tool's schema. The arguments were: {parsed_args}\n The error was: {e}" 822 ) from e 823 824 # Create context with the calling task's allow_saving setting 825 context = ToolCallContext( 826 allow_saving=self.base_adapter_config.allow_saving 827 ) 828 829 async def run_tool_and_format( 830 t=tool, c=context, args=parsed_args, tc_id=tool_call.id 831 ): 832 result = await t.run(c, **args) 833 return ChatCompletionToolMessageParamWrapper( 834 role="tool", 835 tool_call_id=tc_id, 836 content=result.output, 837 kiln_task_tool_data=result.kiln_task_tool_data 838 if isinstance(result, KilnTaskToolResult) 839 else None, 840 is_error=result.is_error if result.is_error else None, 841 error_message=result.error_message 842 if result.error_message 843 else None, 844 ) 845 846 tool_run_coroutines.append(run_tool_and_format()) 847 848 if tool_run_coroutines: 849 tool_call_response_messages = await asyncio.gather(*tool_run_coroutines) 850 851 if ( 852 assistant_output_from_toolcall is not None 853 and len(tool_call_response_messages) > 0 854 ): 855 raise RuntimeError( 856 "Model asked for impossible combination: task_response tool call and other tool calls were both provided in the same turn. This is not supported as it means the model asked us to both return task_response results (ending the turn) and run new tools calls to send back to the model. If the model makes this mistake often, try a difference structured data model like JSON schema, where this is impossible." 857 ) 858 859 return assistant_output_from_toolcall, tool_call_response_messages 860 861 def litellm_message_to_trace_message( 862 self, raw_message: LiteLLMMessage 863 ) -> ChatCompletionAssistantMessageParamWrapper: 864 """ 865 Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper 866 """ 867 message: ChatCompletionAssistantMessageParamWrapper = { 868 "role": "assistant", 869 } 870 if raw_message.role != "assistant": 871 raise ValueError( 872 "Model returned a message with a role other than assistant. This is not supported." 873 ) 874 875 if hasattr(raw_message, "content"): 876 message["content"] = raw_message.content 877 if hasattr(raw_message, "reasoning_content"): 878 message["reasoning_content"] = raw_message.reasoning_content 879 if hasattr(raw_message, "tool_calls"): 880 # Convert ChatCompletionMessageToolCall to ChatCompletionMessageToolCallParam 881 open_ai_tool_calls: List[ChatCompletionMessageToolCallParam] = [] 882 for litellm_tool_call in raw_message.tool_calls or []: 883 # Optional in the SDK for streaming responses, but should never be None at this point. 884 if litellm_tool_call.function.name is None: 885 raise ValueError( 886 "The model requested a tool call, without providing a function name (required)." 887 ) 888 open_ai_tool_calls.append( 889 ChatCompletionMessageToolCallParam( 890 id=litellm_tool_call.id, 891 type="function", 892 function={ 893 "name": litellm_tool_call.function.name, 894 "arguments": litellm_tool_call.function.arguments, 895 }, 896 ) 897 ) 898 if len(open_ai_tool_calls) > 0: 899 message["tool_calls"] = open_ai_tool_calls 900 901 if not message.get("content") and not message.get("tool_calls"): 902 raise ValueError( 903 "Model returned an assistant message, but no content or tool calls. This is not supported." 904 ) 905 906 return message 907 908 def all_messages_to_trace( 909 self, messages: list[ChatCompletionMessageIncludingLiteLLM] 910 ) -> list[ChatCompletionMessageParam]: 911 """ 912 Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types. 913 """ 914 trace: list[ChatCompletionMessageParam] = [] 915 for message in messages: 916 if isinstance(message, LiteLLMMessage): 917 trace.append(self.litellm_message_to_trace_message(message)) 918 else: 919 trace.append(message) 920 return trace
75@dataclass 76class ModelTurnResult: 77 assistant_message: str 78 all_messages: list[ChatCompletionMessageIncludingLiteLLM] 79 model_response: ModelResponse | None 80 model_choice: Choices | None 81 usage: Usage 82 interrupted_by_tool_calls: list[ChatCompletionMessageToolCall] | None = None
85class LiteLlmAdapter(BaseAdapter): 86 def __init__( 87 self, 88 config: LiteLlmConfig, 89 kiln_task: datamodel.Task, 90 base_adapter_config: AdapterConfig | None = None, 91 ): 92 if not isinstance(config.run_config_properties, KilnAgentRunConfigProperties): 93 raise ValueError("LiteLlmAdapter requires KilnAgentRunConfigProperties") 94 self.config = config 95 self._additional_body_options = config.additional_body_options 96 self._api_base = config.base_url 97 self._headers = config.default_headers 98 self._litellm_model_id: str | None = None 99 self._cached_available_tools: list[KilnToolInterface] | None = None 100 101 super().__init__( 102 task=kiln_task, 103 run_config=config.run_config_properties, 104 config=base_adapter_config, 105 ) 106 107 unmanaged_tools = self.base_adapter_config.unmanaged_tools 108 if unmanaged_tools: 109 _validate_unmanaged_tools(unmanaged_tools) 110 111 async def _run_model_turn( 112 self, 113 provider: KilnModelProvider, 114 prior_messages: list[ChatCompletionMessageIncludingLiteLLM], 115 top_logprobs: int | None, 116 skip_response_format: bool, 117 ) -> ModelTurnResult: 118 """ 119 Call the model for a single top level turn: from user message to agent message. 120 121 It may make handle iterations of tool calls between the user/agent message if needed. 122 """ 123 124 usage = Usage() 125 messages = list(prior_messages) 126 tool_calls_count = 0 127 128 while tool_calls_count < MAX_TOOL_CALLS_PER_TURN: 129 # Build completion kwargs for tool calls 130 completion_kwargs = await self.build_completion_kwargs( 131 provider, 132 # Pass a copy, as acompletion mutates objects and breaks types. 133 copy.deepcopy(messages), 134 top_logprobs, 135 skip_response_format, 136 ) 137 138 # Make the completion call 139 model_response, response_choice = await self.acompletion_checking_response( 140 **completion_kwargs 141 ) 142 143 # count the usage 144 usage += self.usage_from_response(model_response) 145 146 # Extract content and tool calls 147 if not hasattr(response_choice, "message"): 148 raise ValueError("Response choice has no message") 149 content = response_choice.message.content 150 tool_calls = response_choice.message.tool_calls 151 if not content and not tool_calls: 152 raise ValueError( 153 "Model returned an assistant message, but no content or tool calls. This is not supported." 154 ) 155 156 # Add message to messages, so it can be used in the next turn 157 messages.append(response_choice.message) 158 159 # Process tool calls if any 160 if tool_calls and len(tool_calls) > 0: 161 # check if we should return control to caller 162 if self.base_adapter_config.return_on_tool_call: 163 # filter out task_response tool (task_response tools are internal) 164 standard_tool_calls = [ 165 tc for tc in tool_calls if tc.function.name != "task_response" 166 ] 167 has_task_response = any( 168 tc.function.name == "task_response" for tc in tool_calls 169 ) 170 if standard_tool_calls and not has_task_response: 171 return ModelTurnResult( 172 # we don't have any content, we are waiting for toolcall output to come back from client 173 assistant_message="", 174 all_messages=messages, 175 model_response=model_response, 176 model_choice=response_choice, 177 usage=usage, 178 interrupted_by_tool_calls=standard_tool_calls, 179 ) 180 181 # otherwise: process tool calls internally until final output 182 ( 183 assistant_message_from_toolcall, 184 tool_call_messages, 185 ) = await self.process_tool_calls(tool_calls) 186 187 # Add tool call results to messages 188 messages.extend(tool_call_messages) 189 190 # If task_response tool was called, we're done 191 if assistant_message_from_toolcall is not None: 192 return ModelTurnResult( 193 assistant_message=assistant_message_from_toolcall, 194 all_messages=messages, 195 model_response=model_response, 196 model_choice=response_choice, 197 usage=usage, 198 ) 199 200 # If there were tool calls, increment counter and continue 201 if tool_call_messages: 202 tool_calls_count += 1 203 continue 204 205 # If no tool calls, return the content as final output 206 if content: 207 return ModelTurnResult( 208 assistant_message=content, 209 all_messages=messages, 210 model_response=model_response, 211 model_choice=response_choice, 212 usage=usage, 213 ) 214 215 # If we get here with no content and no tool calls, break 216 raise RuntimeError( 217 "Model returned neither content nor tool calls. It must return at least one of these." 218 ) 219 220 raise RuntimeError( 221 f"Too many tool calls ({tool_calls_count}). Stopping iteration to avoid using too many tokens." 222 ) 223 224 async def _run( 225 self, 226 input: InputType, 227 prior_trace: list[ChatCompletionMessageParam] | None = None, 228 ) -> tuple[RunOutput, Usage | None]: 229 usage = Usage() 230 231 provider = self.model_provider() 232 if not provider.model_id: 233 raise ValueError("Model ID is required for OpenAI compatible models") 234 235 # build_chat_formatter returns MultiturnFormatter when prior_trace is set, else prompt-based formatter 236 chat_formatter = self.build_chat_formatter(input, prior_trace) 237 messages: list[ChatCompletionMessageIncludingLiteLLM] = copy.deepcopy( 238 chat_formatter.initial_messages() 239 ) 240 241 prior_output: str | None = None 242 final_choice: Choices | None = None 243 turns = 0 244 245 # Same loop for both fresh runs and prior_trace continuation. 246 # _run_model_turn has its own internal loop for tool calls (model calls tool -> we run it -> model continues). 247 while True: 248 turns += 1 249 if turns > MAX_CALLS_PER_TURN: 250 raise RuntimeError( 251 f"Too many turns ({turns}). Stopping iteration to avoid using too many tokens." 252 ) 253 254 turn = chat_formatter.next_turn(prior_output) 255 if turn is None: 256 # No next turn, we're done 257 break 258 259 # Add messages from the turn to chat history 260 for message in turn.messages: 261 if message.content is None: 262 raise ValueError("Empty message content isn't allowed") 263 # pyright incorrectly warns about this, but it's valid so we can ignore. It can't handle the multi-value role. 264 msg_dict: dict = {"role": message.role, "content": message.content} 265 if isinstance(message, ToolResponseMessage): 266 msg_dict["tool_call_id"] = message.tool_call_id 267 messages.append(msg_dict) # type: ignore 268 269 skip_response_format = not turn.final_call 270 turn_result = await self._run_model_turn( 271 provider, 272 messages, 273 self.base_adapter_config.top_logprobs if turn.final_call else None, 274 skip_response_format, 275 ) 276 277 usage += turn_result.usage 278 279 prior_output = turn_result.assistant_message 280 messages = turn_result.all_messages 281 final_choice = turn_result.model_choice 282 283 # Check if we were interrupted by tool calls 284 if turn_result.interrupted_by_tool_calls: 285 trace = self.all_messages_to_trace(messages) 286 intermediate_outputs = chat_formatter.intermediate_outputs() 287 output = RunOutput( 288 output=prior_output or "", 289 intermediate_outputs=intermediate_outputs, 290 output_logprobs=None, 291 trace=trace, 292 ) 293 return output, usage 294 295 if not prior_output: 296 raise RuntimeError("No assistant message/output returned from model") 297 298 logprobs = self._extract_and_validate_logprobs(final_choice) 299 300 # Save COT/reasoning if it exists. May be a message, or may be parsed by LiteLLM (or openrouter, or anyone upstream) 301 intermediate_outputs = chat_formatter.intermediate_outputs() 302 self._extract_reasoning_to_intermediate_outputs( 303 final_choice, intermediate_outputs 304 ) 305 306 if not isinstance(prior_output, str): 307 raise RuntimeError(f"assistant message is not a string: {prior_output}") 308 309 trace = self.all_messages_to_trace(messages) 310 output = RunOutput( 311 output=prior_output, 312 intermediate_outputs=intermediate_outputs, 313 output_logprobs=logprobs, 314 trace=trace, 315 ) 316 317 return output, usage 318 319 def _create_run_stream( 320 self, 321 input: InputType, 322 prior_trace: list[ChatCompletionMessageParam] | None = None, 323 ) -> AdapterStream: 324 provider = self.model_provider() 325 if not provider.model_id: 326 raise ValueError("Model ID is required for OpenAI compatible models") 327 328 chat_formatter = self.build_chat_formatter(input, prior_trace) 329 initial_messages: list[ChatCompletionMessageIncludingLiteLLM] = copy.deepcopy( 330 chat_formatter.initial_messages() 331 ) 332 333 return AdapterStream( 334 adapter=self, 335 provider=provider, 336 chat_formatter=chat_formatter, 337 initial_messages=initial_messages, 338 top_logprobs=self.base_adapter_config.top_logprobs, 339 ) 340 341 def _extract_and_validate_logprobs( 342 self, final_choice: Choices | None 343 ) -> ChoiceLogprobs | None: 344 """ 345 Extract logprobs from the final choice and validate they exist if required. 346 """ 347 logprobs = None 348 if ( 349 final_choice is not None 350 and hasattr(final_choice, "logprobs") 351 and isinstance(final_choice.logprobs, ChoiceLogprobs) 352 ): 353 logprobs = final_choice.logprobs 354 355 # Check logprobs worked, if required 356 if self.base_adapter_config.top_logprobs is not None and logprobs is None: 357 raise RuntimeError("Logprobs were required, but no logprobs were returned.") 358 359 return logprobs 360 361 def _extract_reasoning_to_intermediate_outputs( 362 self, final_choice: Choices | None, intermediate_outputs: Dict[str, Any] 363 ) -> None: 364 """Extract reasoning content from model choice and add to intermediate outputs if present.""" 365 if ( 366 final_choice is not None 367 and hasattr(final_choice, "message") 368 and hasattr(final_choice.message, "reasoning_content") 369 ): 370 reasoning_content = final_choice.message.reasoning_content 371 if reasoning_content is not None: 372 stripped_reasoning_content = reasoning_content.strip() 373 if len(stripped_reasoning_content) > 0: 374 intermediate_outputs["reasoning"] = stripped_reasoning_content 375 376 async def acompletion_checking_response( 377 self, **kwargs: Any 378 ) -> Tuple[ModelResponse, Choices]: 379 response = await litellm.acompletion(**kwargs) 380 381 if ( 382 not isinstance(response, ModelResponse) 383 or not response.choices 384 or len(response.choices) == 0 385 or not isinstance(response.choices[0], Choices) 386 ): 387 raise RuntimeError( 388 f"Expected ModelResponse with Choices, got {type(response)}." 389 ) 390 return response, response.choices[0] 391 392 def adapter_name(self) -> str: 393 return "kiln_openai_compatible_adapter" 394 395 async def response_format_options(self) -> dict[str, Any]: 396 # Unstructured if task isn't structured 397 if not self.has_structured_output(): 398 return {} 399 400 run_config = as_kiln_agent_run_config(self.run_config) 401 structured_output_mode: StructuredOutputMode = run_config.structured_output_mode 402 403 match structured_output_mode: 404 case StructuredOutputMode.json_mode: 405 return {"response_format": {"type": "json_object"}} 406 case StructuredOutputMode.json_schema: 407 return self.json_schema_response_format() 408 case StructuredOutputMode.function_calling_weak: 409 return self.tool_call_params(strict=False) 410 case StructuredOutputMode.function_calling: 411 return self.tool_call_params(strict=True) 412 case StructuredOutputMode.json_instructions: 413 # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below). 414 return {} 415 case StructuredOutputMode.json_custom_instructions: 416 # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above). 417 return {} 418 case StructuredOutputMode.json_instruction_and_object: 419 # We set response_format to json_object and also set json instructions in the prompt 420 return {"response_format": {"type": "json_object"}} 421 case StructuredOutputMode.default: 422 provider_name = run_config.model_provider_name 423 if provider_name == ModelProviderName.ollama: 424 # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs 425 return self.json_schema_response_format() 426 elif provider_name == ModelProviderName.docker_model_runner: 427 # Docker Model Runner uses OpenAI-compatible API with JSON schema support 428 return self.json_schema_response_format() 429 else: 430 # Default to function calling -- it's older than the other modes. Higher compatibility. 431 # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI. 432 strict = provider_name == ModelProviderName.openai 433 return self.tool_call_params(strict=strict) 434 case StructuredOutputMode.unknown: 435 # See above, but this case should never happen. 436 raise ValueError("Structured output mode is unknown.") 437 case _: 438 raise_exhaustive_enum_error(structured_output_mode) # type: ignore[arg-type] 439 440 def json_schema_response_format(self) -> dict[str, Any]: 441 output_schema = self.task.output_schema() 442 if output_schema is None: 443 raise ValueError( 444 "Invalid output schema for this task. Cannot use JSON schema response format." 445 ) 446 output_schema = close_object_schemas(output_schema, strict=True) 447 return { 448 "response_format": { 449 "type": "json_schema", 450 "json_schema": { 451 "name": "task_response", 452 "schema": output_schema, 453 }, 454 } 455 } 456 457 def tool_call_params(self, strict: bool) -> dict[str, Any]: 458 # Add additional_properties: false to the schema (OpenAI requires this for some models) 459 output_schema = self.task.output_schema() 460 if not isinstance(output_schema, dict): 461 raise ValueError( 462 "Invalid output schema for this task. Can not use tool calls." 463 ) 464 output_schema = close_object_schemas(output_schema, strict=strict) 465 466 function_params = { 467 "name": "task_response", 468 "parameters": output_schema, 469 } 470 # This should be on, but we allow setting function_calling_weak for APIs that don't support it. 471 if strict: 472 function_params["strict"] = True 473 474 return { 475 "tools": [ 476 { 477 "type": "function", 478 "function": function_params, 479 } 480 ], 481 "tool_choice": { 482 "type": "function", 483 "function": {"name": "task_response"}, 484 }, 485 } 486 487 def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]: 488 # Don't love having this logic here. But it's worth the usability improvement 489 # so better to keep it than exclude it. Should figure out how I want to isolate 490 # this sort of logic so it's config driven and can be overridden 491 extra_body: dict[str, Any] = {} 492 provider_options = {} 493 494 run_config = as_kiln_agent_run_config(self.run_config) 495 # For legacy config 'thinking_level' is not set, default to provider's default 496 if "thinking_level" in run_config.model_fields_set: 497 thinking_level = run_config.thinking_level 498 else: 499 thinking_level = provider.default_thinking_level 500 501 # Set the reasoning_effort 502 if thinking_level is not None: 503 # Anthropic models in OpenRouter uses reasoning object. See https://openrouter.ai/docs/use-cases/reasoning-tokens 504 if ( 505 provider.name == ModelProviderName.openrouter 506 and provider.openrouter_reasoning_object 507 ): 508 extra_body["reasoning"] = {"effort": thinking_level} 509 else: 510 extra_body["reasoning_effort"] = thinking_level 511 512 if provider.require_openrouter_reasoning: 513 # https://openrouter.ai/docs/use-cases/reasoning-tokens 514 extra_body["reasoning"] = { 515 "exclude": False, 516 } 517 518 if provider.gemini_reasoning_enabled: 519 extra_body["reasoning"] = { 520 "enabled": True, 521 } 522 523 if provider.name == ModelProviderName.openrouter: 524 # Ask OpenRouter to include usage in the response (cost) 525 extra_body["usage"] = {"include": True} 526 527 # Set a default provider order for more deterministic routing. 528 # OpenRouter will ignore providers that don't support the model. 529 # Special cases below (like R1) can override this order. 530 # allow_fallbacks is true by default, but we can override it here. 531 provider_options["order"] = [ 532 "fireworks", 533 "parasail", 534 "together", 535 "deepinfra", 536 "novita", 537 "groq", 538 "amazon-bedrock", 539 "azure", 540 "nebius", 541 ] 542 543 if provider.anthropic_extended_thinking: 544 extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000} 545 546 if provider.r1_openrouter_options: 547 # Require providers that support the reasoning parameter 548 provider_options["require_parameters"] = True 549 # Prefer R1 providers with reasonable perf/quants 550 provider_options["order"] = ["fireworks", "together"] 551 # R1 providers with unreasonable quants 552 provider_options["ignore"] = ["deepinfra"] 553 554 # Only set of this request is to get logprobs. 555 if ( 556 provider.logprobs_openrouter_options 557 and self.base_adapter_config.top_logprobs is not None 558 ): 559 # Don't let OpenRouter choose a provider that doesn't support logprobs. 560 provider_options["require_parameters"] = True 561 # DeepInfra silently fails to return logprobs consistently. 562 provider_options["ignore"] = ["deepinfra"] 563 564 if provider.openrouter_skip_required_parameters: 565 # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params. 566 provider_options["require_parameters"] = False 567 568 # Siliconflow uses a bool flag for thinking, for some models 569 if provider.siliconflow_enable_thinking is not None: 570 extra_body["enable_thinking"] = provider.siliconflow_enable_thinking 571 572 if len(provider_options) > 0: 573 extra_body["provider"] = provider_options 574 575 return extra_body 576 577 def litellm_model_id(self) -> str: 578 # The model ID is an interesting combination of format and url endpoint. 579 # It specifics the provider URL/host, but this is overridden if you manually set an api url 580 if self._litellm_model_id: 581 return self._litellm_model_id 582 583 litellm_provider_info = get_litellm_provider_info(self.model_provider()) 584 if litellm_provider_info.is_custom and self._api_base is None: 585 raise ValueError( 586 "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)" 587 ) 588 589 self._litellm_model_id = litellm_provider_info.litellm_model_id 590 return self._litellm_model_id 591 592 def _allowed_openai_params_for_completion_kwargs( 593 self, completion_kwargs: dict[str, Any] 594 ) -> list[str]: 595 """ 596 LiteLLM drops params it thinks are not supported by the model when drop_params is True. Sometimes it is wrong 597 and we know it is supported, so we whitelist them here and pass that as an allowed_openai_params parameter. 598 """ 599 # callers could have set allowed_openai_params in the additional_body_options, so we need to check for that 600 explicit_allowed_params: Any | list = completion_kwargs.get( 601 "allowed_openai_params", [] 602 ) 603 if not isinstance(explicit_allowed_params, list): 604 raise ValueError( 605 f"Unexpected allowed_openai_params format: {explicit_allowed_params} - expected list, got {type(explicit_allowed_params)}" 606 ) 607 explicit_allowed_params_validated = [ 608 param for param in explicit_allowed_params if isinstance(param, str) 609 ] 610 invalid_count = len(explicit_allowed_params) - len( 611 explicit_allowed_params_validated 612 ) 613 if invalid_count > 0: 614 raise ValueError( 615 f"Unexpected allowed_openai_params format: {explicit_allowed_params} - {invalid_count} items are not strings" 616 ) 617 618 # these are our own logic 619 automatic_allowed_params: list[str] = [] 620 if "tools" in completion_kwargs: 621 automatic_allowed_params.append("tools") 622 if "tool_choice" in completion_kwargs: 623 automatic_allowed_params.append("tool_choice") 624 625 return list(set(explicit_allowed_params_validated + automatic_allowed_params)) 626 627 async def build_completion_kwargs( 628 self, 629 provider: KilnModelProvider, 630 messages: list[ChatCompletionMessageIncludingLiteLLM], 631 top_logprobs: int | None, 632 skip_response_format: bool = False, 633 ) -> dict[str, Any]: 634 run_config = as_kiln_agent_run_config(self.run_config) 635 extra_body = self.build_extra_body(provider) 636 637 # Merge all parameters into a single kwargs dict for litellm 638 completion_kwargs = { 639 "model": self.litellm_model_id(), 640 "messages": messages, 641 "api_base": self._api_base, 642 "headers": self._headers, 643 "temperature": run_config.temperature, 644 "top_p": run_config.top_p, 645 # This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc. 646 # Not all models and providers support all openai params (for example, o3 doesn't support top_p) 647 # Better to ignore them than to fail the model call. 648 # https://docs.litellm.ai/docs/completion/input 649 "drop_params": True, 650 **extra_body, 651 **self._additional_body_options, 652 } 653 654 if self.base_adapter_config.automatic_prompt_caching: 655 # Mark the last message for cache control. Litellm's AnthropicCacheControlHook 656 # handles provider-specific injection. Providers auto-cache matching prefixes, 657 # so marking the last message is sufficient for multi-turn conversations. 658 completion_kwargs["cache_control_injection_points"] = [ 659 {"location": "message", "index": -1} 660 ] 661 662 tool_calls = await self.litellm_tools() 663 has_tools = len(tool_calls) > 0 664 if has_tools: 665 completion_kwargs["tools"] = tool_calls 666 completion_kwargs["tool_choice"] = "auto" 667 668 # Special condition for Claude Opus 4.1 and Sonnet 4.5, where we can only specify top_p or temp, not both. 669 # Remove default values (1.0) prioritizing anything the user customized, then error with helpful message if they are both custom. 670 if provider.temp_top_p_exclusive: 671 if "top_p" in completion_kwargs and completion_kwargs["top_p"] == 1.0: 672 del completion_kwargs["top_p"] 673 if ( 674 "temperature" in completion_kwargs 675 and completion_kwargs["temperature"] == 1.0 676 ): 677 del completion_kwargs["temperature"] 678 if "top_p" in completion_kwargs and "temperature" in completion_kwargs: 679 raise ValueError( 680 "top_p and temperature can not both have custom values for this model. This is a restriction from the model provider. Please set only one of them to a custom value (not 1.0)." 681 ) 682 683 if not skip_response_format: 684 # Response format: json_schema, json_instructions, json_mode, function_calling, etc 685 response_format_options = await self.response_format_options() 686 687 # Check for a conflict between tools and response format using tools 688 # We could reconsider this. Model could be able to choose between a final answer or a tool call on any turn. However, good models for tools tend to also support json_schea, so do we need to support both? If we do, merge them, and consider auto vs forced when merging (only forced for final, auto for merged). 689 if has_tools and "tools" in response_format_options: 690 raise ValueError( 691 "Function calling/tools can't be used as the JSON response format if you're also using tools. Please select a different structured output mode." 692 ) 693 694 completion_kwargs.update(response_format_options) 695 696 if top_logprobs is not None: 697 completion_kwargs["logprobs"] = True 698 completion_kwargs["top_logprobs"] = top_logprobs 699 700 # any params listed in this list will be passed to the model regardless of LiteLLM's own validation 701 allowed_openai_params = self._allowed_openai_params_for_completion_kwargs( 702 completion_kwargs 703 ) 704 if len(allowed_openai_params) > 0: 705 completion_kwargs["allowed_openai_params"] = allowed_openai_params 706 707 return completion_kwargs 708 709 def usage_from_response(self, response: ModelResponse) -> Usage: 710 litellm_usage = response.get("usage", None) 711 712 # LiteLLM isn't consistent in how it returns the cost. 713 cost = response._hidden_params.get("response_cost", None) 714 if cost is None and litellm_usage: 715 cost = litellm_usage.get("cost", None) 716 717 usage = Usage() 718 719 if not litellm_usage and not cost: 720 return usage 721 722 if litellm_usage and isinstance(litellm_usage, LiteLlmUsage): 723 usage.input_tokens = litellm_usage.get("prompt_tokens", None) 724 usage.output_tokens = litellm_usage.get("completion_tokens", None) 725 usage.total_tokens = litellm_usage.get("total_tokens", None) 726 prompt_details = litellm_usage.get("prompt_tokens_details", None) 727 if prompt_details and hasattr(prompt_details, "cached_tokens"): 728 usage.cached_tokens = prompt_details.cached_tokens 729 elif prompt_details: 730 logger.warning( 731 f"prompt_tokens_details has unexpected type {type(prompt_details)}, cached_tokens not extracted" 732 ) 733 else: 734 logger.warning( 735 f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}" 736 ) 737 738 if isinstance(cost, float): 739 usage.cost = cost 740 elif cost is not None: 741 # None is allowed, but no other types are expected 742 logger.warning( 743 f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}" 744 ) 745 746 return usage 747 748 async def cached_available_tools(self) -> list[KilnToolInterface]: 749 if self._cached_available_tools is None: 750 self._cached_available_tools = await self.available_tools() 751 return self._cached_available_tools 752 753 async def _tools_for_execution(self) -> list[KilnToolInterface]: 754 """Registry-resolved tools plus :attr:`AdapterConfig.unmanaged_tools` (same order as ``litellm_tools``).""" 755 registry = await self.cached_available_tools() 756 unmanaged = self.base_adapter_config.unmanaged_tools or [] 757 return registry + unmanaged 758 759 async def litellm_tools(self) -> list[ToolCallDefinition]: 760 available_tools = await self.cached_available_tools() 761 762 registry_defs = [await tool.toolcall_definition() for tool in available_tools] 763 unmanaged = self.base_adapter_config.unmanaged_tools 764 unmanaged_defs = ( 765 [await t.toolcall_definition() for t in unmanaged] if unmanaged else [] 766 ) 767 768 merged = registry_defs + unmanaged_defs 769 seen_names: set[str] = set() 770 for d in merged: 771 name = d["function"]["name"] 772 if name in seen_names: 773 raise ValueError( 774 f"Duplicate tool name {name!r}: unmanaged and registry tools must have unique names." 775 ) 776 seen_names.add(name) 777 778 return merged 779 780 async def process_tool_calls( 781 self, tool_calls: list[ChatCompletionMessageToolCall] | None 782 ) -> tuple[str | None, list[ChatCompletionToolMessageParamWrapper]]: 783 if tool_calls is None: 784 return None, [] 785 786 assistant_output_from_toolcall: str | None = None 787 tool_call_response_messages: list[ChatCompletionToolMessageParamWrapper] = [] 788 tool_run_coroutines = [] 789 790 for tool_call in tool_calls: 791 # Kiln "task_response" tool is used for returning structured output via tool calls. 792 # Load the output from the tool call. Also 793 if tool_call.function.name == "task_response": 794 assistant_output_from_toolcall = tool_call.function.arguments 795 continue 796 797 # Process normal tool calls (not the "task_response" tool) 798 tool_name = tool_call.function.name 799 tool = None 800 for tool_option in await self._tools_for_execution(): 801 if await tool_option.name() == tool_name: 802 tool = tool_option 803 break 804 if not tool: 805 raise RuntimeError( 806 f"A tool named '{tool_name}' was invoked by a model, but was not available." 807 ) 808 809 # Parse the arguments and validate them against the tool's schema 810 try: 811 parsed_args = json.loads(tool_call.function.arguments) 812 except json.JSONDecodeError: 813 raise RuntimeError( 814 f"Failed to parse arguments for tool '{tool_name}' (should be JSON): {tool_call.function.arguments}" 815 ) 816 try: 817 tool_call_definition = await tool.toolcall_definition() 818 json_schema = json.dumps(tool_call_definition["function"]["parameters"]) 819 validate_schema_with_value_error(parsed_args, json_schema) 820 except Exception as e: 821 raise RuntimeError( 822 f"Failed to validate arguments for tool '{tool_name}'. The arguments didn't match the tool's schema. The arguments were: {parsed_args}\n The error was: {e}" 823 ) from e 824 825 # Create context with the calling task's allow_saving setting 826 context = ToolCallContext( 827 allow_saving=self.base_adapter_config.allow_saving 828 ) 829 830 async def run_tool_and_format( 831 t=tool, c=context, args=parsed_args, tc_id=tool_call.id 832 ): 833 result = await t.run(c, **args) 834 return ChatCompletionToolMessageParamWrapper( 835 role="tool", 836 tool_call_id=tc_id, 837 content=result.output, 838 kiln_task_tool_data=result.kiln_task_tool_data 839 if isinstance(result, KilnTaskToolResult) 840 else None, 841 is_error=result.is_error if result.is_error else None, 842 error_message=result.error_message 843 if result.error_message 844 else None, 845 ) 846 847 tool_run_coroutines.append(run_tool_and_format()) 848 849 if tool_run_coroutines: 850 tool_call_response_messages = await asyncio.gather(*tool_run_coroutines) 851 852 if ( 853 assistant_output_from_toolcall is not None 854 and len(tool_call_response_messages) > 0 855 ): 856 raise RuntimeError( 857 "Model asked for impossible combination: task_response tool call and other tool calls were both provided in the same turn. This is not supported as it means the model asked us to both return task_response results (ending the turn) and run new tools calls to send back to the model. If the model makes this mistake often, try a difference structured data model like JSON schema, where this is impossible." 858 ) 859 860 return assistant_output_from_toolcall, tool_call_response_messages 861 862 def litellm_message_to_trace_message( 863 self, raw_message: LiteLLMMessage 864 ) -> ChatCompletionAssistantMessageParamWrapper: 865 """ 866 Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper 867 """ 868 message: ChatCompletionAssistantMessageParamWrapper = { 869 "role": "assistant", 870 } 871 if raw_message.role != "assistant": 872 raise ValueError( 873 "Model returned a message with a role other than assistant. This is not supported." 874 ) 875 876 if hasattr(raw_message, "content"): 877 message["content"] = raw_message.content 878 if hasattr(raw_message, "reasoning_content"): 879 message["reasoning_content"] = raw_message.reasoning_content 880 if hasattr(raw_message, "tool_calls"): 881 # Convert ChatCompletionMessageToolCall to ChatCompletionMessageToolCallParam 882 open_ai_tool_calls: List[ChatCompletionMessageToolCallParam] = [] 883 for litellm_tool_call in raw_message.tool_calls or []: 884 # Optional in the SDK for streaming responses, but should never be None at this point. 885 if litellm_tool_call.function.name is None: 886 raise ValueError( 887 "The model requested a tool call, without providing a function name (required)." 888 ) 889 open_ai_tool_calls.append( 890 ChatCompletionMessageToolCallParam( 891 id=litellm_tool_call.id, 892 type="function", 893 function={ 894 "name": litellm_tool_call.function.name, 895 "arguments": litellm_tool_call.function.arguments, 896 }, 897 ) 898 ) 899 if len(open_ai_tool_calls) > 0: 900 message["tool_calls"] = open_ai_tool_calls 901 902 if not message.get("content") and not message.get("tool_calls"): 903 raise ValueError( 904 "Model returned an assistant message, but no content or tool calls. This is not supported." 905 ) 906 907 return message 908 909 def all_messages_to_trace( 910 self, messages: list[ChatCompletionMessageIncludingLiteLLM] 911 ) -> list[ChatCompletionMessageParam]: 912 """ 913 Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types. 914 """ 915 trace: list[ChatCompletionMessageParam] = [] 916 for message in messages: 917 if isinstance(message, LiteLLMMessage): 918 trace.append(self.litellm_message_to_trace_message(message)) 919 else: 920 trace.append(message) 921 return trace
Base class for AI model adapters that handle task execution.
This abstract class provides the foundation for implementing model-specific adapters that can process tasks with structured or unstructured inputs/outputs. It handles input/output validation, prompt building, and run tracking.
Prompt building is handled internally by the adapter, which uses a prompt builder based on the run config. To override the prompt building behavior, pass a custom prompt builder to the adapter config.
86 def __init__( 87 self, 88 config: LiteLlmConfig, 89 kiln_task: datamodel.Task, 90 base_adapter_config: AdapterConfig | None = None, 91 ): 92 if not isinstance(config.run_config_properties, KilnAgentRunConfigProperties): 93 raise ValueError("LiteLlmAdapter requires KilnAgentRunConfigProperties") 94 self.config = config 95 self._additional_body_options = config.additional_body_options 96 self._api_base = config.base_url 97 self._headers = config.default_headers 98 self._litellm_model_id: str | None = None 99 self._cached_available_tools: list[KilnToolInterface] | None = None 100 101 super().__init__( 102 task=kiln_task, 103 run_config=config.run_config_properties, 104 config=base_adapter_config, 105 ) 106 107 unmanaged_tools = self.base_adapter_config.unmanaged_tools 108 if unmanaged_tools: 109 _validate_unmanaged_tools(unmanaged_tools)
376 async def acompletion_checking_response( 377 self, **kwargs: Any 378 ) -> Tuple[ModelResponse, Choices]: 379 response = await litellm.acompletion(**kwargs) 380 381 if ( 382 not isinstance(response, ModelResponse) 383 or not response.choices 384 or len(response.choices) == 0 385 or not isinstance(response.choices[0], Choices) 386 ): 387 raise RuntimeError( 388 f"Expected ModelResponse with Choices, got {type(response)}." 389 ) 390 return response, response.choices[0]
395 async def response_format_options(self) -> dict[str, Any]: 396 # Unstructured if task isn't structured 397 if not self.has_structured_output(): 398 return {} 399 400 run_config = as_kiln_agent_run_config(self.run_config) 401 structured_output_mode: StructuredOutputMode = run_config.structured_output_mode 402 403 match structured_output_mode: 404 case StructuredOutputMode.json_mode: 405 return {"response_format": {"type": "json_object"}} 406 case StructuredOutputMode.json_schema: 407 return self.json_schema_response_format() 408 case StructuredOutputMode.function_calling_weak: 409 return self.tool_call_params(strict=False) 410 case StructuredOutputMode.function_calling: 411 return self.tool_call_params(strict=True) 412 case StructuredOutputMode.json_instructions: 413 # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below). 414 return {} 415 case StructuredOutputMode.json_custom_instructions: 416 # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above). 417 return {} 418 case StructuredOutputMode.json_instruction_and_object: 419 # We set response_format to json_object and also set json instructions in the prompt 420 return {"response_format": {"type": "json_object"}} 421 case StructuredOutputMode.default: 422 provider_name = run_config.model_provider_name 423 if provider_name == ModelProviderName.ollama: 424 # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs 425 return self.json_schema_response_format() 426 elif provider_name == ModelProviderName.docker_model_runner: 427 # Docker Model Runner uses OpenAI-compatible API with JSON schema support 428 return self.json_schema_response_format() 429 else: 430 # Default to function calling -- it's older than the other modes. Higher compatibility. 431 # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI. 432 strict = provider_name == ModelProviderName.openai 433 return self.tool_call_params(strict=strict) 434 case StructuredOutputMode.unknown: 435 # See above, but this case should never happen. 436 raise ValueError("Structured output mode is unknown.") 437 case _: 438 raise_exhaustive_enum_error(structured_output_mode) # type: ignore[arg-type]
440 def json_schema_response_format(self) -> dict[str, Any]: 441 output_schema = self.task.output_schema() 442 if output_schema is None: 443 raise ValueError( 444 "Invalid output schema for this task. Cannot use JSON schema response format." 445 ) 446 output_schema = close_object_schemas(output_schema, strict=True) 447 return { 448 "response_format": { 449 "type": "json_schema", 450 "json_schema": { 451 "name": "task_response", 452 "schema": output_schema, 453 }, 454 } 455 }
457 def tool_call_params(self, strict: bool) -> dict[str, Any]: 458 # Add additional_properties: false to the schema (OpenAI requires this for some models) 459 output_schema = self.task.output_schema() 460 if not isinstance(output_schema, dict): 461 raise ValueError( 462 "Invalid output schema for this task. Can not use tool calls." 463 ) 464 output_schema = close_object_schemas(output_schema, strict=strict) 465 466 function_params = { 467 "name": "task_response", 468 "parameters": output_schema, 469 } 470 # This should be on, but we allow setting function_calling_weak for APIs that don't support it. 471 if strict: 472 function_params["strict"] = True 473 474 return { 475 "tools": [ 476 { 477 "type": "function", 478 "function": function_params, 479 } 480 ], 481 "tool_choice": { 482 "type": "function", 483 "function": {"name": "task_response"}, 484 }, 485 }
487 def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]: 488 # Don't love having this logic here. But it's worth the usability improvement 489 # so better to keep it than exclude it. Should figure out how I want to isolate 490 # this sort of logic so it's config driven and can be overridden 491 extra_body: dict[str, Any] = {} 492 provider_options = {} 493 494 run_config = as_kiln_agent_run_config(self.run_config) 495 # For legacy config 'thinking_level' is not set, default to provider's default 496 if "thinking_level" in run_config.model_fields_set: 497 thinking_level = run_config.thinking_level 498 else: 499 thinking_level = provider.default_thinking_level 500 501 # Set the reasoning_effort 502 if thinking_level is not None: 503 # Anthropic models in OpenRouter uses reasoning object. See https://openrouter.ai/docs/use-cases/reasoning-tokens 504 if ( 505 provider.name == ModelProviderName.openrouter 506 and provider.openrouter_reasoning_object 507 ): 508 extra_body["reasoning"] = {"effort": thinking_level} 509 else: 510 extra_body["reasoning_effort"] = thinking_level 511 512 if provider.require_openrouter_reasoning: 513 # https://openrouter.ai/docs/use-cases/reasoning-tokens 514 extra_body["reasoning"] = { 515 "exclude": False, 516 } 517 518 if provider.gemini_reasoning_enabled: 519 extra_body["reasoning"] = { 520 "enabled": True, 521 } 522 523 if provider.name == ModelProviderName.openrouter: 524 # Ask OpenRouter to include usage in the response (cost) 525 extra_body["usage"] = {"include": True} 526 527 # Set a default provider order for more deterministic routing. 528 # OpenRouter will ignore providers that don't support the model. 529 # Special cases below (like R1) can override this order. 530 # allow_fallbacks is true by default, but we can override it here. 531 provider_options["order"] = [ 532 "fireworks", 533 "parasail", 534 "together", 535 "deepinfra", 536 "novita", 537 "groq", 538 "amazon-bedrock", 539 "azure", 540 "nebius", 541 ] 542 543 if provider.anthropic_extended_thinking: 544 extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000} 545 546 if provider.r1_openrouter_options: 547 # Require providers that support the reasoning parameter 548 provider_options["require_parameters"] = True 549 # Prefer R1 providers with reasonable perf/quants 550 provider_options["order"] = ["fireworks", "together"] 551 # R1 providers with unreasonable quants 552 provider_options["ignore"] = ["deepinfra"] 553 554 # Only set of this request is to get logprobs. 555 if ( 556 provider.logprobs_openrouter_options 557 and self.base_adapter_config.top_logprobs is not None 558 ): 559 # Don't let OpenRouter choose a provider that doesn't support logprobs. 560 provider_options["require_parameters"] = True 561 # DeepInfra silently fails to return logprobs consistently. 562 provider_options["ignore"] = ["deepinfra"] 563 564 if provider.openrouter_skip_required_parameters: 565 # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params. 566 provider_options["require_parameters"] = False 567 568 # Siliconflow uses a bool flag for thinking, for some models 569 if provider.siliconflow_enable_thinking is not None: 570 extra_body["enable_thinking"] = provider.siliconflow_enable_thinking 571 572 if len(provider_options) > 0: 573 extra_body["provider"] = provider_options 574 575 return extra_body
577 def litellm_model_id(self) -> str: 578 # The model ID is an interesting combination of format and url endpoint. 579 # It specifics the provider URL/host, but this is overridden if you manually set an api url 580 if self._litellm_model_id: 581 return self._litellm_model_id 582 583 litellm_provider_info = get_litellm_provider_info(self.model_provider()) 584 if litellm_provider_info.is_custom and self._api_base is None: 585 raise ValueError( 586 "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)" 587 ) 588 589 self._litellm_model_id = litellm_provider_info.litellm_model_id 590 return self._litellm_model_id
627 async def build_completion_kwargs( 628 self, 629 provider: KilnModelProvider, 630 messages: list[ChatCompletionMessageIncludingLiteLLM], 631 top_logprobs: int | None, 632 skip_response_format: bool = False, 633 ) -> dict[str, Any]: 634 run_config = as_kiln_agent_run_config(self.run_config) 635 extra_body = self.build_extra_body(provider) 636 637 # Merge all parameters into a single kwargs dict for litellm 638 completion_kwargs = { 639 "model": self.litellm_model_id(), 640 "messages": messages, 641 "api_base": self._api_base, 642 "headers": self._headers, 643 "temperature": run_config.temperature, 644 "top_p": run_config.top_p, 645 # This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc. 646 # Not all models and providers support all openai params (for example, o3 doesn't support top_p) 647 # Better to ignore them than to fail the model call. 648 # https://docs.litellm.ai/docs/completion/input 649 "drop_params": True, 650 **extra_body, 651 **self._additional_body_options, 652 } 653 654 if self.base_adapter_config.automatic_prompt_caching: 655 # Mark the last message for cache control. Litellm's AnthropicCacheControlHook 656 # handles provider-specific injection. Providers auto-cache matching prefixes, 657 # so marking the last message is sufficient for multi-turn conversations. 658 completion_kwargs["cache_control_injection_points"] = [ 659 {"location": "message", "index": -1} 660 ] 661 662 tool_calls = await self.litellm_tools() 663 has_tools = len(tool_calls) > 0 664 if has_tools: 665 completion_kwargs["tools"] = tool_calls 666 completion_kwargs["tool_choice"] = "auto" 667 668 # Special condition for Claude Opus 4.1 and Sonnet 4.5, where we can only specify top_p or temp, not both. 669 # Remove default values (1.0) prioritizing anything the user customized, then error with helpful message if they are both custom. 670 if provider.temp_top_p_exclusive: 671 if "top_p" in completion_kwargs and completion_kwargs["top_p"] == 1.0: 672 del completion_kwargs["top_p"] 673 if ( 674 "temperature" in completion_kwargs 675 and completion_kwargs["temperature"] == 1.0 676 ): 677 del completion_kwargs["temperature"] 678 if "top_p" in completion_kwargs and "temperature" in completion_kwargs: 679 raise ValueError( 680 "top_p and temperature can not both have custom values for this model. This is a restriction from the model provider. Please set only one of them to a custom value (not 1.0)." 681 ) 682 683 if not skip_response_format: 684 # Response format: json_schema, json_instructions, json_mode, function_calling, etc 685 response_format_options = await self.response_format_options() 686 687 # Check for a conflict between tools and response format using tools 688 # We could reconsider this. Model could be able to choose between a final answer or a tool call on any turn. However, good models for tools tend to also support json_schea, so do we need to support both? If we do, merge them, and consider auto vs forced when merging (only forced for final, auto for merged). 689 if has_tools and "tools" in response_format_options: 690 raise ValueError( 691 "Function calling/tools can't be used as the JSON response format if you're also using tools. Please select a different structured output mode." 692 ) 693 694 completion_kwargs.update(response_format_options) 695 696 if top_logprobs is not None: 697 completion_kwargs["logprobs"] = True 698 completion_kwargs["top_logprobs"] = top_logprobs 699 700 # any params listed in this list will be passed to the model regardless of LiteLLM's own validation 701 allowed_openai_params = self._allowed_openai_params_for_completion_kwargs( 702 completion_kwargs 703 ) 704 if len(allowed_openai_params) > 0: 705 completion_kwargs["allowed_openai_params"] = allowed_openai_params 706 707 return completion_kwargs
709 def usage_from_response(self, response: ModelResponse) -> Usage: 710 litellm_usage = response.get("usage", None) 711 712 # LiteLLM isn't consistent in how it returns the cost. 713 cost = response._hidden_params.get("response_cost", None) 714 if cost is None and litellm_usage: 715 cost = litellm_usage.get("cost", None) 716 717 usage = Usage() 718 719 if not litellm_usage and not cost: 720 return usage 721 722 if litellm_usage and isinstance(litellm_usage, LiteLlmUsage): 723 usage.input_tokens = litellm_usage.get("prompt_tokens", None) 724 usage.output_tokens = litellm_usage.get("completion_tokens", None) 725 usage.total_tokens = litellm_usage.get("total_tokens", None) 726 prompt_details = litellm_usage.get("prompt_tokens_details", None) 727 if prompt_details and hasattr(prompt_details, "cached_tokens"): 728 usage.cached_tokens = prompt_details.cached_tokens 729 elif prompt_details: 730 logger.warning( 731 f"prompt_tokens_details has unexpected type {type(prompt_details)}, cached_tokens not extracted" 732 ) 733 else: 734 logger.warning( 735 f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}" 736 ) 737 738 if isinstance(cost, float): 739 usage.cost = cost 740 elif cost is not None: 741 # None is allowed, but no other types are expected 742 logger.warning( 743 f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}" 744 ) 745 746 return usage
759 async def litellm_tools(self) -> list[ToolCallDefinition]: 760 available_tools = await self.cached_available_tools() 761 762 registry_defs = [await tool.toolcall_definition() for tool in available_tools] 763 unmanaged = self.base_adapter_config.unmanaged_tools 764 unmanaged_defs = ( 765 [await t.toolcall_definition() for t in unmanaged] if unmanaged else [] 766 ) 767 768 merged = registry_defs + unmanaged_defs 769 seen_names: set[str] = set() 770 for d in merged: 771 name = d["function"]["name"] 772 if name in seen_names: 773 raise ValueError( 774 f"Duplicate tool name {name!r}: unmanaged and registry tools must have unique names." 775 ) 776 seen_names.add(name) 777 778 return merged
780 async def process_tool_calls( 781 self, tool_calls: list[ChatCompletionMessageToolCall] | None 782 ) -> tuple[str | None, list[ChatCompletionToolMessageParamWrapper]]: 783 if tool_calls is None: 784 return None, [] 785 786 assistant_output_from_toolcall: str | None = None 787 tool_call_response_messages: list[ChatCompletionToolMessageParamWrapper] = [] 788 tool_run_coroutines = [] 789 790 for tool_call in tool_calls: 791 # Kiln "task_response" tool is used for returning structured output via tool calls. 792 # Load the output from the tool call. Also 793 if tool_call.function.name == "task_response": 794 assistant_output_from_toolcall = tool_call.function.arguments 795 continue 796 797 # Process normal tool calls (not the "task_response" tool) 798 tool_name = tool_call.function.name 799 tool = None 800 for tool_option in await self._tools_for_execution(): 801 if await tool_option.name() == tool_name: 802 tool = tool_option 803 break 804 if not tool: 805 raise RuntimeError( 806 f"A tool named '{tool_name}' was invoked by a model, but was not available." 807 ) 808 809 # Parse the arguments and validate them against the tool's schema 810 try: 811 parsed_args = json.loads(tool_call.function.arguments) 812 except json.JSONDecodeError: 813 raise RuntimeError( 814 f"Failed to parse arguments for tool '{tool_name}' (should be JSON): {tool_call.function.arguments}" 815 ) 816 try: 817 tool_call_definition = await tool.toolcall_definition() 818 json_schema = json.dumps(tool_call_definition["function"]["parameters"]) 819 validate_schema_with_value_error(parsed_args, json_schema) 820 except Exception as e: 821 raise RuntimeError( 822 f"Failed to validate arguments for tool '{tool_name}'. The arguments didn't match the tool's schema. The arguments were: {parsed_args}\n The error was: {e}" 823 ) from e 824 825 # Create context with the calling task's allow_saving setting 826 context = ToolCallContext( 827 allow_saving=self.base_adapter_config.allow_saving 828 ) 829 830 async def run_tool_and_format( 831 t=tool, c=context, args=parsed_args, tc_id=tool_call.id 832 ): 833 result = await t.run(c, **args) 834 return ChatCompletionToolMessageParamWrapper( 835 role="tool", 836 tool_call_id=tc_id, 837 content=result.output, 838 kiln_task_tool_data=result.kiln_task_tool_data 839 if isinstance(result, KilnTaskToolResult) 840 else None, 841 is_error=result.is_error if result.is_error else None, 842 error_message=result.error_message 843 if result.error_message 844 else None, 845 ) 846 847 tool_run_coroutines.append(run_tool_and_format()) 848 849 if tool_run_coroutines: 850 tool_call_response_messages = await asyncio.gather(*tool_run_coroutines) 851 852 if ( 853 assistant_output_from_toolcall is not None 854 and len(tool_call_response_messages) > 0 855 ): 856 raise RuntimeError( 857 "Model asked for impossible combination: task_response tool call and other tool calls were both provided in the same turn. This is not supported as it means the model asked us to both return task_response results (ending the turn) and run new tools calls to send back to the model. If the model makes this mistake often, try a difference structured data model like JSON schema, where this is impossible." 858 ) 859 860 return assistant_output_from_toolcall, tool_call_response_messages
862 def litellm_message_to_trace_message( 863 self, raw_message: LiteLLMMessage 864 ) -> ChatCompletionAssistantMessageParamWrapper: 865 """ 866 Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper 867 """ 868 message: ChatCompletionAssistantMessageParamWrapper = { 869 "role": "assistant", 870 } 871 if raw_message.role != "assistant": 872 raise ValueError( 873 "Model returned a message with a role other than assistant. This is not supported." 874 ) 875 876 if hasattr(raw_message, "content"): 877 message["content"] = raw_message.content 878 if hasattr(raw_message, "reasoning_content"): 879 message["reasoning_content"] = raw_message.reasoning_content 880 if hasattr(raw_message, "tool_calls"): 881 # Convert ChatCompletionMessageToolCall to ChatCompletionMessageToolCallParam 882 open_ai_tool_calls: List[ChatCompletionMessageToolCallParam] = [] 883 for litellm_tool_call in raw_message.tool_calls or []: 884 # Optional in the SDK for streaming responses, but should never be None at this point. 885 if litellm_tool_call.function.name is None: 886 raise ValueError( 887 "The model requested a tool call, without providing a function name (required)." 888 ) 889 open_ai_tool_calls.append( 890 ChatCompletionMessageToolCallParam( 891 id=litellm_tool_call.id, 892 type="function", 893 function={ 894 "name": litellm_tool_call.function.name, 895 "arguments": litellm_tool_call.function.arguments, 896 }, 897 ) 898 ) 899 if len(open_ai_tool_calls) > 0: 900 message["tool_calls"] = open_ai_tool_calls 901 902 if not message.get("content") and not message.get("tool_calls"): 903 raise ValueError( 904 "Model returned an assistant message, but no content or tool calls. This is not supported." 905 ) 906 907 return message
Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper
909 def all_messages_to_trace( 910 self, messages: list[ChatCompletionMessageIncludingLiteLLM] 911 ) -> list[ChatCompletionMessageParam]: 912 """ 913 Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types. 914 """ 915 trace: list[ChatCompletionMessageParam] = [] 916 for message in messages: 917 if isinstance(message, LiteLLMMessage): 918 trace.append(self.litellm_message_to_trace_message(message)) 919 else: 920 trace.append(message) 921 return trace
Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types.
Inherited Members
- kiln_ai.adapters.model_adapters.base_adapter.BaseAdapter
- task
- run_config
- base_adapter_config
- output_schema
- input_schema
- model_provider
- invoke
- invoke_returning_run_output
- invoke_openai_stream
- invoke_ai_sdk_stream
- has_structured_output
- build_prompt
- build_chat_formatter
- generate_run
- update_run_config_unknown_structured_output_mode
- available_tools