Skip to content

使用 snapshot_download 下載 HuggingFace Hub 上的模型

Last Updated on 2024-04-17 by Clay

介紹

HuggingFace Model Hub 現在已經是無人不知、無人不曉的重要開源平台了,每天都有無數的人或組織上傳自己訓練出來的最新模型(包含文字、圖像、語音...... 等等不同領域)到這個平台上。可說是現在凡舉是個做 AI 相關工作的人,想必都會經常瀏覽 HuggingFace 他們的家的平台網站。

具體來說,我們可以下載 HuggingFace Hub 上的模型到本地端,並在本地端自由使用這些模型(注意,並不是全部的模型皆可免費商用,仍然需要注意授權!)。那麼,我們該如何優雅地下載模型呢?

網路上有各式各樣的下載教學,有的甚至動用到多核心、分片式地下載,有的甚至支援斷點續傳 —— 然而,最簡單的方法仍然無非是 git clone 以及 snapshot_download(),尤其是不希望模型下載後自動在 ~/.cache/huggingface/ 底下儲存一份原始檔案的,更是可以直接使用上述兩種方法。

之後若是有研究其他性能更強、更簡單的下載方式,再另行撰文記錄分享。


git clone

git clone 可說是最簡單的方式之一了,但最大的缺點就是為了追蹤檔案版本,下載時會額外下載 .git 隱藏資料夾,並且這個隱藏資料夾的容量跟模型本體一樣大,比方說你下載了 14GB 的模型,最後去查看時會發現足足有 28GB!

這還滿浪費流量的,所以之後若無頻繁更新的需求,建議直接刪除 .git。

而下載方式則是,若你要下載的 repo_id 為 openai-community/gpt2,則前面要再記得加上 huggingface 的 URL:

git clone https://huggingface.co/openai-community/gpt2

snapshot_download()

要使用 snapshot_download(),必須先行安裝 huggingface_hub

pip3 install huggingface_hub


之後,我們可以使用:

snapshot_download(
repo_id=repo_id,
local_dir=local_dir,
local_dir_use_symlinks=False,
)


來下載模型。

這裡簡單介紹一下幾個最重要的參數:

  • repo_id (str): 要下載的 repo 名稱,例如 openai-community/gpt2
  • local_dir (str): 儲存模型的位置
  • local_dir_use_symlinks (bool): 是否要使用軟連結指向 cache 中的原始模型

這裡有個小坑是,如果你要下載的模型在本地端的 ~/.cache/huggingface/ 底下已經存在了,那麼你就算將 local_dir_use_symlinks 設為 False,仍然會使用軟連結指向 ~/.cache/huggingface/ 底下的模型。

所以建議在下載模型時,先確認以前是否已經有 cache 在本地端;若是難以確認並且 cache 資料不重要的話,乾脆把 ~/.cache/huggingface/ 底下的資料夾都先刪除了吧!不會影響到程式執行的,頂多需要小模型時重新拉。

以下提供一個簡單使用腳本。

不需要事先建立 model_hub/ 資料夾,這是程式自動建立的。

amazon/chronos-t5-tiny
prajjwal1/bert-tiny
NousResearch/Hermes-2-Pro-Mistral-7B


這是一個範例,裡面可以條列各種想要下載的模型 repo id,等等會依序下載。


import argparse 
import os

from huggingface_hub import snapshot_download


def main() -> None:
# Arguments
parser = argparse.ArgumentParser(description="Download a snapshot from Huggingface Model Hub")
parser.add_argument("--download_file", type=str, required=True, help="The list file of repository id")
parser.add_argument("--local_dir", type=str, default="./", help="Directory to save the downloaded snapshot")

# Parsing
args = parser.parse_args()

# Check `local_dir` is existed
os.makedirs(args.local_dir, exist_ok=True)

# Get all repo id
with open(args.download_file, "r") as f:
repo_ids = [repo_id for repo_id in f.read().splitlines() if repo_id.strip()]

# Donwload
for repo_id in repo_ids:
local_dir = os.path.join(args.local_dir, repo_id.replace("/", "--"))

if os.path.isdir(local_dir):
print(f"{repo_id} is existed, pass.")
continue

snapshot_download(
repo_id=repo_id,
local_dir=local_dir,
local_dir_use_symlinks=False,
)
print(f"\n{repo_id} is finished.\n")


if __name__ == "__main__":
main()


這裡是模型下載的程式碼,裡面我也提供了下載模型的 repo id 自動把 / 轉換成 -- 的操作、也在下載目錄已經有同樣模型的情況下,自動略過不要重複下載。


#!/bin/bash


time python3 download.py \
--download_file ./model_list.txt \
--local_dir ./model_hub/


最後就是自動執行的腳本了,這裡直接設定好了要輸入的檔案以及下載目錄。

別忘了使用 chmod +x download.sh 將其轉換成可執行檔。之後只需要使用 ./download.sh 就可以執行了。


References


Read More

Leave a Reply取消回覆

Exit mobile version