close
Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/google/adk/models/google_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,10 @@ async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:
' backend. Please use Vertex AI backend.'
)
llm_request.live_connect_config.tools = llm_request.config.tools
if llm_request.config.thinking_config is not None:
llm_request.live_connect_config.thinking_config = (
llm_request.config.thinking_config
)
logger.debug('Connecting to live with llm_request:%s', llm_request)
logger.debug('Live connect config: %s', llm_request.live_connect_config)
async with self._live_api_client.aio.live.connect(
Expand Down
85 changes: 75 additions & 10 deletions tests/unittests/models/test_google_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,17 @@
from google.genai.types import Part
import pytest

_GENERATIVE_LANGUAGE_API_BASE_URL = (
"https://generativelanguage.googleapis.com/v1alpha"
)
_GENERATIVE_LANGUAGE_API_ROOT = "https://generativelanguage.googleapis.com/"
_GENERATIVE_LANGUAGE_MTLS_API_BASE_URL = (
"https://generativelanguage.mtls.googleapis.com/v1alpha"
)
_GENERATIVE_LANGUAGE_MTLS_API_ROOT = (
"https://generativelanguage.mtls.googleapis.com/"
)


class MockAsyncIterator:
"""Mock for async iterator."""
Expand Down Expand Up @@ -252,17 +263,27 @@ def test_client_version_header_with_agent_engine(monkeypatch):
)


def test_api_client_uses_api_version_from_google_base_url():
@pytest.mark.parametrize(
("base_url", "normalized_base_url"),
(
(_GENERATIVE_LANGUAGE_API_BASE_URL, _GENERATIVE_LANGUAGE_API_ROOT),
(
_GENERATIVE_LANGUAGE_MTLS_API_BASE_URL,
_GENERATIVE_LANGUAGE_MTLS_API_ROOT,
),
),
)
def test_api_client_uses_api_version_from_google_base_url(
base_url, normalized_base_url
):
model = Gemini(
model="gemini-2.5-flash",
base_url="https://generativelanguage.googleapis.com/v1alpha",
base_url=base_url,
)

client = model.api_client

assert client._api_client._http_options.base_url == (
"https://generativelanguage.googleapis.com/"
)
assert client._api_client._http_options.base_url == normalized_base_url
assert client._api_client._http_options.api_version == "v1alpha"


Expand Down Expand Up @@ -670,7 +691,7 @@ async def test_generate_content_async_patches_api_version(
):
gemini_llm = Gemini(
model="gemini-2.5-flash",
base_url="https://generativelanguage.googleapis.com/v1alpha",
base_url=_GENERATIVE_LANGUAGE_API_BASE_URL,
)
llm_request.config.http_options = types.HttpOptions(
headers={"custom-header": "custom-value"}
Expand Down Expand Up @@ -718,7 +739,7 @@ def test_live_api_version_vertex_ai(gemini_llm):
def test_live_api_version_uses_google_base_url_version():
gemini_llm = Gemini(
model="gemini-2.5-flash",
base_url="https://generativelanguage.googleapis.com/v1alpha",
base_url=_GENERATIVE_LANGUAGE_API_BASE_URL,
)

assert gemini_llm._live_api_version == "v1alpha"
Expand All @@ -732,16 +753,28 @@ def test_live_api_version_gemini_api(gemini_llm):
assert gemini_llm._live_api_version == "v1alpha"


def test_live_api_client_uses_api_version_from_google_base_url():
@pytest.mark.parametrize(
("base_url", "normalized_base_url"),
(
(_GENERATIVE_LANGUAGE_API_BASE_URL, _GENERATIVE_LANGUAGE_API_ROOT),
(
_GENERATIVE_LANGUAGE_MTLS_API_BASE_URL,
_GENERATIVE_LANGUAGE_MTLS_API_ROOT,
),
),
)
def test_live_api_client_uses_api_version_from_google_base_url(
base_url, normalized_base_url
):
gemini_llm = Gemini(
model="gemini-2.5-flash",
base_url="https://generativelanguage.googleapis.com/v1alpha",
base_url=base_url,
)

client = gemini_llm._live_api_client
http_options = client._api_client._http_options

assert http_options.base_url == "https://generativelanguage.googleapis.com/"
assert http_options.base_url == normalized_base_url
assert http_options.api_version == "v1alpha"


Expand Down Expand Up @@ -852,6 +885,38 @@ async def __aexit__(self, *args):
)


@pytest.mark.asyncio
async def test_connect_copies_thinking_config_to_live_config(
gemini_llm, llm_request
):
"""Test that live connections preserve thinking_config from generate config."""
thinking_config = types.ThinkingConfig(
thinking_budget=10,
include_thoughts=True,
)
llm_request.config.thinking_config = thinking_config
llm_request.live_connect_config = types.LiveConnectConfig()

mock_live_session = mock.AsyncMock()

with mock.patch.object(gemini_llm, "_live_api_client") as mock_live_client:

class MockLiveConnect:

async def __aenter__(self):
return mock_live_session

async def __aexit__(self, *args):
pass

mock_live_client.aio.live.connect.return_value = MockLiveConnect()

async with gemini_llm.connect(llm_request):
mock_live_client.aio.live.connect.assert_called_once()
config_arg = mock_live_client.aio.live.connect.call_args.kwargs["config"]
assert config_arg.thinking_config == thinking_config


@pytest.mark.parametrize(
(
"api_backend, "
Expand Down