summaryrefslogtreecommitdiff
path: root/core/java/android
diff options
context:
space:
mode:
authorTreeHugger Robot <treehugger-gerrit@google.com>2018-11-13 11:14:08 +0000
committerAndroid (Google) Code Review <android-gerrit@google.com>2018-11-13 11:14:08 +0000
commit9005dafb7440db3130c03bfdacf56759c5a1b606 (patch)
treea8c18734674c4833edb0b602bba47ce51a570f49 /core/java/android
parentbda37423d44663c8d07800ccfa9399244b85f191 (diff)
parentadbebcc634e7c0a876f887a0625ea77a042e63d9 (diff)
Merge "Implements TextClassifierImpl.suggestConversationActions"
Diffstat (limited to 'core/java/android')
-rw-r--r--core/java/android/view/textclassifier/ModelFileManager.java3
-rw-r--r--core/java/android/view/textclassifier/TextClassificationConstants.java35
-rw-r--r--core/java/android/view/textclassifier/TextClassifierImpl.java112
3 files changed, 146 insertions, 4 deletions
diff --git a/core/java/android/view/textclassifier/ModelFileManager.java b/core/java/android/view/textclassifier/ModelFileManager.java
index adea1259b943..896b516bbf9a 100644
--- a/core/java/android/view/textclassifier/ModelFileManager.java
+++ b/core/java/android/view/textclassifier/ModelFileManager.java
@@ -74,10 +74,9 @@ public final class ModelFileManager {
* @param localeList the required locales, use {@code null} if there is no preference.
*/
public ModelFile findBestModelFile(@Nullable LocaleList localeList) {
- // Specified localeList takes priority over the system default, so it is listed first.
final String languages = localeList == null || localeList.isEmpty()
? LocaleList.getDefault().toLanguageTags()
- : localeList.toLanguageTags() + "," + LocaleList.getDefault().toLanguageTags();
+ : localeList.toLanguageTags();
final List<Locale.LanguageRange> languageRangeList = Locale.LanguageRange.parse(languages);
ModelFile bestModel = null;
diff --git a/core/java/android/view/textclassifier/TextClassificationConstants.java b/core/java/android/view/textclassifier/TextClassificationConstants.java
index 2fc74221a456..50801a2b3e3f 100644
--- a/core/java/android/view/textclassifier/TextClassificationConstants.java
+++ b/core/java/android/view/textclassifier/TextClassificationConstants.java
@@ -90,6 +90,10 @@ public final class TextClassificationConstants {
"entity_list_not_editable";
private static final String ENTITY_LIST_EDITABLE =
"entity_list_editable";
+ private static final String IN_APP_CONVERSATION_ACTION_TYPES_DEFAULT =
+ "in_app_conversation_action_types_default";
+ private static final String NOTIFICATION_CONVERSATION_ACTION_TYPES_DEFAULT =
+ "notification_conversation_action_types_default";
private static final boolean LOCAL_TEXT_CLASSIFIER_ENABLED_DEFAULT = true;
private static final boolean SYSTEM_TEXT_CLASSIFIER_ENABLED_DEFAULT = true;
@@ -111,6 +115,18 @@ public final class TextClassificationConstants {
.add(TextClassifier.TYPE_DATE)
.add(TextClassifier.TYPE_DATE_TIME)
.add(TextClassifier.TYPE_FLIGHT_NUMBER).toString();
+ private static final String CONVERSATION_ACTIONS_TYPES_DEFAULT_VALUES =
+ new StringJoiner(ENTITY_LIST_DELIMITER)
+ .add(ConversationActions.TYPE_TEXT_REPLY)
+ .add(ConversationActions.TYPE_CREATE_REMINDER)
+ .add(ConversationActions.TYPE_CALL_PHONE)
+ .add(ConversationActions.TYPE_OPEN_URL)
+ .add(ConversationActions.TYPE_SEND_EMAIL)
+ .add(ConversationActions.TYPE_SEND_SMS)
+ .add(ConversationActions.TYPE_TRACK_FLIGHT)
+ .add(ConversationActions.TYPE_VIEW_CALENDAR)
+ .add(ConversationActions.TYPE_VIEW_MAP)
+ .toString();
private final boolean mSystemTextClassifierEnabled;
private final boolean mLocalTextClassifierEnabled;
@@ -126,6 +142,8 @@ public final class TextClassificationConstants {
private final List<String> mEntityListDefault;
private final List<String> mEntityListNotEditable;
private final List<String> mEntityListEditable;
+ private final List<String> mInAppConversationActionTypesDefault;
+ private final List<String> mNotificationConversationActionTypesDefault;
private TextClassificationConstants(@Nullable String settings) {
final KeyValueListParser parser = new KeyValueListParser(',');
@@ -177,6 +195,12 @@ public final class TextClassificationConstants {
mEntityListEditable = parseEntityList(parser.getString(
ENTITY_LIST_EDITABLE,
ENTITY_LIST_DEFAULT_VALUE));
+ mInAppConversationActionTypesDefault = parseEntityList(parser.getString(
+ IN_APP_CONVERSATION_ACTION_TYPES_DEFAULT,
+ CONVERSATION_ACTIONS_TYPES_DEFAULT_VALUES));
+ mNotificationConversationActionTypesDefault = parseEntityList(parser.getString(
+ NOTIFICATION_CONVERSATION_ACTION_TYPES_DEFAULT,
+ CONVERSATION_ACTIONS_TYPES_DEFAULT_VALUES));
}
/** Load from a settings string. */
@@ -240,6 +264,14 @@ public final class TextClassificationConstants {
return mEntityListEditable;
}
+ public List<String> getInAppConversationActionTypes() {
+ return mInAppConversationActionTypesDefault;
+ }
+
+ public List<String> getNotificationConversationActionTypes() {
+ return mNotificationConversationActionTypesDefault;
+ }
+
private static List<String> parseEntityList(String listStr) {
return Collections.unmodifiableList(Arrays.asList(listStr.split(ENTITY_LIST_DELIMITER)));
}
@@ -261,6 +293,9 @@ public final class TextClassificationConstants {
pw.printPair("getEntityListDefault", mEntityListDefault);
pw.printPair("getEntityListNotEditable", mEntityListNotEditable);
pw.printPair("getEntityListEditable", mEntityListEditable);
+ pw.printPair("getInAppConversationActionTypes", mInAppConversationActionTypesDefault);
+ pw.printPair("getNotificationConversationActionTypes",
+ mNotificationConversationActionTypesDefault);
pw.decreaseIndent();
pw.println();
}
diff --git a/core/java/android/view/textclassifier/TextClassifierImpl.java b/core/java/android/view/textclassifier/TextClassifierImpl.java
index 159bfaa2ab26..798a8208e240 100644
--- a/core/java/android/view/textclassifier/TextClassifierImpl.java
+++ b/core/java/android/view/textclassifier/TextClassifierImpl.java
@@ -40,11 +40,13 @@ import android.os.UserManager;
import android.provider.Browser;
import android.provider.CalendarContract;
import android.provider.ContactsContract;
+import android.text.TextUtils;
import com.android.internal.annotations.GuardedBy;
import com.android.internal.util.IndentingPrintWriter;
import com.android.internal.util.Preconditions;
+import com.google.android.textclassifier.ActionsSuggestionsModel;
import com.google.android.textclassifier.AnnotatorModel;
import com.google.android.textclassifier.LangIdModel;
@@ -90,6 +92,11 @@ public final class TextClassifierImpl implements TextClassifier {
private static final File UPDATED_LANG_ID_MODEL_FILE =
new File("/data/misc/textclassifier/lang_id.model");
+ // Actions
+ private static final String ACTIONS_FACTORY_MODEL_FILENAME_REGEX = "actions_suggestions.model";
+ private static final File UPDATED_ACTIONS_MODEL =
+ new File("/data/misc/textclassifier/actions_suggestions.model");
+
private final Context mContext;
private final TextClassifier mFallback;
private final GenerateLinksLogger mGenerateLinksLogger;
@@ -101,6 +108,8 @@ public final class TextClassifierImpl implements TextClassifier {
private AnnotatorModel mAnnotatorImpl;
@GuardedBy("mLock") // Do not access outside this lock.
private LangIdModel mLangIdImpl;
+ @GuardedBy("mLock") // Do not access outside this lock.
+ private ActionsSuggestionsModel mActionsImpl;
private final Object mLoggerLock = new Object();
@GuardedBy("mLoggerLock") // Do not access outside this lock.
@@ -110,6 +119,7 @@ public final class TextClassifierImpl implements TextClassifier {
private final ModelFileManager mAnnotatorModelFileManager;
private final ModelFileManager mLangIdModelFileManager;
+ private final ModelFileManager mActionsModelFileManager;
public TextClassifierImpl(
Context context, TextClassificationConstants settings, TextClassifier fallback) {
@@ -131,6 +141,13 @@ public final class TextClassifierImpl implements TextClassifier {
UPDATED_LANG_ID_MODEL_FILE,
fd -> -1, // TODO: Replace this with LangIdModel.getVersion(fd)
fd -> ModelFileManager.ModelFile.LANGUAGE_INDEPENDENT));
+ mActionsModelFileManager = new ModelFileManager(
+ new ModelFileManager.ModelFileSupplierImpl(
+ FACTORY_MODEL_DIR,
+ ACTIONS_FACTORY_MODEL_FILENAME_REGEX,
+ UPDATED_ACTIONS_MODEL,
+ ActionsSuggestionsModel::getVersion,
+ ActionsSuggestionsModel::getLocales));
}
public TextClassifierImpl(Context context, TextClassificationConstants settings) {
@@ -346,10 +363,69 @@ public final class TextClassifierImpl implements TextClassifier {
return mFallback.detectLanguage(request);
}
+ @Override
+ public ConversationActions suggestConversationActions(ConversationActions.Request request) {
+ Preconditions.checkNotNull(request);
+ Utils.checkMainThread();
+ try {
+ ActionsSuggestionsModel actionsImpl = getActionsImpl();
+ if (actionsImpl == null) {
+ // Actions model is optional, fallback if it is not available.
+ return mFallback.suggestConversationActions(request);
+ }
+ List<ActionsSuggestionsModel.ConversationMessage> nativeMessages = new ArrayList<>();
+ for (ConversationActions.Message message : request.getConversation()) {
+ if (TextUtils.isEmpty(message.getText())) {
+ continue;
+ }
+ // TODO: We need to map the Person object to user id.
+ int userId = 1;
+ nativeMessages.add(
+ new ActionsSuggestionsModel.ConversationMessage(
+ userId, message.getText().toString()));
+ }
+ ActionsSuggestionsModel.Conversation nativeConversation =
+ new ActionsSuggestionsModel.Conversation(nativeMessages.toArray(
+ new ActionsSuggestionsModel.ConversationMessage[0]));
+
+ ActionsSuggestionsModel.ActionSuggestion[] nativeSuggestions =
+ actionsImpl.suggestActions(nativeConversation, null);
+
+ Collection<String> expectedTypes = resolveActionTypesFromRequest(request);
+ List<ConversationActions.ConversationAction> conversationActions = new ArrayList<>();
+ int maxSuggestions = Math.min(request.getMaxSuggestions(), nativeSuggestions.length);
+ for (int i = 0; i < maxSuggestions; i++) {
+ ActionsSuggestionsModel.ActionSuggestion nativeSuggestion = nativeSuggestions[i];
+ String actionType = nativeSuggestion.getActionType();
+ if (!expectedTypes.contains(actionType)) {
+ continue;
+ }
+ conversationActions.add(
+ new ConversationActions.ConversationAction.Builder(actionType)
+ .setTextReply(nativeSuggestion.getResponseText())
+ .setConfidenceScore(nativeSuggestion.getScore())
+ .build());
+ }
+ return new ConversationActions(conversationActions);
+ } catch (Throwable t) {
+ // Avoid throwing from this method. Log the error.
+ Log.e(LOG_TAG, "Error suggesting conversation actions.", t);
+ }
+ return mFallback.suggestConversationActions(request);
+ }
+
+ private Collection<String> resolveActionTypesFromRequest(ConversationActions.Request request) {
+ List<String> defaultActionTypes =
+ request.getHints().contains(ConversationActions.HINT_FOR_NOTIFICATION)
+ ? mSettings.getNotificationConversationActionTypes()
+ : mSettings.getInAppConversationActionTypes();
+ return request.getTypeConfig().resolveTypes(defaultActionTypes);
+ }
+
private AnnotatorModel getAnnotatorImpl(LocaleList localeList)
throws FileNotFoundException {
synchronized (mLock) {
- localeList = localeList == null ? LocaleList.getEmptyLocaleList() : localeList;
+ localeList = localeList == null ? LocaleList.getDefault() : localeList;
final ModelFileManager.ModelFile bestModel =
mAnnotatorModelFileManager.findBestModelFile(localeList);
if (bestModel == null) {
@@ -386,7 +462,7 @@ public final class TextClassifierImpl implements TextClassifier {
synchronized (mLock) {
if (mLangIdImpl == null) {
final ModelFileManager.ModelFile bestModel =
- mLangIdModelFileManager.findBestModelFile(LocaleList.getEmptyLocaleList());
+ mLangIdModelFileManager.findBestModelFile(null);
if (bestModel == null) {
throw new FileNotFoundException("No LangID model is found");
}
@@ -404,6 +480,30 @@ public final class TextClassifierImpl implements TextClassifier {
}
}
+ @Nullable
+ private ActionsSuggestionsModel getActionsImpl() throws FileNotFoundException {
+ synchronized (mLock) {
+ if (mActionsImpl == null) {
+ // TODO: Use LangID to determine the locale we should use here?
+ final ModelFileManager.ModelFile bestModel =
+ mActionsModelFileManager.findBestModelFile(LocaleList.getDefault());
+ if (bestModel == null) {
+ return null;
+ }
+ final ParcelFileDescriptor pfd = ParcelFileDescriptor.open(
+ new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY);
+ try {
+ if (pfd != null) {
+ mActionsImpl = new ActionsSuggestionsModel(pfd.getFd());
+ }
+ } finally {
+ maybeCloseAndLogError(pfd);
+ }
+ }
+ return mActionsImpl;
+ }
+ }
+
private String createId(String text, int start, int end) {
synchronized (mLock) {
return SelectionSessionLogger.createId(text, start, end, mContext,
@@ -471,11 +571,19 @@ public final class TextClassifierImpl implements TextClassifier {
}
printWriter.decreaseIndent();
printWriter.println("LangID model file(s):");
+ printWriter.increaseIndent();
for (ModelFileManager.ModelFile modelFile :
mLangIdModelFileManager.listModelFiles()) {
printWriter.println(modelFile.toString());
}
printWriter.decreaseIndent();
+ printWriter.println("Actions model file(s):");
+ printWriter.increaseIndent();
+ for (ModelFileManager.ModelFile modelFile :
+ mActionsModelFileManager.listModelFiles()) {
+ printWriter.println(modelFile.toString());
+ }
+ printWriter.decreaseIndent();
printWriter.printPair("mFallback", mFallback);
printWriter.decreaseIndent();
printWriter.println();