diff options
| author | TreeHugger Robot <treehugger-gerrit@google.com> | 2018-11-13 11:14:08 +0000 |
|---|---|---|
| committer | Android (Google) Code Review <android-gerrit@google.com> | 2018-11-13 11:14:08 +0000 |
| commit | 9005dafb7440db3130c03bfdacf56759c5a1b606 (patch) | |
| tree | a8c18734674c4833edb0b602bba47ce51a570f49 /core/java/android | |
| parent | bda37423d44663c8d07800ccfa9399244b85f191 (diff) | |
| parent | adbebcc634e7c0a876f887a0625ea77a042e63d9 (diff) | |
Merge "Implements TextClassifierImpl.suggestConversationActions"
Diffstat (limited to 'core/java/android')
3 files changed, 146 insertions, 4 deletions
diff --git a/core/java/android/view/textclassifier/ModelFileManager.java b/core/java/android/view/textclassifier/ModelFileManager.java index adea1259b943..896b516bbf9a 100644 --- a/core/java/android/view/textclassifier/ModelFileManager.java +++ b/core/java/android/view/textclassifier/ModelFileManager.java @@ -74,10 +74,9 @@ public final class ModelFileManager { * @param localeList the required locales, use {@code null} if there is no preference. */ public ModelFile findBestModelFile(@Nullable LocaleList localeList) { - // Specified localeList takes priority over the system default, so it is listed first. final String languages = localeList == null || localeList.isEmpty() ? LocaleList.getDefault().toLanguageTags() - : localeList.toLanguageTags() + "," + LocaleList.getDefault().toLanguageTags(); + : localeList.toLanguageTags(); final List<Locale.LanguageRange> languageRangeList = Locale.LanguageRange.parse(languages); ModelFile bestModel = null; diff --git a/core/java/android/view/textclassifier/TextClassificationConstants.java b/core/java/android/view/textclassifier/TextClassificationConstants.java index 2fc74221a456..50801a2b3e3f 100644 --- a/core/java/android/view/textclassifier/TextClassificationConstants.java +++ b/core/java/android/view/textclassifier/TextClassificationConstants.java @@ -90,6 +90,10 @@ public final class TextClassificationConstants { "entity_list_not_editable"; private static final String ENTITY_LIST_EDITABLE = "entity_list_editable"; + private static final String IN_APP_CONVERSATION_ACTION_TYPES_DEFAULT = + "in_app_conversation_action_types_default"; + private static final String NOTIFICATION_CONVERSATION_ACTION_TYPES_DEFAULT = + "notification_conversation_action_types_default"; private static final boolean LOCAL_TEXT_CLASSIFIER_ENABLED_DEFAULT = true; private static final boolean SYSTEM_TEXT_CLASSIFIER_ENABLED_DEFAULT = true; @@ -111,6 +115,18 @@ public final class TextClassificationConstants { .add(TextClassifier.TYPE_DATE) .add(TextClassifier.TYPE_DATE_TIME) .add(TextClassifier.TYPE_FLIGHT_NUMBER).toString(); + private static final String CONVERSATION_ACTIONS_TYPES_DEFAULT_VALUES = + new StringJoiner(ENTITY_LIST_DELIMITER) + .add(ConversationActions.TYPE_TEXT_REPLY) + .add(ConversationActions.TYPE_CREATE_REMINDER) + .add(ConversationActions.TYPE_CALL_PHONE) + .add(ConversationActions.TYPE_OPEN_URL) + .add(ConversationActions.TYPE_SEND_EMAIL) + .add(ConversationActions.TYPE_SEND_SMS) + .add(ConversationActions.TYPE_TRACK_FLIGHT) + .add(ConversationActions.TYPE_VIEW_CALENDAR) + .add(ConversationActions.TYPE_VIEW_MAP) + .toString(); private final boolean mSystemTextClassifierEnabled; private final boolean mLocalTextClassifierEnabled; @@ -126,6 +142,8 @@ public final class TextClassificationConstants { private final List<String> mEntityListDefault; private final List<String> mEntityListNotEditable; private final List<String> mEntityListEditable; + private final List<String> mInAppConversationActionTypesDefault; + private final List<String> mNotificationConversationActionTypesDefault; private TextClassificationConstants(@Nullable String settings) { final KeyValueListParser parser = new KeyValueListParser(','); @@ -177,6 +195,12 @@ public final class TextClassificationConstants { mEntityListEditable = parseEntityList(parser.getString( ENTITY_LIST_EDITABLE, ENTITY_LIST_DEFAULT_VALUE)); + mInAppConversationActionTypesDefault = parseEntityList(parser.getString( + IN_APP_CONVERSATION_ACTION_TYPES_DEFAULT, + CONVERSATION_ACTIONS_TYPES_DEFAULT_VALUES)); + mNotificationConversationActionTypesDefault = parseEntityList(parser.getString( + NOTIFICATION_CONVERSATION_ACTION_TYPES_DEFAULT, + CONVERSATION_ACTIONS_TYPES_DEFAULT_VALUES)); } /** Load from a settings string. */ @@ -240,6 +264,14 @@ public final class TextClassificationConstants { return mEntityListEditable; } + public List<String> getInAppConversationActionTypes() { + return mInAppConversationActionTypesDefault; + } + + public List<String> getNotificationConversationActionTypes() { + return mNotificationConversationActionTypesDefault; + } + private static List<String> parseEntityList(String listStr) { return Collections.unmodifiableList(Arrays.asList(listStr.split(ENTITY_LIST_DELIMITER))); } @@ -261,6 +293,9 @@ public final class TextClassificationConstants { pw.printPair("getEntityListDefault", mEntityListDefault); pw.printPair("getEntityListNotEditable", mEntityListNotEditable); pw.printPair("getEntityListEditable", mEntityListEditable); + pw.printPair("getInAppConversationActionTypes", mInAppConversationActionTypesDefault); + pw.printPair("getNotificationConversationActionTypes", + mNotificationConversationActionTypesDefault); pw.decreaseIndent(); pw.println(); } diff --git a/core/java/android/view/textclassifier/TextClassifierImpl.java b/core/java/android/view/textclassifier/TextClassifierImpl.java index 159bfaa2ab26..798a8208e240 100644 --- a/core/java/android/view/textclassifier/TextClassifierImpl.java +++ b/core/java/android/view/textclassifier/TextClassifierImpl.java @@ -40,11 +40,13 @@ import android.os.UserManager; import android.provider.Browser; import android.provider.CalendarContract; import android.provider.ContactsContract; +import android.text.TextUtils; import com.android.internal.annotations.GuardedBy; import com.android.internal.util.IndentingPrintWriter; import com.android.internal.util.Preconditions; +import com.google.android.textclassifier.ActionsSuggestionsModel; import com.google.android.textclassifier.AnnotatorModel; import com.google.android.textclassifier.LangIdModel; @@ -90,6 +92,11 @@ public final class TextClassifierImpl implements TextClassifier { private static final File UPDATED_LANG_ID_MODEL_FILE = new File("/data/misc/textclassifier/lang_id.model"); + // Actions + private static final String ACTIONS_FACTORY_MODEL_FILENAME_REGEX = "actions_suggestions.model"; + private static final File UPDATED_ACTIONS_MODEL = + new File("/data/misc/textclassifier/actions_suggestions.model"); + private final Context mContext; private final TextClassifier mFallback; private final GenerateLinksLogger mGenerateLinksLogger; @@ -101,6 +108,8 @@ public final class TextClassifierImpl implements TextClassifier { private AnnotatorModel mAnnotatorImpl; @GuardedBy("mLock") // Do not access outside this lock. private LangIdModel mLangIdImpl; + @GuardedBy("mLock") // Do not access outside this lock. + private ActionsSuggestionsModel mActionsImpl; private final Object mLoggerLock = new Object(); @GuardedBy("mLoggerLock") // Do not access outside this lock. @@ -110,6 +119,7 @@ public final class TextClassifierImpl implements TextClassifier { private final ModelFileManager mAnnotatorModelFileManager; private final ModelFileManager mLangIdModelFileManager; + private final ModelFileManager mActionsModelFileManager; public TextClassifierImpl( Context context, TextClassificationConstants settings, TextClassifier fallback) { @@ -131,6 +141,13 @@ public final class TextClassifierImpl implements TextClassifier { UPDATED_LANG_ID_MODEL_FILE, fd -> -1, // TODO: Replace this with LangIdModel.getVersion(fd) fd -> ModelFileManager.ModelFile.LANGUAGE_INDEPENDENT)); + mActionsModelFileManager = new ModelFileManager( + new ModelFileManager.ModelFileSupplierImpl( + FACTORY_MODEL_DIR, + ACTIONS_FACTORY_MODEL_FILENAME_REGEX, + UPDATED_ACTIONS_MODEL, + ActionsSuggestionsModel::getVersion, + ActionsSuggestionsModel::getLocales)); } public TextClassifierImpl(Context context, TextClassificationConstants settings) { @@ -346,10 +363,69 @@ public final class TextClassifierImpl implements TextClassifier { return mFallback.detectLanguage(request); } + @Override + public ConversationActions suggestConversationActions(ConversationActions.Request request) { + Preconditions.checkNotNull(request); + Utils.checkMainThread(); + try { + ActionsSuggestionsModel actionsImpl = getActionsImpl(); + if (actionsImpl == null) { + // Actions model is optional, fallback if it is not available. + return mFallback.suggestConversationActions(request); + } + List<ActionsSuggestionsModel.ConversationMessage> nativeMessages = new ArrayList<>(); + for (ConversationActions.Message message : request.getConversation()) { + if (TextUtils.isEmpty(message.getText())) { + continue; + } + // TODO: We need to map the Person object to user id. + int userId = 1; + nativeMessages.add( + new ActionsSuggestionsModel.ConversationMessage( + userId, message.getText().toString())); + } + ActionsSuggestionsModel.Conversation nativeConversation = + new ActionsSuggestionsModel.Conversation(nativeMessages.toArray( + new ActionsSuggestionsModel.ConversationMessage[0])); + + ActionsSuggestionsModel.ActionSuggestion[] nativeSuggestions = + actionsImpl.suggestActions(nativeConversation, null); + + Collection<String> expectedTypes = resolveActionTypesFromRequest(request); + List<ConversationActions.ConversationAction> conversationActions = new ArrayList<>(); + int maxSuggestions = Math.min(request.getMaxSuggestions(), nativeSuggestions.length); + for (int i = 0; i < maxSuggestions; i++) { + ActionsSuggestionsModel.ActionSuggestion nativeSuggestion = nativeSuggestions[i]; + String actionType = nativeSuggestion.getActionType(); + if (!expectedTypes.contains(actionType)) { + continue; + } + conversationActions.add( + new ConversationActions.ConversationAction.Builder(actionType) + .setTextReply(nativeSuggestion.getResponseText()) + .setConfidenceScore(nativeSuggestion.getScore()) + .build()); + } + return new ConversationActions(conversationActions); + } catch (Throwable t) { + // Avoid throwing from this method. Log the error. + Log.e(LOG_TAG, "Error suggesting conversation actions.", t); + } + return mFallback.suggestConversationActions(request); + } + + private Collection<String> resolveActionTypesFromRequest(ConversationActions.Request request) { + List<String> defaultActionTypes = + request.getHints().contains(ConversationActions.HINT_FOR_NOTIFICATION) + ? mSettings.getNotificationConversationActionTypes() + : mSettings.getInAppConversationActionTypes(); + return request.getTypeConfig().resolveTypes(defaultActionTypes); + } + private AnnotatorModel getAnnotatorImpl(LocaleList localeList) throws FileNotFoundException { synchronized (mLock) { - localeList = localeList == null ? LocaleList.getEmptyLocaleList() : localeList; + localeList = localeList == null ? LocaleList.getDefault() : localeList; final ModelFileManager.ModelFile bestModel = mAnnotatorModelFileManager.findBestModelFile(localeList); if (bestModel == null) { @@ -386,7 +462,7 @@ public final class TextClassifierImpl implements TextClassifier { synchronized (mLock) { if (mLangIdImpl == null) { final ModelFileManager.ModelFile bestModel = - mLangIdModelFileManager.findBestModelFile(LocaleList.getEmptyLocaleList()); + mLangIdModelFileManager.findBestModelFile(null); if (bestModel == null) { throw new FileNotFoundException("No LangID model is found"); } @@ -404,6 +480,30 @@ public final class TextClassifierImpl implements TextClassifier { } } + @Nullable + private ActionsSuggestionsModel getActionsImpl() throws FileNotFoundException { + synchronized (mLock) { + if (mActionsImpl == null) { + // TODO: Use LangID to determine the locale we should use here? + final ModelFileManager.ModelFile bestModel = + mActionsModelFileManager.findBestModelFile(LocaleList.getDefault()); + if (bestModel == null) { + return null; + } + final ParcelFileDescriptor pfd = ParcelFileDescriptor.open( + new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY); + try { + if (pfd != null) { + mActionsImpl = new ActionsSuggestionsModel(pfd.getFd()); + } + } finally { + maybeCloseAndLogError(pfd); + } + } + return mActionsImpl; + } + } + private String createId(String text, int start, int end) { synchronized (mLock) { return SelectionSessionLogger.createId(text, start, end, mContext, @@ -471,11 +571,19 @@ public final class TextClassifierImpl implements TextClassifier { } printWriter.decreaseIndent(); printWriter.println("LangID model file(s):"); + printWriter.increaseIndent(); for (ModelFileManager.ModelFile modelFile : mLangIdModelFileManager.listModelFiles()) { printWriter.println(modelFile.toString()); } printWriter.decreaseIndent(); + printWriter.println("Actions model file(s):"); + printWriter.increaseIndent(); + for (ModelFileManager.ModelFile modelFile : + mActionsModelFileManager.listModelFiles()) { + printWriter.println(modelFile.toString()); + } + printWriter.decreaseIndent(); printWriter.printPair("mFallback", mFallback); printWriter.decreaseIndent(); printWriter.println(); |
