MLOps: More Oops than Ops

# LLMs
# MLops
# onnx
# performance
# prem
# tensorrt
π€ image generated using the Stable Diffusion 2
August 30, 2023
Biswaroop Bhattacharjee

Biswaroop Bhattacharjee


ο»Ώ
As model complexity increases exponentially, so too does the need for effective MLOps practices. This post acts as a transparent write-up of all the MLOps frustrations Iβve experienced in the last few days. By sharing my challenges and insights, I hope to contribute to a community that openly discusses and shares solutions for MLOps challenges.
My goal was to improve Inference latency of few of the current state-of-the-art LLMs.
Unfortunately, simply downloading trained model weights & existing code doesnβt solve this problem.
The Promise of Faster Inference
My first target here was Llama 2. I wanted to convert it into ONNX format, which could then be converted to TensorRT, and finally served using Triton Inference Server.
TensorRT optimizes the model network by combining layers and optimizing kernel selection for improved latency, throughput, power efficiency and memory consumption. If the application specifies, it will additionally optimize the network to run in lower precision, further increasing performance and reducing memory requirements.
From online benchmarks [1, 2] it seems possible to achieve a 2~3x boost to latency (by reducing precision without hurting quality much). But the workings for these kind of format conversions feel super flaky, things break too often (without any solution to be found online). Yes, itβs somewhat expected since these models are so new, with different architectures using different (not yet widely-supported) layers and operators.
Model Conversion Errors
Letβs start with Llama 2 7B chat,
- Firstly Iβve downloaded Llama-2-7B-Chat weights from Metaβs Official repository here after requesting.
- Convert raw weights to huggingface format using this script by Huggingface. Letβs say we save it under
llama-2-7b-chat-hf
directory locally.
Now I considered two options for converting Huggingface models to ONNX format:
torch.onnx.export
gibberish textβο»Ώ
Letβs write an
export_to_onnx
function which will load the tokenizer & model, and export it into ONNX format:import torchfrom composer.utils import parse_uri, reproducibilityfrom pathlib import Pathfrom transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizerο»Ώ
def export_to_onnx( pretrained_model_name_or_path: str, output_folder: str, verify_export: bool, max_seq_len: int | None = None,): reproducibility.seed_all(42) _, _, parsed_save_path = parse_uri(output_folder) # Load HF config/model/tokenizer tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, use_fast=True) config = AutoConfig.from_pretrained(pretrained_model_name_or_path) if hasattr(config, 'attn_config'): config.attn_config['attn_impl'] = 'torch'ο»Ώ
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, config=config).to("cuda:0") model.eval() # tips: https://huggingface.co/docs/transformers/v4.31.0/en/model_doc/llama2 tokenizer.add_special_tokens({"pad_token": "<pad>"}) model.resize_token_embeddings(len(tokenizer)) model.config.pad_token_id = tokenizer.pad_token_id sample_input = tokenizer( "Hello, my dog is cute", padding="max_length", max_length=max_seq_len or model.config.max_seq_len, truncation=True, return_tensors="pt", add_special_tokens=True).to("cuda:0")ο»Ώ
with torch.no_grad(): model(**sample_input)ο»Ώ
output_file = Path(parsed_save_path) / 'model.onnx' output_file.parent.mkdir(parents=True, exist_ok=True) # Put sample input on cpu for export sample_input = {k: v.cpu() for k, v in sample_input.items()} model = model.to("cpu") torch.onnx.export( model, (sample_input,), str(output_file), input_names=['input_ids', 'attention_mask'], output_names=['output'], opset_version=16)
We can also check if the exported & original modelsβ outputs are similar:
# (Optional) verify onnx model outputsimport onnximport onnx.checkerimport onnxruntime as ortο»Ώ
with torch.no_grad(): orig_out = model(**sample_input) orig_out.logits = orig_out.logits.cpu() # put on cpu for exportο»Ώ
_ = onnx.load(str(output_file))onnx.checker.check_model(str(output_file))ort_session = ort.InferenceSession(str(output_file))for key, value in sample_input.items(): sample_input[key] = value.cpu().numpy()loaded_model_out = ort_session.run(None, sample_input)torch.testing.assert_close( orig_out.logits.detach().numpy(), loaded_model_out[0], rtol=1e-2, atol=1e-2, msg=f'output mismatch between the orig and onnx exported model')print('Success: exported & original model outputs match')
Assuming weβve saved the ONNX model in
./llama-2-7b-onnx/
, we can now run inference using onnxruntime
:import onnximport onnx.checkerimport onnxruntime as ortimport torchfrom transformers import AutoTokenizer, AutoModelForCausalLMο»Ώ
output_file = 'llama-2-7b-onnx/model.onnx' # converted model from aboveort_session = ort.InferenceSession(str(output_file))tokenizer = AutoTokenizer.from_pretrained("llama-2-7b-chat-hf", use_fast=True)tokenizer.add_special_tokens({"pad_token": "<pad>"})inputs = tokenizer( "Hello, my dog is cute", padding="max_length", max_length=1024, truncation=True, return_tensors="np", add_special_tokens=True)loaded_model_out = ort_session.run(None, inputs.data)tokenizer.batch_decode(torch.argmax(torch.tensor(loaded_model_out[0]), dim=-1))
π On my machine, this generates really funky outputs:
ΠΠΠΠΠΠ\n\n\n\n\n\n\n\n\n\n Hello Hinweis Hinweis Hinweis Hinweis Hinweis Hinweis Hinweis Hinweis Hinweis Hinweis Hinweis Hinweis Hinweis Hinweis Hinweis Hinweis Hinweis Hinweis Hinweis Hinweis Hinweis Hinweis Hinweis Hinweis Hinweis Hinweis Hinweis Hinweis Hinweis Hinweis Hinweis Hinweis Hinweis Hinweis Hinweis Hinweis Hinweis Hinweis Hinweis Hinweis Hinweis Hinweis Hinweis Hinweis..........SMSMSMSMSMSMSMSMSMSMSMS Unterscheidung, I name is ough,
β¦ which is mostly due to missing a proper decoding strategy (greedy, beam, etc.) while generating tokens.
optimum-cli
gibberish text and tensorrt
slowness
To solve the problem above, we can try a different exporter which includes decoding strategies.
Using the Optimum ONNX exporter instead (assuming the original model is in
./llama-2-7b-chat-hf/
), we can do:optimum-cli export onnx \ --model ./llama-2-7b-chat-hf/ --task text-generation --framework pt \ --opset 16 --sequence_length 1024 --batch_size 1 --device cuda --fp16 \ llama-2-7b-optimum/
β This takes a few minutes to generate. If you donβt has a GPU for this conversion, then remove
--device cuda
from the above command.The result is:
llama-2-7b-optimum βββ config.json βββ Constant_162_attr__value βββ Constant_170_attr__value βββ decoder_model.onnx βββ decoder_model.onnx_data βββ generation_config.json βββ special_tokens_map.json βββ tokenizer_config.json βββ tokenizer.json βββ tokenizer.model
Now when I try to do inference using
optimum.onnxruntime.ORTModelForCausalLM
, things work fine (though slowly) using the CPUExecutionProvider
:from transformers import AutoTokenizerfrom optimum.onnxruntime import ORTModelForCausalLMο»Ώ
tokenizer = AutoTokenizer.from_pretrained("./onnx_optimum")model = ORTModelForCausalLM.from_pretrained("./onnx_optimum/", use_cache=False, use_io_binding=False)inputs = tokenizer("My name is Arthur and I live in", return_tensors="pt")gen_tokens = model.generate(**inputs, max_length=16)assert model.providers == ['CPUExecutionProvider']print(tokenizer.batch_decode(gen_tokens))
After waiting a long time, we get a result:
<s> My name is Arthur and I live in a small town in the countr
But when switching to the faster
CUDAExecutionProvider
, I get gibberish text on inference:model = ORTModelForCausalLM.from_pretrained("./onnx_optimum/", use_cache=False, use_io_binding=False, provider="CUDAExecutionProvider")inputs = tokenizer("My name is Arthur and I live in", return_tensors="pt").to("cuda")gen_tokens = model.generate(**inputs, max_length=16)assert model.providers == ['CUDAExecutionProvider', 'CPUExecutionProvider']print(tokenizer.batch_decode(gen_tokens))
2023-08-02 19:47:43.534099146 [W:onnxruntime:, session_state.cc:1169 VerifyEachNodeIsAssignedToAnEp]Some nodes were not assigned to the preferred execution providers which may or may nothave an negative impact on performance. e.g. ORT explicitly assigns shape related opsto CPU to improve perf.2023-08-02 19:47:43.534136078 [W:onnxruntime:, session_state.cc:1171 VerifyEachNodeIsAssignedToAnEp]Rerunning with verbose output on a non-minimal build will show node assignments.ο»Ώ
<s> My name is Arthur and I live in a<unk><unk><unk><unk><unk><unk>
Even with different
temperature
and other parameter values, it always yields unintelligible outputs, as reported in optimum#1248.π Update: after about a week this issue seemed to magically disappear β possibly due to a new version of
llama-2-7b-chat-hf
being released.Using the new model with
max_length=128
, :- Prompt: Why should one run Machine learning model on-premises?
- ONNX inference latency:
2.31s
- HuggingFace version latency:
3s
- ONNX inference latency:
π The ONNX model is ~23% faster than the HuggingFace variant!
β οΈ However, while both CPU and CUDA providers work, there now seems to be a bug when trying
TensorrtExecutionProvider
β reported in optimum#1278.optimum-cli
segfaults
Next letβs try with the Dolly-v2 7B from Databricks. The equivalent
optimum-cli
command for ONNX conversion would be:optimum-cli export onnx \ --model 'databricks/dolly-v2-7b' --task text-generation --framework pt \ --opset 17 --sequence_length 1024 --batch_size 1 --fp16 --device cuda \ dolly_optimum
π’ It uses around 17GB of my GPU RAM, seemingly working fine but finally ending with a segmentation fault:
======= Diagnostic Run torch.onnx.export version 2.1.0.dev20230804+cu118 =======verbose: False, log level: 40======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================Saving external data to one file...2023-08-09 20:59:33.334484259 [W:onnxruntime:, session_state.cc:1169 VerifyEachNodeIsAssignedToAnEp]Some nodes were not assigned to the preferred execution providers which may or may nothave an negative impact on performance. e.g. ORT explicitly assigns shape related opsto CPU to improve perf.2023-08-09 20:59:33.334531829 [W:onnxruntime:, session_state.cc:1171 VerifyEachNodeIsAssignedToAnEp]Rerunning with verbose output on a non-minimal build will show node assignments.Asked a sequence length of 1024, but a sequence length of 1 will be used withuse_past == True for `input_ids`.Post-processing the exported models...Segmentation fault (core dumped)
Confusingly, despite this error, all model files seem to be converted and saved to disk. Other people have reported similar segfault issues while exporting (transformers#21360, optimum#798).
Results using the Dolly v2 model:
- Prompt: Why should one run Machine learning model on-premises?
- ONNX inference latency:
8.2s
- HuggingFace version latency:
5.2s
- ONNX inference latency:
π The ONNX model is actually ~58% slower than the HuggingFace variant!
To make things faster, we can try to optimize the model:
optimum-cli onnxruntime optimize -O4 --onnx_model ./dolly_optimum/ -o dolly_optimized/
-O1
: basic general optimizations.
-O2
: basic and extended general optimizations, transformers-specific fusions.
-O3
: same as O2 with GELU approximation.
-O4
: same as O3 with mixed precision (fp16, GPU-only).
We still get the same segfault error for all of the levels.
For
-O1
, the model gets saved but thereβs no noticeable performance change. For -O2
it gets killed (even though I have 40GB A100 GPU + 80GB CPU RAM). Meanwhile for -O3
& -O4
it gives seg-fault (above) while only partially saving the model files.torch.onnx.export
gibberish images
Moving on from text-based models, letβs now look at an image generator. We can try to speed up the Stable Diffusion 2.1 model. In an IPython shell:
from diffusers import StableDiffusionPipelinepipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16).to("cuda:0")%time img = pipe("Iron man laughing", num_inference_steps=20, num_images_per_prompt=1).images[0]img.save("iron_man.png", format="PNG")
The latency (as measured by the
%time
magic) is 3.25 s
.python convert_stable_diffusion_checkpoint_to_onnx.py \ --model_path stabilityai/stable-diffusion-2-1 \ --output_path sd_onnx/ --opset 16 --fp16
βΉοΈ Note: if a model uses operators unsupported by the opset number above, you'll have to upgrade pytorch to the nightly build:ο»Ώ
pip uninstall torchpip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
The result is:
sd_onnx/βββ model_index.jsonβββ schedulerβ βββ scheduler_config.jsonβββ text_encoderβ βββ model.onnxβββ tokenizerβ βββ merges.txtβ βββ special_tokens_map.jsonβ βββ tokenizer_config.jsonβ βββ vocab.jsonβββ unetβ βββ model.onnxβ βββ weights.pbβββ vae_decoderβ βββ model.onnxβββ vae_encoder βββ model.onnx
Thereβs a separate ONNX model for each Stable Diffusion subcomponent model.
Now to benchmark this similarly we can do the following:
from diffusers import OnnxStableDiffusionPipelinepipe = OnnxStableDiffusionPipeline.from_pretrained("sd_onnx", provider="CUDAExecutionProvider")%time img = pipe("Iron man laughing", num_inference_steps=20, num_images_per_prompt=1).images[0]img.save("iron_man.png", format="PNG")
The overall performance results look great, at ~59% faster! We also didnβt see any noticeable quality difference between the models.
- Prompt: Iron man laughing
- ONNX inference latency:
1.34s
- HuggingFace version latency:
3.25s
- ONNX inference latency:
Since we know that the
unet
model is the bottleneck, taking ~90% of the compute time, we can focus on it for further optimization. We try to serialize the ONNX version of the UNet to a TensorRT engine-compatible format. When building the engine, the builder object selects the most optimized kernels for the chosen platform and configuration. Building the engine from a network definition file can be time-consuming, and should not be repeated each time we need to perform inference unless the model/platform/configuration changes. You can transform the format of the engine after generation and save it to disk for later reuse (known as serializing the engine). Deserializing occurs when you load the engine from disk into memory:
ο»Ώ
To setup TensorRT properly, follow this support table. Itβs a bit painful, and (similar to cuda/cudnn) if you just want a quick solution you can use NVIDIAβs
tensorrt:22.12-py3
docker image as a base:FROM nvcr.io/nvidia/tensorrt:22.12-py3ENV CUDA_MODULE_LOADING=LAZYRUN pip install ipython transformers optimum[onnxruntime-gpu] onnx diffusers accelerate scipy safetensors composerRUN pip uninstall torch -y && pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121COPY sd_onnx sd_onnx
We can then use the following script for serialization:
import tensorrt as trtimport torchο»Ώ
onnx_model = "sd_onnx/unet/model.onnx"engine_filename = "unet.trt" # saved serialized tensorrt engine file path# constantsbatch_size = 1height = 512width = 512latents_shape = (batch_size, 4, height // 8, width // 8)# shape required by Stable Diffusion 2.1's UNet modelembed_shape = (batch_size, 64, 1024)timestep_shape = (batch_size,)ο»Ώ
TRT_LOGGER = trt.Logger(trt.Logger.INFO)TRT_BUILDER = trt.Builder(TRT_LOGGER)network = TRT_BUILDER.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))config = TRT_BUILDER.create_builder_config()profile = TRT_BUILDER.create_optimization_profile()ο»Ώ
print("Loading & validating ONNX model")onnx_parser = trt.OnnxParser(network, TRT_LOGGER)parse_success = onnx_parser.parse_from_file(onnx_model)for idx in range(onnx_parser.num_errors): print(onnx_parser.get_error(idx))if not parse_success: raise ValueError("ONNX model parsing failed")ο»Ώ
# set input, latent and other shapes required by the layersprofile.set_shape("sample", latents_shape, latents_shape, latents_shape)profile.set_shape("encoder_hidden_states", embed_shape, embed_shape, embed_shape)profile.set_shape("timestep", timestep_shape, timestep_shape, timestep_shape)config.add_optimization_profile(profile)ο»Ώ
config.set_flag(trt.BuilderFlag.FP16)print(f"Serializing & saving engine to '{engine_filename}'")serialized_engine = TRT_BUILDER.build_serialized_network(network, config)with open(engine_filename, 'wb') as f: f.write(serialized_engine)
Now letβs move to deserializing
unet.trt
for inference. Weβll use the TRTModel
class from x-stable-diffusionβs trt_model
:import torchimport tensorrt as trttrt.init_libnvinfer_plugins(None, "")import pycuda.autoinitfrom diffusers import AutoencoderKL, LMSDiscreteSchedulerfrom PIL import Imagefrom torch import autocastfrom transformers import CLIPTextModel, CLIPTokenizerfrom trt_model import TRTModelfrom tqdm.contrib import tenumerateο»Ώ
class TrtDiffusionModel: def __init__(self): self.device = torch.device("cuda") self.unet = TRTModel("./unet.trt") # tensorrt engine saved path self.vae = AutoencoderKL.from_pretrained( "stabilityai/stable-diffusion-2-1", subfolder="vae").to(self.device) self.tokenizer = CLIPTokenizer.from_pretrained( "stabilityai/stable-diffusion-2-1", subfolder="tokenizer") self.text_encoder = CLIPTextModel.from_pretrained( "stabilityai/stable-diffusion-2-1", subfolder="text_encoder").to(self.device) self.scheduler = LMSDiscreteScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)ο»Ώ
def predict( self, prompts, num_inference_steps=50, height=512, width=512, max_seq_length=64 ): guidance_scale = 7.5 batch_size = 1 text_input = self.tokenizer( prompts, padding="max_length", max_length=max_seq_length, truncation=True, return_tensors="pt") text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] uncond_input = self.tokenizer( [""] * batch_size, padding="max_length", max_length=max_seq_length, return_tensors="pt") uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] text_embeddings = torch.cat([uncond_embeddings, text_embeddings])ο»Ώ
latents = torch.randn((batch_size, 4, height // 8, width // 8)).to(self.device) self.scheduler.set_timesteps(num_inference_steps) latents = latents * self.scheduler.sigmas[0]ο»Ώ
with torch.inference_mode(), autocast("cuda"): for i, t in tenumerate(self.scheduler.timesteps): latent_model_input = torch.cat([latents] * 2) sigma = self.scheduler.sigmas[i] latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) # predict the noise residual inputs = [ latent_model_input, torch.tensor([t]).to(self.device), text_embeddings] noise_pred = self.unet(inputs, timing=True) noise_pred = torch.reshape(noise_pred[0], (batch_size*2, 4, 64, 64)) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred.cuda(), t, latents)["prev_sample"] # scale and decode the image latents with VAE latents = 1 / 0.18215 * latents image = self.vae.decode(latents).sample return imageο»Ώ
model = TrtDiffusionModel()image = model.predict( prompts="Iron man laughing, real photoshoot", num_inference_steps=25, height=512, width=512, max_seq_length=64)image = (image / 2 + 0.5).clamp(0, 1)image = image.detach().cpu().permute(0, 2, 3, 1).numpy()images = (image * 255).round().astype("uint8")pil_images = [Image.fromarray(image) for image in images]pil_images[0].save("image_generated.png")
The above script runs, but the generated output looks like this:

ο»Ώ
Somethingβs going wrong, and changing to different tensor shapes (defined above) also doesnβt help fix the generation of blank/noisy images.
I donβt know how to make Stable Diffusion 2.1 work with TensorRT, though itβs proved possible for other Stable Diffusion variants in AUTOMATIC1111/stable-diffusion-webui. Others reporting similar issues in stable-diffusion-webui#5503 have suggested:
- Use more than 16-bits: I did, but it didnβt help.
- Use
xformers
: For our model we needpytorch
βs recently addedscaled_dot_product_attention
operator.
Other Frustrations
Maybe the code above is partially in my control, but there are also other issues that have nothing to do with my code:
- Licences: Text Generation Inference recently they came up with a new license which is more restrictive for newer versions. I can only use old releases (up to v0.9).
- Lack of GPU support: GGML doesnβt currently support GPU inference, so I canβt use it if I want very low latency.
- Quality: Iβve heard from peers that saw a big decrease in output quality vLLM. Iβd like to explore this in future.
Conclusion
Iβve listed my recent errors and frustrations. I need more time to dig deeper and solve them, but if you think you can help please do reply in any of the issues linked above! By sharing my experiences and challenges, I hope this can spark lots of discussions and new ideas. Maybe youβve faced something similar?
While the world likes showcasing the latest advancements and shiny results, itβs important to also acknowledge and address the underlying complexities that come with deploying & maintaining ML models. Thereβs a scarcity of documentation/resources for these problems in the ML community. As the field continues to rapidly evolve, there is a need for more in-depth discussions and solutions to these technical hurdles.
Dive in
Related
19:37
video
AIOps, MLOps, DevOps, Ops: Enduring Principles and Practices
By MLOps Community β’Β Aug 15th, 2024 β’ Views 120
Blog
More Automation + More Reproducibility = MLOps Python Package v4.1.0
By MΓ©dΓ©ric Hurier β’Β Mar 26th, 2025 β’ Views 131
Blog
More Automation + More Reproducibility = MLOps Python Package v4.1.0
By MΓ©dΓ©ric Hurier β’Β Mar 26th, 2025 β’ Views 131
19:37
video
AIOps, MLOps, DevOps, Ops: Enduring Principles and Practices
By MLOps Community β’Β Aug 15th, 2024 β’ Views 120