463 lines
13 KiB
TypeScript
463 lines
13 KiB
TypeScript
import * as https from "node:https";
|
|
import { URL } from "node:url";
|
|
import {
|
|
type AssistantMessage,
|
|
type AssistantMessageEventStream,
|
|
type Context,
|
|
type Model,
|
|
type SimpleStreamOptions,
|
|
type StopReason,
|
|
type ToolCall,
|
|
calculateCost,
|
|
createAssistantMessageEventStream,
|
|
parseStreamingJson,
|
|
} from "@mariozechner/pi-ai";
|
|
import { captureFirstOutput, finalizePiTokenStats, type PiTokenStats } from "../shared/token-stats.js";
|
|
import {
|
|
AI_SERVER_CHAT_PATH,
|
|
AI_SERVER_URL,
|
|
getRequestTimeoutMs,
|
|
loadCerts,
|
|
} from "./config.js";
|
|
import {
|
|
contextToOpenAIMessages,
|
|
toolsToOpenAI,
|
|
} from "./messages.js";
|
|
|
|
type CurrentBlock =
|
|
| { kind: "text"; text: string }
|
|
| { kind: "thinking"; thinking: string }
|
|
| {
|
|
kind: "toolCall";
|
|
id: string;
|
|
name: string;
|
|
partialArgs: string;
|
|
};
|
|
|
|
function mapFinishReason(fr: string | null | undefined): StopReason {
|
|
switch (fr) {
|
|
case "length":
|
|
return "length";
|
|
case "tool_calls":
|
|
return "toolUse";
|
|
case "stop":
|
|
case null:
|
|
case undefined:
|
|
return "stop";
|
|
default:
|
|
return "stop";
|
|
}
|
|
}
|
|
|
|
export function streamAiServer(
|
|
model: Model<any>,
|
|
context: Context,
|
|
options?: SimpleStreamOptions,
|
|
): AssistantMessageEventStream {
|
|
const stream = createAssistantMessageEventStream();
|
|
|
|
(async () => {
|
|
const tokenStats: PiTokenStats = {
|
|
requestStartMs: Date.now(),
|
|
};
|
|
const output: AssistantMessage = {
|
|
role: "assistant",
|
|
content: [],
|
|
api: model.api,
|
|
provider: model.provider,
|
|
model: model.id,
|
|
usage: {
|
|
input: 0,
|
|
output: 0,
|
|
cacheRead: 0,
|
|
cacheWrite: 0,
|
|
totalTokens: 0,
|
|
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
|
},
|
|
stopReason: "stop",
|
|
timestamp: Date.now(),
|
|
};
|
|
|
|
let currentBlock: CurrentBlock | null = null;
|
|
let streamEnded = false;
|
|
let thinkingTokens: number | undefined;
|
|
|
|
const endWithError = (reason: "error" | "aborted", message: string) => {
|
|
if (streamEnded) return;
|
|
streamEnded = true;
|
|
output.stopReason = reason;
|
|
output.errorMessage = message;
|
|
stream.push({ type: "error", reason, error: output });
|
|
stream.end();
|
|
};
|
|
|
|
const blockIndex = () => output.content.length - 1;
|
|
|
|
const finishCurrentBlock = () => {
|
|
if (!currentBlock) return;
|
|
const idx = blockIndex();
|
|
if (currentBlock.kind === "text") {
|
|
stream.push({
|
|
type: "text_end",
|
|
contentIndex: idx,
|
|
content: currentBlock.text,
|
|
partial: output,
|
|
});
|
|
} else if (currentBlock.kind === "thinking") {
|
|
stream.push({
|
|
type: "thinking_end",
|
|
contentIndex: idx,
|
|
content: currentBlock.thinking,
|
|
partial: output,
|
|
});
|
|
} else {
|
|
const block = output.content[idx] as ToolCall;
|
|
block.arguments = parseStreamingJson(currentBlock.partialArgs) ?? {};
|
|
stream.push({
|
|
type: "toolcall_end",
|
|
contentIndex: idx,
|
|
toolCall: block,
|
|
partial: output,
|
|
});
|
|
}
|
|
currentBlock = null;
|
|
};
|
|
|
|
try {
|
|
stream.push({ type: "start", partial: output });
|
|
|
|
const bodyObj: Record<string, unknown> = {
|
|
model: model.id,
|
|
messages: contextToOpenAIMessages(context),
|
|
temperature: options?.temperature ?? 0.7,
|
|
max_tokens: options?.maxTokens ?? model.maxTokens,
|
|
stream: true,
|
|
stream_options: { include_usage: true },
|
|
};
|
|
|
|
// Reasoning / thinking-level forwarding. pi-mono passes
|
|
// `options.reasoning` (a ThinkingLevel: minimal|low|medium|high|xhigh)
|
|
// from `defaultThinkingLevel` in ~/.pi/agent/settings.json. Forward
|
|
// it to llama.cpp two ways simultaneously so both qwen-style chat
|
|
// templates and openai-style providers see the directive:
|
|
// • chat_template_kwargs.enable_thinking — Qwen3 / MiMo / Devstral
|
|
// family templates respect this boolean.
|
|
// • reasoning_effort — passed through as a
|
|
// chat-template kwarg by llama-server; the few templates that
|
|
// read it (gpt-oss, MiniMax) get full granularity.
|
|
// Skip entirely for non-reasoning models so we don't poison their
|
|
// chat templates with kwargs they don't understand.
|
|
const reasoning = options?.reasoning;
|
|
if (reasoning && model.reasoning) {
|
|
bodyObj.chat_template_kwargs = {
|
|
enable_thinking: reasoning !== "minimal",
|
|
};
|
|
bodyObj.reasoning_effort = reasoning;
|
|
}
|
|
|
|
const openaiTools = toolsToOpenAI(context.tools);
|
|
if (openaiTools) {
|
|
bodyObj.tools = openaiTools;
|
|
const toolChoice = (options as Record<string, unknown> | undefined)
|
|
?.toolChoice;
|
|
if (toolChoice !== undefined) bodyObj.tool_choice = toolChoice;
|
|
}
|
|
|
|
const body = JSON.stringify(bodyObj);
|
|
const certs = loadCerts();
|
|
const url = new URL(AI_SERVER_URL + AI_SERVER_CHAT_PATH);
|
|
const requestTimeoutMs = getRequestTimeoutMs();
|
|
|
|
// No `ca:` — server cert is publicly-trusted (Let's Encrypt), so
|
|
// rely on Node's default Mozilla CA bundle. mTLS client auth still
|
|
// requires `cert` + `key`.
|
|
const req = https.request({
|
|
hostname: url.hostname,
|
|
port: url.port ? Number(url.port) : 443,
|
|
path: url.pathname + url.search,
|
|
method: "POST",
|
|
headers: {
|
|
"Content-Type": "application/json",
|
|
"Content-Length": Buffer.byteLength(body),
|
|
Accept: "text/event-stream",
|
|
},
|
|
cert: certs.cert,
|
|
key: certs.key,
|
|
timeout: requestTimeoutMs,
|
|
});
|
|
|
|
// TCP keepalive: kernel sends probes every 30s of idle. Stops NAT /
|
|
// stateful firewalls on the LAN path from silently dropping the flow
|
|
// during long prefills (when llama.cpp emits no SSE bytes yet) and
|
|
// surfaces real drops fast instead of after the kernel retransmit
|
|
// deadline (~15min).
|
|
req.on("socket", (socket) => {
|
|
socket.setKeepAlive(true, 30_000);
|
|
});
|
|
|
|
const onAbort = () => {
|
|
req.destroy();
|
|
endWithError("aborted", "Request aborted");
|
|
};
|
|
if (options?.signal) {
|
|
if (options.signal.aborted) {
|
|
onAbort();
|
|
return;
|
|
}
|
|
options.signal.addEventListener("abort", onAbort, { once: true });
|
|
}
|
|
|
|
req.on("timeout", () => {
|
|
req.destroy(
|
|
new Error(`Request timed out after ${requestTimeoutMs}ms`),
|
|
);
|
|
});
|
|
|
|
req.on("error", (err) => {
|
|
if (streamEnded) return;
|
|
endWithError("error", err.message);
|
|
});
|
|
|
|
req.on("response", (res) => {
|
|
if (res.statusCode !== 200) {
|
|
let errBody = "";
|
|
res.setEncoding("utf-8");
|
|
res.on("data", (chunk: string) => {
|
|
errBody += chunk;
|
|
});
|
|
res.on("end", () => {
|
|
endWithError(
|
|
"error",
|
|
`HTTP ${res.statusCode}: ${errBody.slice(0, 500)}`,
|
|
);
|
|
});
|
|
return;
|
|
}
|
|
|
|
const decoder = new TextDecoder();
|
|
let buffer = "";
|
|
|
|
const processEventBlock = (eventBlock: string) => {
|
|
for (const rawLine of eventBlock.split("\n")) {
|
|
const line = rawLine.trim();
|
|
if (!line.startsWith("data:")) continue;
|
|
const payload = line.slice(5).trim();
|
|
if (!payload || payload === "[DONE]") continue;
|
|
|
|
let data: any;
|
|
try {
|
|
data = JSON.parse(payload);
|
|
} catch {
|
|
continue;
|
|
}
|
|
|
|
if (data.id && !output.responseId) output.responseId = data.id;
|
|
|
|
if (data.usage) {
|
|
const reportedThinkingTokens =
|
|
data.usage.completion_tokens_details?.reasoning_tokens;
|
|
if (
|
|
typeof reportedThinkingTokens === "number"
|
|
&& Number.isFinite(reportedThinkingTokens)
|
|
) {
|
|
thinkingTokens = reportedThinkingTokens;
|
|
}
|
|
output.usage.input =
|
|
data.usage.prompt_tokens ?? output.usage.input;
|
|
output.usage.output =
|
|
data.usage.completion_tokens ?? output.usage.output;
|
|
// llama.cpp reports cached prompt tokens under
|
|
// prompt_tokens_details.cached_tokens. This is a subset
|
|
// of prompt_tokens, so it does NOT add to totalTokens
|
|
// (input already counts them).
|
|
output.usage.cacheRead =
|
|
data.usage.prompt_tokens_details?.cached_tokens ??
|
|
output.usage.cacheRead;
|
|
output.usage.totalTokens =
|
|
output.usage.input + output.usage.output;
|
|
}
|
|
|
|
const choice = data.choices?.[0];
|
|
if (!choice) continue;
|
|
|
|
if (choice.finish_reason) {
|
|
output.stopReason = mapFinishReason(choice.finish_reason);
|
|
}
|
|
|
|
const delta = choice.delta;
|
|
if (!delta) continue;
|
|
|
|
// ── Reasoning / thinking (llama.cpp uses reasoning_content) ──
|
|
const reasoning: string | undefined =
|
|
delta.reasoning_content ?? delta.reasoning;
|
|
if (reasoning) {
|
|
captureFirstOutput(tokenStats, Date.now());
|
|
if (!currentBlock || currentBlock.kind !== "thinking") {
|
|
finishCurrentBlock();
|
|
currentBlock = { kind: "thinking", thinking: "" };
|
|
output.content.push({ type: "thinking", thinking: "" });
|
|
stream.push({
|
|
type: "thinking_start",
|
|
contentIndex: blockIndex(),
|
|
partial: output,
|
|
});
|
|
}
|
|
currentBlock.thinking += reasoning;
|
|
(output.content[blockIndex()] as { thinking: string }).thinking =
|
|
currentBlock.thinking;
|
|
stream.push({
|
|
type: "thinking_delta",
|
|
contentIndex: blockIndex(),
|
|
delta: reasoning,
|
|
partial: output,
|
|
});
|
|
}
|
|
|
|
// ── Text ──
|
|
if (typeof delta.content === "string" && delta.content.length > 0) {
|
|
captureFirstOutput(tokenStats, Date.now());
|
|
if (!currentBlock || currentBlock.kind !== "text") {
|
|
finishCurrentBlock();
|
|
currentBlock = { kind: "text", text: "" };
|
|
output.content.push({ type: "text", text: "" });
|
|
stream.push({
|
|
type: "text_start",
|
|
contentIndex: blockIndex(),
|
|
partial: output,
|
|
});
|
|
}
|
|
currentBlock.text += delta.content;
|
|
(output.content[blockIndex()] as { text: string }).text =
|
|
currentBlock.text;
|
|
stream.push({
|
|
type: "text_delta",
|
|
contentIndex: blockIndex(),
|
|
delta: delta.content,
|
|
partial: output,
|
|
});
|
|
}
|
|
|
|
// ── Tool calls ──
|
|
if (Array.isArray(delta.tool_calls)) {
|
|
for (const tc of delta.tool_calls) {
|
|
captureFirstOutput(tokenStats, Date.now());
|
|
const tcId: string | undefined = tc.id;
|
|
const tcName: string | undefined = tc.function?.name;
|
|
const tcArgs: string | undefined = tc.function?.arguments;
|
|
|
|
const sameBlock =
|
|
currentBlock?.kind === "toolCall" &&
|
|
(!tcId || !currentBlock.id || currentBlock.id === tcId);
|
|
|
|
if (!sameBlock) {
|
|
finishCurrentBlock();
|
|
currentBlock = {
|
|
kind: "toolCall",
|
|
id: tcId ?? "",
|
|
name: tcName ?? "",
|
|
partialArgs: "",
|
|
};
|
|
const block: ToolCall = {
|
|
type: "toolCall",
|
|
id: currentBlock.id,
|
|
name: currentBlock.name,
|
|
arguments: {},
|
|
};
|
|
output.content.push(block);
|
|
stream.push({
|
|
type: "toolcall_start",
|
|
contentIndex: blockIndex(),
|
|
partial: output,
|
|
});
|
|
}
|
|
|
|
if (currentBlock?.kind === "toolCall") {
|
|
const block = output.content[blockIndex()] as ToolCall;
|
|
if (tcId && !currentBlock.id) {
|
|
currentBlock.id = tcId;
|
|
block.id = tcId;
|
|
}
|
|
if (tcName && !currentBlock.name) {
|
|
currentBlock.name = tcName;
|
|
block.name = tcName;
|
|
}
|
|
if (tcArgs) {
|
|
currentBlock.partialArgs += tcArgs;
|
|
const parsed = parseStreamingJson(currentBlock.partialArgs);
|
|
if (parsed && typeof parsed === "object") {
|
|
block.arguments = parsed as Record<string, unknown>;
|
|
}
|
|
stream.push({
|
|
type: "toolcall_delta",
|
|
contentIndex: blockIndex(),
|
|
delta: tcArgs,
|
|
partial: output,
|
|
});
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
res.on("data", (chunk: Buffer) => {
|
|
buffer += decoder.decode(chunk, { stream: true });
|
|
let sep: number;
|
|
while ((sep = buffer.indexOf("\n\n")) >= 0) {
|
|
const eventBlock = buffer.slice(0, sep);
|
|
buffer = buffer.slice(sep + 2);
|
|
processEventBlock(eventBlock);
|
|
}
|
|
});
|
|
|
|
res.on("end", () => {
|
|
if (streamEnded) return;
|
|
buffer += decoder.decode();
|
|
if (buffer.trim()) processEventBlock(buffer);
|
|
buffer = "";
|
|
|
|
finishCurrentBlock();
|
|
calculateCost(model, output.usage);
|
|
(output as AssistantMessage & { piTokenStats?: PiTokenStats }).piTokenStats = finalizePiTokenStats(
|
|
tokenStats,
|
|
{
|
|
input: output.usage.input,
|
|
output: output.usage.output,
|
|
thinking: thinkingTokens,
|
|
},
|
|
Date.now(),
|
|
);
|
|
|
|
if (options?.signal?.aborted) {
|
|
endWithError("aborted", "Request aborted");
|
|
return;
|
|
}
|
|
|
|
streamEnded = true;
|
|
stream.push({
|
|
type: "done",
|
|
reason: output.stopReason as "stop" | "length" | "toolUse",
|
|
message: output,
|
|
});
|
|
stream.end();
|
|
});
|
|
|
|
res.on("error", (err) => {
|
|
if (streamEnded) return;
|
|
endWithError("error", err.message);
|
|
});
|
|
});
|
|
|
|
req.write(body);
|
|
req.end();
|
|
} catch (error) {
|
|
endWithError(
|
|
options?.signal?.aborted ? "aborted" : "error",
|
|
error instanceof Error ? error.message : String(error),
|
|
);
|
|
}
|
|
})();
|
|
|
|
return stream;
|
|
}
|