import {BaseHandler} from "./base-handler.ts" import {Bot, Message} from "../snek/snek-socket.ts" import {trim, trimStart} from "npm:lodash-es" import { ChatSessionModelFunctions, ChatWrapper, GeneralChatWrapper, getLlama, LLamaChatPromptOptions, LlamaChatSession, LlamaContext, LlamaModel, resolveChatWrapper, Token, } from "npm:node-llama-cpp" import {getLogger} from "@logtape/logtape" import {deepMerge} from "@std/collections/deep-merge" const llama = await getLlama() const textEncoder = new TextEncoder() function printSync(input: string | Uint8Array, to = Deno.stdout) { let bytesWritten = 0 const bytes = typeof input === "string" ? textEncoder.encode(input) : input while (bytesWritten < bytes.length) { bytesWritten += to.writeSync(bytes.subarray(bytesWritten)) } } const logger = getLogger(["llama-gen-handler"]) const optionsGenerator = < const Functions extends ChatSessionModelFunctions | undefined = undefined, LLamaOptions = LLamaChatPromptOptions >( model: LlamaModel, debugOutput: boolean = true, defaultTimeout = 5 * 60 * 1000, options?: LLamaOptions, ): LLamaChatPromptOptions => { const manager = AbortSignal.timeout(defaultTimeout) const defaultOptions: LLamaChatPromptOptions = { repeatPenalty: { lastTokens: 24, penalty: 1.12, penalizeNewLine: true, frequencyPenalty: 0.02, presencePenalty: 0.02, punishTokensFilter: (tokens: Token[]) => { return tokens.filter((token) => { const text = model.detokenize([token]) // allow the model to repeat tokens // that contain the word "better" return !text.toLowerCase().includes("@") // TODO: Exclude usernames }) }, }, temperature: 0.6, signal: manager, stopOnAbortSignal: true, } if (debugOutput) { defaultOptions.onResponseChunk = (chunk) => { const isThoughtSegment = chunk.type === "segment" && chunk.segmentType === "thought" if ( chunk.type === "segment" && chunk.segmentStartTime != null ) { printSync(` [segment start: ${chunk.segmentType}] `) } printSync(chunk.text) if (chunk.type === "segment" && chunk.segmentEndTime != null) { printSync(` [segment end: ${chunk.segmentType}] `) } } } return deepMerge(defaultOptions, options ?? {}) } export class LLamaHandler extends BaseHandler { joinMode = false debugLogResponses = true systemPrompt: string #activeModel: string #model: LlamaModel | null = null #context: LlamaContext | null = null #chatWrapper: ChatWrapper | null = null #session: LlamaChatSession | null = null constructor( activeModel: string, systemPrompt: string = "You are an AI chatbot.", ) { super("") this.#activeModel = activeModel this.systemPrompt = systemPrompt } async calculateSystemPrompt(): Promise { return this.systemPrompt } override async bind(bot: Bot): Promise { await super.bind(bot) this.prefix = "@" + this.user?.username.toLowerCase() this.#model = await llama.loadModel({ modelPath: this.#activeModel, defaultContextFlashAttention: true, }) this.#context = await this.#model.createContext({ flashAttention: true, }) logger.info("Model loaded", { batchSize: this.#context.batchSize, contextSize: this.#context.contextSize, }) console.log("Model loaded", { batchSize: this.#context.batchSize, contextSize: this.#context.contextSize, }) this.#chatWrapper = //new Llama3ChatWrapper() resolveChatWrapper({ bosString: this.#model .tokens .bosString, filename: this.#model .filename, fileInfo: this.#model .fileInfo, tokenizer: this.#model .tokenizer, }) ?? new GeneralChatWrapper() this.#session = new LlamaChatSession({ contextSequence: this.#context.getSequence(), chatWrapper: this.#chatWrapper, systemPrompt: await this.calculateSystemPrompt(), }) const channels = await bot.channels const channel = channels.find((v) => v.tag === "public") || channels[0] if (channel) { await bot.sendMessage( channel.uid, await this.#session.prompt( "Welcome to chat, greet everyone\n", optionsGenerator(this.#model, this.debugLogResponses), ), ) this.#session.resetChatHistory() } } override async isMatch(message: Message): Promise { return message.userUID !== this.user?.uid && (this.joinMode || trim(message?.message, " `").toLowerCase().includes(this.prefix)) } override async handleMessage(message: Message, bot: Bot): Promise { const session = this.#session const user = this.user if (!session || !user) { return } let response = await session.prompt( `@${message.username}: ${message.message}`, optionsGenerator(this.#model!, this.debugLogResponses), ) response = response.replace(/.*?<\/think>/gs, "") let lwResponse = response.toLowerCase() if (lwResponse.startsWith("ai")) { response = response.substring(2).trim() lwResponse = response.toLowerCase() } if (lwResponse.startsWith(`@${user.username.toLowerCase()}`)) { response = response.substring(user.username.length + 1).trim() lwResponse = response.toLowerCase() } if (lwResponse.startsWith(`@${message.username.toLowerCase()}`)) { response = response.substring(message.username.length + 1).trim() lwResponse = response.toLowerCase() } response = trimStart(response, ":").trim() bot.send("send_message", message.channelUID, response) } }