普通文本  |  123行  |  4.33 KB

/*
 * Copyright (C) 2018 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "utils/tflite/encoder_common.h"

#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/string_util.h"

namespace libtextclassifier3 {

TfLiteIntArray* CreateIntArray(const std::initializer_list<int>& values) {
  TfLiteIntArray* array_size = TfLiteIntArrayCreate(values.size());
  int index = 0;
  for (const int size : values) {
    array_size->data[index++] = size;
  }
  return array_size;
}

TfLiteStatus CopyValuesToTensorAndPadOrTruncate(
    const TfLiteTensor& in, const std::vector<int>& encoding_end_offsets,
    int start_offset, TfLiteContext* context, TfLiteTensor* out) {
  TF_LITE_ENSURE_EQ(context, in.dims->size, kEncoderInputRank);
  TF_LITE_ENSURE_EQ(context, in.dims->data[0], kEncoderBatchSize);
  const int output_size = out->dims->data[1];
  int output_offset = 0;
  for (int value_index = 0;
       value_index < encoding_end_offsets.size() && output_offset < output_size;
       ++value_index) {
    // Calculate how many elements need to be set with this value.
    // The low bound depends on the offset from the beginning. If this is 0, it
    // means that this value it truncated.
    // The upper bound depends on how many elements are in the output tensor.
    const int from_this_element =
        std::min(std::max(0, encoding_end_offsets[value_index] - start_offset -
                                 output_offset),
                 output_size - output_offset);
    if (from_this_element == 0) {
      continue;
    }

    switch (in.type) {
      case kTfLiteInt32: {
        std::fill(out->data.i32 + output_offset,
                  out->data.i32 + output_offset + from_this_element,
                  in.data.i32[value_index]);
      } break;
      case kTfLiteFloat32: {
        std::fill(out->data.f + output_offset,
                  out->data.f + output_offset + from_this_element,
                  in.data.f[value_index]);
      } break;
      default:
        context->ReportError(
            (context), __FILE__ " Not supported attribute type %d", in.type);
        return kTfLiteError;
    }
    output_offset += from_this_element;
  }
  // Do final padding.
  switch (in.type) {
    case kTfLiteInt32: {
      const int32_t value =
          (output_offset > 0) ? out->data.i32[output_offset - 1] : 0;
      std::fill(out->data.i32 + output_offset, out->data.i32 + output_size,
                value);
    } break;
    case kTfLiteFloat32: {
      const float value =
          (output_offset > 0) ? out->data.f[output_offset - 1] : 0;
      std::fill(out->data.f + output_offset, out->data.f + output_size, value);
    } break;
    default:
      break;
  }
  return kTfLiteOk;
}

TfLiteStatus ResizeOutputTensor(const int max_output_length,
                                TfLiteTensor* tensor, TfLiteContext* context) {
  TF_LITE_ENSURE_OK(
      context, context->ResizeTensor(
                   context, tensor,
                   CreateIntArray({kEncoderBatchSize, max_output_length})));
  return kTfLiteOk;
}

int CopyDataToTensorAndPadOrTruncate(const int32_t max_output_length,
                                     const std::vector<int32_t>& data,
                                     const int32_t padding_value,
                                     TfLiteTensor* output_tensor) {
  const int num_skip =
      std::max(0, static_cast<int>(data.size()) - max_output_length);
  int output_offset = 0;
  int32_t* output_buffer = output_tensor->data.i32;
  for (int i = num_skip; i < data.size(); ++i, ++output_offset) {
    output_buffer[output_offset] = data[i];
  }

  // Do padding.
  for (; output_offset < max_output_length; ++output_offset) {
    output_buffer[output_offset] = padding_value;
  }

  // Return number of skipped entries from the beginning.
  return num_skip;
}

}  // namespace libtextclassifier3