diff --git a/Rakefile b/Rakefile index e69de29..debc11c 100644 --- a/Rakefile +++ b/Rakefile @@ -0,0 +1,8 @@ +require 'rake/testtask' + +Rake::TestTask.new do |t| + t.libs << 'test' +end + +desc "Run tests" +task :default => :test diff --git a/lib/activerecord-import/adapters/sqlserver_adapter.rb b/lib/activerecord-import/adapters/sqlserver_adapter.rb index c8520d2..6f57cb0 100644 --- a/lib/activerecord-import/adapters/sqlserver_adapter.rb +++ b/lib/activerecord-import/adapters/sqlserver_adapter.rb @@ -8,13 +8,10 @@ def insert_many( sql, values, options = {}, *args ) [sql.shift, sql.join( ' ' )] end - columns_names = base_sql.match(/INSERT INTO (\[.*\]) (\(.*\)) VALUES /)[2][1..-1].split(',') - sql_id_index = columns_names.index('[id]') - sql_noid = if sql_id_index.nil? - nil - else - (sql_id_index == (columns_names.length - 1) ? base_sql.clone.gsub(/\[id\]/, '') : base_sql.clone.gsub(/\[id\],/, '')) - end + column_override = get_identity_column_name options + columns_names = parse_column_names_from_sql base_sql + sql_id_index = columns_names.index("[#{column_override}]") + sql_noid = get_sql_noid sql_id_index, columns_names, base_sql, column_override max = max_allowed_packet @@ -25,7 +22,7 @@ def insert_many( sql, values, options = {}, *args ) supplied_ids = [] batch.each do |value| - values_sql = value[1..-2].split(',') + values_sql = values_to_array(value) if values_sql[sql_id_index] == "NULL" values_sql.delete_at(sql_id_index) null_ids << "(#{values_sql.join(',')})" @@ -57,4 +54,44 @@ def insert_many( sql, values, options = {}, *args ) def max_allowed_packet 1000 end + + def get_identity_column_name( options ) + options[:id_column_name] || 'id' + end + + def get_sql_noid(sql_id_index, columns_names, base_sql, column_override) + if sql_id_index.nil? + nil + else + (sql_id_index == (columns_names.length - 1) ? base_sql.clone.gsub(/,\[#{column_override}\]/, '') : base_sql.clone.gsub(/\[#{column_override}\],/, '')) + end + end + + # This can be removed, it's just here to show the old way and compare with the new way + def get_sql_noid_OLD(sql_id_index, columns_names, base_sql, column_override) + if sql_id_index.nil? + nil + else + (sql_id_index == (columns_names.length - 1) ? base_sql.clone.gsub(/\[id\]/, '') : base_sql.clone.gsub(/\[id\],/, '')) + end + end + + def parse_column_names_from_sql( sql ) + sql.match(/(?<=\().*(?=\).*VALUES)/)[0].split(',') + end + + # This can be removed, it's just here to show the old way and compare with the new way + def parse_column_names_from_sql_OLD( sql ) + sql.match(/INSERT INTO (\[.*\]) (\(.*\)) VALUES /)[2][1..-1].split(',') + end + + def values_to_array( value ) + value[1..-2].scan(/N\'.*?\'|[^,]+/) + end + + # This can be removed, it's just here to show the old way and compare with the new way + def values_to_array_OLD( value ) + value[1..-2].split(',') + end + end diff --git a/test/test_sqlserver_adapter.rb b/test/test_sqlserver_adapter.rb new file mode 100644 index 0000000..2391c27 --- /dev/null +++ b/test/test_sqlserver_adapter.rb @@ -0,0 +1,108 @@ +require 'minitest/autorun' +require 'active_record' +require 'activerecord-import' +require 'activerecord-import/active_record/adapters/sqlserver_adapter' + +class SqlServerAdapterTest < Minitest::Test + + def test_get_identity_column_name + options = {} + result = Class.new.extend(ActiveRecord::Import::SQLServerAdapter).get_identity_column_name(options) + assert_equal 'id', result + + options[:id_column_name] = 'foobar' + result = Class.new.extend(ActiveRecord::Import::SQLServerAdapter).get_identity_column_name(options) + assert_equal 'foobar', result + end + + def test_get_sql_noid + sql_id_index = nil + columns_names = nil + base_sql = nil + column_override = nil + result = Class.new.extend(ActiveRecord::Import::SQLServerAdapter).get_sql_noid(sql_id_index, columns_names, base_sql, column_override) + assert_nil result + + sql_id_index = 0 + columns_names = ['id', 'col1','col2','col3','col4','col5'] + base_sql = 'INSERT INTO [schema].[table] ([id],[col1],[col2],[col3],[col4],[col5]) VALUES ' + column_override = 'id' + result = Class.new.extend(ActiveRecord::Import::SQLServerAdapter).get_sql_noid(sql_id_index, columns_names, base_sql, column_override) + assert_equal 'INSERT INTO [schema].[table] ([col1],[col2],[col3],[col4],[col5]) VALUES ', result + + sql_id_index = 1 + columns_names = ['col1','id','col2','col3','col4','col5'] + base_sql = 'INSERT INTO [schema].[table] ([col1],[id],[col2],[col3],[col4],[col5]) VALUES ' + result = Class.new.extend(ActiveRecord::Import::SQLServerAdapter).get_sql_noid(sql_id_index, columns_names, base_sql, column_override) + assert_equal 'INSERT INTO [schema].[table] ([col1],[col2],[col3],[col4],[col5]) VALUES ', result + + sql_id_index = 5 + columns_names = ['col1','col2','col3','col4','col5','id'] + base_sql = 'INSERT INTO [schema].[table] ([col1],[col2],[col3],[col4],[col5],[id]) VALUES ' + result = Class.new.extend(ActiveRecord::Import::SQLServerAdapter).get_sql_noid(sql_id_index, columns_names, base_sql, column_override) + assert_equal 'INSERT INTO [schema].[table] ([col1],[col2],[col3],[col4],[col5]) VALUES ', result + end + + def test_get_sql_noid_OLD_AND_BROKEN + sql_id_index = nil + columns_names = nil + base_sql = nil + column_override = nil + result = Class.new.extend(ActiveRecord::Import::SQLServerAdapter).get_sql_noid_OLD(sql_id_index, columns_names, base_sql, column_override) + assert_nil result + + sql_id_index = 0 + columns_names = ['id', 'col1','col2','col3','col4','col5'] + base_sql = 'INSERT INTO [schema].[table] ([id],[col1],[col2],[col3],[col4],[col5]) VALUES ' + column_override = 'id' + result = Class.new.extend(ActiveRecord::Import::SQLServerAdapter).get_sql_noid_OLD(sql_id_index, columns_names, base_sql, column_override) + assert_equal 'INSERT INTO [schema].[table] ([col1],[col2],[col3],[col4],[col5]) VALUES ', result + + sql_id_index = 1 + columns_names = ['col1','id','col2','col3','col4','col5'] + base_sql = 'INSERT INTO [schema].[table] ([col1],[id],[col2],[col3],[col4],[col5]) VALUES ' + result = Class.new.extend(ActiveRecord::Import::SQLServerAdapter).get_sql_noid_OLD(sql_id_index, columns_names, base_sql, column_override) + assert_equal 'INSERT INTO [schema].[table] ([col1],[col2],[col3],[col4],[col5]) VALUES ', result + + sql_id_index = 5 + columns_names = ['col1','col2','col3','col4','col5','id'] + base_sql = 'INSERT INTO [schema].[table] ([col1],[col2],[col3],[col4],[col5],[id]) VALUES ' + result = Class.new.extend(ActiveRecord::Import::SQLServerAdapter).get_sql_noid_OLD(sql_id_index, columns_names, base_sql, column_override) + # Notice that there is a comma at the end of the columns that shouldn't be there + assert_equal 'INSERT INTO [schema].[table] ([col1],[col2],[col3],[col4],[col5],) VALUES ', result + end + + def test_values_to_array + base_string = "(N'firstcolval',N'second col val',N'third, col & value, with commas and stuff',0,true)" + result = Class.new.extend(ActiveRecord::Import::SQLServerAdapter).values_to_array base_string + expected = ["N'firstcolval'", "N'second col val'", "N'third, col & value, with commas and stuff'", "0", "true"] + assert_equal expected, result + end + + def test_values_to_array_OLD_AND_BROKEN + base_string = "(N'firstcolval',N'second col val',N'third, col & value, with commas and stuff',0,true)" + result = Class.new.extend(ActiveRecord::Import::SQLServerAdapter).values_to_array_OLD base_string + + # Notice that the expected values using the old (split) method does weird things when the data contains commas + expected = ["N'firstcolval'", "N'second col val'", "N'third"," col & value"," with commas and stuff'", "0", "true"] + assert_equal expected, result + end + + def test_parse_column_names + sql = 'INSERT INTO [schema].[table] ([col1],[col2],[col3],[col4],[col5]) VALUES ' + result = Class.new.extend(ActiveRecord::Import::SQLServerAdapter).parse_column_names_from_sql sql + expected = ['[col1]', '[col2]','[col3]','[col4]','[col5]'] + assert_equal expected, result + end + + def test_parse_column_names_OLD_AND_BROKEN + sql = 'INSERT INTO [schema].[table] ([col1],[col2],[col3],[col4],[col5]) VALUES ' + result = Class.new.extend(ActiveRecord::Import::SQLServerAdapter).parse_column_names_from_sql_OLD sql + + #Notice that there is a trailing ) here, which is bad! + expected = ['[col1]','[col2]','[col3]','[col4]','[col5])'] + assert_equal expected, result + end + + +end \ No newline at end of file