import {BaseHandler} from "./base-handler.ts"
import {Bot, Message} from "../snek/snek-socket.ts"
import {trim, trimStart} from "npm:lodash-es"
import {
ChatSessionModelFunctions,
ChatWrapper,
GeneralChatWrapper,
getLlama,
LLamaChatPromptOptions,
LlamaChatSession,
LlamaContext,
LlamaModel,
resolveChatWrapper,
Token,
} from "npm:node-llama-cpp"
import {getLogger} from "@logtape/logtape"
import {deepMerge} from "@std/collections/deep-merge"
const llama = await getLlama()
const textEncoder = new TextEncoder()
function printSync(input: string | Uint8Array, to = Deno.stdout) {
let bytesWritten = 0
const bytes = typeof input === "string" ? textEncoder.encode(input) : input
while (bytesWritten < bytes.length) {
bytesWritten += to.writeSync(bytes.subarray(bytesWritten))
}
}
const logger = getLogger(["llama-gen-handler"])
const optionsGenerator = <
const Functions extends ChatSessionModelFunctions | undefined = undefined,
LLamaOptions = LLamaChatPromptOptions<Functions>
>(
model: LlamaModel,
debugOutput: boolean = true,
defaultTimeout = 5 * 60 * 1000,
options?: LLamaOptions,
): LLamaChatPromptOptions<Functions> => {
const manager = AbortSignal.timeout(defaultTimeout)
const defaultOptions: LLamaChatPromptOptions<Functions> = {
repeatPenalty: {
lastTokens: 24,
penalty: 1.12,
penalizeNewLine: true,
frequencyPenalty: 0.02,
presencePenalty: 0.02,
punishTokensFilter: (tokens: Token[]) => {
return tokens.filter((token) => {
const text = model.detokenize([token])
// allow the model to repeat tokens
// that contain the word "better"
return !text.toLowerCase().includes("@")
// TODO: Exclude usernames
})
},
},
temperature: 0.6,
signal: manager,
stopOnAbortSignal: true,
}
if (debugOutput) {
defaultOptions.onResponseChunk = (chunk) => {
const isThoughtSegment = chunk.type === "segment" &&
chunk.segmentType === "thought"
if (
chunk.type === "segment" && chunk.segmentStartTime != null
) {
printSync(` [segment start: ${chunk.segmentType}] `)
}
printSync(chunk.text)
if (chunk.type === "segment" && chunk.segmentEndTime != null) {
printSync(` [segment end: ${chunk.segmentType}] `)
}
}
}
return deepMerge(defaultOptions, options ?? {})
}
export class LLamaHandler extends BaseHandler {
joinMode = false
debugLogResponses = true
systemPrompt: string
#activeModel: string
#model: LlamaModel | null = null
#context: LlamaContext | null = null
#chatWrapper: ChatWrapper | null = null
#session: LlamaChatSession | null = null
constructor(
activeModel: string,
systemPrompt: string = "You are an AI chatbot.",
) {
super("")
this.#activeModel = activeModel
this.systemPrompt = systemPrompt
}
async calculateSystemPrompt(): Promise<string> {
return this.systemPrompt
}
override async bind(bot: Bot): Promise<void> {
await super.bind(bot)
this.prefix = "@" + this.user?.username.toLowerCase()
this.#model = await llama.loadModel({
modelPath: this.#activeModel,
defaultContextFlashAttention: true,
})
this.#context = await this.#model.createContext({
flashAttention: true,
})
logger.info("Model loaded", {
batchSize: this.#context.batchSize,
contextSize: this.#context.contextSize,
})
console.log("Model loaded", {
batchSize: this.#context.batchSize,
contextSize: this.#context.contextSize,
})
this.#chatWrapper = //new Llama3ChatWrapper()
resolveChatWrapper({
bosString: this.#model
.tokens
.bosString,
filename: this.#model
.filename,
fileInfo: this.#model
.fileInfo,
tokenizer: this.#model
.tokenizer,
}) ?? new GeneralChatWrapper()
this.#session = new LlamaChatSession({
contextSequence: this.#context.getSequence(),
chatWrapper: this.#chatWrapper,
systemPrompt: await this.calculateSystemPrompt(),
})
const channels = await bot.channels
const channel = channels.find((v) => v.tag === "public") || channels[0]
if (channel) {
await bot.sendMessage(
channel.uid,
await this.#session.prompt(
"Welcome to chat, greet everyone\n",
optionsGenerator(this.#model, this.debugLogResponses),
),
)
this.#session.resetChatHistory()
}
}
override async isMatch(message: Message): Promise<boolean> {
return message.userUID !== this.user?.uid && (this.joinMode ||
trim(message?.message, " `").toLowerCase().includes(this.prefix))
}
override async handleMessage(message: Message, bot: Bot): Promise<void> {
const session = this.#session
const user = this.user
if (!session || !user) {
return
}
let response = await session.prompt(
`@${message.username}: ${message.message}`,
optionsGenerator(this.#model!, this.debugLogResponses),
)
response = response.replace(/<think>.*?<\/think>/gs, "")
let lwResponse = response.toLowerCase()
if (lwResponse.startsWith("ai")) {
response = response.substring(2).trim()
lwResponse = response.toLowerCase()
}
if (lwResponse.startsWith(`@${user.username.toLowerCase()}`)) {
response = response.substring(user.username.length + 1).trim()
lwResponse = response.toLowerCase()
}
if (lwResponse.startsWith(`@${message.username.toLowerCase()}`)) {
response = response.substring(message.username.length + 1).trim()
lwResponse = response.toLowerCase()
}
response = trimStart(response, ":").trim()
bot.send("send_message", message.channelUID, response)
}
}