2025-03-06 00:36:52 +01:00
|
|
|
import {BaseHandler} from "./base-handler.ts"
|
|
|
|
import {Bot, Message} from "../snek/snek-socket.ts"
|
|
|
|
import {trim, trimStart} from "npm:lodash-es"
|
|
|
|
import {
|
|
|
|
ChatSessionModelFunctions,
|
|
|
|
ChatWrapper,
|
|
|
|
GeneralChatWrapper,
|
|
|
|
getLlama,
|
|
|
|
LLamaChatPromptOptions,
|
|
|
|
LlamaChatSession,
|
|
|
|
LlamaContext,
|
|
|
|
LlamaModel,
|
|
|
|
resolveChatWrapper,
|
|
|
|
Token,
|
|
|
|
} from "npm:node-llama-cpp"
|
|
|
|
import {getLogger} from "@logtape/logtape"
|
|
|
|
import {deepMerge} from "@std/collections/deep-merge"
|
|
|
|
|
|
|
|
const llama = await getLlama()
|
|
|
|
|
|
|
|
const textEncoder = new TextEncoder()
|
|
|
|
|
|
|
|
function printSync(input: string | Uint8Array, to = Deno.stdout) {
|
|
|
|
let bytesWritten = 0
|
|
|
|
const bytes = typeof input === "string" ? textEncoder.encode(input) : input
|
|
|
|
while (bytesWritten < bytes.length) {
|
|
|
|
bytesWritten += to.writeSync(bytes.subarray(bytesWritten))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
const logger = getLogger(["llama-gen-handler"])
|
|
|
|
|
|
|
|
const optionsGenerator = <
|
|
|
|
const Functions extends ChatSessionModelFunctions | undefined = undefined,
|
|
|
|
LLamaOptions = LLamaChatPromptOptions<Functions>
|
|
|
|
>(
|
|
|
|
model: LlamaModel,
|
|
|
|
debugOutput: boolean = true,
|
|
|
|
defaultTimeout = 5 * 60 * 1000,
|
|
|
|
options?: LLamaOptions,
|
|
|
|
): LLamaChatPromptOptions<Functions> => {
|
|
|
|
const manager = AbortSignal.timeout(defaultTimeout)
|
|
|
|
|
|
|
|
const defaultOptions: LLamaChatPromptOptions<Functions> = {
|
|
|
|
repeatPenalty: {
|
|
|
|
lastTokens: 24,
|
|
|
|
penalty: 1.12,
|
|
|
|
penalizeNewLine: true,
|
|
|
|
frequencyPenalty: 0.02,
|
|
|
|
presencePenalty: 0.02,
|
|
|
|
punishTokensFilter: (tokens: Token[]) => {
|
|
|
|
return tokens.filter((token) => {
|
|
|
|
const text = model.detokenize([token])
|
|
|
|
|
|
|
|
// allow the model to repeat tokens
|
|
|
|
// that contain the word "better"
|
|
|
|
return !text.toLowerCase().includes("@")
|
|
|
|
// TODO: Exclude usernames
|
|
|
|
})
|
|
|
|
},
|
|
|
|
},
|
|
|
|
temperature: 0.6,
|
|
|
|
|
|
|
|
signal: manager,
|
|
|
|
stopOnAbortSignal: true,
|
|
|
|
}
|
|
|
|
|
|
|
|
if (debugOutput) {
|
|
|
|
defaultOptions.onResponseChunk = (chunk) => {
|
|
|
|
const isThoughtSegment = chunk.type === "segment" &&
|
|
|
|
chunk.segmentType === "thought"
|
|
|
|
|
|
|
|
if (
|
|
|
|
chunk.type === "segment" && chunk.segmentStartTime != null
|
|
|
|
) {
|
|
|
|
printSync(` [segment start: ${chunk.segmentType}] `)
|
|
|
|
}
|
|
|
|
|
|
|
|
printSync(chunk.text)
|
|
|
|
|
|
|
|
if (chunk.type === "segment" && chunk.segmentEndTime != null) {
|
|
|
|
printSync(` [segment end: ${chunk.segmentType}] `)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2025-04-10 09:52:58 +02:00
|
|
|
return deepMerge(defaultOptions, options ?? {})
|
2025-03-06 00:36:52 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
export class LLamaHandler extends BaseHandler {
|
|
|
|
joinMode = false
|
|
|
|
debugLogResponses = true
|
|
|
|
|
|
|
|
systemPrompt: string
|
|
|
|
|
|
|
|
#activeModel: string
|
|
|
|
|
|
|
|
#model: LlamaModel | null = null
|
|
|
|
#context: LlamaContext | null = null
|
|
|
|
|
|
|
|
#chatWrapper: ChatWrapper | null = null
|
|
|
|
|
|
|
|
#session: LlamaChatSession | null = null
|
|
|
|
|
|
|
|
constructor(
|
|
|
|
activeModel: string,
|
|
|
|
systemPrompt: string = "You are an AI chatbot.",
|
|
|
|
) {
|
|
|
|
super("")
|
|
|
|
this.#activeModel = activeModel
|
|
|
|
this.systemPrompt = systemPrompt
|
|
|
|
}
|
|
|
|
|
|
|
|
async calculateSystemPrompt(): Promise<string> {
|
|
|
|
return this.systemPrompt
|
|
|
|
}
|
|
|
|
|
|
|
|
override async bind(bot: Bot): Promise<void> {
|
|
|
|
await super.bind(bot)
|
|
|
|
this.prefix = "@" + this.user?.username.toLowerCase()
|
|
|
|
|
|
|
|
this.#model = await llama.loadModel({
|
|
|
|
modelPath: this.#activeModel,
|
|
|
|
defaultContextFlashAttention: true,
|
|
|
|
})
|
|
|
|
|
|
|
|
this.#context = await this.#model.createContext({
|
|
|
|
flashAttention: true,
|
|
|
|
})
|
|
|
|
|
|
|
|
logger.info("Model loaded", {
|
|
|
|
batchSize: this.#context.batchSize,
|
|
|
|
contextSize: this.#context.contextSize,
|
|
|
|
})
|
|
|
|
console.log("Model loaded", {
|
|
|
|
batchSize: this.#context.batchSize,
|
|
|
|
contextSize: this.#context.contextSize,
|
|
|
|
})
|
|
|
|
|
|
|
|
this.#chatWrapper = //new Llama3ChatWrapper()
|
|
|
|
resolveChatWrapper({
|
|
|
|
bosString: this.#model
|
|
|
|
.tokens
|
|
|
|
.bosString,
|
|
|
|
filename: this.#model
|
|
|
|
.filename,
|
|
|
|
fileInfo: this.#model
|
|
|
|
.fileInfo,
|
|
|
|
tokenizer: this.#model
|
|
|
|
.tokenizer,
|
|
|
|
}) ?? new GeneralChatWrapper()
|
|
|
|
|
|
|
|
this.#session = new LlamaChatSession({
|
|
|
|
contextSequence: this.#context.getSequence(),
|
|
|
|
chatWrapper: this.#chatWrapper,
|
|
|
|
systemPrompt: await this.calculateSystemPrompt(),
|
|
|
|
})
|
|
|
|
|
|
|
|
const channels = await bot.channels
|
|
|
|
|
|
|
|
const channel = channels.find((v) => v.tag === "public") || channels[0]
|
|
|
|
|
|
|
|
if (channel) {
|
|
|
|
await bot.sendMessage(
|
|
|
|
channel.uid,
|
|
|
|
await this.#session.prompt(
|
|
|
|
"Welcome to chat, greet everyone\n",
|
|
|
|
optionsGenerator(this.#model, this.debugLogResponses),
|
|
|
|
),
|
|
|
|
)
|
|
|
|
this.#session.resetChatHistory()
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
override async isMatch(message: Message): Promise<boolean> {
|
|
|
|
return message.userUID !== this.user?.uid && (this.joinMode ||
|
|
|
|
trim(message?.message, " `").toLowerCase().includes(this.prefix))
|
|
|
|
}
|
|
|
|
|
|
|
|
override async handleMessage(message: Message, bot: Bot): Promise<void> {
|
|
|
|
const session = this.#session
|
|
|
|
const user = this.user
|
|
|
|
if (!session || !user) {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
let response = await session.prompt(
|
|
|
|
`@${message.username}: ${message.message}`,
|
|
|
|
optionsGenerator(this.#model!, this.debugLogResponses),
|
|
|
|
)
|
|
|
|
|
|
|
|
response = response.replace(/<think>.*?<\/think>/gs, "")
|
|
|
|
let lwResponse = response.toLowerCase()
|
|
|
|
|
|
|
|
if (lwResponse.startsWith("ai")) {
|
|
|
|
response = response.substring(2).trim()
|
|
|
|
lwResponse = response.toLowerCase()
|
|
|
|
}
|
|
|
|
|
|
|
|
if (lwResponse.startsWith(`@${user.username.toLowerCase()}`)) {
|
|
|
|
response = response.substring(user.username.length + 1).trim()
|
|
|
|
lwResponse = response.toLowerCase()
|
|
|
|
}
|
|
|
|
|
|
|
|
if (lwResponse.startsWith(`@${message.username.toLowerCase()}`)) {
|
|
|
|
response = response.substring(message.username.length + 1).trim()
|
|
|
|
lwResponse = response.toLowerCase()
|
|
|
|
}
|
|
|
|
|
|
|
|
response = trimStart(response, ":").trim()
|
|
|
|
bot.send("send_message", message.channelUID, response)
|
|
|
|
}
|
|
|
|
}
|