找回密码
 立即注册
首页 业界区 安全 端侧大模型实践 - 生成预测模型&模型轻量化&端侧部署 ...

端侧大模型实践 - 生成预测模型&模型轻量化&端侧部署

梨恐 2026-2-12 17:40:09
为避免模型训练中出现内存异常,先临时增加交换内存:
  1. # 1. 先关闭旧的交换文件(如果之前创建过1GB的)
  2. swapoff /swapfile || true
  3. rm -rf /swapfile || true
  4. # 2. 创建4GB交换文件(bs=1M表示每个块1MB,count=4096表示4096个块=4GB)
  5. dd if=/dev/zero of=/swapfile bs=1M count=4096
  6. # 3. 设置文件权限(仅root可访问,安全要求)
  7. chmod 600 /swapfile
  8. # 4. 格式化交换文件
  9. mkswap /swapfile
  10. # 5. 启用交换文件
  11. swapon /swapfile
  12. # 6. 验证是否生效(查看Swap列,应该显示4.0Gi)
  13. free -h
复制代码
编写模型训练代码:
  1. import paddle
  2. import os
  3. import random
  4. from paddlenlp.transformers import ErnieTokenizer, ErnieForSequenceClassification
  5. from paddlenlp.data import Stack, Tuple, Pad
  6. from paddle.io import DataLoader, BatchSampler, Dataset
  7. # 核心配置:强制CPU + 最小化内存占用
  8. paddle.set_device("cpu")
  9. paddle.disable_static()
  10. paddle.set_default_dtype("float32")  # 降低精度减少内存
  11. # 1. 极简数据集类(降低内存波动)
  12. class SimpleDataset(Dataset):
  13.     def __init__(self, data):
  14.         self.data = data
  15.    
  16.     def __getitem__(self, idx):
  17.         return self.data[idx]
  18.    
  19.     def __len__(self):
  20.         return len(self.data)
  21. # 仅保留3条核心数据,避免内存占用
  22. raw_data = [
  23.     ("我的订单怎么还没发货", 0),
  24.     ("申请退款多久到账", 1),
  25.     ("产品保质期多久", 2)
  26. ]
  27. train_dataset = SimpleDataset(raw_data)
  28. print(f"加载极简训练数据,共 {len(train_dataset)} 条")
  29. # 2. 初始化模型和分词器(忽略权重警告)
  30. tokenizer = ErnieTokenizer.from_pretrained("ernie-3.0-mini-zh")
  31. model = ErnieForSequenceClassification.from_pretrained(
  32.     "ernie-3.0-mini-zh",
  33.     num_classes=3,
  34.     ignore_mismatched_sizes=True  # 关闭权重不匹配警告
  35. )
  36. # 3. 数据预处理(最短文本长度,减少张量大小)
  37. def convert_example(example):
  38.     text, label = example
  39.     inputs = tokenizer(
  40.         text,
  41.         max_len=16,  # 最短文本长度
  42.         padding="max_length",
  43.         truncation=True,
  44.         return_length=False
  45.     )
  46.     return inputs["input_ids"], inputs["token_type_ids"], label
  47. # 提前预处理所有数据,避免加载器中重复计算
  48. processed_data = [convert_example(example) for example in train_dataset]
  49. train_dataset = SimpleDataset(processed_data)
  50. # 4. 数据加载器(批量大小=1,关闭多线程)
  51. batchify_fn = lambda samples, fn=Tuple(
  52.     Pad(axis=0, pad_val=tokenizer.pad_token_id),
  53.     Pad(axis=0, pad_val=tokenizer.pad_token_type_id),
  54.     Stack(dtype="int64")
  55. ): fn(samples)
  56. # 精细化批次控制,最低内存占用
  57. sampler = BatchSampler(train_dataset, batch_size=1, shuffle=True)
  58. train_loader = DataLoader(
  59.     dataset=train_dataset,
  60.     batch_sampler=sampler,
  61.     collate_fn=batchify_fn,
  62.     num_workers=0  # 关闭多线程,避免内存泄漏
  63. )
  64. # 5. 训练配置(极简优化器,降低内存)
  65. model.train()
  66. # SGD优化器内存占用远低于Adam
  67. optimizer = paddle.optimizer.SGD(learning_rate=1e-3, parameters=model.parameters())
  68. loss_fn = paddle.nn.CrossEntropyLoss()
  69. # 6. 训练循环(修复no_grad错误,极简逻辑)
  70. epochs = 1
  71. total_loss = 0.0
  72. batch_count = 0
  73. print(f"开始训练 Epoch 1/{epochs}")
  74. for batch in train_loader:
  75.     input_ids, token_type_ids, labels = batch
  76.    
  77.     # 核心修复:删除错误的paddle.no_grad(False),训练需要梯度
  78.     logits = model(input_ids, token_type_ids)
  79.     loss = loss_fn(logits, labels)
  80.    
  81.     # 反向传播 + 优化
  82.     loss.backward()
  83.     optimizer.step()
  84.     optimizer.clear_grad()
  85.    
  86.     # 记录损失
  87.     total_loss += loss.numpy()[0]
  88.     batch_count += 1
  89.     print(f"Batch {batch_count} 训练完成,损失:{loss.numpy()[0]:.4f}")
  90. # 7. 保存模型(仅保存必要文件)
  91. model_dir = "./ernie_demo_model_light"
  92. os.makedirs(model_dir, exist_ok=True)
  93. model.save_pretrained(model_dir, save_config=False)
  94. tokenizer.save_pretrained(model_dir)
  95. # 最终输出
  96. avg_loss = total_loss / batch_count if batch_count > 0 else 0
  97. print("\n==== 训练完全完成 ====")
  98. print(f"模型保存路径:{os.path.abspath(model_dir)}")
  99. print(f"总训练批次:{batch_count},平均损失:{avg_loss:.4f}")
  100. print("提示:权重警告是正常现象,模型已成功训练并保存!")
复制代码
训练记录如下:
1.png

验证模型训练结果,创建 validate_model.py 文件,命令:
  1. nano /root/validate_model.py
复制代码
代码:
  1. import paddle
  2. from paddlenlp.transformers import ErnieTokenizer
  3. # 核心配置:和训练时保持一致
  4. paddle.set_device("cpu")
  5. paddle.disable_static()
  6. # 1. 加载训练好的模型和分词器
  7. # 模型路径:和训练时保存的路径一致(ernie_demo_model_light)
  8. model_path = "./ernie_demo_model_light"
  9. # 加载分词器
  10. tokenizer = ErnieTokenizer.from_pretrained(model_path)
  11. # 加载模型(和训练时的模型结构一致)
  12. from paddlenlp.transformers import ErnieForSequenceClassification
  13. model = ErnieForSequenceClassification.from_pretrained(
  14.     model_path,
  15.     num_classes=3,
  16.     ignore_mismatched_sizes=True
  17. )
  18. # 切换到评估模式(禁用Dropout等训练层)
  19. model.eval()
  20. # 2. 定义标签映射(数字→中文,方便查看)
  21. label_map = {0: "物流咨询", 1: "退款咨询", 2: "产品咨询"}
  22. # 3. 定义验证函数(输入文本,输出分类结果)
  23. def predict(text):
  24.     # 预处理文本(和训练时的逻辑完全一致)
  25.     inputs = tokenizer(
  26.         text,
  27.         max_len=16,
  28.         padding="max_length",
  29.         truncation=True,
  30.         return_length=False
  31.     )
  32.     # 转换为Paddle张量
  33.     input_ids = paddle.to_tensor([inputs["input_ids"]], dtype="int64")
  34.     token_type_ids = paddle.to_tensor([inputs["token_type_ids"]], dtype="int64")
  35.    
  36.     # 模型推理(禁用梯度计算,节省内存)
  37.     with paddle.no_grad():
  38.         logits = model(input_ids, token_type_ids)
  39.         # 获取概率最大的标签
  40.         pred_label = paddle.argmax(logits, axis=1).numpy()[0]
  41.    
  42.     # 返回直观结果
  43.     return label_map[pred_label]
  44. # 4. 验证新文本(选3条未参与训练的文本)
  45. test_texts = [
  46.     "快递到哪了?",       # 预期:物流咨询
  47.     "退款多久能到账?",   # 预期:退款咨询
  48.     "产品怎么使用?"      # 预期:产品咨询
  49. ]
  50. # 5. 执行验证并打印结果
  51. print("==== 模型验证结果 ====")
  52. for text in test_texts:
  53.     result = predict(text)
  54.     print(f"输入文本:{text}")
  55.     print(f"模型分类结果:{result}\n"):
复制代码
验证有问题,但至少跑通了(不过虽然结果不对,这里也先不做追究,因为是要跑通流程,针对细节暂时不关注)
2.png

因为默认训练后的模型是动态图模型,需要输出静态图模型和移动端模型:
  1. import paddle
  2. import os
  3. from paddlelite.lite import *
  4. from paddlenlp.transformers import ErnieForSequenceClassification
  5. # ====================== 配置项(关键修正:匹配实际分类数)=====================
  6. # 动态图模型路径(你的ernie_demo_model_light)
  7. DYNAMIC_MODEL_PATH = "./ernie_demo_model_light"
  8. # 静态图模型保存路径
  9. STATIC_MODEL_PATH = "./ernie_demo_static/model"
  10. # 移动端模型输出路径
  11. MOBILE_MODEL_PATH = "./ernie_mobile_model"
  12. # 关键修正:改为实际的分类数(从报错看是3分类)
  13. NUM_CLASSES = 3  # 原代码是2,现在改成3,匹配模型参数
  14. # ====================================================
  15. def dynamic2static():
  16.     """第一步:动态图转静态图"""
  17.     # 1. 加载训练好的动态图模型
  18.     # 关键:ignore_mismatched_sizes=True 兼容可能的维度问题(兜底)
  19.     model = ErnieForSequenceClassification.from_pretrained(
  20.         DYNAMIC_MODEL_PATH,
  21.         num_classes=NUM_CLASSES,
  22.         ignore_mismatched_sizes=True  # 新增:忽略参数维度不匹配(防止漏改分类数)
  23.     )
  24.     model.eval()  # 必须切换到评估模式
  25.    
  26.     # 2. 定义输入规格(和模型输入匹配)
  27.     input_spec = [
  28.         paddle.static.InputSpec(shape=[None, None], dtype="int64", name="input_ids"),
  29.         paddle.static.InputSpec(shape=[None, None], dtype="int64", name="token_type_ids")
  30.     ]
  31.    
  32.     # 3. 导出静态图模型(生成.pdmodel和.pdiparams)
  33.     paddle.jit.save(
  34.         layer=model,
  35.         path=STATIC_MODEL_PATH,
  36.         input_spec=input_spec
  37.     )
  38.     print(f"✅ 静态图模型已生成:")
  39.     print(f"   - {STATIC_MODEL_PATH}.pdmodel")
  40.     print(f"   - {STATIC_MODEL_PATH}.pdiparams")
  41. def static2mobile():
  42.     """第二步:静态图转移动端模型"""
  43.     # 1. 检查静态图文件是否存在
  44.     model_file = f"{STATIC_MODEL_PATH}.pdmodel"
  45.     param_file = f"{STATIC_MODEL_PATH}.pdiparams"
  46.     if not os.path.exists(model_file) or not os.path.exists(param_file):
  47.         raise FileNotFoundError("静态图模型文件不存在,请先运行dynamic2static()")
  48.    
  49.     # 2. 初始化Paddle Lite优化器
  50.     opt = Opt()
  51.     opt.set_model_file(model_file)
  52.     opt.set_param_file(param_file)
  53.    
  54.     # 3. 设置移动端适配参数(ARM架构,手机通用)
  55.     opt.set_valid_places("arm")
  56.     opt.set_model_type("naive_buffer")  # 移动端轻量化格式
  57.    
  58.     # 4. 设置输出路径并执行优化
  59.     opt.set_optimize_out(MOBILE_MODEL_PATH)
  60.     opt.run()
  61.    
  62.     print(f"\n✅ 移动端模型已生成:{MOBILE_MODEL_PATH}.nb")
  63.     print("   该文件可直接用于Android/iOS端部署")
  64. # 主执行逻辑
  65. if __name__ == "__main__":
  66.     # 动态图→静态图→移动端模型
  67.     dynamic2static()
  68.     static2mobile()
复制代码
3.png

使用paddle.lite执行端上模型预测:
  1. package com.baidu.paddle.lite;
  2. import android.content.Context;
  3. import android.util.Log;
  4. import com.baidu.paddle.lite.MobileConfig;
  5. import com.baidu.paddle.lite.PaddlePredictor;
  6. import com.baidu.paddle.lite.Tensor;
  7. import java.io.File;
  8. import java.io.FileOutputStream;
  9. import java.io.InputStream;
  10. import java.nio.FloatBuffer;
  11. /**
  12. * Paddle Lite 模型管理类
  13. * 负责模型加载、预测、资源释放
  14. */
  15. public class PaddleLiteManager {
  16.     private static final String TAG = "PaddleLiteManager";
  17.     // 模型文件名(请确保与assets中的文件完全一致,包括.nb后缀)
  18.     private static final String MODEL_FILE_NAME = "ernie_mobile_model.nb";
  19.     // 输入张量名称(根据你的模型实际输入名修改,可通过Paddle Lite工具查看)
  20.     private static final String INPUT_TENSOR_NAME = "input";
  21.     // 输出张量名称(根据你的模型实际输出名修改)
  22.     private static final String OUTPUT_TENSOR_NAME = "output";
  23.     private Context mContext;
  24.     private PaddlePredictor mPredictor; // 预测器实例
  25.     private boolean isInitSuccess = false; // 初始化状态
  26.     public PaddleLiteManager(Context context) {
  27.         this.mContext = context.getApplicationContext();
  28.     }
  29.     /**
  30.      * 初始化预测器(核心:解决模型文件找不到的问题)
  31.      */
  32.     public boolean initPredictor() {
  33.         try {
  34.             // 1. 将assets中的模型文件复制到应用内部存储(避免assets路径访问限制)
  35.             File modelFile = copyAssetFileToInternalStorage(MODEL_FILE_NAME);
  36.             if (modelFile == null || !modelFile.exists()) {
  37.                 Log.e(TAG, "模型文件复制失败或不存在:" + (modelFile != null ? modelFile.getAbsolutePath() : "null"));
  38.                 return false;
  39.             }
  40.             Log.d(TAG, "模型文件路径:" + modelFile.getAbsolutePath());
  41.             // 2. 配置MobileConfig
  42.             MobileConfig config = new MobileConfig();
  43.             config.setModelFromFile(modelFile.getAbsolutePath()); // 使用绝对路径加载
  44.             config.setThreads(4); // 设置线程数(根据设备调整)
  45. //            config.setPowerMode(MobileConfig.PowerMode.LITE_POWER_HIGH); // 高性能模式
  46.             // 可选:设置精度模式(根据需求选择)
  47.             // config.setPrecisionMode(MobileConfig.PrecisionMode.LITE_PRECISION_FP32);
  48.             // 3. 创建预测器
  49.             mPredictor = PaddlePredictor.createPaddlePredictor(config);
  50.             if (mPredictor == null) {
  51.                 Log.e(TAG, "预测器创建失败");
  52.                 return false;
  53.             }
  54.             isInitSuccess = true;
  55.             Log.d(TAG, "模型初始化成功!");
  56.             return true;
  57.         } catch (Exception e) {
  58.             Log.e(TAG, "初始化预测器异常:", e);
  59.             isInitSuccess = false;
  60.             return false;
  61.         }
  62.     }
  63.     /**
  64.      * 执行预测(示例:输入float数组,返回输出结果)
  65.      * @param inputData 输入数据(需与模型输入形状匹配,示例:[1, 128]的float数组)
  66.      * @param inputShape 输入形状(示例:new long[]{1, 128})
  67.      * @return 输出结果float数组,失败返回null
  68.      */
  69.     public float[] runPredict(float[] inputData, long[] inputShape) {
  70.         if (!isInitSuccess || mPredictor == null) {
  71.             Log.e(TAG, "预测器未初始化成功,无法执行预测");
  72.             return null;
  73.         }
  74.         try {
  75.             // 1. 获取输入张量
  76.             Tensor inputTensor = mPredictor.getInput(0);
  77.             if (inputTensor == null) {
  78.                 Log.e(TAG, "获取输入张量失败,张量名:" + INPUT_TENSOR_NAME);
  79.                 return null;
  80.             }
  81.             // 2. 设置输入形状和数据
  82.             inputTensor.resize(inputShape);
  83.             inputTensor.setData(inputData);
  84.             // 3. 执行预测
  85.             long startTime = System.currentTimeMillis();
  86.             boolean predictResult = mPredictor.run();
  87.             long endTime = System.currentTimeMillis();
  88.             Log.d(TAG, "预测耗时:" + (endTime - startTime) + "ms");
  89.             if (!predictResult) {
  90.                 Log.e(TAG, "预测执行失败");
  91.                 return null;
  92.             }
  93.             // 4. 获取输出张量
  94.             Tensor outputTensor = mPredictor.getOutput(0);
  95.             if (outputTensor == null) {
  96.                 Log.e(TAG, "获取输出张量失败,张量名:" + OUTPUT_TENSOR_NAME);
  97.                 return null;
  98.             }
  99.             // 5. 读取输出数据
  100. //            float[] outputBuffer = outputTensor.getByteData();
  101. //            float[] outputData = new float[outputBuffer.remaining()];
  102. //            outputBuffer.get(outputData);
  103.             Log.d(TAG, "预测成功,输出数据长度:" + outputTensor.getFloatData().length);
  104.             return outputTensor.getFloatData();
  105.         } catch (Exception e) {
  106.             Log.e(TAG, "预测过程异常:", e);
  107.             return null;
  108.         }
  109.     }
  110.     /**
  111.      * 释放资源
  112.      */
  113.     public void release() {
  114.         if (mPredictor != null) {
  115. //            mPredictor.close();
  116.             mPredictor = null;
  117.         }
  118.         isInitSuccess = false;
  119.         Log.d(TAG, "预测器资源已释放");
  120.     }
  121.     /**
  122.      * 将assets中的文件复制到应用内部存储
  123.      * @param fileName assets中的文件名
  124.      * @return 复制后的文件,失败返回null
  125.      */
  126.     private File copyAssetFileToInternalStorage(String fileName) {
  127.         InputStream inputStream = null;
  128.         FileOutputStream outputStream = null;
  129.         try {
  130.             // 目标文件路径:/data/data/com.baidu.paddle.lite/files/xxx.nb
  131.             File destFile = new File(mContext.getFilesDir(), fileName);
  132.             // 如果文件已存在,直接返回
  133.             if (destFile.exists()) {
  134.                 Log.d(TAG, "模型文件已存在,无需重复复制:" + destFile.getAbsolutePath());
  135.                 return destFile;
  136.             }
  137.             // 从assets读取文件
  138.             inputStream = mContext.getAssets().open(fileName);
  139.             outputStream = new FileOutputStream(destFile);
  140.             // 复制文件
  141.             byte[] buffer = new byte[1024 * 4]; // 4KB缓冲区
  142.             int bytesRead;
  143.             while ((bytesRead = inputStream.read(buffer)) != -1) {
  144.                 outputStream.write(buffer, 0, bytesRead);
  145.             }
  146.             outputStream.flush();
  147.             Log.d(TAG, "模型文件复制成功:" + destFile.getAbsolutePath());
  148.             return destFile;
  149.         } catch (Exception e) {
  150.             Log.e(TAG, "复制assets文件失败:", e);
  151.             return null;
  152.         } finally {
  153.             // 关闭流
  154.             try {
  155.                 if (inputStream != null) inputStream.close();
  156.                 if (outputStream != null) outputStream.close();
  157.             } catch (Exception e) {
  158.                 Log.e(TAG, "关闭流异常:", e);
  159.             }
  160.         }
  161.     }
  162.     /**
  163.      * 获取初始化状态
  164.      */
  165.     public boolean isInitSuccess() {
  166.         return isInitSuccess;
  167.     }
  168. }
复制代码
至此,我们就跑完了从端模型的初步探索:跑通了 Hadoop → Spark → 大模型轻量化 → 端侧部署 全流程Demo

来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!

相关推荐

您需要登录后才可以回帖 登录 | 立即注册