From 3ca22f7a9f0fa35b56fbbaeb69b12f21992fa734 Mon Sep 17 00:00:00 2001 From: Mateusz Poliwczak Date: Wed, 15 Apr 2026 16:55:12 +0200 Subject: [PATCH] CI token usage --- Makefile | 4 +-- splunklib/ai/engines/langchain.py | 45 ++++++++++++++++++++++++++++++- tests/ai_testlib.py | 9 +++++++ 3 files changed, 55 insertions(+), 3 deletions(-) diff --git a/Makefile b/Makefile index 44c38d74..b24ac665 100644 --- a/Makefile +++ b/Makefile @@ -35,7 +35,7 @@ docs: # -ra prints a report on all failed tests after a run # -vv shows why a test failed while the rest of the suite is running PYTHON_CMD := uv run python -PYTEST_CMD := $(PYTHON_CMD) -m pytest --no-header --ff -ra -vv +PYTEST_CMD := $(PYTHON_CMD) -m pytest --no-header --ff -ra -vvv -s .PHONY: test test: @@ -47,7 +47,7 @@ test-unit: .PHONY: test-integration test-integration: - $(PYTEST_CMD) --ff ./tests/integration ./tests/system + $(PYTEST_CMD) --ff ./tests/integration/ai ./tests/system .PHONY: test-ai test-ai: diff --git a/splunklib/ai/engines/langchain.py b/splunklib/ai/engines/langchain.py index bd103724..8c17b5df 100644 --- a/splunklib/ai/engines/langchain.py +++ b/splunklib/ai/engines/langchain.py @@ -120,6 +120,8 @@ LC_AgentMiddleware = Langchain_AgentMiddleware[Any, "InvokeContext", Any] LC_ModelRequest = Langchain_ModelRequest["InvokeContext"] +total_token_usage: int = 0 + # Set to True to enable debugging mode. _DEBUG = False @@ -291,7 +293,6 @@ async def awrap_model_call( request: LC_ModelRequest, handler: Callable[[LC_ModelRequest], Awaitable[LC_ModelCallResult]], ) -> LC_ModelCallResult: - agent_thread_ids: dict[str, set[str]] = {} # Update the subagent schema definitions to include all thread_ids that the @@ -498,6 +499,9 @@ async def awrap_model_call( print("LLM CALL", request) try: resp = await handler(request) + except LC_StructuredOutputError as e: + print("LLM FAILURE", e, e.ai_message) + raise except Exception as e: print("LLM FAILURE", e) raise @@ -528,6 +532,45 @@ async def awrap_tool_call( if _DEBUG: lc_middleware.append(_DEBUGMiddleware()) + class _TOKENUsage(LC_AgentMiddleware): + @override + async def awrap_model_call( + self, + request: LC_ModelRequest, + handler: Callable[[LC_ModelRequest], Awaitable[LC_ModelCallResult]], + ) -> LC_ModelCallResult: + global total_token_usage + + def _extract_tokens(resp: LC_ModelCallResult) -> int: + ai_message = resp + if isinstance(ai_message, LC_ExtendedModelResponse): + ai_message = ai_message.model_response + if isinstance(ai_message, LC_ModelResponse): + ai_message = next( + ( + m + for m in ai_message.result + if isinstance(m, LC_AIMessage) + ), + None, + ) + if ai_message is not None and ai_message.usage_metadata: + return ai_message.usage_metadata.get("total_tokens", 0) + return 0 + + try: + resp = await handler(request) + total_token_usage += _extract_tokens(resp) + return resp + except LC_StructuredOutputError as e: + if e.ai_message.usage_metadata: + total_token_usage += e.ai_message.usage_metadata.get( + "total_tokens", 0 + ) + raise + + lc_middleware.append(_TOKENUsage()) + response_format = None if agent.output_schema is not None: if _supports_provider_strategy(model_impl): diff --git a/tests/ai_testlib.py b/tests/ai_testlib.py index 631fd16f..96ba00c4 100644 --- a/tests/ai_testlib.py +++ b/tests/ai_testlib.py @@ -1,4 +1,5 @@ from typing import override +import splunklib.ai.engines.langchain as langchain_engine from splunklib.ai.model import PredefinedModel from tests.ai_test_model import InternalAIModel, TestLLMSettings, create_model from tests.testlib import SDKTestCase @@ -6,10 +7,12 @@ class AITestCase(SDKTestCase): _model: PredefinedModel | None = None + _token_usage_before: int = 0 @override def setUp(self) -> None: super().setUp() + self._token_usage_before = langchain_engine.total_token_usage # Our tests don't expect this app to be installed, if needed it is # installed on demand. @@ -18,6 +21,12 @@ def setUp(self) -> None: app.delete() self.restart_splunk() + @override + def tearDown(self) -> None: + tokens_used = langchain_engine.total_token_usage - self._token_usage_before + print(f"\n[token usage] {self.id()}: {tokens_used} tokens") + super().tearDown() + @property def test_llm_settings(self) -> TestLLMSettings: client_id: str = self.opts.kwargs["internal_ai_client_id"]