kiln_ai.adapters.langchain_adapters
1import os 2from os import getenv 3from typing import Any, Dict 4 5from langchain_aws import ChatBedrockConverse 6from langchain_core.language_models import LanguageModelInput 7from langchain_core.language_models.chat_models import BaseChatModel 8from langchain_core.messages import AIMessage, HumanMessage, SystemMessage 9from langchain_core.messages.base import BaseMessage 10from langchain_core.runnables import Runnable 11from langchain_fireworks import ChatFireworks 12from langchain_groq import ChatGroq 13from langchain_ollama import ChatOllama 14from langchain_openai import ChatOpenAI 15from pydantic import BaseModel 16 17import kiln_ai.datamodel as datamodel 18from kiln_ai.adapters.ollama_tools import ( 19 get_ollama_connection, 20 ollama_base_url, 21 ollama_model_installed, 22) 23from kiln_ai.utils.config import Config 24 25from .base_adapter import AdapterInfo, BaseAdapter, BasePromptBuilder, RunOutput 26from .ml_model_list import KilnModelProvider, ModelProviderName 27from .provider_tools import kiln_model_provider_from 28 29LangChainModelType = BaseChatModel | Runnable[LanguageModelInput, Dict | BaseModel] 30 31 32class LangchainAdapter(BaseAdapter): 33 _model: LangChainModelType | None = None 34 35 def __init__( 36 self, 37 kiln_task: datamodel.Task, 38 custom_model: BaseChatModel | None = None, 39 model_name: str | None = None, 40 provider: str | None = None, 41 prompt_builder: BasePromptBuilder | None = None, 42 tags: list[str] | None = None, 43 ): 44 super().__init__(kiln_task, prompt_builder=prompt_builder, tags=tags) 45 if custom_model is not None: 46 self._model = custom_model 47 48 # Attempt to infer model provider and name from custom model 49 self.model_provider = "custom.langchain:" + custom_model.__class__.__name__ 50 self.model_name = "custom.langchain:unknown_model" 51 if hasattr(custom_model, "model_name") and isinstance( 52 getattr(custom_model, "model_name"), str 53 ): 54 self.model_name = "custom.langchain:" + getattr( 55 custom_model, "model_name" 56 ) 57 if hasattr(custom_model, "model") and isinstance( 58 getattr(custom_model, "model"), str 59 ): 60 self.model_name = "custom.langchain:" + getattr(custom_model, "model") 61 elif model_name is not None: 62 self.model_name = model_name 63 self.model_provider = provider or "custom.langchain.default_provider" 64 else: 65 raise ValueError( 66 "model_name and provider must be provided if custom_model is not provided" 67 ) 68 69 async def model(self) -> LangChainModelType: 70 # cached model 71 if self._model: 72 return self._model 73 74 self._model = await langchain_model_from(self.model_name, self.model_provider) 75 76 if self.has_structured_output(): 77 if not hasattr(self._model, "with_structured_output") or not callable( 78 getattr(self._model, "with_structured_output") 79 ): 80 raise ValueError( 81 f"model {self._model} does not support structured output, cannot use output_json_schema" 82 ) 83 # Langchain expects title/description to be at top level, on top of json schema 84 output_schema = self.kiln_task.output_schema() 85 if output_schema is None: 86 raise ValueError( 87 f"output_json_schema is not valid json: {self.kiln_task.output_json_schema}" 88 ) 89 output_schema["title"] = "task_response" 90 output_schema["description"] = "A response from the task" 91 with_structured_output_options = await get_structured_output_options( 92 self.model_name, self.model_provider 93 ) 94 self._model = self._model.with_structured_output( 95 output_schema, 96 include_raw=True, 97 **with_structured_output_options, 98 ) 99 return self._model 100 101 async def _run(self, input: Dict | str) -> RunOutput: 102 model = await self.model() 103 chain = model 104 intermediate_outputs = {} 105 106 prompt = self.build_prompt() 107 user_msg = self.prompt_builder.build_user_message(input) 108 messages = [ 109 SystemMessage(content=prompt), 110 HumanMessage(content=user_msg), 111 ] 112 113 # COT with structured output 114 cot_prompt = self.prompt_builder.chain_of_thought_prompt() 115 if cot_prompt and self.has_structured_output(): 116 # Base model (without structured output) used for COT message 117 base_model = await langchain_model_from( 118 self.model_name, self.model_provider 119 ) 120 messages.append( 121 SystemMessage(content=cot_prompt), 122 ) 123 124 cot_messages = [*messages] 125 cot_response = await base_model.ainvoke(cot_messages) 126 intermediate_outputs["chain_of_thought"] = cot_response.content 127 messages.append(AIMessage(content=cot_response.content)) 128 messages.append( 129 SystemMessage(content="Considering the above, return a final result.") 130 ) 131 elif cot_prompt: 132 messages.append(SystemMessage(content=cot_prompt)) 133 134 response = await chain.ainvoke(messages) 135 136 if self.has_structured_output(): 137 if ( 138 not isinstance(response, dict) 139 or "parsed" not in response 140 or not isinstance(response["parsed"], dict) 141 ): 142 raise RuntimeError(f"structured response not returned: {response}") 143 structured_response = response["parsed"] 144 return RunOutput( 145 output=self._munge_response(structured_response), 146 intermediate_outputs=intermediate_outputs, 147 ) 148 else: 149 if not isinstance(response, BaseMessage): 150 raise RuntimeError(f"response is not a BaseMessage: {response}") 151 text_content = response.content 152 if not isinstance(text_content, str): 153 raise RuntimeError(f"response is not a string: {text_content}") 154 return RunOutput( 155 output=text_content, 156 intermediate_outputs=intermediate_outputs, 157 ) 158 159 def adapter_info(self) -> AdapterInfo: 160 return AdapterInfo( 161 model_name=self.model_name, 162 model_provider=self.model_provider, 163 adapter_name="kiln_langchain_adapter", 164 prompt_builder_name=self.prompt_builder.__class__.prompt_builder_name(), 165 ) 166 167 def _munge_response(self, response: Dict) -> Dict: 168 # Mistral Large tool calling format is a bit different. Convert to standard format. 169 if ( 170 "name" in response 171 and response["name"] == "task_response" 172 and "arguments" in response 173 ): 174 return response["arguments"] 175 return response 176 177 178async def get_structured_output_options( 179 model_name: str, model_provider: str 180) -> Dict[str, Any]: 181 finetune_provider = await kiln_model_provider_from(model_name, model_provider) 182 if finetune_provider and finetune_provider.adapter_options.get("langchain"): 183 return finetune_provider.adapter_options["langchain"].get( 184 "with_structured_output_options", {} 185 ) 186 return {} 187 188 189async def langchain_model_from( 190 name: str, provider_name: str | None = None 191) -> BaseChatModel: 192 provider = await kiln_model_provider_from(name, provider_name) 193 return await langchain_model_from_provider(provider, name) 194 195 196async def langchain_model_from_provider( 197 provider: KilnModelProvider, model_name: str 198) -> BaseChatModel: 199 if provider.name == ModelProviderName.openai: 200 api_key = Config.shared().open_ai_api_key 201 return ChatOpenAI(**provider.provider_options, openai_api_key=api_key) # type: ignore[arg-type] 202 elif provider.name == ModelProviderName.openai_compatible: 203 # See provider_tools.py for how base_url, key and other parameters are set 204 return ChatOpenAI(**provider.provider_options) # type: ignore[arg-type] 205 elif provider.name == ModelProviderName.groq: 206 api_key = Config.shared().groq_api_key 207 if api_key is None: 208 raise ValueError( 209 "Attempted to use Groq without an API key set. " 210 "Get your API key from https://console.groq.com/keys" 211 ) 212 return ChatGroq(**provider.provider_options, groq_api_key=api_key) # type: ignore[arg-type] 213 elif provider.name == ModelProviderName.amazon_bedrock: 214 api_key = Config.shared().bedrock_access_key 215 secret_key = Config.shared().bedrock_secret_key 216 # langchain doesn't allow passing these, so ugly hack to set env vars 217 os.environ["AWS_ACCESS_KEY_ID"] = api_key 218 os.environ["AWS_SECRET_ACCESS_KEY"] = secret_key 219 return ChatBedrockConverse( 220 **provider.provider_options, 221 ) 222 elif provider.name == ModelProviderName.fireworks_ai: 223 api_key = Config.shared().fireworks_api_key 224 return ChatFireworks(**provider.provider_options, api_key=api_key) 225 elif provider.name == ModelProviderName.ollama: 226 # Ollama model naming is pretty flexible. We try a few versions of the model name 227 potential_model_names = [] 228 if "model" in provider.provider_options: 229 potential_model_names.append(provider.provider_options["model"]) 230 if "model_aliases" in provider.provider_options: 231 potential_model_names.extend(provider.provider_options["model_aliases"]) 232 233 # Get the list of models Ollama supports 234 ollama_connection = await get_ollama_connection() 235 if ollama_connection is None: 236 raise ValueError("Failed to connect to Ollama. Ensure Ollama is running.") 237 238 for model_name in potential_model_names: 239 if ollama_model_installed(ollama_connection, model_name): 240 return ChatOllama(model=model_name, base_url=ollama_base_url()) 241 242 raise ValueError(f"Model {model_name} not installed on Ollama") 243 elif provider.name == ModelProviderName.openrouter: 244 api_key = Config.shared().open_router_api_key 245 base_url = getenv("OPENROUTER_BASE_URL") or "https://openrouter.ai/api/v1" 246 return ChatOpenAI( 247 **provider.provider_options, 248 openai_api_key=api_key, # type: ignore[arg-type] 249 openai_api_base=base_url, # type: ignore[arg-type] 250 default_headers={ 251 "HTTP-Referer": "https://getkiln.ai/openrouter", 252 "X-Title": "KilnAI", 253 }, 254 ) 255 else: 256 raise ValueError(f"Invalid model or provider: {model_name} - {provider.name}")
LangChainModelType =
typing.Union[langchain_core.language_models.chat_models.BaseChatModel, langchain_core.runnables.base.Runnable[typing.Union[langchain_core.prompt_values.PromptValue, str, collections.abc.Sequence[typing.Union[langchain_core.messages.base.BaseMessage, list[str], tuple[str, str], str, dict[str, typing.Any]]]], typing.Union[typing.Dict, pydantic.main.BaseModel]]]
33class LangchainAdapter(BaseAdapter): 34 _model: LangChainModelType | None = None 35 36 def __init__( 37 self, 38 kiln_task: datamodel.Task, 39 custom_model: BaseChatModel | None = None, 40 model_name: str | None = None, 41 provider: str | None = None, 42 prompt_builder: BasePromptBuilder | None = None, 43 tags: list[str] | None = None, 44 ): 45 super().__init__(kiln_task, prompt_builder=prompt_builder, tags=tags) 46 if custom_model is not None: 47 self._model = custom_model 48 49 # Attempt to infer model provider and name from custom model 50 self.model_provider = "custom.langchain:" + custom_model.__class__.__name__ 51 self.model_name = "custom.langchain:unknown_model" 52 if hasattr(custom_model, "model_name") and isinstance( 53 getattr(custom_model, "model_name"), str 54 ): 55 self.model_name = "custom.langchain:" + getattr( 56 custom_model, "model_name" 57 ) 58 if hasattr(custom_model, "model") and isinstance( 59 getattr(custom_model, "model"), str 60 ): 61 self.model_name = "custom.langchain:" + getattr(custom_model, "model") 62 elif model_name is not None: 63 self.model_name = model_name 64 self.model_provider = provider or "custom.langchain.default_provider" 65 else: 66 raise ValueError( 67 "model_name and provider must be provided if custom_model is not provided" 68 ) 69 70 async def model(self) -> LangChainModelType: 71 # cached model 72 if self._model: 73 return self._model 74 75 self._model = await langchain_model_from(self.model_name, self.model_provider) 76 77 if self.has_structured_output(): 78 if not hasattr(self._model, "with_structured_output") or not callable( 79 getattr(self._model, "with_structured_output") 80 ): 81 raise ValueError( 82 f"model {self._model} does not support structured output, cannot use output_json_schema" 83 ) 84 # Langchain expects title/description to be at top level, on top of json schema 85 output_schema = self.kiln_task.output_schema() 86 if output_schema is None: 87 raise ValueError( 88 f"output_json_schema is not valid json: {self.kiln_task.output_json_schema}" 89 ) 90 output_schema["title"] = "task_response" 91 output_schema["description"] = "A response from the task" 92 with_structured_output_options = await get_structured_output_options( 93 self.model_name, self.model_provider 94 ) 95 self._model = self._model.with_structured_output( 96 output_schema, 97 include_raw=True, 98 **with_structured_output_options, 99 ) 100 return self._model 101 102 async def _run(self, input: Dict | str) -> RunOutput: 103 model = await self.model() 104 chain = model 105 intermediate_outputs = {} 106 107 prompt = self.build_prompt() 108 user_msg = self.prompt_builder.build_user_message(input) 109 messages = [ 110 SystemMessage(content=prompt), 111 HumanMessage(content=user_msg), 112 ] 113 114 # COT with structured output 115 cot_prompt = self.prompt_builder.chain_of_thought_prompt() 116 if cot_prompt and self.has_structured_output(): 117 # Base model (without structured output) used for COT message 118 base_model = await langchain_model_from( 119 self.model_name, self.model_provider 120 ) 121 messages.append( 122 SystemMessage(content=cot_prompt), 123 ) 124 125 cot_messages = [*messages] 126 cot_response = await base_model.ainvoke(cot_messages) 127 intermediate_outputs["chain_of_thought"] = cot_response.content 128 messages.append(AIMessage(content=cot_response.content)) 129 messages.append( 130 SystemMessage(content="Considering the above, return a final result.") 131 ) 132 elif cot_prompt: 133 messages.append(SystemMessage(content=cot_prompt)) 134 135 response = await chain.ainvoke(messages) 136 137 if self.has_structured_output(): 138 if ( 139 not isinstance(response, dict) 140 or "parsed" not in response 141 or not isinstance(response["parsed"], dict) 142 ): 143 raise RuntimeError(f"structured response not returned: {response}") 144 structured_response = response["parsed"] 145 return RunOutput( 146 output=self._munge_response(structured_response), 147 intermediate_outputs=intermediate_outputs, 148 ) 149 else: 150 if not isinstance(response, BaseMessage): 151 raise RuntimeError(f"response is not a BaseMessage: {response}") 152 text_content = response.content 153 if not isinstance(text_content, str): 154 raise RuntimeError(f"response is not a string: {text_content}") 155 return RunOutput( 156 output=text_content, 157 intermediate_outputs=intermediate_outputs, 158 ) 159 160 def adapter_info(self) -> AdapterInfo: 161 return AdapterInfo( 162 model_name=self.model_name, 163 model_provider=self.model_provider, 164 adapter_name="kiln_langchain_adapter", 165 prompt_builder_name=self.prompt_builder.__class__.prompt_builder_name(), 166 ) 167 168 def _munge_response(self, response: Dict) -> Dict: 169 # Mistral Large tool calling format is a bit different. Convert to standard format. 170 if ( 171 "name" in response 172 and response["name"] == "task_response" 173 and "arguments" in response 174 ): 175 return response["arguments"] 176 return response
Base class for AI model adapters that handle task execution.
This abstract class provides the foundation for implementing model-specific adapters that can process tasks with structured or unstructured inputs/outputs. It handles input/output validation, prompt building, and run tracking.
Attributes: prompt_builder (BasePromptBuilder): Builder for constructing prompts for the model kiln_task (Task): The task configuration and metadata output_schema (dict | None): JSON schema for validating structured outputs input_schema (dict | None): JSON schema for validating structured inputs
LangchainAdapter( kiln_task: kiln_ai.datamodel.Task, custom_model: langchain_core.language_models.chat_models.BaseChatModel | None = None, model_name: str | None = None, provider: str | None = None, prompt_builder: kiln_ai.adapters.prompt_builders.BasePromptBuilder | None = None, tags: list[str] | None = None)
36 def __init__( 37 self, 38 kiln_task: datamodel.Task, 39 custom_model: BaseChatModel | None = None, 40 model_name: str | None = None, 41 provider: str | None = None, 42 prompt_builder: BasePromptBuilder | None = None, 43 tags: list[str] | None = None, 44 ): 45 super().__init__(kiln_task, prompt_builder=prompt_builder, tags=tags) 46 if custom_model is not None: 47 self._model = custom_model 48 49 # Attempt to infer model provider and name from custom model 50 self.model_provider = "custom.langchain:" + custom_model.__class__.__name__ 51 self.model_name = "custom.langchain:unknown_model" 52 if hasattr(custom_model, "model_name") and isinstance( 53 getattr(custom_model, "model_name"), str 54 ): 55 self.model_name = "custom.langchain:" + getattr( 56 custom_model, "model_name" 57 ) 58 if hasattr(custom_model, "model") and isinstance( 59 getattr(custom_model, "model"), str 60 ): 61 self.model_name = "custom.langchain:" + getattr(custom_model, "model") 62 elif model_name is not None: 63 self.model_name = model_name 64 self.model_provider = provider or "custom.langchain.default_provider" 65 else: 66 raise ValueError( 67 "model_name and provider must be provided if custom_model is not provided" 68 )
async def
model( self) -> Union[langchain_core.language_models.chat_models.BaseChatModel, langchain_core.runnables.base.Runnable[Union[langchain_core.prompt_values.PromptValue, str, Sequence[Union[langchain_core.messages.base.BaseMessage, list[str], tuple[str, str], str, dict[str, Any]]]], Union[Dict, pydantic.main.BaseModel]]]:
70 async def model(self) -> LangChainModelType: 71 # cached model 72 if self._model: 73 return self._model 74 75 self._model = await langchain_model_from(self.model_name, self.model_provider) 76 77 if self.has_structured_output(): 78 if not hasattr(self._model, "with_structured_output") or not callable( 79 getattr(self._model, "with_structured_output") 80 ): 81 raise ValueError( 82 f"model {self._model} does not support structured output, cannot use output_json_schema" 83 ) 84 # Langchain expects title/description to be at top level, on top of json schema 85 output_schema = self.kiln_task.output_schema() 86 if output_schema is None: 87 raise ValueError( 88 f"output_json_schema is not valid json: {self.kiln_task.output_json_schema}" 89 ) 90 output_schema["title"] = "task_response" 91 output_schema["description"] = "A response from the task" 92 with_structured_output_options = await get_structured_output_options( 93 self.model_name, self.model_provider 94 ) 95 self._model = self._model.with_structured_output( 96 output_schema, 97 include_raw=True, 98 **with_structured_output_options, 99 ) 100 return self._model
async def
get_structured_output_options(model_name: str, model_provider: str) -> Dict[str, Any]:
179async def get_structured_output_options( 180 model_name: str, model_provider: str 181) -> Dict[str, Any]: 182 finetune_provider = await kiln_model_provider_from(model_name, model_provider) 183 if finetune_provider and finetune_provider.adapter_options.get("langchain"): 184 return finetune_provider.adapter_options["langchain"].get( 185 "with_structured_output_options", {} 186 ) 187 return {}
async def
langchain_model_from( name: str, provider_name: str | None = None) -> langchain_core.language_models.chat_models.BaseChatModel:
async def
langchain_model_from_provider( provider: kiln_ai.adapters.ml_model_list.KilnModelProvider, model_name: str) -> langchain_core.language_models.chat_models.BaseChatModel:
197async def langchain_model_from_provider( 198 provider: KilnModelProvider, model_name: str 199) -> BaseChatModel: 200 if provider.name == ModelProviderName.openai: 201 api_key = Config.shared().open_ai_api_key 202 return ChatOpenAI(**provider.provider_options, openai_api_key=api_key) # type: ignore[arg-type] 203 elif provider.name == ModelProviderName.openai_compatible: 204 # See provider_tools.py for how base_url, key and other parameters are set 205 return ChatOpenAI(**provider.provider_options) # type: ignore[arg-type] 206 elif provider.name == ModelProviderName.groq: 207 api_key = Config.shared().groq_api_key 208 if api_key is None: 209 raise ValueError( 210 "Attempted to use Groq without an API key set. " 211 "Get your API key from https://console.groq.com/keys" 212 ) 213 return ChatGroq(**provider.provider_options, groq_api_key=api_key) # type: ignore[arg-type] 214 elif provider.name == ModelProviderName.amazon_bedrock: 215 api_key = Config.shared().bedrock_access_key 216 secret_key = Config.shared().bedrock_secret_key 217 # langchain doesn't allow passing these, so ugly hack to set env vars 218 os.environ["AWS_ACCESS_KEY_ID"] = api_key 219 os.environ["AWS_SECRET_ACCESS_KEY"] = secret_key 220 return ChatBedrockConverse( 221 **provider.provider_options, 222 ) 223 elif provider.name == ModelProviderName.fireworks_ai: 224 api_key = Config.shared().fireworks_api_key 225 return ChatFireworks(**provider.provider_options, api_key=api_key) 226 elif provider.name == ModelProviderName.ollama: 227 # Ollama model naming is pretty flexible. We try a few versions of the model name 228 potential_model_names = [] 229 if "model" in provider.provider_options: 230 potential_model_names.append(provider.provider_options["model"]) 231 if "model_aliases" in provider.provider_options: 232 potential_model_names.extend(provider.provider_options["model_aliases"]) 233 234 # Get the list of models Ollama supports 235 ollama_connection = await get_ollama_connection() 236 if ollama_connection is None: 237 raise ValueError("Failed to connect to Ollama. Ensure Ollama is running.") 238 239 for model_name in potential_model_names: 240 if ollama_model_installed(ollama_connection, model_name): 241 return ChatOllama(model=model_name, base_url=ollama_base_url()) 242 243 raise ValueError(f"Model {model_name} not installed on Ollama") 244 elif provider.name == ModelProviderName.openrouter: 245 api_key = Config.shared().open_router_api_key 246 base_url = getenv("OPENROUTER_BASE_URL") or "https://openrouter.ai/api/v1" 247 return ChatOpenAI( 248 **provider.provider_options, 249 openai_api_key=api_key, # type: ignore[arg-type] 250 openai_api_base=base_url, # type: ignore[arg-type] 251 default_headers={ 252 "HTTP-Referer": "https://getkiln.ai/openrouter", 253 "X-Title": "KilnAI", 254 }, 255 ) 256 else: 257 raise ValueError(f"Invalid model or provider: {model_name} - {provider.name}")