Last Updated on 2024-09-01 by Clay
Introduction
Generative models are becoming increasingly powerful, and independent researchers are deploying one open-source large language model (LLMs) after another. However, when using LLMs for inference or generating responses, waiting for a longer output can be quite time-consuming.
In fact, streaming output like that in ChatGPT, where a sequence of generated tokens is output in chunks, significantly enhances the user experience.
As the GitHub of the open-source model community, HuggingFace naturally recognized this demand. In the transformers 4.30.1 provided by HuggingFace, the following two interfaces are offered for model.generate()
:
- TextStreamer: Directly prints the model-generated response to standard output (stdout)
- TextIteratorStreamer: Uses a thread to set up the generation task and prints the model-generated response in a customizable iterative manner (typically a for loop)
For testing models, TextStreamer is sufficient; for production-level applications, TextIteratorStreamer is a must.
Usage
TextStreamer
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
# Tokenizer and Model
pretrained_model_name_or_path = "sshleifer/tiny-gpt2"
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)
# Tokenized
text = "How are you?"
inputs = tokenizer(text, return_tensors="pt")
streamer = TextStreamer(tokenizer=tokenizer)
# Generation
model.generate(**inputs, streamer=streamer, max_new_tokens=50)
The output will be directly printed on the screen.
TextIteratorStreamer
from threading import Thread
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
# Tokenizer and Model
pretrained_model_name_or_path = "sshleifer/tiny-gpt2"
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)
# Tokenized
text = "How are you?"
inputs = tokenizer(text, return_tensors="pt")
streamer = TextIteratorStreamer(tokenizer=tokenizer)
# Arguments
inputs.update({"streamer": streamer, "max_new_tokens": 50})
# Generation
thread = Thread(target=model.generate, kwargs=inputs)
thread.start()
for token in streamer:
print(token, end="")
thread.join()
Here, I chose to directly print the output, but in reality, you can freely choose various ways to display the streaming data — for example, in a production environment, it might be presented on a webpage.