Files
2026-05-17 22:55:46 +02:00

463 lines
13 KiB
TypeScript

import * as https from "node:https";
import { URL } from "node:url";
import {
type AssistantMessage,
type AssistantMessageEventStream,
type Context,
type Model,
type SimpleStreamOptions,
type StopReason,
type ToolCall,
calculateCost,
createAssistantMessageEventStream,
parseStreamingJson,
} from "@mariozechner/pi-ai";
import { captureFirstOutput, finalizePiTokenStats, type PiTokenStats } from "../shared/token-stats.js";
import {
AI_SERVER_CHAT_PATH,
AI_SERVER_URL,
getRequestTimeoutMs,
loadCerts,
} from "./config.js";
import {
contextToOpenAIMessages,
toolsToOpenAI,
} from "./messages.js";
type CurrentBlock =
| { kind: "text"; text: string }
| { kind: "thinking"; thinking: string }
| {
kind: "toolCall";
id: string;
name: string;
partialArgs: string;
};
function mapFinishReason(fr: string | null | undefined): StopReason {
switch (fr) {
case "length":
return "length";
case "tool_calls":
return "toolUse";
case "stop":
case null:
case undefined:
return "stop";
default:
return "stop";
}
}
export function streamAiServer(
model: Model<any>,
context: Context,
options?: SimpleStreamOptions,
): AssistantMessageEventStream {
const stream = createAssistantMessageEventStream();
(async () => {
const tokenStats: PiTokenStats = {
requestStartMs: Date.now(),
};
const output: AssistantMessage = {
role: "assistant",
content: [],
api: model.api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
timestamp: Date.now(),
};
let currentBlock: CurrentBlock | null = null;
let streamEnded = false;
let thinkingTokens: number | undefined;
const endWithError = (reason: "error" | "aborted", message: string) => {
if (streamEnded) return;
streamEnded = true;
output.stopReason = reason;
output.errorMessage = message;
stream.push({ type: "error", reason, error: output });
stream.end();
};
const blockIndex = () => output.content.length - 1;
const finishCurrentBlock = () => {
if (!currentBlock) return;
const idx = blockIndex();
if (currentBlock.kind === "text") {
stream.push({
type: "text_end",
contentIndex: idx,
content: currentBlock.text,
partial: output,
});
} else if (currentBlock.kind === "thinking") {
stream.push({
type: "thinking_end",
contentIndex: idx,
content: currentBlock.thinking,
partial: output,
});
} else {
const block = output.content[idx] as ToolCall;
block.arguments = parseStreamingJson(currentBlock.partialArgs) ?? {};
stream.push({
type: "toolcall_end",
contentIndex: idx,
toolCall: block,
partial: output,
});
}
currentBlock = null;
};
try {
stream.push({ type: "start", partial: output });
const bodyObj: Record<string, unknown> = {
model: model.id,
messages: contextToOpenAIMessages(context),
temperature: options?.temperature ?? 0.7,
max_tokens: options?.maxTokens ?? model.maxTokens,
stream: true,
stream_options: { include_usage: true },
};
// Reasoning / thinking-level forwarding. pi-mono passes
// `options.reasoning` (a ThinkingLevel: minimal|low|medium|high|xhigh)
// from `defaultThinkingLevel` in ~/.pi/agent/settings.json. Forward
// it to llama.cpp two ways simultaneously so both qwen-style chat
// templates and openai-style providers see the directive:
// • chat_template_kwargs.enable_thinking — Qwen3 / MiMo / Devstral
// family templates respect this boolean.
// • reasoning_effort — passed through as a
// chat-template kwarg by llama-server; the few templates that
// read it (gpt-oss, MiniMax) get full granularity.
// Skip entirely for non-reasoning models so we don't poison their
// chat templates with kwargs they don't understand.
const reasoning = options?.reasoning;
if (reasoning && model.reasoning) {
bodyObj.chat_template_kwargs = {
enable_thinking: reasoning !== "minimal",
};
bodyObj.reasoning_effort = reasoning;
}
const openaiTools = toolsToOpenAI(context.tools);
if (openaiTools) {
bodyObj.tools = openaiTools;
const toolChoice = (options as Record<string, unknown> | undefined)
?.toolChoice;
if (toolChoice !== undefined) bodyObj.tool_choice = toolChoice;
}
const body = JSON.stringify(bodyObj);
const certs = loadCerts();
const url = new URL(AI_SERVER_URL + AI_SERVER_CHAT_PATH);
const requestTimeoutMs = getRequestTimeoutMs();
// No `ca:` — server cert is publicly-trusted (Let's Encrypt), so
// rely on Node's default Mozilla CA bundle. mTLS client auth still
// requires `cert` + `key`.
const req = https.request({
hostname: url.hostname,
port: url.port ? Number(url.port) : 443,
path: url.pathname + url.search,
method: "POST",
headers: {
"Content-Type": "application/json",
"Content-Length": Buffer.byteLength(body),
Accept: "text/event-stream",
},
cert: certs.cert,
key: certs.key,
timeout: requestTimeoutMs,
});
// TCP keepalive: kernel sends probes every 30s of idle. Stops NAT /
// stateful firewalls on the LAN path from silently dropping the flow
// during long prefills (when llama.cpp emits no SSE bytes yet) and
// surfaces real drops fast instead of after the kernel retransmit
// deadline (~15min).
req.on("socket", (socket) => {
socket.setKeepAlive(true, 30_000);
});
const onAbort = () => {
req.destroy();
endWithError("aborted", "Request aborted");
};
if (options?.signal) {
if (options.signal.aborted) {
onAbort();
return;
}
options.signal.addEventListener("abort", onAbort, { once: true });
}
req.on("timeout", () => {
req.destroy(
new Error(`Request timed out after ${requestTimeoutMs}ms`),
);
});
req.on("error", (err) => {
if (streamEnded) return;
endWithError("error", err.message);
});
req.on("response", (res) => {
if (res.statusCode !== 200) {
let errBody = "";
res.setEncoding("utf-8");
res.on("data", (chunk: string) => {
errBody += chunk;
});
res.on("end", () => {
endWithError(
"error",
`HTTP ${res.statusCode}: ${errBody.slice(0, 500)}`,
);
});
return;
}
const decoder = new TextDecoder();
let buffer = "";
const processEventBlock = (eventBlock: string) => {
for (const rawLine of eventBlock.split("\n")) {
const line = rawLine.trim();
if (!line.startsWith("data:")) continue;
const payload = line.slice(5).trim();
if (!payload || payload === "[DONE]") continue;
let data: any;
try {
data = JSON.parse(payload);
} catch {
continue;
}
if (data.id && !output.responseId) output.responseId = data.id;
if (data.usage) {
const reportedThinkingTokens =
data.usage.completion_tokens_details?.reasoning_tokens;
if (
typeof reportedThinkingTokens === "number"
&& Number.isFinite(reportedThinkingTokens)
) {
thinkingTokens = reportedThinkingTokens;
}
output.usage.input =
data.usage.prompt_tokens ?? output.usage.input;
output.usage.output =
data.usage.completion_tokens ?? output.usage.output;
// llama.cpp reports cached prompt tokens under
// prompt_tokens_details.cached_tokens. This is a subset
// of prompt_tokens, so it does NOT add to totalTokens
// (input already counts them).
output.usage.cacheRead =
data.usage.prompt_tokens_details?.cached_tokens ??
output.usage.cacheRead;
output.usage.totalTokens =
output.usage.input + output.usage.output;
}
const choice = data.choices?.[0];
if (!choice) continue;
if (choice.finish_reason) {
output.stopReason = mapFinishReason(choice.finish_reason);
}
const delta = choice.delta;
if (!delta) continue;
// ── Reasoning / thinking (llama.cpp uses reasoning_content) ──
const reasoning: string | undefined =
delta.reasoning_content ?? delta.reasoning;
if (reasoning) {
captureFirstOutput(tokenStats, Date.now());
if (!currentBlock || currentBlock.kind !== "thinking") {
finishCurrentBlock();
currentBlock = { kind: "thinking", thinking: "" };
output.content.push({ type: "thinking", thinking: "" });
stream.push({
type: "thinking_start",
contentIndex: blockIndex(),
partial: output,
});
}
currentBlock.thinking += reasoning;
(output.content[blockIndex()] as { thinking: string }).thinking =
currentBlock.thinking;
stream.push({
type: "thinking_delta",
contentIndex: blockIndex(),
delta: reasoning,
partial: output,
});
}
// ── Text ──
if (typeof delta.content === "string" && delta.content.length > 0) {
captureFirstOutput(tokenStats, Date.now());
if (!currentBlock || currentBlock.kind !== "text") {
finishCurrentBlock();
currentBlock = { kind: "text", text: "" };
output.content.push({ type: "text", text: "" });
stream.push({
type: "text_start",
contentIndex: blockIndex(),
partial: output,
});
}
currentBlock.text += delta.content;
(output.content[blockIndex()] as { text: string }).text =
currentBlock.text;
stream.push({
type: "text_delta",
contentIndex: blockIndex(),
delta: delta.content,
partial: output,
});
}
// ── Tool calls ──
if (Array.isArray(delta.tool_calls)) {
for (const tc of delta.tool_calls) {
captureFirstOutput(tokenStats, Date.now());
const tcId: string | undefined = tc.id;
const tcName: string | undefined = tc.function?.name;
const tcArgs: string | undefined = tc.function?.arguments;
const sameBlock =
currentBlock?.kind === "toolCall" &&
(!tcId || !currentBlock.id || currentBlock.id === tcId);
if (!sameBlock) {
finishCurrentBlock();
currentBlock = {
kind: "toolCall",
id: tcId ?? "",
name: tcName ?? "",
partialArgs: "",
};
const block: ToolCall = {
type: "toolCall",
id: currentBlock.id,
name: currentBlock.name,
arguments: {},
};
output.content.push(block);
stream.push({
type: "toolcall_start",
contentIndex: blockIndex(),
partial: output,
});
}
if (currentBlock?.kind === "toolCall") {
const block = output.content[blockIndex()] as ToolCall;
if (tcId && !currentBlock.id) {
currentBlock.id = tcId;
block.id = tcId;
}
if (tcName && !currentBlock.name) {
currentBlock.name = tcName;
block.name = tcName;
}
if (tcArgs) {
currentBlock.partialArgs += tcArgs;
const parsed = parseStreamingJson(currentBlock.partialArgs);
if (parsed && typeof parsed === "object") {
block.arguments = parsed as Record<string, unknown>;
}
stream.push({
type: "toolcall_delta",
contentIndex: blockIndex(),
delta: tcArgs,
partial: output,
});
}
}
}
}
}
};
res.on("data", (chunk: Buffer) => {
buffer += decoder.decode(chunk, { stream: true });
let sep: number;
while ((sep = buffer.indexOf("\n\n")) >= 0) {
const eventBlock = buffer.slice(0, sep);
buffer = buffer.slice(sep + 2);
processEventBlock(eventBlock);
}
});
res.on("end", () => {
if (streamEnded) return;
buffer += decoder.decode();
if (buffer.trim()) processEventBlock(buffer);
buffer = "";
finishCurrentBlock();
calculateCost(model, output.usage);
(output as AssistantMessage & { piTokenStats?: PiTokenStats }).piTokenStats = finalizePiTokenStats(
tokenStats,
{
input: output.usage.input,
output: output.usage.output,
thinking: thinkingTokens,
},
Date.now(),
);
if (options?.signal?.aborted) {
endWithError("aborted", "Request aborted");
return;
}
streamEnded = true;
stream.push({
type: "done",
reason: output.stopReason as "stop" | "length" | "toolUse",
message: output,
});
stream.end();
});
res.on("error", (err) => {
if (streamEnded) return;
endWithError("error", err.message);
});
});
req.write(body);
req.end();
} catch (error) {
endWithError(
options?.signal?.aborted ? "aborted" : "error",
error instanceof Error ? error.message : String(error),
);
}
})();
return stream;
}