# SPDX-FileCopyrightText: 2026 Google LLC
# SPDX-License-Identifier: Apache-2.0
"""ADK 2.0 chat runner with closure-bound tools and parallel intent classification."""
from __future__ import annotations
import time
import uuid
from hashlib import sha256
from inspect import isawaitable
from typing import TYPE_CHECKING, Any, cast
import structlog
from google.adk import Runner
from google.adk.agents import LlmAgent
from google.adk.agents.run_config import RunConfig, StreamingMode
from google.genai import errors as genai_errors
from google.genai import types
from sqlspec.adapters.oracledb import OracleAsyncDriver # noqa: TC002
from sqlspec.extensions.adk import SQLSpecSessionService # noqa: TC002
from app.domain.chat.exceptions import AIServiceUnconfigured
from app.domain.chat.services._adk_grounding import (
_build_map_actions,
_coerce_dict_rows,
_extract_location_filters,
_extract_product_query,
_format_availability_answer,
_format_store_location_answer,
_get_field,
_ground_product_rag_turn,
_record_product_search_result,
_request_coordinates,
_safe_location_context,
)
from app.domain.chat.services._adk_support import (
_coerce_history_messages,
_coerce_sql_phases,
_effective_intent,
_event_content_text,
_event_history_messages,
_record_tool_sql_phases,
_response_cache_phase,
_sha256_text,
_similarity_score,
_sql_phase,
_summarize_vector,
)
from app.domain.chat.services.classifier import FlashLiteIntentClassifier, IntentLabel
from app.domain.chat.services.workflow import make_workflow
from app.domain.products.services import ProductService, StoreService, VertexAIService # noqa: TC001
from app.domain.system.schemas import SearchMetricsCreate
from app.domain.system.services import BASE_SYSTEM_INSTRUCTION, CacheService, MetricsService, PersonaManager
from app.lib.service import OracleAsyncService
from app.lib.settings import get_settings
from app.utils.serialization import sanitize_for_json
if TYPE_CHECKING:
from collections.abc import AsyncIterator
from google.adk.agents.callback_context import CallbackContext
from google.adk.agents.llm_agent import ToolUnion
from app.domain.chat.schemas import ChatMessage
logger = structlog.get_logger()
_UNCONFIGURED_MESSAGE = "AI service is not configured. Set GOOGLE_API_KEY or VERTEX_AI_API_KEY in your .env file."
_PLACEHOLDER_PROJECT_IDS = frozenset({"demo-project", "your-project-id", "your-gcp-project-id"})
_PRODUCT_RAG_INTENT = "PRODUCT_RAG"
_PRODUCT_AVAILABILITY_INTENT = "PRODUCT_AVAILABILITY"
_STORE_LOCATION_INTENT = "STORE_LOCATION"
_ORDER_STATUS_INTENT = "ORDER_STATUS"
_DISPLAY_HISTORY_STATE_KEY = "display_history"
async def _collect_workflow_stream(
events: AsyncIterator[Any],
*,
workflow_output: dict[str, Any],
answer_parts: list[str],
partial_answer_parts: list[str],
) -> AsyncIterator[dict[str, str]]:
"""Collect ADK workflow output while yielding streaming deltas."""
async for event in events:
if isinstance(event.output, dict) and "intent" in event.output:
workflow_output.update(event.output)
text = _event_content_text(event)
if not text:
continue
if event.partial:
partial_answer_parts.append(text)
yield {"type": "delta", "text": text}
else:
answer_parts.append(text)
def _final_event(
*,
answer: str,
session_id: str,
response_time_ms: float,
intent_detected: str,
search_metrics: dict[str, Any],
sql_phases: list[dict[str, Any]],
from_cache: bool = False,
embedding_cache_hit: bool = False,
store_results: list[dict[str, Any]] | None = None,
inventory_results: list[dict[str, Any]] | None = None,
map_actions: list[dict[str, Any]] | None = None,
location_context: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Build the final stream event shared by every chat route.
Returns:
The complete reply payload with masked location context.
"""
return {
"type": "final",
"answer": answer,
"session_id": session_id,
"response_time_ms": response_time_ms,
"intent_detected": intent_detected,
"search_metrics": search_metrics,
"from_cache": from_cache,
"embedding_cache_hit": embedding_cache_hit,
"sql_phases": sql_phases,
"store_results": store_results if store_results is not None else [],
"inventory_results": inventory_results if inventory_results is not None else [],
"map_actions": map_actions if map_actions is not None else [],
"location_context": _safe_location_context(location_context),
}
def _inventory_response_payload(metric_state: dict[str, Any]) -> dict[str, list[dict[str, Any]]]:
"""Return inventory rows and map actions recorded during store-aware RAG."""
inventory_results = _coerce_dict_rows(metric_state.get("inventory_results"))
return {"inventory_results": inventory_results, "map_actions": _build_map_actions(inventory_results)}
class AgentToolsService(OracleAsyncService):
"""Business logic invoked by closure-bound ADK tools."""
def __init__(
self,
driver: OracleAsyncDriver,
product_service: ProductService,
metrics_service: MetricsService,
vertex_ai_service: VertexAIService,
store_service: StoreService,
cache_service: CacheService,
) -> None:
super().__init__(driver)
self.product_service = product_service
self.metrics_service = metrics_service
self.vertex_ai_service = vertex_ai_service
self.store_service = store_service
self.cache_service = cache_service
async def search_products_by_vector(
self,
query: str,
limit: int | None = None,
similarity_threshold: float | None = None,
*,
store_id: int | None = None,
) -> dict[str, Any]:
chat = get_settings().chat
limit = chat.product_search_limit if limit is None else limit
similarity_threshold = chat.product_search_threshold if similarity_threshold is None else similarity_threshold
embedding_start = time.time()
embedding, cache_hit = await self.vertex_ai_service.get_text_embedding(
query, embedding_purpose="query", return_cache_status=True
)
embedding_ms = (time.time() - embedding_start) * 1000
oracle_start = time.time()
products = await self.product_service.search_by_vector(
embedding, similarity_threshold, limit, store_id=store_id
)
oracle_ms = (time.time() - oracle_start) * 1000
tool_total_ms = embedding_ms + oracle_ms
model = str(getattr(self.vertex_ai_service, "embedding_model", "unknown"))
vector_sql_key = "vector-search-products-by-store" if store_id is not None else "vector-search-products"
vector_binds: dict[str, Any] = {
"query_vector": _summarize_vector(embedding),
"threshold": similarity_threshold,
"limit": limit,
}
if store_id is not None:
vector_binds["store_id"] = store_id
await self.metrics_service.record_search(
SearchMetricsCreate(
query_id=str(uuid.uuid4()),
user_id="chat",
search_time_ms=tool_total_ms,
embedding_time_ms=embedding_ms,
oracle_time_ms=oracle_ms,
similarity_score=_similarity_score(products),
result_count=len(products),
)
)
return {
"products": sanitize_for_json(products),
"embedding_cache_hit": cache_hit,
"results_count": len(products),
"vector_query": query,
"search_metrics": {
"vector_query": query,
"embedding_ms": round(embedding_ms, 2),
"oracle_ms": round(oracle_ms, 2),
"tool_ms": round(tool_total_ms, 2),
},
"sql_phases": [
_sql_phase(
label="Embedding cache lookup",
sql_key="get-cached-embedding",
binds={"hash": _sha256_text(query), "model": model},
row_count=1 if cache_hit else 0,
runtime_ms=embedding_ms,
cache_status="hit" if cache_hit else "miss",
),
_sql_phase(
label="Oracle vector search",
sql_key=vector_sql_key,
binds=vector_binds,
row_count=len(products),
runtime_ms=oracle_ms,
cache_status="miss",
),
],
}
async def get_product_details(self, product_id: str) -> dict[str, Any]:
try:
product = await self.product_service.get_by_id(int(product_id))
except ValueError:
product = await self.product_service.get_by_name(product_id)
return cast("dict[str, Any]", sanitize_for_json(product)) if product else {"error": "Product not found"}
async def get_all_store_locations(self) -> list[dict[str, Any]]:
stores = await self.store_service.get_all_stores()
return cast("list[dict[str, Any]]", sanitize_for_json(stores))
async def find_stores_by_location(
self, *, city: str | None = None, state: str | None = None, zip_code: str | None = None
) -> dict[str, Any]:
started = time.time()
stores = await self.store_service.find_stores_by_location(city=city, state=state, zip_code=zip_code)
return {
"stores": sanitize_for_json(stores),
"results_count": len(stores),
"sql_phases": [
_sql_phase(
label="Store location lookup",
sql_key="find-stores-by-location",
binds={"city": city, "state": state, "zip_code": zip_code},
row_count=len(stores),
runtime_ms=(time.time() - started) * 1000,
cache_status="miss",
)
],
}
async def get_store_hours(self, store_id: int) -> dict[str, Any]:
started = time.time()
hours = await self.store_service.get_store_hours(store_id)
payload: dict[str, Any] = (
cast("dict[str, Any]", sanitize_for_json(hours)) if hours else {"error": "Store not found"}
)
payload["sql_phases"] = [
_sql_phase(
label="Store hours lookup",
sql_key="get-store-by-id",
binds={"id": store_id},
row_count=0 if hours is None else 1,
runtime_ms=(time.time() - started) * 1000,
cache_status="miss",
)
]
return payload
async def find_nearest_stores(self, latitude: float, longitude: float, limit: int = 5) -> dict[str, Any]:
started = time.time()
stores = await self.store_service.find_nearest_stores(latitude, longitude, limit)
return {
"stores": sanitize_for_json(stores),
"results_count": len(stores),
"sql_phases": [
_sql_phase(
label="Nearest store lookup",
sql_key="list-stores",
binds={"origin": "<REQUEST_COORDINATES>", "limit": limit},
row_count=len(stores),
runtime_ms=(time.time() - started) * 1000,
cache_status="miss",
)
],
}
async def find_stores_with_product(
self, product_query: str, latitude: float | None = None, longitude: float | None = None
) -> dict[str, Any]:
started = time.time()
coordinates = (latitude, longitude) if latitude is not None and longitude is not None else None
# 1. Try exact match first
availability = await self.store_service.find_product_availability(product_query, coordinates=coordinates)
sql_key = "find-product-availability-by-query"
binds: dict[str, Any] = {"product_query": product_query}
# 2. If no exact match, try vector search fallback
if not availability:
query_embedding = await self.vertex_ai_service.get_text_embedding(product_query, embedding_purpose="query")
matches = await self.product_service.search_by_vector(query_embedding, similarity_threshold=0.6, limit=1)
if matches:
resolved_product = matches[0]
availability = await self.store_service.find_stores_with_product(
resolved_product.id, latitude=latitude, longitude=longitude
)
sql_key = "find-stores-with-product-inventory"
binds = {"product_id": resolved_product.id}
await logger.ainfo(
"Resolved product query via vector search",
query=product_query,
resolved_name=resolved_product.name,
similarity=resolved_product.similarity_score,
)
if coordinates:
binds["origin"] = "<REQUEST_COORDINATES>"
return {
"availability": sanitize_for_json(availability),
"results_count": len(availability),
"sql_phases": [
_sql_phase(
label="Product availability lookup",
sql_key=sql_key,
binds=binds,
row_count=len(availability),
runtime_ms=(time.time() - started) * 1000,
cache_status="miss",
)
],
}
def make_response_cache_key(self, query: str, persona: str) -> str:
normalized = " ".join(query.casefold().split())
model = self.vertex_ai_service.model
version = get_settings().chat.response_cache_version
digest = sha256(f"{version}:{model}:{persona}:{normalized}".encode()).hexdigest()
return f"chat:{digest}"
async def get_cached_chat_response(self, cache_key: str) -> dict[str, Any] | None:
cached = await self.cache_service.get_cached_response(cache_key)
return cached.response_data if cached else None
async def set_cached_chat_response(self, cache_key: str, response_data: dict[str, Any]) -> None:
await self.cache_service.set_cached_response(
cache_key, response_data, ttl_minutes=get_settings().chat.response_cache_ttl_minutes
)
def credential_guard_callback(callback_context: CallbackContext) -> types.Content | None:
"""Short-circuit the agent with a 503 message when credentials are missing.
Returns:
A model response when credentials are missing, otherwise ``None``.
"""
del callback_context
if _has_vertex_ai_backend_config():
return None
return types.Content(role="model", parts=[types.Part(text=_UNCONFIGURED_MESSAGE)])
def _has_vertex_ai_backend_config() -> bool:
settings = get_settings()
project_id = settings.ai.project_id.strip()
return bool(settings.ai.api_key or (project_id and project_id not in _PLACEHOLDER_PROJECT_IDS))
def _ensure_vertex_ai_backend_configured() -> None:
if not _has_vertex_ai_backend_config():
raise AIServiceUnconfigured(_UNCONFIGURED_MESSAGE)
def _is_credential_error(exc: BaseException) -> bool:
text = str(exc).lower()
if isinstance(exc, genai_errors.ClientError):
return any(
marker in text
for marker in (
"api key",
"credentials",
"permission_denied",
"service_disabled",
"forbidden",
"unauthorized",
)
)
return "api key" in text or "credentials" in text
[docs]
class ADKRunner:
"""Per-request ADK 2.0 workflow with closure-bound tools."""
[docs]
def __init__(
self,
session_service: SQLSpecSessionService,
classifier: FlashLiteIntentClassifier,
persona_manager: PersonaManager,
) -> None:
self._session_service = session_service
self._classifier = classifier
self._persona_manager = persona_manager
def _make_tool_factories(
self,
tools_service: AgentToolsService,
metric_state: dict[str, Any],
location_context: dict[str, Any] | None = None,
) -> list[ToolUnion]:
chat_settings = get_settings().chat
async def search_products_by_vector(
query: str,
limit: int = chat_settings.product_search_limit,
similarity_threshold: float = chat_settings.product_search_threshold,
) -> dict[str, Any]:
"""Search the Cymbal Coffee menu with vector RAG.
Use for menu, catalog, recommendation, flavor, roast, price, caffeine,
availability, dietary substitution, and idiomatic preference requests.
Returns:
Matching menu products and cache/search metadata. Only these
returned products may be recommended to the user.
"""
target_store = await self._resolve_rag_store(
tools_service=tools_service, query=query, location_context=location_context
)
store_id = target_store.id if target_store else None
result = await tools_service.search_products_by_vector(
query, limit, similarity_threshold, store_id=store_id
)
_record_product_search_result(metric_state, result, query)
return result
async def get_product_details(product_id: str) -> dict[str, Any]:
"""Get exact details for a Cymbal Coffee product by id or name.
Returns:
Product details, or an error object when no product matches.
"""
result = await tools_service.get_product_details(product_id)
if "error" not in result and result.get("name"):
metric_state["rag_products"] = [result]
return result
async def get_all_store_locations() -> list[dict[str, Any]]:
"""List Cymbal Coffee store locations for address, hours, pickup, or nearest-cafe questions.
Returns:
Store location records.
"""
return await tools_service.get_all_store_locations()
async def find_stores_by_location(
city: str | None = None, state: str | None = None, zip_code: str | None = None
) -> dict[str, Any]:
"""Find Cymbal Coffee stores by city, state, or ZIP code.
Returns:
Matching store records and named-SQL telemetry.
"""
result = await tools_service.find_stores_by_location(city=city, state=state, zip_code=zip_code)
_record_tool_sql_phases(metric_state, result)
return result
async def get_store_hours(store_id: int) -> dict[str, Any]:
"""Get business hours for a Cymbal Coffee store.
Returns:
Store hours, timezone, and named-SQL telemetry.
"""
result = await tools_service.get_store_hours(store_id)
_record_tool_sql_phases(metric_state, result)
return result
async def find_nearest_stores(latitude: float, longitude: float, limit: int = 5) -> dict[str, Any]:
"""Find nearest Cymbal Coffee stores from request-scoped browser coordinates.
Returns:
Nearest local stores and named-SQL telemetry with coordinates masked.
"""
result = await tools_service.find_nearest_stores(latitude, longitude, limit)
_record_tool_sql_phases(metric_state, result)
return result
async def find_stores_with_product(
product_query: str, latitude: float | None = None, longitude: float | None = None
) -> dict[str, Any]:
"""Find stores with availability for a Cymbal Coffee product.
Returns:
Store-level availability and named-SQL telemetry.
"""
result = await tools_service.find_stores_with_product(product_query, latitude, longitude)
_record_tool_sql_phases(metric_state, result)
return result
return [
search_products_by_vector,
get_product_details,
get_all_store_locations,
find_stores_by_location,
get_store_hours,
find_nearest_stores,
find_stores_with_product,
]
def _build_workflow(self, instruction: str, temperature: float, tools: list[ToolUnion]) -> Any:
agent = LlmAgent(
name="CoffeeAssistant",
model=get_settings().ai.chat_model,
instruction=instruction,
tools=tools,
generate_content_config=types.GenerateContentConfig(temperature=temperature),
before_agent_callback=credential_guard_callback,
)
return make_workflow(self._classifier, agent)
[docs]
async def get_history(self, user_id: str, session_id: str) -> list[ChatMessage]:
"""Return displayable chat history for the current ADK session."""
session = await self._session_service.get_session(
app_name=get_settings().chat.session_app_name, user_id=user_id, session_id=session_id
)
if not session:
return []
state = getattr(session, "state", None) or {}
if isinstance(state, dict):
persisted = _coerce_history_messages(state.get(_DISPLAY_HISTORY_STATE_KEY))
if persisted:
return persisted
return _event_history_messages(getattr(session, "events", []))
[docs]
async def get_history_or_empty(self, user_id: str, session_id: str) -> list[ChatMessage]:
"""Return displayable chat history, or an empty list if loading fails."""
try:
return await self.get_history(user_id=user_id, session_id=session_id)
except Exception as exc: # noqa: BLE001
await logger.awarning("Chat history unavailable", error_type=type(exc).__name__)
return []
[docs]
async def clear_session(self, user_id: str, session_id: str) -> None:
"""Delete the current ADK session and its event history."""
await self._session_service.delete_session(
app_name=get_settings().chat.session_app_name, user_id=user_id, session_id=session_id
)
async def _append_display_history(
self,
*,
user_id: str,
session_id: str,
query: str,
answer: str,
intent_detected: str | None = None,
last_products: list[str] | None = None,
) -> None:
session = await self._session_service.get_session(
app_name=get_settings().chat.session_app_name, user_id=user_id, session_id=session_id
)
if not session:
return
state = dict(getattr(session, "state", None) or {})
if intent_detected:
state["intent"] = intent_detected
if last_products is not None:
state["last_products"] = last_products
history = [
*[
{"source": message.source, "message": message.message}
for message in _coerce_history_messages(state.get(_DISPLAY_HISTORY_STATE_KEY))
],
{"source": "human", "message": query},
{"source": "ai", "message": answer},
]
state[_DISPLAY_HISTORY_STATE_KEY] = history[-get_settings().chat.display_history_limit :]
result = self._session_service.store.update_session_state(session_id, state)
if isawaitable(result):
await result
async def _cached_response_event(
self,
*,
start: float,
query: str,
user_id: str,
session: Any,
cached_response: dict[str, Any],
response_cache_phase: dict[str, Any],
location_context: dict[str, Any] | None,
) -> dict[str, Any]:
elapsed_ms = (time.time() - start) * 1000
cached_metrics = dict(cached_response.get("search_metrics") or {})
cached_metrics["total_ms"] = round(elapsed_ms)
answer = str(cached_response.get("answer", ""))
sql_phases = _coerce_sql_phases(cached_response.get("sql_phases"))
intent_detected = _effective_intent(
str(cached_response.get("intent_detected") or "GENERAL_CONVERSATION"), cached_metrics, sql_phases
)
last_products = cached_response.get("last_products")
if answer:
await self._append_display_history(
user_id=user_id,
session_id=session.id,
query=query,
answer=answer,
intent_detected=intent_detected,
last_products=last_products,
)
return _final_event(
answer=answer,
session_id=session.id,
response_time_ms=elapsed_ms,
intent_detected=intent_detected,
search_metrics=cached_metrics,
sql_phases=[response_cache_phase, *sql_phases],
from_cache=True,
embedding_cache_hit=bool(cached_response.get("embedding_cache_hit")),
store_results=_coerce_dict_rows(cached_response.get("store_results")),
inventory_results=_coerce_dict_rows(cached_response.get("inventory_results")),
map_actions=_coerce_dict_rows(cached_response.get("map_actions")),
location_context=location_context,
)
async def _product_rag_event(
self,
*,
start: float,
query: str,
user_id: str,
session: Any,
cache_key: str | None,
response_cache_phase: dict[str, Any] | None,
tools_service: AgentToolsService,
location_context: dict[str, Any] | None,
) -> dict[str, Any]:
metric_state: dict[str, Any] = {"search_metrics": {}, "embedding_cache_hit": False, "sql_phases": []}
target_store = await self._resolve_rag_store(
tools_service=tools_service, query=query, location_context=location_context
)
answer = await _ground_product_rag_turn(
query, metric_state, tools_service, store_id=target_store.id if target_store else None
)
elapsed_ms = (time.time() - start) * 1000
search_metrics = dict(metric_state.get("search_metrics", {}))
search_metrics["total_ms"] = round(elapsed_ms)
product_sql_phases = _coerce_sql_phases(metric_state.get("sql_phases"))
products = metric_state.get("rag_products", [])
last_products = [p["name"] for p in products] if products else None
inventory_payload = _inventory_response_payload(metric_state)
response_data = {
"answer": answer,
"intent_detected": _PRODUCT_RAG_INTENT,
"search_metrics": search_metrics,
"embedding_cache_hit": bool(metric_state.get("embedding_cache_hit")),
"sql_phases": product_sql_phases,
"store_results": [],
**inventory_payload,
"last_products": last_products,
}
if answer:
if cache_key:
await tools_service.set_cached_chat_response(cache_key, response_data)
await self._append_display_history(
user_id=user_id,
session_id=session.id,
query=query,
answer=answer,
intent_detected=_PRODUCT_RAG_INTENT,
last_products=last_products,
)
return _final_event(
answer=answer,
session_id=session.id,
response_time_ms=elapsed_ms,
intent_detected=_PRODUCT_RAG_INTENT,
search_metrics=search_metrics,
sql_phases=([response_cache_phase] if response_cache_phase else []) + product_sql_phases,
embedding_cache_hit=bool(metric_state.get("embedding_cache_hit")),
inventory_results=inventory_payload["inventory_results"],
map_actions=inventory_payload["map_actions"],
location_context=location_context,
)
async def _store_location_event(
self,
*,
start: float,
query: str,
user_id: str,
session: Any,
tools_service: AgentToolsService,
location_context: dict[str, Any] | None,
response_cache_phase: dict[str, Any] | None,
) -> dict[str, Any]:
coordinates = _request_coordinates(location_context)
if coordinates:
result = await tools_service.find_nearest_stores(coordinates[0], coordinates[1], 5)
else:
filters = _extract_location_filters(query, location_context)
result = await tools_service.find_stores_by_location(
city=filters["city"], state=filters["state"], zip_code=filters["zip_code"]
)
stores = _coerce_dict_rows(result.get("stores"))
sql_phases = _coerce_sql_phases(result.get("sql_phases"))
elapsed_ms = (time.time() - start) * 1000
answer = _format_store_location_answer(stores)
await self._append_display_history(
user_id=user_id, session_id=session.id, query=query, answer=answer, intent_detected=_STORE_LOCATION_INTENT
)
return _final_event(
answer=answer,
session_id=session.id,
response_time_ms=elapsed_ms,
intent_detected=_STORE_LOCATION_INTENT,
search_metrics={"total_ms": round(elapsed_ms), "results_count": len(stores)},
sql_phases=([response_cache_phase] if response_cache_phase else []) + sql_phases,
store_results=stores,
map_actions=_build_map_actions(stores),
location_context=location_context,
)
async def _product_availability_event(
self,
*,
start: float,
query: str,
user_id: str,
session: Any,
tools_service: AgentToolsService,
location_context: dict[str, Any] | None,
response_cache_phase: dict[str, Any] | None,
) -> dict[str, Any]:
coordinates = _request_coordinates(location_context)
product_query = _extract_product_query(query)
if not product_query:
state = dict(getattr(session, "state", None) or {})
last_products = state.get("last_products", [])
product_query = last_products[0] if last_products else query
location_hint = (location_context or {}).get("store_name") if location_context else None
if not location_hint:
filters = _extract_location_filters(query, location_context)
location_hint = " ".join(str(filters[k]) for k in ("city", "state", "zip_code") if filters[k]) or query
target_store = await tools_service.store_service.resolve_store(
location_hint=location_hint, coordinates=coordinates
)
if coordinates:
result = await tools_service.find_stores_with_product(product_query, coordinates[0], coordinates[1])
else:
result = await tools_service.find_stores_with_product(product_query)
inventory = _coerce_dict_rows(result.get("availability"))
sql_phases = _coerce_sql_phases(result.get("sql_phases"))
elapsed_ms = (time.time() - start) * 1000
target_row = None
alternatives = []
if target_store:
for row in inventory:
if _get_field(row, "store_id") == target_store.id:
target_row = row
else:
alternatives.append(row)
if not target_row:
alternatives = inventory
else:
alternatives = inventory
answer = _format_availability_answer(
target=target_row, alternatives=alternatives, target_store_name=target_store.name if target_store else None
)
await self._append_display_history(
user_id=user_id,
session_id=session.id,
query=query,
answer=answer,
intent_detected=_PRODUCT_AVAILABILITY_INTENT,
)
return _final_event(
answer=answer,
session_id=session.id,
response_time_ms=elapsed_ms,
intent_detected=_PRODUCT_AVAILABILITY_INTENT,
search_metrics={
"total_ms": round(elapsed_ms),
"results_count": len(inventory),
"product_query": product_query,
},
sql_phases=([response_cache_phase] if response_cache_phase else []) + sql_phases,
inventory_results=inventory,
map_actions=_build_map_actions(inventory),
location_context=location_context,
)
async def _deterministic_route_event(
self,
*,
intent_detected: str,
start: float,
query: str,
user_id: str,
session: Any,
tools_service: AgentToolsService,
location_context: dict[str, Any] | None,
cache_key: str | None,
response_cache_phase: dict[str, Any] | None,
) -> dict[str, Any] | None:
if intent_detected == _PRODUCT_RAG_INTENT:
return await self._product_rag_event(
start=start,
query=query,
user_id=user_id,
session=session,
cache_key=cache_key,
response_cache_phase=response_cache_phase,
tools_service=tools_service,
location_context=location_context,
)
if intent_detected == _STORE_LOCATION_INTENT:
return await self._store_location_event(
start=start,
query=query,
user_id=user_id,
session=session,
tools_service=tools_service,
location_context=location_context,
response_cache_phase=response_cache_phase,
)
if intent_detected == _PRODUCT_AVAILABILITY_INTENT:
return await self._product_availability_event(
start=start,
query=query,
user_id=user_id,
session=session,
tools_service=tools_service,
location_context=location_context,
response_cache_phase=response_cache_phase,
)
if intent_detected == _ORDER_STATUS_INTENT:
elapsed_ms = (time.time() - start) * 1000
answer = (
"This demo does not include order tracking yet. I can help with menu recommendations, "
"store locations, and product availability."
)
await self._append_display_history(
user_id=user_id, session_id=session.id, query=query, answer=answer, intent_detected=_ORDER_STATUS_INTENT
)
return _final_event(
answer=answer,
session_id=session.id,
response_time_ms=elapsed_ms,
intent_detected=_ORDER_STATUS_INTENT,
search_metrics={"total_ms": round(elapsed_ms)},
sql_phases=[response_cache_phase] if response_cache_phase else [],
location_context=location_context,
)
return None
async def _resolve_rag_store(
self, *, tools_service: AgentToolsService, query: str, location_context: dict[str, Any] | None
) -> Any | None:
coordinates = _request_coordinates(location_context)
location_hint = (location_context or {}).get("store_name") if location_context else None
if not location_hint:
filters = _extract_location_filters(query, location_context)
location_hint = " ".join(str(filters[k]) for k in ("city", "state", "zip_code") if filters[k]) or None
if not location_hint and not coordinates:
return None
return await tools_service.store_service.resolve_store(
location_hint=str(location_hint) if location_hint else None, coordinates=coordinates
)
[docs]
async def stream_request(
self,
query: str,
user_id: str,
session_id: str | None,
persona: str,
tools_service: AgentToolsService,
location_context: dict[str, Any] | None = None,
) -> AsyncIterator[dict[str, Any]]:
"""Stream a chat turn as ADK produces partial events.
Yields:
Delta events followed by one final event with the complete reply payload.
Raises:
AIServiceUnconfigured: If configured credentials are missing or invalid.
ClientError: If the Gemini client fails for a non-credential reason.
ValueError: If ADK raises a non-credential validation error.
"""
start = time.time()
_ensure_vertex_ai_backend_configured()
session = (
await self._session_service.get_session(
app_name=get_settings().chat.session_app_name, user_id=user_id, session_id=session_id
)
if session_id
else None
)
if not session:
session = await self._session_service.create_session(
app_name=get_settings().chat.session_app_name, user_id=user_id, session_id=session_id
)
# Location context is request-scoped to the turn, so store-specific answers skip the response cache.
cache_key: str | None = None
response_cache_phase: dict[str, Any] | None = None
cached_response: dict[str, Any] | None = None
if not _safe_location_context(location_context):
cache_key = tools_service.make_response_cache_key(query, persona)
cache_start = time.time()
cached_response = await tools_service.get_cached_chat_response(cache_key)
response_cache_phase = _response_cache_phase(
cache_key, hit=bool(cached_response), runtime_ms=(time.time() - cache_start) * 1000
)
if cached_response and response_cache_phase:
yield await self._cached_response_event(
start=start,
query=query,
user_id=user_id,
session=session,
cached_response=cached_response,
response_cache_phase=response_cache_phase,
location_context=location_context,
)
return
intent_result = self._classifier.classify(query)
intent_label = await intent_result if isawaitable(intent_result) else intent_result
intent_detected = intent_label.value if isinstance(intent_label, IntentLabel) else str(intent_label)
route_event = await self._deterministic_route_event(
intent_detected=intent_detected,
start=start,
query=query,
user_id=user_id,
session=session,
tools_service=tools_service,
location_context=location_context,
cache_key=cache_key,
response_cache_phase=response_cache_phase,
)
if route_event:
yield route_event
return
async for event in self._general_conversation_event(
start=start,
query=query,
user_id=user_id,
session=session,
persona=persona,
classifier_intent=intent_detected,
tools_service=tools_service,
location_context=location_context,
cache_key=cache_key,
response_cache_phase=response_cache_phase,
):
yield event
async def _general_conversation_event(
self,
*,
start: float,
query: str,
user_id: str,
session: Any,
persona: str,
classifier_intent: str,
tools_service: AgentToolsService,
location_context: dict[str, Any] | None,
cache_key: str | None,
response_cache_phase: dict[str, Any] | None,
) -> AsyncIterator[dict[str, Any]]:
"""Run the ADK fan-out for a general-conversation turn.
Yields:
Delta events followed by one final event. When the model called the vector
tool, the turn is relabeled to PRODUCT_RAG and re-grounded from that lookup.
Raises:
AIServiceUnconfigured: If configured credentials are missing or invalid.
ClientError: If the Gemini client fails for a non-credential reason.
ValueError: If ADK raises a non-credential validation error.
"""
metric_state: dict[str, Any] = {"search_metrics": {}, "embedding_cache_hit": False, "sql_phases": []}
tools = self._make_tool_factories(tools_service, metric_state, location_context)
workflow = self._build_workflow(
self._persona_manager.get_system_prompt(persona, BASE_SYSTEM_INSTRUCTION),
self._persona_manager.get_temperature(persona),
tools,
)
events = Runner(
node=workflow, app_name=get_settings().chat.session_app_name, session_service=self._session_service
).run_async(
user_id=user_id,
session_id=session.id,
new_message=types.Content(role="user", parts=[types.Part(text=query)]),
run_config=RunConfig(streaming_mode=StreamingMode.SSE),
)
answer_parts: list[str] = []
partial_answer_parts: list[str] = []
workflow_output: dict[str, Any] = {}
try:
async for delta in _collect_workflow_stream(
events,
workflow_output=workflow_output,
answer_parts=answer_parts,
partial_answer_parts=partial_answer_parts,
):
yield delta
except (genai_errors.ClientError, ValueError) as exc:
if _is_credential_error(exc):
raise AIServiceUnconfigured(_UNCONFIGURED_MESSAGE) from exc
raise
answer = str(workflow_output.get("answer") or "".join(answer_parts) or "".join(partial_answer_parts))
elapsed_ms = (time.time() - start) * 1000
search_metrics = dict(metric_state.get("search_metrics", {}))
search_metrics["total_ms"] = round(elapsed_ms)
product_sql_phases = _coerce_sql_phases(metric_state.get("sql_phases"))
# When the model actually called the vector tool, _effective_intent relabels this
# turn to PRODUCT_RAG; re-ground the answer/metrics from the recorded lookup.
intent_detected = _effective_intent(
str(workflow_output.get("intent") or classifier_intent or "GENERAL_CONVERSATION"),
search_metrics,
product_sql_phases,
)
if intent_detected == _PRODUCT_RAG_INTENT:
target_store = await self._resolve_rag_store(
tools_service=tools_service, query=query, location_context=location_context
)
answer = await _ground_product_rag_turn(
query, metric_state, tools_service, store_id=target_store.id if target_store else None
)
elapsed_ms = (time.time() - start) * 1000
search_metrics = dict(metric_state.get("search_metrics", {}))
search_metrics["total_ms"] = round(elapsed_ms)
product_sql_phases = _coerce_sql_phases(metric_state.get("sql_phases"))
sql_phases = ([response_cache_phase] if response_cache_phase else []) + product_sql_phases
inventory_payload = _inventory_response_payload(metric_state)
if answer:
if cache_key:
await tools_service.set_cached_chat_response(
cache_key,
{
"answer": answer,
"intent_detected": intent_detected,
"search_metrics": search_metrics,
"embedding_cache_hit": bool(metric_state.get("embedding_cache_hit")),
"sql_phases": product_sql_phases,
"store_results": [],
**inventory_payload,
},
)
await self._append_display_history(
user_id=user_id, session_id=session.id, query=query, answer=answer, intent_detected=intent_detected
)
yield _final_event(
answer=answer,
session_id=session.id,
response_time_ms=elapsed_ms,
intent_detected=intent_detected,
search_metrics=search_metrics,
sql_phases=sql_phases,
embedding_cache_hit=bool(metric_state.get("embedding_cache_hit")),
inventory_results=inventory_payload["inventory_results"],
map_actions=inventory_payload["map_actions"],
location_context=location_context,
)
__all__ = ("ADKRunner", "AgentToolsService", "credential_guard_callback")