缘起:TABLESAMPLE 的非随机性
最近需要实现一段 Spark SQL 逻辑,对数据集进行抽样指定的行数,才发现直接使用TABLESAMPLE函数抽样指定行数的方法其实是非随机的。
由于数据集较大,刚开始的逻辑是,取窗口函数随机排序后 row_number 的前 n 行。但运行速度较慢,所以想起了 TABLESAMLE 函数,支持直接取 Rows, 尝试后发现速度特别快,基本上几秒内就完成对亿级数据的采样。所以好奇就去查看文档和代码逻辑。
The TABLESAMPLE statement is used to sample the table. It supports the following sampling methods:
TABLESAMPLE(x ROWS): Sample the table down to the given number of rows.
TABLESAMPLE(x PERCENT): Sample the table down to the given percentage. Note that percentages are defined as a number between 0 and 100.
TABLESAMPLE(BUCKET x OUT OF y): Sample the table down to a x out of y fraction.
Note: TABLESAMPLE returns the approximate number of rows or fraction requested.
文档中没有对实现逻辑有过多的说明,所以去代码中找问题。
源码中,匹配 SampleByRowsContext时,调用的方法是 Limit(expression(ctx.expression), query),也就是说和 limit rows是一个逻辑。
而 SampleByPercentileContext实现的才是随机采样。
所以,如果对抽样的随机性有要求,还是老老实实用 SampleByPercentileContext,或者窗口函数。
总结:Spark SQL 随机抽样方法
随机抽样
抽取固定数量
使用窗口函数 + 随机排序进行抽样
WITH RankedData
AS (
SELECT *,
row_number()
OVER (
ORDER BY rand(
2077))
as rn
FROM your_table)
SELECT *FROM RankedData
WHERE rn
<= 1000
抽取固定比例
直接使用TABLESAMPLE函数,实现对整体的固定比例抽样
SELECT *FROM your_table
TABLESAMPLE (
10 PERCENT)
分层随机抽样
分层抽样通常在数据科学中使用较多,为了保证样本的随机性,通常情况下,我们需要对 y标签进行分层抽样;如果考虑时间因素的影响,为了保证样本时间的随机性,通常还需要对月份 + y标签或者日期 + y标签进行双层的分层抽样。
抽取固定数量
WITH RankedData
AS (
SELECT *,
row_number()
OVER (
PARTITION BY 分层字段
ORDER BY rand(
2077))
as rn
FROM your_table)
SELECT *FROM RankedData
WHERE rn
<= 100 -- 每层抽取100条数据
抽取固定比例
对单字段分层
通常用在对时间的随机性要求不严格的场景,如在二分类任务中,可以将分桶字段设置为y列,那么就可以保证最终抽样出来的样本y均值和总体y均值相等:
WITH RankedData
AS (
SELECT *,
row_number()
OVER (
PARTITION BY 分层字段
ORDER BY rand())
as rn,
count(
*)
OVER (
PARTITION BY 分层字段)
as total_count
FROM your_table)
SELECT *FROM RankedData
WHERE rn
<= total_count
* 0.1 -- 每层抽取10%的数据
对双字段分层
通常用在对时间的随机性也有严格要求的场景,这时可以将分层字段1和分层字段2分别设置为y列和时间列,那么就可以保证样本逐时间的y分布和整体随时间的y分布是近似的:
WITH RankedData
AS (
SELECT *,
row_number()
OVER (
PARTITION BY 分层字段
1, 分层字段
2 ORDER BY rand())
as rn,
count(
*)
OVER (
PARTITION BY 分层字段
1, 分层字段
2)
as total_count
FROM your_table)
SELECT *FROM RankedData
WHERE rn
<= total_count
* 0.1 -- 每层抽取10%的数据
附 相关源码:
/** * Add a [[Sample]] to a logical plan. * * This currently supports the following sampling methods: * - TABLESAMPLE(x ROWS): Sample the table down to the given number of rows. * - TABLESAMPLE(x PERCENT) [REPEATABLE (y)]: Sample the table down to the given percentage with * seed y. Note that percentages are defined as a number between 0 and 100. * - TABLESAMPLE(BUCKET x OUT OF y) [REPEATABLE (z)]: Sample the table down to a x divided by * y fraction with seed z. */ private def
withSample(ctx: SampleContext, query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
// Create a sampled plan if we need one. def
sample(fraction: Double, seed: Long): Sample = {
// The range of fraction accepted by Sample is [0, 1]. Because Hives block sampling // function takes X PERCENT as the input and the range of X is [0, 100], we need to // adjust the fraction. val eps = RandomSampler.roundingEpsilon
validate(fraction >= 0.0 - eps && fraction <= 1.0 + eps, s"Sampling fraction ($fraction) must be on interval [0, 1]", ctx) Sample(
0.0, fraction, withReplacement =
false, seed, query) }
if (ctx.sampleMethod() ==
null) {
throw QueryParsingErrors.emptyInputForTableSampleError(ctx) }
val seed = if (ctx.seed !=
null) { ctx.seed.getText.toLong }
else { (math.random() *
1000).toLong } ctx.sampleMethod() match {
case ctx: SampleByRowsContext => Limit(expression(ctx.expression), query)
case ctx: SampleByPercentileContext =>
val fraction = ctx.percentage.getText.toDouble
val sign = if (ctx.negativeSign ==
null)
1 else -
1 sample(sign * fraction /
100.0d, seed)
case ctx: SampleByBytesContext =>
val bytesStr = ctx.bytes.getText
if (bytesStr.matches("[0-9]+[bBkKmMgG]")) {
throw QueryParsingErrors.tableSampleByBytesUnsupportedError(
"byteLengthLiteral", ctx) }
else {
throw QueryParsingErrors.invalidByteLengthLiteralError(bytesStr, ctx) }
case ctx: SampleByBucketContext
if ctx.ON() !=
null =>
if (ctx.identifier !=
null) {
throw QueryParsingErrors.tableSampleByBytesUnsupportedError(
"BUCKET x OUT OF y ON colname", ctx) }
else {
throw QueryParsingErrors.tableSampleByBytesUnsupportedError(
"BUCKET x OUT OF y ON function", ctx) }
case ctx: SampleByBucketContext => sample(ctx.numerator.getText.toDouble / ctx.denominator.getText.toDouble, seed) } }