kiln_ai.test_run_context

Tests for agent run context management using contextvars.

  1"""Tests for agent run context management using contextvars."""
  2
  3import asyncio
  4
  5import pytest
  6
  7from kiln_ai.run_context import (
  8    clear_agent_run_id,
  9    generate_agent_run_id,
 10    get_agent_run_id,
 11    set_agent_run_id,
 12)
 13
 14
 15class TestAdapterRunContext:
 16    """Unit tests for agent run context management."""
 17
 18    def test_default_is_none(self):
 19        """Test that the default run ID is None."""
 20        # Clear any existing context first
 21        clear_agent_run_id()
 22        assert get_agent_run_id() is None
 23
 24    def test_set_and_get(self):
 25        """Test basic set and get cycle."""
 26        clear_agent_run_id()
 27        assert get_agent_run_id() is None
 28
 29        test_run_id = "test_run_123"
 30        set_agent_run_id(test_run_id)
 31        assert get_agent_run_id() == test_run_id
 32
 33    def test_clear(self):
 34        """Test that clear resets to None."""
 35        set_agent_run_id("some_run")
 36        assert get_agent_run_id() == "some_run"
 37
 38        clear_agent_run_id()
 39        assert get_agent_run_id() is None
 40
 41    def test_generate_agent_run_id_format(self):
 42        """Test that generated run IDs have the correct format."""
 43        run_id = generate_agent_run_id()
 44        assert run_id.startswith("run_")
 45        assert len(run_id) == 20  # "run_" + 16 hex characters
 46
 47    def test_generate_agent_run_id_unique(self):
 48        """Test that generated run IDs are unique."""
 49        ids = [generate_agent_run_id() for _ in range(100)]
 50        assert len(set(ids)) == 100  # All should be unique
 51
 52    @pytest.mark.asyncio
 53    async def test_asyncio_gather_propagation(self):
 54        """Test that contextvar propagates through asyncio.gather."""
 55        clear_agent_run_id()
 56        parent_run_id = "parent_run"
 57        set_agent_run_id(parent_run_id)
 58
 59        child_results = await asyncio.gather(
 60            self._read_context_var(),
 61            self._read_context_var(),
 62            self._read_context_var(),
 63        )
 64
 65        # All children should see the parent's run ID
 66        for result in child_results:
 67            assert result == parent_run_id
 68
 69    @pytest.mark.asyncio
 70    async def test_asyncio_gather_isolation(self):
 71        """Test that child writes don't affect parent or siblings."""
 72        clear_agent_run_id()
 73        parent_run_id = "parent_run"
 74        set_agent_run_id(parent_run_id)
 75
 76        async def child_modify_context():
 77            # Child sets a different value
 78            set_agent_run_id("child_modified")
 79            return get_agent_run_id()
 80
 81        child_results = await asyncio.gather(
 82            child_modify_context(),
 83            child_modify_context(),
 84        )
 85
 86        # Children should see their own modifications
 87        for result in child_results:
 88            assert result == "child_modified"
 89
 90        # Parent should still have the original value
 91        assert get_agent_run_id() == parent_run_id
 92
 93    @pytest.mark.asyncio
 94    async def test_nested_async_calls_propagate(self):
 95        """Test that contextvar propagates through nested async calls."""
 96        clear_agent_run_id()
 97        root_run_id = "root_run"
 98        set_agent_run_id(root_run_id)
 99
100        async def level_2():
101            return get_agent_run_id()
102
103        async def level_1():
104            return await level_2()
105
106        result = await level_1()
107        assert result == root_run_id
108
109    async def _read_context_var(self) -> str | None:
110        """Helper to read the context var in an async context."""
111        return get_agent_run_id()
class TestAdapterRunContext:
 16class TestAdapterRunContext:
 17    """Unit tests for agent run context management."""
 18
 19    def test_default_is_none(self):
 20        """Test that the default run ID is None."""
 21        # Clear any existing context first
 22        clear_agent_run_id()
 23        assert get_agent_run_id() is None
 24
 25    def test_set_and_get(self):
 26        """Test basic set and get cycle."""
 27        clear_agent_run_id()
 28        assert get_agent_run_id() is None
 29
 30        test_run_id = "test_run_123"
 31        set_agent_run_id(test_run_id)
 32        assert get_agent_run_id() == test_run_id
 33
 34    def test_clear(self):
 35        """Test that clear resets to None."""
 36        set_agent_run_id("some_run")
 37        assert get_agent_run_id() == "some_run"
 38
 39        clear_agent_run_id()
 40        assert get_agent_run_id() is None
 41
 42    def test_generate_agent_run_id_format(self):
 43        """Test that generated run IDs have the correct format."""
 44        run_id = generate_agent_run_id()
 45        assert run_id.startswith("run_")
 46        assert len(run_id) == 20  # "run_" + 16 hex characters
 47
 48    def test_generate_agent_run_id_unique(self):
 49        """Test that generated run IDs are unique."""
 50        ids = [generate_agent_run_id() for _ in range(100)]
 51        assert len(set(ids)) == 100  # All should be unique
 52
 53    @pytest.mark.asyncio
 54    async def test_asyncio_gather_propagation(self):
 55        """Test that contextvar propagates through asyncio.gather."""
 56        clear_agent_run_id()
 57        parent_run_id = "parent_run"
 58        set_agent_run_id(parent_run_id)
 59
 60        child_results = await asyncio.gather(
 61            self._read_context_var(),
 62            self._read_context_var(),
 63            self._read_context_var(),
 64        )
 65
 66        # All children should see the parent's run ID
 67        for result in child_results:
 68            assert result == parent_run_id
 69
 70    @pytest.mark.asyncio
 71    async def test_asyncio_gather_isolation(self):
 72        """Test that child writes don't affect parent or siblings."""
 73        clear_agent_run_id()
 74        parent_run_id = "parent_run"
 75        set_agent_run_id(parent_run_id)
 76
 77        async def child_modify_context():
 78            # Child sets a different value
 79            set_agent_run_id("child_modified")
 80            return get_agent_run_id()
 81
 82        child_results = await asyncio.gather(
 83            child_modify_context(),
 84            child_modify_context(),
 85        )
 86
 87        # Children should see their own modifications
 88        for result in child_results:
 89            assert result == "child_modified"
 90
 91        # Parent should still have the original value
 92        assert get_agent_run_id() == parent_run_id
 93
 94    @pytest.mark.asyncio
 95    async def test_nested_async_calls_propagate(self):
 96        """Test that contextvar propagates through nested async calls."""
 97        clear_agent_run_id()
 98        root_run_id = "root_run"
 99        set_agent_run_id(root_run_id)
100
101        async def level_2():
102            return get_agent_run_id()
103
104        async def level_1():
105            return await level_2()
106
107        result = await level_1()
108        assert result == root_run_id
109
110    async def _read_context_var(self) -> str | None:
111        """Helper to read the context var in an async context."""
112        return get_agent_run_id()

Unit tests for agent run context management.

def test_default_is_none(self):
19    def test_default_is_none(self):
20        """Test that the default run ID is None."""
21        # Clear any existing context first
22        clear_agent_run_id()
23        assert get_agent_run_id() is None

Test that the default run ID is None.

def test_set_and_get(self):
25    def test_set_and_get(self):
26        """Test basic set and get cycle."""
27        clear_agent_run_id()
28        assert get_agent_run_id() is None
29
30        test_run_id = "test_run_123"
31        set_agent_run_id(test_run_id)
32        assert get_agent_run_id() == test_run_id

Test basic set and get cycle.

def test_clear(self):
34    def test_clear(self):
35        """Test that clear resets to None."""
36        set_agent_run_id("some_run")
37        assert get_agent_run_id() == "some_run"
38
39        clear_agent_run_id()
40        assert get_agent_run_id() is None

Test that clear resets to None.

def test_generate_agent_run_id_format(self):
42    def test_generate_agent_run_id_format(self):
43        """Test that generated run IDs have the correct format."""
44        run_id = generate_agent_run_id()
45        assert run_id.startswith("run_")
46        assert len(run_id) == 20  # "run_" + 16 hex characters

Test that generated run IDs have the correct format.

def test_generate_agent_run_id_unique(self):
48    def test_generate_agent_run_id_unique(self):
49        """Test that generated run IDs are unique."""
50        ids = [generate_agent_run_id() for _ in range(100)]
51        assert len(set(ids)) == 100  # All should be unique

Test that generated run IDs are unique.

@pytest.mark.asyncio
async def test_asyncio_gather_propagation(self):
53    @pytest.mark.asyncio
54    async def test_asyncio_gather_propagation(self):
55        """Test that contextvar propagates through asyncio.gather."""
56        clear_agent_run_id()
57        parent_run_id = "parent_run"
58        set_agent_run_id(parent_run_id)
59
60        child_results = await asyncio.gather(
61            self._read_context_var(),
62            self._read_context_var(),
63            self._read_context_var(),
64        )
65
66        # All children should see the parent's run ID
67        for result in child_results:
68            assert result == parent_run_id

Test that contextvar propagates through asyncio.gather.

@pytest.mark.asyncio
async def test_asyncio_gather_isolation(self):
70    @pytest.mark.asyncio
71    async def test_asyncio_gather_isolation(self):
72        """Test that child writes don't affect parent or siblings."""
73        clear_agent_run_id()
74        parent_run_id = "parent_run"
75        set_agent_run_id(parent_run_id)
76
77        async def child_modify_context():
78            # Child sets a different value
79            set_agent_run_id("child_modified")
80            return get_agent_run_id()
81
82        child_results = await asyncio.gather(
83            child_modify_context(),
84            child_modify_context(),
85        )
86
87        # Children should see their own modifications
88        for result in child_results:
89            assert result == "child_modified"
90
91        # Parent should still have the original value
92        assert get_agent_run_id() == parent_run_id

Test that child writes don't affect parent or siblings.

@pytest.mark.asyncio
async def test_nested_async_calls_propagate(self):
 94    @pytest.mark.asyncio
 95    async def test_nested_async_calls_propagate(self):
 96        """Test that contextvar propagates through nested async calls."""
 97        clear_agent_run_id()
 98        root_run_id = "root_run"
 99        set_agent_run_id(root_run_id)
100
101        async def level_2():
102            return get_agent_run_id()
103
104        async def level_1():
105            return await level_2()
106
107        result = await level_1()
108        assert result == root_run_id

Test that contextvar propagates through nested async calls.