找回密码
 立即注册
首页 业界区 业界 使用HuggingFace 模型并预测

使用HuggingFace 模型并预测

晌集涟 2025-6-2 23:51:41
下载HuggingFace 模型

首先打开网址:https://huggingface.co/models 这个网址是huggingface/transformers支持的所有模型,目前大约一千多个。搜索gpt2(其他的模型类似,比如bert-base-uncased等),并点击进去。
进入之后,可以看到gpt2模型的说明页,点击页面中的list all files in model,可以看到模型的所有文件。
1.png

通常需要把保存的是三个文件以及一些额外的文件

  • 配置文件 -- config.json
  • 词典文件 -- vocab.json
  • 预训练模型文件
    pytorch -- pytorch_model.bin文件
    tensorflow 2 -- tf_model.h5文件
额外的文件,指的是merges.txtspecial_tokens_map.jsonadded_tokens.jsontokenizer_config.jsonsentencepiece.bpe.model等,这几类是tokenizer需要使用的文件,如果出现的话,也需要保存下来。没有的话,就不必在意。如果不确定哪些需要下,哪些不需要的话,可以把图1中类似的文件全部下载下来。
2.png

看下这几个文件都是什么:

  • config.json配置文件
    3.png

    包含了模型的类型、激活函数等配置信息
  • vocab.json 词典文件
    4.png

  • merges.txt
    5.png

使用HuggingFace模型

将上述下载的模型存储在本地:
6.png

加载本地HuggingFace模型


  • 导入依赖
  1. import torch
  2. from transformers import GPT2Tokenizer, GPT2LMHeadModel
复制代码
导入PyTorch框架和HuggingFace Transformers库的GPT-2组件

  • 初始化分词器
  1. tokenizer = GPT2Tokenizer.from_pretrained("../../Models/gpt2/")
  2. text = "Who was Jim Henson ? Jim Henson was a"
  3. indexed_tokens = tokenizer.encode(text)
  4. print(indexed_tokens) # [8241, 373, 5395, 367, 19069, 5633]
  5. # 转换为torch Tensor
  6. token_tensor = torch.tensor([indexed_tokens])
  7. print(token_tensor) # tensor([[ 8241,   373,  5395,   367, 19069,  5633]])
复制代码
tokenizer.encode(text)执行流程如下:
分词器处理:
首先将文本分词为子词(subwords),如:
"Who was Jim Henson ?" → ['Who', 'Ġwas', 'ĠJim', 'ĠHen', 'son', '?']
ID转换:
然后将每个子词转换为对应的整数ID(来自vcab.json),如:
['Who', 'Ġwas', 'ĠJim', 'ĠHen', 'son', '?'] -> [8241, 373, 5395, 367, 19069, 5633]
可以查看vcab.json文件:
7.png

返回的是 token ID 列表(整数列表),而非词向量

  • 加载预训练模型并预测
  1. # 加载预训练模型
  2. model = GPT2LMHeadModel.from_pretrained("../../Models/gpt2/")
  3. # print(model)
  4. model.eval()
  5. with torch.no_grad():
  6.     outputs = model(token_tensor)
  7.     predictions = outputs[0]
  8. # 我们需要预测下一个单词,所以是使用predictions第一个batch,最后一个词的logits去计算
  9. # predicted_index = 582,通过计算最大得分的索引得到的
  10. predicted_index = torch.argmax(predictions[0, -1, :]).item()
  11. # 反向解码为我们需要的文本
  12. predicted_text = tokenizer.decode(indexed_tokens + [predicted_index])
  13. # 解码后的文本:'Who was Jim Henson? Jim Henson was a man'
  14. # 成功预测出单词 'man'
  15. print(predicted_text)
复制代码
输出结果:
8.png


来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!

相关推荐

您需要登录后才可以回帖 登录 | 立即注册