|
import { BaseHandler } from "./base-handler.ts"
|
|
import { Bot, Channel, Message } from "../snek/snek-socket.ts"
|
|
import { trim, trimStart } from "npm:lodash-es"
|
|
import {
|
|
ChatSessionModelFunctions,
|
|
ChatWrapper,
|
|
defineChatSessionFunction,
|
|
GemmaChatWrapper,
|
|
GeneralChatWrapper,
|
|
getLlama,
|
|
LLamaChatPromptOptions,
|
|
LlamaChatSession,
|
|
LlamaContext,
|
|
LlamaModel,
|
|
resolveChatWrapper,
|
|
Token,
|
|
} from "npm:node-llama-cpp"
|
|
import { getLogger } from "@logtape/logtape"
|
|
import { deepMerge } from "@std/collections/deep-merge"
|
|
import { fetchRemoteFunctions } from "../ai/remote-functions.ts"
|
|
|
|
const llama = await getLlama()
|
|
|
|
const textEncoder = new TextEncoder()
|
|
|
|
function printSync(input: string | Uint8Array, to = Deno.stdout) {
|
|
let bytesWritten = 0
|
|
const bytes = typeof input === "string" ? textEncoder.encode(input) : input
|
|
while (bytesWritten < bytes.length) {
|
|
bytesWritten += to.writeSync(bytes.subarray(bytesWritten))
|
|
}
|
|
}
|
|
const logger = getLogger(["llama-gen-handler"])
|
|
|
|
const optionsGenerator = <
|
|
const Functions extends ChatSessionModelFunctions | undefined =
|
|
| ChatSessionModelFunctions
|
|
| undefined,
|
|
LLamaOptions extends LLamaChatPromptOptions<Functions> =
|
|
LLamaChatPromptOptions<Functions>,
|
|
>(
|
|
model: LlamaModel,
|
|
debugOutput: boolean = true,
|
|
defaultTimeout = 5 * 60 * 1000,
|
|
options?: LLamaOptions,
|
|
): LLamaOptions => {
|
|
const manager = AbortSignal.timeout(defaultTimeout)
|
|
|
|
const defaultOptions = {
|
|
repeatPenalty: {
|
|
lastTokens: 24,
|
|
penalty: 1.12,
|
|
penalizeNewLine: true,
|
|
frequencyPenalty: 0.02,
|
|
presencePenalty: 0.02,
|
|
punishTokensFilter: (tokens: Token[]) => {
|
|
return tokens.filter((token) => {
|
|
const text = model.detokenize([token])
|
|
|
|
// allow the model to repeat tokens
|
|
// that contain the word "better"
|
|
return !text.toLowerCase().includes("@")
|
|
// TODO: Exclude usernames
|
|
})
|
|
},
|
|
},
|
|
temperature: 0.7,
|
|
minP: 0.03,
|
|
|
|
// topK: 64,
|
|
// topP: 0.95,
|
|
// minP: 0.01,
|
|
signal: manager,
|
|
stopOnAbortSignal: true,
|
|
} as LLamaOptions
|
|
|
|
if (debugOutput) {
|
|
defaultOptions.onResponseChunk = (chunk) => {
|
|
options?.onResponseChunk?.(chunk)
|
|
const isThoughtSegment = chunk.type === "segment" &&
|
|
chunk.segmentType === "thought"
|
|
|
|
if (
|
|
chunk.type === "segment" && chunk.segmentStartTime != null
|
|
) {
|
|
printSync(` [segment start: ${chunk.segmentType}] `)
|
|
}
|
|
|
|
printSync(chunk.text)
|
|
|
|
if (chunk.type === "segment" && chunk.segmentEndTime != null) {
|
|
printSync(` [segment end: ${chunk.segmentType}] `)
|
|
}
|
|
}
|
|
}
|
|
|
|
return deepMerge(defaultOptions, options ?? {})
|
|
}
|
|
|
|
export class LLamaFuncHandler extends BaseHandler {
|
|
joinMode = new Map<string, boolean>()
|
|
streamMode = false
|
|
debugLogResponses = true
|
|
|
|
systemPrompt: string
|
|
|
|
#activeModel: string
|
|
|
|
#model: LlamaModel | null = null
|
|
#context: LlamaContext | null = null
|
|
|
|
#chatWrapper: ChatWrapper | null = null
|
|
|
|
#session: LlamaChatSession | null = null
|
|
#subCommands = {
|
|
"prompt": this.prompt.bind(this),
|
|
"join": this.join.bind(this),
|
|
"stream": this.stream.bind(this),
|
|
"reset": this.reset.bind(this),
|
|
} as Record<string, (command: string, message: Message, bot: Bot) => void>
|
|
|
|
constructor(
|
|
activeModel: string,
|
|
systemPrompt: string = "You are an AI chatbot.",
|
|
) {
|
|
super("")
|
|
this.#activeModel = activeModel
|
|
this.systemPrompt = systemPrompt
|
|
this.autoTyping = false
|
|
}
|
|
|
|
async calculateSystemPrompt(): Promise<string> {
|
|
return this.systemPrompt
|
|
}
|
|
|
|
override async bind(bot: Bot): Promise<void> {
|
|
await super.bind(bot)
|
|
this.prefix = this.user!.username.toLowerCase()
|
|
|
|
this.#model = await llama.loadModel({
|
|
modelPath: this.#activeModel,
|
|
defaultContextFlashAttention: true,
|
|
})
|
|
|
|
this.#context = await this.#model.createContext({
|
|
flashAttention: true,
|
|
})
|
|
|
|
logger.info("Model loaded", {
|
|
batchSize: this.#context.batchSize,
|
|
contextSize: this.#context.contextSize,
|
|
})
|
|
|
|
this.#chatWrapper = //new Llama3ChatWrapper()
|
|
resolveChatWrapper(this.#model) ?? new GeneralChatWrapper()
|
|
|
|
this.#session = new LlamaChatSession({
|
|
contextSequence: this.#context.getSequence(),
|
|
chatWrapper: this.#chatWrapper,
|
|
systemPrompt: await this.calculateSystemPrompt(),
|
|
})
|
|
|
|
// const channels = await bot.channels
|
|
|
|
// const channel = channels.find((v) => v.tag === "public") || channels[0]
|
|
|
|
// if (channel) {
|
|
// await bot.sendMessage(
|
|
// channel.uid,
|
|
// await this.cleanResponse(
|
|
// await this.#session.prompt(
|
|
// "Welcome to chat, greet everyone\n",
|
|
// optionsGenerator(this.#model, this.debugLogResponses),
|
|
// ),
|
|
// null,
|
|
// bot,
|
|
// ),
|
|
// )
|
|
// this.#session.resetChatHistory()
|
|
// }
|
|
|
|
// logger.info("LLamaHandler bound to bot")
|
|
}
|
|
|
|
cachedChannels: Channel[] | null = null
|
|
async isInJoinMode(channelUID: string, bot: Bot): Promise<boolean> {
|
|
if (!this.joinMode.has(channelUID)) {
|
|
if (!this.cachedChannels) {
|
|
this.cachedChannels = await bot.channels
|
|
}
|
|
const channel = this.cachedChannels?.find((c) => c.uid === channelUID)
|
|
if (channel) {
|
|
this.joinMode.set(channelUID, channel.tag === "dm")
|
|
} else {
|
|
logger.warn("Channel not found in cached channels", { channelUID })
|
|
this.cachedChannels = await bot.channels
|
|
}
|
|
}
|
|
return this.joinMode.get(channelUID) ?? false
|
|
}
|
|
|
|
override async isMatch(message: Message): Promise<boolean> {
|
|
return message
|
|
.userUID !== this.user?.uid &&
|
|
message.isFinal &&
|
|
(await this.isInJoinMode(message.channelUID, message.bot) ||
|
|
trim(message?.message, " `").toLowerCase().includes(this.prefix))
|
|
}
|
|
|
|
async cleanResponse(
|
|
response: string,
|
|
message: Message | null,
|
|
bot: Bot,
|
|
): Promise<string> {
|
|
const session = this.#session
|
|
const user = this.user
|
|
|
|
response = trim(response.replace(/<think>.*?<\/think>/gs, ""), '" \t')
|
|
let lwResponse = response.toLowerCase()
|
|
|
|
if (lwResponse.startsWith("ai")) {
|
|
response = response.substring(2).trim()
|
|
lwResponse = response.toLowerCase()
|
|
}
|
|
|
|
if (user && lwResponse.startsWith(`@${user.username.toLowerCase()}:`)) {
|
|
response = response.substring(user.username.length + 2).trim()
|
|
lwResponse = response.toLowerCase()
|
|
}
|
|
|
|
if (
|
|
message && lwResponse.startsWith(`@${message.username.toLowerCase()}:`)
|
|
) {
|
|
response = response.substring(message.username.length + 2).trim()
|
|
lwResponse = response.toLowerCase()
|
|
}
|
|
|
|
if (user && lwResponse.startsWith(`${user.username.toLowerCase()}:`)) {
|
|
response = response.substring(user.username.length + 2).trim()
|
|
lwResponse = response.toLowerCase()
|
|
}
|
|
|
|
if (
|
|
message && lwResponse.startsWith(`${message.username.toLowerCase()}:`)
|
|
) {
|
|
response = response.substring(message.username.length + 2).trim()
|
|
lwResponse = response.toLowerCase()
|
|
}
|
|
|
|
response = trimStart(response, ":").trim()
|
|
response = trim(response, '"')
|
|
return response
|
|
}
|
|
|
|
async join(command: string, message: Message, bot: Bot): Promise<void> {
|
|
this.joinMode.set(
|
|
message.channelUID,
|
|
!await this.isInJoinMode(message.channelUID, bot),
|
|
)
|
|
}
|
|
|
|
async stream(command: string, message: Message, bot: Bot): Promise<void> {
|
|
this.streamMode = !this.streamMode
|
|
}
|
|
|
|
async reset(command: string, message: Message, bot: Bot): Promise<void> {
|
|
const session = this.#session
|
|
if (!session) {
|
|
return
|
|
}
|
|
|
|
await session.resetChatHistory()
|
|
bot.sendMessage(
|
|
message.channelUID,
|
|
await this.cleanResponse(
|
|
await session.prompt(
|
|
"Your memory was just reset. Welcome to chat, greet everyone\n",
|
|
optionsGenerator(this.#model!, this.debugLogResponses),
|
|
),
|
|
message,
|
|
bot,
|
|
),
|
|
)
|
|
await session.resetChatHistory()
|
|
}
|
|
|
|
async prompt(command: string, message: Message, bot: Bot): Promise<void> {
|
|
const session = this.#session
|
|
const user = this.user
|
|
if (!session || !user) {
|
|
return
|
|
}
|
|
|
|
let msgSoFar = ""
|
|
|
|
let sentMessageInfo: Promise<unknown> | null = null
|
|
|
|
let streamId: number | undefined = undefined
|
|
if (this.streamMode) {
|
|
streamId = setInterval(async () => {
|
|
if (
|
|
msgSoFar.length < 1 ||
|
|
(msgSoFar.startsWith("@") && !msgSoFar.includes(" "))
|
|
) {
|
|
return
|
|
}
|
|
sentMessageInfo = sentMessageInfo?.then(async (msgInfo) => {
|
|
const msg = await this.cleanResponse(msgSoFar, message, bot)
|
|
try {
|
|
const msgRes = await bot.updateMessageText(
|
|
msgInfo as string,
|
|
msg,
|
|
)
|
|
if ("error" in msgRes) {
|
|
console.error("Error updating message text", msgRes)
|
|
msgSoFar = ""
|
|
return bot.sendMessage(message.channelUID, msg)
|
|
}
|
|
} catch (data) {
|
|
console.error(data)
|
|
}
|
|
return msgInfo
|
|
}) ??
|
|
bot.sendMessage(
|
|
message.channelUID,
|
|
await this.cleanResponse(msgSoFar, message, bot),
|
|
)
|
|
}, 50)
|
|
}
|
|
|
|
console.log("Prompting model", {
|
|
model: this.#activeModel,
|
|
command,
|
|
message,
|
|
user: user.username,
|
|
})
|
|
let response = await session.prompt(
|
|
`@${message.username}: ${message.message}`,
|
|
optionsGenerator(this.#model!, this.debugLogResponses, 5 * 60 * 1000, {
|
|
onTextChunk: (text) => {
|
|
bot.setTyping(message.channelUID).catch(() => {})
|
|
|
|
msgSoFar += text
|
|
|
|
if (this.streamMode) {
|
|
// sentMessageInfo = sentMessageInfo
|
|
// ? sentMessageInfo.then(([msgInfo, msg]) => {
|
|
// const newMsg = msg + text
|
|
// return this.cleanResponse(newMsg, message, bot).then((msg) => {
|
|
// return bot.updateMessageText(
|
|
// msgInfo,
|
|
// msg,
|
|
// ).catch(console.error).then(() => [msgInfo, newMsg])
|
|
// })
|
|
// })
|
|
// : bot.sendMessage(
|
|
// message.channelUID,
|
|
// text,
|
|
// ).then((msg) => [msg as string, text])
|
|
}
|
|
},
|
|
functions: {
|
|
generateImage: defineChatSessionFunction({
|
|
description: "Generate an image from a prompt",
|
|
params: {
|
|
type: "object",
|
|
properties: {
|
|
prompt: {
|
|
type: "string",
|
|
description: "The prompt to generate the image from",
|
|
},
|
|
},
|
|
required: ["prompt"],
|
|
},
|
|
handler: async ({ prompt }) => {
|
|
bot.sendMessage(message.channelUID, "@abot prompt " + prompt)
|
|
},
|
|
}),
|
|
now: defineChatSessionFunction({
|
|
description: "Get the current time",
|
|
params: {
|
|
type: "object",
|
|
properties: {},
|
|
},
|
|
handler: async () => {
|
|
const now = new Date()
|
|
return now.toLocaleString()
|
|
},
|
|
}),
|
|
...(await fetchRemoteFunctions().catch((e) => {
|
|
logger.error("Failed to fetch remote functions", { error: e })
|
|
return ({})
|
|
})),
|
|
},
|
|
}),
|
|
)
|
|
|
|
response = await this.cleanResponse(response, message, bot)
|
|
|
|
clearInterval(streamId)
|
|
|
|
if (sentMessageInfo) {
|
|
const msgInfo = await sentMessageInfo
|
|
|
|
bot.updateMessageText(
|
|
msgInfo,
|
|
response,
|
|
)
|
|
} else {
|
|
bot.sendMessage(message.channelUID, response)
|
|
}
|
|
}
|
|
|
|
override async handleMessage(message: Message, bot: Bot): Promise<void> {
|
|
const user = this.user
|
|
if (!user) {
|
|
return
|
|
}
|
|
|
|
if (
|
|
message.message.toLowerCase().startsWith(
|
|
`@${user.username.toLowerCase()}`,
|
|
)
|
|
) {
|
|
const newMessage = message.message.substring(
|
|
message.message.indexOf(" "),
|
|
message.message.length,
|
|
).trim()
|
|
if (newMessage && message.message.includes(" ")) {
|
|
const [[_, command, rest]] = newMessage.matchAll(/^(\S+)\s*(.*)$/gs)
|
|
|
|
if (command in this.#subCommands) {
|
|
return this.#subCommands[command]?.(rest.trim(), message, bot)
|
|
}
|
|
}
|
|
}
|
|
|
|
return this.prompt(message.message, message, bot)
|
|
}
|
|
}
|