115 lines
4.1 KiB
TypeScript
115 lines
4.1 KiB
TypeScript
import type { ExtensionAPI, ExtensionCommandContext } from "@mariozechner/pi-coding-agent";
|
||
import { discoverModels } from "./discovery.js";
|
||
import { registerProviderWithModels } from "./model-utils.js";
|
||
import {
|
||
BASE_URL,
|
||
DISCOVERY_TIMEOUT_MS,
|
||
FALLBACK_MODEL_ID,
|
||
FALLBACK_NAME,
|
||
PROVIDER_ID,
|
||
} from "./config.js";
|
||
|
||
/** Strip the trailing /v1 from a base URL, used to build the models endpoint. */
|
||
const stripV1 = (url: string): string => url.replace(/\/v1$/, "");
|
||
|
||
// ─── Extension factory ───────────────────────────────────────────────────
|
||
|
||
export default async function (pi: ExtensionAPI): Promise<void> {
|
||
// 1. Register provider IMMEDIATELY with the static fallback so pi startup
|
||
// isn't blocked on the HTTP round-trip.
|
||
const fallback = [{ id: FALLBACK_MODEL_ID, name: FALLBACK_NAME }];
|
||
registerProviderWithModels(pi, fallback);
|
||
|
||
// 2. Race discovery against a short timeout. On LAN the server answers
|
||
// in ~10–50 ms; slow networks fall back to the static list.
|
||
const discovery = discoverModels();
|
||
const timeout = new Promise<never[]>((r) =>
|
||
setTimeout(() => r([]), DISCOVERY_TIMEOUT_MS),
|
||
);
|
||
|
||
try {
|
||
const discovered = await Promise.race([discovery, timeout]);
|
||
|
||
if (discovered.length > 0) {
|
||
registerProviderWithModels(pi, discovered);
|
||
if (process.env.PI_DEBUG) {
|
||
console.log(
|
||
`[local-llama] Discovered ${discovered.length} model(s) at ${BASE_URL}: ${discovered.map((m) => m.id).join(", ")}`,
|
||
);
|
||
}
|
||
}
|
||
} catch {
|
||
// Discovery failed or timed out — static fallback stays registered.
|
||
if (process.env.PI_DEBUG) {
|
||
console.log(`[local-llama] Discovery failed; fallback "${FALLBACK_MODEL_ID}" remains`);
|
||
}
|
||
}
|
||
|
||
// ─── Slash command: /local-llama-refresh ─────────────────────────────
|
||
|
||
pi.registerCommand("local-llama-refresh", {
|
||
description: "Re-discover models from the local llama.cpp server",
|
||
handler: async (_args: string, ctx: ExtensionCommandContext) => {
|
||
try {
|
||
ctx.ui.setStatus(PROVIDER_ID, "Discovering models…");
|
||
const discovered = await discoverModels();
|
||
ctx.ui.setStatus(PROVIDER_ID, undefined);
|
||
|
||
if (discovered.length === 0) {
|
||
ctx.ui.notify("No models returned from /v1/models endpoint", "warning");
|
||
return;
|
||
}
|
||
|
||
registerProviderWithModels(pi, discovered);
|
||
ctx.ui.notify(
|
||
`Registered ${discovered.length} model(s): ${discovered.map((m) => m.id).join(", ")}`,
|
||
"info",
|
||
);
|
||
} catch (err) {
|
||
ctx.ui.setStatus(PROVIDER_ID, undefined);
|
||
ctx.ui.notify(
|
||
`Discovery failed: ${(err as Error).message}`,
|
||
"error",
|
||
);
|
||
}
|
||
},
|
||
});
|
||
|
||
// ─── Slash command: /local-llama-status ──────────────────────────────
|
||
|
||
pi.registerCommand("local-llama-status", {
|
||
description: "Show local llama.cpp server URL and model status",
|
||
handler: async (_args: string, ctx: ExtensionCommandContext) => {
|
||
try {
|
||
const url = `${stripV1(BASE_URL)}/v1/models`;
|
||
const res = await fetch(url, { signal: AbortSignal.timeout(5_000) });
|
||
if (!res.ok) {
|
||
ctx.ui.notify(`Server responded with HTTP ${res.status}`, "error");
|
||
return;
|
||
}
|
||
const json = (await res.json()) as { data?: Array<{ id: string; name?: string }> };
|
||
const models = json.data ?? [];
|
||
const lines = [`Local llama.cpp: ${BASE_URL}`];
|
||
if (models.length === 0) {
|
||
lines.push(" (no models reported)");
|
||
} else {
|
||
for (const m of models) {
|
||
lines.push(` ${m.name ?? m.id}`);
|
||
}
|
||
}
|
||
ctx.ui.notify(lines.join("\n"), "info");
|
||
} catch (err) {
|
||
ctx.ui.notify(
|
||
`Cannot reach server at ${BASE_URL}: ${(err as Error).message}`,
|
||
"error",
|
||
);
|
||
}
|
||
},
|
||
});
|
||
}
|
||
|
||
// ─── Exports for testing ────────────────────────────────────────────────
|
||
|
||
export { discoverModels } from "./discovery.js";
|
||
export { registerProviderWithModels, isReasoningModel } from "./model-utils.js";
|