Skip to content

Huggingface

Installation#

Install AG2 with the necessary dependencies for custom model integration:

pip install ag2[openai] torch transformers sentencepiece

Tip

If you have been using autogen or ag2, all you need to do is upgrade it using:

pip install -U autogen[openai] torch transformers sentencepiece
or
pip install -U ag2[openai] torch transformers sentencepiece
as autogen and ag2 are aliases for the same PyPI package.

Dependencies#

We will explore how to use custom models (specifically Hugging Face models) in AG2. Install the necessary dependencies with the following command:

pip install ag2[openai] torch transformers sentencepiece

For more information, please refer to the installation guide.

Features#

AG2 supports custom model integration through the ModelClient protocol, allowing you to:

  • Use any Hugging Face model or other local models with AG2 agents
  • Define custom model loading and inference logic
  • Control model initialization and configuration
  • Support both model-loaded-in-client and pre-loaded model approaches
  • Integrate seamlessly with AG2's agent framework

Note

NOTE: Depending on what model you use, you may need to play with the default prompts of the Agent's

Main Distinctions#

  • Protocol-based: Custom models must adhere to the ModelClient protocol defined in AG2
  • Flexible initialization: Models can be loaded within the client class or pre-loaded and passed as arguments
  • Local inference: Custom models run locally, so there are no API costs (cost returns 0)
  • No streaming support: Local models do not support streaming (raises NotImplementedError)
  • Response structure: Must return responses that follow the ModelClientResponseProtocol

ModelClient Protocol#

A custom model class must implement the ModelClient protocol. The protocol requires the following methods:

  • create(): Must return a response object that implements the ModelClientResponseProtocol
  • message_retrieval(): Must return a list of strings or a list of Choice.Message objects
  • cost(): Must return the cost of the response (typically 0 for local models)
  • get_usage(): Must return a dictionary with keys: prompt_tokens, completion_tokens, total_tokens, cost, model

The response protocol structure:

class ModelClient(Protocol):
    RESPONSE_USAGE_KEYS = ["prompt_tokens", "completion_tokens", "total_tokens", "cost", "model"]

    class ModelClientResponseProtocol(Protocol):
        class Choice(Protocol):
            class Message(Protocol):
                content: Optional[str]

            message: Message

        choices: List[Choice]
        model: str

    def create(self, params) -> ModelClientResponseProtocol:
        ...

    def message_retrieval(
        self, response: ModelClientResponseProtocol
    ) -> Union[List[str], List[ModelClientResponseProtocol.Choice.Message]]:
        ...

    def cost(self, response: ModelClientResponseProtocol) -> float:
        ...

    @staticmethod
    def get_usage(response: ModelClientResponseProtocol) -> Dict:
        ...

Configuration#

To use a custom model, you need to add the model_client_cls field to your configuration. Here's an example configuration for using the Open-Orca/Mistral-7B-OpenOrca model:

[
    {
        "model": "Open-Orca/Mistral-7B-OpenOrca",
        "model_client_cls": "CustomModelClient",
        "device": "cuda",
        "n": 1,
        "params": {
            "max_length": 1000
        }
    }
]

The configuration supports: - model: The Hugging Face model identifier or path - model_client_cls: The name of your custom client class (must match the registered class name) - device: Device to run the model on ("cpu" or "cuda") - params: Custom parameters for your model (e.g., max_length for generation)

Custom Model Client Example#

Here's a complete example of a custom model client that loads a Hugging Face model:

from types import SimpleNamespace
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from autogen import AssistantAgent, LLMConfig, UserProxyAgent

class CustomModelClient:
    def __init__(self, config, **kwargs):
        print(f"CustomModelClient config: {config}")
        self.device = config.get("device", "cpu")
        self.model = AutoModelForCausalLM.from_pretrained(config["model"]).to(self.device)
        self.model_name = config["model"]
        self.tokenizer = AutoTokenizer.from_pretrained(config["model"], use_fast=False)
        self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

        gen_config_params = config.get("params", {})
        self.max_length = gen_config_params.get("max_length", 256)

        print(f"Loaded model {config['model']} to {self.device}")

    def create(self, params):
        if params.get("stream", False) and "messages" in params:
            raise NotImplementedError("Local models do not support streaming.")

        num_of_responses = params.get("n", 1)
        response = SimpleNamespace()

        inputs = self.tokenizer.apply_chat_template(
            params["messages"], return_tensors="pt", add_generation_prompt=True
        ).to(self.device)
        inputs_length = inputs.shape[-1]

        max_length = self.max_length + inputs_length
        generation_config = GenerationConfig(
            max_length=max_length,
            eos_token_id=self.tokenizer.eos_token_id,
            pad_token_id=self.tokenizer.eos_token_id,
        )

        response.choices = []
        response.model = self.model_name

        for _ in range(num_of_responses):
            outputs = self.model.generate(inputs, generation_config=generation_config)
            text = self.tokenizer.decode(outputs[0, inputs_length:])
            choice = SimpleNamespace()
            choice.message = SimpleNamespace()
            choice.message.content = text
            choice.message.function_call = None
            response.choices.append(choice)

        return response

    def message_retrieval(self, response):
        """Retrieve the messages from the response."""
        choices = response.choices
        return [choice.message.content for choice in choices]

    def cost(self, response) -> float:
        """Calculate the cost of the response."""
        response.cost = 0
        return 0

    @staticmethod
    def get_usage(response):
        return {}

Using the Custom Model#

Step 1: Configure the LLM#

filter_dict = {"model_client_cls": "CustomModelClient"}

llm_config = LLMConfig.from_json(path="OAI_CONFIG_LIST").where(**filter_dict)

Step 2: Create Agents#

assistant = AssistantAgent("assistant", llm_config=llm_config)

user_proxy = UserProxyAgent(
    "user_proxy",
    code_execution_config={
        "work_dir": "coding",
        "use_docker": False,  # Set use_docker=True if docker is available
    },
)

Step 3: Register the Custom Client#

assistant.register_model_client(model_client_cls=CustomModelClient)

Step 4: Initiate a Conversation#

user_proxy.initiate_chat(assistant, message="Write python code to print Hello World!")

Pre-loaded Model Approach#

If you want more control over when the model gets loaded, you can pre-load the model and pass it to the client:

class CustomModelClientWithArguments(CustomModelClient):
    def __init__(self, config, loaded_model, tokenizer, **kwargs):
        print(f"CustomModelClientWithArguments config: {config}")
        self.model_name = config["model"]
        self.model = loaded_model
        self.tokenizer = tokenizer
        self.device = config.get("device", "cpu")

        gen_config_params = config.get("params", {})
        self.max_length = gen_config_params.get("max_length", 256)
        print(f"Loaded model {config['model']} to {self.device}")

# Pre-load the model
config = llm_config.config_list[0]
device = config.get("device", "cpu")
loaded_model = AutoModelForCausalLM.from_pretrained(config["model"]).to(device)
tokenizer = AutoTokenizer.from_pretrained(config["model"], use_fast=False)
tokenizer.pad_token_id = tokenizer.eos_token_id

# Configure and register
filter_dict = {"model_client_cls": ["CustomModelClientWithArguments"]}
llm_config = LLMConfig.from_json(path="OAI_CONFIG_LIST").where(**filter_dict)

assistant = AssistantAgent("assistant", llm_config=llm_config)

assistant.register_model_client(
    model_client_cls=CustomModelClientWithArguments,
    loaded_model=loaded_model,
    tokenizer=tokenizer,
)

Additional Resources#