diff --git a/openkb/agent/chat.py b/openkb/agent/chat.py index 42ac9f9..6a36d66 100644 --- a/openkb/agent/chat.py +++ b/openkb/agent/chat.py @@ -189,7 +189,41 @@ def _make_prompt_session(session: ChatSession, style: Style, use_color: bool) -> ) -async def _run_turn(agent: Any, session: ChatSession, user_input: str, style: Style) -> None: +def _make_rich_console() -> Any: + """Create a Rich Console with a Claude-Code-like Markdown theme.""" + from rich.console import Console + from rich.theme import Theme + + theme = Theme({ + # Headings: bold with blue tint + "markdown.h1": "bold #5fa0e0", + "markdown.h2": "bold #5fa0e0", + "markdown.h3": "bold #7ab0e8", + "markdown.h4": "bold #8abae0", + # Code + "markdown.code": "#e8c87a on #1e1e1e", + # Links + "markdown.link": "underline #5fa0e0", + "markdown.link_url": "#5fa0e0", + # Emphasis + "markdown.bold": "bold #e0e0e0", + "markdown.italic": "italic #c0c0c0", + # Lists and block quotes + "markdown.item.bullet": "#6ac0a0", + "markdown.item.number": "#6ac0a0", + "markdown.block_quote": "italic #8a8a8a", + # Horizontal rule + "markdown.hr": "#4a4a4a", + # Paragraphs — ensure normal text is visible + "markdown.paragraph": "#d0d0d0", + }) + return Console(theme=theme) + + +async def _run_turn( + agent: Any, session: ChatSession, user_input: str, style: Style, + *, use_color: bool = True, +) -> None: """Run one agent turn with streaming output and persist the new history.""" from agents import ( RawResponsesStreamEvent, @@ -202,11 +236,29 @@ async def _run_turn(agent: Any, session: ChatSession, user_input: str, style: St result = Runner.run_streamed(agent, new_input, max_turns=MAX_TURNS) - sys.stdout.write("\n") - sys.stdout.flush() + print() collected: list[str] = [] last_was_text = False need_blank_before_text = False + + if use_color: + from rich.console import Console + from rich.live import Live + from rich.markdown import Markdown + + console = _make_rich_console() + else: + console = None # type: ignore[assignment] + + def _start_live() -> Any: + if console is None: + return None + lv = Live(console=console, vertical_overflow="visible") + lv.start() + return lv + + live = _start_live() + try: async for event in result.stream_events(): if isinstance(event, RawResponsesStreamEvent): @@ -214,27 +266,45 @@ async def _run_turn(agent: Any, session: ChatSession, user_input: str, style: St text = event.data.delta if text: if need_blank_before_text: - sys.stdout.write("\n") + if live: + live.stop() + live = None + print() + live = _start_live() + else: + sys.stdout.write("\n") need_blank_before_text = False - sys.stdout.write(text) - sys.stdout.flush() collected.append(text) last_was_text = True + if live: + live.update(Markdown("".join(collected), code_theme="monokai")) + else: + sys.stdout.write(text) + sys.stdout.flush() elif isinstance(event, RunItemStreamEvent): item = event.item if item.type == "tool_call_item": if last_was_text: - sys.stdout.write("\n") - sys.stdout.flush() + if live: + live.stop() + live = None + else: + sys.stdout.write("\n") + sys.stdout.flush() last_was_text = False raw = item.raw_item name = getattr(raw, "name", "?") args = getattr(raw, "arguments", "") or "" + if live: + live.stop() + live = None _fmt(style, ("class:tool", _format_tool_line(name, args) + "\n")) + live = _start_live() need_blank_before_text = True finally: - sys.stdout.write("\n\n") - sys.stdout.flush() + if live: + live.stop() + print() answer = "".join(collected).strip() if not answer: @@ -371,7 +441,7 @@ async def run_chat( append_log(kb_dir / "wiki", "query", user_input) try: - await _run_turn(agent, session, user_input, style) + await _run_turn(agent, session, user_input, style, use_color=use_color) except KeyboardInterrupt: _fmt(style, ("class:error", "\n[aborted]\n")) except Exception as exc: diff --git a/openkb/agent/query.py b/openkb/agent/query.py index 39e0e40..a7f2053 100644 --- a/openkb/agent/query.py +++ b/openkb/agent/query.py @@ -120,25 +120,55 @@ async def run_query(question: str, kb_dir: Path, model: str, stream: bool = Fals result = await Runner.run(agent, question, max_turns=MAX_TURNS) return result.final_output or "" + import os + use_color = sys.stdout.isatty() and not os.environ.get("NO_COLOR", "") + + if use_color: + from rich.live import Live + from rich.markdown import Markdown + from openkb.agent.chat import _make_rich_console + console = _make_rich_console() + else: + console = None # type: ignore[assignment] + + def _start_live() -> Live | None: + if console is None: + return None + lv = Live(console=console, vertical_overflow="visible") + lv.start() + return lv + + live = _start_live() + result = Runner.run_streamed(agent, question, max_turns=MAX_TURNS) - collected = [] - async for event in result.stream_events(): - if isinstance(event, RawResponsesStreamEvent): - if isinstance(event.data, ResponseTextDeltaEvent): - text = event.data.delta - if text: - sys.stdout.write(text) + collected: list[str] = [] + try: + async for event in result.stream_events(): + if isinstance(event, RawResponsesStreamEvent): + if isinstance(event.data, ResponseTextDeltaEvent): + text = event.data.delta + if text: + collected.append(text) + if live: + live.update(Markdown("".join(collected), code_theme="monokai")) + else: + sys.stdout.write(text) + sys.stdout.flush() + elif isinstance(event, RunItemStreamEvent): + item = event.item + if item.type == "tool_call_item": + raw = item.raw_item + args = getattr(raw, "arguments", "{}") + if live: + live.stop() + live = None + sys.stdout.write(f"\n[tool call] {raw.name}({args})\n\n") sys.stdout.flush() - collected.append(text) - elif isinstance(event, RunItemStreamEvent): - item = event.item - if item.type == "tool_call_item": - raw = item.raw_item - args = getattr(raw, "arguments", "{}") - sys.stdout.write(f"\n[tool call] {raw.name}({args})\n\n") - sys.stdout.flush() - elif item.type == "tool_call_output_item": - pass - sys.stdout.write("\n") - sys.stdout.flush() + live = _start_live() + elif item.type == "tool_call_output_item": + pass + finally: + if live: + live.stop() + print() return "".join(collected) if collected else result.final_output or "" diff --git a/pyproject.toml b/pyproject.toml index 4af87be..e368a97 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ "python-dotenv", "json-repair", "prompt_toolkit>=3.0", + "rich>=13.0", ] [project.urls]