from __future__ import annotations
import re
from typing import Optional, Union
from mlflow.entities.model_registry.model_version import ModelVersion
from mlflow.entities.model_registry.model_version_tag import ModelVersionTag
from mlflow.exceptions import MlflowException
from mlflow.prompt.constants import (
IS_PROMPT_TAG_KEY,
PROMPT_ASSOCIATED_RUN_IDS_TAG_KEY,
PROMPT_TEMPLATE_VARIABLE_PATTERN,
PROMPT_TEXT_DISPLAY_LIMIT,
PROMPT_TEXT_TAG_KEY,
)
# Alias type
PromptVersionTag = ModelVersionTag
def _is_reserved_tag(key: str) -> bool:
return key in {IS_PROMPT_TAG_KEY, PROMPT_TEXT_TAG_KEY, PROMPT_ASSOCIATED_RUN_IDS_TAG_KEY}
# Prompt is implemented as a special type of ModelVersion. MLflow stores both prompts
# and model versions in the model registry as ModelVersion DB records, but distinguishes
# them using the special tag "mlflow.prompt.is_prompt".
[docs]class Prompt(ModelVersion):
"""
An entity representing a prompt (template) for GenAI applications.
Args:
name: The name of the prompt.
version: The version number of the prompt.
template: The template text of the prompt. It can contain variables enclosed in
double curly braces, e.g. {{variable}}, which will be replaced with actual values
by the `format` method. MLflow use the same variable naming rules same as Jinja2
https://jinja.palletsprojects.com/en/stable/api/#notes-on-identifiers
commit_message: The commit message for the prompt version. Optional.
creation_timestamp: Timestamp of the prompt creation. Optional.
version_metadata: A dictionary of metadata associated with the **prompt version**.
This is useful for storing version-specific information, such as the author of
the changes. Optional.
prompt_tags: A dictionary of tags associated with the entire prompt. This is different
from the `version_metadata` as it is not tied to a specific version of the prompt.
"""
def __init__(
self,
name: str,
version: int,
template: str,
commit_message: Optional[str] = None,
creation_timestamp: Optional[int] = None,
version_metadata: Optional[dict[str, str]] = None,
prompt_tags: Optional[dict[str, str]] = None,
aliases: Optional[list[str]] = None,
):
# Store template text as a tag
version_metadata = version_metadata or {}
version_metadata[PROMPT_TEXT_TAG_KEY] = template
version_metadata[IS_PROMPT_TAG_KEY] = "true"
super().__init__(
name=name,
version=version,
creation_timestamp=creation_timestamp,
description=commit_message,
# "version_metadata" is represented as ModelVersion tags.
tags=[ModelVersionTag(key=key, value=value) for key, value in version_metadata.items()],
aliases=aliases,
)
self._variables = set(PROMPT_TEMPLATE_VARIABLE_PATTERN.findall(self.template))
# Store the prompt-level tags (from RegisteredModel).
self._prompt_tags = prompt_tags or {}
def __repr__(self) -> str:
text = (
self.template[:PROMPT_TEXT_DISPLAY_LIMIT] + "..."
if len(self.template) > PROMPT_TEXT_DISPLAY_LIMIT
else self.template
)
return f"Prompt(name={self.name}, version={self.version}, template={text})"
@property
def template(self) -> str:
"""
Return the template text of the prompt.
"""
return self._tags[PROMPT_TEXT_TAG_KEY]
@property
def variables(self) -> set[str]:
"""
Return a list of variables in the template text.
The value must be enclosed in double curly braces, e.g. {{variable}}.
"""
return self._variables
@property
def commit_message(self) -> Optional[str]:
"""
Return the commit message of the prompt version.
"""
return self.description # inherited from ModelVersion
@property
def version_metadata(self) -> dict[str, str]:
"""Return the tags of the prompt as a dictionary."""
# Remove the prompt text tag as it should not be user-facing
return {key: value for key, value in self._tags.items() if not _is_reserved_tag(key)}
@property
def tags(self) -> dict[str, str]:
"""
Return the prompt-level tags (from RegisteredModel).
"""
return {key: value for key, value in self._prompt_tags.items() if not _is_reserved_tag(key)}
@property
def run_ids(self) -> list[str]:
"""Get the run IDs associated with the prompt."""
run_tag = self._tags.get(PROMPT_ASSOCIATED_RUN_IDS_TAG_KEY)
if not run_tag:
return []
return run_tag.split(",")
@property
def uri(self) -> str:
"""Return the URI of the prompt."""
return f"prompts:/{self.name}/{self.version}"
[docs] @classmethod
def from_model_version(
cls, model_version: ModelVersion, prompt_tags: Optional[dict[str, str]] = None
) -> Prompt:
"""
Create a Prompt object from a ModelVersion object.
Args:
model_version: The ModelVersion object to convert to a Prompt.
prompt_tags: The prompt-level tags (from RegisteredModel). Optional.
"""
if IS_PROMPT_TAG_KEY not in model_version.tags:
raise MlflowException.invalid_parameter_value(
f"Name `{model_version.name}` is registered as a model, not a prompt. MLflow "
"does not allow registering a prompt with the same name as an existing model.",
)
if PROMPT_TEXT_TAG_KEY not in model_version.tags:
raise MlflowException.invalid_parameter_value(
f"Prompt `{model_version.name}` does not contain a prompt text"
)
return cls(
name=model_version.name,
version=model_version.version,
template=model_version.tags[PROMPT_TEXT_TAG_KEY],
commit_message=model_version.description,
creation_timestamp=model_version.creation_timestamp,
version_metadata=model_version.tags,
prompt_tags=prompt_tags,
aliases=model_version.aliases,
)