import { BaseHandler } from "./base-handler.ts"
import { Bot, Channel, Message } from "../snek/snek-socket.ts"
import { trim, trimStart } from "npm:lodash-es"
import {
ChatSessionModelFunctions,
ChatWrapper,
defineChatSessionFunction,
GemmaChatWrapper,
GeneralChatWrapper,
getLlama,
LLamaChatPromptOptions,
LlamaChatSession,
LlamaContext,
LlamaModel,
resolveChatWrapper,
Token,
} from "npm:node-llama-cpp"
import { getLogger } from "@logtape/logtape"
import { deepMerge } from "@std/collections/deep-merge"
import { fetchRemoteFunctions } from "../ai/remote-functions.ts"
const llama = await getLlama()
const textEncoder = new TextEncoder()
function printSync(input: string | Uint8Array, to = Deno.stdout) {
let bytesWritten = 0
const bytes = typeof input === "string" ? textEncoder.encode(input) : input
while (bytesWritten < bytes.length) {
bytesWritten += to.writeSync(bytes.subarray(bytesWritten))
}
}
const logger = getLogger(["llama-gen-handler"])
const optionsGenerator = <
const Functions extends ChatSessionModelFunctions | undefined =
| ChatSessionModelFunctions
| undefined,
LLamaOptions extends LLamaChatPromptOptions<Functions> =
LLamaChatPromptOptions<Functions>,
>(
model: LlamaModel,
debugOutput: boolean = true,
defaultTimeout = 5 * 60 * 1000,
options?: LLamaOptions,
): LLamaOptions => {
const manager = AbortSignal.timeout(defaultTimeout)
const defaultOptions = {
repeatPenalty: {
lastTokens: 24,
penalty: 1.12,
penalizeNewLine: true,
frequencyPenalty: 0.02,
presencePenalty: 0.02,
punishTokensFilter: (tokens: Token[]) => {
return tokens.filter((token) => {
const text = model.detokenize([token])
// allow the model to repeat tokens
// that contain the word "better"
return !text.toLowerCase().includes("@")
// TODO: Exclude usernames
})
},
},
temperature: 0.7,
minP: 0.03,
// topK: 64,
// topP: 0.95,
// minP: 0.01,
signal: manager,
stopOnAbortSignal: true,
} as LLamaOptions
if (debugOutput) {
defaultOptions.onResponseChunk = (chunk) => {
options?.onResponseChunk?.(chunk)
const isThoughtSegment = chunk.type === "segment" &&
chunk.segmentType === "thought"
if (
chunk.type === "segment" && chunk.segmentStartTime != null
) {
printSync(` [segment start: ${chunk.segmentType}] `)
}
printSync(chunk.text)
if (chunk.type === "segment" && chunk.segmentEndTime != null) {
printSync(` [segment end: ${chunk.segmentType}] `)
}
}
}
return deepMerge(defaultOptions, options ?? {})
}
export class LLamaFuncHandler extends BaseHandler {
joinMode = new Map<string, boolean>()
streamMode = false
debugLogResponses = true
systemPrompt: string
#activeModel: string
#model: LlamaModel | null = null
#context: LlamaContext | null = null
#chatWrapper: ChatWrapper | null = null
#session: LlamaChatSession | null = null
#subCommands = {
"prompt": this.prompt.bind(this),
"join": this.join.bind(this),
"stream": this.stream.bind(this),
"reset": this.reset.bind(this),
} as Record<string, (command: string, message: Message, bot: Bot) => void>
constructor(
activeModel: string,
systemPrompt: string = "You are an AI chatbot.",
) {
super("")
this.#activeModel = activeModel
this.systemPrompt = systemPrompt
this.autoTyping = false
}
async calculateSystemPrompt(): Promise<string> {
return this.systemPrompt
}
override async bind(bot: Bot): Promise<void> {
await super.bind(bot)
this.prefix = this.user!.username.toLowerCase()
this.#model = await llama.loadModel({
modelPath: this.#activeModel,
defaultContextFlashAttention: true,
})
this.#context = await this.#model.createContext({
flashAttention: true,
})
logger.info("Model loaded", {
batchSize: this.#context.batchSize,
contextSize: this.#context.contextSize,
})
this.#chatWrapper = //new Llama3ChatWrapper()
resolveChatWrapper(this.#model) ?? new GeneralChatWrapper()
this.#session = new LlamaChatSession({
contextSequence: this.#context.getSequence(),
chatWrapper: this.#chatWrapper,
systemPrompt: await this.calculateSystemPrompt(),
})
// const channels = await bot.channels
// const channel = channels.find((v) => v.tag === "public") || channels[0]
// if (channel) {
// await bot.sendMessage(
// channel.uid,
// await this.cleanResponse(
// await this.#session.prompt(
// "Welcome to chat, greet everyone\n",
// optionsGenerator(this.#model, this.debugLogResponses),
// ),
// null,
// bot,
// ),
// )
// this.#session.resetChatHistory()
// }
// logger.info("LLamaHandler bound to bot")
}
cachedChannels: Channel[] | null = null
async isInJoinMode(channelUID: string, bot: Bot): Promise<boolean> {
if (!this.joinMode.has(channelUID)) {
if (!this.cachedChannels) {
this.cachedChannels = await bot.channels
}
const channel = this.cachedChannels?.find((c) => c.uid === channelUID)
if (channel) {
this.joinMode.set(channelUID, channel.tag === "dm")
} else {
logger.warn("Channel not found in cached channels", { channelUID })
this.cachedChannels = await bot.channels
}
}
return this.joinMode.get(channelUID) ?? false
}
override async isMatch(message: Message): Promise<boolean> {
return message
.userUID !== this.user?.uid &&
message.isFinal &&
(await this.isInJoinMode(message.channelUID, message.bot) ||
trim(message?.message, " `").toLowerCase().includes(this.prefix))
}
async cleanResponse(
response: string,
message: Message | null,
bot: Bot,
): Promise<string> {
const session = this.#session
const user = this.user
response = trim(response.replace(/<think>.*?<\/think>/gs, ""), '" \t')
let lwResponse = response.toLowerCase()
if (lwResponse.startsWith("ai")) {
response = response.substring(2).trim()
lwResponse = response.toLowerCase()
}
if (user && lwResponse.startsWith(`@${user.username.toLowerCase()}:`)) {
response = response.substring(user.username.length + 2).trim()
lwResponse = response.toLowerCase()
}
if (
message && lwResponse.startsWith(`@${message.username.toLowerCase()}:`)
) {
response = response.substring(message.username.length + 2).trim()
lwResponse = response.toLowerCase()
}
if (user && lwResponse.startsWith(`${user.username.toLowerCase()}:`)) {
response = response.substring(user.username.length + 2).trim()
lwResponse = response.toLowerCase()
}
if (
message && lwResponse.startsWith(`${message.username.toLowerCase()}:`)
) {
response = response.substring(message.username.length + 2).trim()
lwResponse = response.toLowerCase()
}
response = trimStart(response, ":").trim()
response = trim(response, '"')
return response
}
async join(command: string, message: Message, bot: Bot): Promise<void> {
this.joinMode.set(
message.channelUID,
!await this.isInJoinMode(message.channelUID, bot),
)
}
async stream(command: string, message: Message, bot: Bot): Promise<void> {
this.streamMode = !this.streamMode
}
async reset(command: string, message: Message, bot: Bot): Promise<void> {
const session = this.#session
if (!session) {
return
}
await session.resetChatHistory()
bot.sendMessage(
message.channelUID,
await this.cleanResponse(
await session.prompt(
"Your memory was just reset. Welcome to chat, greet everyone\n",
optionsGenerator(this.#model!, this.debugLogResponses),
),
message,
bot,
),
)
await session.resetChatHistory()
}
async prompt(command: string, message: Message, bot: Bot): Promise<void> {
const session = this.#session
const user = this.user
if (!session || !user) {
return
}
let msgSoFar = ""
let sentMessageInfo: Promise<unknown> | null = null
let streamId: number | undefined = undefined
if (this.streamMode) {
streamId = setInterval(async () => {
if (
msgSoFar.length < 1 ||
(msgSoFar.startsWith("@") && !msgSoFar.includes(" "))
) {
return
}
sentMessageInfo = sentMessageInfo?.then(async (msgInfo) => {
const msg = await this.cleanResponse(msgSoFar, message, bot)
try {
const msgRes = await bot.updateMessageText(
msgInfo as string,
msg,
)
if ("error" in msgRes) {
console.error("Error updating message text", msgRes)
msgSoFar = ""
return bot.sendMessage(message.channelUID, msg)
}
} catch (data) {
console.error(data)
}
return msgInfo
}) ??
bot.sendMessage(
message.channelUID,
await this.cleanResponse(msgSoFar, message, bot),
)
}, 50)
}
console.log("Prompting model", {
model: this.#activeModel,
command,
message,
user: user.username,
})
let response = await session.prompt(
`@${message.username}: ${message.message}`,
optionsGenerator(this.#model!, this.debugLogResponses, 5 * 60 * 1000, {
onTextChunk: (text) => {
bot.setTyping(message.channelUID).catch(() => {})
msgSoFar += text
if (this.streamMode) {
// sentMessageInfo = sentMessageInfo
// ? sentMessageInfo.then(([msgInfo, msg]) => {
// const newMsg = msg + text
// return this.cleanResponse(newMsg, message, bot).then((msg) => {
// return bot.updateMessageText(
// msgInfo,
// msg,
// ).catch(console.error).then(() => [msgInfo, newMsg])
// })
// })
// : bot.sendMessage(
// message.channelUID,
// text,
// ).then((msg) => [msg as string, text])
}
},
functions: {
generateImage: defineChatSessionFunction({
description: "Generate an image from a prompt",
params: {
type: "object",
properties: {
prompt: {
type: "string",
description: "The prompt to generate the image from",
},
},
required: ["prompt"],
},
handler: async ({ prompt }) => {
bot.sendMessage(message.channelUID, "@abot prompt " + prompt)
},
}),
now: defineChatSessionFunction({
description: "Get the current time",
params: {
type: "object",
properties: {},
},
handler: async () => {
const now = new Date()
return now.toLocaleString()
},
}),
...(await fetchRemoteFunctions().catch((e) => {
logger.error("Failed to fetch remote functions", { error: e })
return ({})
})),
},
}),
)
response = await this.cleanResponse(response, message, bot)
clearInterval(streamId)
if (sentMessageInfo) {
const msgInfo = await sentMessageInfo
bot.updateMessageText(
msgInfo,
response,
)
} else {
bot.sendMessage(message.channelUID, response)
}
}
override async handleMessage(message: Message, bot: Bot): Promise<void> {
const user = this.user
if (!user) {
return
}
if (
message.message.toLowerCase().startsWith(
`@${user.username.toLowerCase()}`,
)
) {
const newMessage = message.message.substring(
message.message.indexOf(" "),
message.message.length,
).trim()
if (newMessage && message.message.includes(" ")) {
const [[_, command, rest]] = newMessage.matchAll(/^(\S+)\s*(.*)$/gs)
if (command in this.#subCommands) {
return this.#subCommands[command]?.(rest.trim(), message, bot)
}
}
}
return this.prompt(message.message, message, bot)
}
}