Skip to content

Utils

promptolution.utils

Module for utility functions and classes.

callbacks

Callback classes for logging, saving, and tracking optimization progress.

BaseCallback

Bases: ABC

Base class for optimization callbacks.

Callbacks can be used to monitor the optimization process, save checkpoints, log metrics, or implement early stopping criteria.

Source code in promptolution/utils/callbacks.py
class BaseCallback(ABC):
    """Base class for optimization callbacks.

    Callbacks can be used to monitor the optimization process, save checkpoints,
    log metrics, or implement early stopping criteria.

    """

    def __init__(self, **kwargs):
        """Initialize the callback with a configuration.

        Args:
            config: Configuration for the callback.
            **kwargs: Additional keyword arguments.
        """
        pass

    def on_step_end(self, optimizer):
        """Called at the end of each optimization step.

        Args:
            optimizer: The optimizer object that called the callback.

        Returns:
            Bool: True if the optimization should continue, False if it should stop.
        """
        return True

    def on_epoch_end(self, optimizer):
        """Called at the end of each optimization epoch.

        Args:
            optimizer: The optimizer object that called the callback.

        Returns:
            Bool: True if the optimization should continue, False if it should stop.
        """
        return True

    def on_train_end(self, optimizer):
        """Called at the end of the entire optimization process.

        Args:
            optimizer: The optimizer object that called the callback.

        Returns:
            Bool: True if the optimization should continue, False if it should stop.
        """
        return True
__init__(**kwargs)

Initialize the callback with a configuration.

Parameters:

Name Type Description Default
config

Configuration for the callback.

required
**kwargs

Additional keyword arguments.

{}
Source code in promptolution/utils/callbacks.py
def __init__(self, **kwargs):
    """Initialize the callback with a configuration.

    Args:
        config: Configuration for the callback.
        **kwargs: Additional keyword arguments.
    """
    pass
on_epoch_end(optimizer)

Called at the end of each optimization epoch.

Parameters:

Name Type Description Default
optimizer

The optimizer object that called the callback.

required

Returns:

Name Type Description
Bool

True if the optimization should continue, False if it should stop.

Source code in promptolution/utils/callbacks.py
def on_epoch_end(self, optimizer):
    """Called at the end of each optimization epoch.

    Args:
        optimizer: The optimizer object that called the callback.

    Returns:
        Bool: True if the optimization should continue, False if it should stop.
    """
    return True
on_step_end(optimizer)

Called at the end of each optimization step.

Parameters:

Name Type Description Default
optimizer

The optimizer object that called the callback.

required

Returns:

Name Type Description
Bool

True if the optimization should continue, False if it should stop.

Source code in promptolution/utils/callbacks.py
def on_step_end(self, optimizer):
    """Called at the end of each optimization step.

    Args:
        optimizer: The optimizer object that called the callback.

    Returns:
        Bool: True if the optimization should continue, False if it should stop.
    """
    return True
on_train_end(optimizer)

Called at the end of the entire optimization process.

Parameters:

Name Type Description Default
optimizer

The optimizer object that called the callback.

required

Returns:

Name Type Description
Bool

True if the optimization should continue, False if it should stop.

Source code in promptolution/utils/callbacks.py
def on_train_end(self, optimizer):
    """Called at the end of the entire optimization process.

    Args:
        optimizer: The optimizer object that called the callback.

    Returns:
        Bool: True if the optimization should continue, False if it should stop.
    """
    return True

BestPromptCallback

Bases: BaseCallback

Callback for tracking the best prompt during optimization.

This callback keeps track of the prompt with the highest score.

Attributes:

Name Type Description
best_prompt str

The prompt with the highest score so far.

best_score float

The highest score achieved so far.

Source code in promptolution/utils/callbacks.py
class BestPromptCallback(BaseCallback):
    """Callback for tracking the best prompt during optimization.

    This callback keeps track of the prompt with the highest score.

    Attributes:
        best_prompt (str): The prompt with the highest score so far.
        best_score (float): The highest score achieved so far.
    """

    def __init__(self):
        """Initialize the BestPromptCallback."""
        self.best_prompt = ""
        self.best_score = -99999

    def on_step_end(self, optimizer):
        """Update the best prompt and score if a new high score is achieved.

        Args:
        optimizer: The optimizer object that called the callback.
        """
        if optimizer.scores[0] > self.best_score:
            self.best_score = optimizer.scores[0]
            self.best_prompt = optimizer.prompts[0]

        return True

    def get_best_prompt(self):
        """Get the best prompt and score achieved during optimization.

        Returns:
        Tuple[str, float]: The best prompt and score.
        """
        return self.best_prompt, self.best_score
__init__()

Initialize the BestPromptCallback.

Source code in promptolution/utils/callbacks.py
def __init__(self):
    """Initialize the BestPromptCallback."""
    self.best_prompt = ""
    self.best_score = -99999
get_best_prompt()

Get the best prompt and score achieved during optimization.

Returns: Tuple[str, float]: The best prompt and score.

Source code in promptolution/utils/callbacks.py
def get_best_prompt(self):
    """Get the best prompt and score achieved during optimization.

    Returns:
    Tuple[str, float]: The best prompt and score.
    """
    return self.best_prompt, self.best_score
on_step_end(optimizer)

Update the best prompt and score if a new high score is achieved.

Args: optimizer: The optimizer object that called the callback.

Source code in promptolution/utils/callbacks.py
def on_step_end(self, optimizer):
    """Update the best prompt and score if a new high score is achieved.

    Args:
    optimizer: The optimizer object that called the callback.
    """
    if optimizer.scores[0] > self.best_score:
        self.best_score = optimizer.scores[0]
        self.best_prompt = optimizer.prompts[0]

    return True

FileOutputCallback

Bases: BaseCallback

Callback for saving optimization progress to a specified file type.

This callback saves information about each step to a file.

Attributes:

Name Type Description
dir str

Directory the file is saved to.

step int

The current step number.

file_type str

The type of file to save the output to.

Source code in promptolution/utils/callbacks.py
class FileOutputCallback(BaseCallback):
    """Callback for saving optimization progress to a specified file type.

    This callback saves information about each step to a file.

    Attributes:
        dir (str): Directory the file is saved to.
        step (int): The current step number.
        file_type (str): The type of file to save the output to.
    """

    def __init__(self, dir, file_type: Literal["parquet", "csv"] = "parquet"):
        """Initialize the FileOutputCallback.

        Args:
        dir (str): Directory the CSV file is saved to.
        file_type (str): The type of file to save the output to.
        """
        if not os.path.exists(dir):
            os.makedirs(dir)

        self.file_type = file_type

        if file_type == "parquet":
            self.path = dir + "/step_results.parquet"
        elif file_type == "csv":
            self.path = dir + "/step_results.csv"
        else:
            raise ValueError(f"File type {file_type} not supported.")

        self.step = 0

    def on_step_end(self, optimizer):
        """Save prompts and scores to csv.

        Args:
        optimizer: The optimizer object that called the callback
        """
        self.step += 1
        df = pd.DataFrame(
            {
                "step": [self.step] * len(optimizer.prompts),
                "input_tokens": [optimizer.meta_llm.input_token_count] * len(optimizer.prompts),
                "output_tokens": [optimizer.meta_llm.output_token_count] * len(optimizer.prompts),
                "time": [datetime.now().timestamp()] * len(optimizer.prompts),
                "score": optimizer.scores,
                "prompt": optimizer.prompts,
            }
        )

        if self.file_type == "parquet":
            if self.step == 1:
                df.to_parquet(self.path, index=False)
            else:
                df.to_parquet(self.path, mode="a", index=False)
        elif self.file_type == "csv":
            if self.step == 1:
                df.to_csv(self.path, index=False)
            else:
                df.to_csv(self.path, mode="a", header=False, index=False)

        return True
__init__(dir, file_type='parquet')

Initialize the FileOutputCallback.

Args: dir (str): Directory the CSV file is saved to. file_type (str): The type of file to save the output to.

Source code in promptolution/utils/callbacks.py
def __init__(self, dir, file_type: Literal["parquet", "csv"] = "parquet"):
    """Initialize the FileOutputCallback.

    Args:
    dir (str): Directory the CSV file is saved to.
    file_type (str): The type of file to save the output to.
    """
    if not os.path.exists(dir):
        os.makedirs(dir)

    self.file_type = file_type

    if file_type == "parquet":
        self.path = dir + "/step_results.parquet"
    elif file_type == "csv":
        self.path = dir + "/step_results.csv"
    else:
        raise ValueError(f"File type {file_type} not supported.")

    self.step = 0
on_step_end(optimizer)

Save prompts and scores to csv.

Args: optimizer: The optimizer object that called the callback

Source code in promptolution/utils/callbacks.py
def on_step_end(self, optimizer):
    """Save prompts and scores to csv.

    Args:
    optimizer: The optimizer object that called the callback
    """
    self.step += 1
    df = pd.DataFrame(
        {
            "step": [self.step] * len(optimizer.prompts),
            "input_tokens": [optimizer.meta_llm.input_token_count] * len(optimizer.prompts),
            "output_tokens": [optimizer.meta_llm.output_token_count] * len(optimizer.prompts),
            "time": [datetime.now().timestamp()] * len(optimizer.prompts),
            "score": optimizer.scores,
            "prompt": optimizer.prompts,
        }
    )

    if self.file_type == "parquet":
        if self.step == 1:
            df.to_parquet(self.path, index=False)
        else:
            df.to_parquet(self.path, mode="a", index=False)
    elif self.file_type == "csv":
        if self.step == 1:
            df.to_csv(self.path, index=False)
        else:
            df.to_csv(self.path, mode="a", header=False, index=False)

    return True

LoggerCallback

Bases: BaseCallback

Callback for logging optimization progress.

This callback logs information about each step, epoch, and the end of training.

Attributes:

Name Type Description
logger

The logger object to use for logging.

step int

The current step number.

Source code in promptolution/utils/callbacks.py
class LoggerCallback(BaseCallback):
    """Callback for logging optimization progress.

    This callback logs information about each step, epoch, and the end of training.

    Attributes:
        logger: The logger object to use for logging.
        step (int): The current step number.
    """

    def __init__(self, logger):
        """Initialize the LoggerCallback."""
        self.logger = logger
        self.step = 0

    def on_step_end(self, optimizer):
        """Log information about the current step."""
        self.step += 1
        time = datetime.now().strftime("%d-%m-%y %H:%M:%S:%f")
        self.logger.critical(f"{time} - ✨ Step {self.step} ended ✨")
        for i, (prompt, score) in enumerate(zip(optimizer.prompts, optimizer.scores)):
            self.logger.critical(f"📝 Prompt {i}: Score: {score}")
            self.logger.critical(f"💬 {prompt}")

        return True

    def on_train_end(self, optimizer, logs=None):
        """Log information at the end of training.

        Args:
        optimizer: The optimizer object that called the callback.
        logs: Additional information to log.
        """
        time = datetime.now().strftime("%d-%m-%y %H:%M:%S:%f")
        if logs is None:
            self.logger.critical(f"{time} - 🏁 Training ended")
        else:
            self.logger.critical(f"{time} - 🏁 Training ended - {logs}")

        return True
__init__(logger)

Initialize the LoggerCallback.

Source code in promptolution/utils/callbacks.py
def __init__(self, logger):
    """Initialize the LoggerCallback."""
    self.logger = logger
    self.step = 0
on_step_end(optimizer)

Log information about the current step.

Source code in promptolution/utils/callbacks.py
def on_step_end(self, optimizer):
    """Log information about the current step."""
    self.step += 1
    time = datetime.now().strftime("%d-%m-%y %H:%M:%S:%f")
    self.logger.critical(f"{time} - ✨ Step {self.step} ended ✨")
    for i, (prompt, score) in enumerate(zip(optimizer.prompts, optimizer.scores)):
        self.logger.critical(f"📝 Prompt {i}: Score: {score}")
        self.logger.critical(f"💬 {prompt}")

    return True
on_train_end(optimizer, logs=None)

Log information at the end of training.

Args: optimizer: The optimizer object that called the callback. logs: Additional information to log.

Source code in promptolution/utils/callbacks.py
def on_train_end(self, optimizer, logs=None):
    """Log information at the end of training.

    Args:
    optimizer: The optimizer object that called the callback.
    logs: Additional information to log.
    """
    time = datetime.now().strftime("%d-%m-%y %H:%M:%S:%f")
    if logs is None:
        self.logger.critical(f"{time} - 🏁 Training ended")
    else:
        self.logger.critical(f"{time} - 🏁 Training ended - {logs}")

    return True

ProgressBarCallback

Bases: BaseCallback

Callback for displaying a progress bar during optimization.

This callback uses tqdm to display a progress bar that updates at each step.

Attributes:

Name Type Description
pbar tqdm

The tqdm progress bar object.

Source code in promptolution/utils/callbacks.py
class ProgressBarCallback(BaseCallback):
    """Callback for displaying a progress bar during optimization.

    This callback uses tqdm to display a progress bar that updates at each step.

    Attributes:
        pbar (tqdm): The tqdm progress bar object.
    """

    def __init__(self, total_steps):
        """Initialize the ProgressBarCallback.

        Args:
        total_steps (int): The total number of steps in the optimization process.
        """
        self.pbar = tqdm(total=total_steps)

    def on_step_end(self, optimizer):
        """Update the progress bar at the end of each step.

        Args:
        optimizer: The optimizer object that called the callback.
        """
        self.pbar.update(1)

        return True

    def on_train_end(self, optimizer):
        """Close the progress bar at the end of training.

        Args:
        optimizer: The optimizer object that called the callback.
        """
        self.pbar.close()

        return True
__init__(total_steps)

Initialize the ProgressBarCallback.

Args: total_steps (int): The total number of steps in the optimization process.

Source code in promptolution/utils/callbacks.py
def __init__(self, total_steps):
    """Initialize the ProgressBarCallback.

    Args:
    total_steps (int): The total number of steps in the optimization process.
    """
    self.pbar = tqdm(total=total_steps)
on_step_end(optimizer)

Update the progress bar at the end of each step.

Args: optimizer: The optimizer object that called the callback.

Source code in promptolution/utils/callbacks.py
def on_step_end(self, optimizer):
    """Update the progress bar at the end of each step.

    Args:
    optimizer: The optimizer object that called the callback.
    """
    self.pbar.update(1)

    return True
on_train_end(optimizer)

Close the progress bar at the end of training.

Args: optimizer: The optimizer object that called the callback.

Source code in promptolution/utils/callbacks.py
def on_train_end(self, optimizer):
    """Close the progress bar at the end of training.

    Args:
    optimizer: The optimizer object that called the callback.
    """
    self.pbar.close()

    return True

TokenCountCallback

Bases: BaseCallback

Callback for stopping optimization based on the total token count.

Source code in promptolution/utils/callbacks.py
class TokenCountCallback(BaseCallback):
    """Callback for stopping optimization based on the total token count."""

    def __init__(
        self,
        max_tokens_for_termination: int,
        token_type_for_termination: Literal["input_tokens", "output_tokens", "total_tokens"],
    ):
        """Initialize the TokenCountCallback.

        Args:
        max_tokens_for_termination (int): Maximum number of tokens which is allowed befor the algorithm is stopped.
        token_type_for_termination (str): Can be one of either "input_tokens", "output_tokens" or "total_tokens".
        """
        self.max_tokens_for_termination = max_tokens_for_termination
        self.token_type_for_termination = token_type_for_termination

    def on_step_end(self, optimizer):
        """Check if the total token count exceeds the maximum allowed. If so, stop the optimization."""
        token_counts = optimizer.predictor.llm.get_token_count()

        if token_counts[self.token_type_for_termination] > self.max_tokens_for_termination:
            return False

        return True
__init__(max_tokens_for_termination, token_type_for_termination)

Initialize the TokenCountCallback.

Args: max_tokens_for_termination (int): Maximum number of tokens which is allowed befor the algorithm is stopped. token_type_for_termination (str): Can be one of either "input_tokens", "output_tokens" or "total_tokens".

Source code in promptolution/utils/callbacks.py
def __init__(
    self,
    max_tokens_for_termination: int,
    token_type_for_termination: Literal["input_tokens", "output_tokens", "total_tokens"],
):
    """Initialize the TokenCountCallback.

    Args:
    max_tokens_for_termination (int): Maximum number of tokens which is allowed befor the algorithm is stopped.
    token_type_for_termination (str): Can be one of either "input_tokens", "output_tokens" or "total_tokens".
    """
    self.max_tokens_for_termination = max_tokens_for_termination
    self.token_type_for_termination = token_type_for_termination
on_step_end(optimizer)

Check if the total token count exceeds the maximum allowed. If so, stop the optimization.

Source code in promptolution/utils/callbacks.py
def on_step_end(self, optimizer):
    """Check if the total token count exceeds the maximum allowed. If so, stop the optimization."""
    token_counts = optimizer.predictor.llm.get_token_count()

    if token_counts[self.token_type_for_termination] > self.max_tokens_for_termination:
        return False

    return True

config

Configuration class for the promptolution library.

ExperimentConfig

Configuration class for the promptolution library.

This is a unified configuration class that handles all experiment settings. It provides validation and tracking of used fields.

Source code in promptolution/utils/config.py
class ExperimentConfig:
    """Configuration class for the promptolution library.

    This is a unified configuration class that handles all experiment settings.
    It provides validation and tracking of used fields.
    """

    def __init__(self, **kwargs):
        """Initialize the configuration with the provided keyword arguments."""
        self._used_attributes: Set[str] = set()
        for key, value in kwargs.items():
            setattr(self, key, value)

    def __setattr__(self, name, value):
        """Override attribute setting to track used attributes."""
        # Set the attribute using the standard mechanism
        object.__setattr__(self, name, value)
        if not name.startswith("_") and not callable(value):
            self._used_attributes.add(name)

    def __getattribute__(self, name):
        """Override attribute access to track used attributes."""
        # Get the attribute using the standard mechanism
        try:
            value = object.__getattribute__(self, name)
        except AttributeError:
            return None
        if not name.startswith("_") and not callable(value):
            self._used_attributes.add(name)

        return value

    def apply_to(self, obj):
        """Apply matching attributes from this config to an existing object.

        Examines each attribute of the target object and updates it if a matching
        attribute exists in the config.

        Args:
            obj: The object to update with config values

        Returns:
            The updated object
        """
        for attr_name in dir(obj):
            if attr_name.startswith("_") or not isinstance(
                getattr(obj, attr_name), (str, int, float, list, type(None))
            ):
                continue

            if hasattr(self, attr_name) and getattr(self, attr_name) is not None:
                setattr(obj, attr_name, getattr(self, attr_name))

        return obj

    def validate(self):
        """Check if any attributes were not used and run validation.

        Does not raise an error, but logs a warning if any attributes are unused or validation fails.
        """
        all_attributes = {k for k in self.__dict__ if not k.startswith("_")}
        unused_attributes = all_attributes - self._used_attributes
        if unused_attributes:
            logger.warning(f"⚠️ Unused configuration attributes: {unused_attributes}")
__getattribute__(name)

Override attribute access to track used attributes.

Source code in promptolution/utils/config.py
def __getattribute__(self, name):
    """Override attribute access to track used attributes."""
    # Get the attribute using the standard mechanism
    try:
        value = object.__getattribute__(self, name)
    except AttributeError:
        return None
    if not name.startswith("_") and not callable(value):
        self._used_attributes.add(name)

    return value
__init__(**kwargs)

Initialize the configuration with the provided keyword arguments.

Source code in promptolution/utils/config.py
def __init__(self, **kwargs):
    """Initialize the configuration with the provided keyword arguments."""
    self._used_attributes: Set[str] = set()
    for key, value in kwargs.items():
        setattr(self, key, value)
__setattr__(name, value)

Override attribute setting to track used attributes.

Source code in promptolution/utils/config.py
def __setattr__(self, name, value):
    """Override attribute setting to track used attributes."""
    # Set the attribute using the standard mechanism
    object.__setattr__(self, name, value)
    if not name.startswith("_") and not callable(value):
        self._used_attributes.add(name)
apply_to(obj)

Apply matching attributes from this config to an existing object.

Examines each attribute of the target object and updates it if a matching attribute exists in the config.

Parameters:

Name Type Description Default
obj

The object to update with config values

required

Returns:

Type Description

The updated object

Source code in promptolution/utils/config.py
def apply_to(self, obj):
    """Apply matching attributes from this config to an existing object.

    Examines each attribute of the target object and updates it if a matching
    attribute exists in the config.

    Args:
        obj: The object to update with config values

    Returns:
        The updated object
    """
    for attr_name in dir(obj):
        if attr_name.startswith("_") or not isinstance(
            getattr(obj, attr_name), (str, int, float, list, type(None))
        ):
            continue

        if hasattr(self, attr_name) and getattr(self, attr_name) is not None:
            setattr(obj, attr_name, getattr(self, attr_name))

    return obj
validate()

Check if any attributes were not used and run validation.

Does not raise an error, but logs a warning if any attributes are unused or validation fails.

Source code in promptolution/utils/config.py
def validate(self):
    """Check if any attributes were not used and run validation.

    Does not raise an error, but logs a warning if any attributes are unused or validation fails.
    """
    all_attributes = {k for k in self.__dict__ if not k.startswith("_")}
    unused_attributes = all_attributes - self._used_attributes
    if unused_attributes:
        logger.warning(f"⚠️ Unused configuration attributes: {unused_attributes}")

logging

Logging configuration for the promptolution library.

get_logger(name, level=None)

Get a logger with the specified name and level.

This function provides a standardized way to get loggers throughout the library, ensuring consistent formatting and behavior.

Parameters:

Name Type Description Default
name str

Name of the logger, typically name of the module.

required
level int

Logging level. Defaults to None, which uses the root logger's level.

None

Returns:

Type Description
Logger

logging.Logger: Configured logger instance.

Source code in promptolution/utils/logging.py
def get_logger(name: str, level: Optional[int] = None) -> logging.Logger:
    """Get a logger with the specified name and level.

    This function provides a standardized way to get loggers throughout the library,
    ensuring consistent formatting and behavior.

    Args:
        name (str): Name of the logger, typically __name__ of the module.
        level (int, optional): Logging level. Defaults to None, which uses the root logger's level.

    Returns:
        logging.Logger: Configured logger instance.
    """
    logger = logging.getLogger(name)
    if level is not None:
        logger.setLevel(level)
    return logger

setup_logging(level=logging.INFO)

Set up logging for the promptolution library.

This function configures the root logger for the library with appropriate formatting and level.

Parameters:

Name Type Description Default
level int

Logging level. Defaults to logging.INFO.

INFO
Source code in promptolution/utils/logging.py
def setup_logging(level: int = logging.INFO) -> None:
    """Set up logging for the promptolution library.

    This function configures the root logger for the library with appropriate
    formatting and level.

    Args:
        level (int, optional): Logging level. Defaults to logging.INFO.
    """
    # Configure the root logger
    logging.basicConfig(
        level=level,
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )

prompt_creation

Utility functions for prompt creation.

create_prompt_variation(prompt, llm, meta_prompt=None)

Generate a variation of the given prompt(s) while keeping the semantic meaning.

Idea taken from the paper Zhou et al. (2021) https://arxiv.org/pdf/2211.01910

Parameters:

Name Type Description Default
prompt Union[List[str], str]

The prompt(s) to generate variations of.

required
llm BaseLLM

The language model to use for generating the variations.

required
meta_prompt str

The meta prompt to use for generating the variations.

None

Returns:

Type Description
List[str]

List[str]: A list of generated variations of the input prompt(s).

Source code in promptolution/utils/prompt_creation.py
def create_prompt_variation(prompt: Union[List[str], str], llm: "BaseLLM", meta_prompt: str = None) -> List[str]:
    """Generate a variation of the given prompt(s) while keeping the semantic meaning.

    Idea taken from the paper Zhou et al. (2021) https://arxiv.org/pdf/2211.01910

    Args:
        prompt (Union[List[str], str]): The prompt(s) to generate variations of.
        llm (BaseLLM): The language model to use for generating the variations.
        meta_prompt (str): The meta prompt to use for generating the variations.
        If None, a default meta prompt is used. Should contain <prev_prompt> tag.

    Returns:
        List[str]: A list of generated variations of the input prompt(s).
    """
    meta_prompt = PROMPT_VARIATION_TEMPLATE if meta_prompt is None else meta_prompt

    if isinstance(prompt, str):
        prompt = [prompt]
    varied_prompts = llm.get_response([meta_prompt.replace("<prev_prompt>", p) for p in prompt])

    varied_prompts = [p.split("</prompt>")[0].split("<prompt>")[-1] for p in varied_prompts]

    return varied_prompts

create_prompts_from_samples(task, llm, meta_prompt=None, n_samples=3, task_description=None, n_prompts=1, get_uniform_labels=False)

Generate a set of prompts from dataset examples sampled from a given task.

Idea taken from the paper Zhou et al. (2021) https://arxiv.org/pdf/2211.01910 Samples are selected, such that (1) all possible classes are represented (2) the samples are as representative as possible

Parameters:

Name Type Description Default
task BaseTask

The task to generate prompts for.

required
llm BaseLLM

The language model to use for generating the prompts.

required
meta_prompt str

The meta prompt to use for generating the prompts.

None
n_samples int

The number of samples to use for generating prompts.

3
task_description str

The description of the task to include in the prompt.

None
n_prompts int

The number of prompts to generate.

1
get_uniform_labels bool

If True, samples are selected such that all classes are represented.

False

Returns:

Type Description
List[str]

List[str]: A list of generated prompts.

Source code in promptolution/utils/prompt_creation.py
def create_prompts_from_samples(
    task: "BaseTask",
    llm: "BaseLLM",
    meta_prompt: str = None,
    n_samples: int = 3,
    task_description: str = None,
    n_prompts: int = 1,
    get_uniform_labels: bool = False,
) -> List[str]:
    """Generate a set of prompts from dataset examples sampled from a given task.

    Idea taken from the paper Zhou et al. (2021) https://arxiv.org/pdf/2211.01910
    Samples are selected, such that
    (1) all possible classes are represented
    (2) the samples are as representative as possible

    Args:
        task (BaseTask): The task to generate prompts for.
        Xs and Ys from this object are used to generate the prompts.
        llm (BaseLLM): The language model to use for generating the prompts.
        meta_prompt (str): The meta prompt to use for generating the prompts.
        If None, a default meta prompt is used.
        n_samples (int): The number of samples to use for generating prompts.
        task_description (str): The description of the task to include in the prompt.
        n_prompts (int): The number of prompts to generate.
        get_uniform_labels (bool): If True, samples are selected such that all classes are represented.

    Returns:
        List[str]: A list of generated prompts.
    """
    if meta_prompt is None and task_description is None:
        meta_prompt_template = PROMPT_CREATION_TEMPLATE
    elif meta_prompt is None and task_description is not None:
        meta_prompt_template = PROMPT_CREATION_TEMPLATE_TD.replace("<task_desc>", task_description)
    elif meta_prompt is not None and task_description is None:
        meta_prompt_template = meta_prompt
    elif meta_prompt is not None and task_description is not None:
        meta_prompt_template = meta_prompt.replace("<task_desc>", task_description)

    meta_prompts = []
    for _ in range(n_prompts):
        if isinstance(task, ClassificationTask) and get_uniform_labels:
            # if classification task sample such that all classes are represented
            unique_labels, counts = np.unique(task.ys, return_counts=True)
            proportions = counts / len(task.ys)
            samples_per_class = np.round(proportions * n_samples).astype(int)
            samples_per_class = np.maximum(samples_per_class, 1)

            # sample
            xs = []
            ys = []
            for label, n_samples in zip(unique_labels, samples_per_class):
                indices = np.where(task.ys == label)[0]
                indices = np.random.choice(indices, n_samples, replace=False)
                xs.extend(task.xs[indices])
                ys.extend(task.ys[indices])

        else:
            # if not classification task, sample randomly
            indices = np.random.choice(len(task.xs), n_samples, replace=False)
            xs = task.xs[indices].tolist()
            ys = task.ys[indices].tolist()

        examples = "\n\n".join([f"Input: {x}\nOutput: {y}" for x, y in zip(xs, ys)])
        meta_prompt = meta_prompt_template.replace("<input_output_pairs>", examples)
        meta_prompts.append(meta_prompt)

    prompts = llm.get_response(meta_prompts)
    prompts = [prompt.split("</prompt>")[0].split("<prompt>")[-1].strip() for prompt in prompts]

    return prompts

test_statistics

Implementation of statistical significance tests used in the racing algorithm. Contains paired t-test functionality to compare prompt performance and determine statistical significance between candidates.

get_test_statistic_func(name)

Get the test statistic function based on the name provided.

Parameters:

Name Type Description Default
name str

Name of the test statistic function.

required

Returns:

Name Type Description
callable callable

The corresponding test statistic function.

Source code in promptolution/utils/test_statistics.py
def get_test_statistic_func(name: TestStatistics) -> callable:
    """
    Get the test statistic function based on the name provided.

    Args:
        name (str): Name of the test statistic function.

    Returns:
        callable: The corresponding test statistic function.
    """
    if name == "paired_t_test":
        return paired_t_test
    else:
        raise ValueError(f"Unknown test statistic function: {name}. Should be one of {TestStatistics.__args__}.")

paired_t_test(scores_a, scores_b, alpha=0.05)

Uses a paired t-test to test if candidate A's accuracy is significantly higher than candidate B's accuracy within a confidence interval of 1-lpha. Assumptions: - The samples are paired. - The differences between the pairs are normally distributed (-> n > 30).

Parameters:

Name Type Description Default
scores_a ndarray

Array of accuracy scores for candidate A.

required
scores_b ndarray

Array of accuracy scores for candidate B.

required
alpha float

Significance level (default 0.05 for 95% confidence).

0.05

Returns:

Name Type Description
bool bool

True if candidate A is significantly better than candidate B, False otherwise.

Source code in promptolution/utils/test_statistics.py
def paired_t_test(
    scores_a: np.ndarray,
    scores_b: np.ndarray,
    alpha: float = 0.05,
) -> bool:
    """
    Uses a paired t-test to test if candidate A's accuracy is significantly
    higher than candidate B's accuracy within a confidence interval of 1-\alpha.
    Assumptions:
    - The samples are paired.
    - The differences between the pairs are normally distributed (-> n > 30).

    Parameters:
        scores_a (np.ndarray): Array of accuracy scores for candidate A.
        scores_b (np.ndarray): Array of accuracy scores for candidate B.
        alpha (float): Significance level (default 0.05 for 95% confidence).

    Returns:
        bool: True if candidate A is significantly better than candidate B, False otherwise.
    """

    _, p_value = ttest_rel(scores_a, scores_b, alternative="greater")

    result = p_value < alpha

    return result

token_counter

Token counter for LLMs.

This module provides a function to count the number of tokens in a given text.

get_token_counter(llm)

Get a token counter function for the given LLM.

This function returns a callable that counts tokens based on the LLM's tokenizer or a simple split method if no tokenizer is available.

Parameters:

Name Type Description Default
llm

The language model object that may have a tokenizer.

required

Returns:

Type Description

A callable that takes a text input and returns the token count.

Source code in promptolution/utils/token_counter.py
def get_token_counter(llm):
    """Get a token counter function for the given LLM.

    This function returns a callable that counts tokens based on the LLM's tokenizer
    or a simple split method if no tokenizer is available.

    Args:
        llm: The language model object that may have a tokenizer.

    Returns:
        A callable that takes a text input and returns the token count.

    """
    if hasattr(llm, "tokenizer"):
        token_counter = lambda x: len(llm.tokenizer(x)["input_ids"])
    else:
        logger.warning("⚠️ The LLM does not have a tokenizer. Using simple token count.")
        token_counter = lambda x: len(x.split())

    return token_counter