Call setPageSize for LegacyPagingSource from .asPagingSourceFactory

There are currently some instances where we receive a PagingSource, but
do not check if it is an instance of LegacyPagingSource. This can cause
use-cases which explicitly call .asPagingSourceFactory.invoke() such as
Room to fail to set pageSize, which can cause an upgrade to paging3 to
inadvertently load a different number of items than expected.

Fixes: 181805612
Test: ./gradlew paging:paging-runtime:cC
Change-Id: Ic394fabf3e4712dcc9484eca3ddd2da2d8841a8b
diff --git a/paging/common/src/main/kotlin/androidx/paging/LegacyPagingSource.kt b/paging/common/src/main/kotlin/androidx/paging/LegacyPagingSource.kt
index 0dec76a..cdf77c1 100644
--- a/paging/common/src/main/kotlin/androidx/paging/LegacyPagingSource.kt
+++ b/paging/common/src/main/kotlin/androidx/paging/LegacyPagingSource.kt
@@ -16,6 +16,7 @@
 
 package androidx.paging
 
+import androidx.annotation.RestrictTo
 import androidx.paging.DataSource.KeyType.ITEM_KEYED
 import androidx.paging.DataSource.KeyType.PAGE_KEYED
 import androidx.paging.DataSource.KeyType.POSITIONAL
@@ -30,8 +31,11 @@
 
 /**
  * A wrapper around [DataSource] which adapts it to the [PagingSource] API.
+ *
+ * @hide
  */
-internal class LegacyPagingSource<Key : Any, Value : Any>(
+@RestrictTo(RestrictTo.Scope.LIBRARY_GROUP)
+public class LegacyPagingSource<Key : Any, Value : Any>(
     private val fetchDispatcher: CoroutineDispatcher,
     internal val dataSource: DataSource<Key, Value>
 ) : PagingSource<Key, Value>() {
@@ -60,7 +64,11 @@
         }
     }
 
-    fun setPageSize(pageSize: Int) {
+    /**
+     * @hide
+     */
+    @RestrictTo(RestrictTo.Scope.LIBRARY_GROUP)
+    public fun setPageSize(pageSize: Int) {
         check(this.pageSize == PAGE_SIZE_NOT_SET || pageSize == this.pageSize) {
             "Page size is already set to ${this.pageSize}."
         }
@@ -147,7 +155,7 @@
     override val jumpingSupported: Boolean
         get() = dataSource.type == POSITIONAL
 
-    companion object {
-        const val PAGE_SIZE_NOT_SET = Integer.MIN_VALUE
+    private companion object {
+        private const val PAGE_SIZE_NOT_SET = Integer.MIN_VALUE
     }
 }
diff --git a/paging/common/src/main/kotlin/androidx/paging/PagedList.kt b/paging/common/src/main/kotlin/androidx/paging/PagedList.kt
index 0dd12d2..39b8e5c 100644
--- a/paging/common/src/main/kotlin/androidx/paging/PagedList.kt
+++ b/paging/common/src/main/kotlin/androidx/paging/PagedList.kt
@@ -493,9 +493,11 @@
                 LegacyPagingSource(
                     fetchDispatcher = fetchDispatcher,
                     dataSource = dataSource
-                ).also {
-                    it.setPageSize(config.pageSize)
-                }
+                )
+            }
+
+            if (pagingSource is LegacyPagingSource) {
+                pagingSource.setPageSize(config.pageSize)
             }
 
             check(pagingSource != null) {
diff --git a/paging/runtime/src/androidTest/java/androidx/paging/LivePagedListBuilderTest.kt b/paging/runtime/src/androidTest/java/androidx/paging/LivePagedListBuilderTest.kt
index e07551d..2b6ab75 100644
--- a/paging/runtime/src/androidTest/java/androidx/paging/LivePagedListBuilderTest.kt
+++ b/paging/runtime/src/androidTest/java/androidx/paging/LivePagedListBuilderTest.kt
@@ -27,8 +27,10 @@
 import androidx.test.ext.junit.runners.AndroidJUnit4
 import androidx.test.filters.SmallTest
 import androidx.testutils.TestExecutor
+import com.google.common.truth.Truth.assertThat
 import kotlinx.coroutines.ExperimentalCoroutinesApi
 import kotlinx.coroutines.Runnable
+import kotlinx.coroutines.asCoroutineDispatcher
 import org.junit.After
 import org.junit.Assert.assertEquals
 import org.junit.Assert.assertNotNull
@@ -260,6 +262,60 @@
         )
     }
 
+    @Test
+    fun legacyPagingSourcePageSize() {
+        val dataSources = mutableListOf<DataSource<Int, Int>>()
+        val pagedLists = mutableListOf<PagedList<Int>>()
+        val requestedLoadSizes = mutableListOf<Int>()
+        val livePagedList = LivePagedListBuilder(
+            pagingSourceFactory = object : DataSource.Factory<Int, Int>() {
+                override fun create(): DataSource<Int, Int> {
+                    return object : PositionalDataSource<Int>() {
+                        override fun loadInitial(
+                            params: LoadInitialParams,
+                            callback: LoadInitialCallback<Int>
+                        ) {
+                            requestedLoadSizes.add(params.requestedLoadSize)
+                            callback.onResult(listOf(1, 2, 3), 0)
+                        }
+
+                        override fun loadRange(
+                            params: LoadRangeParams,
+                            callback: LoadRangeCallback<Int>
+                        ) {
+                            requestedLoadSizes.add(params.loadSize)
+                        }
+                    }.also {
+                        dataSources.add(it)
+                    }
+                }
+            }.asPagingSourceFactory(backgroundExecutor.asCoroutineDispatcher()),
+            config = PagedList.Config.Builder()
+                .setPageSize(3)
+                .setInitialLoadSizeHint(3)
+                .setEnablePlaceholders(false)
+                .build()
+        ).setFetchExecutor(backgroundExecutor)
+            .build()
+
+        livePagedList.observeForever { pagedLists.add(it) }
+
+        drain()
+        assertThat(requestedLoadSizes).containsExactly(3)
+
+        pagedLists.last().loadAround(2)
+        drain()
+        assertThat(requestedLoadSizes).containsExactly(3, 3)
+
+        dataSources[0].invalidate()
+        drain()
+        assertThat(requestedLoadSizes).containsExactly(3, 3, 3)
+
+        pagedLists.last().loadAround(2)
+        drain()
+        assertThat(requestedLoadSizes).containsExactly(3, 3, 3, 3)
+    }
+
     private fun drain() {
         var executed: Boolean
         do {
diff --git a/paging/runtime/src/main/java/androidx/paging/LivePagedList.kt b/paging/runtime/src/main/java/androidx/paging/LivePagedList.kt
index 29cd5d2..e9b696a 100644
--- a/paging/runtime/src/main/java/androidx/paging/LivePagedList.kt
+++ b/paging/runtime/src/main/java/androidx/paging/LivePagedList.kt
@@ -75,6 +75,9 @@
             currentData.pagingSource.unregisterInvalidatedCallback(callback)
             val pagingSource = pagingSourceFactory()
             pagingSource.registerInvalidatedCallback(callback)
+            if (pagingSource is LegacyPagingSource) {
+                pagingSource.setPageSize(config.pageSize)
+            }
 
             withContext(notifyDispatcher) {
                 currentData.setInitialLoadState(REFRESH, Loading)
@@ -83,6 +86,7 @@
             @Suppress("UNCHECKED_CAST")
             val lastKey = currentData.lastKey as Key?
             val params = config.toRefreshLoadParams(lastKey)
+
             when (val initialResult = pagingSource.load(params)) {
                 is PagingSource.LoadResult.Error -> {
                     currentData.setInitialLoadState(
diff --git a/paging/rxjava2/src/main/java/androidx/paging/RxPagedListBuilder.kt b/paging/rxjava2/src/main/java/androidx/paging/RxPagedListBuilder.kt
index d721431..9397d75 100644
--- a/paging/rxjava2/src/main/java/androidx/paging/RxPagedListBuilder.kt
+++ b/paging/rxjava2/src/main/java/androidx/paging/RxPagedListBuilder.kt
@@ -391,6 +391,9 @@
                 currentData.pagingSource.unregisterInvalidatedCallback(callback)
                 val pagingSource = pagingSourceFactory()
                 pagingSource.registerInvalidatedCallback(callback)
+                if (pagingSource is LegacyPagingSource) {
+                    pagingSource.setPageSize(config.pageSize)
+                }
 
                 withContext(notifyDispatcher) {
                     currentData.setInitialLoadState(LoadType.REFRESH, Loading)
diff --git a/paging/rxjava3/src/main/java/androidx/paging/rxjava3/RxPagedListBuilder.kt b/paging/rxjava3/src/main/java/androidx/paging/rxjava3/RxPagedListBuilder.kt
index 13de21b..0a48711 100644
--- a/paging/rxjava3/src/main/java/androidx/paging/rxjava3/RxPagedListBuilder.kt
+++ b/paging/rxjava3/src/main/java/androidx/paging/rxjava3/RxPagedListBuilder.kt
@@ -21,6 +21,7 @@
 import androidx.paging.DataSource
 import androidx.paging.InitialPagedList
 import androidx.paging.InitialPagingSource
+import androidx.paging.LegacyPagingSource
 import androidx.paging.LoadState
 import androidx.paging.LoadState.Loading
 import androidx.paging.LoadType
@@ -399,6 +400,9 @@
                 currentData.pagingSource.unregisterInvalidatedCallback(callback)
                 val pagingSource = pagingSourceFactory()
                 pagingSource.registerInvalidatedCallback(callback)
+                if (pagingSource is LegacyPagingSource) {
+                    pagingSource.setPageSize(config.pageSize)
+                }
 
                 withContext(notifyDispatcher) {
                     currentData.setInitialLoadState(LoadType.REFRESH, Loading)