Commit 7560db93 authored by Wenbin Chen's avatar Wenbin Chen Committed by Guo Yejun

libavfi/dnn: enable LibTorch xpu device option support

Add xpu device support to libtorch backend.
To enable xpu support you need to add
 "-Wl,--no-as-needed -lintel-ext-pt-gpu -Wl,--as-needed" to
"--extra-libs" when configure ffmpeg.
Signed-off-by: 's avatarWenbin Chen <wenbin.chen@intel.com>
parent f68f4073
......@@ -250,6 +250,10 @@ static int th_start_inference(void *args)
av_log(ctx, AV_LOG_ERROR, "input or output tensor is NULL\n");
return DNN_GENERIC_ERROR;
}
// Transfer tensor to the same device as model
c10::Device device = (*th_model->jit_model->parameters().begin()).device();
if (infer_request->input_tensor->device() != device)
*infer_request->input_tensor = infer_request->input_tensor->to(device);
inputs.push_back(*infer_request->input_tensor);
*infer_request->output = th_model->jit_model->forward(inputs).toTensor();
......@@ -285,6 +289,9 @@ static void infer_completion_callback(void *args) {
switch (th_model->model.func_type) {
case DFT_PROCESS_FRAME:
if (task->do_ioproc) {
// Post process can only deal with CPU memory.
if (output->device() != torch::kCPU)
*output = output->to(torch::kCPU);
outputs.scale = 255;
outputs.data = output->data_ptr();
if (th_model->model.frame_post_proc != NULL) {
......@@ -424,7 +431,13 @@ static DNNModel *dnn_load_model_th(DnnContext *ctx, DNNFunctionType func_type, A
th_model->ctx = ctx;
c10::Device device = c10::Device(device_name);
if (!device.is_cpu()) {
if (device.is_xpu()) {
if (!at::hasXPU()) {
av_log(ctx, AV_LOG_ERROR, "No XPU device found\n");
goto fail;
}
at::detail::getXPUHooks().initXPU();
} else if (!device.is_cpu()) {
av_log(ctx, AV_LOG_ERROR, "Not supported device:\"%s\"\n", device_name);
goto fail;
}
......@@ -432,6 +445,7 @@ static DNNModel *dnn_load_model_th(DnnContext *ctx, DNNFunctionType func_type, A
try {
th_model->jit_model = new torch::jit::Module;
(*th_model->jit_model) = torch::jit::load(ctx->model_filename);
th_model->jit_model->to(device);
} catch (const c10::Error& e) {
av_log(ctx, AV_LOG_ERROR, "Failed to load torch model\n");
goto fail;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment