import { BaseHandler } from "./base-handler.ts" import { Bot, Channel, Message } from "../snek/snek-socket.ts" import { trim, trimStart } from "npm:lodash-es" import { ChatSessionModelFunctions, ChatWrapper, defineChatSessionFunction, GemmaChatWrapper, GeneralChatWrapper, getLlama, LLamaChatPromptOptions, LlamaChatSession, LlamaContext, LlamaModel, resolveChatWrapper, Token, } from "npm:node-llama-cpp" import { getLogger } from "@logtape/logtape" import { deepMerge } from "@std/collections/deep-merge" import { fetchRemoteFunctions } from "../ai/remote-functions.ts" const llama = await getLlama() const textEncoder = new TextEncoder() function printSync(input: string | Uint8Array, to = Deno.stdout) { let bytesWritten = 0 const bytes = typeof input === "string" ? textEncoder.encode(input) : input while (bytesWritten < bytes.length) { bytesWritten += to.writeSync(bytes.subarray(bytesWritten)) } } const logger = getLogger(["llama-gen-handler"]) const optionsGenerator = < const Functions extends ChatSessionModelFunctions | undefined = | ChatSessionModelFunctions | undefined, LLamaOptions extends LLamaChatPromptOptions = LLamaChatPromptOptions, >( model: LlamaModel, debugOutput: boolean = true, defaultTimeout = 5 * 60 * 1000, options?: LLamaOptions, ): LLamaOptions => { const manager = AbortSignal.timeout(defaultTimeout) const defaultOptions = { repeatPenalty: { lastTokens: 24, penalty: 1.12, penalizeNewLine: true, frequencyPenalty: 0.02, presencePenalty: 0.02, punishTokensFilter: (tokens: Token[]) => { return tokens.filter((token) => { const text = model.detokenize([token]) // allow the model to repeat tokens // that contain the word "better" return !text.toLowerCase().includes("@") // TODO: Exclude usernames }) }, }, temperature: 0.7, minP: 0.03, // topK: 64, // topP: 0.95, // minP: 0.01, signal: manager, stopOnAbortSignal: true, } as LLamaOptions if (debugOutput) { defaultOptions.onResponseChunk = (chunk) => { options?.onResponseChunk?.(chunk) const isThoughtSegment = chunk.type === "segment" && chunk.segmentType === "thought" if ( chunk.type === "segment" && chunk.segmentStartTime != null ) { printSync(` [segment start: ${chunk.segmentType}] `) } printSync(chunk.text) if (chunk.type === "segment" && chunk.segmentEndTime != null) { printSync(` [segment end: ${chunk.segmentType}] `) } } } return deepMerge(defaultOptions, options ?? {}) } export class LLamaFuncHandler extends BaseHandler { joinMode = new Map() streamMode = false debugLogResponses = true systemPrompt: string #activeModel: string #model: LlamaModel | null = null #context: LlamaContext | null = null #chatWrapper: ChatWrapper | null = null #session: LlamaChatSession | null = null #subCommands = { "prompt": this.prompt.bind(this), "join": this.join.bind(this), "stream": this.stream.bind(this), "reset": this.reset.bind(this), } as Record void> constructor( activeModel: string, systemPrompt: string = "You are an AI chatbot.", ) { super("") this.#activeModel = activeModel this.systemPrompt = systemPrompt this.autoTyping = false } async calculateSystemPrompt(): Promise { return this.systemPrompt } override async bind(bot: Bot): Promise { await super.bind(bot) this.prefix = this.user!.username.toLowerCase() this.#model = await llama.loadModel({ modelPath: this.#activeModel, defaultContextFlashAttention: true, }) this.#context = await this.#model.createContext({ flashAttention: true, }) logger.info("Model loaded", { batchSize: this.#context.batchSize, contextSize: this.#context.contextSize, }) this.#chatWrapper = //new Llama3ChatWrapper() resolveChatWrapper(this.#model) ?? new GeneralChatWrapper() this.#session = new LlamaChatSession({ contextSequence: this.#context.getSequence(), chatWrapper: this.#chatWrapper, systemPrompt: await this.calculateSystemPrompt(), }) // const channels = await bot.channels // const channel = channels.find((v) => v.tag === "public") || channels[0] // if (channel) { // await bot.sendMessage( // channel.uid, // await this.cleanResponse( // await this.#session.prompt( // "Welcome to chat, greet everyone\n", // optionsGenerator(this.#model, this.debugLogResponses), // ), // null, // bot, // ), // ) // this.#session.resetChatHistory() // } // logger.info("LLamaHandler bound to bot") } cachedChannels: Channel[] | null = null async isInJoinMode(channelUID: string, bot: Bot): Promise { if (!this.joinMode.has(channelUID)) { if (!this.cachedChannels) { this.cachedChannels = await bot.channels } const channel = this.cachedChannels?.find((c) => c.uid === channelUID) if (channel) { this.joinMode.set(channelUID, channel.tag === "dm") } else { logger.warn("Channel not found in cached channels", { channelUID }) this.cachedChannels = await bot.channels } } return this.joinMode.get(channelUID) ?? false } override async isMatch(message: Message): Promise { return message .userUID !== this.user?.uid && message.isFinal && (await this.isInJoinMode(message.channelUID, message.bot) || trim(message?.message, " `").toLowerCase().includes(this.prefix)) } async cleanResponse( response: string, message: Message | null, bot: Bot, ): Promise { const session = this.#session const user = this.user response = trim(response.replace(/.*?<\/think>/gs, ""), '" \t') let lwResponse = response.toLowerCase() if (lwResponse.startsWith("ai")) { response = response.substring(2).trim() lwResponse = response.toLowerCase() } if (user && lwResponse.startsWith(`@${user.username.toLowerCase()}:`)) { response = response.substring(user.username.length + 2).trim() lwResponse = response.toLowerCase() } if ( message && lwResponse.startsWith(`@${message.username.toLowerCase()}:`) ) { response = response.substring(message.username.length + 2).trim() lwResponse = response.toLowerCase() } if (user && lwResponse.startsWith(`${user.username.toLowerCase()}:`)) { response = response.substring(user.username.length + 2).trim() lwResponse = response.toLowerCase() } if ( message && lwResponse.startsWith(`${message.username.toLowerCase()}:`) ) { response = response.substring(message.username.length + 2).trim() lwResponse = response.toLowerCase() } response = trimStart(response, ":").trim() response = trim(response, '"') return response } async join(command: string, message: Message, bot: Bot): Promise { this.joinMode.set( message.channelUID, !await this.isInJoinMode(message.channelUID, bot), ) } async stream(command: string, message: Message, bot: Bot): Promise { this.streamMode = !this.streamMode } async reset(command: string, message: Message, bot: Bot): Promise { const session = this.#session if (!session) { return } await session.resetChatHistory() bot.sendMessage( message.channelUID, await this.cleanResponse( await session.prompt( "Your memory was just reset. Welcome to chat, greet everyone\n", optionsGenerator(this.#model!, this.debugLogResponses), ), message, bot, ), ) await session.resetChatHistory() } async prompt(command: string, message: Message, bot: Bot): Promise { const session = this.#session const user = this.user if (!session || !user) { return } let msgSoFar = "" let sentMessageInfo: Promise | null = null let streamId: number | undefined = undefined if (this.streamMode) { streamId = setInterval(async () => { if ( msgSoFar.length < 1 || (msgSoFar.startsWith("@") && !msgSoFar.includes(" ")) ) { return } sentMessageInfo = sentMessageInfo?.then(async (msgInfo) => { const msg = await this.cleanResponse(msgSoFar, message, bot) try { const msgRes = await bot.updateMessageText( msgInfo as string, msg, ) if ("error" in msgRes) { console.error("Error updating message text", msgRes) msgSoFar = "" return bot.sendMessage(message.channelUID, msg) } } catch (data) { console.error(data) } return msgInfo }) ?? bot.sendMessage( message.channelUID, await this.cleanResponse(msgSoFar, message, bot), ) }, 50) } console.log("Prompting model", { model: this.#activeModel, command, message, user: user.username, }) let response = await session.prompt( `@${message.username}: ${message.message}`, optionsGenerator(this.#model!, this.debugLogResponses, 5 * 60 * 1000, { onTextChunk: (text) => { bot.setTyping(message.channelUID).catch(() => {}) msgSoFar += text if (this.streamMode) { // sentMessageInfo = sentMessageInfo // ? sentMessageInfo.then(([msgInfo, msg]) => { // const newMsg = msg + text // return this.cleanResponse(newMsg, message, bot).then((msg) => { // return bot.updateMessageText( // msgInfo, // msg, // ).catch(console.error).then(() => [msgInfo, newMsg]) // }) // }) // : bot.sendMessage( // message.channelUID, // text, // ).then((msg) => [msg as string, text]) } }, functions: { generateImage: defineChatSessionFunction({ description: "Generate an image from a prompt", params: { type: "object", properties: { prompt: { type: "string", description: "The prompt to generate the image from", }, }, required: ["prompt"], }, handler: async ({ prompt }) => { bot.sendMessage(message.channelUID, "@abot prompt " + prompt) }, }), now: defineChatSessionFunction({ description: "Get the current time", params: { type: "object", properties: {}, }, handler: async () => { const now = new Date() return now.toLocaleString() }, }), ...(await fetchRemoteFunctions().catch((e) => { logger.error("Failed to fetch remote functions", { error: e }) return ({}) })), }, }), ) response = await this.cleanResponse(response, message, bot) clearInterval(streamId) if (sentMessageInfo) { const msgInfo = await sentMessageInfo bot.updateMessageText( msgInfo, response, ) } else { bot.sendMessage(message.channelUID, response) } } override async handleMessage(message: Message, bot: Bot): Promise { const user = this.user if (!user) { return } if ( message.message.toLowerCase().startsWith( `@${user.username.toLowerCase()}`, ) ) { const newMessage = message.message.substring( message.message.indexOf(" "), message.message.length, ).trim() if (newMessage && message.message.includes(" ")) { const [[_, command, rest]] = newMessage.matchAll(/^(\S+)\s*(.*)$/gs) if (command in this.#subCommands) { return this.#subCommands[command]?.(rest.trim(), message, bot) } } } return this.prompt(message.message, message, bot) } }