diff --git a/download_base_direct_aria2.sh b/download_base_direct_aria2.sh new file mode 100644 index 0000000..f4de32b --- /dev/null +++ b/download_base_direct_aria2.sh @@ -0,0 +1,335 @@ +#!/bin/bash + +# ============================================================================= +# Bench2Drive Base Dataset Download Script (Direct aria2c with TLS fix) +# ============================================================================= +# This script downloads the Base dataset using aria2c directly with proper TLS settings +# It auto-downloads aria2c binary if not installed +# Usage: bash download_base_direct_aria2.sh [download_dir] [threads] +# ============================================================================= + +set -e + +# 配置参数 +REPO_ID="rethinklab/Bench2Drive" +DATASET_NAME="Bench2Drive-Base" +DEFAULT_DOWNLOAD_DIR="./Bench2Drive-Base" +DEFAULT_THREADS=8 + +# 用户提供的镜像 URL +JSON_URL="https://git.hyuyao.cn/sam/binary-mirror/raw/branch/main/bench2drive_base_1000.json" +ARIA2_URL="https://git.hyuyao.cn/sam/binary-mirror/raw/branch/main/aria2-x86_64-linux-musl_static.zip" + +# 下载目录和线程数 +DOWNLOAD_DIR="${1:-$DEFAULT_DOWNLOAD_DIR}" +THREADS="${2:-$DEFAULT_THREADS}" + +# 本地 aria2c 路径 +LOCAL_ARIA2_DIR="./.aria2" +LOCAL_ARIA2_BIN="$LOCAL_ARIA2_DIR/aria2c" +ARIA2_BIN="aria2c" + +# 文件列表 URL +FILE_LIST_URL="https://hf-mirror.com/datasets/rethinklab/Bench2Drive/resolve/main/" + +# 颜色输出 +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +CYAN='\033[0;36m' +NC='\033[0m' + +print_info() { echo -e "${BLUE}[INFO]${NC} $1"; } +print_success() { echo -e "${GREEN}[SUCCESS]${NC} $1"; } +print_warning() { echo -e "${YELLOW}[WARNING]${NC} $1"; } +print_error() { echo -e "${RED}[ERROR]${NC} $1"; } +print_highlight() { echo -e "${CYAN}$1${NC}"; } + +show_banner() { + echo "" + print_highlight "=============================================================================" + print_highlight " Bench2Drive Base Dataset Downloader (Direct aria2c with TLS fix)" + print_highlight "=============================================================================" + echo "" +} + +# 下载并安装本地 aria2c +download_local_aria2() { + print_info "Downloading aria2c from mirror..." + print_info "URL: $ARIA2_URL" + + mkdir -p "$LOCAL_ARIA2_DIR" + local temp_zip="$LOCAL_ARIA2_DIR/aria2.zip" + + # 下载 aria2c 压缩包 + if ! curl -k -L -o "$temp_zip" "$ARIA2_URL"; then + print_error "Failed to download aria2c" + return 1 + fi + + # 解压 + print_info "Extracting aria2c..." + if command -v unzip &> /dev/null; then + unzip -o "$temp_zip" -d "$LOCAL_ARIA2_DIR" + else + print_error "unzip is not installed. Please install it:" + echo " Ubuntu/Debian: sudo apt install unzip" + exit 1 + fi + + rm -f "$temp_zip" + + # 验证 + if [ -f "$LOCAL_ARIA2_BIN" ]; then + chmod +x "$LOCAL_ARIA2_BIN" + print_success "aria2c installed to: $LOCAL_ARIA2_BIN" + "$LOCAL_ARIA2_BIN" --version | head -1 + ARIA2_BIN="$LOCAL_ARIA2_BIN" + return 0 + else + print_error "Failed to extract aria2c" + return 1 + fi +} + +# 检查依赖 +check_dependencies() { + print_info "Checking dependencies..." + + # 检查 aria2c + if command -v aria2c &> /dev/null; then + print_success "System aria2c is installed" + ARIA2_VERSION=$(aria2c --version | head -1) + print_info "Version: $ARIA2_VERSION" + elif [ -f "$LOCAL_ARIA2_BIN" ]; then + print_success "Local aria2c found: $LOCAL_ARIA2_BIN" + ARIA2_BIN="$LOCAL_ARIA2_BIN" + else + print_warning "aria2c not found, will download from mirror..." + download_local_aria2 + fi + + # 检查 curl + if ! command -v curl &> /dev/null; then + print_error "curl is not installed." + exit 1 + fi + print_success "curl is installed" + + # 检查 unzip + if ! command -v unzip &> /dev/null; then + print_warning "unzip is not installed (needed for downloading aria2c)" + print_info "Please install: sudo apt install unzip" + fi +} + +# 下载文件列表 +download_file_list() { + print_info "Downloading file list from mirror..." + print_info "URL: $JSON_URL" + + local temp_json="/tmp/bench2drive_base_1000.json" + + # 从镜像下载 JSON + if curl -k -L -o "$temp_json" "$JSON_URL"; then + if [ -f "$temp_json" ] && [ -s "$temp_json" ]; then + # 使用 python3 解析 JSON + cat "$temp_json" | python3 -c "import json,sys; data=json.load(sys.stdin); print('\n'.join(data.keys()))" > /tmp/bench2drive_files.txt 2>/dev/null + + # 如果 python3 失败,使用 grep 提取 .tar.gz 文件名 + if [ ! -s /tmp/bench2drive_files.txt ]; then + cat "$temp_json" | grep -o '"[^"]*\.tar\.gz"' | sed 's/"//g' > /tmp/bench2drive_files.txt + fi + + FILE_COUNT=$(wc -l < /tmp/bench2drive_files.txt) + print_success "Downloaded file list from mirror" + print_info "Found $FILE_COUNT files to download" + rm -f "$temp_json" + return 0 + fi + fi + + print_error "Failed to download file list from mirror" + exit 1 +} + +# 创建 aria2c 输入文件 +create_aria2_input() { + print_info "Creating aria2c input file..." + + local input_file="/tmp/bench2drive_aria2_input.txt" + > "$input_file" + + while IFS= read -r filename; do + if [ -n "$filename" ]; then + echo "${FILE_LIST_URL}${filename}" >> "$input_file" + echo " dir=${DOWNLOAD_DIR}" >> "$input_file" + echo " out=${filename}" >> "$input_file" + fi + done < /tmp/bench2drive_files.txt + + print_success "Created aria2c input file with $(wc -l < "$input_file") lines" +} + +# 获取 aria2c 选项 +get_aria2_options() { + local opts="" + + # 基础选项 + opts="--continue=true" + opts="$opts --max-concurrent-downloads=5" + opts="$opts --split=$THREADS" + opts="$opts --max-connection-per-server=$THREADS" + opts="$opts --min-split-size=10M" + opts="$opts --max-tries=5" + opts="$opts --retry-wait=30" + opts="$opts --timeout=600" + opts="$opts --connect-timeout=60" + opts="$opts --allow-overwrite=false" + opts="$opts --auto-file-renaming=false" + opts="$opts --conditional-get=true" + opts="$opts --console-log-level=warn" + opts="$opts --summary-interval=0" + + # TLS 选项 - 显式指定 CA 证书路径 + if [ -f "/etc/pki/tls/certs/ca-bundle.crt" ]; then + opts="$opts --ca-certificate=/etc/pki/tls/certs/ca-bundle.crt" + opts="$opts --check-certificate=true" + print_info "Using CA certificate: /etc/pki/tls/certs/ca-bundle.crt" + elif [ -f "/etc/pki/ca-trust/extracted/pem/tls-ca-bundle.pem" ]; then + opts="$opts --ca-certificate=/etc/pki/ca-trust/extracted/pem/tls-ca-bundle.pem" + opts="$opts --check-certificate=true" + print_info "Using CA certificate: /etc/pki/ca-trust/extracted/pem/tls-ca-bundle.pem" + else + opts="$opts --check-certificate=false" + print_warning "CA certificate not found, disabling certificate verification" + fi + + echo "$opts" +} + +# 设置环境 +setup_environment() { + print_info "Setting up environment..." + export HF_ENDPOINT="https://hf-mirror.com" + print_success "HF_ENDPOINT set to: $HF_ENDPOINT" + + # 创建下载目录 + if [ ! -d "$DOWNLOAD_DIR" ]; then + mkdir -p "$DOWNLOAD_DIR" + print_success "Created download directory: $DOWNLOAD_DIR" + else + print_warning "Download directory already exists: $DOWNLOAD_DIR" + fi +} + +# 下载数据集 +download_dataset() { + print_info "Starting download..." + print_info "Download directory: $DOWNLOAD_DIR" + print_info "Threads per file: $THREADS" + print_info "Dataset size: ~400GB (1000 clips)" + echo "" + + local aria2_opts=$(get_aria2_options) + local input_file="/tmp/bench2drive_aria2_input.txt" + + print_info "Running aria2c..." + echo "" + + # 执行下载 + "$ARIA2_BIN" $aria2_opts --input-file="$input_file" + + if [ $? -eq 0 ]; then + echo "" + print_success "Download completed successfully!" + else + echo "" + print_error "Download failed or interrupted." + print_info "You can resume by running this script again." + exit 1 + fi +} + +# 验证下载 +verify_download() { + print_info "Verifying downloaded files..." + + local downloaded_count=$(find "$DOWNLOAD_DIR" -name "*.tar.gz" 2>/dev/null | wc -l) + local expected_count=1000 + + echo "" + print_info "Downloaded files: $downloaded_count / $expected_count" + + if [ "$downloaded_count" -eq "$expected_count" ]; then + print_success "All files downloaded successfully!" + elif [ "$downloaded_count" -gt 0 ]; then + print_warning "Partial download: $downloaded_count / $expected_count" + print_info "Run the script again to resume." + else + print_error "No .tar.gz files found in $DOWNLOAD_DIR" + fi +} + +# 清理 +cleanup() { + rm -f /tmp/bench2drive_files.txt /tmp/bench2drive_aria2_input.txt +} + +# 显示帮助 +show_help() { + cat << EOF +Bench2Drive Base Dataset Download Script + +Usage: + bash $0 [download_directory] [threads] + +Arguments: + download_directory Directory to save the dataset (default: ./Bench2Drive-Base) + threads Number of download threads per file (default: 8, max: 16) + +Examples: + bash $0 + bash $0 ./Bench2Drive-Base 16 + +Features: + - Auto-downloads aria2c binary if not installed + - Downloads file list from mirror (git.hyuyao.cn) + - Resume capability (断点续传) + - Multi-threaded download + +EOF +} + +# 主函数 +main() { + if [ "$1" == "-h" ] || [ "$1" == "--help" ]; then + show_help + exit 0 + fi + + show_banner + + # 验证线程数 + if ! [[ "$THREADS" =~ ^[1-9][0-9]*$ ]] || [ "$THREADS" -gt 16 ]; then + print_error "Invalid thread count: $THREADS" + exit 1 + fi + + trap cleanup EXIT + + check_dependencies + setup_environment + download_file_list + create_aria2_input + download_dataset + verify_download + + echo "" + print_highlight "=============================================================================" + print_success "All done! Dataset saved to: $DOWNLOAD_DIR" + print_highlight "=============================================================================" +} + +main "$@"