kiln_ai.adapters.model_adapters.litellm_adapter
1import asyncio 2import copy 3import json 4import logging 5from dataclasses import dataclass 6from typing import Any, Dict, List, Tuple, TypeAlias, Union 7 8import litellm 9from litellm.types.utils import ( 10 ChatCompletionMessageToolCall, 11 ChoiceLogprobs, 12 Choices, 13 ModelResponse, 14) 15from litellm.types.utils import Message as LiteLLMMessage 16from litellm.types.utils import Usage as LiteLlmUsage 17from openai.types.chat.chat_completion_message_tool_call_param import ( 18 ChatCompletionMessageToolCallParam, 19) 20 21import kiln_ai.datamodel as datamodel 22from kiln_ai.adapters.ml_model_list import ( 23 KilnModelProvider, 24 ModelProviderName, 25 StructuredOutputMode, 26) 27from kiln_ai.adapters.model_adapters.base_adapter import ( 28 AdapterConfig, 29 BaseAdapter, 30 RunOutput, 31 Usage, 32) 33from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig 34from kiln_ai.datamodel.datamodel_enums import InputType 35from kiln_ai.datamodel.json_schema import validate_schema_with_value_error 36from kiln_ai.datamodel.run_config import ( 37 KilnAgentRunConfigProperties, 38 as_kiln_agent_run_config, 39) 40from kiln_ai.tools.base_tool import ( 41 KilnToolInterface, 42 ToolCallContext, 43 ToolCallDefinition, 44) 45from kiln_ai.tools.kiln_task_tool import KilnTaskToolResult 46from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error 47from kiln_ai.utils.litellm import get_litellm_provider_info 48from kiln_ai.utils.open_ai_types import ( 49 ChatCompletionAssistantMessageParamWrapper, 50 ChatCompletionMessageParam, 51 ChatCompletionToolMessageParamWrapper, 52) 53 54MAX_CALLS_PER_TURN = 10 55MAX_TOOL_CALLS_PER_TURN = 30 56 57logger = logging.getLogger(__name__) 58 59ChatCompletionMessageIncludingLiteLLM: TypeAlias = Union[ 60 ChatCompletionMessageParam, LiteLLMMessage 61] 62 63 64@dataclass 65class ModelTurnResult: 66 assistant_message: str 67 all_messages: list[ChatCompletionMessageIncludingLiteLLM] 68 model_response: ModelResponse | None 69 model_choice: Choices | None 70 usage: Usage 71 72 73class LiteLlmAdapter(BaseAdapter): 74 def __init__( 75 self, 76 config: LiteLlmConfig, 77 kiln_task: datamodel.Task, 78 base_adapter_config: AdapterConfig | None = None, 79 ): 80 if not isinstance(config.run_config_properties, KilnAgentRunConfigProperties): 81 raise ValueError("LiteLlmAdapter requires KilnAgentRunConfigProperties") 82 self.config = config 83 self._additional_body_options = config.additional_body_options 84 self._api_base = config.base_url 85 self._headers = config.default_headers 86 self._litellm_model_id: str | None = None 87 self._cached_available_tools: list[KilnToolInterface] | None = None 88 89 super().__init__( 90 task=kiln_task, 91 run_config=config.run_config_properties, 92 config=base_adapter_config, 93 ) 94 95 async def _run_model_turn( 96 self, 97 provider: KilnModelProvider, 98 prior_messages: list[ChatCompletionMessageIncludingLiteLLM], 99 top_logprobs: int | None, 100 skip_response_format: bool, 101 ) -> ModelTurnResult: 102 """ 103 Call the model for a single top level turn: from user message to agent message. 104 105 It may make handle iterations of tool calls between the user/agent message if needed. 106 """ 107 108 usage = Usage() 109 messages = list(prior_messages) 110 tool_calls_count = 0 111 112 while tool_calls_count < MAX_TOOL_CALLS_PER_TURN: 113 # Build completion kwargs for tool calls 114 completion_kwargs = await self.build_completion_kwargs( 115 provider, 116 # Pass a copy, as acompletion mutates objects and breaks types. 117 copy.deepcopy(messages), 118 top_logprobs, 119 skip_response_format, 120 ) 121 122 # Make the completion call 123 model_response, response_choice = await self.acompletion_checking_response( 124 **completion_kwargs 125 ) 126 127 # count the usage 128 usage += self.usage_from_response(model_response) 129 130 # Extract content and tool calls 131 if not hasattr(response_choice, "message"): 132 raise ValueError("Response choice has no message") 133 content = response_choice.message.content 134 tool_calls = response_choice.message.tool_calls 135 if not content and not tool_calls: 136 raise ValueError( 137 "Model returned an assistant message, but no content or tool calls. This is not supported." 138 ) 139 140 # Add message to messages, so it can be used in the next turn 141 messages.append(response_choice.message) 142 143 # Process tool calls if any 144 if tool_calls and len(tool_calls) > 0: 145 ( 146 assistant_message_from_toolcall, 147 tool_call_messages, 148 ) = await self.process_tool_calls(tool_calls) 149 150 # Add tool call results to messages 151 messages.extend(tool_call_messages) 152 153 # If task_response tool was called, we're done 154 if assistant_message_from_toolcall is not None: 155 return ModelTurnResult( 156 assistant_message=assistant_message_from_toolcall, 157 all_messages=messages, 158 model_response=model_response, 159 model_choice=response_choice, 160 usage=usage, 161 ) 162 163 # If there were tool calls, increment counter and continue 164 if tool_call_messages: 165 tool_calls_count += 1 166 continue 167 168 # If no tool calls, return the content as final output 169 if content: 170 return ModelTurnResult( 171 assistant_message=content, 172 all_messages=messages, 173 model_response=model_response, 174 model_choice=response_choice, 175 usage=usage, 176 ) 177 178 # If we get here with no content and no tool calls, break 179 raise RuntimeError( 180 "Model returned neither content nor tool calls. It must return at least one of these." 181 ) 182 183 raise RuntimeError( 184 f"Too many tool calls ({tool_calls_count}). Stopping iteration to avoid using too many tokens." 185 ) 186 187 async def _run(self, input: InputType) -> tuple[RunOutput, Usage | None]: 188 usage = Usage() 189 190 provider = self.model_provider() 191 if not provider.model_id: 192 raise ValueError("Model ID is required for OpenAI compatible models") 193 194 chat_formatter = self.build_chat_formatter(input) 195 messages: list[ChatCompletionMessageIncludingLiteLLM] = [] 196 197 prior_output: str | None = None 198 final_choice: Choices | None = None 199 turns = 0 200 201 while True: 202 turns += 1 203 if turns > MAX_CALLS_PER_TURN: 204 raise RuntimeError( 205 f"Too many turns ({turns}). Stopping iteration to avoid using too many tokens." 206 ) 207 208 turn = chat_formatter.next_turn(prior_output) 209 if turn is None: 210 # No next turn, we're done 211 break 212 213 # Add messages from the turn to chat history 214 for message in turn.messages: 215 if message.content is None: 216 raise ValueError("Empty message content isn't allowed") 217 # pyright incorrectly warns about this, but it's valid so we can ignore. It can't handle the multi-value role. 218 messages.append({"role": message.role, "content": message.content}) # type: ignore 219 220 skip_response_format = not turn.final_call 221 turn_result = await self._run_model_turn( 222 provider, 223 messages, 224 self.base_adapter_config.top_logprobs if turn.final_call else None, 225 skip_response_format, 226 ) 227 228 usage += turn_result.usage 229 230 prior_output = turn_result.assistant_message 231 messages = turn_result.all_messages 232 final_choice = turn_result.model_choice 233 234 if not prior_output: 235 raise RuntimeError("No assistant message/output returned from model") 236 237 logprobs = self._extract_and_validate_logprobs(final_choice) 238 239 # Save COT/reasoning if it exists. May be a message, or may be parsed by LiteLLM (or openrouter, or anyone upstream) 240 intermediate_outputs = chat_formatter.intermediate_outputs() 241 self._extract_reasoning_to_intermediate_outputs( 242 final_choice, intermediate_outputs 243 ) 244 245 if not isinstance(prior_output, str): 246 raise RuntimeError(f"assistant message is not a string: {prior_output}") 247 248 trace = self.all_messages_to_trace(messages) 249 output = RunOutput( 250 output=prior_output, 251 intermediate_outputs=intermediate_outputs, 252 output_logprobs=logprobs, 253 trace=trace, 254 ) 255 256 return output, usage 257 258 def _extract_and_validate_logprobs( 259 self, final_choice: Choices | None 260 ) -> ChoiceLogprobs | None: 261 """ 262 Extract logprobs from the final choice and validate they exist if required. 263 """ 264 logprobs = None 265 if ( 266 final_choice is not None 267 and hasattr(final_choice, "logprobs") 268 and isinstance(final_choice.logprobs, ChoiceLogprobs) 269 ): 270 logprobs = final_choice.logprobs 271 272 # Check logprobs worked, if required 273 if self.base_adapter_config.top_logprobs is not None and logprobs is None: 274 raise RuntimeError("Logprobs were required, but no logprobs were returned.") 275 276 return logprobs 277 278 def _extract_reasoning_to_intermediate_outputs( 279 self, final_choice: Choices | None, intermediate_outputs: Dict[str, Any] 280 ) -> None: 281 """Extract reasoning content from model choice and add to intermediate outputs if present.""" 282 if ( 283 final_choice is not None 284 and hasattr(final_choice, "message") 285 and hasattr(final_choice.message, "reasoning_content") 286 ): 287 reasoning_content = final_choice.message.reasoning_content 288 if reasoning_content is not None: 289 stripped_reasoning_content = reasoning_content.strip() 290 if len(stripped_reasoning_content) > 0: 291 intermediate_outputs["reasoning"] = stripped_reasoning_content 292 293 async def acompletion_checking_response( 294 self, **kwargs 295 ) -> Tuple[ModelResponse, Choices]: 296 response = await litellm.acompletion(**kwargs) 297 if ( 298 not isinstance(response, ModelResponse) 299 or not response.choices 300 or len(response.choices) == 0 301 or not isinstance(response.choices[0], Choices) 302 ): 303 raise RuntimeError( 304 f"Expected ModelResponse with Choices, got {type(response)}." 305 ) 306 return response, response.choices[0] 307 308 def adapter_name(self) -> str: 309 return "kiln_openai_compatible_adapter" 310 311 async def response_format_options(self) -> dict[str, Any]: 312 # Unstructured if task isn't structured 313 if not self.has_structured_output(): 314 return {} 315 316 run_config = as_kiln_agent_run_config(self.run_config) 317 structured_output_mode: StructuredOutputMode = run_config.structured_output_mode 318 319 match structured_output_mode: 320 case StructuredOutputMode.json_mode: 321 return {"response_format": {"type": "json_object"}} 322 case StructuredOutputMode.json_schema: 323 return self.json_schema_response_format() 324 case StructuredOutputMode.function_calling_weak: 325 return self.tool_call_params(strict=False) 326 case StructuredOutputMode.function_calling: 327 return self.tool_call_params(strict=True) 328 case StructuredOutputMode.json_instructions: 329 # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below). 330 return {} 331 case StructuredOutputMode.json_custom_instructions: 332 # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above). 333 return {} 334 case StructuredOutputMode.json_instruction_and_object: 335 # We set response_format to json_object and also set json instructions in the prompt 336 return {"response_format": {"type": "json_object"}} 337 case StructuredOutputMode.default: 338 provider_name = run_config.model_provider_name 339 if provider_name == ModelProviderName.ollama: 340 # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs 341 return self.json_schema_response_format() 342 elif provider_name == ModelProviderName.docker_model_runner: 343 # Docker Model Runner uses OpenAI-compatible API with JSON schema support 344 return self.json_schema_response_format() 345 else: 346 # Default to function calling -- it's older than the other modes. Higher compatibility. 347 # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI. 348 strict = provider_name == ModelProviderName.openai 349 return self.tool_call_params(strict=strict) 350 case StructuredOutputMode.unknown: 351 # See above, but this case should never happen. 352 raise ValueError("Structured output mode is unknown.") 353 case _: 354 raise_exhaustive_enum_error(structured_output_mode) # type: ignore[arg-type] 355 356 def json_schema_response_format(self) -> dict[str, Any]: 357 output_schema = self.task.output_schema() 358 return { 359 "response_format": { 360 "type": "json_schema", 361 "json_schema": { 362 "name": "task_response", 363 "schema": output_schema, 364 }, 365 } 366 } 367 368 def tool_call_params(self, strict: bool) -> dict[str, Any]: 369 # Add additional_properties: false to the schema (OpenAI requires this for some models) 370 output_schema = self.task.output_schema() 371 if not isinstance(output_schema, dict): 372 raise ValueError( 373 "Invalid output schema for this task. Can not use tool calls." 374 ) 375 output_schema["additionalProperties"] = False 376 377 function_params = { 378 "name": "task_response", 379 "parameters": output_schema, 380 } 381 # This should be on, but we allow setting function_calling_weak for APIs that don't support it. 382 if strict: 383 function_params["strict"] = True 384 385 return { 386 "tools": [ 387 { 388 "type": "function", 389 "function": function_params, 390 } 391 ], 392 "tool_choice": { 393 "type": "function", 394 "function": {"name": "task_response"}, 395 }, 396 } 397 398 def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]: 399 # Don't love having this logic here. But it's worth the usability improvement 400 # so better to keep it than exclude it. Should figure out how I want to isolate 401 # this sort of logic so it's config driven and can be overridden 402 403 extra_body = {} 404 provider_options = {} 405 406 run_config = as_kiln_agent_run_config(self.run_config) 407 # For legacy config 'thinking_level' is not set, default to provider's default 408 if "thinking_level" in run_config.model_fields_set: 409 thinking_level = run_config.thinking_level 410 else: 411 thinking_level = provider.default_thinking_level 412 413 # Set the reasoning_effort 414 if thinking_level is not None: 415 # Anthropic models in OpenRouter uses reasoning object. See https://openrouter.ai/docs/use-cases/reasoning-tokens 416 if ( 417 provider.name == ModelProviderName.openrouter 418 and provider.openrouter_reasoning_object 419 ): 420 extra_body["reasoning"] = {"effort": thinking_level} 421 else: 422 extra_body["reasoning_effort"] = thinking_level 423 424 if provider.require_openrouter_reasoning: 425 # https://openrouter.ai/docs/use-cases/reasoning-tokens 426 extra_body["reasoning"] = { 427 "exclude": False, 428 } 429 430 if provider.gemini_reasoning_enabled: 431 extra_body["reasoning"] = { 432 "enabled": True, 433 } 434 435 if provider.name == ModelProviderName.openrouter: 436 # Ask OpenRouter to include usage in the response (cost) 437 extra_body["usage"] = {"include": True} 438 439 # Set a default provider order for more deterministic routing. 440 # OpenRouter will ignore providers that don't support the model. 441 # Special cases below (like R1) can override this order. 442 # allow_fallbacks is true by default, but we can override it here. 443 provider_options["order"] = [ 444 "fireworks", 445 "parasail", 446 "together", 447 "deepinfra", 448 "novita", 449 "groq", 450 "amazon-bedrock", 451 "azure", 452 "nebius", 453 ] 454 455 if provider.anthropic_extended_thinking: 456 extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000} 457 458 if provider.r1_openrouter_options: 459 # Require providers that support the reasoning parameter 460 provider_options["require_parameters"] = True 461 # Prefer R1 providers with reasonable perf/quants 462 provider_options["order"] = ["fireworks", "together"] 463 # R1 providers with unreasonable quants 464 provider_options["ignore"] = ["deepinfra"] 465 466 # Only set of this request is to get logprobs. 467 if ( 468 provider.logprobs_openrouter_options 469 and self.base_adapter_config.top_logprobs is not None 470 ): 471 # Don't let OpenRouter choose a provider that doesn't support logprobs. 472 provider_options["require_parameters"] = True 473 # DeepInfra silently fails to return logprobs consistently. 474 provider_options["ignore"] = ["deepinfra"] 475 476 if provider.openrouter_skip_required_parameters: 477 # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params. 478 provider_options["require_parameters"] = False 479 480 # Siliconflow uses a bool flag for thinking, for some models 481 if provider.siliconflow_enable_thinking is not None: 482 extra_body["enable_thinking"] = provider.siliconflow_enable_thinking 483 484 if len(provider_options) > 0: 485 extra_body["provider"] = provider_options 486 487 return extra_body 488 489 def litellm_model_id(self) -> str: 490 # The model ID is an interesting combination of format and url endpoint. 491 # It specifics the provider URL/host, but this is overridden if you manually set an api url 492 if self._litellm_model_id: 493 return self._litellm_model_id 494 495 litellm_provider_info = get_litellm_provider_info(self.model_provider()) 496 if litellm_provider_info.is_custom and self._api_base is None: 497 raise ValueError( 498 "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)" 499 ) 500 501 self._litellm_model_id = litellm_provider_info.litellm_model_id 502 return self._litellm_model_id 503 504 async def build_completion_kwargs( 505 self, 506 provider: KilnModelProvider, 507 messages: list[ChatCompletionMessageIncludingLiteLLM], 508 top_logprobs: int | None, 509 skip_response_format: bool = False, 510 ) -> dict[str, Any]: 511 run_config = as_kiln_agent_run_config(self.run_config) 512 extra_body = self.build_extra_body(provider) 513 514 # Merge all parameters into a single kwargs dict for litellm 515 completion_kwargs = { 516 "model": self.litellm_model_id(), 517 "messages": messages, 518 "api_base": self._api_base, 519 "headers": self._headers, 520 "temperature": run_config.temperature, 521 "top_p": run_config.top_p, 522 # This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc. 523 # Not all models and providers support all openai params (for example, o3 doesn't support top_p) 524 # Better to ignore them than to fail the model call. 525 # https://docs.litellm.ai/docs/completion/input 526 "drop_params": True, 527 **extra_body, 528 **self._additional_body_options, 529 } 530 531 tool_calls = await self.litellm_tools() 532 has_tools = len(tool_calls) > 0 533 if has_tools: 534 completion_kwargs["tools"] = tool_calls 535 completion_kwargs["tool_choice"] = "auto" 536 537 # Special condition for Claude Opus 4.1 and Sonnet 4.5, where we can only specify top_p or temp, not both. 538 # Remove default values (1.0) prioritizing anything the user customized, then error with helpful message if they are both custom. 539 if provider.temp_top_p_exclusive: 540 if "top_p" in completion_kwargs and completion_kwargs["top_p"] == 1.0: 541 del completion_kwargs["top_p"] 542 if ( 543 "temperature" in completion_kwargs 544 and completion_kwargs["temperature"] == 1.0 545 ): 546 del completion_kwargs["temperature"] 547 if "top_p" in completion_kwargs and "temperature" in completion_kwargs: 548 raise ValueError( 549 "top_p and temperature can not both have custom values for this model. This is a restriction from the model provider. Please set only one of them to a custom value (not 1.0)." 550 ) 551 552 if not skip_response_format: 553 # Response format: json_schema, json_instructions, json_mode, function_calling, etc 554 response_format_options = await self.response_format_options() 555 556 # Check for a conflict between tools and response format using tools 557 # We could reconsider this. Model could be able to choose between a final answer or a tool call on any turn. However, good models for tools tend to also support json_schea, so do we need to support both? If we do, merge them, and consider auto vs forced when merging (only forced for final, auto for merged). 558 if has_tools and "tools" in response_format_options: 559 raise ValueError( 560 "Function calling/tools can't be used as the JSON response format if you're also using tools. Please select a different structured output mode." 561 ) 562 563 completion_kwargs.update(response_format_options) 564 565 if top_logprobs is not None: 566 completion_kwargs["logprobs"] = True 567 completion_kwargs["top_logprobs"] = top_logprobs 568 569 return completion_kwargs 570 571 def usage_from_response(self, response: ModelResponse) -> Usage: 572 litellm_usage = response.get("usage", None) 573 574 # LiteLLM isn't consistent in how it returns the cost. 575 cost = response._hidden_params.get("response_cost", None) 576 if cost is None and litellm_usage: 577 cost = litellm_usage.get("cost", None) 578 579 usage = Usage() 580 581 if not litellm_usage and not cost: 582 return usage 583 584 if litellm_usage and isinstance(litellm_usage, LiteLlmUsage): 585 usage.input_tokens = litellm_usage.get("prompt_tokens", None) 586 usage.output_tokens = litellm_usage.get("completion_tokens", None) 587 usage.total_tokens = litellm_usage.get("total_tokens", None) 588 else: 589 logger.warning( 590 f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}" 591 ) 592 593 if isinstance(cost, float): 594 usage.cost = cost 595 elif cost is not None: 596 # None is allowed, but no other types are expected 597 logger.warning( 598 f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}" 599 ) 600 601 return usage 602 603 async def cached_available_tools(self) -> list[KilnToolInterface]: 604 if self._cached_available_tools is None: 605 self._cached_available_tools = await self.available_tools() 606 return self._cached_available_tools 607 608 async def litellm_tools(self) -> list[ToolCallDefinition]: 609 available_tools = await self.cached_available_tools() 610 611 # LiteLLM takes the standard OpenAI-compatible tool call format 612 return [await tool.toolcall_definition() for tool in available_tools] 613 614 async def process_tool_calls( 615 self, tool_calls: list[ChatCompletionMessageToolCall] | None 616 ) -> tuple[str | None, list[ChatCompletionToolMessageParamWrapper]]: 617 if tool_calls is None: 618 return None, [] 619 620 assistant_output_from_toolcall: str | None = None 621 tool_call_response_messages: list[ChatCompletionToolMessageParamWrapper] = [] 622 tool_run_coroutines = [] 623 624 for tool_call in tool_calls: 625 # Kiln "task_response" tool is used for returning structured output via tool calls. 626 # Load the output from the tool call. Also 627 if tool_call.function.name == "task_response": 628 assistant_output_from_toolcall = tool_call.function.arguments 629 continue 630 631 # Process normal tool calls (not the "task_response" tool) 632 tool_name = tool_call.function.name 633 tool = None 634 for tool_option in await self.cached_available_tools(): 635 if await tool_option.name() == tool_name: 636 tool = tool_option 637 break 638 if not tool: 639 raise RuntimeError( 640 f"A tool named '{tool_name}' was invoked by a model, but was not available." 641 ) 642 643 # Parse the arguments and validate them against the tool's schema 644 try: 645 parsed_args = json.loads(tool_call.function.arguments) 646 except json.JSONDecodeError: 647 raise RuntimeError( 648 f"Failed to parse arguments for tool '{tool_name}' (should be JSON): {tool_call.function.arguments}" 649 ) 650 try: 651 tool_call_definition = await tool.toolcall_definition() 652 json_schema = json.dumps(tool_call_definition["function"]["parameters"]) 653 validate_schema_with_value_error(parsed_args, json_schema) 654 except Exception as e: 655 raise RuntimeError( 656 f"Failed to validate arguments for tool '{tool_name}'. The arguments didn't match the tool's schema. The arguments were: {parsed_args}\n The error was: {e}" 657 ) from e 658 659 # Create context with the calling task's allow_saving setting 660 context = ToolCallContext( 661 allow_saving=self.base_adapter_config.allow_saving 662 ) 663 664 async def run_tool_and_format( 665 t=tool, c=context, args=parsed_args, tc_id=tool_call.id 666 ): 667 result = await t.run(c, **args) 668 return ChatCompletionToolMessageParamWrapper( 669 role="tool", 670 tool_call_id=tc_id, 671 content=result.output, 672 kiln_task_tool_data=result.kiln_task_tool_data 673 if isinstance(result, KilnTaskToolResult) 674 else None, 675 ) 676 677 tool_run_coroutines.append(run_tool_and_format()) 678 679 if tool_run_coroutines: 680 tool_call_response_messages = await asyncio.gather(*tool_run_coroutines) 681 682 if ( 683 assistant_output_from_toolcall is not None 684 and len(tool_call_response_messages) > 0 685 ): 686 raise RuntimeError( 687 "Model asked for impossible combination: task_response tool call and other tool calls were both provided in the same turn. This is not supported as it means the model asked us to both return task_response results (ending the turn) and run new tools calls to send back to the model. If the model makes this mistake often, try a difference structured data model like JSON schema, where this is impossible." 688 ) 689 690 return assistant_output_from_toolcall, tool_call_response_messages 691 692 def litellm_message_to_trace_message( 693 self, raw_message: LiteLLMMessage 694 ) -> ChatCompletionAssistantMessageParamWrapper: 695 """ 696 Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper 697 """ 698 message: ChatCompletionAssistantMessageParamWrapper = { 699 "role": "assistant", 700 } 701 if raw_message.role != "assistant": 702 raise ValueError( 703 "Model returned a message with a role other than assistant. This is not supported." 704 ) 705 706 if hasattr(raw_message, "content"): 707 message["content"] = raw_message.content 708 if hasattr(raw_message, "reasoning_content"): 709 message["reasoning_content"] = raw_message.reasoning_content 710 if hasattr(raw_message, "tool_calls"): 711 # Convert ChatCompletionMessageToolCall to ChatCompletionMessageToolCallParam 712 open_ai_tool_calls: List[ChatCompletionMessageToolCallParam] = [] 713 for litellm_tool_call in raw_message.tool_calls or []: 714 # Optional in the SDK for streaming responses, but should never be None at this point. 715 if litellm_tool_call.function.name is None: 716 raise ValueError( 717 "The model requested a tool call, without providing a function name (required)." 718 ) 719 open_ai_tool_calls.append( 720 ChatCompletionMessageToolCallParam( 721 id=litellm_tool_call.id, 722 type="function", 723 function={ 724 "name": litellm_tool_call.function.name, 725 "arguments": litellm_tool_call.function.arguments, 726 }, 727 ) 728 ) 729 if len(open_ai_tool_calls) > 0: 730 message["tool_calls"] = open_ai_tool_calls 731 732 if not message.get("content") and not message.get("tool_calls"): 733 raise ValueError( 734 "Model returned an assistant message, but no content or tool calls. This is not supported." 735 ) 736 737 return message 738 739 def all_messages_to_trace( 740 self, messages: list[ChatCompletionMessageIncludingLiteLLM] 741 ) -> list[ChatCompletionMessageParam]: 742 """ 743 Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types. 744 """ 745 trace: list[ChatCompletionMessageParam] = [] 746 for message in messages: 747 if isinstance(message, LiteLLMMessage): 748 trace.append(self.litellm_message_to_trace_message(message)) 749 else: 750 trace.append(message) 751 return trace
65@dataclass 66class ModelTurnResult: 67 assistant_message: str 68 all_messages: list[ChatCompletionMessageIncludingLiteLLM] 69 model_response: ModelResponse | None 70 model_choice: Choices | None 71 usage: Usage
74class LiteLlmAdapter(BaseAdapter): 75 def __init__( 76 self, 77 config: LiteLlmConfig, 78 kiln_task: datamodel.Task, 79 base_adapter_config: AdapterConfig | None = None, 80 ): 81 if not isinstance(config.run_config_properties, KilnAgentRunConfigProperties): 82 raise ValueError("LiteLlmAdapter requires KilnAgentRunConfigProperties") 83 self.config = config 84 self._additional_body_options = config.additional_body_options 85 self._api_base = config.base_url 86 self._headers = config.default_headers 87 self._litellm_model_id: str | None = None 88 self._cached_available_tools: list[KilnToolInterface] | None = None 89 90 super().__init__( 91 task=kiln_task, 92 run_config=config.run_config_properties, 93 config=base_adapter_config, 94 ) 95 96 async def _run_model_turn( 97 self, 98 provider: KilnModelProvider, 99 prior_messages: list[ChatCompletionMessageIncludingLiteLLM], 100 top_logprobs: int | None, 101 skip_response_format: bool, 102 ) -> ModelTurnResult: 103 """ 104 Call the model for a single top level turn: from user message to agent message. 105 106 It may make handle iterations of tool calls between the user/agent message if needed. 107 """ 108 109 usage = Usage() 110 messages = list(prior_messages) 111 tool_calls_count = 0 112 113 while tool_calls_count < MAX_TOOL_CALLS_PER_TURN: 114 # Build completion kwargs for tool calls 115 completion_kwargs = await self.build_completion_kwargs( 116 provider, 117 # Pass a copy, as acompletion mutates objects and breaks types. 118 copy.deepcopy(messages), 119 top_logprobs, 120 skip_response_format, 121 ) 122 123 # Make the completion call 124 model_response, response_choice = await self.acompletion_checking_response( 125 **completion_kwargs 126 ) 127 128 # count the usage 129 usage += self.usage_from_response(model_response) 130 131 # Extract content and tool calls 132 if not hasattr(response_choice, "message"): 133 raise ValueError("Response choice has no message") 134 content = response_choice.message.content 135 tool_calls = response_choice.message.tool_calls 136 if not content and not tool_calls: 137 raise ValueError( 138 "Model returned an assistant message, but no content or tool calls. This is not supported." 139 ) 140 141 # Add message to messages, so it can be used in the next turn 142 messages.append(response_choice.message) 143 144 # Process tool calls if any 145 if tool_calls and len(tool_calls) > 0: 146 ( 147 assistant_message_from_toolcall, 148 tool_call_messages, 149 ) = await self.process_tool_calls(tool_calls) 150 151 # Add tool call results to messages 152 messages.extend(tool_call_messages) 153 154 # If task_response tool was called, we're done 155 if assistant_message_from_toolcall is not None: 156 return ModelTurnResult( 157 assistant_message=assistant_message_from_toolcall, 158 all_messages=messages, 159 model_response=model_response, 160 model_choice=response_choice, 161 usage=usage, 162 ) 163 164 # If there were tool calls, increment counter and continue 165 if tool_call_messages: 166 tool_calls_count += 1 167 continue 168 169 # If no tool calls, return the content as final output 170 if content: 171 return ModelTurnResult( 172 assistant_message=content, 173 all_messages=messages, 174 model_response=model_response, 175 model_choice=response_choice, 176 usage=usage, 177 ) 178 179 # If we get here with no content and no tool calls, break 180 raise RuntimeError( 181 "Model returned neither content nor tool calls. It must return at least one of these." 182 ) 183 184 raise RuntimeError( 185 f"Too many tool calls ({tool_calls_count}). Stopping iteration to avoid using too many tokens." 186 ) 187 188 async def _run(self, input: InputType) -> tuple[RunOutput, Usage | None]: 189 usage = Usage() 190 191 provider = self.model_provider() 192 if not provider.model_id: 193 raise ValueError("Model ID is required for OpenAI compatible models") 194 195 chat_formatter = self.build_chat_formatter(input) 196 messages: list[ChatCompletionMessageIncludingLiteLLM] = [] 197 198 prior_output: str | None = None 199 final_choice: Choices | None = None 200 turns = 0 201 202 while True: 203 turns += 1 204 if turns > MAX_CALLS_PER_TURN: 205 raise RuntimeError( 206 f"Too many turns ({turns}). Stopping iteration to avoid using too many tokens." 207 ) 208 209 turn = chat_formatter.next_turn(prior_output) 210 if turn is None: 211 # No next turn, we're done 212 break 213 214 # Add messages from the turn to chat history 215 for message in turn.messages: 216 if message.content is None: 217 raise ValueError("Empty message content isn't allowed") 218 # pyright incorrectly warns about this, but it's valid so we can ignore. It can't handle the multi-value role. 219 messages.append({"role": message.role, "content": message.content}) # type: ignore 220 221 skip_response_format = not turn.final_call 222 turn_result = await self._run_model_turn( 223 provider, 224 messages, 225 self.base_adapter_config.top_logprobs if turn.final_call else None, 226 skip_response_format, 227 ) 228 229 usage += turn_result.usage 230 231 prior_output = turn_result.assistant_message 232 messages = turn_result.all_messages 233 final_choice = turn_result.model_choice 234 235 if not prior_output: 236 raise RuntimeError("No assistant message/output returned from model") 237 238 logprobs = self._extract_and_validate_logprobs(final_choice) 239 240 # Save COT/reasoning if it exists. May be a message, or may be parsed by LiteLLM (or openrouter, or anyone upstream) 241 intermediate_outputs = chat_formatter.intermediate_outputs() 242 self._extract_reasoning_to_intermediate_outputs( 243 final_choice, intermediate_outputs 244 ) 245 246 if not isinstance(prior_output, str): 247 raise RuntimeError(f"assistant message is not a string: {prior_output}") 248 249 trace = self.all_messages_to_trace(messages) 250 output = RunOutput( 251 output=prior_output, 252 intermediate_outputs=intermediate_outputs, 253 output_logprobs=logprobs, 254 trace=trace, 255 ) 256 257 return output, usage 258 259 def _extract_and_validate_logprobs( 260 self, final_choice: Choices | None 261 ) -> ChoiceLogprobs | None: 262 """ 263 Extract logprobs from the final choice and validate they exist if required. 264 """ 265 logprobs = None 266 if ( 267 final_choice is not None 268 and hasattr(final_choice, "logprobs") 269 and isinstance(final_choice.logprobs, ChoiceLogprobs) 270 ): 271 logprobs = final_choice.logprobs 272 273 # Check logprobs worked, if required 274 if self.base_adapter_config.top_logprobs is not None and logprobs is None: 275 raise RuntimeError("Logprobs were required, but no logprobs were returned.") 276 277 return logprobs 278 279 def _extract_reasoning_to_intermediate_outputs( 280 self, final_choice: Choices | None, intermediate_outputs: Dict[str, Any] 281 ) -> None: 282 """Extract reasoning content from model choice and add to intermediate outputs if present.""" 283 if ( 284 final_choice is not None 285 and hasattr(final_choice, "message") 286 and hasattr(final_choice.message, "reasoning_content") 287 ): 288 reasoning_content = final_choice.message.reasoning_content 289 if reasoning_content is not None: 290 stripped_reasoning_content = reasoning_content.strip() 291 if len(stripped_reasoning_content) > 0: 292 intermediate_outputs["reasoning"] = stripped_reasoning_content 293 294 async def acompletion_checking_response( 295 self, **kwargs 296 ) -> Tuple[ModelResponse, Choices]: 297 response = await litellm.acompletion(**kwargs) 298 if ( 299 not isinstance(response, ModelResponse) 300 or not response.choices 301 or len(response.choices) == 0 302 or not isinstance(response.choices[0], Choices) 303 ): 304 raise RuntimeError( 305 f"Expected ModelResponse with Choices, got {type(response)}." 306 ) 307 return response, response.choices[0] 308 309 def adapter_name(self) -> str: 310 return "kiln_openai_compatible_adapter" 311 312 async def response_format_options(self) -> dict[str, Any]: 313 # Unstructured if task isn't structured 314 if not self.has_structured_output(): 315 return {} 316 317 run_config = as_kiln_agent_run_config(self.run_config) 318 structured_output_mode: StructuredOutputMode = run_config.structured_output_mode 319 320 match structured_output_mode: 321 case StructuredOutputMode.json_mode: 322 return {"response_format": {"type": "json_object"}} 323 case StructuredOutputMode.json_schema: 324 return self.json_schema_response_format() 325 case StructuredOutputMode.function_calling_weak: 326 return self.tool_call_params(strict=False) 327 case StructuredOutputMode.function_calling: 328 return self.tool_call_params(strict=True) 329 case StructuredOutputMode.json_instructions: 330 # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below). 331 return {} 332 case StructuredOutputMode.json_custom_instructions: 333 # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above). 334 return {} 335 case StructuredOutputMode.json_instruction_and_object: 336 # We set response_format to json_object and also set json instructions in the prompt 337 return {"response_format": {"type": "json_object"}} 338 case StructuredOutputMode.default: 339 provider_name = run_config.model_provider_name 340 if provider_name == ModelProviderName.ollama: 341 # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs 342 return self.json_schema_response_format() 343 elif provider_name == ModelProviderName.docker_model_runner: 344 # Docker Model Runner uses OpenAI-compatible API with JSON schema support 345 return self.json_schema_response_format() 346 else: 347 # Default to function calling -- it's older than the other modes. Higher compatibility. 348 # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI. 349 strict = provider_name == ModelProviderName.openai 350 return self.tool_call_params(strict=strict) 351 case StructuredOutputMode.unknown: 352 # See above, but this case should never happen. 353 raise ValueError("Structured output mode is unknown.") 354 case _: 355 raise_exhaustive_enum_error(structured_output_mode) # type: ignore[arg-type] 356 357 def json_schema_response_format(self) -> dict[str, Any]: 358 output_schema = self.task.output_schema() 359 return { 360 "response_format": { 361 "type": "json_schema", 362 "json_schema": { 363 "name": "task_response", 364 "schema": output_schema, 365 }, 366 } 367 } 368 369 def tool_call_params(self, strict: bool) -> dict[str, Any]: 370 # Add additional_properties: false to the schema (OpenAI requires this for some models) 371 output_schema = self.task.output_schema() 372 if not isinstance(output_schema, dict): 373 raise ValueError( 374 "Invalid output schema for this task. Can not use tool calls." 375 ) 376 output_schema["additionalProperties"] = False 377 378 function_params = { 379 "name": "task_response", 380 "parameters": output_schema, 381 } 382 # This should be on, but we allow setting function_calling_weak for APIs that don't support it. 383 if strict: 384 function_params["strict"] = True 385 386 return { 387 "tools": [ 388 { 389 "type": "function", 390 "function": function_params, 391 } 392 ], 393 "tool_choice": { 394 "type": "function", 395 "function": {"name": "task_response"}, 396 }, 397 } 398 399 def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]: 400 # Don't love having this logic here. But it's worth the usability improvement 401 # so better to keep it than exclude it. Should figure out how I want to isolate 402 # this sort of logic so it's config driven and can be overridden 403 404 extra_body = {} 405 provider_options = {} 406 407 run_config = as_kiln_agent_run_config(self.run_config) 408 # For legacy config 'thinking_level' is not set, default to provider's default 409 if "thinking_level" in run_config.model_fields_set: 410 thinking_level = run_config.thinking_level 411 else: 412 thinking_level = provider.default_thinking_level 413 414 # Set the reasoning_effort 415 if thinking_level is not None: 416 # Anthropic models in OpenRouter uses reasoning object. See https://openrouter.ai/docs/use-cases/reasoning-tokens 417 if ( 418 provider.name == ModelProviderName.openrouter 419 and provider.openrouter_reasoning_object 420 ): 421 extra_body["reasoning"] = {"effort": thinking_level} 422 else: 423 extra_body["reasoning_effort"] = thinking_level 424 425 if provider.require_openrouter_reasoning: 426 # https://openrouter.ai/docs/use-cases/reasoning-tokens 427 extra_body["reasoning"] = { 428 "exclude": False, 429 } 430 431 if provider.gemini_reasoning_enabled: 432 extra_body["reasoning"] = { 433 "enabled": True, 434 } 435 436 if provider.name == ModelProviderName.openrouter: 437 # Ask OpenRouter to include usage in the response (cost) 438 extra_body["usage"] = {"include": True} 439 440 # Set a default provider order for more deterministic routing. 441 # OpenRouter will ignore providers that don't support the model. 442 # Special cases below (like R1) can override this order. 443 # allow_fallbacks is true by default, but we can override it here. 444 provider_options["order"] = [ 445 "fireworks", 446 "parasail", 447 "together", 448 "deepinfra", 449 "novita", 450 "groq", 451 "amazon-bedrock", 452 "azure", 453 "nebius", 454 ] 455 456 if provider.anthropic_extended_thinking: 457 extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000} 458 459 if provider.r1_openrouter_options: 460 # Require providers that support the reasoning parameter 461 provider_options["require_parameters"] = True 462 # Prefer R1 providers with reasonable perf/quants 463 provider_options["order"] = ["fireworks", "together"] 464 # R1 providers with unreasonable quants 465 provider_options["ignore"] = ["deepinfra"] 466 467 # Only set of this request is to get logprobs. 468 if ( 469 provider.logprobs_openrouter_options 470 and self.base_adapter_config.top_logprobs is not None 471 ): 472 # Don't let OpenRouter choose a provider that doesn't support logprobs. 473 provider_options["require_parameters"] = True 474 # DeepInfra silently fails to return logprobs consistently. 475 provider_options["ignore"] = ["deepinfra"] 476 477 if provider.openrouter_skip_required_parameters: 478 # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params. 479 provider_options["require_parameters"] = False 480 481 # Siliconflow uses a bool flag for thinking, for some models 482 if provider.siliconflow_enable_thinking is not None: 483 extra_body["enable_thinking"] = provider.siliconflow_enable_thinking 484 485 if len(provider_options) > 0: 486 extra_body["provider"] = provider_options 487 488 return extra_body 489 490 def litellm_model_id(self) -> str: 491 # The model ID is an interesting combination of format and url endpoint. 492 # It specifics the provider URL/host, but this is overridden if you manually set an api url 493 if self._litellm_model_id: 494 return self._litellm_model_id 495 496 litellm_provider_info = get_litellm_provider_info(self.model_provider()) 497 if litellm_provider_info.is_custom and self._api_base is None: 498 raise ValueError( 499 "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)" 500 ) 501 502 self._litellm_model_id = litellm_provider_info.litellm_model_id 503 return self._litellm_model_id 504 505 async def build_completion_kwargs( 506 self, 507 provider: KilnModelProvider, 508 messages: list[ChatCompletionMessageIncludingLiteLLM], 509 top_logprobs: int | None, 510 skip_response_format: bool = False, 511 ) -> dict[str, Any]: 512 run_config = as_kiln_agent_run_config(self.run_config) 513 extra_body = self.build_extra_body(provider) 514 515 # Merge all parameters into a single kwargs dict for litellm 516 completion_kwargs = { 517 "model": self.litellm_model_id(), 518 "messages": messages, 519 "api_base": self._api_base, 520 "headers": self._headers, 521 "temperature": run_config.temperature, 522 "top_p": run_config.top_p, 523 # This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc. 524 # Not all models and providers support all openai params (for example, o3 doesn't support top_p) 525 # Better to ignore them than to fail the model call. 526 # https://docs.litellm.ai/docs/completion/input 527 "drop_params": True, 528 **extra_body, 529 **self._additional_body_options, 530 } 531 532 tool_calls = await self.litellm_tools() 533 has_tools = len(tool_calls) > 0 534 if has_tools: 535 completion_kwargs["tools"] = tool_calls 536 completion_kwargs["tool_choice"] = "auto" 537 538 # Special condition for Claude Opus 4.1 and Sonnet 4.5, where we can only specify top_p or temp, not both. 539 # Remove default values (1.0) prioritizing anything the user customized, then error with helpful message if they are both custom. 540 if provider.temp_top_p_exclusive: 541 if "top_p" in completion_kwargs and completion_kwargs["top_p"] == 1.0: 542 del completion_kwargs["top_p"] 543 if ( 544 "temperature" in completion_kwargs 545 and completion_kwargs["temperature"] == 1.0 546 ): 547 del completion_kwargs["temperature"] 548 if "top_p" in completion_kwargs and "temperature" in completion_kwargs: 549 raise ValueError( 550 "top_p and temperature can not both have custom values for this model. This is a restriction from the model provider. Please set only one of them to a custom value (not 1.0)." 551 ) 552 553 if not skip_response_format: 554 # Response format: json_schema, json_instructions, json_mode, function_calling, etc 555 response_format_options = await self.response_format_options() 556 557 # Check for a conflict between tools and response format using tools 558 # We could reconsider this. Model could be able to choose between a final answer or a tool call on any turn. However, good models for tools tend to also support json_schea, so do we need to support both? If we do, merge them, and consider auto vs forced when merging (only forced for final, auto for merged). 559 if has_tools and "tools" in response_format_options: 560 raise ValueError( 561 "Function calling/tools can't be used as the JSON response format if you're also using tools. Please select a different structured output mode." 562 ) 563 564 completion_kwargs.update(response_format_options) 565 566 if top_logprobs is not None: 567 completion_kwargs["logprobs"] = True 568 completion_kwargs["top_logprobs"] = top_logprobs 569 570 return completion_kwargs 571 572 def usage_from_response(self, response: ModelResponse) -> Usage: 573 litellm_usage = response.get("usage", None) 574 575 # LiteLLM isn't consistent in how it returns the cost. 576 cost = response._hidden_params.get("response_cost", None) 577 if cost is None and litellm_usage: 578 cost = litellm_usage.get("cost", None) 579 580 usage = Usage() 581 582 if not litellm_usage and not cost: 583 return usage 584 585 if litellm_usage and isinstance(litellm_usage, LiteLlmUsage): 586 usage.input_tokens = litellm_usage.get("prompt_tokens", None) 587 usage.output_tokens = litellm_usage.get("completion_tokens", None) 588 usage.total_tokens = litellm_usage.get("total_tokens", None) 589 else: 590 logger.warning( 591 f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}" 592 ) 593 594 if isinstance(cost, float): 595 usage.cost = cost 596 elif cost is not None: 597 # None is allowed, but no other types are expected 598 logger.warning( 599 f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}" 600 ) 601 602 return usage 603 604 async def cached_available_tools(self) -> list[KilnToolInterface]: 605 if self._cached_available_tools is None: 606 self._cached_available_tools = await self.available_tools() 607 return self._cached_available_tools 608 609 async def litellm_tools(self) -> list[ToolCallDefinition]: 610 available_tools = await self.cached_available_tools() 611 612 # LiteLLM takes the standard OpenAI-compatible tool call format 613 return [await tool.toolcall_definition() for tool in available_tools] 614 615 async def process_tool_calls( 616 self, tool_calls: list[ChatCompletionMessageToolCall] | None 617 ) -> tuple[str | None, list[ChatCompletionToolMessageParamWrapper]]: 618 if tool_calls is None: 619 return None, [] 620 621 assistant_output_from_toolcall: str | None = None 622 tool_call_response_messages: list[ChatCompletionToolMessageParamWrapper] = [] 623 tool_run_coroutines = [] 624 625 for tool_call in tool_calls: 626 # Kiln "task_response" tool is used for returning structured output via tool calls. 627 # Load the output from the tool call. Also 628 if tool_call.function.name == "task_response": 629 assistant_output_from_toolcall = tool_call.function.arguments 630 continue 631 632 # Process normal tool calls (not the "task_response" tool) 633 tool_name = tool_call.function.name 634 tool = None 635 for tool_option in await self.cached_available_tools(): 636 if await tool_option.name() == tool_name: 637 tool = tool_option 638 break 639 if not tool: 640 raise RuntimeError( 641 f"A tool named '{tool_name}' was invoked by a model, but was not available." 642 ) 643 644 # Parse the arguments and validate them against the tool's schema 645 try: 646 parsed_args = json.loads(tool_call.function.arguments) 647 except json.JSONDecodeError: 648 raise RuntimeError( 649 f"Failed to parse arguments for tool '{tool_name}' (should be JSON): {tool_call.function.arguments}" 650 ) 651 try: 652 tool_call_definition = await tool.toolcall_definition() 653 json_schema = json.dumps(tool_call_definition["function"]["parameters"]) 654 validate_schema_with_value_error(parsed_args, json_schema) 655 except Exception as e: 656 raise RuntimeError( 657 f"Failed to validate arguments for tool '{tool_name}'. The arguments didn't match the tool's schema. The arguments were: {parsed_args}\n The error was: {e}" 658 ) from e 659 660 # Create context with the calling task's allow_saving setting 661 context = ToolCallContext( 662 allow_saving=self.base_adapter_config.allow_saving 663 ) 664 665 async def run_tool_and_format( 666 t=tool, c=context, args=parsed_args, tc_id=tool_call.id 667 ): 668 result = await t.run(c, **args) 669 return ChatCompletionToolMessageParamWrapper( 670 role="tool", 671 tool_call_id=tc_id, 672 content=result.output, 673 kiln_task_tool_data=result.kiln_task_tool_data 674 if isinstance(result, KilnTaskToolResult) 675 else None, 676 ) 677 678 tool_run_coroutines.append(run_tool_and_format()) 679 680 if tool_run_coroutines: 681 tool_call_response_messages = await asyncio.gather(*tool_run_coroutines) 682 683 if ( 684 assistant_output_from_toolcall is not None 685 and len(tool_call_response_messages) > 0 686 ): 687 raise RuntimeError( 688 "Model asked for impossible combination: task_response tool call and other tool calls were both provided in the same turn. This is not supported as it means the model asked us to both return task_response results (ending the turn) and run new tools calls to send back to the model. If the model makes this mistake often, try a difference structured data model like JSON schema, where this is impossible." 689 ) 690 691 return assistant_output_from_toolcall, tool_call_response_messages 692 693 def litellm_message_to_trace_message( 694 self, raw_message: LiteLLMMessage 695 ) -> ChatCompletionAssistantMessageParamWrapper: 696 """ 697 Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper 698 """ 699 message: ChatCompletionAssistantMessageParamWrapper = { 700 "role": "assistant", 701 } 702 if raw_message.role != "assistant": 703 raise ValueError( 704 "Model returned a message with a role other than assistant. This is not supported." 705 ) 706 707 if hasattr(raw_message, "content"): 708 message["content"] = raw_message.content 709 if hasattr(raw_message, "reasoning_content"): 710 message["reasoning_content"] = raw_message.reasoning_content 711 if hasattr(raw_message, "tool_calls"): 712 # Convert ChatCompletionMessageToolCall to ChatCompletionMessageToolCallParam 713 open_ai_tool_calls: List[ChatCompletionMessageToolCallParam] = [] 714 for litellm_tool_call in raw_message.tool_calls or []: 715 # Optional in the SDK for streaming responses, but should never be None at this point. 716 if litellm_tool_call.function.name is None: 717 raise ValueError( 718 "The model requested a tool call, without providing a function name (required)." 719 ) 720 open_ai_tool_calls.append( 721 ChatCompletionMessageToolCallParam( 722 id=litellm_tool_call.id, 723 type="function", 724 function={ 725 "name": litellm_tool_call.function.name, 726 "arguments": litellm_tool_call.function.arguments, 727 }, 728 ) 729 ) 730 if len(open_ai_tool_calls) > 0: 731 message["tool_calls"] = open_ai_tool_calls 732 733 if not message.get("content") and not message.get("tool_calls"): 734 raise ValueError( 735 "Model returned an assistant message, but no content or tool calls. This is not supported." 736 ) 737 738 return message 739 740 def all_messages_to_trace( 741 self, messages: list[ChatCompletionMessageIncludingLiteLLM] 742 ) -> list[ChatCompletionMessageParam]: 743 """ 744 Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types. 745 """ 746 trace: list[ChatCompletionMessageParam] = [] 747 for message in messages: 748 if isinstance(message, LiteLLMMessage): 749 trace.append(self.litellm_message_to_trace_message(message)) 750 else: 751 trace.append(message) 752 return trace
Base class for AI model adapters that handle task execution.
This abstract class provides the foundation for implementing model-specific adapters that can process tasks with structured or unstructured inputs/outputs. It handles input/output validation, prompt building, and run tracking.
Prompt building is handled internally by the adapter, which uses a prompt builder based on the run config. To override the prompt building behavior, pass a custom prompt builder to the adapter config.
75 def __init__( 76 self, 77 config: LiteLlmConfig, 78 kiln_task: datamodel.Task, 79 base_adapter_config: AdapterConfig | None = None, 80 ): 81 if not isinstance(config.run_config_properties, KilnAgentRunConfigProperties): 82 raise ValueError("LiteLlmAdapter requires KilnAgentRunConfigProperties") 83 self.config = config 84 self._additional_body_options = config.additional_body_options 85 self._api_base = config.base_url 86 self._headers = config.default_headers 87 self._litellm_model_id: str | None = None 88 self._cached_available_tools: list[KilnToolInterface] | None = None 89 90 super().__init__( 91 task=kiln_task, 92 run_config=config.run_config_properties, 93 config=base_adapter_config, 94 )
294 async def acompletion_checking_response( 295 self, **kwargs 296 ) -> Tuple[ModelResponse, Choices]: 297 response = await litellm.acompletion(**kwargs) 298 if ( 299 not isinstance(response, ModelResponse) 300 or not response.choices 301 or len(response.choices) == 0 302 or not isinstance(response.choices[0], Choices) 303 ): 304 raise RuntimeError( 305 f"Expected ModelResponse with Choices, got {type(response)}." 306 ) 307 return response, response.choices[0]
312 async def response_format_options(self) -> dict[str, Any]: 313 # Unstructured if task isn't structured 314 if not self.has_structured_output(): 315 return {} 316 317 run_config = as_kiln_agent_run_config(self.run_config) 318 structured_output_mode: StructuredOutputMode = run_config.structured_output_mode 319 320 match structured_output_mode: 321 case StructuredOutputMode.json_mode: 322 return {"response_format": {"type": "json_object"}} 323 case StructuredOutputMode.json_schema: 324 return self.json_schema_response_format() 325 case StructuredOutputMode.function_calling_weak: 326 return self.tool_call_params(strict=False) 327 case StructuredOutputMode.function_calling: 328 return self.tool_call_params(strict=True) 329 case StructuredOutputMode.json_instructions: 330 # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below). 331 return {} 332 case StructuredOutputMode.json_custom_instructions: 333 # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above). 334 return {} 335 case StructuredOutputMode.json_instruction_and_object: 336 # We set response_format to json_object and also set json instructions in the prompt 337 return {"response_format": {"type": "json_object"}} 338 case StructuredOutputMode.default: 339 provider_name = run_config.model_provider_name 340 if provider_name == ModelProviderName.ollama: 341 # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs 342 return self.json_schema_response_format() 343 elif provider_name == ModelProviderName.docker_model_runner: 344 # Docker Model Runner uses OpenAI-compatible API with JSON schema support 345 return self.json_schema_response_format() 346 else: 347 # Default to function calling -- it's older than the other modes. Higher compatibility. 348 # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI. 349 strict = provider_name == ModelProviderName.openai 350 return self.tool_call_params(strict=strict) 351 case StructuredOutputMode.unknown: 352 # See above, but this case should never happen. 353 raise ValueError("Structured output mode is unknown.") 354 case _: 355 raise_exhaustive_enum_error(structured_output_mode) # type: ignore[arg-type]
369 def tool_call_params(self, strict: bool) -> dict[str, Any]: 370 # Add additional_properties: false to the schema (OpenAI requires this for some models) 371 output_schema = self.task.output_schema() 372 if not isinstance(output_schema, dict): 373 raise ValueError( 374 "Invalid output schema for this task. Can not use tool calls." 375 ) 376 output_schema["additionalProperties"] = False 377 378 function_params = { 379 "name": "task_response", 380 "parameters": output_schema, 381 } 382 # This should be on, but we allow setting function_calling_weak for APIs that don't support it. 383 if strict: 384 function_params["strict"] = True 385 386 return { 387 "tools": [ 388 { 389 "type": "function", 390 "function": function_params, 391 } 392 ], 393 "tool_choice": { 394 "type": "function", 395 "function": {"name": "task_response"}, 396 }, 397 }
399 def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]: 400 # Don't love having this logic here. But it's worth the usability improvement 401 # so better to keep it than exclude it. Should figure out how I want to isolate 402 # this sort of logic so it's config driven and can be overridden 403 404 extra_body = {} 405 provider_options = {} 406 407 run_config = as_kiln_agent_run_config(self.run_config) 408 # For legacy config 'thinking_level' is not set, default to provider's default 409 if "thinking_level" in run_config.model_fields_set: 410 thinking_level = run_config.thinking_level 411 else: 412 thinking_level = provider.default_thinking_level 413 414 # Set the reasoning_effort 415 if thinking_level is not None: 416 # Anthropic models in OpenRouter uses reasoning object. See https://openrouter.ai/docs/use-cases/reasoning-tokens 417 if ( 418 provider.name == ModelProviderName.openrouter 419 and provider.openrouter_reasoning_object 420 ): 421 extra_body["reasoning"] = {"effort": thinking_level} 422 else: 423 extra_body["reasoning_effort"] = thinking_level 424 425 if provider.require_openrouter_reasoning: 426 # https://openrouter.ai/docs/use-cases/reasoning-tokens 427 extra_body["reasoning"] = { 428 "exclude": False, 429 } 430 431 if provider.gemini_reasoning_enabled: 432 extra_body["reasoning"] = { 433 "enabled": True, 434 } 435 436 if provider.name == ModelProviderName.openrouter: 437 # Ask OpenRouter to include usage in the response (cost) 438 extra_body["usage"] = {"include": True} 439 440 # Set a default provider order for more deterministic routing. 441 # OpenRouter will ignore providers that don't support the model. 442 # Special cases below (like R1) can override this order. 443 # allow_fallbacks is true by default, but we can override it here. 444 provider_options["order"] = [ 445 "fireworks", 446 "parasail", 447 "together", 448 "deepinfra", 449 "novita", 450 "groq", 451 "amazon-bedrock", 452 "azure", 453 "nebius", 454 ] 455 456 if provider.anthropic_extended_thinking: 457 extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000} 458 459 if provider.r1_openrouter_options: 460 # Require providers that support the reasoning parameter 461 provider_options["require_parameters"] = True 462 # Prefer R1 providers with reasonable perf/quants 463 provider_options["order"] = ["fireworks", "together"] 464 # R1 providers with unreasonable quants 465 provider_options["ignore"] = ["deepinfra"] 466 467 # Only set of this request is to get logprobs. 468 if ( 469 provider.logprobs_openrouter_options 470 and self.base_adapter_config.top_logprobs is not None 471 ): 472 # Don't let OpenRouter choose a provider that doesn't support logprobs. 473 provider_options["require_parameters"] = True 474 # DeepInfra silently fails to return logprobs consistently. 475 provider_options["ignore"] = ["deepinfra"] 476 477 if provider.openrouter_skip_required_parameters: 478 # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params. 479 provider_options["require_parameters"] = False 480 481 # Siliconflow uses a bool flag for thinking, for some models 482 if provider.siliconflow_enable_thinking is not None: 483 extra_body["enable_thinking"] = provider.siliconflow_enable_thinking 484 485 if len(provider_options) > 0: 486 extra_body["provider"] = provider_options 487 488 return extra_body
490 def litellm_model_id(self) -> str: 491 # The model ID is an interesting combination of format and url endpoint. 492 # It specifics the provider URL/host, but this is overridden if you manually set an api url 493 if self._litellm_model_id: 494 return self._litellm_model_id 495 496 litellm_provider_info = get_litellm_provider_info(self.model_provider()) 497 if litellm_provider_info.is_custom and self._api_base is None: 498 raise ValueError( 499 "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)" 500 ) 501 502 self._litellm_model_id = litellm_provider_info.litellm_model_id 503 return self._litellm_model_id
505 async def build_completion_kwargs( 506 self, 507 provider: KilnModelProvider, 508 messages: list[ChatCompletionMessageIncludingLiteLLM], 509 top_logprobs: int | None, 510 skip_response_format: bool = False, 511 ) -> dict[str, Any]: 512 run_config = as_kiln_agent_run_config(self.run_config) 513 extra_body = self.build_extra_body(provider) 514 515 # Merge all parameters into a single kwargs dict for litellm 516 completion_kwargs = { 517 "model": self.litellm_model_id(), 518 "messages": messages, 519 "api_base": self._api_base, 520 "headers": self._headers, 521 "temperature": run_config.temperature, 522 "top_p": run_config.top_p, 523 # This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc. 524 # Not all models and providers support all openai params (for example, o3 doesn't support top_p) 525 # Better to ignore them than to fail the model call. 526 # https://docs.litellm.ai/docs/completion/input 527 "drop_params": True, 528 **extra_body, 529 **self._additional_body_options, 530 } 531 532 tool_calls = await self.litellm_tools() 533 has_tools = len(tool_calls) > 0 534 if has_tools: 535 completion_kwargs["tools"] = tool_calls 536 completion_kwargs["tool_choice"] = "auto" 537 538 # Special condition for Claude Opus 4.1 and Sonnet 4.5, where we can only specify top_p or temp, not both. 539 # Remove default values (1.0) prioritizing anything the user customized, then error with helpful message if they are both custom. 540 if provider.temp_top_p_exclusive: 541 if "top_p" in completion_kwargs and completion_kwargs["top_p"] == 1.0: 542 del completion_kwargs["top_p"] 543 if ( 544 "temperature" in completion_kwargs 545 and completion_kwargs["temperature"] == 1.0 546 ): 547 del completion_kwargs["temperature"] 548 if "top_p" in completion_kwargs and "temperature" in completion_kwargs: 549 raise ValueError( 550 "top_p and temperature can not both have custom values for this model. This is a restriction from the model provider. Please set only one of them to a custom value (not 1.0)." 551 ) 552 553 if not skip_response_format: 554 # Response format: json_schema, json_instructions, json_mode, function_calling, etc 555 response_format_options = await self.response_format_options() 556 557 # Check for a conflict between tools and response format using tools 558 # We could reconsider this. Model could be able to choose between a final answer or a tool call on any turn. However, good models for tools tend to also support json_schea, so do we need to support both? If we do, merge them, and consider auto vs forced when merging (only forced for final, auto for merged). 559 if has_tools and "tools" in response_format_options: 560 raise ValueError( 561 "Function calling/tools can't be used as the JSON response format if you're also using tools. Please select a different structured output mode." 562 ) 563 564 completion_kwargs.update(response_format_options) 565 566 if top_logprobs is not None: 567 completion_kwargs["logprobs"] = True 568 completion_kwargs["top_logprobs"] = top_logprobs 569 570 return completion_kwargs
572 def usage_from_response(self, response: ModelResponse) -> Usage: 573 litellm_usage = response.get("usage", None) 574 575 # LiteLLM isn't consistent in how it returns the cost. 576 cost = response._hidden_params.get("response_cost", None) 577 if cost is None and litellm_usage: 578 cost = litellm_usage.get("cost", None) 579 580 usage = Usage() 581 582 if not litellm_usage and not cost: 583 return usage 584 585 if litellm_usage and isinstance(litellm_usage, LiteLlmUsage): 586 usage.input_tokens = litellm_usage.get("prompt_tokens", None) 587 usage.output_tokens = litellm_usage.get("completion_tokens", None) 588 usage.total_tokens = litellm_usage.get("total_tokens", None) 589 else: 590 logger.warning( 591 f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}" 592 ) 593 594 if isinstance(cost, float): 595 usage.cost = cost 596 elif cost is not None: 597 # None is allowed, but no other types are expected 598 logger.warning( 599 f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}" 600 ) 601 602 return usage
615 async def process_tool_calls( 616 self, tool_calls: list[ChatCompletionMessageToolCall] | None 617 ) -> tuple[str | None, list[ChatCompletionToolMessageParamWrapper]]: 618 if tool_calls is None: 619 return None, [] 620 621 assistant_output_from_toolcall: str | None = None 622 tool_call_response_messages: list[ChatCompletionToolMessageParamWrapper] = [] 623 tool_run_coroutines = [] 624 625 for tool_call in tool_calls: 626 # Kiln "task_response" tool is used for returning structured output via tool calls. 627 # Load the output from the tool call. Also 628 if tool_call.function.name == "task_response": 629 assistant_output_from_toolcall = tool_call.function.arguments 630 continue 631 632 # Process normal tool calls (not the "task_response" tool) 633 tool_name = tool_call.function.name 634 tool = None 635 for tool_option in await self.cached_available_tools(): 636 if await tool_option.name() == tool_name: 637 tool = tool_option 638 break 639 if not tool: 640 raise RuntimeError( 641 f"A tool named '{tool_name}' was invoked by a model, but was not available." 642 ) 643 644 # Parse the arguments and validate them against the tool's schema 645 try: 646 parsed_args = json.loads(tool_call.function.arguments) 647 except json.JSONDecodeError: 648 raise RuntimeError( 649 f"Failed to parse arguments for tool '{tool_name}' (should be JSON): {tool_call.function.arguments}" 650 ) 651 try: 652 tool_call_definition = await tool.toolcall_definition() 653 json_schema = json.dumps(tool_call_definition["function"]["parameters"]) 654 validate_schema_with_value_error(parsed_args, json_schema) 655 except Exception as e: 656 raise RuntimeError( 657 f"Failed to validate arguments for tool '{tool_name}'. The arguments didn't match the tool's schema. The arguments were: {parsed_args}\n The error was: {e}" 658 ) from e 659 660 # Create context with the calling task's allow_saving setting 661 context = ToolCallContext( 662 allow_saving=self.base_adapter_config.allow_saving 663 ) 664 665 async def run_tool_and_format( 666 t=tool, c=context, args=parsed_args, tc_id=tool_call.id 667 ): 668 result = await t.run(c, **args) 669 return ChatCompletionToolMessageParamWrapper( 670 role="tool", 671 tool_call_id=tc_id, 672 content=result.output, 673 kiln_task_tool_data=result.kiln_task_tool_data 674 if isinstance(result, KilnTaskToolResult) 675 else None, 676 ) 677 678 tool_run_coroutines.append(run_tool_and_format()) 679 680 if tool_run_coroutines: 681 tool_call_response_messages = await asyncio.gather(*tool_run_coroutines) 682 683 if ( 684 assistant_output_from_toolcall is not None 685 and len(tool_call_response_messages) > 0 686 ): 687 raise RuntimeError( 688 "Model asked for impossible combination: task_response tool call and other tool calls were both provided in the same turn. This is not supported as it means the model asked us to both return task_response results (ending the turn) and run new tools calls to send back to the model. If the model makes this mistake often, try a difference structured data model like JSON schema, where this is impossible." 689 ) 690 691 return assistant_output_from_toolcall, tool_call_response_messages
693 def litellm_message_to_trace_message( 694 self, raw_message: LiteLLMMessage 695 ) -> ChatCompletionAssistantMessageParamWrapper: 696 """ 697 Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper 698 """ 699 message: ChatCompletionAssistantMessageParamWrapper = { 700 "role": "assistant", 701 } 702 if raw_message.role != "assistant": 703 raise ValueError( 704 "Model returned a message with a role other than assistant. This is not supported." 705 ) 706 707 if hasattr(raw_message, "content"): 708 message["content"] = raw_message.content 709 if hasattr(raw_message, "reasoning_content"): 710 message["reasoning_content"] = raw_message.reasoning_content 711 if hasattr(raw_message, "tool_calls"): 712 # Convert ChatCompletionMessageToolCall to ChatCompletionMessageToolCallParam 713 open_ai_tool_calls: List[ChatCompletionMessageToolCallParam] = [] 714 for litellm_tool_call in raw_message.tool_calls or []: 715 # Optional in the SDK for streaming responses, but should never be None at this point. 716 if litellm_tool_call.function.name is None: 717 raise ValueError( 718 "The model requested a tool call, without providing a function name (required)." 719 ) 720 open_ai_tool_calls.append( 721 ChatCompletionMessageToolCallParam( 722 id=litellm_tool_call.id, 723 type="function", 724 function={ 725 "name": litellm_tool_call.function.name, 726 "arguments": litellm_tool_call.function.arguments, 727 }, 728 ) 729 ) 730 if len(open_ai_tool_calls) > 0: 731 message["tool_calls"] = open_ai_tool_calls 732 733 if not message.get("content") and not message.get("tool_calls"): 734 raise ValueError( 735 "Model returned an assistant message, but no content or tool calls. This is not supported." 736 ) 737 738 return message
Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper
740 def all_messages_to_trace( 741 self, messages: list[ChatCompletionMessageIncludingLiteLLM] 742 ) -> list[ChatCompletionMessageParam]: 743 """ 744 Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types. 745 """ 746 trace: list[ChatCompletionMessageParam] = [] 747 for message in messages: 748 if isinstance(message, LiteLLMMessage): 749 trace.append(self.litellm_message_to_trace_message(message)) 750 else: 751 trace.append(message) 752 return trace
Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types.
Inherited Members
- kiln_ai.adapters.model_adapters.base_adapter.BaseAdapter
- task
- run_config
- base_adapter_config
- output_schema
- input_schema
- model_provider
- invoke
- invoke_returning_run_output
- has_structured_output
- build_prompt
- build_chat_formatter
- generate_run
- update_run_config_unknown_structured_output_mode
- available_tools