Multi-Node LLM Inference with SGLang on SLURM-Enabled Clusters
Performing inference for a Llama 405B-BF16 model on Multi-Node SLURM Clusters with SGLang
Hi Folks
Last week I found myself in a situation where I needed to run the Llama 405B model in BF16 precision however all the tools I could find did not play well with SLURM or had other issues.
For instance, vLLM just did not configure properly on multiple SLURM nodes and I got weird ray issues. I also looked at TGI but their Github issues seemed to suggest that they did not support full context length for the models. I also thought of TensorRT however the setup process needed things such as sudo access to install linux packages which I did not have to the cluster. I believe there are alternate ways, but I settled with SGLang finally.
I have been a regular user of SGLang however I had never run it on multiple nodes + I couldn’t find documentation to do it via their Engine API. Hence, I had put it on the backburner. In the end I just thought I can try launching an OpenAI compatible server and it worked well (barring the issues I had with writing the SLURM config). I followed the snippet here. So, to save you all that time, here is the SLURM file I use and a breakdown of the same -
The SLURM submit script -
#!/bin/bash -l
#SBATCH -o SLURM_Logs/h100-%x_%j_master.out
#SBATCH -e SLURM_Logs/h100-%x_%j_master.err
#SBATCH -D ./
#SBATCH -J h100-405B-Online-TP16-Sglang
#SBATCH --nodes=2
#SBATCH --ntasks=2  # Total tasks across all nodes
#SBATCH --cpus-per-task=18
#SBATCH --mem=224GB
#SBATCH --partition="h100" # To request for H100 GPUs. Your cluster might have other names and hence switch this to the partition of your choice.
#SBATCH --gres=gpu:h100:8 # Requesting for 8 H100 GPUs per node. You can change h100 to your partition and count to as many GPUs as you want to use from that node.
#SBATCH --nodelist=sws-8h100grid-05,sws-8h100grid-07  # Specific node pair (I request for this pair as it worked well and other pairs crashed in my experiments but this totally depends on your network connections etc. and hence you can omit this as likely you don't have nodes with this name)
#SBATCH --time=12:00:00
# Load required modules or set environment variables if needed
echo "[INFO] Activating environment on node $SLURM_PROCID"
if ! source ENV_FOLDER/bin/activate; then
    echo "[ERROR] Failed to activate environment" >&2
    exit 1
fi
# Define parameters
model=MODEL_PATH
tp_size=16
echo "[INFO] Running inference"
echo "[INFO] Model: $model"
echo "[INFO] TP Size: $tp_size"
# Define the NCCL init address using the hostname of the head node
HEAD_NODE=$(scontrol show hostname "$SLURM_NODELIST" | head -n 1)
NCCL_INIT_ADDR="${HEAD_NODE}:8000"
echo "[INFO] NCCL_INIT_ADDR: $NCCL_INIT_ADDR"
# Set different OUTLINES_CACHE_DIR before starting process on each node to prevent weird sqlite errors that arise on SLURM nodes when running multinode
export OUTLINES_CACHE_DIR="/tmp/node_0_cache"
# Launch processes with srun
srun --ntasks=1 --nodes=1 --exclusive --output="SLURM_Logs/%x_%j_node0.out" \
    --error="SLURM_Logs/%x_%j_node0.err" \
    python3 -m sglang.launch_server \
    --model-path "$model" \
    --tp "$tp_size" \
    --nccl-init-addr "$NCCL_INIT_ADDR" \
    --nnodes 2 \
    --node-rank 0 &
export OUTLINES_CACHE_DIR="/tmp/node_1_cache"
srun --ntasks=1 --nodes=1 --exclusive --output="SLURM_Logs/%x_%j_node1.out" \
    --error="SLURM_Logs/%x_%j_node1.err" \
    python3 -m sglang.launch_server \
    --model-path "$model" \
    --tp "$tp_size" \
    --nccl-init-addr "$NCCL_INIT_ADDR" \
    --nnodes 2 \
    --node-rank 1 &
# Wait for localhost:30000 to accept connections
while ! nc -z localhost 30000; do
    sleep 1
    echo "[INFO] Waiting for localhost:30000 to accept connections"
done
echo "[INFO] localhost:30000 is ready to accept connections"
# Test server via sending a CURL request
response=$(curl -s -X POST http://127.0.0.1:30000/v1/chat/completions \
-H "Authorization: Bearer None" \
-H "Content-Type: application/json" \
-d '{
  "model": "meta-llama/Meta-Llama-3.1-405B-Instruct",
  "messages": [
    {
      "role": "user",
      "content": "List 3 countries and their capitals."
    }
  ],
  "temperature": 0,
  "max_tokens": 64
}')
echo "[INFO] Response from server:"
echo "$response"
# Run inference via a Python script
python sglang_requester.py --model "meta-llama/Meta-Llama-3.1-405B-Instruct"
wait # Keeps waiting and doesn't let server die until end of requested timeframe. I do this to later SSH into the SLURM nodes and run python files there as well similar to sglang_requester.py but you can omit this/kill the process once you have all results.Here is the python request script -
import openai
import argparse
import time
import os
def perform_request(client, messages, model):
    s_time = time.time()
    response = client.chat.completions.create(
        model=model,
        messages=messages,
        temperature=0,  # Lower temperature for more focused responses
        max_tokens=20000,  # Reasonable length for a concise response
        top_p=1,  # Slightly higher for better fluency
        n=1,  # Single response is usually more stable
        seed=20242,  # Keep for reproducibility
    )
    e_time = time.time()
    print("Time taken for request: ", e_time - s_time)
    return response.choices[0].message.content
def main(args):
    print("Arguments: ", args)
    client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="None")
    message = [
        {"role": "system", "content": "You are a helpful assistant"},
        {"role": "user", "content": "Hey"},
    ]
    response = perform_request(client, message, args.model)
    print(response)
if __name__ == "__main__":
    argparser = argparse.ArgumentParser()
    argparser.add_argument("--model", type=str)
    args = argparser.parse_args()
    main(args)I hope this helps others in running large models across multiple nodes in SLURM machines.



Saved my life