diff options
Diffstat (limited to 'core/java/android')
4 files changed, 43 insertions, 52 deletions
diff --git a/core/java/android/view/textclassifier/EntityConfidence.java b/core/java/android/view/textclassifier/EntityConfidence.java index 0589d204ac3f..19660d95e927 100644 --- a/core/java/android/view/textclassifier/EntityConfidence.java +++ b/core/java/android/view/textclassifier/EntityConfidence.java @@ -18,13 +18,12 @@ package android.view.textclassifier; import android.annotation.FloatRange; import android.annotation.NonNull; +import android.util.ArrayMap; import com.android.internal.util.Preconditions; import java.util.ArrayList; import java.util.Collections; -import java.util.Comparator; -import java.util.HashMap; import java.util.List; import java.util.Map; @@ -36,42 +35,43 @@ import java.util.Map; */ final class EntityConfidence<T> { - private final Map<T, Float> mEntityConfidence = new HashMap<>(); - - private final Comparator<T> mEntityComparator = (e1, e2) -> { - float score1 = mEntityConfidence.get(e1); - float score2 = mEntityConfidence.get(e2); - if (score1 > score2) { - return -1; - } - if (score1 < score2) { - return 1; - } - return 0; - }; + private final ArrayMap<T, Float> mEntityConfidence = new ArrayMap<>(); + private final ArrayList<T> mSortedEntities = new ArrayList<>(); EntityConfidence() {} EntityConfidence(@NonNull EntityConfidence<T> source) { Preconditions.checkNotNull(source); mEntityConfidence.putAll(source.mEntityConfidence); + mSortedEntities.addAll(source.mSortedEntities); } /** - * Sets an entity type for the classified text and assigns a confidence score. + * Constructs an EntityConfidence from a map of entity to confidence. * - * @param confidenceScore a value from 0 (low confidence) to 1 (high confidence). - * 0 implies the entity does not exist for the classified text. - * Values greater than 1 are clamped to 1. + * Map entries that have 0 confidence are removed, and values greater than 1 are clamped to 1. + * + * @param source a map from entity to a confidence value in the range 0 (low confidence) to + * 1 (high confidence). */ - public void setEntityType( - @NonNull T type, @FloatRange(from = 0.0, to = 1.0) float confidenceScore) { - Preconditions.checkNotNull(type); - if (confidenceScore > 0) { - mEntityConfidence.put(type, Math.min(1, confidenceScore)); - } else { - mEntityConfidence.remove(type); + EntityConfidence(@NonNull Map<T, Float> source) { + Preconditions.checkNotNull(source); + + // Prune non-existent entities and clamp to 1. + mEntityConfidence.ensureCapacity(source.size()); + for (Map.Entry<T, Float> it : source.entrySet()) { + if (it.getValue() <= 0) continue; + mEntityConfidence.put(it.getKey(), Math.min(1, it.getValue())); } + + // Create a list of entities sorted by decreasing confidence for getEntities(). + mSortedEntities.ensureCapacity(mEntityConfidence.size()); + mSortedEntities.addAll(mEntityConfidence.keySet()); + mSortedEntities.sort((e1, e2) -> { + float score1 = mEntityConfidence.get(e1); + float score2 = mEntityConfidence.get(e2); + return Float.compare(score2, score1); + }); } /** @@ -80,10 +80,7 @@ final class EntityConfidence<T> { */ @NonNull public List<T> getEntities() { - List<T> entities = new ArrayList<>(mEntityConfidence.size()); - entities.addAll(mEntityConfidence.keySet()); - entities.sort(mEntityComparator); - return Collections.unmodifiableList(entities); + return Collections.unmodifiableList(mSortedEntities); } /** diff --git a/core/java/android/view/textclassifier/TextClassification.java b/core/java/android/view/textclassifier/TextClassification.java index f675c355638c..89163238ea4d 100644 --- a/core/java/android/view/textclassifier/TextClassification.java +++ b/core/java/android/view/textclassifier/TextClassification.java @@ -24,6 +24,7 @@ import android.content.Context; import android.content.Intent; import android.graphics.drawable.Drawable; import android.os.LocaleList; +import android.util.ArrayMap; import android.view.View.OnClickListener; import android.view.textclassifier.TextClassifier.EntityType; @@ -32,6 +33,7 @@ import com.android.internal.util.Preconditions; import java.util.ArrayList; import java.util.List; import java.util.Locale; +import java.util.Map; /** * Information for generating a widget to handle classified text. @@ -95,7 +97,6 @@ public final class TextClassification { @NonNull private final List<Intent> mIntents; @NonNull private final List<OnClickListener> mOnClickListeners; @NonNull private final EntityConfidence<String> mEntityConfidence; - @NonNull private final List<String> mEntities; private int mLogType; @NonNull private final String mVersionInfo; @@ -105,7 +106,7 @@ public final class TextClassification { @NonNull List<String> labels, @NonNull List<Intent> intents, @NonNull List<OnClickListener> onClickListeners, - @NonNull EntityConfidence<String> entityConfidence, + @NonNull Map<String, Float> entityConfidence, int logType, @NonNull String versionInfo) { Preconditions.checkArgument(labels.size() == intents.size()); @@ -117,7 +118,6 @@ public final class TextClassification { mIntents = intents; mOnClickListeners = onClickListeners; mEntityConfidence = new EntityConfidence<>(entityConfidence); - mEntities = mEntityConfidence.getEntities(); mLogType = logType; mVersionInfo = versionInfo; } @@ -135,7 +135,7 @@ public final class TextClassification { */ @IntRange(from = 0) public int getEntityCount() { - return mEntities.size(); + return mEntityConfidence.getEntities().size(); } /** @@ -147,7 +147,7 @@ public final class TextClassification { */ @NonNull public @EntityType String getEntity(int index) { - return mEntities.get(index); + return mEntityConfidence.getEntities().get(index); } /** @@ -311,8 +311,7 @@ public final class TextClassification { @NonNull private final List<String> mLabels = new ArrayList<>(); @NonNull private final List<Intent> mIntents = new ArrayList<>(); @NonNull private final List<OnClickListener> mOnClickListeners = new ArrayList<>(); - @NonNull private final EntityConfidence<String> mEntityConfidence = - new EntityConfidence<>(); + @NonNull private final Map<String, Float> mEntityConfidence = new ArrayMap<>(); private int mLogType; @NonNull private String mVersionInfo = ""; @@ -334,7 +333,7 @@ public final class TextClassification { public Builder setEntityType( @NonNull @EntityType String type, @FloatRange(from = 0.0, to = 1.0) float confidenceScore) { - mEntityConfidence.setEntityType(type, confidenceScore); + mEntityConfidence.put(type, confidenceScore); return this; } diff --git a/core/java/android/view/textclassifier/TextLinks.java b/core/java/android/view/textclassifier/TextLinks.java index 76748d2b191a..0e039e35367e 100644 --- a/core/java/android/view/textclassifier/TextLinks.java +++ b/core/java/android/view/textclassifier/TextLinks.java @@ -103,11 +103,7 @@ public final class TextLinks { mOriginalText = originalText; mStart = start; mEnd = end; - mEntityScores = new EntityConfidence<>(); - - for (Map.Entry<String, Float> entry : entityScores.entrySet()) { - mEntityScores.setEntityType(entry.getKey(), entry.getValue()); - } + mEntityScores = new EntityConfidence<>(entityScores); } /** diff --git a/core/java/android/view/textclassifier/TextSelection.java b/core/java/android/view/textclassifier/TextSelection.java index 480b27a73fc1..ced4018bcd82 100644 --- a/core/java/android/view/textclassifier/TextSelection.java +++ b/core/java/android/view/textclassifier/TextSelection.java @@ -21,12 +21,13 @@ import android.annotation.IntRange; import android.annotation.NonNull; import android.annotation.Nullable; import android.os.LocaleList; +import android.util.ArrayMap; import android.view.textclassifier.TextClassifier.EntityType; import com.android.internal.util.Preconditions; -import java.util.List; import java.util.Locale; +import java.util.Map; /** * Information about where text selection should be. @@ -36,7 +37,6 @@ public final class TextSelection { private final int mStartIndex; private final int mEndIndex; @NonNull private final EntityConfidence<String> mEntityConfidence; - @NonNull private final List<String> mEntities; @NonNull private final String mLogSource; @NonNull private final String mVersionInfo; @@ -46,7 +46,6 @@ public final class TextSelection { mStartIndex = startIndex; mEndIndex = endIndex; mEntityConfidence = new EntityConfidence<>(entityConfidence); - mEntities = mEntityConfidence.getEntities(); mLogSource = logSource; mVersionInfo = versionInfo; } @@ -70,7 +69,7 @@ public final class TextSelection { */ @IntRange(from = 0) public int getEntityCount() { - return mEntities.size(); + return mEntityConfidence.getEntities().size(); } /** @@ -82,7 +81,7 @@ public final class TextSelection { */ @NonNull public @EntityType String getEntity(int index) { - return mEntities.get(index); + return mEntityConfidence.getEntities().get(index); } /** @@ -126,8 +125,7 @@ public final class TextSelection { private final int mStartIndex; private final int mEndIndex; - @NonNull private final EntityConfidence<String> mEntityConfidence = - new EntityConfidence<>(); + @NonNull private final Map<String, Float> mEntityConfidence = new ArrayMap<>(); @NonNull private String mLogSource = ""; @NonNull private String mVersionInfo = ""; @@ -154,7 +152,7 @@ public final class TextSelection { public Builder setEntityType( @NonNull @EntityType String type, @FloatRange(from = 0.0, to = 1.0) float confidenceScore) { - mEntityConfidence.setEntityType(type, confidenceScore); + mEntityConfidence.put(type, confidenceScore); return this; } @@ -181,7 +179,8 @@ public final class TextSelection { */ public TextSelection build() { return new TextSelection( - mStartIndex, mEndIndex, mEntityConfidence, mLogSource, mVersionInfo); + mStartIndex, mEndIndex, new EntityConfidence<>(mEntityConfidence), mLogSource, + mVersionInfo); } } |
