kiln_ai.adapters.model_adapters.litellm_adapter

  1import logging
  2from typing import Any, Dict
  3
  4import litellm
  5from litellm.types.utils import ChoiceLogprobs, Choices, ModelResponse
  6from litellm.types.utils import Usage as LiteLlmUsage
  7
  8import kiln_ai.datamodel as datamodel
  9from kiln_ai.adapters.ml_model_list import (
 10    KilnModelProvider,
 11    ModelProviderName,
 12    StructuredOutputMode,
 13)
 14from kiln_ai.adapters.model_adapters.base_adapter import (
 15    COT_FINAL_ANSWER_PROMPT,
 16    AdapterConfig,
 17    BaseAdapter,
 18    RunOutput,
 19    Usage,
 20)
 21from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig
 22from kiln_ai.datamodel import PromptGenerators, PromptId
 23from kiln_ai.datamodel.task import RunConfig
 24from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
 25
 26logger = logging.getLogger(__name__)
 27
 28
 29class LiteLlmAdapter(BaseAdapter):
 30    def __init__(
 31        self,
 32        config: LiteLlmConfig,
 33        kiln_task: datamodel.Task,
 34        prompt_id: PromptId | None = None,
 35        base_adapter_config: AdapterConfig | None = None,
 36    ):
 37        self.config = config
 38        self._additional_body_options = config.additional_body_options
 39        self._api_base = config.base_url
 40        self._headers = config.default_headers
 41        self._litellm_model_id: str | None = None
 42
 43        run_config = RunConfig(
 44            task=kiln_task,
 45            model_name=config.model_name,
 46            model_provider_name=config.provider_name,
 47            prompt_id=prompt_id or PromptGenerators.SIMPLE,
 48        )
 49
 50        super().__init__(
 51            run_config=run_config,
 52            config=base_adapter_config,
 53        )
 54
 55    async def _run(self, input: Dict | str) -> tuple[RunOutput, Usage | None]:
 56        provider = self.model_provider()
 57        if not provider.model_id:
 58            raise ValueError("Model ID is required for OpenAI compatible models")
 59
 60        intermediate_outputs: dict[str, str] = {}
 61        prompt = self.build_prompt()
 62        user_msg = self.prompt_builder.build_user_message(input)
 63        messages = [
 64            {"role": "system", "content": prompt},
 65            {"role": "user", "content": user_msg},
 66        ]
 67
 68        run_strategy, cot_prompt = self.run_strategy()
 69
 70        if run_strategy == "cot_as_message":
 71            # Used for reasoning-capable models that can output thinking and structured format
 72            if not cot_prompt:
 73                raise ValueError("cot_prompt is required for cot_as_message strategy")
 74            messages.append({"role": "system", "content": cot_prompt})
 75        elif run_strategy == "cot_two_call":
 76            if not cot_prompt:
 77                raise ValueError("cot_prompt is required for cot_two_call strategy")
 78            messages.append({"role": "system", "content": cot_prompt})
 79
 80            # First call for chain of thought
 81            # No response format as this request is for "thinking" in plain text
 82            # No logprobs as only needed for final answer
 83            completion_kwargs = await self.build_completion_kwargs(
 84                provider, messages, None, skip_response_format=True
 85            )
 86            cot_response = await litellm.acompletion(**completion_kwargs)
 87            if (
 88                not isinstance(cot_response, ModelResponse)
 89                or not cot_response.choices
 90                or len(cot_response.choices) == 0
 91                or not isinstance(cot_response.choices[0], Choices)
 92            ):
 93                raise RuntimeError(
 94                    f"Expected ModelResponse with Choices, got {type(cot_response)}."
 95                )
 96            cot_content = cot_response.choices[0].message.content
 97            if cot_content is not None:
 98                intermediate_outputs["chain_of_thought"] = cot_content
 99
100            messages.extend(
101                [
102                    {"role": "assistant", "content": cot_content or ""},
103                    {"role": "user", "content": COT_FINAL_ANSWER_PROMPT},
104                ]
105            )
106
107        # Make the API call using litellm
108        completion_kwargs = await self.build_completion_kwargs(
109            provider, messages, self.base_adapter_config.top_logprobs
110        )
111        response = await litellm.acompletion(**completion_kwargs)
112
113        if not isinstance(response, ModelResponse):
114            raise RuntimeError(f"Expected ModelResponse, got {type(response)}.")
115
116        # Maybe remove this? There is no error attribute on the response object.
117        # # Keeping in typesafe way as we added it for a reason, but should investigate what that was and if it still applies.
118        if hasattr(response, "error") and response.__getattribute__("error"):
119            raise RuntimeError(
120                f"LLM API returned an error: {response.__getattribute__('error')}"
121            )
122
123        if (
124            not response.choices
125            or len(response.choices) == 0
126            or not isinstance(response.choices[0], Choices)
127        ):
128            raise RuntimeError(
129                "No message content returned in the response from LLM API"
130            )
131
132        message = response.choices[0].message
133        logprobs = (
134            response.choices[0].logprobs
135            if hasattr(response.choices[0], "logprobs")
136            and isinstance(response.choices[0].logprobs, ChoiceLogprobs)
137            else None
138        )
139
140        # Check logprobs worked, if requested
141        if self.base_adapter_config.top_logprobs is not None and logprobs is None:
142            raise RuntimeError("Logprobs were required, but no logprobs were returned.")
143
144        # Save reasoning if it exists and was parsed by LiteLLM (or openrouter, or anyone upstream)
145        if (
146            hasattr(message, "reasoning_content")
147            and message.reasoning_content
148            and len(message.reasoning_content.strip()) > 0
149        ):
150            intermediate_outputs["reasoning"] = message.reasoning_content.strip()
151
152        # the string content of the response
153        response_content = message.content
154
155        # Fallback: Use args of first tool call to task_response if it exists
156        if (
157            not response_content
158            and hasattr(message, "tool_calls")
159            and message.tool_calls
160        ):
161            tool_call = next(
162                (
163                    tool_call
164                    for tool_call in message.tool_calls
165                    if tool_call.function.name == "task_response"
166                ),
167                None,
168            )
169            if tool_call:
170                response_content = tool_call.function.arguments
171
172        if not isinstance(response_content, str):
173            raise RuntimeError(f"response is not a string: {response_content}")
174
175        return RunOutput(
176            output=response_content,
177            intermediate_outputs=intermediate_outputs,
178            output_logprobs=logprobs,
179        ), self.usage_from_response(response)
180
181    def adapter_name(self) -> str:
182        return "kiln_openai_compatible_adapter"
183
184    async def response_format_options(self) -> dict[str, Any]:
185        # Unstructured if task isn't structured
186        if not self.has_structured_output():
187            return {}
188
189        provider = self.model_provider()
190        match provider.structured_output_mode:
191            case StructuredOutputMode.json_mode:
192                return {"response_format": {"type": "json_object"}}
193            case StructuredOutputMode.json_schema:
194                return self.json_schema_response_format()
195            case StructuredOutputMode.function_calling_weak:
196                return self.tool_call_params(strict=False)
197            case StructuredOutputMode.function_calling:
198                return self.tool_call_params(strict=True)
199            case StructuredOutputMode.json_instructions:
200                # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below).
201                return {}
202            case StructuredOutputMode.json_custom_instructions:
203                # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above).
204                return {}
205            case StructuredOutputMode.json_instruction_and_object:
206                # We set response_format to json_object and also set json instructions in the prompt
207                return {"response_format": {"type": "json_object"}}
208            case StructuredOutputMode.default:
209                if provider.name == ModelProviderName.ollama:
210                    # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs
211                    return self.json_schema_response_format()
212                else:
213                    # Default to function calling -- it's older than the other modes. Higher compatibility.
214                    # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI.
215                    strict = provider.name == ModelProviderName.openai
216                    return self.tool_call_params(strict=strict)
217            case _:
218                raise_exhaustive_enum_error(provider.structured_output_mode)
219
220    def json_schema_response_format(self) -> dict[str, Any]:
221        output_schema = self.task().output_schema()
222        return {
223            "response_format": {
224                "type": "json_schema",
225                "json_schema": {
226                    "name": "task_response",
227                    "schema": output_schema,
228                },
229            }
230        }
231
232    def tool_call_params(self, strict: bool) -> dict[str, Any]:
233        # Add additional_properties: false to the schema (OpenAI requires this for some models)
234        output_schema = self.task().output_schema()
235        if not isinstance(output_schema, dict):
236            raise ValueError(
237                "Invalid output schema for this task. Can not use tool calls."
238            )
239        output_schema["additionalProperties"] = False
240
241        function_params = {
242            "name": "task_response",
243            "parameters": output_schema,
244        }
245        # This should be on, but we allow setting function_calling_weak for APIs that don't support it.
246        if strict:
247            function_params["strict"] = True
248
249        return {
250            "tools": [
251                {
252                    "type": "function",
253                    "function": function_params,
254                }
255            ],
256            "tool_choice": {
257                "type": "function",
258                "function": {"name": "task_response"},
259            },
260        }
261
262    def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]:
263        # TODO P1: Don't love having this logic here. But it's a usability improvement
264        # so better to keep it than exclude it. Should figure out how I want to isolate
265        # this sort of logic so it's config driven and can be overridden
266
267        extra_body = {}
268        provider_options = {}
269
270        if provider.thinking_level is not None:
271            extra_body["reasoning_effort"] = provider.thinking_level
272
273        if provider.require_openrouter_reasoning:
274            # https://openrouter.ai/docs/use-cases/reasoning-tokens
275            extra_body["reasoning"] = {
276                "exclude": False,
277            }
278
279        if provider.anthropic_extended_thinking:
280            extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000}
281
282        if provider.r1_openrouter_options:
283            # Require providers that support the reasoning parameter
284            provider_options["require_parameters"] = True
285            # Prefer R1 providers with reasonable perf/quants
286            provider_options["order"] = ["Fireworks", "Together"]
287            # R1 providers with unreasonable quants
288            provider_options["ignore"] = ["DeepInfra"]
289
290        # Only set of this request is to get logprobs.
291        if (
292            provider.logprobs_openrouter_options
293            and self.base_adapter_config.top_logprobs is not None
294        ):
295            # Don't let OpenRouter choose a provider that doesn't support logprobs.
296            provider_options["require_parameters"] = True
297            # DeepInfra silently fails to return logprobs consistently.
298            provider_options["ignore"] = ["DeepInfra"]
299
300        if provider.openrouter_skip_required_parameters:
301            # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params.
302            provider_options["require_parameters"] = False
303
304        if len(provider_options) > 0:
305            extra_body["provider"] = provider_options
306
307        return extra_body
308
309    def litellm_model_id(self) -> str:
310        # The model ID is an interesting combination of format and url endpoint.
311        # It specifics the provider URL/host, but this is overridden if you manually set an api url
312
313        if self._litellm_model_id:
314            return self._litellm_model_id
315
316        provider = self.model_provider()
317        if not provider.model_id:
318            raise ValueError("Model ID is required for OpenAI compatible models")
319
320        litellm_provider_name: str | None = None
321        is_custom = False
322        match provider.name:
323            case ModelProviderName.openrouter:
324                litellm_provider_name = "openrouter"
325            case ModelProviderName.openai:
326                litellm_provider_name = "openai"
327            case ModelProviderName.groq:
328                litellm_provider_name = "groq"
329            case ModelProviderName.anthropic:
330                litellm_provider_name = "anthropic"
331            case ModelProviderName.ollama:
332                # We don't let litellm use the Ollama API and muck with our requests. We use Ollama's OpenAI compatible API.
333                # This is because we're setting detailed features like response_format=json_schema and want lower level control.
334                is_custom = True
335            case ModelProviderName.gemini_api:
336                litellm_provider_name = "gemini"
337            case ModelProviderName.fireworks_ai:
338                litellm_provider_name = "fireworks_ai"
339            case ModelProviderName.amazon_bedrock:
340                litellm_provider_name = "bedrock"
341            case ModelProviderName.azure_openai:
342                litellm_provider_name = "azure"
343            case ModelProviderName.huggingface:
344                litellm_provider_name = "huggingface"
345            case ModelProviderName.vertex:
346                litellm_provider_name = "vertex_ai"
347            case ModelProviderName.together_ai:
348                litellm_provider_name = "together_ai"
349            case ModelProviderName.openai_compatible:
350                is_custom = True
351            case ModelProviderName.kiln_custom_registry:
352                is_custom = True
353            case ModelProviderName.kiln_fine_tune:
354                is_custom = True
355            case _:
356                raise_exhaustive_enum_error(provider.name)
357
358        if is_custom:
359            if self._api_base is None:
360                raise ValueError(
361                    "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)"
362                )
363            # Use openai as it's only used for format, not url
364            litellm_provider_name = "openai"
365
366        # Sholdn't be possible but keep type checker happy
367        if litellm_provider_name is None:
368            raise ValueError(
369                f"Provider name could not lookup valid litellm provider ID {provider.model_id}"
370            )
371
372        self._litellm_model_id = litellm_provider_name + "/" + provider.model_id
373        return self._litellm_model_id
374
375    async def build_completion_kwargs(
376        self,
377        provider: KilnModelProvider,
378        messages: list[dict[str, Any]],
379        top_logprobs: int | None,
380        skip_response_format: bool = False,
381    ) -> dict[str, Any]:
382        extra_body = self.build_extra_body(provider)
383
384        # Merge all parameters into a single kwargs dict for litellm
385        completion_kwargs = {
386            "model": self.litellm_model_id(),
387            "messages": messages,
388            "api_base": self._api_base,
389            "headers": self._headers,
390            **extra_body,
391            **self._additional_body_options,
392        }
393
394        if not skip_response_format:
395            # Response format: json_schema, json_instructions, json_mode, function_calling, etc
396            response_format_options = await self.response_format_options()
397            completion_kwargs.update(response_format_options)
398
399        if top_logprobs is not None:
400            completion_kwargs["logprobs"] = True
401            completion_kwargs["top_logprobs"] = top_logprobs
402
403        return completion_kwargs
404
405    def usage_from_response(self, response: ModelResponse) -> Usage | None:
406        litellm_usage = response.get("usage", None)
407        cost = response._hidden_params.get("response_cost", None)
408        if not litellm_usage and not cost:
409            return None
410
411        usage = Usage()
412
413        if litellm_usage and isinstance(litellm_usage, LiteLlmUsage):
414            usage.input_tokens = litellm_usage.get("prompt_tokens", None)
415            usage.output_tokens = litellm_usage.get("completion_tokens", None)
416            usage.total_tokens = litellm_usage.get("total_tokens", None)
417        else:
418            logger.warning(
419                f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}"
420            )
421
422        if isinstance(cost, float):
423            usage.cost = cost
424        elif cost is not None:
425            # None is allowed, but no other types are expected
426            logger.warning(
427                f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}"
428            )
429
430        return usage
 30class LiteLlmAdapter(BaseAdapter):
 31    def __init__(
 32        self,
 33        config: LiteLlmConfig,
 34        kiln_task: datamodel.Task,
 35        prompt_id: PromptId | None = None,
 36        base_adapter_config: AdapterConfig | None = None,
 37    ):
 38        self.config = config
 39        self._additional_body_options = config.additional_body_options
 40        self._api_base = config.base_url
 41        self._headers = config.default_headers
 42        self._litellm_model_id: str | None = None
 43
 44        run_config = RunConfig(
 45            task=kiln_task,
 46            model_name=config.model_name,
 47            model_provider_name=config.provider_name,
 48            prompt_id=prompt_id or PromptGenerators.SIMPLE,
 49        )
 50
 51        super().__init__(
 52            run_config=run_config,
 53            config=base_adapter_config,
 54        )
 55
 56    async def _run(self, input: Dict | str) -> tuple[RunOutput, Usage | None]:
 57        provider = self.model_provider()
 58        if not provider.model_id:
 59            raise ValueError("Model ID is required for OpenAI compatible models")
 60
 61        intermediate_outputs: dict[str, str] = {}
 62        prompt = self.build_prompt()
 63        user_msg = self.prompt_builder.build_user_message(input)
 64        messages = [
 65            {"role": "system", "content": prompt},
 66            {"role": "user", "content": user_msg},
 67        ]
 68
 69        run_strategy, cot_prompt = self.run_strategy()
 70
 71        if run_strategy == "cot_as_message":
 72            # Used for reasoning-capable models that can output thinking and structured format
 73            if not cot_prompt:
 74                raise ValueError("cot_prompt is required for cot_as_message strategy")
 75            messages.append({"role": "system", "content": cot_prompt})
 76        elif run_strategy == "cot_two_call":
 77            if not cot_prompt:
 78                raise ValueError("cot_prompt is required for cot_two_call strategy")
 79            messages.append({"role": "system", "content": cot_prompt})
 80
 81            # First call for chain of thought
 82            # No response format as this request is for "thinking" in plain text
 83            # No logprobs as only needed for final answer
 84            completion_kwargs = await self.build_completion_kwargs(
 85                provider, messages, None, skip_response_format=True
 86            )
 87            cot_response = await litellm.acompletion(**completion_kwargs)
 88            if (
 89                not isinstance(cot_response, ModelResponse)
 90                or not cot_response.choices
 91                or len(cot_response.choices) == 0
 92                or not isinstance(cot_response.choices[0], Choices)
 93            ):
 94                raise RuntimeError(
 95                    f"Expected ModelResponse with Choices, got {type(cot_response)}."
 96                )
 97            cot_content = cot_response.choices[0].message.content
 98            if cot_content is not None:
 99                intermediate_outputs["chain_of_thought"] = cot_content
100
101            messages.extend(
102                [
103                    {"role": "assistant", "content": cot_content or ""},
104                    {"role": "user", "content": COT_FINAL_ANSWER_PROMPT},
105                ]
106            )
107
108        # Make the API call using litellm
109        completion_kwargs = await self.build_completion_kwargs(
110            provider, messages, self.base_adapter_config.top_logprobs
111        )
112        response = await litellm.acompletion(**completion_kwargs)
113
114        if not isinstance(response, ModelResponse):
115            raise RuntimeError(f"Expected ModelResponse, got {type(response)}.")
116
117        # Maybe remove this? There is no error attribute on the response object.
118        # # Keeping in typesafe way as we added it for a reason, but should investigate what that was and if it still applies.
119        if hasattr(response, "error") and response.__getattribute__("error"):
120            raise RuntimeError(
121                f"LLM API returned an error: {response.__getattribute__('error')}"
122            )
123
124        if (
125            not response.choices
126            or len(response.choices) == 0
127            or not isinstance(response.choices[0], Choices)
128        ):
129            raise RuntimeError(
130                "No message content returned in the response from LLM API"
131            )
132
133        message = response.choices[0].message
134        logprobs = (
135            response.choices[0].logprobs
136            if hasattr(response.choices[0], "logprobs")
137            and isinstance(response.choices[0].logprobs, ChoiceLogprobs)
138            else None
139        )
140
141        # Check logprobs worked, if requested
142        if self.base_adapter_config.top_logprobs is not None and logprobs is None:
143            raise RuntimeError("Logprobs were required, but no logprobs were returned.")
144
145        # Save reasoning if it exists and was parsed by LiteLLM (or openrouter, or anyone upstream)
146        if (
147            hasattr(message, "reasoning_content")
148            and message.reasoning_content
149            and len(message.reasoning_content.strip()) > 0
150        ):
151            intermediate_outputs["reasoning"] = message.reasoning_content.strip()
152
153        # the string content of the response
154        response_content = message.content
155
156        # Fallback: Use args of first tool call to task_response if it exists
157        if (
158            not response_content
159            and hasattr(message, "tool_calls")
160            and message.tool_calls
161        ):
162            tool_call = next(
163                (
164                    tool_call
165                    for tool_call in message.tool_calls
166                    if tool_call.function.name == "task_response"
167                ),
168                None,
169            )
170            if tool_call:
171                response_content = tool_call.function.arguments
172
173        if not isinstance(response_content, str):
174            raise RuntimeError(f"response is not a string: {response_content}")
175
176        return RunOutput(
177            output=response_content,
178            intermediate_outputs=intermediate_outputs,
179            output_logprobs=logprobs,
180        ), self.usage_from_response(response)
181
182    def adapter_name(self) -> str:
183        return "kiln_openai_compatible_adapter"
184
185    async def response_format_options(self) -> dict[str, Any]:
186        # Unstructured if task isn't structured
187        if not self.has_structured_output():
188            return {}
189
190        provider = self.model_provider()
191        match provider.structured_output_mode:
192            case StructuredOutputMode.json_mode:
193                return {"response_format": {"type": "json_object"}}
194            case StructuredOutputMode.json_schema:
195                return self.json_schema_response_format()
196            case StructuredOutputMode.function_calling_weak:
197                return self.tool_call_params(strict=False)
198            case StructuredOutputMode.function_calling:
199                return self.tool_call_params(strict=True)
200            case StructuredOutputMode.json_instructions:
201                # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below).
202                return {}
203            case StructuredOutputMode.json_custom_instructions:
204                # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above).
205                return {}
206            case StructuredOutputMode.json_instruction_and_object:
207                # We set response_format to json_object and also set json instructions in the prompt
208                return {"response_format": {"type": "json_object"}}
209            case StructuredOutputMode.default:
210                if provider.name == ModelProviderName.ollama:
211                    # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs
212                    return self.json_schema_response_format()
213                else:
214                    # Default to function calling -- it's older than the other modes. Higher compatibility.
215                    # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI.
216                    strict = provider.name == ModelProviderName.openai
217                    return self.tool_call_params(strict=strict)
218            case _:
219                raise_exhaustive_enum_error(provider.structured_output_mode)
220
221    def json_schema_response_format(self) -> dict[str, Any]:
222        output_schema = self.task().output_schema()
223        return {
224            "response_format": {
225                "type": "json_schema",
226                "json_schema": {
227                    "name": "task_response",
228                    "schema": output_schema,
229                },
230            }
231        }
232
233    def tool_call_params(self, strict: bool) -> dict[str, Any]:
234        # Add additional_properties: false to the schema (OpenAI requires this for some models)
235        output_schema = self.task().output_schema()
236        if not isinstance(output_schema, dict):
237            raise ValueError(
238                "Invalid output schema for this task. Can not use tool calls."
239            )
240        output_schema["additionalProperties"] = False
241
242        function_params = {
243            "name": "task_response",
244            "parameters": output_schema,
245        }
246        # This should be on, but we allow setting function_calling_weak for APIs that don't support it.
247        if strict:
248            function_params["strict"] = True
249
250        return {
251            "tools": [
252                {
253                    "type": "function",
254                    "function": function_params,
255                }
256            ],
257            "tool_choice": {
258                "type": "function",
259                "function": {"name": "task_response"},
260            },
261        }
262
263    def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]:
264        # TODO P1: Don't love having this logic here. But it's a usability improvement
265        # so better to keep it than exclude it. Should figure out how I want to isolate
266        # this sort of logic so it's config driven and can be overridden
267
268        extra_body = {}
269        provider_options = {}
270
271        if provider.thinking_level is not None:
272            extra_body["reasoning_effort"] = provider.thinking_level
273
274        if provider.require_openrouter_reasoning:
275            # https://openrouter.ai/docs/use-cases/reasoning-tokens
276            extra_body["reasoning"] = {
277                "exclude": False,
278            }
279
280        if provider.anthropic_extended_thinking:
281            extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000}
282
283        if provider.r1_openrouter_options:
284            # Require providers that support the reasoning parameter
285            provider_options["require_parameters"] = True
286            # Prefer R1 providers with reasonable perf/quants
287            provider_options["order"] = ["Fireworks", "Together"]
288            # R1 providers with unreasonable quants
289            provider_options["ignore"] = ["DeepInfra"]
290
291        # Only set of this request is to get logprobs.
292        if (
293            provider.logprobs_openrouter_options
294            and self.base_adapter_config.top_logprobs is not None
295        ):
296            # Don't let OpenRouter choose a provider that doesn't support logprobs.
297            provider_options["require_parameters"] = True
298            # DeepInfra silently fails to return logprobs consistently.
299            provider_options["ignore"] = ["DeepInfra"]
300
301        if provider.openrouter_skip_required_parameters:
302            # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params.
303            provider_options["require_parameters"] = False
304
305        if len(provider_options) > 0:
306            extra_body["provider"] = provider_options
307
308        return extra_body
309
310    def litellm_model_id(self) -> str:
311        # The model ID is an interesting combination of format and url endpoint.
312        # It specifics the provider URL/host, but this is overridden if you manually set an api url
313
314        if self._litellm_model_id:
315            return self._litellm_model_id
316
317        provider = self.model_provider()
318        if not provider.model_id:
319            raise ValueError("Model ID is required for OpenAI compatible models")
320
321        litellm_provider_name: str | None = None
322        is_custom = False
323        match provider.name:
324            case ModelProviderName.openrouter:
325                litellm_provider_name = "openrouter"
326            case ModelProviderName.openai:
327                litellm_provider_name = "openai"
328            case ModelProviderName.groq:
329                litellm_provider_name = "groq"
330            case ModelProviderName.anthropic:
331                litellm_provider_name = "anthropic"
332            case ModelProviderName.ollama:
333                # We don't let litellm use the Ollama API and muck with our requests. We use Ollama's OpenAI compatible API.
334                # This is because we're setting detailed features like response_format=json_schema and want lower level control.
335                is_custom = True
336            case ModelProviderName.gemini_api:
337                litellm_provider_name = "gemini"
338            case ModelProviderName.fireworks_ai:
339                litellm_provider_name = "fireworks_ai"
340            case ModelProviderName.amazon_bedrock:
341                litellm_provider_name = "bedrock"
342            case ModelProviderName.azure_openai:
343                litellm_provider_name = "azure"
344            case ModelProviderName.huggingface:
345                litellm_provider_name = "huggingface"
346            case ModelProviderName.vertex:
347                litellm_provider_name = "vertex_ai"
348            case ModelProviderName.together_ai:
349                litellm_provider_name = "together_ai"
350            case ModelProviderName.openai_compatible:
351                is_custom = True
352            case ModelProviderName.kiln_custom_registry:
353                is_custom = True
354            case ModelProviderName.kiln_fine_tune:
355                is_custom = True
356            case _:
357                raise_exhaustive_enum_error(provider.name)
358
359        if is_custom:
360            if self._api_base is None:
361                raise ValueError(
362                    "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)"
363                )
364            # Use openai as it's only used for format, not url
365            litellm_provider_name = "openai"
366
367        # Sholdn't be possible but keep type checker happy
368        if litellm_provider_name is None:
369            raise ValueError(
370                f"Provider name could not lookup valid litellm provider ID {provider.model_id}"
371            )
372
373        self._litellm_model_id = litellm_provider_name + "/" + provider.model_id
374        return self._litellm_model_id
375
376    async def build_completion_kwargs(
377        self,
378        provider: KilnModelProvider,
379        messages: list[dict[str, Any]],
380        top_logprobs: int | None,
381        skip_response_format: bool = False,
382    ) -> dict[str, Any]:
383        extra_body = self.build_extra_body(provider)
384
385        # Merge all parameters into a single kwargs dict for litellm
386        completion_kwargs = {
387            "model": self.litellm_model_id(),
388            "messages": messages,
389            "api_base": self._api_base,
390            "headers": self._headers,
391            **extra_body,
392            **self._additional_body_options,
393        }
394
395        if not skip_response_format:
396            # Response format: json_schema, json_instructions, json_mode, function_calling, etc
397            response_format_options = await self.response_format_options()
398            completion_kwargs.update(response_format_options)
399
400        if top_logprobs is not None:
401            completion_kwargs["logprobs"] = True
402            completion_kwargs["top_logprobs"] = top_logprobs
403
404        return completion_kwargs
405
406    def usage_from_response(self, response: ModelResponse) -> Usage | None:
407        litellm_usage = response.get("usage", None)
408        cost = response._hidden_params.get("response_cost", None)
409        if not litellm_usage and not cost:
410            return None
411
412        usage = Usage()
413
414        if litellm_usage and isinstance(litellm_usage, LiteLlmUsage):
415            usage.input_tokens = litellm_usage.get("prompt_tokens", None)
416            usage.output_tokens = litellm_usage.get("completion_tokens", None)
417            usage.total_tokens = litellm_usage.get("total_tokens", None)
418        else:
419            logger.warning(
420                f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}"
421            )
422
423        if isinstance(cost, float):
424            usage.cost = cost
425        elif cost is not None:
426            # None is allowed, but no other types are expected
427            logger.warning(
428                f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}"
429            )
430
431        return usage

Base class for AI model adapters that handle task execution.

This abstract class provides the foundation for implementing model-specific adapters that can process tasks with structured or unstructured inputs/outputs. It handles input/output validation, prompt building, and run tracking.

Attributes: prompt_builder (BasePromptBuilder): Builder for constructing prompts for the model kiln_task (Task): The task configuration and metadata output_schema (dict | None): JSON schema for validating structured outputs input_schema (dict | None): JSON schema for validating structured inputs

LiteLlmAdapter( config: kiln_ai.adapters.model_adapters.litellm_config.LiteLlmConfig, kiln_task: kiln_ai.datamodel.Task, prompt_id: Optional[Annotated[str, AfterValidator(func=<function <lambda>>)]] = None, base_adapter_config: kiln_ai.adapters.model_adapters.base_adapter.AdapterConfig | None = None)
31    def __init__(
32        self,
33        config: LiteLlmConfig,
34        kiln_task: datamodel.Task,
35        prompt_id: PromptId | None = None,
36        base_adapter_config: AdapterConfig | None = None,
37    ):
38        self.config = config
39        self._additional_body_options = config.additional_body_options
40        self._api_base = config.base_url
41        self._headers = config.default_headers
42        self._litellm_model_id: str | None = None
43
44        run_config = RunConfig(
45            task=kiln_task,
46            model_name=config.model_name,
47            model_provider_name=config.provider_name,
48            prompt_id=prompt_id or PromptGenerators.SIMPLE,
49        )
50
51        super().__init__(
52            run_config=run_config,
53            config=base_adapter_config,
54        )
config
def adapter_name(self) -> str:
182    def adapter_name(self) -> str:
183        return "kiln_openai_compatible_adapter"
async def response_format_options(self) -> dict[str, typing.Any]:
185    async def response_format_options(self) -> dict[str, Any]:
186        # Unstructured if task isn't structured
187        if not self.has_structured_output():
188            return {}
189
190        provider = self.model_provider()
191        match provider.structured_output_mode:
192            case StructuredOutputMode.json_mode:
193                return {"response_format": {"type": "json_object"}}
194            case StructuredOutputMode.json_schema:
195                return self.json_schema_response_format()
196            case StructuredOutputMode.function_calling_weak:
197                return self.tool_call_params(strict=False)
198            case StructuredOutputMode.function_calling:
199                return self.tool_call_params(strict=True)
200            case StructuredOutputMode.json_instructions:
201                # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below).
202                return {}
203            case StructuredOutputMode.json_custom_instructions:
204                # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above).
205                return {}
206            case StructuredOutputMode.json_instruction_and_object:
207                # We set response_format to json_object and also set json instructions in the prompt
208                return {"response_format": {"type": "json_object"}}
209            case StructuredOutputMode.default:
210                if provider.name == ModelProviderName.ollama:
211                    # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs
212                    return self.json_schema_response_format()
213                else:
214                    # Default to function calling -- it's older than the other modes. Higher compatibility.
215                    # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI.
216                    strict = provider.name == ModelProviderName.openai
217                    return self.tool_call_params(strict=strict)
218            case _:
219                raise_exhaustive_enum_error(provider.structured_output_mode)
def json_schema_response_format(self) -> dict[str, typing.Any]:
221    def json_schema_response_format(self) -> dict[str, Any]:
222        output_schema = self.task().output_schema()
223        return {
224            "response_format": {
225                "type": "json_schema",
226                "json_schema": {
227                    "name": "task_response",
228                    "schema": output_schema,
229                },
230            }
231        }
def tool_call_params(self, strict: bool) -> dict[str, typing.Any]:
233    def tool_call_params(self, strict: bool) -> dict[str, Any]:
234        # Add additional_properties: false to the schema (OpenAI requires this for some models)
235        output_schema = self.task().output_schema()
236        if not isinstance(output_schema, dict):
237            raise ValueError(
238                "Invalid output schema for this task. Can not use tool calls."
239            )
240        output_schema["additionalProperties"] = False
241
242        function_params = {
243            "name": "task_response",
244            "parameters": output_schema,
245        }
246        # This should be on, but we allow setting function_calling_weak for APIs that don't support it.
247        if strict:
248            function_params["strict"] = True
249
250        return {
251            "tools": [
252                {
253                    "type": "function",
254                    "function": function_params,
255                }
256            ],
257            "tool_choice": {
258                "type": "function",
259                "function": {"name": "task_response"},
260            },
261        }
def build_extra_body( self, provider: kiln_ai.adapters.ml_model_list.KilnModelProvider) -> dict[str, typing.Any]:
263    def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]:
264        # TODO P1: Don't love having this logic here. But it's a usability improvement
265        # so better to keep it than exclude it. Should figure out how I want to isolate
266        # this sort of logic so it's config driven and can be overridden
267
268        extra_body = {}
269        provider_options = {}
270
271        if provider.thinking_level is not None:
272            extra_body["reasoning_effort"] = provider.thinking_level
273
274        if provider.require_openrouter_reasoning:
275            # https://openrouter.ai/docs/use-cases/reasoning-tokens
276            extra_body["reasoning"] = {
277                "exclude": False,
278            }
279
280        if provider.anthropic_extended_thinking:
281            extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000}
282
283        if provider.r1_openrouter_options:
284            # Require providers that support the reasoning parameter
285            provider_options["require_parameters"] = True
286            # Prefer R1 providers with reasonable perf/quants
287            provider_options["order"] = ["Fireworks", "Together"]
288            # R1 providers with unreasonable quants
289            provider_options["ignore"] = ["DeepInfra"]
290
291        # Only set of this request is to get logprobs.
292        if (
293            provider.logprobs_openrouter_options
294            and self.base_adapter_config.top_logprobs is not None
295        ):
296            # Don't let OpenRouter choose a provider that doesn't support logprobs.
297            provider_options["require_parameters"] = True
298            # DeepInfra silently fails to return logprobs consistently.
299            provider_options["ignore"] = ["DeepInfra"]
300
301        if provider.openrouter_skip_required_parameters:
302            # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params.
303            provider_options["require_parameters"] = False
304
305        if len(provider_options) > 0:
306            extra_body["provider"] = provider_options
307
308        return extra_body
def litellm_model_id(self) -> str:
310    def litellm_model_id(self) -> str:
311        # The model ID is an interesting combination of format and url endpoint.
312        # It specifics the provider URL/host, but this is overridden if you manually set an api url
313
314        if self._litellm_model_id:
315            return self._litellm_model_id
316
317        provider = self.model_provider()
318        if not provider.model_id:
319            raise ValueError("Model ID is required for OpenAI compatible models")
320
321        litellm_provider_name: str | None = None
322        is_custom = False
323        match provider.name:
324            case ModelProviderName.openrouter:
325                litellm_provider_name = "openrouter"
326            case ModelProviderName.openai:
327                litellm_provider_name = "openai"
328            case ModelProviderName.groq:
329                litellm_provider_name = "groq"
330            case ModelProviderName.anthropic:
331                litellm_provider_name = "anthropic"
332            case ModelProviderName.ollama:
333                # We don't let litellm use the Ollama API and muck with our requests. We use Ollama's OpenAI compatible API.
334                # This is because we're setting detailed features like response_format=json_schema and want lower level control.
335                is_custom = True
336            case ModelProviderName.gemini_api:
337                litellm_provider_name = "gemini"
338            case ModelProviderName.fireworks_ai:
339                litellm_provider_name = "fireworks_ai"
340            case ModelProviderName.amazon_bedrock:
341                litellm_provider_name = "bedrock"
342            case ModelProviderName.azure_openai:
343                litellm_provider_name = "azure"
344            case ModelProviderName.huggingface:
345                litellm_provider_name = "huggingface"
346            case ModelProviderName.vertex:
347                litellm_provider_name = "vertex_ai"
348            case ModelProviderName.together_ai:
349                litellm_provider_name = "together_ai"
350            case ModelProviderName.openai_compatible:
351                is_custom = True
352            case ModelProviderName.kiln_custom_registry:
353                is_custom = True
354            case ModelProviderName.kiln_fine_tune:
355                is_custom = True
356            case _:
357                raise_exhaustive_enum_error(provider.name)
358
359        if is_custom:
360            if self._api_base is None:
361                raise ValueError(
362                    "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)"
363                )
364            # Use openai as it's only used for format, not url
365            litellm_provider_name = "openai"
366
367        # Sholdn't be possible but keep type checker happy
368        if litellm_provider_name is None:
369            raise ValueError(
370                f"Provider name could not lookup valid litellm provider ID {provider.model_id}"
371            )
372
373        self._litellm_model_id = litellm_provider_name + "/" + provider.model_id
374        return self._litellm_model_id
async def build_completion_kwargs( self, provider: kiln_ai.adapters.ml_model_list.KilnModelProvider, messages: list[dict[str, typing.Any]], top_logprobs: int | None, skip_response_format: bool = False) -> dict[str, typing.Any]:
376    async def build_completion_kwargs(
377        self,
378        provider: KilnModelProvider,
379        messages: list[dict[str, Any]],
380        top_logprobs: int | None,
381        skip_response_format: bool = False,
382    ) -> dict[str, Any]:
383        extra_body = self.build_extra_body(provider)
384
385        # Merge all parameters into a single kwargs dict for litellm
386        completion_kwargs = {
387            "model": self.litellm_model_id(),
388            "messages": messages,
389            "api_base": self._api_base,
390            "headers": self._headers,
391            **extra_body,
392            **self._additional_body_options,
393        }
394
395        if not skip_response_format:
396            # Response format: json_schema, json_instructions, json_mode, function_calling, etc
397            response_format_options = await self.response_format_options()
398            completion_kwargs.update(response_format_options)
399
400        if top_logprobs is not None:
401            completion_kwargs["logprobs"] = True
402            completion_kwargs["top_logprobs"] = top_logprobs
403
404        return completion_kwargs
def usage_from_response( self, response: litellm.types.utils.ModelResponse) -> kiln_ai.datamodel.Usage | None:
406    def usage_from_response(self, response: ModelResponse) -> Usage | None:
407        litellm_usage = response.get("usage", None)
408        cost = response._hidden_params.get("response_cost", None)
409        if not litellm_usage and not cost:
410            return None
411
412        usage = Usage()
413
414        if litellm_usage and isinstance(litellm_usage, LiteLlmUsage):
415            usage.input_tokens = litellm_usage.get("prompt_tokens", None)
416            usage.output_tokens = litellm_usage.get("completion_tokens", None)
417            usage.total_tokens = litellm_usage.get("total_tokens", None)
418        else:
419            logger.warning(
420                f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}"
421            )
422
423        if isinstance(cost, float):
424            usage.cost = cost
425        elif cost is not None:
426            # None is allowed, but no other types are expected
427            logger.warning(
428                f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}"
429            )
430
431        return usage