anya/core/message-processor.ts

441 lines
14 KiB
TypeScript
Raw Permalink Normal View History

2024-10-06 13:56:33 +05:30
import { PlatformAdapter } from "../interfaces/platform-adapter";
import { Message, SentMessage } from "../interfaces/message";
import { getTools, zodFunction } from "../tools";
import OpenAI from "openai";
import { createHash } from "crypto";
import { format } from "date-fns";
import { saveApiUsage } from "../usage";
import { buildSystemPrompts } from "../assistant/system-prompts";
import YAML from "yaml";
import { ask, get_transcription } from "../tools/ask";
import { z } from "zod";
import { send_sys_log } from "../interfaces/log";
interface MessageQueueEntry {
abortController: AbortController;
runningTools: boolean;
}
export class MessageProcessor {
private openai: OpenAI;
2024-10-09 00:26:06 +05:30
private model: string = "gpt-4o";
2024-10-06 13:56:33 +05:30
private messageQueue: Map<string, MessageQueueEntry> = new Map();
private toolsCallMap: Map<string, OpenAI.Chat.ChatCompletionMessageParam[]> =
new Map();
private channelIdHashMap: Map<string, string[]> = new Map();
private sentMessage: SentMessage | null = null;
constructor(private adapter: PlatformAdapter) {
this.openai = new OpenAI({
apiKey: process.env.OPENAI_API_KEY!,
});
}
public async processMessage(message: Message): Promise<void> {
const userId = message.author.id;
const channelId = message.channelId || userId; // Use message.id if channelId is not available
// Check if the message is a stop message
if (["stop", "reset"].includes(message.content.toLowerCase())) {
message.platform !== "whatsapp" &&
(await message.send({
content: "---setting this point as the start---",
}));
// Clear maps
const hashes = this.channelIdHashMap.get(channelId) ?? [];
hashes.forEach((hash) => {
this.toolsCallMap.delete(hash);
});
this.channelIdHashMap.set(channelId, []);
return;
}
if (this.messageQueue.has(channelId)) {
const queueEntry = this.messageQueue.get(channelId)!;
if (!queueEntry.runningTools) {
// Abort previous processing
queueEntry.abortController.abort();
this.messageQueue.delete(channelId);
} else {
// If tools are running, do not abort and return
return;
}
}
// Prepare OpenAI request
const abortController = new AbortController();
this.messageQueue.set(channelId, {
abortController,
runningTools: false,
});
// Handle timeout
setTimeout(async () => {
const queueEntry = this.messageQueue.get(channelId);
if (queueEntry && !queueEntry.runningTools) {
abortController.abort();
this.messageQueue.delete(channelId);
await message.send({ content: "Timed out." });
}
}, 600000); // 10 minutes
try {
// Indicate typing
message.platformAdapter.config.indicators.typing &&
(await message.sendTyping());
// Fetch message history
const history = await this.adapter.fetchMessages(channelId, {
limit: 50,
});
// Send 'thinking...' indicator
if (message.platformAdapter.config.indicators.processing)
this.sentMessage = await message.send({ content: "thinking..." });
// Check for stop message in history
let stopIndex = -1;
for (let i = 0; i < history.length; i++) {
if (
history[i].content === "---setting this point as the start---" ||
history[i].content.replaceAll("!", "").trim() === "stop"
) {
stopIndex = i;
break;
}
}
const effectiveHistory =
stopIndex !== -1 ? history.slice(0, stopIndex) : history;
// Construct AI messages
const aiMessages = await this.constructAIMessages(
effectiveHistory,
message,
channelId
);
// Run tools and get AI response
const response = await this.runAI(
aiMessages as OpenAI.Chat.ChatCompletionMessage[],
message.author.username,
message,
abortController,
channelId
);
// Send reply via adapter
if (response && !response.includes("<NOREPLY>")) {
const content = this.isJsonParseable(response);
if (content && content.user_message) {
await message.send({ content: content.user_message });
} else {
await message.send({ content: response });
}
}
// Delete the thinking message
if (this.sentMessage && this.sentMessage.deletable) {
await this.sentMessage.delete();
} else if (this.sentMessage) {
// If not deletable, edit the message to indicate completion
await this.sentMessage.edit({ content: "Response sent." });
}
} catch (error) {
console.error("Error processing message:", error);
await this.sentMessage?.delete();
// await message.send({
// content: "An error occurred while processing your message.",
// });
} finally {
// Clean up
this.messageQueue.delete(channelId);
}
}
private async constructAIMessages(
history: Message[],
message: Message,
channelId: string
): Promise<OpenAI.Chat.ChatCompletionMessageParam[]> {
// Build system prompts based on user roles
const systemMessages: OpenAI.Chat.ChatCompletionMessageParam[] =
await buildSystemPrompts(message);
// Map history messages to AI messages
const channelHashes = this.channelIdHashMap.get(channelId) || [];
const aiMessagesArrays = await Promise.all(
history.reverse().map(async (msg) => {
const role =
msg.author.id === this.adapter.getBotId() ? "assistant" : "user";
// Process attachments
const files = (msg.attachments || [])
.filter((a) => !a.url.includes("voice-message.ogg"))
.map((a) => a.url);
const embeds = (msg.embeds || [])
.map((e) => JSON.stringify(e))
.join("\n");
// Transcribe voice messages
const voiceMessagesPromises = (msg.attachments || [])
.filter(
(a) => a.url.includes("voice-message.ogg") || a.type === "ptt"
)
.map(async (a) => {
const data =
msg.platform === "whatsapp" ? (a.data as string) : a.url;
const binary = msg.platform === "whatsapp";
const key = msg.platform === "whatsapp" ? msg.id : undefined;
return {
file: a.url,
transcription: await get_transcription(data, binary, key),
};
});
const voiceMessages = await Promise.all(voiceMessagesPromises);
// Process context message if any
let contextMessage = null;
if (msg.threadId) {
contextMessage = history.find((m) => m.id === msg.threadId);
// If not found, attempt to fetch it
if (!contextMessage) {
contextMessage = await this.adapter.fetchMessageById(
channelId,
msg.threadId
);
}
}
const contextAsJson = JSON.stringify({
embeds: embeds || undefined,
files: files.length > 0 ? files : undefined,
user_message: msg.content,
user_voice_messages:
voiceMessages.length > 0 ? voiceMessages : undefined,
created_at: format(msg.timestamp, "yyyy-MM-dd HH:mm:ss") + " IST",
context_message: contextMessage
? {
author: contextMessage.author.username,
created_at:
format(contextMessage.timestamp, "yyyy-MM-dd HH:mm:ss") +
" IST",
content: contextMessage.content,
}
: undefined,
context_files:
contextMessage?.attachments?.map((a) => a.url) || undefined,
context_embeds:
contextMessage?.embeds?.map((e) => JSON.stringify(e)).join("\n") ||
undefined,
});
// get main user from userConfig
const user = this.adapter.getUserById(msg.author.id);
const aiMessage: OpenAI.Chat.ChatCompletionMessageParam = {
role,
content: contextAsJson,
name:
user?.name ||
msg.author.username.replace(/\s+/g, "_").substring(0, 64),
};
// Handle tool calls mapping if necessary
const hash = this.generateHash(msg.content);
const calls = this.toolsCallMap.get(hash);
if (calls) {
return [aiMessage, ...calls];
} else {
return [aiMessage];
}
})
);
// Flatten aiMessages (since it's an array of arrays)
let aiMessages = aiMessagesArrays.flat();
// Collect hashes
history.forEach((msg) => {
const hash = this.generateHash(msg.content);
channelHashes.push(hash);
});
// Update the channelIdHashMap
this.channelIdHashMap.set(channelId, channelHashes);
// If the conversation history is too long, summarize it
if (aiMessages.length > 25) {
aiMessages = await this.summarizeConversation(aiMessages);
}
// Combine system messages and conversation messages
return systemMessages.concat(aiMessages);
}
private async summarizeConversation(
messages: OpenAI.Chat.ChatCompletionMessageParam[]
): Promise<OpenAI.Chat.ChatCompletionMessageParam[]> {
// Split the messages if necessary
const lastTen = messages.slice(-10);
const firstTen = messages.slice(0, 10);
// Use the OpenAI API to generate the summary
const summaryResponse = await ask({
model: "gpt-4o-mini",
prompt: `Summarize the below conversation into 2 sections:
1. General info about the conversation
2. Tools used in the conversation and their data in relation to the conversation.
Conversation:
----
${YAML.stringify(firstTen)}
----
Notes:
- Keep only important information and points, remove anything repetitive.
- Keep tools information if they are relevant.
- The summary is to give context about the conversation that was happening previously.
`,
});
const summaryContent = summaryResponse.choices[0].message.content;
// Create a new conversation history with the summary
const summarizedConversation: OpenAI.Chat.ChatCompletionMessageParam[] = [
{
role: "system",
content: `Previous messages summarized:
${summaryContent}
`,
},
...lastTen,
];
return summarizedConversation;
}
private async runAI(
messages: OpenAI.Chat.ChatCompletionMessage[],
username: string,
message: Message,
abortController: AbortController,
channelId: string
): Promise<string> {
const tmp = this;
async function changeModel({ model }: { model: string }) {
tmp.model = model;
console.log("Model changed to", model);
return { message: "Model changed to " + model };
}
// Use OpenAI to get a response, include tools integration
const tools = getTools(username, message, "self");
const toolCalls: OpenAI.Chat.ChatCompletionMessageParam[] = [];
console.log("Current Model", this.model);
const runner = this.openai.beta.chat.completions
.runTools(
{
model: this.model,
temperature: 0.6,
user: username,
messages,
stream: true,
tools: [
zodFunction({
name: "changeModel",
schema: z.object({
model: z.string(z.enum(["gpt-4o-mini", "gpt-4o"])),
}),
function: changeModel,
description: `Change the model at run time.
Default Model is 'gpt-4o-mini'.
Current Model: ${this.model}
Switch to 'gpt-4o' before running any other tool.
Try to switch back to 'gpt-4o-mini' after running tools.
`,
}),
...tools,
],
},
{ signal: abortController.signal }
)
.on("functionCall", async (fnc) => {
console.log("Function call:", fnc);
// Set runningTools to true
send_sys_log(`calling function: ${fnc.name}, in channel ${channelId}`);
const queueEntry = this.messageQueue.get(channelId);
if (queueEntry) {
queueEntry.runningTools = true;
}
// Indicate running tools
if (this.sentMessage) {
await this.sentMessage.edit({ content: `Running ${fnc.name}...` });
} else await message.send({ content: `Running ${fnc.name}...` });
})
.on("message", (m) => {
if (
m.role === "assistant" &&
(m.function_call || (m as any).tool_calls?.length)
) {
toolCalls.push(m);
}
if (
(m.role === "function" || m.role === "tool") &&
((m as any).function_call || (m as any).tool_call_id)
) {
toolCalls.push(m);
}
})
.on("error", (err) => {
console.error("Error:", err);
send_sys_log(`Error: ${err}, in channel ${channelId}`);
if (this.sentMessage)
this.sentMessage.edit({ content: "Error: " + JSON.stringify(err) });
else message.send({ content: "Error: " + JSON.stringify(err) });
})
.on("abort", () => {
send_sys_log(`Aborting in channel ${channelId}`);
console.log("Aborted");
})
.on("totalUsage", (stat) => {
send_sys_log(`Usage: ${JSON.stringify(stat)}, in channel ${channelId}`);
saveApiUsage(
format(new Date(), "yyyy-MM-dd"),
this.model,
stat.prompt_tokens,
stat.completion_tokens
);
});
const finalContent = await runner.finalContent();
// Store tool calls in toolsCallMap
const hash = this.generateHash(messages[messages.length - 1].content || "");
this.toolsCallMap.set(hash, toolCalls);
return finalContent ?? "";
}
private isJsonParseable(str: string) {
try {
return JSON.parse(str);
} catch (e) {
return false;
}
}
private generateHash(input: string): string {
const hash = createHash("sha256");
hash.update(input);
return hash.digest("hex");
}
}