深度学习模型大多用Python开发,而服务端却多用Java,导致许多开发者不得不使用Java调用Python接口,效率低下且不够优雅。更糟糕的是,如果想在Android上进行推理,就必须使用Java。
别担心!现在,我们可以用Java直接进行深度学习了!DJL(Deep Java Library)是一个强大的开源深度学习框架,它支持模型构建、训练、推理,甚至在Android上运行。本文将带你深入了解DJL,并通过一个实战案例,教你用Java加载PyTorch模型进行图片分类。
DJL:Java深度学习的利器
DJL 的出现,为Java开发者打开了深度学习的大门。它提供了一套简洁易用的API,让Java开发者能够轻松地构建、训练和部署深度学习模型。
DJL 的优势:
- Java 开发: 使用熟悉的 Java 语言进行深度学习开发,无需学习其他语言。
- 跨平台支持: 支持 Windows、Linux、macOS 和 Android 等多种平台。
- GPU 加速: 支持 GPU 加速,提升模型训练和推理速度。
- 模型兼容性: 支持多种深度学习框架,包括 PyTorch、TensorFlow 和 MXNet。
DJL 核心 API 解密
DJL 的核心 API 包括 Criteria、Translator 和 NDArray,它们共同构成了深度学习模型的构建和操作基础。
1. Criteria:模型的定义
Criteria 类对象定义了模型的属性,例如模型路径、输入和输出类型等。
Criteria<Input, Output> criteria = Criteria.builder()
.setTypes(Input.class, Output.class) // 定义输入和输出数据类型
.optTranslator(new InputOutputTranslator()) // 设置输入输出转换器
.optModelPath(Paths.get("/var/models/my_resnet50")) // 指定模型路径
.optModelName("model/resnet50") // 指定模型文件前缀
.build();
ZooModel<Image, Classifications> model = criteria.loadModel();
这段代码定义了一个名为 "resnet50" 的模型,并加载了它。
2. Translator:数据转换桥梁
Translator 接口定义了如何将自定义的输入输出类转换为 Tensor 类型。
private Translator<Input, Output> translator = new Translator<Input, Output>() {
@Override
public NDList processInput(TranslatorContext ctx, Input input) throws Exception {
return null;
}
@Override
public Output processOutput(TranslatorContext ctx, NDList ndList) throws Exception {
return null;
}
};
Translator 接口包含两个方法:
- processInput: 将输入类对象转换为 Tensor。
- processOutput: 将模型输出的 Tensor 转换为自定义类。
3. NDArray:Tensor 操作的利器
NDArray 类类似于 Python 中的 NumPy 数组,它提供了丰富的 Tensor 操作功能。
NDManager ndManager = NDManager.newBaseManager(); // 创建 NDManager 对象
NDArray ndArray = ndManager.create(new Shape(1, 2, 3, 4)); // 创建一个 Shape 为 (1, 2, 3, 4) 的 Tensor
DJL 提供了多种 NDArray 操作,例如:
- 创建 NDArray
- 变更数据类型
- 运算(加减乘除)
- 切片
- 赋值
- 翻转
实战:用 DJL 加载 PyTorch 模型进行图片分类
下面,我们将使用 PyTorch 提供的 ResNet18 模型进行图片分类。
步骤:
- 引入依赖: 在项目的 pom.xml 文件中添加 DJL 的依赖。
- 导出 PyTorch 模型: 使用 Python 将 ResNet18 模型保存为 TorchScript 模型。
- 创建 Translator: 定义输入为图片路径,输出为类别。
- 定义 Criteria: 定义模型路径、输入输出类型和 Translator。
- 实例化模型: 使用 Criteria 加载模型。
- 创建 Predictor: 使用模型创建 Predictor 对象。
- 进行预测: 使用 Predictor 对图片进行分类。
代码示例:
// ... (引入依赖)
// 创建 Translator
Translator<String, String> translator = new Translator<String, String>() {
@Override
public NDList processInput(TranslatorContext ctx, String input) throws Exception {
// ... (读取图片,进行预处理)
return new NDList(ndArray);
}
@Override
public String processOutput(TranslatorContext ctx, NDList list) throws Exception {
// ... (获取预测结果)
return index + "";
}
};
// 定义 Criteria
Criteria<String, String> criteria = Criteria.builder()
.setTypes(String.class, String.class)
.optModelPath(Paths.get("model/traced_resnet_model.pt"))
.optOption("mapLocation", "true")
.optTranslator(translator)
.build();
// 实例化模型
ZooModel model = criteria.loadModel();
// 创建 Predictor
Predictor predictor = model.newPredictor();
// 进行预测
System.out.println(predictor.predict("test/test.jpg"));
最终输出:
258
258 对应的类别为 Samoyed(萨摩耶),说明预测成功。
总结
DJL 为 Java 开发者提供了强大的深度学习能力,让我们能够使用 Java 语言进行模型构建、训练和推理。本文通过一个简单的图片分类案例,展示了如何使用 DJL 加载 PyTorch 模型进行预测。
参考文献:
希望本文能够帮助你快速入门 DJL,并开始你的 Java 深度学习之旅!