Skip to content

Commit 5c37eb7

Browse files
author
James Lee
committed
add typed dataset
1 parent 8679f5e commit 5c37eb7

File tree

3 files changed

+135
-1
lines changed

3 files changed

+135
-1
lines changed
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
package com.sparkTutorial.sparkSql;
2+
3+
import java.io.Serializable;
4+
5+
public class Response implements Serializable {
6+
private String country;
7+
private float ageMidPoint;
8+
private String occupation;
9+
private float salaryMidPoint;
10+
11+
public Response(String country, float ageMidPoint, String occupation, float salaryMidPoint) {
12+
this.country = country;
13+
this.ageMidPoint = ageMidPoint;
14+
this.occupation = occupation;
15+
this.salaryMidPoint = salaryMidPoint;
16+
}
17+
18+
public Response() {
19+
}
20+
21+
public String getCountry() {
22+
return country;
23+
}
24+
25+
public void setCountry(String country) {
26+
this.country = country;
27+
}
28+
29+
public float getAgeMidPoint() {
30+
return ageMidPoint;
31+
}
32+
33+
public void setAgeMidPoint(float ageMidPoint) {
34+
this.ageMidPoint = ageMidPoint;
35+
}
36+
37+
public String getOccupation() {
38+
return occupation;
39+
}
40+
41+
public void setOccupation(String occupation) {
42+
this.occupation = occupation;
43+
}
44+
45+
public float getSalaryMidPoint() {
46+
return salaryMidPoint;
47+
}
48+
49+
public void setSalaryMidPoint(float salaryMidPoint) {
50+
this.salaryMidPoint = salaryMidPoint;
51+
}
52+
}

src/main/java/com/sparkTutorial/sparkSql/StackOverFlowSurvey.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ public class StackOverFlowSurvey {
1414

1515
private static final String AGE_MIDPOINT = "age_midpoint";
1616
private static final String SALARY_MIDPOINT = "salary_midpoint";
17-
public static final String SALARY_MIDPOINT_BUCKET = "salary_midpoint_bucket";
17+
private static final String SALARY_MIDPOINT_BUCKET = "salary_midpoint_bucket";
1818

1919
public static void main(String[] args) throws Exception {
2020

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
package com.sparkTutorial.sparkSql;
2+
3+
import org.apache.log4j.Level;
4+
import org.apache.log4j.Logger;
5+
import org.apache.spark.SparkConf;
6+
import org.apache.spark.api.java.JavaRDD;
7+
import org.apache.spark.api.java.JavaSparkContext;
8+
import org.apache.spark.sql.Dataset;
9+
import org.apache.spark.sql.Encoders;
10+
import org.apache.spark.sql.SparkSession;
11+
12+
import static org.apache.spark.sql.functions.avg;
13+
import static org.apache.spark.sql.functions.max;
14+
15+
16+
public class TypedDataset {
17+
private static final String AGE_MIDPOINT = "ageMidpoint";
18+
private static final String SALARY_MIDPOINT = "salaryMidPoint";
19+
private static final String SALARY_MIDPOINT_BUCKET = "salaryMidpointBucket";
20+
private static final float NULL_VALUE = -1.0f;
21+
private static final String COMMA_DELIMITER = ",(?=([^\"]*\"[^\"]*\")*[^\"]*$)";
22+
23+
public static void main(String[] args) throws Exception {
24+
25+
Logger.getLogger("org").setLevel(Level.ERROR);
26+
SparkConf conf = new SparkConf().setAppName("StackOverFlowSurvey").setMaster("local[1]");
27+
28+
JavaSparkContext sc = new JavaSparkContext(conf);
29+
30+
SparkSession session = SparkSession.builder().appName("StackOverFlowSurvey").master("local[1]").getOrCreate();
31+
32+
JavaRDD<String> lines = sc.textFile("in/2016-stack-overflow-survey-responses.csv");
33+
34+
JavaRDD<Response> responseRDD = lines
35+
.filter(line -> !line.split(COMMA_DELIMITER, -1)[2].equals("country"))
36+
.map(line -> {
37+
String[] splits = line.split(COMMA_DELIMITER, -1);
38+
return new Response(splits[2], convertStringToFloat(splits[6]), splits[9], convertStringToFloat(splits[14]));
39+
});
40+
Dataset<Response> responseDataset = session.createDataset(responseRDD.rdd(), Encoders.bean(Response.class));
41+
42+
System.out.println("=== Print out schema ===");
43+
responseDataset.printSchema();
44+
45+
System.out.println("=== Print 20 records of responses table ===");
46+
responseDataset.show(20);
47+
48+
System.out.println("=== Print records where the response is from Afghanistan ===");
49+
responseDataset.filter(response -> response.getCountry().equals("Afghanistan")).show();
50+
51+
System.out.println("=== Print the count of occupations ===");
52+
responseDataset.groupBy(responseDataset.col("occupation")).count().show();
53+
54+
55+
System.out.println("=== Print records with average mid age less than 20 ===");
56+
responseDataset.filter(response -> response.getAgeMidPoint() != NULL_VALUE && response.getAgeMidPoint() < 20).show();
57+
58+
System.out.println("=== Print the result with salary middle point in descending order ===");
59+
responseDataset.orderBy(responseDataset.col(SALARY_MIDPOINT ).desc()).show();
60+
61+
System.out.println("=== Group by country and aggregate by average salary middle point and max age middle point ===");
62+
responseDataset
63+
.filter(response -> response.getSalaryMidPoint() != NULL_VALUE)
64+
.groupBy("country")
65+
.agg(avg(SALARY_MIDPOINT), max(AGE_MIDPOINT))
66+
.show();
67+
68+
System.out.println("=== Group by salary bucket ===");
69+
70+
responseDataset
71+
.map(response -> Math.round(response.getSalaryMidPoint()/20000) * 20000, Encoders.INT())
72+
.withColumnRenamed("value", SALARY_MIDPOINT_BUCKET)
73+
.groupBy(SALARY_MIDPOINT_BUCKET)
74+
.count()
75+
.orderBy(SALARY_MIDPOINT_BUCKET).show();
76+
}
77+
78+
private static float convertStringToFloat(String split) {
79+
return split.isEmpty() ? NULL_VALUE : Float.valueOf(split);
80+
}
81+
82+
}

0 commit comments

Comments
 (0)