3

我正在尝试对 android 中的存储库类进行单元测试,该存储库类使用远程调解器和分页源进行分页。

但是当我运行测试时,返回的结果是空的,虽然实际应该包含测试项。

如下所示:

这是我的存储库

class PostsRepository @Inject constructor(
    private val postsApi: AutomatticPostsApi,
    private val postsDao: PostDao
) : IPostsRepository {

    @ExperimentalPagingApi
    override fun loadPosts(): Flow<PagingData<Post>> {
        println("loadPosts")
        return Pager(
            config = PagingConfig(20),
            initialKey = 1,
            remoteMediator = PostsPageRemoteMediator(
                postsApi,
                postsDao
            ),
            pagingSourceFactory = { postsDao.getPostsPagingSource() }
        ).flow.map { pagingData ->
            pagingData.map { it.toPost() }
        }
    }

}

这是我的UT

@ExperimentalCoroutinesApi
@ExperimentalPagingApi
class PostsRepositoryTest {
    @get:Rule
    val instantTaskExecutorRule = InstantTaskExecutorRule()

    private val coroutineDispatcher = TestCoroutineDispatcher()
    private lateinit var postDao: FakePostDao
    private lateinit var postsApi: CommonAutomatticPostsApi
    private val remotePosts = listOf(createDummyPostResponse())
    private val domainPosts = remotePosts.map { it.toPost() }

    //GIVEN: subject under test
    private lateinit var postsRepository: PostsRepository

    @Before
    fun createRepository() =  coroutineDispatcher.runBlockingTest {
        postsApi = CommonAutomatticPostsApi(remotePosts.toMutableList())
        postDao = FakePostDao()
        postsRepository = PostsRepository(postsApi, postDao)
    }

    @Test
    fun loadPosts_returnsCorrectPosts() = runBlockingTest {
        //WHEN: posts are retrieved from paging source

        launch {

            postsRepository.loadPosts().collect { pagingData ->

                val posts = mutableListOf<Post>()
                pagingData.map {

                    posts.add(it)
                    println(it)
                }

                //THEN: retrieved posts should be the remotePosts
                assertThat(posts, IsEqual(domainPosts))
            }

        }

    }
}

这里是 FakeApi、FakePagingSource 和 FakeDao

class CommonAutomatticPostsApi(val posts: MutableList<PostResponse> = mutableListOf()) : AutomatticPostsApi {
    companion object {
        const val SUBSCRIBER_COUNT = 2L
        const val AUTHOR_NAME = "RR"
    }

    override suspend fun loadPosts(page: Int, itemCount: Int): PostsResponse {
        println("Loaded")
        return PostsResponse(posts.size.toLong(), posts)
    }
}

class FakePostsPagingSource() : PagingSource<Int, PostEntity>() {
    var triggerError = false
    var posts: List<PostEntity> = emptyList()
        set(value) {
            println("set")
            field = value
            invalidate()
        }

    override suspend fun load(params: LoadParams<Int>): LoadResult<Int, PostEntity> {
        println("load")
        if (triggerError) {
            return LoadResult.Error(Exception("A test error triggered"))
        }
        println("not error")

        return LoadResult.Page(
            data = posts,
            prevKey = null,
            nextKey = null
        )
    }

    override fun getRefreshKey(state: PagingState<Int, PostEntity>): Int? {
        println("refresh")

        return state.anchorPosition ?: 1
    }
}

class FakePostDao(val posts: MutableList<PostEntity> = mutableListOf()) : PostDao {
    val pagingSource = FakePostsPagingSource()

    override suspend fun insertPosts(posts: List<PostEntity>) {
        this.posts.addAll(posts)
        println("insertPosts")
        updatePagingSource()
    }

    override suspend fun updatePost(post: PostEntity) {
        onValidPost(post.id) {
            posts[it] = post
            updatePagingSource()
        }
    }

    private fun onValidPost(postId: Long, block: (index: Int) -> Unit): Boolean {
        println("onValidPost")

        val index = posts.indexOfFirst { it.id == postId }
        if (index != -1) {
            block(index)
            return true
        }

        return false
    }

    override suspend fun updatePost(postId: Long, subscriberCount: Long) {
        onValidPost(postId) {
            posts[it] = posts[it].copy(subscriberCount = subscriberCount)
            updatePagingSource()
        }
    }

    override suspend fun getPostById(postId: Long): PostEntity? {
        val index = posts.indexOfFirst { it.id == postId }
        return if (index != -1) {
            posts[index]
        } else {
            null
        }
    }

    override suspend fun getPosts(): List<PostEntity> {
        println("getPosts")

        return posts
    }

    override fun getPostsPagingSource(): PagingSource<Int, PostEntity> {
        println("getPostsPagingSource")

        return pagingSource
    }

    override suspend fun clearAll() {
        posts.clear()
        updatePagingSource()
    }

    private fun updatePagingSource() {
        println("updatePagingSource")

        pagingSource.posts = posts
    }

    @Transaction
    override suspend fun refreshPosts(newPosts: List<PostEntity>) {
        println("refreshPosts")
        clearAll()
        insertPosts(newPosts)
    }
}
4

2 回答 2

4

PagingData 测试(kotlin),版本:3.0.0-rc01。感谢法里德。

   private suspend fun <T : Any> PagingData<T>.collectDataForTest(): List<T> {
        val dcb = object : DifferCallback {
            override fun onChanged(position: Int, count: Int) {}
            override fun onInserted(position: Int, count: Int) {}
            override fun onRemoved(position: Int, count: Int) {}
        }
        val items = mutableListOf<T>()
        val dif = object : PagingDataDiffer<T>(dcb, TestCoroutineDispatcher()) {
            override suspend fun presentNewList(
                previousList: NullPaddedList<T>,
                newList: NullPaddedList<T>,
                newCombinedLoadStates: CombinedLoadStates,
                lastAccessedIndex: Int,
                onListPresentable: () -> Unit
            ): Int? {
                for (idx in 0 until newList.size)
                    items.add(newList.getFromStorage(idx))
                onListPresentable()
                return null
            }
        }
        dif.collectFrom(this)
        return items
    }

用法:

    // searchHistoryList: Flow<PagingData<Your Data type>>
    val tmp = useCase.searchHistoryList.take(1).toList().first()
    // result: List<Your Data type>
    val result = tmp.collectDataForTest()
    assertEquals(expect, result)
于 2021-05-05T05:12:44.353 回答
2

28/8 更新:

这里有一个关于测试分页的最佳实践的完整主题。

旧答案:

如果您查看分页库源代码,您会看到从 dao 或远程调解器获取数据不会开始,除非开始从 pagingData 收集某些内容。我找到了一些实用函数来开始从 pagingData 收集:

suspend fun <T : Any> PagingData<T>.collectData(): List<T> {
    val dcb = object : DifferCallback {
        override fun onChanged(position: Int, count: Int) {}
        override fun onInserted(position: Int, count: Int) {}
        override fun onRemoved(position: Int, count: Int) {}
    }
    val items = mutableListOf<T>()
    val dif = object : PagingDataDiffer<T>(dcb, TestCoroutineDispatcher()) {
        override suspend fun presentNewList(
            previousList: NullPaddedList<T>,
            newList: NullPaddedList<T>,
            newCombinedLoadStates: CombinedLoadStates,
            lastAccessedIndex: Int
        ): Int? {
            for (idx in 0 until newList.size)
                items.add(newList.getFromStorage(idx))
            return null
        }
    }
    dif.collectFrom(this)
    return items
}

你可以像这样使用它:

postsRepository.loadPosts().collect { pagingData ->
        val posts = pagingData.collectData ()
        //THEN: retrieved posts should be the remotePosts
        assertThat(posts, IsEqual(domainPosts))
    }

我曾尝试编写这样的测试,但我认为这不是测试分页的最佳方法,因为您不是测试自己的代码,而是尝试测试分页库。更好的选择是测试你的daoandpagingSource和你的mappers.

于 2021-03-18T07:58:58.897 回答