kiln_ai.adapters.model_adapters.litellm_adapter

  1import logging
  2from typing import Any, Dict
  3
  4import litellm
  5from litellm.types.utils import ChoiceLogprobs, Choices, ModelResponse
  6from litellm.types.utils import Usage as LiteLlmUsage
  7
  8import kiln_ai.datamodel as datamodel
  9from kiln_ai.adapters.ml_model_list import (
 10    KilnModelProvider,
 11    ModelProviderName,
 12    StructuredOutputMode,
 13)
 14from kiln_ai.adapters.model_adapters.base_adapter import (
 15    AdapterConfig,
 16    BaseAdapter,
 17    RunOutput,
 18    Usage,
 19)
 20from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig
 21from kiln_ai.datamodel.task import run_config_from_run_config_properties
 22from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
 23
 24logger = logging.getLogger(__name__)
 25
 26
 27class LiteLlmAdapter(BaseAdapter):
 28    def __init__(
 29        self,
 30        config: LiteLlmConfig,
 31        kiln_task: datamodel.Task,
 32        base_adapter_config: AdapterConfig | None = None,
 33    ):
 34        self.config = config
 35        self._additional_body_options = config.additional_body_options
 36        self._api_base = config.base_url
 37        self._headers = config.default_headers
 38        self._litellm_model_id: str | None = None
 39
 40        # Create a RunConfig, adding the task to the RunConfigProperties
 41        run_config = run_config_from_run_config_properties(
 42            task=kiln_task,
 43            run_config_properties=config.run_config_properties,
 44        )
 45
 46        super().__init__(
 47            run_config=run_config,
 48            config=base_adapter_config,
 49        )
 50
 51    async def _run(self, input: Dict | str) -> tuple[RunOutput, Usage | None]:
 52        provider = self.model_provider()
 53        if not provider.model_id:
 54            raise ValueError("Model ID is required for OpenAI compatible models")
 55
 56        chat_formatter = self.build_chat_formatter(input)
 57
 58        prior_output = None
 59        prior_message = None
 60        response = None
 61        turns = 0
 62        while True:
 63            turns += 1
 64            if turns > 10:
 65                raise RuntimeError(
 66                    "Too many turns. Stopping iteration to avoid using too many tokens."
 67                )
 68
 69            turn = chat_formatter.next_turn(prior_output)
 70            if turn is None:
 71                break
 72
 73            skip_response_format = not turn.final_call
 74            all_messages = chat_formatter.message_dicts()
 75            completion_kwargs = await self.build_completion_kwargs(
 76                provider,
 77                all_messages,
 78                self.base_adapter_config.top_logprobs if turn.final_call else None,
 79                skip_response_format,
 80            )
 81            response = await litellm.acompletion(**completion_kwargs)
 82            if (
 83                not isinstance(response, ModelResponse)
 84                or not response.choices
 85                or len(response.choices) == 0
 86                or not isinstance(response.choices[0], Choices)
 87            ):
 88                raise RuntimeError(
 89                    f"Expected ModelResponse with Choices, got {type(response)}."
 90                )
 91            prior_message = response.choices[0].message
 92            prior_output = prior_message.content
 93
 94            # Fallback: Use args of first tool call to task_response if it exists
 95            if (
 96                not prior_output
 97                and hasattr(prior_message, "tool_calls")
 98                and prior_message.tool_calls
 99            ):
100                tool_call = next(
101                    (
102                        tool_call
103                        for tool_call in prior_message.tool_calls
104                        if tool_call.function.name == "task_response"
105                    ),
106                    None,
107                )
108                if tool_call:
109                    prior_output = tool_call.function.arguments
110
111            if not prior_output:
112                raise RuntimeError("No output returned from model")
113
114        if response is None or prior_message is None:
115            raise RuntimeError("No response returned from model")
116
117        intermediate_outputs = chat_formatter.intermediate_outputs()
118
119        logprobs = (
120            response.choices[0].logprobs
121            if hasattr(response.choices[0], "logprobs")
122            and isinstance(response.choices[0].logprobs, ChoiceLogprobs)
123            else None
124        )
125
126        # Check logprobs worked, if requested
127        if self.base_adapter_config.top_logprobs is not None and logprobs is None:
128            raise RuntimeError("Logprobs were required, but no logprobs were returned.")
129
130        # Save reasoning if it exists and was parsed by LiteLLM (or openrouter, or anyone upstream)
131        if (
132            prior_message is not None
133            and hasattr(prior_message, "reasoning_content")
134            and prior_message.reasoning_content
135            and len(prior_message.reasoning_content.strip()) > 0
136        ):
137            intermediate_outputs["reasoning"] = prior_message.reasoning_content.strip()
138
139        # the string content of the response
140        response_content = prior_output
141
142        if not isinstance(response_content, str):
143            raise RuntimeError(f"response is not a string: {response_content}")
144
145        return RunOutput(
146            output=response_content,
147            intermediate_outputs=intermediate_outputs,
148            output_logprobs=logprobs,
149        ), self.usage_from_response(response)
150
151    def adapter_name(self) -> str:
152        return "kiln_openai_compatible_adapter"
153
154    async def response_format_options(self) -> dict[str, Any]:
155        # Unstructured if task isn't structured
156        if not self.has_structured_output():
157            return {}
158
159        structured_output_mode = self.run_config.structured_output_mode
160
161        match structured_output_mode:
162            case StructuredOutputMode.json_mode:
163                return {"response_format": {"type": "json_object"}}
164            case StructuredOutputMode.json_schema:
165                return self.json_schema_response_format()
166            case StructuredOutputMode.function_calling_weak:
167                return self.tool_call_params(strict=False)
168            case StructuredOutputMode.function_calling:
169                return self.tool_call_params(strict=True)
170            case StructuredOutputMode.json_instructions:
171                # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below).
172                return {}
173            case StructuredOutputMode.json_custom_instructions:
174                # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above).
175                return {}
176            case StructuredOutputMode.json_instruction_and_object:
177                # We set response_format to json_object and also set json instructions in the prompt
178                return {"response_format": {"type": "json_object"}}
179            case StructuredOutputMode.default:
180                provider_name = self.run_config.model_provider_name
181                if provider_name == ModelProviderName.ollama:
182                    # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs
183                    return self.json_schema_response_format()
184                else:
185                    # Default to function calling -- it's older than the other modes. Higher compatibility.
186                    # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI.
187                    strict = provider_name == ModelProviderName.openai
188                    return self.tool_call_params(strict=strict)
189            case StructuredOutputMode.unknown:
190                # See above, but this case should never happen.
191                raise ValueError("Structured output mode is unknown.")
192            case _:
193                raise_exhaustive_enum_error(structured_output_mode)
194
195    def json_schema_response_format(self) -> dict[str, Any]:
196        output_schema = self.task().output_schema()
197        return {
198            "response_format": {
199                "type": "json_schema",
200                "json_schema": {
201                    "name": "task_response",
202                    "schema": output_schema,
203                },
204            }
205        }
206
207    def tool_call_params(self, strict: bool) -> dict[str, Any]:
208        # Add additional_properties: false to the schema (OpenAI requires this for some models)
209        output_schema = self.task().output_schema()
210        if not isinstance(output_schema, dict):
211            raise ValueError(
212                "Invalid output schema for this task. Can not use tool calls."
213            )
214        output_schema["additionalProperties"] = False
215
216        function_params = {
217            "name": "task_response",
218            "parameters": output_schema,
219        }
220        # This should be on, but we allow setting function_calling_weak for APIs that don't support it.
221        if strict:
222            function_params["strict"] = True
223
224        return {
225            "tools": [
226                {
227                    "type": "function",
228                    "function": function_params,
229                }
230            ],
231            "tool_choice": {
232                "type": "function",
233                "function": {"name": "task_response"},
234            },
235        }
236
237    def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]:
238        # TODO P1: Don't love having this logic here. But it's a usability improvement
239        # so better to keep it than exclude it. Should figure out how I want to isolate
240        # this sort of logic so it's config driven and can be overridden
241
242        extra_body = {}
243        provider_options = {}
244
245        if provider.thinking_level is not None:
246            extra_body["reasoning_effort"] = provider.thinking_level
247
248        if provider.require_openrouter_reasoning:
249            # https://openrouter.ai/docs/use-cases/reasoning-tokens
250            extra_body["reasoning"] = {
251                "exclude": False,
252            }
253
254        if provider.anthropic_extended_thinking:
255            extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000}
256
257        if provider.r1_openrouter_options:
258            # Require providers that support the reasoning parameter
259            provider_options["require_parameters"] = True
260            # Prefer R1 providers with reasonable perf/quants
261            provider_options["order"] = ["Fireworks", "Together"]
262            # R1 providers with unreasonable quants
263            provider_options["ignore"] = ["DeepInfra"]
264
265        # Only set of this request is to get logprobs.
266        if (
267            provider.logprobs_openrouter_options
268            and self.base_adapter_config.top_logprobs is not None
269        ):
270            # Don't let OpenRouter choose a provider that doesn't support logprobs.
271            provider_options["require_parameters"] = True
272            # DeepInfra silently fails to return logprobs consistently.
273            provider_options["ignore"] = ["DeepInfra"]
274
275        if provider.openrouter_skip_required_parameters:
276            # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params.
277            provider_options["require_parameters"] = False
278
279        if len(provider_options) > 0:
280            extra_body["provider"] = provider_options
281
282        return extra_body
283
284    def litellm_model_id(self) -> str:
285        # The model ID is an interesting combination of format and url endpoint.
286        # It specifics the provider URL/host, but this is overridden if you manually set an api url
287
288        if self._litellm_model_id:
289            return self._litellm_model_id
290
291        provider = self.model_provider()
292        if not provider.model_id:
293            raise ValueError("Model ID is required for OpenAI compatible models")
294
295        litellm_provider_name: str | None = None
296        is_custom = False
297        match provider.name:
298            case ModelProviderName.openrouter:
299                litellm_provider_name = "openrouter"
300            case ModelProviderName.openai:
301                litellm_provider_name = "openai"
302            case ModelProviderName.groq:
303                litellm_provider_name = "groq"
304            case ModelProviderName.anthropic:
305                litellm_provider_name = "anthropic"
306            case ModelProviderName.ollama:
307                # We don't let litellm use the Ollama API and muck with our requests. We use Ollama's OpenAI compatible API.
308                # This is because we're setting detailed features like response_format=json_schema and want lower level control.
309                is_custom = True
310            case ModelProviderName.gemini_api:
311                litellm_provider_name = "gemini"
312            case ModelProviderName.fireworks_ai:
313                litellm_provider_name = "fireworks_ai"
314            case ModelProviderName.amazon_bedrock:
315                litellm_provider_name = "bedrock"
316            case ModelProviderName.azure_openai:
317                litellm_provider_name = "azure"
318            case ModelProviderName.huggingface:
319                litellm_provider_name = "huggingface"
320            case ModelProviderName.vertex:
321                litellm_provider_name = "vertex_ai"
322            case ModelProviderName.together_ai:
323                litellm_provider_name = "together_ai"
324            case ModelProviderName.openai_compatible:
325                is_custom = True
326            case ModelProviderName.kiln_custom_registry:
327                is_custom = True
328            case ModelProviderName.kiln_fine_tune:
329                is_custom = True
330            case _:
331                raise_exhaustive_enum_error(provider.name)
332
333        if is_custom:
334            if self._api_base is None:
335                raise ValueError(
336                    "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)"
337                )
338            # Use openai as it's only used for format, not url
339            litellm_provider_name = "openai"
340
341        # Sholdn't be possible but keep type checker happy
342        if litellm_provider_name is None:
343            raise ValueError(
344                f"Provider name could not lookup valid litellm provider ID {provider.model_id}"
345            )
346
347        self._litellm_model_id = litellm_provider_name + "/" + provider.model_id
348        return self._litellm_model_id
349
350    async def build_completion_kwargs(
351        self,
352        provider: KilnModelProvider,
353        messages: list[dict[str, Any]],
354        top_logprobs: int | None,
355        skip_response_format: bool = False,
356    ) -> dict[str, Any]:
357        extra_body = self.build_extra_body(provider)
358
359        # Merge all parameters into a single kwargs dict for litellm
360        completion_kwargs = {
361            "model": self.litellm_model_id(),
362            "messages": messages,
363            "api_base": self._api_base,
364            "headers": self._headers,
365            "temperature": self.run_config.temperature,
366            "top_p": self.run_config.top_p,
367            # This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc.
368            # Not all models and providers support all openai params (for example, o3 doesn't support top_p)
369            # Better to ignore them than to fail the model call.
370            # https://docs.litellm.ai/docs/completion/input
371            "drop_params": True,
372            **extra_body,
373            **self._additional_body_options,
374        }
375
376        if not skip_response_format:
377            # Response format: json_schema, json_instructions, json_mode, function_calling, etc
378            response_format_options = await self.response_format_options()
379            completion_kwargs.update(response_format_options)
380
381        if top_logprobs is not None:
382            completion_kwargs["logprobs"] = True
383            completion_kwargs["top_logprobs"] = top_logprobs
384
385        return completion_kwargs
386
387    def usage_from_response(self, response: ModelResponse) -> Usage | None:
388        litellm_usage = response.get("usage", None)
389        cost = response._hidden_params.get("response_cost", None)
390        if not litellm_usage and not cost:
391            return None
392
393        usage = Usage()
394
395        if litellm_usage and isinstance(litellm_usage, LiteLlmUsage):
396            usage.input_tokens = litellm_usage.get("prompt_tokens", None)
397            usage.output_tokens = litellm_usage.get("completion_tokens", None)
398            usage.total_tokens = litellm_usage.get("total_tokens", None)
399        else:
400            logger.warning(
401                f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}"
402            )
403
404        if isinstance(cost, float):
405            usage.cost = cost
406        elif cost is not None:
407            # None is allowed, but no other types are expected
408            logger.warning(
409                f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}"
410            )
411
412        return usage
 28class LiteLlmAdapter(BaseAdapter):
 29    def __init__(
 30        self,
 31        config: LiteLlmConfig,
 32        kiln_task: datamodel.Task,
 33        base_adapter_config: AdapterConfig | None = None,
 34    ):
 35        self.config = config
 36        self._additional_body_options = config.additional_body_options
 37        self._api_base = config.base_url
 38        self._headers = config.default_headers
 39        self._litellm_model_id: str | None = None
 40
 41        # Create a RunConfig, adding the task to the RunConfigProperties
 42        run_config = run_config_from_run_config_properties(
 43            task=kiln_task,
 44            run_config_properties=config.run_config_properties,
 45        )
 46
 47        super().__init__(
 48            run_config=run_config,
 49            config=base_adapter_config,
 50        )
 51
 52    async def _run(self, input: Dict | str) -> tuple[RunOutput, Usage | None]:
 53        provider = self.model_provider()
 54        if not provider.model_id:
 55            raise ValueError("Model ID is required for OpenAI compatible models")
 56
 57        chat_formatter = self.build_chat_formatter(input)
 58
 59        prior_output = None
 60        prior_message = None
 61        response = None
 62        turns = 0
 63        while True:
 64            turns += 1
 65            if turns > 10:
 66                raise RuntimeError(
 67                    "Too many turns. Stopping iteration to avoid using too many tokens."
 68                )
 69
 70            turn = chat_formatter.next_turn(prior_output)
 71            if turn is None:
 72                break
 73
 74            skip_response_format = not turn.final_call
 75            all_messages = chat_formatter.message_dicts()
 76            completion_kwargs = await self.build_completion_kwargs(
 77                provider,
 78                all_messages,
 79                self.base_adapter_config.top_logprobs if turn.final_call else None,
 80                skip_response_format,
 81            )
 82            response = await litellm.acompletion(**completion_kwargs)
 83            if (
 84                not isinstance(response, ModelResponse)
 85                or not response.choices
 86                or len(response.choices) == 0
 87                or not isinstance(response.choices[0], Choices)
 88            ):
 89                raise RuntimeError(
 90                    f"Expected ModelResponse with Choices, got {type(response)}."
 91                )
 92            prior_message = response.choices[0].message
 93            prior_output = prior_message.content
 94
 95            # Fallback: Use args of first tool call to task_response if it exists
 96            if (
 97                not prior_output
 98                and hasattr(prior_message, "tool_calls")
 99                and prior_message.tool_calls
100            ):
101                tool_call = next(
102                    (
103                        tool_call
104                        for tool_call in prior_message.tool_calls
105                        if tool_call.function.name == "task_response"
106                    ),
107                    None,
108                )
109                if tool_call:
110                    prior_output = tool_call.function.arguments
111
112            if not prior_output:
113                raise RuntimeError("No output returned from model")
114
115        if response is None or prior_message is None:
116            raise RuntimeError("No response returned from model")
117
118        intermediate_outputs = chat_formatter.intermediate_outputs()
119
120        logprobs = (
121            response.choices[0].logprobs
122            if hasattr(response.choices[0], "logprobs")
123            and isinstance(response.choices[0].logprobs, ChoiceLogprobs)
124            else None
125        )
126
127        # Check logprobs worked, if requested
128        if self.base_adapter_config.top_logprobs is not None and logprobs is None:
129            raise RuntimeError("Logprobs were required, but no logprobs were returned.")
130
131        # Save reasoning if it exists and was parsed by LiteLLM (or openrouter, or anyone upstream)
132        if (
133            prior_message is not None
134            and hasattr(prior_message, "reasoning_content")
135            and prior_message.reasoning_content
136            and len(prior_message.reasoning_content.strip()) > 0
137        ):
138            intermediate_outputs["reasoning"] = prior_message.reasoning_content.strip()
139
140        # the string content of the response
141        response_content = prior_output
142
143        if not isinstance(response_content, str):
144            raise RuntimeError(f"response is not a string: {response_content}")
145
146        return RunOutput(
147            output=response_content,
148            intermediate_outputs=intermediate_outputs,
149            output_logprobs=logprobs,
150        ), self.usage_from_response(response)
151
152    def adapter_name(self) -> str:
153        return "kiln_openai_compatible_adapter"
154
155    async def response_format_options(self) -> dict[str, Any]:
156        # Unstructured if task isn't structured
157        if not self.has_structured_output():
158            return {}
159
160        structured_output_mode = self.run_config.structured_output_mode
161
162        match structured_output_mode:
163            case StructuredOutputMode.json_mode:
164                return {"response_format": {"type": "json_object"}}
165            case StructuredOutputMode.json_schema:
166                return self.json_schema_response_format()
167            case StructuredOutputMode.function_calling_weak:
168                return self.tool_call_params(strict=False)
169            case StructuredOutputMode.function_calling:
170                return self.tool_call_params(strict=True)
171            case StructuredOutputMode.json_instructions:
172                # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below).
173                return {}
174            case StructuredOutputMode.json_custom_instructions:
175                # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above).
176                return {}
177            case StructuredOutputMode.json_instruction_and_object:
178                # We set response_format to json_object and also set json instructions in the prompt
179                return {"response_format": {"type": "json_object"}}
180            case StructuredOutputMode.default:
181                provider_name = self.run_config.model_provider_name
182                if provider_name == ModelProviderName.ollama:
183                    # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs
184                    return self.json_schema_response_format()
185                else:
186                    # Default to function calling -- it's older than the other modes. Higher compatibility.
187                    # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI.
188                    strict = provider_name == ModelProviderName.openai
189                    return self.tool_call_params(strict=strict)
190            case StructuredOutputMode.unknown:
191                # See above, but this case should never happen.
192                raise ValueError("Structured output mode is unknown.")
193            case _:
194                raise_exhaustive_enum_error(structured_output_mode)
195
196    def json_schema_response_format(self) -> dict[str, Any]:
197        output_schema = self.task().output_schema()
198        return {
199            "response_format": {
200                "type": "json_schema",
201                "json_schema": {
202                    "name": "task_response",
203                    "schema": output_schema,
204                },
205            }
206        }
207
208    def tool_call_params(self, strict: bool) -> dict[str, Any]:
209        # Add additional_properties: false to the schema (OpenAI requires this for some models)
210        output_schema = self.task().output_schema()
211        if not isinstance(output_schema, dict):
212            raise ValueError(
213                "Invalid output schema for this task. Can not use tool calls."
214            )
215        output_schema["additionalProperties"] = False
216
217        function_params = {
218            "name": "task_response",
219            "parameters": output_schema,
220        }
221        # This should be on, but we allow setting function_calling_weak for APIs that don't support it.
222        if strict:
223            function_params["strict"] = True
224
225        return {
226            "tools": [
227                {
228                    "type": "function",
229                    "function": function_params,
230                }
231            ],
232            "tool_choice": {
233                "type": "function",
234                "function": {"name": "task_response"},
235            },
236        }
237
238    def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]:
239        # TODO P1: Don't love having this logic here. But it's a usability improvement
240        # so better to keep it than exclude it. Should figure out how I want to isolate
241        # this sort of logic so it's config driven and can be overridden
242
243        extra_body = {}
244        provider_options = {}
245
246        if provider.thinking_level is not None:
247            extra_body["reasoning_effort"] = provider.thinking_level
248
249        if provider.require_openrouter_reasoning:
250            # https://openrouter.ai/docs/use-cases/reasoning-tokens
251            extra_body["reasoning"] = {
252                "exclude": False,
253            }
254
255        if provider.anthropic_extended_thinking:
256            extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000}
257
258        if provider.r1_openrouter_options:
259            # Require providers that support the reasoning parameter
260            provider_options["require_parameters"] = True
261            # Prefer R1 providers with reasonable perf/quants
262            provider_options["order"] = ["Fireworks", "Together"]
263            # R1 providers with unreasonable quants
264            provider_options["ignore"] = ["DeepInfra"]
265
266        # Only set of this request is to get logprobs.
267        if (
268            provider.logprobs_openrouter_options
269            and self.base_adapter_config.top_logprobs is not None
270        ):
271            # Don't let OpenRouter choose a provider that doesn't support logprobs.
272            provider_options["require_parameters"] = True
273            # DeepInfra silently fails to return logprobs consistently.
274            provider_options["ignore"] = ["DeepInfra"]
275
276        if provider.openrouter_skip_required_parameters:
277            # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params.
278            provider_options["require_parameters"] = False
279
280        if len(provider_options) > 0:
281            extra_body["provider"] = provider_options
282
283        return extra_body
284
285    def litellm_model_id(self) -> str:
286        # The model ID is an interesting combination of format and url endpoint.
287        # It specifics the provider URL/host, but this is overridden if you manually set an api url
288
289        if self._litellm_model_id:
290            return self._litellm_model_id
291
292        provider = self.model_provider()
293        if not provider.model_id:
294            raise ValueError("Model ID is required for OpenAI compatible models")
295
296        litellm_provider_name: str | None = None
297        is_custom = False
298        match provider.name:
299            case ModelProviderName.openrouter:
300                litellm_provider_name = "openrouter"
301            case ModelProviderName.openai:
302                litellm_provider_name = "openai"
303            case ModelProviderName.groq:
304                litellm_provider_name = "groq"
305            case ModelProviderName.anthropic:
306                litellm_provider_name = "anthropic"
307            case ModelProviderName.ollama:
308                # We don't let litellm use the Ollama API and muck with our requests. We use Ollama's OpenAI compatible API.
309                # This is because we're setting detailed features like response_format=json_schema and want lower level control.
310                is_custom = True
311            case ModelProviderName.gemini_api:
312                litellm_provider_name = "gemini"
313            case ModelProviderName.fireworks_ai:
314                litellm_provider_name = "fireworks_ai"
315            case ModelProviderName.amazon_bedrock:
316                litellm_provider_name = "bedrock"
317            case ModelProviderName.azure_openai:
318                litellm_provider_name = "azure"
319            case ModelProviderName.huggingface:
320                litellm_provider_name = "huggingface"
321            case ModelProviderName.vertex:
322                litellm_provider_name = "vertex_ai"
323            case ModelProviderName.together_ai:
324                litellm_provider_name = "together_ai"
325            case ModelProviderName.openai_compatible:
326                is_custom = True
327            case ModelProviderName.kiln_custom_registry:
328                is_custom = True
329            case ModelProviderName.kiln_fine_tune:
330                is_custom = True
331            case _:
332                raise_exhaustive_enum_error(provider.name)
333
334        if is_custom:
335            if self._api_base is None:
336                raise ValueError(
337                    "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)"
338                )
339            # Use openai as it's only used for format, not url
340            litellm_provider_name = "openai"
341
342        # Sholdn't be possible but keep type checker happy
343        if litellm_provider_name is None:
344            raise ValueError(
345                f"Provider name could not lookup valid litellm provider ID {provider.model_id}"
346            )
347
348        self._litellm_model_id = litellm_provider_name + "/" + provider.model_id
349        return self._litellm_model_id
350
351    async def build_completion_kwargs(
352        self,
353        provider: KilnModelProvider,
354        messages: list[dict[str, Any]],
355        top_logprobs: int | None,
356        skip_response_format: bool = False,
357    ) -> dict[str, Any]:
358        extra_body = self.build_extra_body(provider)
359
360        # Merge all parameters into a single kwargs dict for litellm
361        completion_kwargs = {
362            "model": self.litellm_model_id(),
363            "messages": messages,
364            "api_base": self._api_base,
365            "headers": self._headers,
366            "temperature": self.run_config.temperature,
367            "top_p": self.run_config.top_p,
368            # This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc.
369            # Not all models and providers support all openai params (for example, o3 doesn't support top_p)
370            # Better to ignore them than to fail the model call.
371            # https://docs.litellm.ai/docs/completion/input
372            "drop_params": True,
373            **extra_body,
374            **self._additional_body_options,
375        }
376
377        if not skip_response_format:
378            # Response format: json_schema, json_instructions, json_mode, function_calling, etc
379            response_format_options = await self.response_format_options()
380            completion_kwargs.update(response_format_options)
381
382        if top_logprobs is not None:
383            completion_kwargs["logprobs"] = True
384            completion_kwargs["top_logprobs"] = top_logprobs
385
386        return completion_kwargs
387
388    def usage_from_response(self, response: ModelResponse) -> Usage | None:
389        litellm_usage = response.get("usage", None)
390        cost = response._hidden_params.get("response_cost", None)
391        if not litellm_usage and not cost:
392            return None
393
394        usage = Usage()
395
396        if litellm_usage and isinstance(litellm_usage, LiteLlmUsage):
397            usage.input_tokens = litellm_usage.get("prompt_tokens", None)
398            usage.output_tokens = litellm_usage.get("completion_tokens", None)
399            usage.total_tokens = litellm_usage.get("total_tokens", None)
400        else:
401            logger.warning(
402                f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}"
403            )
404
405        if isinstance(cost, float):
406            usage.cost = cost
407        elif cost is not None:
408            # None is allowed, but no other types are expected
409            logger.warning(
410                f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}"
411            )
412
413        return usage

Base class for AI model adapters that handle task execution.

This abstract class provides the foundation for implementing model-specific adapters that can process tasks with structured or unstructured inputs/outputs. It handles input/output validation, prompt building, and run tracking.

Attributes: prompt_builder (BasePromptBuilder): Builder for constructing prompts for the model kiln_task (Task): The task configuration and metadata output_schema (dict | None): JSON schema for validating structured outputs input_schema (dict | None): JSON schema for validating structured inputs

LiteLlmAdapter( config: kiln_ai.adapters.model_adapters.litellm_config.LiteLlmConfig, kiln_task: kiln_ai.datamodel.Task, base_adapter_config: kiln_ai.adapters.model_adapters.base_adapter.AdapterConfig | None = None)
29    def __init__(
30        self,
31        config: LiteLlmConfig,
32        kiln_task: datamodel.Task,
33        base_adapter_config: AdapterConfig | None = None,
34    ):
35        self.config = config
36        self._additional_body_options = config.additional_body_options
37        self._api_base = config.base_url
38        self._headers = config.default_headers
39        self._litellm_model_id: str | None = None
40
41        # Create a RunConfig, adding the task to the RunConfigProperties
42        run_config = run_config_from_run_config_properties(
43            task=kiln_task,
44            run_config_properties=config.run_config_properties,
45        )
46
47        super().__init__(
48            run_config=run_config,
49            config=base_adapter_config,
50        )
config
def adapter_name(self) -> str:
152    def adapter_name(self) -> str:
153        return "kiln_openai_compatible_adapter"
async def response_format_options(self) -> dict[str, typing.Any]:
155    async def response_format_options(self) -> dict[str, Any]:
156        # Unstructured if task isn't structured
157        if not self.has_structured_output():
158            return {}
159
160        structured_output_mode = self.run_config.structured_output_mode
161
162        match structured_output_mode:
163            case StructuredOutputMode.json_mode:
164                return {"response_format": {"type": "json_object"}}
165            case StructuredOutputMode.json_schema:
166                return self.json_schema_response_format()
167            case StructuredOutputMode.function_calling_weak:
168                return self.tool_call_params(strict=False)
169            case StructuredOutputMode.function_calling:
170                return self.tool_call_params(strict=True)
171            case StructuredOutputMode.json_instructions:
172                # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below).
173                return {}
174            case StructuredOutputMode.json_custom_instructions:
175                # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above).
176                return {}
177            case StructuredOutputMode.json_instruction_and_object:
178                # We set response_format to json_object and also set json instructions in the prompt
179                return {"response_format": {"type": "json_object"}}
180            case StructuredOutputMode.default:
181                provider_name = self.run_config.model_provider_name
182                if provider_name == ModelProviderName.ollama:
183                    # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs
184                    return self.json_schema_response_format()
185                else:
186                    # Default to function calling -- it's older than the other modes. Higher compatibility.
187                    # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI.
188                    strict = provider_name == ModelProviderName.openai
189                    return self.tool_call_params(strict=strict)
190            case StructuredOutputMode.unknown:
191                # See above, but this case should never happen.
192                raise ValueError("Structured output mode is unknown.")
193            case _:
194                raise_exhaustive_enum_error(structured_output_mode)
def json_schema_response_format(self) -> dict[str, typing.Any]:
196    def json_schema_response_format(self) -> dict[str, Any]:
197        output_schema = self.task().output_schema()
198        return {
199            "response_format": {
200                "type": "json_schema",
201                "json_schema": {
202                    "name": "task_response",
203                    "schema": output_schema,
204                },
205            }
206        }
def tool_call_params(self, strict: bool) -> dict[str, typing.Any]:
208    def tool_call_params(self, strict: bool) -> dict[str, Any]:
209        # Add additional_properties: false to the schema (OpenAI requires this for some models)
210        output_schema = self.task().output_schema()
211        if not isinstance(output_schema, dict):
212            raise ValueError(
213                "Invalid output schema for this task. Can not use tool calls."
214            )
215        output_schema["additionalProperties"] = False
216
217        function_params = {
218            "name": "task_response",
219            "parameters": output_schema,
220        }
221        # This should be on, but we allow setting function_calling_weak for APIs that don't support it.
222        if strict:
223            function_params["strict"] = True
224
225        return {
226            "tools": [
227                {
228                    "type": "function",
229                    "function": function_params,
230                }
231            ],
232            "tool_choice": {
233                "type": "function",
234                "function": {"name": "task_response"},
235            },
236        }
def build_extra_body( self, provider: kiln_ai.adapters.ml_model_list.KilnModelProvider) -> dict[str, typing.Any]:
238    def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]:
239        # TODO P1: Don't love having this logic here. But it's a usability improvement
240        # so better to keep it than exclude it. Should figure out how I want to isolate
241        # this sort of logic so it's config driven and can be overridden
242
243        extra_body = {}
244        provider_options = {}
245
246        if provider.thinking_level is not None:
247            extra_body["reasoning_effort"] = provider.thinking_level
248
249        if provider.require_openrouter_reasoning:
250            # https://openrouter.ai/docs/use-cases/reasoning-tokens
251            extra_body["reasoning"] = {
252                "exclude": False,
253            }
254
255        if provider.anthropic_extended_thinking:
256            extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000}
257
258        if provider.r1_openrouter_options:
259            # Require providers that support the reasoning parameter
260            provider_options["require_parameters"] = True
261            # Prefer R1 providers with reasonable perf/quants
262            provider_options["order"] = ["Fireworks", "Together"]
263            # R1 providers with unreasonable quants
264            provider_options["ignore"] = ["DeepInfra"]
265
266        # Only set of this request is to get logprobs.
267        if (
268            provider.logprobs_openrouter_options
269            and self.base_adapter_config.top_logprobs is not None
270        ):
271            # Don't let OpenRouter choose a provider that doesn't support logprobs.
272            provider_options["require_parameters"] = True
273            # DeepInfra silently fails to return logprobs consistently.
274            provider_options["ignore"] = ["DeepInfra"]
275
276        if provider.openrouter_skip_required_parameters:
277            # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params.
278            provider_options["require_parameters"] = False
279
280        if len(provider_options) > 0:
281            extra_body["provider"] = provider_options
282
283        return extra_body
def litellm_model_id(self) -> str:
285    def litellm_model_id(self) -> str:
286        # The model ID is an interesting combination of format and url endpoint.
287        # It specifics the provider URL/host, but this is overridden if you manually set an api url
288
289        if self._litellm_model_id:
290            return self._litellm_model_id
291
292        provider = self.model_provider()
293        if not provider.model_id:
294            raise ValueError("Model ID is required for OpenAI compatible models")
295
296        litellm_provider_name: str | None = None
297        is_custom = False
298        match provider.name:
299            case ModelProviderName.openrouter:
300                litellm_provider_name = "openrouter"
301            case ModelProviderName.openai:
302                litellm_provider_name = "openai"
303            case ModelProviderName.groq:
304                litellm_provider_name = "groq"
305            case ModelProviderName.anthropic:
306                litellm_provider_name = "anthropic"
307            case ModelProviderName.ollama:
308                # We don't let litellm use the Ollama API and muck with our requests. We use Ollama's OpenAI compatible API.
309                # This is because we're setting detailed features like response_format=json_schema and want lower level control.
310                is_custom = True
311            case ModelProviderName.gemini_api:
312                litellm_provider_name = "gemini"
313            case ModelProviderName.fireworks_ai:
314                litellm_provider_name = "fireworks_ai"
315            case ModelProviderName.amazon_bedrock:
316                litellm_provider_name = "bedrock"
317            case ModelProviderName.azure_openai:
318                litellm_provider_name = "azure"
319            case ModelProviderName.huggingface:
320                litellm_provider_name = "huggingface"
321            case ModelProviderName.vertex:
322                litellm_provider_name = "vertex_ai"
323            case ModelProviderName.together_ai:
324                litellm_provider_name = "together_ai"
325            case ModelProviderName.openai_compatible:
326                is_custom = True
327            case ModelProviderName.kiln_custom_registry:
328                is_custom = True
329            case ModelProviderName.kiln_fine_tune:
330                is_custom = True
331            case _:
332                raise_exhaustive_enum_error(provider.name)
333
334        if is_custom:
335            if self._api_base is None:
336                raise ValueError(
337                    "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)"
338                )
339            # Use openai as it's only used for format, not url
340            litellm_provider_name = "openai"
341
342        # Sholdn't be possible but keep type checker happy
343        if litellm_provider_name is None:
344            raise ValueError(
345                f"Provider name could not lookup valid litellm provider ID {provider.model_id}"
346            )
347
348        self._litellm_model_id = litellm_provider_name + "/" + provider.model_id
349        return self._litellm_model_id
async def build_completion_kwargs( self, provider: kiln_ai.adapters.ml_model_list.KilnModelProvider, messages: list[dict[str, typing.Any]], top_logprobs: int | None, skip_response_format: bool = False) -> dict[str, typing.Any]:
351    async def build_completion_kwargs(
352        self,
353        provider: KilnModelProvider,
354        messages: list[dict[str, Any]],
355        top_logprobs: int | None,
356        skip_response_format: bool = False,
357    ) -> dict[str, Any]:
358        extra_body = self.build_extra_body(provider)
359
360        # Merge all parameters into a single kwargs dict for litellm
361        completion_kwargs = {
362            "model": self.litellm_model_id(),
363            "messages": messages,
364            "api_base": self._api_base,
365            "headers": self._headers,
366            "temperature": self.run_config.temperature,
367            "top_p": self.run_config.top_p,
368            # This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc.
369            # Not all models and providers support all openai params (for example, o3 doesn't support top_p)
370            # Better to ignore them than to fail the model call.
371            # https://docs.litellm.ai/docs/completion/input
372            "drop_params": True,
373            **extra_body,
374            **self._additional_body_options,
375        }
376
377        if not skip_response_format:
378            # Response format: json_schema, json_instructions, json_mode, function_calling, etc
379            response_format_options = await self.response_format_options()
380            completion_kwargs.update(response_format_options)
381
382        if top_logprobs is not None:
383            completion_kwargs["logprobs"] = True
384            completion_kwargs["top_logprobs"] = top_logprobs
385
386        return completion_kwargs
def usage_from_response( self, response: litellm.types.utils.ModelResponse) -> kiln_ai.datamodel.Usage | None:
388    def usage_from_response(self, response: ModelResponse) -> Usage | None:
389        litellm_usage = response.get("usage", None)
390        cost = response._hidden_params.get("response_cost", None)
391        if not litellm_usage and not cost:
392            return None
393
394        usage = Usage()
395
396        if litellm_usage and isinstance(litellm_usage, LiteLlmUsage):
397            usage.input_tokens = litellm_usage.get("prompt_tokens", None)
398            usage.output_tokens = litellm_usage.get("completion_tokens", None)
399            usage.total_tokens = litellm_usage.get("total_tokens", None)
400        else:
401            logger.warning(
402                f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}"
403            )
404
405        if isinstance(cost, float):
406            usage.cost = cost
407        elif cost is not None:
408            # None is allowed, but no other types are expected
409            logger.warning(
410                f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}"
411            )
412
413        return usage