kiln_ai.test_run_context
Tests for agent run context management using contextvars.
1"""Tests for agent run context management using contextvars.""" 2 3import asyncio 4 5import pytest 6 7from kiln_ai.run_context import ( 8 clear_agent_run_id, 9 generate_agent_run_id, 10 get_agent_run_id, 11 set_agent_run_id, 12) 13 14 15class TestAdapterRunContext: 16 """Unit tests for agent run context management.""" 17 18 def test_default_is_none(self): 19 """Test that the default run ID is None.""" 20 # Clear any existing context first 21 clear_agent_run_id() 22 assert get_agent_run_id() is None 23 24 def test_set_and_get(self): 25 """Test basic set and get cycle.""" 26 clear_agent_run_id() 27 assert get_agent_run_id() is None 28 29 test_run_id = "test_run_123" 30 set_agent_run_id(test_run_id) 31 assert get_agent_run_id() == test_run_id 32 33 def test_clear(self): 34 """Test that clear resets to None.""" 35 set_agent_run_id("some_run") 36 assert get_agent_run_id() == "some_run" 37 38 clear_agent_run_id() 39 assert get_agent_run_id() is None 40 41 def test_generate_agent_run_id_format(self): 42 """Test that generated run IDs have the correct format.""" 43 run_id = generate_agent_run_id() 44 assert run_id.startswith("run_") 45 assert len(run_id) == 20 # "run_" + 16 hex characters 46 47 def test_generate_agent_run_id_unique(self): 48 """Test that generated run IDs are unique.""" 49 ids = [generate_agent_run_id() for _ in range(100)] 50 assert len(set(ids)) == 100 # All should be unique 51 52 @pytest.mark.asyncio 53 async def test_asyncio_gather_propagation(self): 54 """Test that contextvar propagates through asyncio.gather.""" 55 clear_agent_run_id() 56 parent_run_id = "parent_run" 57 set_agent_run_id(parent_run_id) 58 59 child_results = await asyncio.gather( 60 self._read_context_var(), 61 self._read_context_var(), 62 self._read_context_var(), 63 ) 64 65 # All children should see the parent's run ID 66 for result in child_results: 67 assert result == parent_run_id 68 69 @pytest.mark.asyncio 70 async def test_asyncio_gather_isolation(self): 71 """Test that child writes don't affect parent or siblings.""" 72 clear_agent_run_id() 73 parent_run_id = "parent_run" 74 set_agent_run_id(parent_run_id) 75 76 async def child_modify_context(): 77 # Child sets a different value 78 set_agent_run_id("child_modified") 79 return get_agent_run_id() 80 81 child_results = await asyncio.gather( 82 child_modify_context(), 83 child_modify_context(), 84 ) 85 86 # Children should see their own modifications 87 for result in child_results: 88 assert result == "child_modified" 89 90 # Parent should still have the original value 91 assert get_agent_run_id() == parent_run_id 92 93 @pytest.mark.asyncio 94 async def test_nested_async_calls_propagate(self): 95 """Test that contextvar propagates through nested async calls.""" 96 clear_agent_run_id() 97 root_run_id = "root_run" 98 set_agent_run_id(root_run_id) 99 100 async def level_2(): 101 return get_agent_run_id() 102 103 async def level_1(): 104 return await level_2() 105 106 result = await level_1() 107 assert result == root_run_id 108 109 async def _read_context_var(self) -> str | None: 110 """Helper to read the context var in an async context.""" 111 return get_agent_run_id()
class
TestAdapterRunContext:
16class TestAdapterRunContext: 17 """Unit tests for agent run context management.""" 18 19 def test_default_is_none(self): 20 """Test that the default run ID is None.""" 21 # Clear any existing context first 22 clear_agent_run_id() 23 assert get_agent_run_id() is None 24 25 def test_set_and_get(self): 26 """Test basic set and get cycle.""" 27 clear_agent_run_id() 28 assert get_agent_run_id() is None 29 30 test_run_id = "test_run_123" 31 set_agent_run_id(test_run_id) 32 assert get_agent_run_id() == test_run_id 33 34 def test_clear(self): 35 """Test that clear resets to None.""" 36 set_agent_run_id("some_run") 37 assert get_agent_run_id() == "some_run" 38 39 clear_agent_run_id() 40 assert get_agent_run_id() is None 41 42 def test_generate_agent_run_id_format(self): 43 """Test that generated run IDs have the correct format.""" 44 run_id = generate_agent_run_id() 45 assert run_id.startswith("run_") 46 assert len(run_id) == 20 # "run_" + 16 hex characters 47 48 def test_generate_agent_run_id_unique(self): 49 """Test that generated run IDs are unique.""" 50 ids = [generate_agent_run_id() for _ in range(100)] 51 assert len(set(ids)) == 100 # All should be unique 52 53 @pytest.mark.asyncio 54 async def test_asyncio_gather_propagation(self): 55 """Test that contextvar propagates through asyncio.gather.""" 56 clear_agent_run_id() 57 parent_run_id = "parent_run" 58 set_agent_run_id(parent_run_id) 59 60 child_results = await asyncio.gather( 61 self._read_context_var(), 62 self._read_context_var(), 63 self._read_context_var(), 64 ) 65 66 # All children should see the parent's run ID 67 for result in child_results: 68 assert result == parent_run_id 69 70 @pytest.mark.asyncio 71 async def test_asyncio_gather_isolation(self): 72 """Test that child writes don't affect parent or siblings.""" 73 clear_agent_run_id() 74 parent_run_id = "parent_run" 75 set_agent_run_id(parent_run_id) 76 77 async def child_modify_context(): 78 # Child sets a different value 79 set_agent_run_id("child_modified") 80 return get_agent_run_id() 81 82 child_results = await asyncio.gather( 83 child_modify_context(), 84 child_modify_context(), 85 ) 86 87 # Children should see their own modifications 88 for result in child_results: 89 assert result == "child_modified" 90 91 # Parent should still have the original value 92 assert get_agent_run_id() == parent_run_id 93 94 @pytest.mark.asyncio 95 async def test_nested_async_calls_propagate(self): 96 """Test that contextvar propagates through nested async calls.""" 97 clear_agent_run_id() 98 root_run_id = "root_run" 99 set_agent_run_id(root_run_id) 100 101 async def level_2(): 102 return get_agent_run_id() 103 104 async def level_1(): 105 return await level_2() 106 107 result = await level_1() 108 assert result == root_run_id 109 110 async def _read_context_var(self) -> str | None: 111 """Helper to read the context var in an async context.""" 112 return get_agent_run_id()
Unit tests for agent run context management.
def
test_default_is_none(self):
19 def test_default_is_none(self): 20 """Test that the default run ID is None.""" 21 # Clear any existing context first 22 clear_agent_run_id() 23 assert get_agent_run_id() is None
Test that the default run ID is None.
def
test_set_and_get(self):
25 def test_set_and_get(self): 26 """Test basic set and get cycle.""" 27 clear_agent_run_id() 28 assert get_agent_run_id() is None 29 30 test_run_id = "test_run_123" 31 set_agent_run_id(test_run_id) 32 assert get_agent_run_id() == test_run_id
Test basic set and get cycle.
def
test_clear(self):
34 def test_clear(self): 35 """Test that clear resets to None.""" 36 set_agent_run_id("some_run") 37 assert get_agent_run_id() == "some_run" 38 39 clear_agent_run_id() 40 assert get_agent_run_id() is None
Test that clear resets to None.
def
test_generate_agent_run_id_format(self):
42 def test_generate_agent_run_id_format(self): 43 """Test that generated run IDs have the correct format.""" 44 run_id = generate_agent_run_id() 45 assert run_id.startswith("run_") 46 assert len(run_id) == 20 # "run_" + 16 hex characters
Test that generated run IDs have the correct format.
def
test_generate_agent_run_id_unique(self):
48 def test_generate_agent_run_id_unique(self): 49 """Test that generated run IDs are unique.""" 50 ids = [generate_agent_run_id() for _ in range(100)] 51 assert len(set(ids)) == 100 # All should be unique
Test that generated run IDs are unique.
@pytest.mark.asyncio
async def
test_asyncio_gather_propagation(self):
53 @pytest.mark.asyncio 54 async def test_asyncio_gather_propagation(self): 55 """Test that contextvar propagates through asyncio.gather.""" 56 clear_agent_run_id() 57 parent_run_id = "parent_run" 58 set_agent_run_id(parent_run_id) 59 60 child_results = await asyncio.gather( 61 self._read_context_var(), 62 self._read_context_var(), 63 self._read_context_var(), 64 ) 65 66 # All children should see the parent's run ID 67 for result in child_results: 68 assert result == parent_run_id
Test that contextvar propagates through asyncio.gather.
@pytest.mark.asyncio
async def
test_asyncio_gather_isolation(self):
70 @pytest.mark.asyncio 71 async def test_asyncio_gather_isolation(self): 72 """Test that child writes don't affect parent or siblings.""" 73 clear_agent_run_id() 74 parent_run_id = "parent_run" 75 set_agent_run_id(parent_run_id) 76 77 async def child_modify_context(): 78 # Child sets a different value 79 set_agent_run_id("child_modified") 80 return get_agent_run_id() 81 82 child_results = await asyncio.gather( 83 child_modify_context(), 84 child_modify_context(), 85 ) 86 87 # Children should see their own modifications 88 for result in child_results: 89 assert result == "child_modified" 90 91 # Parent should still have the original value 92 assert get_agent_run_id() == parent_run_id
Test that child writes don't affect parent or siblings.
@pytest.mark.asyncio
async def
test_nested_async_calls_propagate(self):
94 @pytest.mark.asyncio 95 async def test_nested_async_calls_propagate(self): 96 """Test that contextvar propagates through nested async calls.""" 97 clear_agent_run_id() 98 root_run_id = "root_run" 99 set_agent_run_id(root_run_id) 100 101 async def level_2(): 102 return get_agent_run_id() 103 104 async def level_1(): 105 return await level_2() 106 107 result = await level_1() 108 assert result == root_run_id
Test that contextvar propagates through nested async calls.