break;
case DNN_UINT8:
dt = TF_UINT8;
- size = sizeof(char);
+ size = 1;
break;
default:
av_assert0(!"should not reach here");
input_dims[1] * input_dims[2] * input_dims[3] * size);
}
+static DNNReturnType get_input_tf(void *model, DNNData *input, const char *input_name)
+{
+ TFModel *tf_model = (TFModel *)model;
+ TF_Status *status;
+ int64_t dims[4];
+
+ TF_Output tf_output;
+ tf_output.oper = TF_GraphOperationByName(tf_model->graph, input_name);
+ if (!tf_output.oper)
+ return DNN_ERROR;
+
+ tf_output.index = 0;
+ input->dt = TF_OperationOutputType(tf_output);
+
+ status = TF_NewStatus();
+ TF_GraphGetTensorShape(tf_model->graph, tf_output, dims, 4, status);
+ if (TF_GetCode(status) != TF_OK){
+ TF_DeleteStatus(status);
+ return DNN_ERROR;
+ }
+ TF_DeleteStatus(status);
+
+ // currently only NHWC is supported
+ av_assert0(dims[0] == 1);
+ input->height = dims[1];
+ input->width = dims[2];
+ input->channels = dims[3];
+
+ return DNN_SUCCESS;
+}
+
static DNNReturnType set_input_output_tf(void *model, DNNData *input, const char *input_name, const char **output_names, uint32_t nb_output)
{
TFModel *tf_model = (TFModel *)model;
model->model = (void *)tf_model;
model->set_input_output = &set_input_output_tf;
+ model->get_input = &get_input_tf;
return model;
}