Custom LLM Agent (with a ChatModel)
This notebook goes through how to create your own custom agent based on a chat model.
An LLM chat agent consists of three parts:
- PromptTemplate: This is the prompt template that can be used to instruct the language model on what to do
- ChatModel: This is the language model that powers the agent
- stopsequence: Instructs the LLM to stop generating as soon as this string is found
- OutputParser: This determines how to parse the LLMOutput into an AgentAction or AgentFinish object
The LLMAgent is used in an AgentExecutor. This AgentExecutor can largely be thought of as a loop that:
- Passes user input and any previous steps to the Agent (in this case, the LLMAgent)
- If the Agent returns an AgentFinish, then return that directly to the user
- If the Agent returns an AgentAction, then use that to call a tool and get anObservation
- Repeat, passing the AgentActionandObservationback to the Agent until anAgentFinishis emitted.
AgentAction is a response that consists of action and action_input. action refers to which tool to use, and action_input refers to the input to that tool. log can also be provided as more context (that can be used for logging, tracing, etc).
AgentFinish is a response that contains the final message to be sent back to the user. This should be used to end an agent run.
With LCEL
- npm
- Yarn
- pnpm
npm install @langchain/openai
yarn add @langchain/openai
pnpm add @langchain/openai
import { AgentExecutor } from "langchain/agents";
import { formatLogToString } from "langchain/agents/format_scratchpad/log";
import { ChatOpenAI } from "@langchain/openai";
import { Calculator } from "@langchain/community/tools/calculator";
import { PromptTemplate } from "@langchain/core/prompts";
import { AgentAction, AgentFinish, AgentStep } from "@langchain/core/agents";
import { BaseMessage, HumanMessage } from "@langchain/core/messages";
import { InputValues } from "@langchain/core/memory";
import { RunnableSequence } from "@langchain/core/runnables";
import { SerpAPI } from "@langchain/community/tools/serpapi";
/**
 * Instantiate the chat model and bind the stop token
 * @important The stop token must be set, if not the LLM will happily continue generating text forever.
 */
const model = new ChatOpenAI({ temperature: 0 }).bind({
  stop: ["\nObservation"],
});
/** Define the tools */
const tools = [
  new SerpAPI(process.env.SERPAPI_API_KEY, {
    location: "Austin,Texas,United States",
    hl: "en",
    gl: "us",
  }),
  new Calculator(),
];
/** Create the prefix prompt */
const PREFIX = `Answer the following questions as best you can. You have access to the following tools:
{tools}`;
/** Create the tool instructions prompt */
const TOOL_INSTRUCTIONS_TEMPLATE = `Use the following format in your response:
Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [{tool_names}]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question`;
/** Create the suffix prompt */
const SUFFIX = `Begin!
Question: {input}
Thought:`;
async function formatMessages(
  values: InputValues
): Promise<Array<BaseMessage>> {
  /** Check input and intermediate steps are both inside values */
  if (!("input" in values) || !("intermediate_steps" in values)) {
    throw new Error("Missing input or agent_scratchpad from values.");
  }
  /** Extract and case the intermediateSteps from values as Array<AgentStep> or an empty array if none are passed */
  const intermediateSteps = values.intermediate_steps
    ? (values.intermediate_steps as Array<AgentStep>)
    : [];
  /** Call the helper `formatLogToString` which returns the steps as a string  */
  const agentScratchpad = formatLogToString(intermediateSteps);
  /** Construct the tool strings */
  const toolStrings = tools
    .map((tool) => `${tool.name}: ${tool.description}`)
    .join("\n");
  const toolNames = tools.map((tool) => tool.name).join(",\n");
  /** Create templates and format the instructions and suffix prompts */
  const prefixTemplate = new PromptTemplate({
    template: PREFIX,
    inputVariables: ["tools"],
  });
  const instructionsTemplate = new PromptTemplate({
    template: TOOL_INSTRUCTIONS_TEMPLATE,
    inputVariables: ["tool_names"],
  });
  const suffixTemplate = new PromptTemplate({
    template: SUFFIX,
    inputVariables: ["input"],
  });
  /** Format both templates by passing in the input variables */
  const formattedPrefix = await prefixTemplate.format({
    tools: toolStrings,
  });
  const formattedInstructions = await instructionsTemplate.format({
    tool_names: toolNames,
  });
  const formattedSuffix = await suffixTemplate.format({
    input: values.input,
  });
  /** Construct the final prompt string */
  const formatted = [
    formattedPrefix,
    formattedInstructions,
    formattedSuffix,
    agentScratchpad,
  ].join("\n");
  /** Return the message as a HumanMessage. */
  return [new HumanMessage(formatted)];
}
/** Define the custom output parser */
function customOutputParser(message: BaseMessage): AgentAction | AgentFinish {
  const text = message.content;
  if (typeof text !== "string") {
    throw new Error(
      `Message content is not a string. Received: ${JSON.stringify(
        text,
        null,
        2
      )}`
    );
  }
  /** If the input includes "Final Answer" return as an instance of `AgentFinish` */
  if (text.includes("Final Answer:")) {
    const parts = text.split("Final Answer:");
    const input = parts[parts.length - 1].trim();
    const finalAnswers = { output: input };
    return { log: text, returnValues: finalAnswers };
  }
  /** Use RegEx to extract any actions and their values */
  const match = /Action: (.*)\nAction Input: (.*)/s.exec(text);
  if (!match) {
    throw new Error(`Could not parse LLM output: ${text}`);
  }
  /** Return as an instance of `AgentAction` */
  return {
    tool: match[1].trim(),
    toolInput: match[2].trim().replace(/^"+|"+$/g, ""),
    log: text,
  };
}
/** Define the Runnable with LCEL */
const runnable = RunnableSequence.from([
  {
    input: (values: InputValues) => values.input,
    intermediate_steps: (values: InputValues) => values.steps,
  },
  formatMessages,
  model,
  customOutputParser,
]);
/** Pass the runnable to the `AgentExecutor` class as the agent */
const executor = new AgentExecutor({
  agent: runnable,
  tools,
});
console.log("Loaded agent.");
const input = `Who is Olivia Wilde's boyfriend? What is his current age raised to the 0.23 power?`;
console.log(`Executing with input "${input}"...`);
const result = await executor.invoke({ input });
console.log(`Got output ${result.output}`);
/**
 * Got output Harry Styles' current age raised to the 0.23 power is approximately 2.1156502324195268.
 */
API Reference:
- AgentExecutor from langchain/agents
- formatLogToString from langchain/agents/format_scratchpad/log
- ChatOpenAI from @langchain/openai
- Calculator from @langchain/community/tools/calculator
- PromptTemplate from @langchain/core/prompts
- AgentAction from @langchain/core/agents
- AgentFinish from @langchain/core/agents
- AgentStep from @langchain/core/agents
- BaseMessage from @langchain/core/messages
- HumanMessage from @langchain/core/messages
- InputValues from @langchain/core/memory
- RunnableSequence from @langchain/core/runnables
- SerpAPI from @langchain/community/tools/serpapi
With LLMChain
import {
  AgentActionOutputParser,
  AgentExecutor,
  LLMSingleActionAgent,
} from "langchain/agents";
import { LLMChain } from "langchain/chains";
import { ChatOpenAI } from "@langchain/openai";
import { Calculator } from "@langchain/community/tools/calculator";
import {
  BaseChatPromptTemplate,
  SerializedBasePromptTemplate,
  renderTemplate,
} from "@langchain/core/prompts";
import { AgentAction, AgentFinish, AgentStep } from "@langchain/core/agents";
import { BaseMessage, HumanMessage } from "@langchain/core/messages";
import { InputValues } from "@langchain/core/memory";
import { PartialValues } from "@langchain/core/utils/types";
import { Tool } from "@langchain/core/tools";
import { SerpAPI } from "@langchain/community/tools/serpapi";
const PREFIX = `Answer the following questions as best you can. You have access to the following tools:`;
const formatInstructions = (
  toolNames: string
) => `Use the following format in your response:
Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [${toolNames}]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question`;
const SUFFIX = `Begin!
Question: {input}
Thought:{agent_scratchpad}`;
class CustomPromptTemplate extends BaseChatPromptTemplate {
  tools: Tool[];
  constructor(args: { tools: Tool[]; inputVariables: string[] }) {
    super({ inputVariables: args.inputVariables });
    this.tools = args.tools;
  }
  _getPromptType(): string {
    return "chat";
  }
  async formatMessages(values: InputValues): Promise<BaseMessage[]> {
    /** Construct the final template */
    const toolStrings = this.tools
      .map((tool) => `${tool.name}: ${tool.description}`)
      .join("\n");
    const toolNames = this.tools.map((tool) => tool.name).join("\n");
    const instructions = formatInstructions(toolNames);
    const template = [PREFIX, toolStrings, instructions, SUFFIX].join("\n\n");
    /** Construct the agent_scratchpad */
    const intermediateSteps = values.intermediate_steps as AgentStep[];
    const agentScratchpad = intermediateSteps.reduce(
      (thoughts, { action, observation }) =>
        thoughts +
        [action.log, `\nObservation: ${observation}`, "Thought:"].join("\n"),
      ""
    );
    const newInput = { agent_scratchpad: agentScratchpad, ...values };
    /** Format the template. */
    const formatted = renderTemplate(template, "f-string", newInput);
    return [new HumanMessage(formatted)];
  }
  partial(_values: PartialValues): Promise<BaseChatPromptTemplate> {
    throw new Error("Not implemented");
  }
  serialize(): SerializedBasePromptTemplate {
    throw new Error("Not implemented");
  }
}
class CustomOutputParser extends AgentActionOutputParser {
  lc_namespace = ["langchain", "agents", "custom_llm_agent_chat"];
  async parse(text: string): Promise<AgentAction | AgentFinish> {
    if (text.includes("Final Answer:")) {
      const parts = text.split("Final Answer:");
      const input = parts[parts.length - 1].trim();
      const finalAnswers = { output: input };
      return { log: text, returnValues: finalAnswers };
    }
    const match = /Action: (.*)\nAction Input: (.*)/s.exec(text);
    if (!match) {
      throw new Error(`Could not parse LLM output: ${text}`);
    }
    return {
      tool: match[1].trim(),
      toolInput: match[2].trim().replace(/^"+|"+$/g, ""),
      log: text,
    };
  }
  getFormatInstructions(): string {
    throw new Error("Not implemented");
  }
}
export const run = async () => {
  const model = new ChatOpenAI({ temperature: 0 });
  const tools = [
    new SerpAPI(process.env.SERPAPI_API_KEY, {
      location: "Austin,Texas,United States",
      hl: "en",
      gl: "us",
    }),
    new Calculator(),
  ];
  const llmChain = new LLMChain({
    prompt: new CustomPromptTemplate({
      tools,
      inputVariables: ["input", "agent_scratchpad"],
    }),
    llm: model,
  });
  const agent = new LLMSingleActionAgent({
    llmChain,
    outputParser: new CustomOutputParser(),
    stop: ["\nObservation"],
  });
  const executor = new AgentExecutor({
    agent,
    tools,
  });
  console.log("Loaded agent.");
  const input = `Who is Olivia Wilde's boyfriend? What is his current age raised to the 0.23 power?`;
  console.log(`Executing with input "${input}"...`);
  const result = await executor.invoke({ input });
  console.log(`Got output ${result.output}`);
};
run();
API Reference:
- AgentActionOutputParser from langchain/agents
- AgentExecutor from langchain/agents
- LLMSingleActionAgent from langchain/agents
- LLMChain from langchain/chains
- ChatOpenAI from @langchain/openai
- Calculator from @langchain/community/tools/calculator
- BaseChatPromptTemplate from @langchain/core/prompts
- SerializedBasePromptTemplate from @langchain/core/prompts
- renderTemplate from @langchain/core/prompts
- AgentAction from @langchain/core/agents
- AgentFinish from @langchain/core/agents
- AgentStep from @langchain/core/agents
- BaseMessage from @langchain/core/messages
- HumanMessage from @langchain/core/messages
- InputValues from @langchain/core/memory
- PartialValues from @langchain/core/utils/types
- Tool from @langchain/core/tools
- SerpAPI from @langchain/community/tools/serpapi