kiln_ai.adapters.model_adapters.litellm_adapter
1import asyncio 2import copy 3import json 4import logging 5from dataclasses import dataclass 6from typing import Any, Dict, List, Tuple, TypeAlias, Union 7 8import litellm 9from litellm.types.utils import ( 10 ChatCompletionMessageToolCall, 11 ChoiceLogprobs, 12 Choices, 13 ModelResponse, 14) 15from litellm.types.utils import Message as LiteLLMMessage 16from litellm.types.utils import Usage as LiteLlmUsage 17from openai.types.chat.chat_completion_message_tool_call_param import ( 18 ChatCompletionMessageToolCallParam, 19) 20 21import kiln_ai.datamodel as datamodel 22from kiln_ai.adapters.ml_model_list import ( 23 KilnModelProvider, 24 ModelProviderName, 25 StructuredOutputMode, 26) 27from kiln_ai.adapters.model_adapters.base_adapter import ( 28 AdapterConfig, 29 BaseAdapter, 30 RunOutput, 31 Usage, 32) 33from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig 34from kiln_ai.datamodel.datamodel_enums import InputType 35from kiln_ai.datamodel.json_schema import validate_schema_with_value_error 36from kiln_ai.datamodel.run_config import ( 37 KilnAgentRunConfigProperties, 38 as_kiln_agent_run_config, 39) 40from kiln_ai.tools.base_tool import ( 41 KilnToolInterface, 42 ToolCallContext, 43 ToolCallDefinition, 44) 45from kiln_ai.tools.kiln_task_tool import KilnTaskToolResult 46from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error 47from kiln_ai.utils.litellm import get_litellm_provider_info 48from kiln_ai.utils.open_ai_types import ( 49 ChatCompletionAssistantMessageParamWrapper, 50 ChatCompletionMessageParam, 51 ChatCompletionToolMessageParamWrapper, 52) 53 54MAX_CALLS_PER_TURN = 10 55MAX_TOOL_CALLS_PER_TURN = 30 56 57logger = logging.getLogger(__name__) 58 59ChatCompletionMessageIncludingLiteLLM: TypeAlias = Union[ 60 ChatCompletionMessageParam, LiteLLMMessage 61] 62 63 64@dataclass 65class ModelTurnResult: 66 assistant_message: str 67 all_messages: list[ChatCompletionMessageIncludingLiteLLM] 68 model_response: ModelResponse | None 69 model_choice: Choices | None 70 usage: Usage 71 72 73class LiteLlmAdapter(BaseAdapter): 74 def __init__( 75 self, 76 config: LiteLlmConfig, 77 kiln_task: datamodel.Task, 78 base_adapter_config: AdapterConfig | None = None, 79 ): 80 if not isinstance(config.run_config_properties, KilnAgentRunConfigProperties): 81 raise ValueError("LiteLlmAdapter requires KilnAgentRunConfigProperties") 82 self.config = config 83 self._additional_body_options = config.additional_body_options 84 self._api_base = config.base_url 85 self._headers = config.default_headers 86 self._litellm_model_id: str | None = None 87 self._cached_available_tools: list[KilnToolInterface] | None = None 88 89 super().__init__( 90 task=kiln_task, 91 run_config=config.run_config_properties, 92 config=base_adapter_config, 93 ) 94 95 async def _run_model_turn( 96 self, 97 provider: KilnModelProvider, 98 prior_messages: list[ChatCompletionMessageIncludingLiteLLM], 99 top_logprobs: int | None, 100 skip_response_format: bool, 101 ) -> ModelTurnResult: 102 """ 103 Call the model for a single top level turn: from user message to agent message. 104 105 It may make handle iterations of tool calls between the user/agent message if needed. 106 """ 107 108 usage = Usage() 109 messages = list(prior_messages) 110 tool_calls_count = 0 111 112 while tool_calls_count < MAX_TOOL_CALLS_PER_TURN: 113 # Build completion kwargs for tool calls 114 completion_kwargs = await self.build_completion_kwargs( 115 provider, 116 # Pass a copy, as acompletion mutates objects and breaks types. 117 copy.deepcopy(messages), 118 top_logprobs, 119 skip_response_format, 120 ) 121 122 # Make the completion call 123 model_response, response_choice = await self.acompletion_checking_response( 124 **completion_kwargs 125 ) 126 127 # count the usage 128 usage += self.usage_from_response(model_response) 129 130 # Extract content and tool calls 131 if not hasattr(response_choice, "message"): 132 raise ValueError("Response choice has no message") 133 content = response_choice.message.content 134 tool_calls = response_choice.message.tool_calls 135 if not content and not tool_calls: 136 raise ValueError( 137 "Model returned an assistant message, but no content or tool calls. This is not supported." 138 ) 139 140 # Add message to messages, so it can be used in the next turn 141 messages.append(response_choice.message) 142 143 # Process tool calls if any 144 if tool_calls and len(tool_calls) > 0: 145 ( 146 assistant_message_from_toolcall, 147 tool_call_messages, 148 ) = await self.process_tool_calls(tool_calls) 149 150 # Add tool call results to messages 151 messages.extend(tool_call_messages) 152 153 # If task_response tool was called, we're done 154 if assistant_message_from_toolcall is not None: 155 return ModelTurnResult( 156 assistant_message=assistant_message_from_toolcall, 157 all_messages=messages, 158 model_response=model_response, 159 model_choice=response_choice, 160 usage=usage, 161 ) 162 163 # If there were tool calls, increment counter and continue 164 if tool_call_messages: 165 tool_calls_count += 1 166 continue 167 168 # If no tool calls, return the content as final output 169 if content: 170 return ModelTurnResult( 171 assistant_message=content, 172 all_messages=messages, 173 model_response=model_response, 174 model_choice=response_choice, 175 usage=usage, 176 ) 177 178 # If we get here with no content and no tool calls, break 179 raise RuntimeError( 180 "Model returned neither content nor tool calls. It must return at least one of these." 181 ) 182 183 raise RuntimeError( 184 f"Too many tool calls ({tool_calls_count}). Stopping iteration to avoid using too many tokens." 185 ) 186 187 async def _run(self, input: InputType) -> tuple[RunOutput, Usage | None]: 188 usage = Usage() 189 190 provider = self.model_provider() 191 if not provider.model_id: 192 raise ValueError("Model ID is required for OpenAI compatible models") 193 194 chat_formatter = self.build_chat_formatter(input) 195 messages: list[ChatCompletionMessageIncludingLiteLLM] = [] 196 197 prior_output: str | None = None 198 final_choice: Choices | None = None 199 turns = 0 200 201 while True: 202 turns += 1 203 if turns > MAX_CALLS_PER_TURN: 204 raise RuntimeError( 205 f"Too many turns ({turns}). Stopping iteration to avoid using too many tokens." 206 ) 207 208 turn = chat_formatter.next_turn(prior_output) 209 if turn is None: 210 # No next turn, we're done 211 break 212 213 # Add messages from the turn to chat history 214 for message in turn.messages: 215 if message.content is None: 216 raise ValueError("Empty message content isn't allowed") 217 # pyright incorrectly warns about this, but it's valid so we can ignore. It can't handle the multi-value role. 218 messages.append({"role": message.role, "content": message.content}) # type: ignore 219 220 skip_response_format = not turn.final_call 221 turn_result = await self._run_model_turn( 222 provider, 223 messages, 224 self.base_adapter_config.top_logprobs if turn.final_call else None, 225 skip_response_format, 226 ) 227 228 usage += turn_result.usage 229 230 prior_output = turn_result.assistant_message 231 messages = turn_result.all_messages 232 final_choice = turn_result.model_choice 233 234 if not prior_output: 235 raise RuntimeError("No assistant message/output returned from model") 236 237 logprobs = self._extract_and_validate_logprobs(final_choice) 238 239 # Save COT/reasoning if it exists. May be a message, or may be parsed by LiteLLM (or openrouter, or anyone upstream) 240 intermediate_outputs = chat_formatter.intermediate_outputs() 241 self._extract_reasoning_to_intermediate_outputs( 242 final_choice, intermediate_outputs 243 ) 244 245 if not isinstance(prior_output, str): 246 raise RuntimeError(f"assistant message is not a string: {prior_output}") 247 248 trace = self.all_messages_to_trace(messages) 249 output = RunOutput( 250 output=prior_output, 251 intermediate_outputs=intermediate_outputs, 252 output_logprobs=logprobs, 253 trace=trace, 254 ) 255 256 return output, usage 257 258 def _extract_and_validate_logprobs( 259 self, final_choice: Choices | None 260 ) -> ChoiceLogprobs | None: 261 """ 262 Extract logprobs from the final choice and validate they exist if required. 263 """ 264 logprobs = None 265 if ( 266 final_choice is not None 267 and hasattr(final_choice, "logprobs") 268 and isinstance(final_choice.logprobs, ChoiceLogprobs) 269 ): 270 logprobs = final_choice.logprobs 271 272 # Check logprobs worked, if required 273 if self.base_adapter_config.top_logprobs is not None and logprobs is None: 274 raise RuntimeError("Logprobs were required, but no logprobs were returned.") 275 276 return logprobs 277 278 def _extract_reasoning_to_intermediate_outputs( 279 self, final_choice: Choices | None, intermediate_outputs: Dict[str, Any] 280 ) -> None: 281 """Extract reasoning content from model choice and add to intermediate outputs if present.""" 282 if ( 283 final_choice is not None 284 and hasattr(final_choice, "message") 285 and hasattr(final_choice.message, "reasoning_content") 286 ): 287 reasoning_content = final_choice.message.reasoning_content 288 if reasoning_content is not None: 289 stripped_reasoning_content = reasoning_content.strip() 290 if len(stripped_reasoning_content) > 0: 291 intermediate_outputs["reasoning"] = stripped_reasoning_content 292 293 async def acompletion_checking_response( 294 self, **kwargs 295 ) -> Tuple[ModelResponse, Choices]: 296 response = await litellm.acompletion(**kwargs) 297 if ( 298 not isinstance(response, ModelResponse) 299 or not response.choices 300 or len(response.choices) == 0 301 or not isinstance(response.choices[0], Choices) 302 ): 303 raise RuntimeError( 304 f"Expected ModelResponse with Choices, got {type(response)}." 305 ) 306 return response, response.choices[0] 307 308 def adapter_name(self) -> str: 309 return "kiln_openai_compatible_adapter" 310 311 async def response_format_options(self) -> dict[str, Any]: 312 # Unstructured if task isn't structured 313 if not self.has_structured_output(): 314 return {} 315 316 run_config = as_kiln_agent_run_config(self.run_config) 317 structured_output_mode: StructuredOutputMode = run_config.structured_output_mode 318 319 match structured_output_mode: 320 case StructuredOutputMode.json_mode: 321 return {"response_format": {"type": "json_object"}} 322 case StructuredOutputMode.json_schema: 323 return self.json_schema_response_format() 324 case StructuredOutputMode.function_calling_weak: 325 return self.tool_call_params(strict=False) 326 case StructuredOutputMode.function_calling: 327 return self.tool_call_params(strict=True) 328 case StructuredOutputMode.json_instructions: 329 # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below). 330 return {} 331 case StructuredOutputMode.json_custom_instructions: 332 # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above). 333 return {} 334 case StructuredOutputMode.json_instruction_and_object: 335 # We set response_format to json_object and also set json instructions in the prompt 336 return {"response_format": {"type": "json_object"}} 337 case StructuredOutputMode.default: 338 provider_name = run_config.model_provider_name 339 if provider_name == ModelProviderName.ollama: 340 # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs 341 return self.json_schema_response_format() 342 elif provider_name == ModelProviderName.docker_model_runner: 343 # Docker Model Runner uses OpenAI-compatible API with JSON schema support 344 return self.json_schema_response_format() 345 else: 346 # Default to function calling -- it's older than the other modes. Higher compatibility. 347 # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI. 348 strict = provider_name == ModelProviderName.openai 349 return self.tool_call_params(strict=strict) 350 case StructuredOutputMode.unknown: 351 # See above, but this case should never happen. 352 raise ValueError("Structured output mode is unknown.") 353 case _: 354 raise_exhaustive_enum_error(structured_output_mode) # type: ignore[arg-type] 355 356 def json_schema_response_format(self) -> dict[str, Any]: 357 output_schema = self.task.output_schema() 358 return { 359 "response_format": { 360 "type": "json_schema", 361 "json_schema": { 362 "name": "task_response", 363 "schema": output_schema, 364 }, 365 } 366 } 367 368 def tool_call_params(self, strict: bool) -> dict[str, Any]: 369 # Add additional_properties: false to the schema (OpenAI requires this for some models) 370 output_schema = self.task.output_schema() 371 if not isinstance(output_schema, dict): 372 raise ValueError( 373 "Invalid output schema for this task. Can not use tool calls." 374 ) 375 output_schema["additionalProperties"] = False 376 377 function_params = { 378 "name": "task_response", 379 "parameters": output_schema, 380 } 381 # This should be on, but we allow setting function_calling_weak for APIs that don't support it. 382 if strict: 383 function_params["strict"] = True 384 385 return { 386 "tools": [ 387 { 388 "type": "function", 389 "function": function_params, 390 } 391 ], 392 "tool_choice": { 393 "type": "function", 394 "function": {"name": "task_response"}, 395 }, 396 } 397 398 def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]: 399 # Don't love having this logic here. But it's worth the usability improvement 400 # so better to keep it than exclude it. Should figure out how I want to isolate 401 # this sort of logic so it's config driven and can be overridden 402 403 extra_body = {} 404 provider_options = {} 405 406 if provider.thinking_level is not None: 407 extra_body["reasoning_effort"] = provider.thinking_level 408 409 if provider.require_openrouter_reasoning: 410 # https://openrouter.ai/docs/use-cases/reasoning-tokens 411 extra_body["reasoning"] = { 412 "exclude": False, 413 } 414 415 if provider.gemini_reasoning_enabled: 416 extra_body["reasoning"] = { 417 "enabled": True, 418 } 419 420 if provider.name == ModelProviderName.openrouter: 421 # Ask OpenRouter to include usage in the response (cost) 422 extra_body["usage"] = {"include": True} 423 424 # Set a default provider order for more deterministic routing. 425 # OpenRouter will ignore providers that don't support the model. 426 # Special cases below (like R1) can override this order. 427 # allow_fallbacks is true by default, but we can override it here. 428 provider_options["order"] = [ 429 "fireworks", 430 "parasail", 431 "together", 432 "deepinfra", 433 "novita", 434 "groq", 435 "amazon-bedrock", 436 "azure", 437 "nebius", 438 ] 439 440 if provider.anthropic_extended_thinking: 441 extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000} 442 443 if provider.r1_openrouter_options: 444 # Require providers that support the reasoning parameter 445 provider_options["require_parameters"] = True 446 # Prefer R1 providers with reasonable perf/quants 447 provider_options["order"] = ["fireworks", "together"] 448 # R1 providers with unreasonable quants 449 provider_options["ignore"] = ["deepinfra"] 450 451 # Only set of this request is to get logprobs. 452 if ( 453 provider.logprobs_openrouter_options 454 and self.base_adapter_config.top_logprobs is not None 455 ): 456 # Don't let OpenRouter choose a provider that doesn't support logprobs. 457 provider_options["require_parameters"] = True 458 # DeepInfra silently fails to return logprobs consistently. 459 provider_options["ignore"] = ["deepinfra"] 460 461 if provider.openrouter_skip_required_parameters: 462 # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params. 463 provider_options["require_parameters"] = False 464 465 # Siliconflow uses a bool flag for thinking, for some models 466 if provider.siliconflow_enable_thinking is not None: 467 extra_body["enable_thinking"] = provider.siliconflow_enable_thinking 468 469 if len(provider_options) > 0: 470 extra_body["provider"] = provider_options 471 472 return extra_body 473 474 def litellm_model_id(self) -> str: 475 # The model ID is an interesting combination of format and url endpoint. 476 # It specifics the provider URL/host, but this is overridden if you manually set an api url 477 if self._litellm_model_id: 478 return self._litellm_model_id 479 480 litellm_provider_info = get_litellm_provider_info(self.model_provider()) 481 if litellm_provider_info.is_custom and self._api_base is None: 482 raise ValueError( 483 "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)" 484 ) 485 486 self._litellm_model_id = litellm_provider_info.litellm_model_id 487 return self._litellm_model_id 488 489 async def build_completion_kwargs( 490 self, 491 provider: KilnModelProvider, 492 messages: list[ChatCompletionMessageIncludingLiteLLM], 493 top_logprobs: int | None, 494 skip_response_format: bool = False, 495 ) -> dict[str, Any]: 496 run_config = as_kiln_agent_run_config(self.run_config) 497 extra_body = self.build_extra_body(provider) 498 499 # Merge all parameters into a single kwargs dict for litellm 500 completion_kwargs = { 501 "model": self.litellm_model_id(), 502 "messages": messages, 503 "api_base": self._api_base, 504 "headers": self._headers, 505 "temperature": run_config.temperature, 506 "top_p": run_config.top_p, 507 # This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc. 508 # Not all models and providers support all openai params (for example, o3 doesn't support top_p) 509 # Better to ignore them than to fail the model call. 510 # https://docs.litellm.ai/docs/completion/input 511 "drop_params": True, 512 **extra_body, 513 **self._additional_body_options, 514 } 515 516 tool_calls = await self.litellm_tools() 517 has_tools = len(tool_calls) > 0 518 if has_tools: 519 completion_kwargs["tools"] = tool_calls 520 completion_kwargs["tool_choice"] = "auto" 521 522 # Special condition for Claude Opus 4.1 and Sonnet 4.5, where we can only specify top_p or temp, not both. 523 # Remove default values (1.0) prioritizing anything the user customized, then error with helpful message if they are both custom. 524 if provider.temp_top_p_exclusive: 525 if "top_p" in completion_kwargs and completion_kwargs["top_p"] == 1.0: 526 del completion_kwargs["top_p"] 527 if ( 528 "temperature" in completion_kwargs 529 and completion_kwargs["temperature"] == 1.0 530 ): 531 del completion_kwargs["temperature"] 532 if "top_p" in completion_kwargs and "temperature" in completion_kwargs: 533 raise ValueError( 534 "top_p and temperature can not both have custom values for this model. This is a restriction from the model provider. Please set only one of them to a custom value (not 1.0)." 535 ) 536 537 if not skip_response_format: 538 # Response format: json_schema, json_instructions, json_mode, function_calling, etc 539 response_format_options = await self.response_format_options() 540 541 # Check for a conflict between tools and response format using tools 542 # We could reconsider this. Model could be able to choose between a final answer or a tool call on any turn. However, good models for tools tend to also support json_schea, so do we need to support both? If we do, merge them, and consider auto vs forced when merging (only forced for final, auto for merged). 543 if has_tools and "tools" in response_format_options: 544 raise ValueError( 545 "Function calling/tools can't be used as the JSON response format if you're also using tools. Please select a different structured output mode." 546 ) 547 548 completion_kwargs.update(response_format_options) 549 550 if top_logprobs is not None: 551 completion_kwargs["logprobs"] = True 552 completion_kwargs["top_logprobs"] = top_logprobs 553 554 return completion_kwargs 555 556 def usage_from_response(self, response: ModelResponse) -> Usage: 557 litellm_usage = response.get("usage", None) 558 559 # LiteLLM isn't consistent in how it returns the cost. 560 cost = response._hidden_params.get("response_cost", None) 561 if cost is None and litellm_usage: 562 cost = litellm_usage.get("cost", None) 563 564 usage = Usage() 565 566 if not litellm_usage and not cost: 567 return usage 568 569 if litellm_usage and isinstance(litellm_usage, LiteLlmUsage): 570 usage.input_tokens = litellm_usage.get("prompt_tokens", None) 571 usage.output_tokens = litellm_usage.get("completion_tokens", None) 572 usage.total_tokens = litellm_usage.get("total_tokens", None) 573 else: 574 logger.warning( 575 f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}" 576 ) 577 578 if isinstance(cost, float): 579 usage.cost = cost 580 elif cost is not None: 581 # None is allowed, but no other types are expected 582 logger.warning( 583 f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}" 584 ) 585 586 return usage 587 588 async def cached_available_tools(self) -> list[KilnToolInterface]: 589 if self._cached_available_tools is None: 590 self._cached_available_tools = await self.available_tools() 591 return self._cached_available_tools 592 593 async def litellm_tools(self) -> list[ToolCallDefinition]: 594 available_tools = await self.cached_available_tools() 595 596 # LiteLLM takes the standard OpenAI-compatible tool call format 597 return [await tool.toolcall_definition() for tool in available_tools] 598 599 async def process_tool_calls( 600 self, tool_calls: list[ChatCompletionMessageToolCall] | None 601 ) -> tuple[str | None, list[ChatCompletionToolMessageParamWrapper]]: 602 if tool_calls is None: 603 return None, [] 604 605 assistant_output_from_toolcall: str | None = None 606 tool_call_response_messages: list[ChatCompletionToolMessageParamWrapper] = [] 607 tool_run_coroutines = [] 608 609 for tool_call in tool_calls: 610 # Kiln "task_response" tool is used for returning structured output via tool calls. 611 # Load the output from the tool call. Also 612 if tool_call.function.name == "task_response": 613 assistant_output_from_toolcall = tool_call.function.arguments 614 continue 615 616 # Process normal tool calls (not the "task_response" tool) 617 tool_name = tool_call.function.name 618 tool = None 619 for tool_option in await self.cached_available_tools(): 620 if await tool_option.name() == tool_name: 621 tool = tool_option 622 break 623 if not tool: 624 raise RuntimeError( 625 f"A tool named '{tool_name}' was invoked by a model, but was not available." 626 ) 627 628 # Parse the arguments and validate them against the tool's schema 629 try: 630 parsed_args = json.loads(tool_call.function.arguments) 631 except json.JSONDecodeError: 632 raise RuntimeError( 633 f"Failed to parse arguments for tool '{tool_name}' (should be JSON): {tool_call.function.arguments}" 634 ) 635 try: 636 tool_call_definition = await tool.toolcall_definition() 637 json_schema = json.dumps(tool_call_definition["function"]["parameters"]) 638 validate_schema_with_value_error(parsed_args, json_schema) 639 except Exception as e: 640 raise RuntimeError( 641 f"Failed to validate arguments for tool '{tool_name}'. The arguments didn't match the tool's schema. The arguments were: {parsed_args}\n The error was: {e}" 642 ) from e 643 644 # Create context with the calling task's allow_saving setting 645 context = ToolCallContext( 646 allow_saving=self.base_adapter_config.allow_saving 647 ) 648 649 async def run_tool_and_format( 650 t=tool, c=context, args=parsed_args, tc_id=tool_call.id 651 ): 652 result = await t.run(c, **args) 653 return ChatCompletionToolMessageParamWrapper( 654 role="tool", 655 tool_call_id=tc_id, 656 content=result.output, 657 kiln_task_tool_data=result.kiln_task_tool_data 658 if isinstance(result, KilnTaskToolResult) 659 else None, 660 ) 661 662 tool_run_coroutines.append(run_tool_and_format()) 663 664 if tool_run_coroutines: 665 tool_call_response_messages = await asyncio.gather(*tool_run_coroutines) 666 667 if ( 668 assistant_output_from_toolcall is not None 669 and len(tool_call_response_messages) > 0 670 ): 671 raise RuntimeError( 672 "Model asked for impossible combination: task_response tool call and other tool calls were both provided in the same turn. This is not supported as it means the model asked us to both return task_response results (ending the turn) and run new tools calls to send back to the model. If the model makes this mistake often, try a difference structured data model like JSON schema, where this is impossible." 673 ) 674 675 return assistant_output_from_toolcall, tool_call_response_messages 676 677 def litellm_message_to_trace_message( 678 self, raw_message: LiteLLMMessage 679 ) -> ChatCompletionAssistantMessageParamWrapper: 680 """ 681 Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper 682 """ 683 message: ChatCompletionAssistantMessageParamWrapper = { 684 "role": "assistant", 685 } 686 if raw_message.role != "assistant": 687 raise ValueError( 688 "Model returned a message with a role other than assistant. This is not supported." 689 ) 690 691 if hasattr(raw_message, "content"): 692 message["content"] = raw_message.content 693 if hasattr(raw_message, "reasoning_content"): 694 message["reasoning_content"] = raw_message.reasoning_content 695 if hasattr(raw_message, "tool_calls"): 696 # Convert ChatCompletionMessageToolCall to ChatCompletionMessageToolCallParam 697 open_ai_tool_calls: List[ChatCompletionMessageToolCallParam] = [] 698 for litellm_tool_call in raw_message.tool_calls or []: 699 # Optional in the SDK for streaming responses, but should never be None at this point. 700 if litellm_tool_call.function.name is None: 701 raise ValueError( 702 "The model requested a tool call, without providing a function name (required)." 703 ) 704 open_ai_tool_calls.append( 705 ChatCompletionMessageToolCallParam( 706 id=litellm_tool_call.id, 707 type="function", 708 function={ 709 "name": litellm_tool_call.function.name, 710 "arguments": litellm_tool_call.function.arguments, 711 }, 712 ) 713 ) 714 if len(open_ai_tool_calls) > 0: 715 message["tool_calls"] = open_ai_tool_calls 716 717 if not message.get("content") and not message.get("tool_calls"): 718 raise ValueError( 719 "Model returned an assistant message, but no content or tool calls. This is not supported." 720 ) 721 722 return message 723 724 def all_messages_to_trace( 725 self, messages: list[ChatCompletionMessageIncludingLiteLLM] 726 ) -> list[ChatCompletionMessageParam]: 727 """ 728 Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types. 729 """ 730 trace: list[ChatCompletionMessageParam] = [] 731 for message in messages: 732 if isinstance(message, LiteLLMMessage): 733 trace.append(self.litellm_message_to_trace_message(message)) 734 else: 735 trace.append(message) 736 return trace
65@dataclass 66class ModelTurnResult: 67 assistant_message: str 68 all_messages: list[ChatCompletionMessageIncludingLiteLLM] 69 model_response: ModelResponse | None 70 model_choice: Choices | None 71 usage: Usage
74class LiteLlmAdapter(BaseAdapter): 75 def __init__( 76 self, 77 config: LiteLlmConfig, 78 kiln_task: datamodel.Task, 79 base_adapter_config: AdapterConfig | None = None, 80 ): 81 if not isinstance(config.run_config_properties, KilnAgentRunConfigProperties): 82 raise ValueError("LiteLlmAdapter requires KilnAgentRunConfigProperties") 83 self.config = config 84 self._additional_body_options = config.additional_body_options 85 self._api_base = config.base_url 86 self._headers = config.default_headers 87 self._litellm_model_id: str | None = None 88 self._cached_available_tools: list[KilnToolInterface] | None = None 89 90 super().__init__( 91 task=kiln_task, 92 run_config=config.run_config_properties, 93 config=base_adapter_config, 94 ) 95 96 async def _run_model_turn( 97 self, 98 provider: KilnModelProvider, 99 prior_messages: list[ChatCompletionMessageIncludingLiteLLM], 100 top_logprobs: int | None, 101 skip_response_format: bool, 102 ) -> ModelTurnResult: 103 """ 104 Call the model for a single top level turn: from user message to agent message. 105 106 It may make handle iterations of tool calls between the user/agent message if needed. 107 """ 108 109 usage = Usage() 110 messages = list(prior_messages) 111 tool_calls_count = 0 112 113 while tool_calls_count < MAX_TOOL_CALLS_PER_TURN: 114 # Build completion kwargs for tool calls 115 completion_kwargs = await self.build_completion_kwargs( 116 provider, 117 # Pass a copy, as acompletion mutates objects and breaks types. 118 copy.deepcopy(messages), 119 top_logprobs, 120 skip_response_format, 121 ) 122 123 # Make the completion call 124 model_response, response_choice = await self.acompletion_checking_response( 125 **completion_kwargs 126 ) 127 128 # count the usage 129 usage += self.usage_from_response(model_response) 130 131 # Extract content and tool calls 132 if not hasattr(response_choice, "message"): 133 raise ValueError("Response choice has no message") 134 content = response_choice.message.content 135 tool_calls = response_choice.message.tool_calls 136 if not content and not tool_calls: 137 raise ValueError( 138 "Model returned an assistant message, but no content or tool calls. This is not supported." 139 ) 140 141 # Add message to messages, so it can be used in the next turn 142 messages.append(response_choice.message) 143 144 # Process tool calls if any 145 if tool_calls and len(tool_calls) > 0: 146 ( 147 assistant_message_from_toolcall, 148 tool_call_messages, 149 ) = await self.process_tool_calls(tool_calls) 150 151 # Add tool call results to messages 152 messages.extend(tool_call_messages) 153 154 # If task_response tool was called, we're done 155 if assistant_message_from_toolcall is not None: 156 return ModelTurnResult( 157 assistant_message=assistant_message_from_toolcall, 158 all_messages=messages, 159 model_response=model_response, 160 model_choice=response_choice, 161 usage=usage, 162 ) 163 164 # If there were tool calls, increment counter and continue 165 if tool_call_messages: 166 tool_calls_count += 1 167 continue 168 169 # If no tool calls, return the content as final output 170 if content: 171 return ModelTurnResult( 172 assistant_message=content, 173 all_messages=messages, 174 model_response=model_response, 175 model_choice=response_choice, 176 usage=usage, 177 ) 178 179 # If we get here with no content and no tool calls, break 180 raise RuntimeError( 181 "Model returned neither content nor tool calls. It must return at least one of these." 182 ) 183 184 raise RuntimeError( 185 f"Too many tool calls ({tool_calls_count}). Stopping iteration to avoid using too many tokens." 186 ) 187 188 async def _run(self, input: InputType) -> tuple[RunOutput, Usage | None]: 189 usage = Usage() 190 191 provider = self.model_provider() 192 if not provider.model_id: 193 raise ValueError("Model ID is required for OpenAI compatible models") 194 195 chat_formatter = self.build_chat_formatter(input) 196 messages: list[ChatCompletionMessageIncludingLiteLLM] = [] 197 198 prior_output: str | None = None 199 final_choice: Choices | None = None 200 turns = 0 201 202 while True: 203 turns += 1 204 if turns > MAX_CALLS_PER_TURN: 205 raise RuntimeError( 206 f"Too many turns ({turns}). Stopping iteration to avoid using too many tokens." 207 ) 208 209 turn = chat_formatter.next_turn(prior_output) 210 if turn is None: 211 # No next turn, we're done 212 break 213 214 # Add messages from the turn to chat history 215 for message in turn.messages: 216 if message.content is None: 217 raise ValueError("Empty message content isn't allowed") 218 # pyright incorrectly warns about this, but it's valid so we can ignore. It can't handle the multi-value role. 219 messages.append({"role": message.role, "content": message.content}) # type: ignore 220 221 skip_response_format = not turn.final_call 222 turn_result = await self._run_model_turn( 223 provider, 224 messages, 225 self.base_adapter_config.top_logprobs if turn.final_call else None, 226 skip_response_format, 227 ) 228 229 usage += turn_result.usage 230 231 prior_output = turn_result.assistant_message 232 messages = turn_result.all_messages 233 final_choice = turn_result.model_choice 234 235 if not prior_output: 236 raise RuntimeError("No assistant message/output returned from model") 237 238 logprobs = self._extract_and_validate_logprobs(final_choice) 239 240 # Save COT/reasoning if it exists. May be a message, or may be parsed by LiteLLM (or openrouter, or anyone upstream) 241 intermediate_outputs = chat_formatter.intermediate_outputs() 242 self._extract_reasoning_to_intermediate_outputs( 243 final_choice, intermediate_outputs 244 ) 245 246 if not isinstance(prior_output, str): 247 raise RuntimeError(f"assistant message is not a string: {prior_output}") 248 249 trace = self.all_messages_to_trace(messages) 250 output = RunOutput( 251 output=prior_output, 252 intermediate_outputs=intermediate_outputs, 253 output_logprobs=logprobs, 254 trace=trace, 255 ) 256 257 return output, usage 258 259 def _extract_and_validate_logprobs( 260 self, final_choice: Choices | None 261 ) -> ChoiceLogprobs | None: 262 """ 263 Extract logprobs from the final choice and validate they exist if required. 264 """ 265 logprobs = None 266 if ( 267 final_choice is not None 268 and hasattr(final_choice, "logprobs") 269 and isinstance(final_choice.logprobs, ChoiceLogprobs) 270 ): 271 logprobs = final_choice.logprobs 272 273 # Check logprobs worked, if required 274 if self.base_adapter_config.top_logprobs is not None and logprobs is None: 275 raise RuntimeError("Logprobs were required, but no logprobs were returned.") 276 277 return logprobs 278 279 def _extract_reasoning_to_intermediate_outputs( 280 self, final_choice: Choices | None, intermediate_outputs: Dict[str, Any] 281 ) -> None: 282 """Extract reasoning content from model choice and add to intermediate outputs if present.""" 283 if ( 284 final_choice is not None 285 and hasattr(final_choice, "message") 286 and hasattr(final_choice.message, "reasoning_content") 287 ): 288 reasoning_content = final_choice.message.reasoning_content 289 if reasoning_content is not None: 290 stripped_reasoning_content = reasoning_content.strip() 291 if len(stripped_reasoning_content) > 0: 292 intermediate_outputs["reasoning"] = stripped_reasoning_content 293 294 async def acompletion_checking_response( 295 self, **kwargs 296 ) -> Tuple[ModelResponse, Choices]: 297 response = await litellm.acompletion(**kwargs) 298 if ( 299 not isinstance(response, ModelResponse) 300 or not response.choices 301 or len(response.choices) == 0 302 or not isinstance(response.choices[0], Choices) 303 ): 304 raise RuntimeError( 305 f"Expected ModelResponse with Choices, got {type(response)}." 306 ) 307 return response, response.choices[0] 308 309 def adapter_name(self) -> str: 310 return "kiln_openai_compatible_adapter" 311 312 async def response_format_options(self) -> dict[str, Any]: 313 # Unstructured if task isn't structured 314 if not self.has_structured_output(): 315 return {} 316 317 run_config = as_kiln_agent_run_config(self.run_config) 318 structured_output_mode: StructuredOutputMode = run_config.structured_output_mode 319 320 match structured_output_mode: 321 case StructuredOutputMode.json_mode: 322 return {"response_format": {"type": "json_object"}} 323 case StructuredOutputMode.json_schema: 324 return self.json_schema_response_format() 325 case StructuredOutputMode.function_calling_weak: 326 return self.tool_call_params(strict=False) 327 case StructuredOutputMode.function_calling: 328 return self.tool_call_params(strict=True) 329 case StructuredOutputMode.json_instructions: 330 # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below). 331 return {} 332 case StructuredOutputMode.json_custom_instructions: 333 # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above). 334 return {} 335 case StructuredOutputMode.json_instruction_and_object: 336 # We set response_format to json_object and also set json instructions in the prompt 337 return {"response_format": {"type": "json_object"}} 338 case StructuredOutputMode.default: 339 provider_name = run_config.model_provider_name 340 if provider_name == ModelProviderName.ollama: 341 # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs 342 return self.json_schema_response_format() 343 elif provider_name == ModelProviderName.docker_model_runner: 344 # Docker Model Runner uses OpenAI-compatible API with JSON schema support 345 return self.json_schema_response_format() 346 else: 347 # Default to function calling -- it's older than the other modes. Higher compatibility. 348 # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI. 349 strict = provider_name == ModelProviderName.openai 350 return self.tool_call_params(strict=strict) 351 case StructuredOutputMode.unknown: 352 # See above, but this case should never happen. 353 raise ValueError("Structured output mode is unknown.") 354 case _: 355 raise_exhaustive_enum_error(structured_output_mode) # type: ignore[arg-type] 356 357 def json_schema_response_format(self) -> dict[str, Any]: 358 output_schema = self.task.output_schema() 359 return { 360 "response_format": { 361 "type": "json_schema", 362 "json_schema": { 363 "name": "task_response", 364 "schema": output_schema, 365 }, 366 } 367 } 368 369 def tool_call_params(self, strict: bool) -> dict[str, Any]: 370 # Add additional_properties: false to the schema (OpenAI requires this for some models) 371 output_schema = self.task.output_schema() 372 if not isinstance(output_schema, dict): 373 raise ValueError( 374 "Invalid output schema for this task. Can not use tool calls." 375 ) 376 output_schema["additionalProperties"] = False 377 378 function_params = { 379 "name": "task_response", 380 "parameters": output_schema, 381 } 382 # This should be on, but we allow setting function_calling_weak for APIs that don't support it. 383 if strict: 384 function_params["strict"] = True 385 386 return { 387 "tools": [ 388 { 389 "type": "function", 390 "function": function_params, 391 } 392 ], 393 "tool_choice": { 394 "type": "function", 395 "function": {"name": "task_response"}, 396 }, 397 } 398 399 def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]: 400 # Don't love having this logic here. But it's worth the usability improvement 401 # so better to keep it than exclude it. Should figure out how I want to isolate 402 # this sort of logic so it's config driven and can be overridden 403 404 extra_body = {} 405 provider_options = {} 406 407 if provider.thinking_level is not None: 408 extra_body["reasoning_effort"] = provider.thinking_level 409 410 if provider.require_openrouter_reasoning: 411 # https://openrouter.ai/docs/use-cases/reasoning-tokens 412 extra_body["reasoning"] = { 413 "exclude": False, 414 } 415 416 if provider.gemini_reasoning_enabled: 417 extra_body["reasoning"] = { 418 "enabled": True, 419 } 420 421 if provider.name == ModelProviderName.openrouter: 422 # Ask OpenRouter to include usage in the response (cost) 423 extra_body["usage"] = {"include": True} 424 425 # Set a default provider order for more deterministic routing. 426 # OpenRouter will ignore providers that don't support the model. 427 # Special cases below (like R1) can override this order. 428 # allow_fallbacks is true by default, but we can override it here. 429 provider_options["order"] = [ 430 "fireworks", 431 "parasail", 432 "together", 433 "deepinfra", 434 "novita", 435 "groq", 436 "amazon-bedrock", 437 "azure", 438 "nebius", 439 ] 440 441 if provider.anthropic_extended_thinking: 442 extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000} 443 444 if provider.r1_openrouter_options: 445 # Require providers that support the reasoning parameter 446 provider_options["require_parameters"] = True 447 # Prefer R1 providers with reasonable perf/quants 448 provider_options["order"] = ["fireworks", "together"] 449 # R1 providers with unreasonable quants 450 provider_options["ignore"] = ["deepinfra"] 451 452 # Only set of this request is to get logprobs. 453 if ( 454 provider.logprobs_openrouter_options 455 and self.base_adapter_config.top_logprobs is not None 456 ): 457 # Don't let OpenRouter choose a provider that doesn't support logprobs. 458 provider_options["require_parameters"] = True 459 # DeepInfra silently fails to return logprobs consistently. 460 provider_options["ignore"] = ["deepinfra"] 461 462 if provider.openrouter_skip_required_parameters: 463 # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params. 464 provider_options["require_parameters"] = False 465 466 # Siliconflow uses a bool flag for thinking, for some models 467 if provider.siliconflow_enable_thinking is not None: 468 extra_body["enable_thinking"] = provider.siliconflow_enable_thinking 469 470 if len(provider_options) > 0: 471 extra_body["provider"] = provider_options 472 473 return extra_body 474 475 def litellm_model_id(self) -> str: 476 # The model ID is an interesting combination of format and url endpoint. 477 # It specifics the provider URL/host, but this is overridden if you manually set an api url 478 if self._litellm_model_id: 479 return self._litellm_model_id 480 481 litellm_provider_info = get_litellm_provider_info(self.model_provider()) 482 if litellm_provider_info.is_custom and self._api_base is None: 483 raise ValueError( 484 "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)" 485 ) 486 487 self._litellm_model_id = litellm_provider_info.litellm_model_id 488 return self._litellm_model_id 489 490 async def build_completion_kwargs( 491 self, 492 provider: KilnModelProvider, 493 messages: list[ChatCompletionMessageIncludingLiteLLM], 494 top_logprobs: int | None, 495 skip_response_format: bool = False, 496 ) -> dict[str, Any]: 497 run_config = as_kiln_agent_run_config(self.run_config) 498 extra_body = self.build_extra_body(provider) 499 500 # Merge all parameters into a single kwargs dict for litellm 501 completion_kwargs = { 502 "model": self.litellm_model_id(), 503 "messages": messages, 504 "api_base": self._api_base, 505 "headers": self._headers, 506 "temperature": run_config.temperature, 507 "top_p": run_config.top_p, 508 # This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc. 509 # Not all models and providers support all openai params (for example, o3 doesn't support top_p) 510 # Better to ignore them than to fail the model call. 511 # https://docs.litellm.ai/docs/completion/input 512 "drop_params": True, 513 **extra_body, 514 **self._additional_body_options, 515 } 516 517 tool_calls = await self.litellm_tools() 518 has_tools = len(tool_calls) > 0 519 if has_tools: 520 completion_kwargs["tools"] = tool_calls 521 completion_kwargs["tool_choice"] = "auto" 522 523 # Special condition for Claude Opus 4.1 and Sonnet 4.5, where we can only specify top_p or temp, not both. 524 # Remove default values (1.0) prioritizing anything the user customized, then error with helpful message if they are both custom. 525 if provider.temp_top_p_exclusive: 526 if "top_p" in completion_kwargs and completion_kwargs["top_p"] == 1.0: 527 del completion_kwargs["top_p"] 528 if ( 529 "temperature" in completion_kwargs 530 and completion_kwargs["temperature"] == 1.0 531 ): 532 del completion_kwargs["temperature"] 533 if "top_p" in completion_kwargs and "temperature" in completion_kwargs: 534 raise ValueError( 535 "top_p and temperature can not both have custom values for this model. This is a restriction from the model provider. Please set only one of them to a custom value (not 1.0)." 536 ) 537 538 if not skip_response_format: 539 # Response format: json_schema, json_instructions, json_mode, function_calling, etc 540 response_format_options = await self.response_format_options() 541 542 # Check for a conflict between tools and response format using tools 543 # We could reconsider this. Model could be able to choose between a final answer or a tool call on any turn. However, good models for tools tend to also support json_schea, so do we need to support both? If we do, merge them, and consider auto vs forced when merging (only forced for final, auto for merged). 544 if has_tools and "tools" in response_format_options: 545 raise ValueError( 546 "Function calling/tools can't be used as the JSON response format if you're also using tools. Please select a different structured output mode." 547 ) 548 549 completion_kwargs.update(response_format_options) 550 551 if top_logprobs is not None: 552 completion_kwargs["logprobs"] = True 553 completion_kwargs["top_logprobs"] = top_logprobs 554 555 return completion_kwargs 556 557 def usage_from_response(self, response: ModelResponse) -> Usage: 558 litellm_usage = response.get("usage", None) 559 560 # LiteLLM isn't consistent in how it returns the cost. 561 cost = response._hidden_params.get("response_cost", None) 562 if cost is None and litellm_usage: 563 cost = litellm_usage.get("cost", None) 564 565 usage = Usage() 566 567 if not litellm_usage and not cost: 568 return usage 569 570 if litellm_usage and isinstance(litellm_usage, LiteLlmUsage): 571 usage.input_tokens = litellm_usage.get("prompt_tokens", None) 572 usage.output_tokens = litellm_usage.get("completion_tokens", None) 573 usage.total_tokens = litellm_usage.get("total_tokens", None) 574 else: 575 logger.warning( 576 f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}" 577 ) 578 579 if isinstance(cost, float): 580 usage.cost = cost 581 elif cost is not None: 582 # None is allowed, but no other types are expected 583 logger.warning( 584 f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}" 585 ) 586 587 return usage 588 589 async def cached_available_tools(self) -> list[KilnToolInterface]: 590 if self._cached_available_tools is None: 591 self._cached_available_tools = await self.available_tools() 592 return self._cached_available_tools 593 594 async def litellm_tools(self) -> list[ToolCallDefinition]: 595 available_tools = await self.cached_available_tools() 596 597 # LiteLLM takes the standard OpenAI-compatible tool call format 598 return [await tool.toolcall_definition() for tool in available_tools] 599 600 async def process_tool_calls( 601 self, tool_calls: list[ChatCompletionMessageToolCall] | None 602 ) -> tuple[str | None, list[ChatCompletionToolMessageParamWrapper]]: 603 if tool_calls is None: 604 return None, [] 605 606 assistant_output_from_toolcall: str | None = None 607 tool_call_response_messages: list[ChatCompletionToolMessageParamWrapper] = [] 608 tool_run_coroutines = [] 609 610 for tool_call in tool_calls: 611 # Kiln "task_response" tool is used for returning structured output via tool calls. 612 # Load the output from the tool call. Also 613 if tool_call.function.name == "task_response": 614 assistant_output_from_toolcall = tool_call.function.arguments 615 continue 616 617 # Process normal tool calls (not the "task_response" tool) 618 tool_name = tool_call.function.name 619 tool = None 620 for tool_option in await self.cached_available_tools(): 621 if await tool_option.name() == tool_name: 622 tool = tool_option 623 break 624 if not tool: 625 raise RuntimeError( 626 f"A tool named '{tool_name}' was invoked by a model, but was not available." 627 ) 628 629 # Parse the arguments and validate them against the tool's schema 630 try: 631 parsed_args = json.loads(tool_call.function.arguments) 632 except json.JSONDecodeError: 633 raise RuntimeError( 634 f"Failed to parse arguments for tool '{tool_name}' (should be JSON): {tool_call.function.arguments}" 635 ) 636 try: 637 tool_call_definition = await tool.toolcall_definition() 638 json_schema = json.dumps(tool_call_definition["function"]["parameters"]) 639 validate_schema_with_value_error(parsed_args, json_schema) 640 except Exception as e: 641 raise RuntimeError( 642 f"Failed to validate arguments for tool '{tool_name}'. The arguments didn't match the tool's schema. The arguments were: {parsed_args}\n The error was: {e}" 643 ) from e 644 645 # Create context with the calling task's allow_saving setting 646 context = ToolCallContext( 647 allow_saving=self.base_adapter_config.allow_saving 648 ) 649 650 async def run_tool_and_format( 651 t=tool, c=context, args=parsed_args, tc_id=tool_call.id 652 ): 653 result = await t.run(c, **args) 654 return ChatCompletionToolMessageParamWrapper( 655 role="tool", 656 tool_call_id=tc_id, 657 content=result.output, 658 kiln_task_tool_data=result.kiln_task_tool_data 659 if isinstance(result, KilnTaskToolResult) 660 else None, 661 ) 662 663 tool_run_coroutines.append(run_tool_and_format()) 664 665 if tool_run_coroutines: 666 tool_call_response_messages = await asyncio.gather(*tool_run_coroutines) 667 668 if ( 669 assistant_output_from_toolcall is not None 670 and len(tool_call_response_messages) > 0 671 ): 672 raise RuntimeError( 673 "Model asked for impossible combination: task_response tool call and other tool calls were both provided in the same turn. This is not supported as it means the model asked us to both return task_response results (ending the turn) and run new tools calls to send back to the model. If the model makes this mistake often, try a difference structured data model like JSON schema, where this is impossible." 674 ) 675 676 return assistant_output_from_toolcall, tool_call_response_messages 677 678 def litellm_message_to_trace_message( 679 self, raw_message: LiteLLMMessage 680 ) -> ChatCompletionAssistantMessageParamWrapper: 681 """ 682 Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper 683 """ 684 message: ChatCompletionAssistantMessageParamWrapper = { 685 "role": "assistant", 686 } 687 if raw_message.role != "assistant": 688 raise ValueError( 689 "Model returned a message with a role other than assistant. This is not supported." 690 ) 691 692 if hasattr(raw_message, "content"): 693 message["content"] = raw_message.content 694 if hasattr(raw_message, "reasoning_content"): 695 message["reasoning_content"] = raw_message.reasoning_content 696 if hasattr(raw_message, "tool_calls"): 697 # Convert ChatCompletionMessageToolCall to ChatCompletionMessageToolCallParam 698 open_ai_tool_calls: List[ChatCompletionMessageToolCallParam] = [] 699 for litellm_tool_call in raw_message.tool_calls or []: 700 # Optional in the SDK for streaming responses, but should never be None at this point. 701 if litellm_tool_call.function.name is None: 702 raise ValueError( 703 "The model requested a tool call, without providing a function name (required)." 704 ) 705 open_ai_tool_calls.append( 706 ChatCompletionMessageToolCallParam( 707 id=litellm_tool_call.id, 708 type="function", 709 function={ 710 "name": litellm_tool_call.function.name, 711 "arguments": litellm_tool_call.function.arguments, 712 }, 713 ) 714 ) 715 if len(open_ai_tool_calls) > 0: 716 message["tool_calls"] = open_ai_tool_calls 717 718 if not message.get("content") and not message.get("tool_calls"): 719 raise ValueError( 720 "Model returned an assistant message, but no content or tool calls. This is not supported." 721 ) 722 723 return message 724 725 def all_messages_to_trace( 726 self, messages: list[ChatCompletionMessageIncludingLiteLLM] 727 ) -> list[ChatCompletionMessageParam]: 728 """ 729 Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types. 730 """ 731 trace: list[ChatCompletionMessageParam] = [] 732 for message in messages: 733 if isinstance(message, LiteLLMMessage): 734 trace.append(self.litellm_message_to_trace_message(message)) 735 else: 736 trace.append(message) 737 return trace
Base class for AI model adapters that handle task execution.
This abstract class provides the foundation for implementing model-specific adapters that can process tasks with structured or unstructured inputs/outputs. It handles input/output validation, prompt building, and run tracking.
Prompt building is handled internally by the adapter, which uses a prompt builder based on the run config. To override the prompt building behavior, pass a custom prompt builder to the adapter config.
75 def __init__( 76 self, 77 config: LiteLlmConfig, 78 kiln_task: datamodel.Task, 79 base_adapter_config: AdapterConfig | None = None, 80 ): 81 if not isinstance(config.run_config_properties, KilnAgentRunConfigProperties): 82 raise ValueError("LiteLlmAdapter requires KilnAgentRunConfigProperties") 83 self.config = config 84 self._additional_body_options = config.additional_body_options 85 self._api_base = config.base_url 86 self._headers = config.default_headers 87 self._litellm_model_id: str | None = None 88 self._cached_available_tools: list[KilnToolInterface] | None = None 89 90 super().__init__( 91 task=kiln_task, 92 run_config=config.run_config_properties, 93 config=base_adapter_config, 94 )
294 async def acompletion_checking_response( 295 self, **kwargs 296 ) -> Tuple[ModelResponse, Choices]: 297 response = await litellm.acompletion(**kwargs) 298 if ( 299 not isinstance(response, ModelResponse) 300 or not response.choices 301 or len(response.choices) == 0 302 or not isinstance(response.choices[0], Choices) 303 ): 304 raise RuntimeError( 305 f"Expected ModelResponse with Choices, got {type(response)}." 306 ) 307 return response, response.choices[0]
312 async def response_format_options(self) -> dict[str, Any]: 313 # Unstructured if task isn't structured 314 if not self.has_structured_output(): 315 return {} 316 317 run_config = as_kiln_agent_run_config(self.run_config) 318 structured_output_mode: StructuredOutputMode = run_config.structured_output_mode 319 320 match structured_output_mode: 321 case StructuredOutputMode.json_mode: 322 return {"response_format": {"type": "json_object"}} 323 case StructuredOutputMode.json_schema: 324 return self.json_schema_response_format() 325 case StructuredOutputMode.function_calling_weak: 326 return self.tool_call_params(strict=False) 327 case StructuredOutputMode.function_calling: 328 return self.tool_call_params(strict=True) 329 case StructuredOutputMode.json_instructions: 330 # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below). 331 return {} 332 case StructuredOutputMode.json_custom_instructions: 333 # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above). 334 return {} 335 case StructuredOutputMode.json_instruction_and_object: 336 # We set response_format to json_object and also set json instructions in the prompt 337 return {"response_format": {"type": "json_object"}} 338 case StructuredOutputMode.default: 339 provider_name = run_config.model_provider_name 340 if provider_name == ModelProviderName.ollama: 341 # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs 342 return self.json_schema_response_format() 343 elif provider_name == ModelProviderName.docker_model_runner: 344 # Docker Model Runner uses OpenAI-compatible API with JSON schema support 345 return self.json_schema_response_format() 346 else: 347 # Default to function calling -- it's older than the other modes. Higher compatibility. 348 # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI. 349 strict = provider_name == ModelProviderName.openai 350 return self.tool_call_params(strict=strict) 351 case StructuredOutputMode.unknown: 352 # See above, but this case should never happen. 353 raise ValueError("Structured output mode is unknown.") 354 case _: 355 raise_exhaustive_enum_error(structured_output_mode) # type: ignore[arg-type]
369 def tool_call_params(self, strict: bool) -> dict[str, Any]: 370 # Add additional_properties: false to the schema (OpenAI requires this for some models) 371 output_schema = self.task.output_schema() 372 if not isinstance(output_schema, dict): 373 raise ValueError( 374 "Invalid output schema for this task. Can not use tool calls." 375 ) 376 output_schema["additionalProperties"] = False 377 378 function_params = { 379 "name": "task_response", 380 "parameters": output_schema, 381 } 382 # This should be on, but we allow setting function_calling_weak for APIs that don't support it. 383 if strict: 384 function_params["strict"] = True 385 386 return { 387 "tools": [ 388 { 389 "type": "function", 390 "function": function_params, 391 } 392 ], 393 "tool_choice": { 394 "type": "function", 395 "function": {"name": "task_response"}, 396 }, 397 }
399 def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]: 400 # Don't love having this logic here. But it's worth the usability improvement 401 # so better to keep it than exclude it. Should figure out how I want to isolate 402 # this sort of logic so it's config driven and can be overridden 403 404 extra_body = {} 405 provider_options = {} 406 407 if provider.thinking_level is not None: 408 extra_body["reasoning_effort"] = provider.thinking_level 409 410 if provider.require_openrouter_reasoning: 411 # https://openrouter.ai/docs/use-cases/reasoning-tokens 412 extra_body["reasoning"] = { 413 "exclude": False, 414 } 415 416 if provider.gemini_reasoning_enabled: 417 extra_body["reasoning"] = { 418 "enabled": True, 419 } 420 421 if provider.name == ModelProviderName.openrouter: 422 # Ask OpenRouter to include usage in the response (cost) 423 extra_body["usage"] = {"include": True} 424 425 # Set a default provider order for more deterministic routing. 426 # OpenRouter will ignore providers that don't support the model. 427 # Special cases below (like R1) can override this order. 428 # allow_fallbacks is true by default, but we can override it here. 429 provider_options["order"] = [ 430 "fireworks", 431 "parasail", 432 "together", 433 "deepinfra", 434 "novita", 435 "groq", 436 "amazon-bedrock", 437 "azure", 438 "nebius", 439 ] 440 441 if provider.anthropic_extended_thinking: 442 extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000} 443 444 if provider.r1_openrouter_options: 445 # Require providers that support the reasoning parameter 446 provider_options["require_parameters"] = True 447 # Prefer R1 providers with reasonable perf/quants 448 provider_options["order"] = ["fireworks", "together"] 449 # R1 providers with unreasonable quants 450 provider_options["ignore"] = ["deepinfra"] 451 452 # Only set of this request is to get logprobs. 453 if ( 454 provider.logprobs_openrouter_options 455 and self.base_adapter_config.top_logprobs is not None 456 ): 457 # Don't let OpenRouter choose a provider that doesn't support logprobs. 458 provider_options["require_parameters"] = True 459 # DeepInfra silently fails to return logprobs consistently. 460 provider_options["ignore"] = ["deepinfra"] 461 462 if provider.openrouter_skip_required_parameters: 463 # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params. 464 provider_options["require_parameters"] = False 465 466 # Siliconflow uses a bool flag for thinking, for some models 467 if provider.siliconflow_enable_thinking is not None: 468 extra_body["enable_thinking"] = provider.siliconflow_enable_thinking 469 470 if len(provider_options) > 0: 471 extra_body["provider"] = provider_options 472 473 return extra_body
475 def litellm_model_id(self) -> str: 476 # The model ID is an interesting combination of format and url endpoint. 477 # It specifics the provider URL/host, but this is overridden if you manually set an api url 478 if self._litellm_model_id: 479 return self._litellm_model_id 480 481 litellm_provider_info = get_litellm_provider_info(self.model_provider()) 482 if litellm_provider_info.is_custom and self._api_base is None: 483 raise ValueError( 484 "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)" 485 ) 486 487 self._litellm_model_id = litellm_provider_info.litellm_model_id 488 return self._litellm_model_id
490 async def build_completion_kwargs( 491 self, 492 provider: KilnModelProvider, 493 messages: list[ChatCompletionMessageIncludingLiteLLM], 494 top_logprobs: int | None, 495 skip_response_format: bool = False, 496 ) -> dict[str, Any]: 497 run_config = as_kiln_agent_run_config(self.run_config) 498 extra_body = self.build_extra_body(provider) 499 500 # Merge all parameters into a single kwargs dict for litellm 501 completion_kwargs = { 502 "model": self.litellm_model_id(), 503 "messages": messages, 504 "api_base": self._api_base, 505 "headers": self._headers, 506 "temperature": run_config.temperature, 507 "top_p": run_config.top_p, 508 # This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc. 509 # Not all models and providers support all openai params (for example, o3 doesn't support top_p) 510 # Better to ignore them than to fail the model call. 511 # https://docs.litellm.ai/docs/completion/input 512 "drop_params": True, 513 **extra_body, 514 **self._additional_body_options, 515 } 516 517 tool_calls = await self.litellm_tools() 518 has_tools = len(tool_calls) > 0 519 if has_tools: 520 completion_kwargs["tools"] = tool_calls 521 completion_kwargs["tool_choice"] = "auto" 522 523 # Special condition for Claude Opus 4.1 and Sonnet 4.5, where we can only specify top_p or temp, not both. 524 # Remove default values (1.0) prioritizing anything the user customized, then error with helpful message if they are both custom. 525 if provider.temp_top_p_exclusive: 526 if "top_p" in completion_kwargs and completion_kwargs["top_p"] == 1.0: 527 del completion_kwargs["top_p"] 528 if ( 529 "temperature" in completion_kwargs 530 and completion_kwargs["temperature"] == 1.0 531 ): 532 del completion_kwargs["temperature"] 533 if "top_p" in completion_kwargs and "temperature" in completion_kwargs: 534 raise ValueError( 535 "top_p and temperature can not both have custom values for this model. This is a restriction from the model provider. Please set only one of them to a custom value (not 1.0)." 536 ) 537 538 if not skip_response_format: 539 # Response format: json_schema, json_instructions, json_mode, function_calling, etc 540 response_format_options = await self.response_format_options() 541 542 # Check for a conflict between tools and response format using tools 543 # We could reconsider this. Model could be able to choose between a final answer or a tool call on any turn. However, good models for tools tend to also support json_schea, so do we need to support both? If we do, merge them, and consider auto vs forced when merging (only forced for final, auto for merged). 544 if has_tools and "tools" in response_format_options: 545 raise ValueError( 546 "Function calling/tools can't be used as the JSON response format if you're also using tools. Please select a different structured output mode." 547 ) 548 549 completion_kwargs.update(response_format_options) 550 551 if top_logprobs is not None: 552 completion_kwargs["logprobs"] = True 553 completion_kwargs["top_logprobs"] = top_logprobs 554 555 return completion_kwargs
557 def usage_from_response(self, response: ModelResponse) -> Usage: 558 litellm_usage = response.get("usage", None) 559 560 # LiteLLM isn't consistent in how it returns the cost. 561 cost = response._hidden_params.get("response_cost", None) 562 if cost is None and litellm_usage: 563 cost = litellm_usage.get("cost", None) 564 565 usage = Usage() 566 567 if not litellm_usage and not cost: 568 return usage 569 570 if litellm_usage and isinstance(litellm_usage, LiteLlmUsage): 571 usage.input_tokens = litellm_usage.get("prompt_tokens", None) 572 usage.output_tokens = litellm_usage.get("completion_tokens", None) 573 usage.total_tokens = litellm_usage.get("total_tokens", None) 574 else: 575 logger.warning( 576 f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}" 577 ) 578 579 if isinstance(cost, float): 580 usage.cost = cost 581 elif cost is not None: 582 # None is allowed, but no other types are expected 583 logger.warning( 584 f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}" 585 ) 586 587 return usage
600 async def process_tool_calls( 601 self, tool_calls: list[ChatCompletionMessageToolCall] | None 602 ) -> tuple[str | None, list[ChatCompletionToolMessageParamWrapper]]: 603 if tool_calls is None: 604 return None, [] 605 606 assistant_output_from_toolcall: str | None = None 607 tool_call_response_messages: list[ChatCompletionToolMessageParamWrapper] = [] 608 tool_run_coroutines = [] 609 610 for tool_call in tool_calls: 611 # Kiln "task_response" tool is used for returning structured output via tool calls. 612 # Load the output from the tool call. Also 613 if tool_call.function.name == "task_response": 614 assistant_output_from_toolcall = tool_call.function.arguments 615 continue 616 617 # Process normal tool calls (not the "task_response" tool) 618 tool_name = tool_call.function.name 619 tool = None 620 for tool_option in await self.cached_available_tools(): 621 if await tool_option.name() == tool_name: 622 tool = tool_option 623 break 624 if not tool: 625 raise RuntimeError( 626 f"A tool named '{tool_name}' was invoked by a model, but was not available." 627 ) 628 629 # Parse the arguments and validate them against the tool's schema 630 try: 631 parsed_args = json.loads(tool_call.function.arguments) 632 except json.JSONDecodeError: 633 raise RuntimeError( 634 f"Failed to parse arguments for tool '{tool_name}' (should be JSON): {tool_call.function.arguments}" 635 ) 636 try: 637 tool_call_definition = await tool.toolcall_definition() 638 json_schema = json.dumps(tool_call_definition["function"]["parameters"]) 639 validate_schema_with_value_error(parsed_args, json_schema) 640 except Exception as e: 641 raise RuntimeError( 642 f"Failed to validate arguments for tool '{tool_name}'. The arguments didn't match the tool's schema. The arguments were: {parsed_args}\n The error was: {e}" 643 ) from e 644 645 # Create context with the calling task's allow_saving setting 646 context = ToolCallContext( 647 allow_saving=self.base_adapter_config.allow_saving 648 ) 649 650 async def run_tool_and_format( 651 t=tool, c=context, args=parsed_args, tc_id=tool_call.id 652 ): 653 result = await t.run(c, **args) 654 return ChatCompletionToolMessageParamWrapper( 655 role="tool", 656 tool_call_id=tc_id, 657 content=result.output, 658 kiln_task_tool_data=result.kiln_task_tool_data 659 if isinstance(result, KilnTaskToolResult) 660 else None, 661 ) 662 663 tool_run_coroutines.append(run_tool_and_format()) 664 665 if tool_run_coroutines: 666 tool_call_response_messages = await asyncio.gather(*tool_run_coroutines) 667 668 if ( 669 assistant_output_from_toolcall is not None 670 and len(tool_call_response_messages) > 0 671 ): 672 raise RuntimeError( 673 "Model asked for impossible combination: task_response tool call and other tool calls were both provided in the same turn. This is not supported as it means the model asked us to both return task_response results (ending the turn) and run new tools calls to send back to the model. If the model makes this mistake often, try a difference structured data model like JSON schema, where this is impossible." 674 ) 675 676 return assistant_output_from_toolcall, tool_call_response_messages
678 def litellm_message_to_trace_message( 679 self, raw_message: LiteLLMMessage 680 ) -> ChatCompletionAssistantMessageParamWrapper: 681 """ 682 Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper 683 """ 684 message: ChatCompletionAssistantMessageParamWrapper = { 685 "role": "assistant", 686 } 687 if raw_message.role != "assistant": 688 raise ValueError( 689 "Model returned a message with a role other than assistant. This is not supported." 690 ) 691 692 if hasattr(raw_message, "content"): 693 message["content"] = raw_message.content 694 if hasattr(raw_message, "reasoning_content"): 695 message["reasoning_content"] = raw_message.reasoning_content 696 if hasattr(raw_message, "tool_calls"): 697 # Convert ChatCompletionMessageToolCall to ChatCompletionMessageToolCallParam 698 open_ai_tool_calls: List[ChatCompletionMessageToolCallParam] = [] 699 for litellm_tool_call in raw_message.tool_calls or []: 700 # Optional in the SDK for streaming responses, but should never be None at this point. 701 if litellm_tool_call.function.name is None: 702 raise ValueError( 703 "The model requested a tool call, without providing a function name (required)." 704 ) 705 open_ai_tool_calls.append( 706 ChatCompletionMessageToolCallParam( 707 id=litellm_tool_call.id, 708 type="function", 709 function={ 710 "name": litellm_tool_call.function.name, 711 "arguments": litellm_tool_call.function.arguments, 712 }, 713 ) 714 ) 715 if len(open_ai_tool_calls) > 0: 716 message["tool_calls"] = open_ai_tool_calls 717 718 if not message.get("content") and not message.get("tool_calls"): 719 raise ValueError( 720 "Model returned an assistant message, but no content or tool calls. This is not supported." 721 ) 722 723 return message
Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper
725 def all_messages_to_trace( 726 self, messages: list[ChatCompletionMessageIncludingLiteLLM] 727 ) -> list[ChatCompletionMessageParam]: 728 """ 729 Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types. 730 """ 731 trace: list[ChatCompletionMessageParam] = [] 732 for message in messages: 733 if isinstance(message, LiteLLMMessage): 734 trace.append(self.litellm_message_to_trace_message(message)) 735 else: 736 trace.append(message) 737 return trace
Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types.
Inherited Members
- kiln_ai.adapters.model_adapters.base_adapter.BaseAdapter
- task
- run_config
- base_adapter_config
- output_schema
- input_schema
- model_provider
- invoke
- invoke_returning_run_output
- has_structured_output
- build_prompt
- build_chat_formatter
- generate_run
- update_run_config_unknown_structured_output_mode
- available_tools