diff --git a/src/agents/claude_client.py b/src/agents/claude_client.py index 2d8d21c..6dd6c01 100644 --- a/src/agents/claude_client.py +++ b/src/agents/claude_client.py @@ -1,10 +1,15 @@ """Shared Claude CLI Client mit Usage-Tracking.""" import asyncio +import contextvars import json import logging from dataclasses import dataclass from config import CLAUDE_PATH, CLAUDE_TIMEOUT, CLAUDE_MODEL_FAST +# ContextVar fuer Cancel-Event: Wird vom Orchestrator gesetzt, +# call_claude prueft automatisch darauf -- kein Durchreichen noetig. +_cancel_event_var: contextvars.ContextVar[asyncio.Event | None] = contextvars.ContextVar("_cancel_event_var", default=None) + logger = logging.getLogger("osint.claude_client") @@ -78,9 +83,37 @@ async def call_claude(prompt: str, tools: str | None = "WebSearch,WebFetch", mod }, ) try: - stdout, stderr = await asyncio.wait_for( - process.communicate(input=prompt.encode("utf-8")), timeout=CLAUDE_TIMEOUT - ) + cancel_event = _cancel_event_var.get(None) + if cancel_event: + # Cancel-aware: Monitor cancel_event while process runs + communicate_task = asyncio.create_task( + process.communicate(input=prompt.encode("utf-8")) + ) + cancel_wait_task = asyncio.create_task(cancel_event.wait()) + timeout_task = asyncio.create_task(asyncio.sleep(CLAUDE_TIMEOUT)) + + done, pending = await asyncio.wait( + [communicate_task, cancel_wait_task, timeout_task], + return_when=asyncio.FIRST_COMPLETED, + ) + + for p in pending: + p.cancel() + + if communicate_task in done: + stdout, stderr = communicate_task.result() + elif cancel_wait_task in done: + process.kill() + await process.wait() + raise asyncio.CancelledError("Cancel angefordert") + else: + process.kill() + await process.wait() + raise TimeoutError(f"Claude CLI Timeout nach {CLAUDE_TIMEOUT}s") + else: + stdout, stderr = await asyncio.wait_for( + process.communicate(input=prompt.encode("utf-8")), timeout=CLAUDE_TIMEOUT + ) except asyncio.TimeoutError: process.kill() raise TimeoutError(f"Claude CLI Timeout nach {CLAUDE_TIMEOUT}s") diff --git a/src/agents/orchestrator.py b/src/agents/orchestrator.py index a1c3b6b..74e4f8e 100644 --- a/src/agents/orchestrator.py +++ b/src/agents/orchestrator.py @@ -10,7 +10,7 @@ from urllib.parse import urlparse, urlunparse, quote_plus import httpx -from agents.claude_client import UsageAccumulator +from agents.claude_client import UsageAccumulator, _cancel_event_var from agents.factchecker import find_matching_claim, deduplicate_new_facts, TWOPHASE_MIN_FACTS from source_rules import ( _detect_category, @@ -398,6 +398,7 @@ class AgentOrchestrator: self._ws_manager = None self._queued_ids: set[int] = set() self._cancel_requested: set[int] = set() + self._cancel_event: asyncio.Event | None = None def set_ws_manager(self, ws_manager): """WebSocket-Manager setzen für Echtzeit-Updates.""" @@ -441,6 +442,8 @@ class AgentOrchestrator: # Check if it's the currently running task if self._current_task == incident_id: self._cancel_requested.add(incident_id) + if self._cancel_event: + self._cancel_event.set() logger.info(f"Cancel angefordert fuer laufende Lage {incident_id}") if self._ws_manager: try: @@ -512,6 +515,8 @@ class AgentOrchestrator: user_id = None self._queued_ids.discard(incident_id) self._current_task = incident_id + self._cancel_event = asyncio.Event() + _cancel_event_var.set(self._cancel_event) logger.info(f"Starte Refresh für Lage {incident_id} (Trigger: {trigger_type})") RETRY_DELAYS = [0, 120, 300] # Sekunden: sofort, 2min, 5min @@ -585,6 +590,8 @@ class AgentOrchestrator: }, _vis, _cb, _tid) finally: self._current_task = None + self._cancel_event = None + _cancel_event_var.set(None) self._queue.task_done() async def _mark_refresh_cancelled(self, incident_id: int):