Skip to content

Sampler SDK

The Sampler SDK gives you access to the raw Grok-1 model allowing you to do your own prompt engineering. Please note that the version of Grok-1 currently available in the IDE and via the API is fine-tuned for conversations. This means, phrasing a task in the form of a dialogue often yields the best results.

Getting started

To get started with the Sampler SDK, create a new client and access the sampler property.

import xai_sdk

client = xai_sdk.Client()
sampler = client.sampler

Sampling from the model

The main function of the Sampler SDK is the sample function, which given a prompt samples tokens from the model. The function returns an async iterator that streams out the generated tokens.

simple_completion.py
"""A simple example demonstrating text completion."""

import asyncio

import xai_sdk


async def main():
    """Runs the example."""
    client = xai_sdk.Client()

    prompt = "The answer to life and the universe is"
    print(prompt, end="")
    async for token in client.sampler.sample(prompt="", inputs=(prompt,), max_len=3):
        print(token.token_str, end="")
    print("")


asyncio.run(main())

The sampler can also provide completion over multimodal inputs.

simple_completion.py
"""An example demonstrating multimodal completion."""

import asyncio

import xai_sdk


async def main():
    """Runs the example."""

    with open("dog.png", 'rb') as file:
        image = file.read()

    client = xai_sdk.Client()

    prompt = "What is this? "
    print(prompt, end="")
    async for token in client.sampler.sample(
            prompt="",
            inputs=(prompt, image),
            max_len=8,
            model_name="vlm-1"
    ):
        print(token.token_str, end="")
    print("")


asyncio.run(main())

API Reference

xai_sdk.sampler

Sampler API to generate text completions.

This API gives access to the raw underlying models and is the most versatile and complex way to interact with our models.

xai_sdk.sampler.AsyncSampler

Allows sampling from the raw model API. All functions are asynchronous.

Source code in xai_sdk/sampler.py
class AsyncSampler:
    """Allows sampling from the raw model API. All functions are asynchronous."""

    def __init__(
        self, stub: sampler_public_pb2_grpc.SamplerStub, initial_rng_seed: Optional[int] = None
    ):
        """Initializes a new instance of the `Sampler` class.

        Args:
            stub: The gRPC stub to use for interacting with the API.
            initial_rng_seed: First RNG seed to use for sampling. Each time we sample from the model
                and no RNG seed is explicitly specified, we deterministically generate a new seed
                based on the initial seed. This ensures that the generated responses are
                deterministic given a call order. If no initial seed is specified, we sample a
                random initial seed.
        """
        if initial_rng_seed is None:
            initial_rng_seed = random.randint(0, 100000)

        self._stub = stub
        self._initial_rng_seed = initial_rng_seed

    def _get_next_rng_seed(self) -> int:
        """Deterministically chooses a new RNG seed based on the initial seed."""
        seed = self._initial_rng_seed
        self._initial_rng_seed += 1
        return seed

    async def tokenize(self, prompt: str, model_name: str = "") -> list["Token"]:
        """Converts the given prompt text into a sequence of discrete tokens.

        Args:
            prompt: Text to convert into a sequence of tokens.
            model_name: Model whose tokenizer should be used. Make sure to use the same value when
                tokenizing and when sampling as different models use different tokenizers. Leave
                empty to use the default model's tokenizer.

        Returns:
            A sequence of discrete tokens that represent the original `prompt` text.
        """
        # Nothing to do if the text is empty.
        if not prompt:
            return []

        logging.debug("Tokenizing prompt with {len(prompt)} characters.")
        response: sampler_public_pb2.TokenizeResponse = await self._stub.Tokenize(
            sampler_public_pb2.TokenizeRequest(text=prompt, model_name=model_name)
        )
        tokens = response.tokens
        compression = (1 - len(tokens) / len(prompt)) * 100
        logging.debug(
            "Tokenization done. %d tokens detected (Compression of %.1f%%).",
            len(tokens),
            compression,
        )

        return [Token.from_proto(token) for token in tokens]

    async def sample(
        self,
        *,
        prompt: Union[str, Sequence[int], Sequence["Token"]],
        inputs: Sequence[Union[str, Sequence[int], bytes]] = (),
        model_name: str = "",
        max_len: int = 256,
        temperature: float = 0.7,
        nucleus_p: float = 0.95,
        stop_tokens: Optional[list[str]] = None,
        stop_strings: Optional[list[str]] = None,
        rng_seed: Optional[int] = None,
        return_attention: bool = False,
        allowed_tokens: Optional[Sequence[Union[int, str]]] = None,
        disallowed_tokens: Optional[Sequence[Union[int, str]]] = None,
        augment_tokens: bool = True,
    ) -> AsyncGenerator["Token", None]:
        """Generates a model response by continuing `prompt`.

        Args:
            prompt: [Deprecated, use inputs instead] Prompt to continue. This can either be a
                string, a sequence of token IDs, or a sequence of `Token` instances.
            inputs: Multimodal input of the model. This can be a sequence of strings, token IDs,
                image in bytes or base64 encoded string.
            model_name: Name of the model to sample from. Leave empty to sample from the default
                model.
            max_len: Maximum number of tokens to generate.
            temperature: Temperature of the final softmax operation. The lower the temperature, the
                lower the variance of the token distribution. In the limit, the distribution
                collapses onto the single token with the highest probability.
            nucleus_p: Threshold of the Top-P sampling technique: We rank all tokens by their
                probability and then only actually sample from the set of tokens that ranks in the
                Top-P percentile of the distribution.
            stop_tokens: A list of strings, each of which will be mapped independently to a single
                token. If a string does not map cleanly to one token, it will be silently ignored.
                If the network samples one of these tokens, sampling is stopped and the stop token
                *is not* included in the response.
            stop_strings: A list of strings. If any of these strings occurs in the network output,
                sampling is stopped but the string that triggered the stop *will be* included in the
                response. Note that the response may be longer than the stop string. For example, if
                the stop string is "Hel" and the network predicts the single-token response "Hello",
                sampling will be stopped but the response will still read "Hello".
            rng_seed: Seed of the random number generator used to sample from the model outputs. If
                unspecified, a seed is chosen deterministically from the `initial_rng_seed`
                specified in the constructor.
            return_attention: If true, returns the attention mask. Note that this can significantly
                increase the response size for long sequences.
            allowed_tokens: If set, only these tokens can be sampled. Invalid input tokens are
                ignored. Only one of `allowed_tokens` and `disallowed_tokens` must be set.
            disallowed_tokens: If set, these tokens cannot be sampled. Invalid input tokens are
                ignored. Only one of `allowed_tokens` and `disallowed_tokens` must be set.
            augment_tokens: If true, strings passed to `stop_tokens`, `allowed_tokens` and
                `disallowed_tokens` will be augmented to include both the passed token and the
                version with leading whitespace. This is useful because most words have two
                corresponding vocabulary entries: one with leading whitespace and one without.

        Yields:
            A sequence of `Token` instances.
        """

        if rng_seed is None:
            rng_seed = self._get_next_rng_seed()

        logging.debug(
            "Sampling %d tokens [seed=%d, temperature=%f, nucleus_p=%f, stop_tokens=%s, stop_strings=%s].",
            max_len,
            rng_seed,
            temperature,
            nucleus_p,
            stop_tokens,
            stop_strings,
        )

        if augment_tokens:
            # The underscore character here is not an ordinary underscore _, it's a special utf-8
            # character ▁ used by the tokenizer to indicate whitespace.
            if stop_tokens:
                stop_tokens = stop_tokens + [f"▁{t}" for t in stop_tokens]
            if allowed_tokens:
                allowed_tokens = list(allowed_tokens) + [
                    f"▁{t}" for t in allowed_tokens if isinstance(t, str) and not t.startswith("▁")
                ]
            if disallowed_tokens:
                disallowed_tokens = list(disallowed_tokens) + [
                    f"▁{t}"
                    for t in disallowed_tokens
                    if isinstance(t, str) and not t.startswith("▁")
                ]

        # Convert inputs
        converted_inputs = []
        if inputs:
            for element in inputs:
                if isinstance(element, str):
                    if element.startswith("data:image"):
                        converted_inputs.append(PromptInput(image_base64=element))
                    converted_inputs.append(PromptInput(text=element))
                elif isinstance(element, list):
                    converted_inputs.append(PromptInput(token_ids=TokenIds(tokens=element)))
                elif isinstance(element, bytes):
                    converted_inputs.append(PromptInput(image_bytes=element))
                else:
                    logging.error("Invalid input type %s.", type(element))
        else:
            converted_inputs.append(_prompt_to_input(prompt))

        request = sampler_public_pb2.SampleTokensRequest(
            inputs=converted_inputs,
            settings=sampler_public_pb2.SampleSettings(
                max_len=max_len or 0,
                temperature=temperature,
                nucleus_p=nucleus_p,
                stop_tokens=stop_tokens or [],
                stop_strings=stop_strings or [],
                rng_seed=rng_seed,
                allowed_tokens=[_parse_input_token(t) for t in allowed_tokens or []],
                disallowed_tokens=[_parse_input_token(t) for t in disallowed_tokens or []],
            ),
            return_attention=return_attention,
            model_name=model_name,
        )
        response = self._stub.SampleTokens(request)

        token_counter = 0
        async for token in response:
            if token.HasField("token"):
                token_counter += 1
                if token_counter % 10 == 0:
                    logging.debug("Sampled %d tokens", token_counter)
                yield Token.from_proto(token.token)
            elif token.HasField("budget"):
                # The sample request also sends the current token budget information.
                log_budget_update(token.budget)
xai_sdk.sampler.AsyncSampler.__init__(stub, initial_rng_seed=None)

Initializes a new instance of the Sampler class.

Parameters:

Name Type Description Default
stub SamplerStub

The gRPC stub to use for interacting with the API.

required
initial_rng_seed Optional[int]

First RNG seed to use for sampling. Each time we sample from the model and no RNG seed is explicitly specified, we deterministically generate a new seed based on the initial seed. This ensures that the generated responses are deterministic given a call order. If no initial seed is specified, we sample a random initial seed.

None
Source code in xai_sdk/sampler.py
def __init__(
    self, stub: sampler_public_pb2_grpc.SamplerStub, initial_rng_seed: Optional[int] = None
):
    """Initializes a new instance of the `Sampler` class.

    Args:
        stub: The gRPC stub to use for interacting with the API.
        initial_rng_seed: First RNG seed to use for sampling. Each time we sample from the model
            and no RNG seed is explicitly specified, we deterministically generate a new seed
            based on the initial seed. This ensures that the generated responses are
            deterministic given a call order. If no initial seed is specified, we sample a
            random initial seed.
    """
    if initial_rng_seed is None:
        initial_rng_seed = random.randint(0, 100000)

    self._stub = stub
    self._initial_rng_seed = initial_rng_seed
xai_sdk.sampler.AsyncSampler.sample(*, prompt, inputs=(), model_name='', max_len=256, temperature=0.7, nucleus_p=0.95, stop_tokens=None, stop_strings=None, rng_seed=None, return_attention=False, allowed_tokens=None, disallowed_tokens=None, augment_tokens=True) async

Generates a model response by continuing prompt.

Parameters:

Name Type Description Default
prompt Union[str, Sequence[int], Sequence[Token]]

[Deprecated, use inputs instead] Prompt to continue. This can either be a string, a sequence of token IDs, or a sequence of Token instances.

required
inputs Sequence[Union[str, Sequence[int], bytes]]

Multimodal input of the model. This can be a sequence of strings, token IDs, image in bytes or base64 encoded string.

()
model_name str

Name of the model to sample from. Leave empty to sample from the default model.

''
max_len int

Maximum number of tokens to generate.

256
temperature float

Temperature of the final softmax operation. The lower the temperature, the lower the variance of the token distribution. In the limit, the distribution collapses onto the single token with the highest probability.

0.7
nucleus_p float

Threshold of the Top-P sampling technique: We rank all tokens by their probability and then only actually sample from the set of tokens that ranks in the Top-P percentile of the distribution.

0.95
stop_tokens Optional[list[str]]

A list of strings, each of which will be mapped independently to a single token. If a string does not map cleanly to one token, it will be silently ignored. If the network samples one of these tokens, sampling is stopped and the stop token is not included in the response.

None
stop_strings Optional[list[str]]

A list of strings. If any of these strings occurs in the network output, sampling is stopped but the string that triggered the stop will be included in the response. Note that the response may be longer than the stop string. For example, if the stop string is "Hel" and the network predicts the single-token response "Hello", sampling will be stopped but the response will still read "Hello".

None
rng_seed Optional[int]

Seed of the random number generator used to sample from the model outputs. If unspecified, a seed is chosen deterministically from the initial_rng_seed specified in the constructor.

None
return_attention bool

If true, returns the attention mask. Note that this can significantly increase the response size for long sequences.

False
allowed_tokens Optional[Sequence[Union[int, str]]]

If set, only these tokens can be sampled. Invalid input tokens are ignored. Only one of allowed_tokens and disallowed_tokens must be set.

None
disallowed_tokens Optional[Sequence[Union[int, str]]]

If set, these tokens cannot be sampled. Invalid input tokens are ignored. Only one of allowed_tokens and disallowed_tokens must be set.

None
augment_tokens bool

If true, strings passed to stop_tokens, allowed_tokens and disallowed_tokens will be augmented to include both the passed token and the version with leading whitespace. This is useful because most words have two corresponding vocabulary entries: one with leading whitespace and one without.

True

Yields:

Type Description
AsyncGenerator[Token, None]

A sequence of Token instances.

Source code in xai_sdk/sampler.py
async def sample(
    self,
    *,
    prompt: Union[str, Sequence[int], Sequence["Token"]],
    inputs: Sequence[Union[str, Sequence[int], bytes]] = (),
    model_name: str = "",
    max_len: int = 256,
    temperature: float = 0.7,
    nucleus_p: float = 0.95,
    stop_tokens: Optional[list[str]] = None,
    stop_strings: Optional[list[str]] = None,
    rng_seed: Optional[int] = None,
    return_attention: bool = False,
    allowed_tokens: Optional[Sequence[Union[int, str]]] = None,
    disallowed_tokens: Optional[Sequence[Union[int, str]]] = None,
    augment_tokens: bool = True,
) -> AsyncGenerator["Token", None]:
    """Generates a model response by continuing `prompt`.

    Args:
        prompt: [Deprecated, use inputs instead] Prompt to continue. This can either be a
            string, a sequence of token IDs, or a sequence of `Token` instances.
        inputs: Multimodal input of the model. This can be a sequence of strings, token IDs,
            image in bytes or base64 encoded string.
        model_name: Name of the model to sample from. Leave empty to sample from the default
            model.
        max_len: Maximum number of tokens to generate.
        temperature: Temperature of the final softmax operation. The lower the temperature, the
            lower the variance of the token distribution. In the limit, the distribution
            collapses onto the single token with the highest probability.
        nucleus_p: Threshold of the Top-P sampling technique: We rank all tokens by their
            probability and then only actually sample from the set of tokens that ranks in the
            Top-P percentile of the distribution.
        stop_tokens: A list of strings, each of which will be mapped independently to a single
            token. If a string does not map cleanly to one token, it will be silently ignored.
            If the network samples one of these tokens, sampling is stopped and the stop token
            *is not* included in the response.
        stop_strings: A list of strings. If any of these strings occurs in the network output,
            sampling is stopped but the string that triggered the stop *will be* included in the
            response. Note that the response may be longer than the stop string. For example, if
            the stop string is "Hel" and the network predicts the single-token response "Hello",
            sampling will be stopped but the response will still read "Hello".
        rng_seed: Seed of the random number generator used to sample from the model outputs. If
            unspecified, a seed is chosen deterministically from the `initial_rng_seed`
            specified in the constructor.
        return_attention: If true, returns the attention mask. Note that this can significantly
            increase the response size for long sequences.
        allowed_tokens: If set, only these tokens can be sampled. Invalid input tokens are
            ignored. Only one of `allowed_tokens` and `disallowed_tokens` must be set.
        disallowed_tokens: If set, these tokens cannot be sampled. Invalid input tokens are
            ignored. Only one of `allowed_tokens` and `disallowed_tokens` must be set.
        augment_tokens: If true, strings passed to `stop_tokens`, `allowed_tokens` and
            `disallowed_tokens` will be augmented to include both the passed token and the
            version with leading whitespace. This is useful because most words have two
            corresponding vocabulary entries: one with leading whitespace and one without.

    Yields:
        A sequence of `Token` instances.
    """

    if rng_seed is None:
        rng_seed = self._get_next_rng_seed()

    logging.debug(
        "Sampling %d tokens [seed=%d, temperature=%f, nucleus_p=%f, stop_tokens=%s, stop_strings=%s].",
        max_len,
        rng_seed,
        temperature,
        nucleus_p,
        stop_tokens,
        stop_strings,
    )

    if augment_tokens:
        # The underscore character here is not an ordinary underscore _, it's a special utf-8
        # character ▁ used by the tokenizer to indicate whitespace.
        if stop_tokens:
            stop_tokens = stop_tokens + [f"▁{t}" for t in stop_tokens]
        if allowed_tokens:
            allowed_tokens = list(allowed_tokens) + [
                f"▁{t}" for t in allowed_tokens if isinstance(t, str) and not t.startswith("▁")
            ]
        if disallowed_tokens:
            disallowed_tokens = list(disallowed_tokens) + [
                f"▁{t}"
                for t in disallowed_tokens
                if isinstance(t, str) and not t.startswith("▁")
            ]

    # Convert inputs
    converted_inputs = []
    if inputs:
        for element in inputs:
            if isinstance(element, str):
                if element.startswith("data:image"):
                    converted_inputs.append(PromptInput(image_base64=element))
                converted_inputs.append(PromptInput(text=element))
            elif isinstance(element, list):
                converted_inputs.append(PromptInput(token_ids=TokenIds(tokens=element)))
            elif isinstance(element, bytes):
                converted_inputs.append(PromptInput(image_bytes=element))
            else:
                logging.error("Invalid input type %s.", type(element))
    else:
        converted_inputs.append(_prompt_to_input(prompt))

    request = sampler_public_pb2.SampleTokensRequest(
        inputs=converted_inputs,
        settings=sampler_public_pb2.SampleSettings(
            max_len=max_len or 0,
            temperature=temperature,
            nucleus_p=nucleus_p,
            stop_tokens=stop_tokens or [],
            stop_strings=stop_strings or [],
            rng_seed=rng_seed,
            allowed_tokens=[_parse_input_token(t) for t in allowed_tokens or []],
            disallowed_tokens=[_parse_input_token(t) for t in disallowed_tokens or []],
        ),
        return_attention=return_attention,
        model_name=model_name,
    )
    response = self._stub.SampleTokens(request)

    token_counter = 0
    async for token in response:
        if token.HasField("token"):
            token_counter += 1
            if token_counter % 10 == 0:
                logging.debug("Sampled %d tokens", token_counter)
            yield Token.from_proto(token.token)
        elif token.HasField("budget"):
            # The sample request also sends the current token budget information.
            log_budget_update(token.budget)
xai_sdk.sampler.AsyncSampler.tokenize(prompt, model_name='') async

Converts the given prompt text into a sequence of discrete tokens.

Parameters:

Name Type Description Default
prompt str

Text to convert into a sequence of tokens.

required
model_name str

Model whose tokenizer should be used. Make sure to use the same value when tokenizing and when sampling as different models use different tokenizers. Leave empty to use the default model's tokenizer.

''

Returns:

Type Description
list[Token]

A sequence of discrete tokens that represent the original prompt text.

Source code in xai_sdk/sampler.py
async def tokenize(self, prompt: str, model_name: str = "") -> list["Token"]:
    """Converts the given prompt text into a sequence of discrete tokens.

    Args:
        prompt: Text to convert into a sequence of tokens.
        model_name: Model whose tokenizer should be used. Make sure to use the same value when
            tokenizing and when sampling as different models use different tokenizers. Leave
            empty to use the default model's tokenizer.

    Returns:
        A sequence of discrete tokens that represent the original `prompt` text.
    """
    # Nothing to do if the text is empty.
    if not prompt:
        return []

    logging.debug("Tokenizing prompt with {len(prompt)} characters.")
    response: sampler_public_pb2.TokenizeResponse = await self._stub.Tokenize(
        sampler_public_pb2.TokenizeRequest(text=prompt, model_name=model_name)
    )
    tokens = response.tokens
    compression = (1 - len(tokens) / len(prompt)) * 100
    logging.debug(
        "Tokenization done. %d tokens detected (Compression of %.1f%%).",
        len(tokens),
        compression,
    )

    return [Token.from_proto(token) for token in tokens]

xai_sdk.sampler.Token dataclass

A token is an element of our vocabulary that has a unique index and string representation.

A token can either be sampled from a model or provided by the user (i.e. prompted). If the token comes from the mode, we may have additional metadata such as its sampling probability, the attention pattern used when sampling the token, and alternative tokens.

Source code in xai_sdk/sampler.py
@dataclasses.dataclass(frozen=True)
class Token:
    """A token is an element of our vocabulary that has a unique index and string representation.

    A token can either be sampled from a model or provided by the user (i.e. prompted). If the token
    comes from the mode, we may have additional metadata such as its sampling probability, the
    attention pattern used when sampling the token, and alternative tokens.
    """

    # The integer representation of the token. Corresponds to its index in the vocabulary.
    token_id: int
    # The string representation of the token. Corresponds to its value in the vocabulary.
    token_str: str
    # If this token was sampled, the token sampling probability. 0 if not sampled.
    prob: float
    # If this token was sampled, alternative tokens that could have been sampled instead.
    top_k: list["Token"]
    # If this token was sampled with the correct options, the token's attention pattern. The array
    # contains one value for every token in the context.
    attn_weights: list[float]
    # 1 if this token was created by a user and 2 if it was created by model.
    token_type: int

    @classmethod
    def from_proto(cls, proto: sampler_public_pb2.Token) -> "Token":
        """Converts the protocol buffer instance to a `Token` instance."""
        return Token(
            token_id=proto.final_logit.token_id,
            token_str=proto.final_logit.string_token,
            prob=proto.final_logit.prob,
            top_k=[
                Token.from_proto(
                    sampler_public_pb2.Token(
                        final_logit=logit,
                        top_k=[],
                        attention=[],
                        token_type=sampler_public_pb2.Token.TokenType.MODEL,
                    )
                )
                for logit in proto.top_k
            ],
            attn_weights=[a for a in proto.attention],
            token_type=proto.token_type,
        )
xai_sdk.sampler.Token.from_proto(proto) classmethod

Converts the protocol buffer instance to a Token instance.

Source code in xai_sdk/sampler.py
@classmethod
def from_proto(cls, proto: sampler_public_pb2.Token) -> "Token":
    """Converts the protocol buffer instance to a `Token` instance."""
    return Token(
        token_id=proto.final_logit.token_id,
        token_str=proto.final_logit.string_token,
        prob=proto.final_logit.prob,
        top_k=[
            Token.from_proto(
                sampler_public_pb2.Token(
                    final_logit=logit,
                    top_k=[],
                    attention=[],
                    token_type=sampler_public_pb2.Token.TokenType.MODEL,
                )
            )
            for logit in proto.top_k
        ],
        attn_weights=[a for a in proto.attention],
        token_type=proto.token_type,
    )

xai_sdk.sampler.log_budget_update(budget)

Logs a budget update.

Source code in xai_sdk/sampler.py
def log_budget_update(budget: sampler_public_pb2.TokenBudget) -> None:
    """Logs a budget update."""
    logging.info(
        "xAI Tokens used: %d (%f.1%%). Token limit: %d",
        budget.tokens_spent,
        budget.tokens_spent / budget.token_limit * 100,
        budget.token_limit,
    )