diff options
Diffstat (limited to 'core/java/android')
3 files changed, 156 insertions, 123 deletions
diff --git a/core/java/android/view/textclassifier/SmartSelection.java b/core/java/android/view/textclassifier/SmartSelection.java index 8edf97ea0336..69c38ee4db4f 100644 --- a/core/java/android/view/textclassifier/SmartSelection.java +++ b/core/java/android/view/textclassifier/SmartSelection.java @@ -108,9 +108,9 @@ final class SmartSelection { } /** - * Returns the language of the model. + * Returns a comma separated list of locales supported by the model as BCP 47 tags. */ - public static String getLanguage(int fd) { + public static String getLanguages(int fd) { return nativeGetLanguage(fd); } diff --git a/core/java/android/view/textclassifier/TextClassifierImpl.java b/core/java/android/view/textclassifier/TextClassifierImpl.java index af76a7fa4c7d..fc034937312c 100644 --- a/core/java/android/view/textclassifier/TextClassifierImpl.java +++ b/core/java/android/view/textclassifier/TextClassifierImpl.java @@ -58,6 +58,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Objects; +import java.util.StringJoiner; import java.util.concurrent.TimeUnit; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -101,11 +102,9 @@ public final class TextClassifierImpl implements TextClassifier { private final Object mLock = new Object(); @GuardedBy("mLock") // Do not access outside this lock. - private Map<Locale, String> mModelFilePaths; + private List<ModelFile> mAllModelFiles; @GuardedBy("mLock") // Do not access outside this lock. - private Locale mLocale; - @GuardedBy("mLock") // Do not access outside this lock. - private int mVersion; + private ModelFile mModel; @GuardedBy("mLock") // Do not access outside this lock. private SmartSelection mSmartSelection; @@ -281,18 +280,18 @@ public final class TextClassifierImpl implements TextClassifier { private SmartSelection getSmartSelection(LocaleList localeList) throws FileNotFoundException { synchronized (mLock) { localeList = localeList == null ? LocaleList.getEmptyLocaleList() : localeList; - final Locale locale = findBestSupportedLocaleLocked(localeList); - if (locale == null) { - throw new FileNotFoundException("No file for null locale"); + final ModelFile bestModel = findBestModelLocked(localeList); + if (bestModel == null) { + throw new FileNotFoundException("No model for " + localeList.toLanguageTags()); } - if (mSmartSelection == null || !Objects.equals(mLocale, locale)) { + if (mSmartSelection == null || !Objects.equals(mModel, bestModel)) { + Log.d(DEFAULT_LOG_TAG, "Loading " + bestModel); destroySmartSelectionIfExistsLocked(); - final ParcelFileDescriptor fd = getFdLocked(locale); - final int modelFd = fd.getFd(); - mVersion = SmartSelection.getVersion(modelFd); - mSmartSelection = new SmartSelection(modelFd); + final ParcelFileDescriptor fd = ParcelFileDescriptor.open( + new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY); + mSmartSelection = new SmartSelection(fd.getFd()); closeAndLogError(fd); - mLocale = locale; + mModel = bestModel; } return mSmartSelection; } @@ -300,74 +299,8 @@ public final class TextClassifierImpl implements TextClassifier { private String getSignature(String text, int start, int end) { synchronized (mLock) { - return DefaultLogger.createSignature(text, start, end, mContext, mVersion, mLocale); - } - } - - @GuardedBy("mLock") // Do not call outside this lock. - private ParcelFileDescriptor getFdLocked(Locale locale) throws FileNotFoundException { - ParcelFileDescriptor updateFd; - int updateVersion = -1; - try { - updateFd = ParcelFileDescriptor.open( - new File(UPDATED_MODEL_FILE_PATH), ParcelFileDescriptor.MODE_READ_ONLY); - if (updateFd != null) { - updateVersion = SmartSelection.getVersion(updateFd.getFd()); - } - } catch (FileNotFoundException e) { - updateFd = null; - } - ParcelFileDescriptor factoryFd; - int factoryVersion = -1; - try { - final String factoryModelFilePath = getFactoryModelFilePathsLocked().get(locale); - if (factoryModelFilePath != null) { - factoryFd = ParcelFileDescriptor.open( - new File(factoryModelFilePath), ParcelFileDescriptor.MODE_READ_ONLY); - if (factoryFd != null) { - factoryVersion = SmartSelection.getVersion(factoryFd.getFd()); - } - } else { - factoryFd = null; - } - } catch (FileNotFoundException e) { - factoryFd = null; - } - - if (updateFd == null) { - if (factoryFd != null) { - return factoryFd; - } else { - throw new FileNotFoundException( - String.format(Locale.US, "No model file found for %s", locale)); - } - } - - final int updateFdInt = updateFd.getFd(); - final boolean localeMatches = Objects.equals( - locale.getLanguage().trim().toLowerCase(), - SmartSelection.getLanguage(updateFdInt).trim().toLowerCase()); - if (factoryFd == null) { - if (localeMatches) { - return updateFd; - } else { - closeAndLogError(updateFd); - throw new FileNotFoundException( - String.format(Locale.US, "No model file found for %s", locale)); - } - } - - if (!localeMatches) { - closeAndLogError(updateFd); - return factoryFd; - } - - if (updateVersion > factoryVersion) { - closeAndLogError(factoryFd); - return updateFd; - } else { - closeAndLogError(updateFd); - return factoryFd; + return DefaultLogger.createSignature(text, start, end, mContext, mModel.getVersion(), + mModel.getSupportedLocales()); } } @@ -379,60 +312,66 @@ public final class TextClassifierImpl implements TextClassifier { } } + /** + * Finds the most appropriate model to use for the given target locale list. + * + * The basic logic is: we ignore all models that don't support any of the target locales. For + * the remaining candidates, we take the update model unless its version number is lower than + * the factory version. It's assumed that factory models do not have overlapping locale ranges + * and conflict resolution between these models hence doesn't matter. + */ @GuardedBy("mLock") // Do not call outside this lock. @Nullable - private Locale findBestSupportedLocaleLocked(LocaleList localeList) { + private ModelFile findBestModelLocked(LocaleList localeList) { // Specified localeList takes priority over the system default, so it is listed first. final String languages = localeList.isEmpty() ? LocaleList.getDefault().toLanguageTags() : localeList.toLanguageTags() + "," + LocaleList.getDefault().toLanguageTags(); final List<Locale.LanguageRange> languageRangeList = Locale.LanguageRange.parse(languages); - final List<Locale> supportedLocales = - new ArrayList<>(getFactoryModelFilePathsLocked().keySet()); - final Locale updatedModelLocale = getUpdatedModelLocale(); - if (updatedModelLocale != null) { - supportedLocales.add(updatedModelLocale); + ModelFile bestModel = null; + int bestModelVersion = -1; + for (ModelFile model : listAllModelsLocked()) { + if (model.isAnyLanguageSupported(languageRangeList)) { + if (model.getVersion() >= bestModelVersion) { + bestModel = model; + bestModelVersion = model.getVersion(); + } + } } - return Locale.lookup(languageRangeList, supportedLocales); + return bestModel; } + /** Returns a list of all model files available, in order of precedence. */ @GuardedBy("mLock") // Do not call outside this lock. - private Map<Locale, String> getFactoryModelFilePathsLocked() { - if (mModelFilePaths == null) { - final Map<Locale, String> modelFilePaths = new HashMap<>(); + private List<ModelFile> listAllModelsLocked() { + if (mAllModelFiles == null) { + final List<ModelFile> allModels = new ArrayList<>(); + // The update model has the highest precedence. + if (new File(UPDATED_MODEL_FILE_PATH).exists()) { + final ModelFile updatedModel = ModelFile.fromPath(UPDATED_MODEL_FILE_PATH); + if (updatedModel != null) { + allModels.add(updatedModel); + } + } + // Factory models should never have overlapping locales, so the order doesn't matter. final File modelsDir = new File(MODEL_DIR); if (modelsDir.exists() && modelsDir.isDirectory()) { - final File[] models = modelsDir.listFiles(); + final File[] modelFiles = modelsDir.listFiles(); final Pattern modelFilenamePattern = Pattern.compile(MODEL_FILE_REGEX); - final int size = models.length; - for (int i = 0; i < size; i++) { - final File modelFile = models[i]; + for (File modelFile : modelFiles) { final Matcher matcher = modelFilenamePattern.matcher(modelFile.getName()); if (matcher.matches() && modelFile.isFile()) { - final String language = matcher.group(1); - final Locale locale = Locale.forLanguageTag(language); - modelFilePaths.put(locale, modelFile.getAbsolutePath()); + final ModelFile model = ModelFile.fromPath(modelFile.getAbsolutePath()); + if (model != null) { + allModels.add(model); + } } } } - mModelFilePaths = modelFilePaths; - } - return mModelFilePaths; - } - - @Nullable - private Locale getUpdatedModelLocale() { - try { - final ParcelFileDescriptor updateFd = ParcelFileDescriptor.open( - new File(UPDATED_MODEL_FILE_PATH), ParcelFileDescriptor.MODE_READ_ONLY); - final Locale locale = Locale.forLanguageTag( - SmartSelection.getLanguage(updateFd.getFd())); - closeAndLogError(updateFd); - return locale; - } catch (FileNotFoundException e) { - return null; + mAllModelFiles = allModels; } + return mAllModelFiles; } private TextClassification createClassificationResult( @@ -522,6 +461,95 @@ public final class TextClassifierImpl implements TextClassifier { } /** + * Describes TextClassifier model files on disk. + */ + private static final class ModelFile { + + private final String mPath; + private final String mName; + private final int mVersion; + private final List<Locale> mSupportedLocales; + + /** Returns null if the path did not point to a compatible model. */ + static @Nullable ModelFile fromPath(String path) { + final File file = new File(path); + try { + final ParcelFileDescriptor modelFd = ParcelFileDescriptor.open( + file, ParcelFileDescriptor.MODE_READ_ONLY); + final int version = SmartSelection.getVersion(modelFd.getFd()); + final String supportedLocalesStr = SmartSelection.getLanguages(modelFd.getFd()); + if (supportedLocalesStr.isEmpty()) { + Log.d(DEFAULT_LOG_TAG, "Ignoring " + file.getAbsolutePath()); + return null; + } + final List<Locale> supportedLocales = new ArrayList<>(); + for (String langTag : supportedLocalesStr.split(",")) { + supportedLocales.add(Locale.forLanguageTag(langTag)); + } + closeAndLogError(modelFd); + return new ModelFile(path, file.getName(), version, supportedLocales); + } catch (FileNotFoundException e) { + Log.e(DEFAULT_LOG_TAG, "Failed to peek " + file.getAbsolutePath(), e); + return null; + } + } + + /** The absolute path to the model file. */ + String getPath() { + return mPath; + } + + /** A name to use for signature generation. Effectively the name of the model file. */ + String getName() { + return mName; + } + + /** Returns the version tag in the model's metadata. */ + int getVersion() { + return mVersion; + } + + /** Returns whether the language supports any language in the given ranges. */ + boolean isAnyLanguageSupported(List<Locale.LanguageRange> languageRanges) { + return Locale.lookup(languageRanges, mSupportedLocales) != null; + } + + /** All locales supported by the model. */ + List<Locale> getSupportedLocales() { + return Collections.unmodifiableList(mSupportedLocales); + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } else if (other == null || !ModelFile.class.isAssignableFrom(other.getClass())) { + return false; + } else { + final ModelFile otherModel = (ModelFile) other; + return mPath.equals(otherModel.mPath); + } + } + + @Override + public String toString() { + final StringJoiner localesJoiner = new StringJoiner(","); + for (Locale locale : mSupportedLocales) { + localesJoiner.add(locale.toLanguageTag()); + } + return String.format(Locale.US, "ModelFile { path=%s name=%s version=%d locales=%s }", + mPath, mName, mVersion, localesJoiner.toString()); + } + + private ModelFile(String path, String name, int version, List<Locale> supportedLocales) { + mPath = path; + mName = name; + mVersion = version; + mSupportedLocales = supportedLocales; + } + } + + /** * Creates intents based on the classification type. */ static final class IntentFactory { diff --git a/core/java/android/view/textclassifier/logging/DefaultLogger.java b/core/java/android/view/textclassifier/logging/DefaultLogger.java index 6b848351cbf6..03a6d3a7f10f 100644 --- a/core/java/android/view/textclassifier/logging/DefaultLogger.java +++ b/core/java/android/view/textclassifier/logging/DefaultLogger.java @@ -17,7 +17,6 @@ package android.view.textclassifier.logging; import android.annotation.NonNull; -import android.annotation.Nullable; import android.content.Context; import android.metrics.LogMaker; import android.util.Log; @@ -27,8 +26,10 @@ import com.android.internal.logging.MetricsLogger; import com.android.internal.logging.nano.MetricsProto.MetricsEvent; import com.android.internal.util.Preconditions; +import java.util.List; import java.util.Locale; import java.util.Objects; +import java.util.StringJoiner; /** * Default Logger. @@ -210,12 +211,16 @@ public final class DefaultLogger extends Logger { */ public static String createSignature( String text, int start, int end, Context context, int modelVersion, - @Nullable Locale locale) { + List<Locale> locales) { Preconditions.checkNotNull(text); Preconditions.checkNotNull(context); - final String modelName = (locale != null) - ? String.format(Locale.US, "%s_v%d", locale.toLanguageTag(), modelVersion) - : ""; + Preconditions.checkNotNull(locales); + final StringJoiner localesJoiner = new StringJoiner(","); + for (Locale locale : locales) { + localesJoiner.add(locale.toLanguageTag()); + } + final String modelName = String.format(Locale.US, "%s_v%d", localesJoiner.toString(), + modelVersion); final int hash = Objects.hash(text, start, end, context.getPackageName()); return SignatureParser.createSignature(CLASSIFIER_ID, modelName, hash); } @@ -242,9 +247,9 @@ public final class DefaultLogger extends Logger { static String getModelName(String signature) { Preconditions.checkNotNull(signature); - final int start = signature.indexOf("|"); + final int start = signature.indexOf("|") + 1; final int end = signature.indexOf("|", start); - if (start >= 0 && end >= start) { + if (start >= 1 && end >= start) { return signature.substring(start, end); } return ""; |
