Skip to content

Commit a4f93c7

Browse files
committed
Implement case-insensitive key comparison for csvjoin
1 parent ffe5f44 commit a4f93c7

File tree

4 files changed

+108
-26
lines changed

4 files changed

+108
-26
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ docs/_build
1010
.coverage
1111
.tox
1212
cover
13+
env

csvkit/join.py

Lines changed: 65 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
#!/usr/bin/env python
22

33

4-
def _get_ordered_keys(rows, column_index):
4+
def _get_keys(rows, column_index, lowercase=False):
55
"""
6-
Get ordered keys from rows, given the key column index.
6+
Get keys from rows as keys in a dictionary (i.e. unordered), given the key column index.
77
"""
8-
return [r[column_index] for r in rows]
8+
pairs = ((r[column_index], True) for r in rows)
9+
return CaseInsensitiveDict(pairs) if lowercase else dict(pairs)
910

1011

11-
def _get_mapped_keys(rows, column_index):
12-
mapped_keys = {}
12+
def _get_mapped_keys(rows, column_index, case_insensitive=False):
13+
mapped_keys = CaseInsensitiveDict() if case_insensitive else {}
1314

1415
for r in rows:
1516
key = r[column_index]
@@ -21,6 +22,11 @@ def _get_mapped_keys(rows, column_index):
2122

2223
return mapped_keys
2324

25+
def _lower(key):
26+
"""Transforms a string to lowercase, leaves other types alone."""
27+
keyfn = getattr(key, 'lower', None)
28+
return keyfn() if keyfn else key
29+
2430

2531
def sequential_join(left_rows, right_rows, header=True):
2632
"""
@@ -49,7 +55,7 @@ def sequential_join(left_rows, right_rows, header=True):
4955
return output
5056

5157

52-
def inner_join(left_rows, left_column_id, right_rows, right_column_id, header=True):
58+
def inner_join(left_rows, left_column_id, right_rows, right_column_id, header=True, ignore_case=False):
5359
"""
5460
Execute an inner join on two tables and return the combined table.
5561
"""
@@ -63,7 +69,7 @@ def inner_join(left_rows, left_column_id, right_rows, right_column_id, header=Tr
6369
output = []
6470

6571
# Map right rows to keys
66-
right_mapped_keys = _get_mapped_keys(right_rows, right_column_id)
72+
right_mapped_keys = _get_mapped_keys(right_rows, right_column_id, ignore_case)
6773

6874
for left_row in left_rows:
6975
len_left_row = len(left_row)
@@ -80,7 +86,7 @@ def inner_join(left_rows, left_column_id, right_rows, right_column_id, header=Tr
8086
return output
8187

8288

83-
def full_outer_join(left_rows, left_column_id, right_rows, right_column_id, header=True):
89+
def full_outer_join(left_rows, left_column_id, right_rows, right_column_id, header=True, ignore_case=False):
8490
"""
8591
Execute full outer join on two tables and return the combined table.
8692
"""
@@ -94,11 +100,11 @@ def full_outer_join(left_rows, left_column_id, right_rows, right_column_id, head
94100
else:
95101
output = []
96102

97-
# Get ordered keys
98-
left_ordered_keys = _get_ordered_keys(left_rows, left_column_id)
103+
# Get left keys
104+
left_keys = _get_keys(left_rows, left_column_id, ignore_case)
99105

100106
# Get mapped keys
101-
right_mapped_keys = _get_mapped_keys(right_rows, right_column_id)
107+
right_mapped_keys = _get_mapped_keys(right_rows, right_column_id, ignore_case)
102108

103109
for left_row in left_rows:
104110
len_left_row = len(left_row)
@@ -116,13 +122,13 @@ def full_outer_join(left_rows, left_column_id, right_rows, right_column_id, head
116122
for right_row in right_rows:
117123
right_key = right_row[right_column_id]
118124

119-
if right_key not in left_ordered_keys:
125+
if right_key not in left_keys:
120126
output.append(([u''] * len_left_headers) + right_row)
121127

122128
return output
123129

124130

125-
def left_outer_join(left_rows, left_column_id, right_rows, right_column_id, header=True):
131+
def left_outer_join(left_rows, left_column_id, right_rows, right_column_id, header=True, ignore_case=False):
126132
"""
127133
Execute left outer join on two tables and return the combined table.
128134
"""
@@ -137,7 +143,7 @@ def left_outer_join(left_rows, left_column_id, right_rows, right_column_id, head
137143
output = []
138144

139145
# Get mapped keys
140-
right_mapped_keys = _get_mapped_keys(right_rows, right_column_id)
146+
right_mapped_keys = _get_mapped_keys(right_rows, right_column_id, ignore_case)
141147

142148
for left_row in left_rows:
143149
len_left_row = len(left_row)
@@ -155,7 +161,7 @@ def left_outer_join(left_rows, left_column_id, right_rows, right_column_id, head
155161
return output
156162

157163

158-
def right_outer_join(left_rows, left_column_id, right_rows, right_column_id, header=True):
164+
def right_outer_join(left_rows, left_column_id, right_rows, right_column_id, header=True, ignore_case=False):
159165
"""
160166
Execute right outer join on two tables and return the combined table.
161167
"""
@@ -168,11 +174,11 @@ def right_outer_join(left_rows, left_column_id, right_rows, right_column_id, hea
168174
else:
169175
output = []
170176

171-
# Get ordered keys
172-
left_ordered_keys = _get_ordered_keys(left_rows, left_column_id)
177+
# Get left keys
178+
left_keys = _get_keys(left_rows, left_column_id, ignore_case)
173179

174180
# Get mapped keys
175-
right_mapped_keys = _get_mapped_keys(right_rows, right_column_id)
181+
right_mapped_keys = _get_mapped_keys(right_rows, right_column_id, ignore_case)
176182

177183
for left_row in left_rows:
178184
len_left_row = len(left_row)
@@ -188,7 +194,47 @@ def right_outer_join(left_rows, left_column_id, right_rows, right_column_id, hea
188194
for right_row in right_rows:
189195
right_key = right_row[right_column_id]
190196

191-
if right_key not in left_ordered_keys:
197+
if right_key not in left_keys:
192198
output.append(([u''] * len_left_headers) + right_row)
193199

194200
return output
201+
202+
203+
204+
class CaseInsensitiveDict(dict):
205+
"""
206+
Adapted from http://stackoverflow.com/a/32888599/1583437
207+
"""
208+
def __init__(self, *args, **kwargs):
209+
super(CaseInsensitiveDict, self).__init__(*args, **kwargs)
210+
self._convert_keys()
211+
212+
def __getitem__(self, key):
213+
return super(CaseInsensitiveDict, self).__getitem__(_lower(key))
214+
215+
def __setitem__(self, key, value):
216+
super(CaseInsensitiveDict, self).__setitem__(_lower(key), value)
217+
218+
def __delitem__(self, key):
219+
return super(CaseInsensitiveDict, self).__delitem__(_lower(key))
220+
221+
def __contains__(self, key):
222+
return super(CaseInsensitiveDict, self).__contains__(_lower(key))
223+
224+
def pop(self, key, *args, **kwargs):
225+
return super(CaseInsensitiveDict, self).pop(_lower(key), *args, **kwargs)
226+
227+
def get(self, key, *args, **kwargs):
228+
return super(CaseInsensitiveDict, self).get(_lower(key), *args, **kwargs)
229+
230+
def setdefault(self, key, *args, **kwargs):
231+
return super(CaseInsensitiveDict, self).setdefault(_lower(key), *args, **kwargs)
232+
233+
def update(self, single_arg=None, **kwargs):
234+
super(CaseInsensitiveDict, self).update(self.__class__(single_arg))
235+
super(CaseInsensitiveDict, self).update(self.__class__(**kwargs))
236+
237+
def _convert_keys(self):
238+
for k in list(self.keys()):
239+
v = super(CaseInsensitiveDict, self).pop(k)
240+
self.__setitem__(k, v)

csvkit/utilities/csvjoin.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ def add_arguments(self):
2222
help='Perform a left outer join, rather than the default inner join. If more than two files are provided this will be executed as a sequence of left outer joins, starting at the left.')
2323
self.argparser.add_argument('--right', dest='right_join', action='store_true',
2424
help='Perform a right outer join, rather than the default inner join. If more than two files are provided this will be executed as a sequence of right outer joins, starting at the right.')
25+
self.argparser.add_argument('--ignorecase', dest='ignore_case', action='store_true',
26+
help='Whether to ignore string case when comparing keys.')
2527

2628
def main(self):
2729
self.input_files = []
@@ -62,10 +64,11 @@ def main(self):
6264

6365
jointab = tables[0]
6466

67+
ignore_case = self.args.ignore_case
6568
if self.args.left_join:
6669
# Left outer join
6770
for i, t in enumerate(tables[1:]):
68-
jointab = join.left_outer_join(jointab, join_column_ids[0], t, join_column_ids[i + 1], header=header)
71+
jointab = join.left_outer_join(jointab, join_column_ids[0], t, join_column_ids[i + 1], header=header, ignore_case=ignore_case)
6972
elif self.args.right_join:
7073
# Right outer join
7174
jointab = tables[-1]
@@ -74,15 +77,15 @@ def main(self):
7477
remaining_tables.reverse()
7578

7679
for i, t in enumerate(remaining_tables):
77-
jointab = join.right_outer_join(t, join_column_ids[-(i + 2)], jointab, join_column_ids[-1], header=header)
80+
jointab = join.right_outer_join(t, join_column_ids[-(i + 2)], jointab, join_column_ids[-1], header=header, ignore_case=ignore_case)
7881
elif self.args.outer_join:
7982
# Full outer join
8083
for i, t in enumerate(tables[1:]):
81-
jointab = join.full_outer_join(jointab, join_column_ids[0], t, join_column_ids[i + 1], header=header)
84+
jointab = join.full_outer_join(jointab, join_column_ids[0], t, join_column_ids[i + 1], header=header, ignore_case=ignore_case)
8285
elif self.args.columns:
8386
# Inner join
8487
for i, t in enumerate(tables[1:]):
85-
jointab = join.inner_join(jointab, join_column_ids[0], t, join_column_ids[i + 1], header=header)
88+
jointab = join.inner_join(jointab, join_column_ids[0], t, join_column_ids[i + 1], header=header, ignore_case=ignore_case)
8689
else:
8790
# Sequential join
8891
for t in tables[1:]:

tests/test_join.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,23 @@ def setUp(self):
2525
[u'1', u'second', u'0'],
2626
[u'2', u'only', u'0', u'0']] # Note extra value in this column
2727

28-
def test_get_ordered_keys(self):
29-
self.assertEqual(join._get_ordered_keys(self.tab1[1:], 0), [u'1', u'2', u'3', u'1'])
30-
self.assertEqual(join._get_ordered_keys(self.tab2[1:], 0), [u'1', u'4', u'1', u'2'])
28+
def test_get_keys(self):
29+
self.assertEqual(join._get_keys(self.tab1[1:], 0).keys(), set([u'1', u'2', u'3', u'1']))
30+
self.assertEqual(join._get_keys(self.tab2[1:], 0).keys(), set([u'1', u'4', u'1', u'2']))
3131

3232
def test_get_mapped_keys(self):
3333
self.assertEqual(join._get_mapped_keys(self.tab1[1:], 0), {
3434
u'1': [[u'1', u'Chicago Reader', u'first'], [u'1', u'Chicago Reader', u'second']],
3535
u'2': [[u'2', u'Chicago Sun-Times', u'only']],
3636
u'3': [[u'3', u'Chicago Tribune', u'only']]})
3737

38+
def test_get_mapped_keys_ignore_case(self):
39+
mapped_keys = join._get_mapped_keys(self.tab1[1:], 1, case_insensitive=True)
40+
assert u'Chicago Reader' in mapped_keys
41+
assert u'chicago reader' in mapped_keys
42+
assert u'CHICAGO SUN-TIMES' in mapped_keys
43+
assert u'1' not in mapped_keys
44+
3845
def test_sequential_join(self):
3946
self.assertEqual(join.sequential_join(self.tab1, self.tab2), [
4047
['id', 'name', 'i_work_here', 'id', 'age', 'i_work_here'],
@@ -82,3 +89,28 @@ def test_right_outer_join(self):
8289
[u'1', u'Chicago Reader', u'second', u'1', u'first', u'0'],
8390
[u'1', u'Chicago Reader', u'second', u'1', u'second', u'0'],
8491
[u'', u'', u'', u'4', u'only', u'0']])
92+
93+
def test_right_outer_join_ignore_case(self):
94+
# Right outer join exercises all the case dependencies
95+
tab1 = [
96+
['id', 'name', 'i_work_here'],
97+
[u'a', u'Chicago Reader', u'first'],
98+
[u'b', u'Chicago Sun-Times', u'only'],
99+
[u'c', u'Chicago Tribune', u'only'],
100+
[u'a', u'Chicago Reader', u'second']]
101+
102+
tab2 = [
103+
['id', 'age', 'i_work_here'],
104+
[u'A', u'first', u'0'],
105+
[u'D', u'only', u'0'],
106+
[u'A', u'second', u'0'],
107+
[u'B', u'only', u'0', u'0']] # Note extra value in this column
108+
109+
self.assertEqual(join.right_outer_join(tab1, 0, tab2, 0, ignore_case=True), [
110+
['id', 'name', 'i_work_here', 'id', 'age', 'i_work_here'],
111+
[u'a', u'Chicago Reader', u'first', u'A', u'first', u'0'],
112+
[u'a', u'Chicago Reader', u'first', u'A', u'second', u'0'],
113+
[u'b', u'Chicago Sun-Times', u'only', u'B', u'only', u'0', u'0'],
114+
[u'a', u'Chicago Reader', u'second', u'A', u'first', u'0'],
115+
[u'a', u'Chicago Reader', u'second', u'A', u'second', u'0'],
116+
[u'', u'', u'', u'D', u'only', u'0']])

0 commit comments

Comments
 (0)