summaryrefslogtreecommitdiff
path: root/core/java/android
diff options
context:
space:
mode:
Diffstat (limited to 'core/java/android')
-rw-r--r--core/java/android/view/textclassifier/SmartSelection.java4
-rw-r--r--core/java/android/view/textclassifier/TextClassifierImpl.java256
-rw-r--r--core/java/android/view/textclassifier/logging/DefaultLogger.java19
3 files changed, 156 insertions, 123 deletions
diff --git a/core/java/android/view/textclassifier/SmartSelection.java b/core/java/android/view/textclassifier/SmartSelection.java
index 8edf97ea0336..69c38ee4db4f 100644
--- a/core/java/android/view/textclassifier/SmartSelection.java
+++ b/core/java/android/view/textclassifier/SmartSelection.java
@@ -108,9 +108,9 @@ final class SmartSelection {
}
/**
- * Returns the language of the model.
+ * Returns a comma separated list of locales supported by the model as BCP 47 tags.
*/
- public static String getLanguage(int fd) {
+ public static String getLanguages(int fd) {
return nativeGetLanguage(fd);
}
diff --git a/core/java/android/view/textclassifier/TextClassifierImpl.java b/core/java/android/view/textclassifier/TextClassifierImpl.java
index af76a7fa4c7d..fc034937312c 100644
--- a/core/java/android/view/textclassifier/TextClassifierImpl.java
+++ b/core/java/android/view/textclassifier/TextClassifierImpl.java
@@ -58,6 +58,7 @@ import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
+import java.util.StringJoiner;
import java.util.concurrent.TimeUnit;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
@@ -101,11 +102,9 @@ public final class TextClassifierImpl implements TextClassifier {
private final Object mLock = new Object();
@GuardedBy("mLock") // Do not access outside this lock.
- private Map<Locale, String> mModelFilePaths;
+ private List<ModelFile> mAllModelFiles;
@GuardedBy("mLock") // Do not access outside this lock.
- private Locale mLocale;
- @GuardedBy("mLock") // Do not access outside this lock.
- private int mVersion;
+ private ModelFile mModel;
@GuardedBy("mLock") // Do not access outside this lock.
private SmartSelection mSmartSelection;
@@ -281,18 +280,18 @@ public final class TextClassifierImpl implements TextClassifier {
private SmartSelection getSmartSelection(LocaleList localeList) throws FileNotFoundException {
synchronized (mLock) {
localeList = localeList == null ? LocaleList.getEmptyLocaleList() : localeList;
- final Locale locale = findBestSupportedLocaleLocked(localeList);
- if (locale == null) {
- throw new FileNotFoundException("No file for null locale");
+ final ModelFile bestModel = findBestModelLocked(localeList);
+ if (bestModel == null) {
+ throw new FileNotFoundException("No model for " + localeList.toLanguageTags());
}
- if (mSmartSelection == null || !Objects.equals(mLocale, locale)) {
+ if (mSmartSelection == null || !Objects.equals(mModel, bestModel)) {
+ Log.d(DEFAULT_LOG_TAG, "Loading " + bestModel);
destroySmartSelectionIfExistsLocked();
- final ParcelFileDescriptor fd = getFdLocked(locale);
- final int modelFd = fd.getFd();
- mVersion = SmartSelection.getVersion(modelFd);
- mSmartSelection = new SmartSelection(modelFd);
+ final ParcelFileDescriptor fd = ParcelFileDescriptor.open(
+ new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY);
+ mSmartSelection = new SmartSelection(fd.getFd());
closeAndLogError(fd);
- mLocale = locale;
+ mModel = bestModel;
}
return mSmartSelection;
}
@@ -300,74 +299,8 @@ public final class TextClassifierImpl implements TextClassifier {
private String getSignature(String text, int start, int end) {
synchronized (mLock) {
- return DefaultLogger.createSignature(text, start, end, mContext, mVersion, mLocale);
- }
- }
-
- @GuardedBy("mLock") // Do not call outside this lock.
- private ParcelFileDescriptor getFdLocked(Locale locale) throws FileNotFoundException {
- ParcelFileDescriptor updateFd;
- int updateVersion = -1;
- try {
- updateFd = ParcelFileDescriptor.open(
- new File(UPDATED_MODEL_FILE_PATH), ParcelFileDescriptor.MODE_READ_ONLY);
- if (updateFd != null) {
- updateVersion = SmartSelection.getVersion(updateFd.getFd());
- }
- } catch (FileNotFoundException e) {
- updateFd = null;
- }
- ParcelFileDescriptor factoryFd;
- int factoryVersion = -1;
- try {
- final String factoryModelFilePath = getFactoryModelFilePathsLocked().get(locale);
- if (factoryModelFilePath != null) {
- factoryFd = ParcelFileDescriptor.open(
- new File(factoryModelFilePath), ParcelFileDescriptor.MODE_READ_ONLY);
- if (factoryFd != null) {
- factoryVersion = SmartSelection.getVersion(factoryFd.getFd());
- }
- } else {
- factoryFd = null;
- }
- } catch (FileNotFoundException e) {
- factoryFd = null;
- }
-
- if (updateFd == null) {
- if (factoryFd != null) {
- return factoryFd;
- } else {
- throw new FileNotFoundException(
- String.format(Locale.US, "No model file found for %s", locale));
- }
- }
-
- final int updateFdInt = updateFd.getFd();
- final boolean localeMatches = Objects.equals(
- locale.getLanguage().trim().toLowerCase(),
- SmartSelection.getLanguage(updateFdInt).trim().toLowerCase());
- if (factoryFd == null) {
- if (localeMatches) {
- return updateFd;
- } else {
- closeAndLogError(updateFd);
- throw new FileNotFoundException(
- String.format(Locale.US, "No model file found for %s", locale));
- }
- }
-
- if (!localeMatches) {
- closeAndLogError(updateFd);
- return factoryFd;
- }
-
- if (updateVersion > factoryVersion) {
- closeAndLogError(factoryFd);
- return updateFd;
- } else {
- closeAndLogError(updateFd);
- return factoryFd;
+ return DefaultLogger.createSignature(text, start, end, mContext, mModel.getVersion(),
+ mModel.getSupportedLocales());
}
}
@@ -379,60 +312,66 @@ public final class TextClassifierImpl implements TextClassifier {
}
}
+ /**
+ * Finds the most appropriate model to use for the given target locale list.
+ *
+ * The basic logic is: we ignore all models that don't support any of the target locales. For
+ * the remaining candidates, we take the update model unless its version number is lower than
+ * the factory version. It's assumed that factory models do not have overlapping locale ranges
+ * and conflict resolution between these models hence doesn't matter.
+ */
@GuardedBy("mLock") // Do not call outside this lock.
@Nullable
- private Locale findBestSupportedLocaleLocked(LocaleList localeList) {
+ private ModelFile findBestModelLocked(LocaleList localeList) {
// Specified localeList takes priority over the system default, so it is listed first.
final String languages = localeList.isEmpty()
? LocaleList.getDefault().toLanguageTags()
: localeList.toLanguageTags() + "," + LocaleList.getDefault().toLanguageTags();
final List<Locale.LanguageRange> languageRangeList = Locale.LanguageRange.parse(languages);
- final List<Locale> supportedLocales =
- new ArrayList<>(getFactoryModelFilePathsLocked().keySet());
- final Locale updatedModelLocale = getUpdatedModelLocale();
- if (updatedModelLocale != null) {
- supportedLocales.add(updatedModelLocale);
+ ModelFile bestModel = null;
+ int bestModelVersion = -1;
+ for (ModelFile model : listAllModelsLocked()) {
+ if (model.isAnyLanguageSupported(languageRangeList)) {
+ if (model.getVersion() >= bestModelVersion) {
+ bestModel = model;
+ bestModelVersion = model.getVersion();
+ }
+ }
}
- return Locale.lookup(languageRangeList, supportedLocales);
+ return bestModel;
}
+ /** Returns a list of all model files available, in order of precedence. */
@GuardedBy("mLock") // Do not call outside this lock.
- private Map<Locale, String> getFactoryModelFilePathsLocked() {
- if (mModelFilePaths == null) {
- final Map<Locale, String> modelFilePaths = new HashMap<>();
+ private List<ModelFile> listAllModelsLocked() {
+ if (mAllModelFiles == null) {
+ final List<ModelFile> allModels = new ArrayList<>();
+ // The update model has the highest precedence.
+ if (new File(UPDATED_MODEL_FILE_PATH).exists()) {
+ final ModelFile updatedModel = ModelFile.fromPath(UPDATED_MODEL_FILE_PATH);
+ if (updatedModel != null) {
+ allModels.add(updatedModel);
+ }
+ }
+ // Factory models should never have overlapping locales, so the order doesn't matter.
final File modelsDir = new File(MODEL_DIR);
if (modelsDir.exists() && modelsDir.isDirectory()) {
- final File[] models = modelsDir.listFiles();
+ final File[] modelFiles = modelsDir.listFiles();
final Pattern modelFilenamePattern = Pattern.compile(MODEL_FILE_REGEX);
- final int size = models.length;
- for (int i = 0; i < size; i++) {
- final File modelFile = models[i];
+ for (File modelFile : modelFiles) {
final Matcher matcher = modelFilenamePattern.matcher(modelFile.getName());
if (matcher.matches() && modelFile.isFile()) {
- final String language = matcher.group(1);
- final Locale locale = Locale.forLanguageTag(language);
- modelFilePaths.put(locale, modelFile.getAbsolutePath());
+ final ModelFile model = ModelFile.fromPath(modelFile.getAbsolutePath());
+ if (model != null) {
+ allModels.add(model);
+ }
}
}
}
- mModelFilePaths = modelFilePaths;
- }
- return mModelFilePaths;
- }
-
- @Nullable
- private Locale getUpdatedModelLocale() {
- try {
- final ParcelFileDescriptor updateFd = ParcelFileDescriptor.open(
- new File(UPDATED_MODEL_FILE_PATH), ParcelFileDescriptor.MODE_READ_ONLY);
- final Locale locale = Locale.forLanguageTag(
- SmartSelection.getLanguage(updateFd.getFd()));
- closeAndLogError(updateFd);
- return locale;
- } catch (FileNotFoundException e) {
- return null;
+ mAllModelFiles = allModels;
}
+ return mAllModelFiles;
}
private TextClassification createClassificationResult(
@@ -522,6 +461,95 @@ public final class TextClassifierImpl implements TextClassifier {
}
/**
+ * Describes TextClassifier model files on disk.
+ */
+ private static final class ModelFile {
+
+ private final String mPath;
+ private final String mName;
+ private final int mVersion;
+ private final List<Locale> mSupportedLocales;
+
+ /** Returns null if the path did not point to a compatible model. */
+ static @Nullable ModelFile fromPath(String path) {
+ final File file = new File(path);
+ try {
+ final ParcelFileDescriptor modelFd = ParcelFileDescriptor.open(
+ file, ParcelFileDescriptor.MODE_READ_ONLY);
+ final int version = SmartSelection.getVersion(modelFd.getFd());
+ final String supportedLocalesStr = SmartSelection.getLanguages(modelFd.getFd());
+ if (supportedLocalesStr.isEmpty()) {
+ Log.d(DEFAULT_LOG_TAG, "Ignoring " + file.getAbsolutePath());
+ return null;
+ }
+ final List<Locale> supportedLocales = new ArrayList<>();
+ for (String langTag : supportedLocalesStr.split(",")) {
+ supportedLocales.add(Locale.forLanguageTag(langTag));
+ }
+ closeAndLogError(modelFd);
+ return new ModelFile(path, file.getName(), version, supportedLocales);
+ } catch (FileNotFoundException e) {
+ Log.e(DEFAULT_LOG_TAG, "Failed to peek " + file.getAbsolutePath(), e);
+ return null;
+ }
+ }
+
+ /** The absolute path to the model file. */
+ String getPath() {
+ return mPath;
+ }
+
+ /** A name to use for signature generation. Effectively the name of the model file. */
+ String getName() {
+ return mName;
+ }
+
+ /** Returns the version tag in the model's metadata. */
+ int getVersion() {
+ return mVersion;
+ }
+
+ /** Returns whether the language supports any language in the given ranges. */
+ boolean isAnyLanguageSupported(List<Locale.LanguageRange> languageRanges) {
+ return Locale.lookup(languageRanges, mSupportedLocales) != null;
+ }
+
+ /** All locales supported by the model. */
+ List<Locale> getSupportedLocales() {
+ return Collections.unmodifiableList(mSupportedLocales);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (this == other) {
+ return true;
+ } else if (other == null || !ModelFile.class.isAssignableFrom(other.getClass())) {
+ return false;
+ } else {
+ final ModelFile otherModel = (ModelFile) other;
+ return mPath.equals(otherModel.mPath);
+ }
+ }
+
+ @Override
+ public String toString() {
+ final StringJoiner localesJoiner = new StringJoiner(",");
+ for (Locale locale : mSupportedLocales) {
+ localesJoiner.add(locale.toLanguageTag());
+ }
+ return String.format(Locale.US, "ModelFile { path=%s name=%s version=%d locales=%s }",
+ mPath, mName, mVersion, localesJoiner.toString());
+ }
+
+ private ModelFile(String path, String name, int version, List<Locale> supportedLocales) {
+ mPath = path;
+ mName = name;
+ mVersion = version;
+ mSupportedLocales = supportedLocales;
+ }
+ }
+
+ /**
* Creates intents based on the classification type.
*/
static final class IntentFactory {
diff --git a/core/java/android/view/textclassifier/logging/DefaultLogger.java b/core/java/android/view/textclassifier/logging/DefaultLogger.java
index 6b848351cbf6..03a6d3a7f10f 100644
--- a/core/java/android/view/textclassifier/logging/DefaultLogger.java
+++ b/core/java/android/view/textclassifier/logging/DefaultLogger.java
@@ -17,7 +17,6 @@
package android.view.textclassifier.logging;
import android.annotation.NonNull;
-import android.annotation.Nullable;
import android.content.Context;
import android.metrics.LogMaker;
import android.util.Log;
@@ -27,8 +26,10 @@ import com.android.internal.logging.MetricsLogger;
import com.android.internal.logging.nano.MetricsProto.MetricsEvent;
import com.android.internal.util.Preconditions;
+import java.util.List;
import java.util.Locale;
import java.util.Objects;
+import java.util.StringJoiner;
/**
* Default Logger.
@@ -210,12 +211,16 @@ public final class DefaultLogger extends Logger {
*/
public static String createSignature(
String text, int start, int end, Context context, int modelVersion,
- @Nullable Locale locale) {
+ List<Locale> locales) {
Preconditions.checkNotNull(text);
Preconditions.checkNotNull(context);
- final String modelName = (locale != null)
- ? String.format(Locale.US, "%s_v%d", locale.toLanguageTag(), modelVersion)
- : "";
+ Preconditions.checkNotNull(locales);
+ final StringJoiner localesJoiner = new StringJoiner(",");
+ for (Locale locale : locales) {
+ localesJoiner.add(locale.toLanguageTag());
+ }
+ final String modelName = String.format(Locale.US, "%s_v%d", localesJoiner.toString(),
+ modelVersion);
final int hash = Objects.hash(text, start, end, context.getPackageName());
return SignatureParser.createSignature(CLASSIFIER_ID, modelName, hash);
}
@@ -242,9 +247,9 @@ public final class DefaultLogger extends Logger {
static String getModelName(String signature) {
Preconditions.checkNotNull(signature);
- final int start = signature.indexOf("|");
+ final int start = signature.indexOf("|") + 1;
final int end = signature.indexOf("|", start);
- if (start >= 0 && end >= start) {
+ if (start >= 1 && end >= start) {
return signature.substring(start, end);
}
return "";