kiln_ai.adapters.model_adapters.litellm_adapter
1import asyncio 2import copy 3import json 4import logging 5from dataclasses import dataclass 6from typing import Any, Dict, List, Tuple, TypeAlias, Union 7 8import litellm 9from litellm.types.utils import ( 10 ChatCompletionMessageToolCall, 11 ChoiceLogprobs, 12 Choices, 13 ModelResponse, 14) 15from litellm.types.utils import Message as LiteLLMMessage 16from litellm.types.utils import Usage as LiteLlmUsage 17from openai.types.chat.chat_completion_message_tool_call_param import ( 18 ChatCompletionMessageToolCallParam, 19) 20 21import kiln_ai.datamodel as datamodel 22from kiln_ai.adapters.ml_model_list import ( 23 KilnModelProvider, 24 ModelProviderName, 25 StructuredOutputMode, 26) 27from kiln_ai.adapters.model_adapters.base_adapter import ( 28 AdapterConfig, 29 BaseAdapter, 30 RunOutput, 31 Usage, 32) 33from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig 34from kiln_ai.datamodel.datamodel_enums import InputType 35from kiln_ai.datamodel.json_schema import validate_schema_with_value_error 36from kiln_ai.tools.base_tool import ( 37 KilnToolInterface, 38 ToolCallContext, 39 ToolCallDefinition, 40) 41from kiln_ai.tools.kiln_task_tool import KilnTaskToolResult 42from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error 43from kiln_ai.utils.litellm import get_litellm_provider_info 44from kiln_ai.utils.open_ai_types import ( 45 ChatCompletionAssistantMessageParamWrapper, 46 ChatCompletionMessageParam, 47 ChatCompletionToolMessageParamWrapper, 48) 49 50MAX_CALLS_PER_TURN = 10 51MAX_TOOL_CALLS_PER_TURN = 30 52 53logger = logging.getLogger(__name__) 54 55ChatCompletionMessageIncludingLiteLLM: TypeAlias = Union[ 56 ChatCompletionMessageParam, LiteLLMMessage 57] 58 59 60@dataclass 61class ModelTurnResult: 62 assistant_message: str 63 all_messages: list[ChatCompletionMessageIncludingLiteLLM] 64 model_response: ModelResponse | None 65 model_choice: Choices | None 66 usage: Usage 67 68 69class LiteLlmAdapter(BaseAdapter): 70 def __init__( 71 self, 72 config: LiteLlmConfig, 73 kiln_task: datamodel.Task, 74 base_adapter_config: AdapterConfig | None = None, 75 ): 76 self.config = config 77 self._additional_body_options = config.additional_body_options 78 self._api_base = config.base_url 79 self._headers = config.default_headers 80 self._litellm_model_id: str | None = None 81 self._cached_available_tools: list[KilnToolInterface] | None = None 82 83 super().__init__( 84 task=kiln_task, 85 run_config=config.run_config_properties, 86 config=base_adapter_config, 87 ) 88 89 async def _run_model_turn( 90 self, 91 provider: KilnModelProvider, 92 prior_messages: list[ChatCompletionMessageIncludingLiteLLM], 93 top_logprobs: int | None, 94 skip_response_format: bool, 95 ) -> ModelTurnResult: 96 """ 97 Call the model for a single top level turn: from user message to agent message. 98 99 It may make handle iterations of tool calls between the user/agent message if needed. 100 """ 101 102 usage = Usage() 103 messages = list(prior_messages) 104 tool_calls_count = 0 105 106 while tool_calls_count < MAX_TOOL_CALLS_PER_TURN: 107 # Build completion kwargs for tool calls 108 completion_kwargs = await self.build_completion_kwargs( 109 provider, 110 # Pass a copy, as acompletion mutates objects and breaks types. 111 copy.deepcopy(messages), 112 top_logprobs, 113 skip_response_format, 114 ) 115 116 # Make the completion call 117 model_response, response_choice = await self.acompletion_checking_response( 118 **completion_kwargs 119 ) 120 121 # count the usage 122 usage += self.usage_from_response(model_response) 123 124 # Extract content and tool calls 125 if not hasattr(response_choice, "message"): 126 raise ValueError("Response choice has no message") 127 content = response_choice.message.content 128 tool_calls = response_choice.message.tool_calls 129 if not content and not tool_calls: 130 raise ValueError( 131 "Model returned an assistant message, but no content or tool calls. This is not supported." 132 ) 133 134 # Add message to messages, so it can be used in the next turn 135 messages.append(response_choice.message) 136 137 # Process tool calls if any 138 if tool_calls and len(tool_calls) > 0: 139 ( 140 assistant_message_from_toolcall, 141 tool_call_messages, 142 ) = await self.process_tool_calls(tool_calls) 143 144 # Add tool call results to messages 145 messages.extend(tool_call_messages) 146 147 # If task_response tool was called, we're done 148 if assistant_message_from_toolcall is not None: 149 return ModelTurnResult( 150 assistant_message=assistant_message_from_toolcall, 151 all_messages=messages, 152 model_response=model_response, 153 model_choice=response_choice, 154 usage=usage, 155 ) 156 157 # If there were tool calls, increment counter and continue 158 if tool_call_messages: 159 tool_calls_count += 1 160 continue 161 162 # If no tool calls, return the content as final output 163 if content: 164 return ModelTurnResult( 165 assistant_message=content, 166 all_messages=messages, 167 model_response=model_response, 168 model_choice=response_choice, 169 usage=usage, 170 ) 171 172 # If we get here with no content and no tool calls, break 173 raise RuntimeError( 174 "Model returned neither content nor tool calls. It must return at least one of these." 175 ) 176 177 raise RuntimeError( 178 f"Too many tool calls ({tool_calls_count}). Stopping iteration to avoid using too many tokens." 179 ) 180 181 async def _run(self, input: InputType) -> tuple[RunOutput, Usage | None]: 182 usage = Usage() 183 184 provider = self.model_provider() 185 if not provider.model_id: 186 raise ValueError("Model ID is required for OpenAI compatible models") 187 188 chat_formatter = self.build_chat_formatter(input) 189 messages: list[ChatCompletionMessageIncludingLiteLLM] = [] 190 191 prior_output: str | None = None 192 final_choice: Choices | None = None 193 turns = 0 194 195 while True: 196 turns += 1 197 if turns > MAX_CALLS_PER_TURN: 198 raise RuntimeError( 199 f"Too many turns ({turns}). Stopping iteration to avoid using too many tokens." 200 ) 201 202 turn = chat_formatter.next_turn(prior_output) 203 if turn is None: 204 # No next turn, we're done 205 break 206 207 # Add messages from the turn to chat history 208 for message in turn.messages: 209 if message.content is None: 210 raise ValueError("Empty message content isn't allowed") 211 # pyright incorrectly warns about this, but it's valid so we can ignore. It can't handle the multi-value role. 212 messages.append({"role": message.role, "content": message.content}) # type: ignore 213 214 skip_response_format = not turn.final_call 215 turn_result = await self._run_model_turn( 216 provider, 217 messages, 218 self.base_adapter_config.top_logprobs if turn.final_call else None, 219 skip_response_format, 220 ) 221 222 usage += turn_result.usage 223 224 prior_output = turn_result.assistant_message 225 messages = turn_result.all_messages 226 final_choice = turn_result.model_choice 227 228 if not prior_output: 229 raise RuntimeError("No assistant message/output returned from model") 230 231 logprobs = self._extract_and_validate_logprobs(final_choice) 232 233 # Save COT/reasoning if it exists. May be a message, or may be parsed by LiteLLM (or openrouter, or anyone upstream) 234 intermediate_outputs = chat_formatter.intermediate_outputs() 235 self._extract_reasoning_to_intermediate_outputs( 236 final_choice, intermediate_outputs 237 ) 238 239 if not isinstance(prior_output, str): 240 raise RuntimeError(f"assistant message is not a string: {prior_output}") 241 242 trace = self.all_messages_to_trace(messages) 243 output = RunOutput( 244 output=prior_output, 245 intermediate_outputs=intermediate_outputs, 246 output_logprobs=logprobs, 247 trace=trace, 248 ) 249 250 return output, usage 251 252 def _extract_and_validate_logprobs( 253 self, final_choice: Choices | None 254 ) -> ChoiceLogprobs | None: 255 """ 256 Extract logprobs from the final choice and validate they exist if required. 257 """ 258 logprobs = None 259 if ( 260 final_choice is not None 261 and hasattr(final_choice, "logprobs") 262 and isinstance(final_choice.logprobs, ChoiceLogprobs) 263 ): 264 logprobs = final_choice.logprobs 265 266 # Check logprobs worked, if required 267 if self.base_adapter_config.top_logprobs is not None and logprobs is None: 268 raise RuntimeError("Logprobs were required, but no logprobs were returned.") 269 270 return logprobs 271 272 def _extract_reasoning_to_intermediate_outputs( 273 self, final_choice: Choices | None, intermediate_outputs: Dict[str, Any] 274 ) -> None: 275 """Extract reasoning content from model choice and add to intermediate outputs if present.""" 276 if ( 277 final_choice is not None 278 and hasattr(final_choice, "message") 279 and hasattr(final_choice.message, "reasoning_content") 280 ): 281 reasoning_content = final_choice.message.reasoning_content 282 if reasoning_content is not None: 283 stripped_reasoning_content = reasoning_content.strip() 284 if len(stripped_reasoning_content) > 0: 285 intermediate_outputs["reasoning"] = stripped_reasoning_content 286 287 async def acompletion_checking_response( 288 self, **kwargs 289 ) -> Tuple[ModelResponse, Choices]: 290 response = await litellm.acompletion(**kwargs) 291 if ( 292 not isinstance(response, ModelResponse) 293 or not response.choices 294 or len(response.choices) == 0 295 or not isinstance(response.choices[0], Choices) 296 ): 297 raise RuntimeError( 298 f"Expected ModelResponse with Choices, got {type(response)}." 299 ) 300 return response, response.choices[0] 301 302 def adapter_name(self) -> str: 303 return "kiln_openai_compatible_adapter" 304 305 async def response_format_options(self) -> dict[str, Any]: 306 # Unstructured if task isn't structured 307 if not self.has_structured_output(): 308 return {} 309 310 structured_output_mode: StructuredOutputMode = ( 311 self.run_config.structured_output_mode 312 ) 313 314 match structured_output_mode: 315 case StructuredOutputMode.json_mode: 316 return {"response_format": {"type": "json_object"}} 317 case StructuredOutputMode.json_schema: 318 return self.json_schema_response_format() 319 case StructuredOutputMode.function_calling_weak: 320 return self.tool_call_params(strict=False) 321 case StructuredOutputMode.function_calling: 322 return self.tool_call_params(strict=True) 323 case StructuredOutputMode.json_instructions: 324 # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below). 325 return {} 326 case StructuredOutputMode.json_custom_instructions: 327 # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above). 328 return {} 329 case StructuredOutputMode.json_instruction_and_object: 330 # We set response_format to json_object and also set json instructions in the prompt 331 return {"response_format": {"type": "json_object"}} 332 case StructuredOutputMode.default: 333 provider_name = self.run_config.model_provider_name 334 if provider_name == ModelProviderName.ollama: 335 # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs 336 return self.json_schema_response_format() 337 elif provider_name == ModelProviderName.docker_model_runner: 338 # Docker Model Runner uses OpenAI-compatible API with JSON schema support 339 return self.json_schema_response_format() 340 else: 341 # Default to function calling -- it's older than the other modes. Higher compatibility. 342 # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI. 343 strict = provider_name == ModelProviderName.openai 344 return self.tool_call_params(strict=strict) 345 case StructuredOutputMode.unknown: 346 # See above, but this case should never happen. 347 raise ValueError("Structured output mode is unknown.") 348 case _: 349 raise_exhaustive_enum_error(structured_output_mode) 350 351 def json_schema_response_format(self) -> dict[str, Any]: 352 output_schema = self.task.output_schema() 353 return { 354 "response_format": { 355 "type": "json_schema", 356 "json_schema": { 357 "name": "task_response", 358 "schema": output_schema, 359 }, 360 } 361 } 362 363 def tool_call_params(self, strict: bool) -> dict[str, Any]: 364 # Add additional_properties: false to the schema (OpenAI requires this for some models) 365 output_schema = self.task.output_schema() 366 if not isinstance(output_schema, dict): 367 raise ValueError( 368 "Invalid output schema for this task. Can not use tool calls." 369 ) 370 output_schema["additionalProperties"] = False 371 372 function_params = { 373 "name": "task_response", 374 "parameters": output_schema, 375 } 376 # This should be on, but we allow setting function_calling_weak for APIs that don't support it. 377 if strict: 378 function_params["strict"] = True 379 380 return { 381 "tools": [ 382 { 383 "type": "function", 384 "function": function_params, 385 } 386 ], 387 "tool_choice": { 388 "type": "function", 389 "function": {"name": "task_response"}, 390 }, 391 } 392 393 def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]: 394 # Don't love having this logic here. But it's worth the usability improvement 395 # so better to keep it than exclude it. Should figure out how I want to isolate 396 # this sort of logic so it's config driven and can be overridden 397 398 extra_body = {} 399 provider_options = {} 400 401 if provider.thinking_level is not None: 402 extra_body["reasoning_effort"] = provider.thinking_level 403 404 if provider.require_openrouter_reasoning: 405 # https://openrouter.ai/docs/use-cases/reasoning-tokens 406 extra_body["reasoning"] = { 407 "exclude": False, 408 } 409 410 if provider.gemini_reasoning_enabled: 411 extra_body["reasoning"] = { 412 "enabled": True, 413 } 414 415 if provider.name == ModelProviderName.openrouter: 416 # Ask OpenRouter to include usage in the response (cost) 417 extra_body["usage"] = {"include": True} 418 419 # Set a default provider order for more deterministic routing. 420 # OpenRouter will ignore providers that don't support the model. 421 # Special cases below (like R1) can override this order. 422 # allow_fallbacks is true by default, but we can override it here. 423 provider_options["order"] = [ 424 "fireworks", 425 "parasail", 426 "together", 427 "deepinfra", 428 "novita", 429 "groq", 430 "amazon-bedrock", 431 "azure", 432 "nebius", 433 ] 434 435 if provider.anthropic_extended_thinking: 436 extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000} 437 438 if provider.r1_openrouter_options: 439 # Require providers that support the reasoning parameter 440 provider_options["require_parameters"] = True 441 # Prefer R1 providers with reasonable perf/quants 442 provider_options["order"] = ["fireworks", "together"] 443 # R1 providers with unreasonable quants 444 provider_options["ignore"] = ["deepinfra"] 445 446 # Only set of this request is to get logprobs. 447 if ( 448 provider.logprobs_openrouter_options 449 and self.base_adapter_config.top_logprobs is not None 450 ): 451 # Don't let OpenRouter choose a provider that doesn't support logprobs. 452 provider_options["require_parameters"] = True 453 # DeepInfra silently fails to return logprobs consistently. 454 provider_options["ignore"] = ["deepinfra"] 455 456 if provider.openrouter_skip_required_parameters: 457 # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params. 458 provider_options["require_parameters"] = False 459 460 # Siliconflow uses a bool flag for thinking, for some models 461 if provider.siliconflow_enable_thinking is not None: 462 extra_body["enable_thinking"] = provider.siliconflow_enable_thinking 463 464 if len(provider_options) > 0: 465 extra_body["provider"] = provider_options 466 467 return extra_body 468 469 def litellm_model_id(self) -> str: 470 # The model ID is an interesting combination of format and url endpoint. 471 # It specifics the provider URL/host, but this is overridden if you manually set an api url 472 if self._litellm_model_id: 473 return self._litellm_model_id 474 475 litellm_provider_info = get_litellm_provider_info(self.model_provider()) 476 if litellm_provider_info.is_custom and self._api_base is None: 477 raise ValueError( 478 "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)" 479 ) 480 481 self._litellm_model_id = litellm_provider_info.litellm_model_id 482 return self._litellm_model_id 483 484 async def build_completion_kwargs( 485 self, 486 provider: KilnModelProvider, 487 messages: list[ChatCompletionMessageIncludingLiteLLM], 488 top_logprobs: int | None, 489 skip_response_format: bool = False, 490 ) -> dict[str, Any]: 491 extra_body = self.build_extra_body(provider) 492 493 # Merge all parameters into a single kwargs dict for litellm 494 completion_kwargs = { 495 "model": self.litellm_model_id(), 496 "messages": messages, 497 "api_base": self._api_base, 498 "headers": self._headers, 499 "temperature": self.run_config.temperature, 500 "top_p": self.run_config.top_p, 501 # This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc. 502 # Not all models and providers support all openai params (for example, o3 doesn't support top_p) 503 # Better to ignore them than to fail the model call. 504 # https://docs.litellm.ai/docs/completion/input 505 "drop_params": True, 506 **extra_body, 507 **self._additional_body_options, 508 } 509 510 tool_calls = await self.litellm_tools() 511 has_tools = len(tool_calls) > 0 512 if has_tools: 513 completion_kwargs["tools"] = tool_calls 514 completion_kwargs["tool_choice"] = "auto" 515 516 # Special condition for Claude Opus 4.1 and Sonnet 4.5, where we can only specify top_p or temp, not both. 517 # Remove default values (1.0) prioritizing anything the user customized, then error with helpful message if they are both custom. 518 if provider.temp_top_p_exclusive: 519 if "top_p" in completion_kwargs and completion_kwargs["top_p"] == 1.0: 520 del completion_kwargs["top_p"] 521 if ( 522 "temperature" in completion_kwargs 523 and completion_kwargs["temperature"] == 1.0 524 ): 525 del completion_kwargs["temperature"] 526 if "top_p" in completion_kwargs and "temperature" in completion_kwargs: 527 raise ValueError( 528 "top_p and temperature can not both have custom values for this model. This is a restriction from the model provider. Please set only one of them to a custom value (not 1.0)." 529 ) 530 531 if not skip_response_format: 532 # Response format: json_schema, json_instructions, json_mode, function_calling, etc 533 response_format_options = await self.response_format_options() 534 535 # Check for a conflict between tools and response format using tools 536 # We could reconsider this. Model could be able to choose between a final answer or a tool call on any turn. However, good models for tools tend to also support json_schea, so do we need to support both? If we do, merge them, and consider auto vs forced when merging (only forced for final, auto for merged). 537 if has_tools and "tools" in response_format_options: 538 raise ValueError( 539 "Function calling/tools can't be used as the JSON response format if you're also using tools. Please select a different structured output mode." 540 ) 541 542 completion_kwargs.update(response_format_options) 543 544 if top_logprobs is not None: 545 completion_kwargs["logprobs"] = True 546 completion_kwargs["top_logprobs"] = top_logprobs 547 548 return completion_kwargs 549 550 def usage_from_response(self, response: ModelResponse) -> Usage: 551 litellm_usage = response.get("usage", None) 552 553 # LiteLLM isn't consistent in how it returns the cost. 554 cost = response._hidden_params.get("response_cost", None) 555 if cost is None and litellm_usage: 556 cost = litellm_usage.get("cost", None) 557 558 usage = Usage() 559 560 if not litellm_usage and not cost: 561 return usage 562 563 if litellm_usage and isinstance(litellm_usage, LiteLlmUsage): 564 usage.input_tokens = litellm_usage.get("prompt_tokens", None) 565 usage.output_tokens = litellm_usage.get("completion_tokens", None) 566 usage.total_tokens = litellm_usage.get("total_tokens", None) 567 else: 568 logger.warning( 569 f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}" 570 ) 571 572 if isinstance(cost, float): 573 usage.cost = cost 574 elif cost is not None: 575 # None is allowed, but no other types are expected 576 logger.warning( 577 f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}" 578 ) 579 580 return usage 581 582 async def cached_available_tools(self) -> list[KilnToolInterface]: 583 if self._cached_available_tools is None: 584 self._cached_available_tools = await self.available_tools() 585 return self._cached_available_tools 586 587 async def litellm_tools(self) -> list[ToolCallDefinition]: 588 available_tools = await self.cached_available_tools() 589 590 # LiteLLM takes the standard OpenAI-compatible tool call format 591 return [await tool.toolcall_definition() for tool in available_tools] 592 593 async def process_tool_calls( 594 self, tool_calls: list[ChatCompletionMessageToolCall] | None 595 ) -> tuple[str | None, list[ChatCompletionToolMessageParamWrapper]]: 596 if tool_calls is None: 597 return None, [] 598 599 assistant_output_from_toolcall: str | None = None 600 tool_call_response_messages: list[ChatCompletionToolMessageParamWrapper] = [] 601 tool_run_coroutines = [] 602 603 for tool_call in tool_calls: 604 # Kiln "task_response" tool is used for returning structured output via tool calls. 605 # Load the output from the tool call. Also 606 if tool_call.function.name == "task_response": 607 assistant_output_from_toolcall = tool_call.function.arguments 608 continue 609 610 # Process normal tool calls (not the "task_response" tool) 611 tool_name = tool_call.function.name 612 tool = None 613 for tool_option in await self.cached_available_tools(): 614 if await tool_option.name() == tool_name: 615 tool = tool_option 616 break 617 if not tool: 618 raise RuntimeError( 619 f"A tool named '{tool_name}' was invoked by a model, but was not available." 620 ) 621 622 # Parse the arguments and validate them against the tool's schema 623 try: 624 parsed_args = json.loads(tool_call.function.arguments) 625 except json.JSONDecodeError: 626 raise RuntimeError( 627 f"Failed to parse arguments for tool '{tool_name}' (should be JSON): {tool_call.function.arguments}" 628 ) 629 try: 630 tool_call_definition = await tool.toolcall_definition() 631 json_schema = json.dumps(tool_call_definition["function"]["parameters"]) 632 validate_schema_with_value_error(parsed_args, json_schema) 633 except Exception as e: 634 raise RuntimeError( 635 f"Failed to validate arguments for tool '{tool_name}'. The arguments didn't match the tool's schema. The arguments were: {parsed_args}\n The error was: {e}" 636 ) from e 637 638 # Create context with the calling task's allow_saving setting 639 context = ToolCallContext( 640 allow_saving=self.base_adapter_config.allow_saving 641 ) 642 643 async def run_tool_and_format( 644 t=tool, c=context, args=parsed_args, tc_id=tool_call.id 645 ): 646 result = await t.run(c, **args) 647 return ChatCompletionToolMessageParamWrapper( 648 role="tool", 649 tool_call_id=tc_id, 650 content=result.output, 651 kiln_task_tool_data=result.kiln_task_tool_data 652 if isinstance(result, KilnTaskToolResult) 653 else None, 654 ) 655 656 tool_run_coroutines.append(run_tool_and_format()) 657 658 if tool_run_coroutines: 659 tool_call_response_messages = await asyncio.gather(*tool_run_coroutines) 660 661 if ( 662 assistant_output_from_toolcall is not None 663 and len(tool_call_response_messages) > 0 664 ): 665 raise RuntimeError( 666 "Model asked for impossible combination: task_response tool call and other tool calls were both provided in the same turn. This is not supported as it means the model asked us to both return task_response results (ending the turn) and run new tools calls to send back to the model. If the model makes this mistake often, try a difference structured data model like JSON schema, where this is impossible." 667 ) 668 669 return assistant_output_from_toolcall, tool_call_response_messages 670 671 def litellm_message_to_trace_message( 672 self, raw_message: LiteLLMMessage 673 ) -> ChatCompletionAssistantMessageParamWrapper: 674 """ 675 Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper 676 """ 677 message: ChatCompletionAssistantMessageParamWrapper = { 678 "role": "assistant", 679 } 680 if raw_message.role != "assistant": 681 raise ValueError( 682 "Model returned a message with a role other than assistant. This is not supported." 683 ) 684 685 if hasattr(raw_message, "content"): 686 message["content"] = raw_message.content 687 if hasattr(raw_message, "reasoning_content"): 688 message["reasoning_content"] = raw_message.reasoning_content 689 if hasattr(raw_message, "tool_calls"): 690 # Convert ChatCompletionMessageToolCall to ChatCompletionMessageToolCallParam 691 open_ai_tool_calls: List[ChatCompletionMessageToolCallParam] = [] 692 for litellm_tool_call in raw_message.tool_calls or []: 693 # Optional in the SDK for streaming responses, but should never be None at this point. 694 if litellm_tool_call.function.name is None: 695 raise ValueError( 696 "The model requested a tool call, without providing a function name (required)." 697 ) 698 open_ai_tool_calls.append( 699 ChatCompletionMessageToolCallParam( 700 id=litellm_tool_call.id, 701 type="function", 702 function={ 703 "name": litellm_tool_call.function.name, 704 "arguments": litellm_tool_call.function.arguments, 705 }, 706 ) 707 ) 708 if len(open_ai_tool_calls) > 0: 709 message["tool_calls"] = open_ai_tool_calls 710 711 if not message.get("content") and not message.get("tool_calls"): 712 raise ValueError( 713 "Model returned an assistant message, but no content or tool calls. This is not supported." 714 ) 715 716 return message 717 718 def all_messages_to_trace( 719 self, messages: list[ChatCompletionMessageIncludingLiteLLM] 720 ) -> list[ChatCompletionMessageParam]: 721 """ 722 Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types. 723 """ 724 trace: list[ChatCompletionMessageParam] = [] 725 for message in messages: 726 if isinstance(message, LiteLLMMessage): 727 trace.append(self.litellm_message_to_trace_message(message)) 728 else: 729 trace.append(message) 730 return trace
61@dataclass 62class ModelTurnResult: 63 assistant_message: str 64 all_messages: list[ChatCompletionMessageIncludingLiteLLM] 65 model_response: ModelResponse | None 66 model_choice: Choices | None 67 usage: Usage
70class LiteLlmAdapter(BaseAdapter): 71 def __init__( 72 self, 73 config: LiteLlmConfig, 74 kiln_task: datamodel.Task, 75 base_adapter_config: AdapterConfig | None = None, 76 ): 77 self.config = config 78 self._additional_body_options = config.additional_body_options 79 self._api_base = config.base_url 80 self._headers = config.default_headers 81 self._litellm_model_id: str | None = None 82 self._cached_available_tools: list[KilnToolInterface] | None = None 83 84 super().__init__( 85 task=kiln_task, 86 run_config=config.run_config_properties, 87 config=base_adapter_config, 88 ) 89 90 async def _run_model_turn( 91 self, 92 provider: KilnModelProvider, 93 prior_messages: list[ChatCompletionMessageIncludingLiteLLM], 94 top_logprobs: int | None, 95 skip_response_format: bool, 96 ) -> ModelTurnResult: 97 """ 98 Call the model for a single top level turn: from user message to agent message. 99 100 It may make handle iterations of tool calls between the user/agent message if needed. 101 """ 102 103 usage = Usage() 104 messages = list(prior_messages) 105 tool_calls_count = 0 106 107 while tool_calls_count < MAX_TOOL_CALLS_PER_TURN: 108 # Build completion kwargs for tool calls 109 completion_kwargs = await self.build_completion_kwargs( 110 provider, 111 # Pass a copy, as acompletion mutates objects and breaks types. 112 copy.deepcopy(messages), 113 top_logprobs, 114 skip_response_format, 115 ) 116 117 # Make the completion call 118 model_response, response_choice = await self.acompletion_checking_response( 119 **completion_kwargs 120 ) 121 122 # count the usage 123 usage += self.usage_from_response(model_response) 124 125 # Extract content and tool calls 126 if not hasattr(response_choice, "message"): 127 raise ValueError("Response choice has no message") 128 content = response_choice.message.content 129 tool_calls = response_choice.message.tool_calls 130 if not content and not tool_calls: 131 raise ValueError( 132 "Model returned an assistant message, but no content or tool calls. This is not supported." 133 ) 134 135 # Add message to messages, so it can be used in the next turn 136 messages.append(response_choice.message) 137 138 # Process tool calls if any 139 if tool_calls and len(tool_calls) > 0: 140 ( 141 assistant_message_from_toolcall, 142 tool_call_messages, 143 ) = await self.process_tool_calls(tool_calls) 144 145 # Add tool call results to messages 146 messages.extend(tool_call_messages) 147 148 # If task_response tool was called, we're done 149 if assistant_message_from_toolcall is not None: 150 return ModelTurnResult( 151 assistant_message=assistant_message_from_toolcall, 152 all_messages=messages, 153 model_response=model_response, 154 model_choice=response_choice, 155 usage=usage, 156 ) 157 158 # If there were tool calls, increment counter and continue 159 if tool_call_messages: 160 tool_calls_count += 1 161 continue 162 163 # If no tool calls, return the content as final output 164 if content: 165 return ModelTurnResult( 166 assistant_message=content, 167 all_messages=messages, 168 model_response=model_response, 169 model_choice=response_choice, 170 usage=usage, 171 ) 172 173 # If we get here with no content and no tool calls, break 174 raise RuntimeError( 175 "Model returned neither content nor tool calls. It must return at least one of these." 176 ) 177 178 raise RuntimeError( 179 f"Too many tool calls ({tool_calls_count}). Stopping iteration to avoid using too many tokens." 180 ) 181 182 async def _run(self, input: InputType) -> tuple[RunOutput, Usage | None]: 183 usage = Usage() 184 185 provider = self.model_provider() 186 if not provider.model_id: 187 raise ValueError("Model ID is required for OpenAI compatible models") 188 189 chat_formatter = self.build_chat_formatter(input) 190 messages: list[ChatCompletionMessageIncludingLiteLLM] = [] 191 192 prior_output: str | None = None 193 final_choice: Choices | None = None 194 turns = 0 195 196 while True: 197 turns += 1 198 if turns > MAX_CALLS_PER_TURN: 199 raise RuntimeError( 200 f"Too many turns ({turns}). Stopping iteration to avoid using too many tokens." 201 ) 202 203 turn = chat_formatter.next_turn(prior_output) 204 if turn is None: 205 # No next turn, we're done 206 break 207 208 # Add messages from the turn to chat history 209 for message in turn.messages: 210 if message.content is None: 211 raise ValueError("Empty message content isn't allowed") 212 # pyright incorrectly warns about this, but it's valid so we can ignore. It can't handle the multi-value role. 213 messages.append({"role": message.role, "content": message.content}) # type: ignore 214 215 skip_response_format = not turn.final_call 216 turn_result = await self._run_model_turn( 217 provider, 218 messages, 219 self.base_adapter_config.top_logprobs if turn.final_call else None, 220 skip_response_format, 221 ) 222 223 usage += turn_result.usage 224 225 prior_output = turn_result.assistant_message 226 messages = turn_result.all_messages 227 final_choice = turn_result.model_choice 228 229 if not prior_output: 230 raise RuntimeError("No assistant message/output returned from model") 231 232 logprobs = self._extract_and_validate_logprobs(final_choice) 233 234 # Save COT/reasoning if it exists. May be a message, or may be parsed by LiteLLM (or openrouter, or anyone upstream) 235 intermediate_outputs = chat_formatter.intermediate_outputs() 236 self._extract_reasoning_to_intermediate_outputs( 237 final_choice, intermediate_outputs 238 ) 239 240 if not isinstance(prior_output, str): 241 raise RuntimeError(f"assistant message is not a string: {prior_output}") 242 243 trace = self.all_messages_to_trace(messages) 244 output = RunOutput( 245 output=prior_output, 246 intermediate_outputs=intermediate_outputs, 247 output_logprobs=logprobs, 248 trace=trace, 249 ) 250 251 return output, usage 252 253 def _extract_and_validate_logprobs( 254 self, final_choice: Choices | None 255 ) -> ChoiceLogprobs | None: 256 """ 257 Extract logprobs from the final choice and validate they exist if required. 258 """ 259 logprobs = None 260 if ( 261 final_choice is not None 262 and hasattr(final_choice, "logprobs") 263 and isinstance(final_choice.logprobs, ChoiceLogprobs) 264 ): 265 logprobs = final_choice.logprobs 266 267 # Check logprobs worked, if required 268 if self.base_adapter_config.top_logprobs is not None and logprobs is None: 269 raise RuntimeError("Logprobs were required, but no logprobs were returned.") 270 271 return logprobs 272 273 def _extract_reasoning_to_intermediate_outputs( 274 self, final_choice: Choices | None, intermediate_outputs: Dict[str, Any] 275 ) -> None: 276 """Extract reasoning content from model choice and add to intermediate outputs if present.""" 277 if ( 278 final_choice is not None 279 and hasattr(final_choice, "message") 280 and hasattr(final_choice.message, "reasoning_content") 281 ): 282 reasoning_content = final_choice.message.reasoning_content 283 if reasoning_content is not None: 284 stripped_reasoning_content = reasoning_content.strip() 285 if len(stripped_reasoning_content) > 0: 286 intermediate_outputs["reasoning"] = stripped_reasoning_content 287 288 async def acompletion_checking_response( 289 self, **kwargs 290 ) -> Tuple[ModelResponse, Choices]: 291 response = await litellm.acompletion(**kwargs) 292 if ( 293 not isinstance(response, ModelResponse) 294 or not response.choices 295 or len(response.choices) == 0 296 or not isinstance(response.choices[0], Choices) 297 ): 298 raise RuntimeError( 299 f"Expected ModelResponse with Choices, got {type(response)}." 300 ) 301 return response, response.choices[0] 302 303 def adapter_name(self) -> str: 304 return "kiln_openai_compatible_adapter" 305 306 async def response_format_options(self) -> dict[str, Any]: 307 # Unstructured if task isn't structured 308 if not self.has_structured_output(): 309 return {} 310 311 structured_output_mode: StructuredOutputMode = ( 312 self.run_config.structured_output_mode 313 ) 314 315 match structured_output_mode: 316 case StructuredOutputMode.json_mode: 317 return {"response_format": {"type": "json_object"}} 318 case StructuredOutputMode.json_schema: 319 return self.json_schema_response_format() 320 case StructuredOutputMode.function_calling_weak: 321 return self.tool_call_params(strict=False) 322 case StructuredOutputMode.function_calling: 323 return self.tool_call_params(strict=True) 324 case StructuredOutputMode.json_instructions: 325 # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below). 326 return {} 327 case StructuredOutputMode.json_custom_instructions: 328 # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above). 329 return {} 330 case StructuredOutputMode.json_instruction_and_object: 331 # We set response_format to json_object and also set json instructions in the prompt 332 return {"response_format": {"type": "json_object"}} 333 case StructuredOutputMode.default: 334 provider_name = self.run_config.model_provider_name 335 if provider_name == ModelProviderName.ollama: 336 # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs 337 return self.json_schema_response_format() 338 elif provider_name == ModelProviderName.docker_model_runner: 339 # Docker Model Runner uses OpenAI-compatible API with JSON schema support 340 return self.json_schema_response_format() 341 else: 342 # Default to function calling -- it's older than the other modes. Higher compatibility. 343 # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI. 344 strict = provider_name == ModelProviderName.openai 345 return self.tool_call_params(strict=strict) 346 case StructuredOutputMode.unknown: 347 # See above, but this case should never happen. 348 raise ValueError("Structured output mode is unknown.") 349 case _: 350 raise_exhaustive_enum_error(structured_output_mode) 351 352 def json_schema_response_format(self) -> dict[str, Any]: 353 output_schema = self.task.output_schema() 354 return { 355 "response_format": { 356 "type": "json_schema", 357 "json_schema": { 358 "name": "task_response", 359 "schema": output_schema, 360 }, 361 } 362 } 363 364 def tool_call_params(self, strict: bool) -> dict[str, Any]: 365 # Add additional_properties: false to the schema (OpenAI requires this for some models) 366 output_schema = self.task.output_schema() 367 if not isinstance(output_schema, dict): 368 raise ValueError( 369 "Invalid output schema for this task. Can not use tool calls." 370 ) 371 output_schema["additionalProperties"] = False 372 373 function_params = { 374 "name": "task_response", 375 "parameters": output_schema, 376 } 377 # This should be on, but we allow setting function_calling_weak for APIs that don't support it. 378 if strict: 379 function_params["strict"] = True 380 381 return { 382 "tools": [ 383 { 384 "type": "function", 385 "function": function_params, 386 } 387 ], 388 "tool_choice": { 389 "type": "function", 390 "function": {"name": "task_response"}, 391 }, 392 } 393 394 def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]: 395 # Don't love having this logic here. But it's worth the usability improvement 396 # so better to keep it than exclude it. Should figure out how I want to isolate 397 # this sort of logic so it's config driven and can be overridden 398 399 extra_body = {} 400 provider_options = {} 401 402 if provider.thinking_level is not None: 403 extra_body["reasoning_effort"] = provider.thinking_level 404 405 if provider.require_openrouter_reasoning: 406 # https://openrouter.ai/docs/use-cases/reasoning-tokens 407 extra_body["reasoning"] = { 408 "exclude": False, 409 } 410 411 if provider.gemini_reasoning_enabled: 412 extra_body["reasoning"] = { 413 "enabled": True, 414 } 415 416 if provider.name == ModelProviderName.openrouter: 417 # Ask OpenRouter to include usage in the response (cost) 418 extra_body["usage"] = {"include": True} 419 420 # Set a default provider order for more deterministic routing. 421 # OpenRouter will ignore providers that don't support the model. 422 # Special cases below (like R1) can override this order. 423 # allow_fallbacks is true by default, but we can override it here. 424 provider_options["order"] = [ 425 "fireworks", 426 "parasail", 427 "together", 428 "deepinfra", 429 "novita", 430 "groq", 431 "amazon-bedrock", 432 "azure", 433 "nebius", 434 ] 435 436 if provider.anthropic_extended_thinking: 437 extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000} 438 439 if provider.r1_openrouter_options: 440 # Require providers that support the reasoning parameter 441 provider_options["require_parameters"] = True 442 # Prefer R1 providers with reasonable perf/quants 443 provider_options["order"] = ["fireworks", "together"] 444 # R1 providers with unreasonable quants 445 provider_options["ignore"] = ["deepinfra"] 446 447 # Only set of this request is to get logprobs. 448 if ( 449 provider.logprobs_openrouter_options 450 and self.base_adapter_config.top_logprobs is not None 451 ): 452 # Don't let OpenRouter choose a provider that doesn't support logprobs. 453 provider_options["require_parameters"] = True 454 # DeepInfra silently fails to return logprobs consistently. 455 provider_options["ignore"] = ["deepinfra"] 456 457 if provider.openrouter_skip_required_parameters: 458 # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params. 459 provider_options["require_parameters"] = False 460 461 # Siliconflow uses a bool flag for thinking, for some models 462 if provider.siliconflow_enable_thinking is not None: 463 extra_body["enable_thinking"] = provider.siliconflow_enable_thinking 464 465 if len(provider_options) > 0: 466 extra_body["provider"] = provider_options 467 468 return extra_body 469 470 def litellm_model_id(self) -> str: 471 # The model ID is an interesting combination of format and url endpoint. 472 # It specifics the provider URL/host, but this is overridden if you manually set an api url 473 if self._litellm_model_id: 474 return self._litellm_model_id 475 476 litellm_provider_info = get_litellm_provider_info(self.model_provider()) 477 if litellm_provider_info.is_custom and self._api_base is None: 478 raise ValueError( 479 "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)" 480 ) 481 482 self._litellm_model_id = litellm_provider_info.litellm_model_id 483 return self._litellm_model_id 484 485 async def build_completion_kwargs( 486 self, 487 provider: KilnModelProvider, 488 messages: list[ChatCompletionMessageIncludingLiteLLM], 489 top_logprobs: int | None, 490 skip_response_format: bool = False, 491 ) -> dict[str, Any]: 492 extra_body = self.build_extra_body(provider) 493 494 # Merge all parameters into a single kwargs dict for litellm 495 completion_kwargs = { 496 "model": self.litellm_model_id(), 497 "messages": messages, 498 "api_base": self._api_base, 499 "headers": self._headers, 500 "temperature": self.run_config.temperature, 501 "top_p": self.run_config.top_p, 502 # This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc. 503 # Not all models and providers support all openai params (for example, o3 doesn't support top_p) 504 # Better to ignore them than to fail the model call. 505 # https://docs.litellm.ai/docs/completion/input 506 "drop_params": True, 507 **extra_body, 508 **self._additional_body_options, 509 } 510 511 tool_calls = await self.litellm_tools() 512 has_tools = len(tool_calls) > 0 513 if has_tools: 514 completion_kwargs["tools"] = tool_calls 515 completion_kwargs["tool_choice"] = "auto" 516 517 # Special condition for Claude Opus 4.1 and Sonnet 4.5, where we can only specify top_p or temp, not both. 518 # Remove default values (1.0) prioritizing anything the user customized, then error with helpful message if they are both custom. 519 if provider.temp_top_p_exclusive: 520 if "top_p" in completion_kwargs and completion_kwargs["top_p"] == 1.0: 521 del completion_kwargs["top_p"] 522 if ( 523 "temperature" in completion_kwargs 524 and completion_kwargs["temperature"] == 1.0 525 ): 526 del completion_kwargs["temperature"] 527 if "top_p" in completion_kwargs and "temperature" in completion_kwargs: 528 raise ValueError( 529 "top_p and temperature can not both have custom values for this model. This is a restriction from the model provider. Please set only one of them to a custom value (not 1.0)." 530 ) 531 532 if not skip_response_format: 533 # Response format: json_schema, json_instructions, json_mode, function_calling, etc 534 response_format_options = await self.response_format_options() 535 536 # Check for a conflict between tools and response format using tools 537 # We could reconsider this. Model could be able to choose between a final answer or a tool call on any turn. However, good models for tools tend to also support json_schea, so do we need to support both? If we do, merge them, and consider auto vs forced when merging (only forced for final, auto for merged). 538 if has_tools and "tools" in response_format_options: 539 raise ValueError( 540 "Function calling/tools can't be used as the JSON response format if you're also using tools. Please select a different structured output mode." 541 ) 542 543 completion_kwargs.update(response_format_options) 544 545 if top_logprobs is not None: 546 completion_kwargs["logprobs"] = True 547 completion_kwargs["top_logprobs"] = top_logprobs 548 549 return completion_kwargs 550 551 def usage_from_response(self, response: ModelResponse) -> Usage: 552 litellm_usage = response.get("usage", None) 553 554 # LiteLLM isn't consistent in how it returns the cost. 555 cost = response._hidden_params.get("response_cost", None) 556 if cost is None and litellm_usage: 557 cost = litellm_usage.get("cost", None) 558 559 usage = Usage() 560 561 if not litellm_usage and not cost: 562 return usage 563 564 if litellm_usage and isinstance(litellm_usage, LiteLlmUsage): 565 usage.input_tokens = litellm_usage.get("prompt_tokens", None) 566 usage.output_tokens = litellm_usage.get("completion_tokens", None) 567 usage.total_tokens = litellm_usage.get("total_tokens", None) 568 else: 569 logger.warning( 570 f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}" 571 ) 572 573 if isinstance(cost, float): 574 usage.cost = cost 575 elif cost is not None: 576 # None is allowed, but no other types are expected 577 logger.warning( 578 f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}" 579 ) 580 581 return usage 582 583 async def cached_available_tools(self) -> list[KilnToolInterface]: 584 if self._cached_available_tools is None: 585 self._cached_available_tools = await self.available_tools() 586 return self._cached_available_tools 587 588 async def litellm_tools(self) -> list[ToolCallDefinition]: 589 available_tools = await self.cached_available_tools() 590 591 # LiteLLM takes the standard OpenAI-compatible tool call format 592 return [await tool.toolcall_definition() for tool in available_tools] 593 594 async def process_tool_calls( 595 self, tool_calls: list[ChatCompletionMessageToolCall] | None 596 ) -> tuple[str | None, list[ChatCompletionToolMessageParamWrapper]]: 597 if tool_calls is None: 598 return None, [] 599 600 assistant_output_from_toolcall: str | None = None 601 tool_call_response_messages: list[ChatCompletionToolMessageParamWrapper] = [] 602 tool_run_coroutines = [] 603 604 for tool_call in tool_calls: 605 # Kiln "task_response" tool is used for returning structured output via tool calls. 606 # Load the output from the tool call. Also 607 if tool_call.function.name == "task_response": 608 assistant_output_from_toolcall = tool_call.function.arguments 609 continue 610 611 # Process normal tool calls (not the "task_response" tool) 612 tool_name = tool_call.function.name 613 tool = None 614 for tool_option in await self.cached_available_tools(): 615 if await tool_option.name() == tool_name: 616 tool = tool_option 617 break 618 if not tool: 619 raise RuntimeError( 620 f"A tool named '{tool_name}' was invoked by a model, but was not available." 621 ) 622 623 # Parse the arguments and validate them against the tool's schema 624 try: 625 parsed_args = json.loads(tool_call.function.arguments) 626 except json.JSONDecodeError: 627 raise RuntimeError( 628 f"Failed to parse arguments for tool '{tool_name}' (should be JSON): {tool_call.function.arguments}" 629 ) 630 try: 631 tool_call_definition = await tool.toolcall_definition() 632 json_schema = json.dumps(tool_call_definition["function"]["parameters"]) 633 validate_schema_with_value_error(parsed_args, json_schema) 634 except Exception as e: 635 raise RuntimeError( 636 f"Failed to validate arguments for tool '{tool_name}'. The arguments didn't match the tool's schema. The arguments were: {parsed_args}\n The error was: {e}" 637 ) from e 638 639 # Create context with the calling task's allow_saving setting 640 context = ToolCallContext( 641 allow_saving=self.base_adapter_config.allow_saving 642 ) 643 644 async def run_tool_and_format( 645 t=tool, c=context, args=parsed_args, tc_id=tool_call.id 646 ): 647 result = await t.run(c, **args) 648 return ChatCompletionToolMessageParamWrapper( 649 role="tool", 650 tool_call_id=tc_id, 651 content=result.output, 652 kiln_task_tool_data=result.kiln_task_tool_data 653 if isinstance(result, KilnTaskToolResult) 654 else None, 655 ) 656 657 tool_run_coroutines.append(run_tool_and_format()) 658 659 if tool_run_coroutines: 660 tool_call_response_messages = await asyncio.gather(*tool_run_coroutines) 661 662 if ( 663 assistant_output_from_toolcall is not None 664 and len(tool_call_response_messages) > 0 665 ): 666 raise RuntimeError( 667 "Model asked for impossible combination: task_response tool call and other tool calls were both provided in the same turn. This is not supported as it means the model asked us to both return task_response results (ending the turn) and run new tools calls to send back to the model. If the model makes this mistake often, try a difference structured data model like JSON schema, where this is impossible." 668 ) 669 670 return assistant_output_from_toolcall, tool_call_response_messages 671 672 def litellm_message_to_trace_message( 673 self, raw_message: LiteLLMMessage 674 ) -> ChatCompletionAssistantMessageParamWrapper: 675 """ 676 Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper 677 """ 678 message: ChatCompletionAssistantMessageParamWrapper = { 679 "role": "assistant", 680 } 681 if raw_message.role != "assistant": 682 raise ValueError( 683 "Model returned a message with a role other than assistant. This is not supported." 684 ) 685 686 if hasattr(raw_message, "content"): 687 message["content"] = raw_message.content 688 if hasattr(raw_message, "reasoning_content"): 689 message["reasoning_content"] = raw_message.reasoning_content 690 if hasattr(raw_message, "tool_calls"): 691 # Convert ChatCompletionMessageToolCall to ChatCompletionMessageToolCallParam 692 open_ai_tool_calls: List[ChatCompletionMessageToolCallParam] = [] 693 for litellm_tool_call in raw_message.tool_calls or []: 694 # Optional in the SDK for streaming responses, but should never be None at this point. 695 if litellm_tool_call.function.name is None: 696 raise ValueError( 697 "The model requested a tool call, without providing a function name (required)." 698 ) 699 open_ai_tool_calls.append( 700 ChatCompletionMessageToolCallParam( 701 id=litellm_tool_call.id, 702 type="function", 703 function={ 704 "name": litellm_tool_call.function.name, 705 "arguments": litellm_tool_call.function.arguments, 706 }, 707 ) 708 ) 709 if len(open_ai_tool_calls) > 0: 710 message["tool_calls"] = open_ai_tool_calls 711 712 if not message.get("content") and not message.get("tool_calls"): 713 raise ValueError( 714 "Model returned an assistant message, but no content or tool calls. This is not supported." 715 ) 716 717 return message 718 719 def all_messages_to_trace( 720 self, messages: list[ChatCompletionMessageIncludingLiteLLM] 721 ) -> list[ChatCompletionMessageParam]: 722 """ 723 Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types. 724 """ 725 trace: list[ChatCompletionMessageParam] = [] 726 for message in messages: 727 if isinstance(message, LiteLLMMessage): 728 trace.append(self.litellm_message_to_trace_message(message)) 729 else: 730 trace.append(message) 731 return trace
Base class for AI model adapters that handle task execution.
This abstract class provides the foundation for implementing model-specific adapters that can process tasks with structured or unstructured inputs/outputs. It handles input/output validation, prompt building, and run tracking.
Prompt building is handled internally by the adapter, which uses a prompt builder based on the run config. To override the prompt building behavior, pass a custom prompt builder to the adapter config.
71 def __init__( 72 self, 73 config: LiteLlmConfig, 74 kiln_task: datamodel.Task, 75 base_adapter_config: AdapterConfig | None = None, 76 ): 77 self.config = config 78 self._additional_body_options = config.additional_body_options 79 self._api_base = config.base_url 80 self._headers = config.default_headers 81 self._litellm_model_id: str | None = None 82 self._cached_available_tools: list[KilnToolInterface] | None = None 83 84 super().__init__( 85 task=kiln_task, 86 run_config=config.run_config_properties, 87 config=base_adapter_config, 88 )
288 async def acompletion_checking_response( 289 self, **kwargs 290 ) -> Tuple[ModelResponse, Choices]: 291 response = await litellm.acompletion(**kwargs) 292 if ( 293 not isinstance(response, ModelResponse) 294 or not response.choices 295 or len(response.choices) == 0 296 or not isinstance(response.choices[0], Choices) 297 ): 298 raise RuntimeError( 299 f"Expected ModelResponse with Choices, got {type(response)}." 300 ) 301 return response, response.choices[0]
306 async def response_format_options(self) -> dict[str, Any]: 307 # Unstructured if task isn't structured 308 if not self.has_structured_output(): 309 return {} 310 311 structured_output_mode: StructuredOutputMode = ( 312 self.run_config.structured_output_mode 313 ) 314 315 match structured_output_mode: 316 case StructuredOutputMode.json_mode: 317 return {"response_format": {"type": "json_object"}} 318 case StructuredOutputMode.json_schema: 319 return self.json_schema_response_format() 320 case StructuredOutputMode.function_calling_weak: 321 return self.tool_call_params(strict=False) 322 case StructuredOutputMode.function_calling: 323 return self.tool_call_params(strict=True) 324 case StructuredOutputMode.json_instructions: 325 # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below). 326 return {} 327 case StructuredOutputMode.json_custom_instructions: 328 # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above). 329 return {} 330 case StructuredOutputMode.json_instruction_and_object: 331 # We set response_format to json_object and also set json instructions in the prompt 332 return {"response_format": {"type": "json_object"}} 333 case StructuredOutputMode.default: 334 provider_name = self.run_config.model_provider_name 335 if provider_name == ModelProviderName.ollama: 336 # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs 337 return self.json_schema_response_format() 338 elif provider_name == ModelProviderName.docker_model_runner: 339 # Docker Model Runner uses OpenAI-compatible API with JSON schema support 340 return self.json_schema_response_format() 341 else: 342 # Default to function calling -- it's older than the other modes. Higher compatibility. 343 # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI. 344 strict = provider_name == ModelProviderName.openai 345 return self.tool_call_params(strict=strict) 346 case StructuredOutputMode.unknown: 347 # See above, but this case should never happen. 348 raise ValueError("Structured output mode is unknown.") 349 case _: 350 raise_exhaustive_enum_error(structured_output_mode)
364 def tool_call_params(self, strict: bool) -> dict[str, Any]: 365 # Add additional_properties: false to the schema (OpenAI requires this for some models) 366 output_schema = self.task.output_schema() 367 if not isinstance(output_schema, dict): 368 raise ValueError( 369 "Invalid output schema for this task. Can not use tool calls." 370 ) 371 output_schema["additionalProperties"] = False 372 373 function_params = { 374 "name": "task_response", 375 "parameters": output_schema, 376 } 377 # This should be on, but we allow setting function_calling_weak for APIs that don't support it. 378 if strict: 379 function_params["strict"] = True 380 381 return { 382 "tools": [ 383 { 384 "type": "function", 385 "function": function_params, 386 } 387 ], 388 "tool_choice": { 389 "type": "function", 390 "function": {"name": "task_response"}, 391 }, 392 }
394 def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]: 395 # Don't love having this logic here. But it's worth the usability improvement 396 # so better to keep it than exclude it. Should figure out how I want to isolate 397 # this sort of logic so it's config driven and can be overridden 398 399 extra_body = {} 400 provider_options = {} 401 402 if provider.thinking_level is not None: 403 extra_body["reasoning_effort"] = provider.thinking_level 404 405 if provider.require_openrouter_reasoning: 406 # https://openrouter.ai/docs/use-cases/reasoning-tokens 407 extra_body["reasoning"] = { 408 "exclude": False, 409 } 410 411 if provider.gemini_reasoning_enabled: 412 extra_body["reasoning"] = { 413 "enabled": True, 414 } 415 416 if provider.name == ModelProviderName.openrouter: 417 # Ask OpenRouter to include usage in the response (cost) 418 extra_body["usage"] = {"include": True} 419 420 # Set a default provider order for more deterministic routing. 421 # OpenRouter will ignore providers that don't support the model. 422 # Special cases below (like R1) can override this order. 423 # allow_fallbacks is true by default, but we can override it here. 424 provider_options["order"] = [ 425 "fireworks", 426 "parasail", 427 "together", 428 "deepinfra", 429 "novita", 430 "groq", 431 "amazon-bedrock", 432 "azure", 433 "nebius", 434 ] 435 436 if provider.anthropic_extended_thinking: 437 extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000} 438 439 if provider.r1_openrouter_options: 440 # Require providers that support the reasoning parameter 441 provider_options["require_parameters"] = True 442 # Prefer R1 providers with reasonable perf/quants 443 provider_options["order"] = ["fireworks", "together"] 444 # R1 providers with unreasonable quants 445 provider_options["ignore"] = ["deepinfra"] 446 447 # Only set of this request is to get logprobs. 448 if ( 449 provider.logprobs_openrouter_options 450 and self.base_adapter_config.top_logprobs is not None 451 ): 452 # Don't let OpenRouter choose a provider that doesn't support logprobs. 453 provider_options["require_parameters"] = True 454 # DeepInfra silently fails to return logprobs consistently. 455 provider_options["ignore"] = ["deepinfra"] 456 457 if provider.openrouter_skip_required_parameters: 458 # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params. 459 provider_options["require_parameters"] = False 460 461 # Siliconflow uses a bool flag for thinking, for some models 462 if provider.siliconflow_enable_thinking is not None: 463 extra_body["enable_thinking"] = provider.siliconflow_enable_thinking 464 465 if len(provider_options) > 0: 466 extra_body["provider"] = provider_options 467 468 return extra_body
470 def litellm_model_id(self) -> str: 471 # The model ID is an interesting combination of format and url endpoint. 472 # It specifics the provider URL/host, but this is overridden if you manually set an api url 473 if self._litellm_model_id: 474 return self._litellm_model_id 475 476 litellm_provider_info = get_litellm_provider_info(self.model_provider()) 477 if litellm_provider_info.is_custom and self._api_base is None: 478 raise ValueError( 479 "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)" 480 ) 481 482 self._litellm_model_id = litellm_provider_info.litellm_model_id 483 return self._litellm_model_id
485 async def build_completion_kwargs( 486 self, 487 provider: KilnModelProvider, 488 messages: list[ChatCompletionMessageIncludingLiteLLM], 489 top_logprobs: int | None, 490 skip_response_format: bool = False, 491 ) -> dict[str, Any]: 492 extra_body = self.build_extra_body(provider) 493 494 # Merge all parameters into a single kwargs dict for litellm 495 completion_kwargs = { 496 "model": self.litellm_model_id(), 497 "messages": messages, 498 "api_base": self._api_base, 499 "headers": self._headers, 500 "temperature": self.run_config.temperature, 501 "top_p": self.run_config.top_p, 502 # This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc. 503 # Not all models and providers support all openai params (for example, o3 doesn't support top_p) 504 # Better to ignore them than to fail the model call. 505 # https://docs.litellm.ai/docs/completion/input 506 "drop_params": True, 507 **extra_body, 508 **self._additional_body_options, 509 } 510 511 tool_calls = await self.litellm_tools() 512 has_tools = len(tool_calls) > 0 513 if has_tools: 514 completion_kwargs["tools"] = tool_calls 515 completion_kwargs["tool_choice"] = "auto" 516 517 # Special condition for Claude Opus 4.1 and Sonnet 4.5, where we can only specify top_p or temp, not both. 518 # Remove default values (1.0) prioritizing anything the user customized, then error with helpful message if they are both custom. 519 if provider.temp_top_p_exclusive: 520 if "top_p" in completion_kwargs and completion_kwargs["top_p"] == 1.0: 521 del completion_kwargs["top_p"] 522 if ( 523 "temperature" in completion_kwargs 524 and completion_kwargs["temperature"] == 1.0 525 ): 526 del completion_kwargs["temperature"] 527 if "top_p" in completion_kwargs and "temperature" in completion_kwargs: 528 raise ValueError( 529 "top_p and temperature can not both have custom values for this model. This is a restriction from the model provider. Please set only one of them to a custom value (not 1.0)." 530 ) 531 532 if not skip_response_format: 533 # Response format: json_schema, json_instructions, json_mode, function_calling, etc 534 response_format_options = await self.response_format_options() 535 536 # Check for a conflict between tools and response format using tools 537 # We could reconsider this. Model could be able to choose between a final answer or a tool call on any turn. However, good models for tools tend to also support json_schea, so do we need to support both? If we do, merge them, and consider auto vs forced when merging (only forced for final, auto for merged). 538 if has_tools and "tools" in response_format_options: 539 raise ValueError( 540 "Function calling/tools can't be used as the JSON response format if you're also using tools. Please select a different structured output mode." 541 ) 542 543 completion_kwargs.update(response_format_options) 544 545 if top_logprobs is not None: 546 completion_kwargs["logprobs"] = True 547 completion_kwargs["top_logprobs"] = top_logprobs 548 549 return completion_kwargs
551 def usage_from_response(self, response: ModelResponse) -> Usage: 552 litellm_usage = response.get("usage", None) 553 554 # LiteLLM isn't consistent in how it returns the cost. 555 cost = response._hidden_params.get("response_cost", None) 556 if cost is None and litellm_usage: 557 cost = litellm_usage.get("cost", None) 558 559 usage = Usage() 560 561 if not litellm_usage and not cost: 562 return usage 563 564 if litellm_usage and isinstance(litellm_usage, LiteLlmUsage): 565 usage.input_tokens = litellm_usage.get("prompt_tokens", None) 566 usage.output_tokens = litellm_usage.get("completion_tokens", None) 567 usage.total_tokens = litellm_usage.get("total_tokens", None) 568 else: 569 logger.warning( 570 f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}" 571 ) 572 573 if isinstance(cost, float): 574 usage.cost = cost 575 elif cost is not None: 576 # None is allowed, but no other types are expected 577 logger.warning( 578 f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}" 579 ) 580 581 return usage
594 async def process_tool_calls( 595 self, tool_calls: list[ChatCompletionMessageToolCall] | None 596 ) -> tuple[str | None, list[ChatCompletionToolMessageParamWrapper]]: 597 if tool_calls is None: 598 return None, [] 599 600 assistant_output_from_toolcall: str | None = None 601 tool_call_response_messages: list[ChatCompletionToolMessageParamWrapper] = [] 602 tool_run_coroutines = [] 603 604 for tool_call in tool_calls: 605 # Kiln "task_response" tool is used for returning structured output via tool calls. 606 # Load the output from the tool call. Also 607 if tool_call.function.name == "task_response": 608 assistant_output_from_toolcall = tool_call.function.arguments 609 continue 610 611 # Process normal tool calls (not the "task_response" tool) 612 tool_name = tool_call.function.name 613 tool = None 614 for tool_option in await self.cached_available_tools(): 615 if await tool_option.name() == tool_name: 616 tool = tool_option 617 break 618 if not tool: 619 raise RuntimeError( 620 f"A tool named '{tool_name}' was invoked by a model, but was not available." 621 ) 622 623 # Parse the arguments and validate them against the tool's schema 624 try: 625 parsed_args = json.loads(tool_call.function.arguments) 626 except json.JSONDecodeError: 627 raise RuntimeError( 628 f"Failed to parse arguments for tool '{tool_name}' (should be JSON): {tool_call.function.arguments}" 629 ) 630 try: 631 tool_call_definition = await tool.toolcall_definition() 632 json_schema = json.dumps(tool_call_definition["function"]["parameters"]) 633 validate_schema_with_value_error(parsed_args, json_schema) 634 except Exception as e: 635 raise RuntimeError( 636 f"Failed to validate arguments for tool '{tool_name}'. The arguments didn't match the tool's schema. The arguments were: {parsed_args}\n The error was: {e}" 637 ) from e 638 639 # Create context with the calling task's allow_saving setting 640 context = ToolCallContext( 641 allow_saving=self.base_adapter_config.allow_saving 642 ) 643 644 async def run_tool_and_format( 645 t=tool, c=context, args=parsed_args, tc_id=tool_call.id 646 ): 647 result = await t.run(c, **args) 648 return ChatCompletionToolMessageParamWrapper( 649 role="tool", 650 tool_call_id=tc_id, 651 content=result.output, 652 kiln_task_tool_data=result.kiln_task_tool_data 653 if isinstance(result, KilnTaskToolResult) 654 else None, 655 ) 656 657 tool_run_coroutines.append(run_tool_and_format()) 658 659 if tool_run_coroutines: 660 tool_call_response_messages = await asyncio.gather(*tool_run_coroutines) 661 662 if ( 663 assistant_output_from_toolcall is not None 664 and len(tool_call_response_messages) > 0 665 ): 666 raise RuntimeError( 667 "Model asked for impossible combination: task_response tool call and other tool calls were both provided in the same turn. This is not supported as it means the model asked us to both return task_response results (ending the turn) and run new tools calls to send back to the model. If the model makes this mistake often, try a difference structured data model like JSON schema, where this is impossible." 668 ) 669 670 return assistant_output_from_toolcall, tool_call_response_messages
672 def litellm_message_to_trace_message( 673 self, raw_message: LiteLLMMessage 674 ) -> ChatCompletionAssistantMessageParamWrapper: 675 """ 676 Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper 677 """ 678 message: ChatCompletionAssistantMessageParamWrapper = { 679 "role": "assistant", 680 } 681 if raw_message.role != "assistant": 682 raise ValueError( 683 "Model returned a message with a role other than assistant. This is not supported." 684 ) 685 686 if hasattr(raw_message, "content"): 687 message["content"] = raw_message.content 688 if hasattr(raw_message, "reasoning_content"): 689 message["reasoning_content"] = raw_message.reasoning_content 690 if hasattr(raw_message, "tool_calls"): 691 # Convert ChatCompletionMessageToolCall to ChatCompletionMessageToolCallParam 692 open_ai_tool_calls: List[ChatCompletionMessageToolCallParam] = [] 693 for litellm_tool_call in raw_message.tool_calls or []: 694 # Optional in the SDK for streaming responses, but should never be None at this point. 695 if litellm_tool_call.function.name is None: 696 raise ValueError( 697 "The model requested a tool call, without providing a function name (required)." 698 ) 699 open_ai_tool_calls.append( 700 ChatCompletionMessageToolCallParam( 701 id=litellm_tool_call.id, 702 type="function", 703 function={ 704 "name": litellm_tool_call.function.name, 705 "arguments": litellm_tool_call.function.arguments, 706 }, 707 ) 708 ) 709 if len(open_ai_tool_calls) > 0: 710 message["tool_calls"] = open_ai_tool_calls 711 712 if not message.get("content") and not message.get("tool_calls"): 713 raise ValueError( 714 "Model returned an assistant message, but no content or tool calls. This is not supported." 715 ) 716 717 return message
Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper
719 def all_messages_to_trace( 720 self, messages: list[ChatCompletionMessageIncludingLiteLLM] 721 ) -> list[ChatCompletionMessageParam]: 722 """ 723 Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types. 724 """ 725 trace: list[ChatCompletionMessageParam] = [] 726 for message in messages: 727 if isinstance(message, LiteLLMMessage): 728 trace.append(self.litellm_message_to_trace_message(message)) 729 else: 730 trace.append(message) 731 return trace
Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types.
Inherited Members
- kiln_ai.adapters.model_adapters.base_adapter.BaseAdapter
- task
- run_config
- base_adapter_config
- prompt_builder
- output_schema
- input_schema
- model_provider
- invoke
- invoke_returning_run_output
- has_structured_output
- build_prompt
- build_chat_formatter
- generate_run
- update_run_config_unknown_structured_output_mode
- available_tools