Skip to content

Commit 395c608

Browse files
committed
Fixed #12460 -- Improved inspectdb handling of special field names
Thanks mihail lukin for the report and elijahr and kgibula for their work on the patch.
1 parent 10d3207 commit 395c608

File tree

3 files changed

+109
-51
lines changed

3 files changed

+109
-51
lines changed

django/core/management/commands/inspectdb.py

Lines changed: 78 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import unicode_literals
2+
13
import keyword
24
from optparse import make_option
35

@@ -31,6 +33,7 @@ def handle_inspection(self, options):
3133
table_name_filter = options.get('table_name_filter')
3234

3335
table2model = lambda table_name: table_name.title().replace('_', '').replace(' ', '').replace('-', '')
36+
strip_prefix = lambda s: s.startswith("u'") and s[1:] or s
3437

3538
cursor = connection.cursor()
3639
yield "# This is an auto-generated Django model module."
@@ -41,6 +44,7 @@ def handle_inspection(self, options):
4144
yield "#"
4245
yield "# Also note: You'll have to insert the output of 'django-admin.py sqlcustom [appname]'"
4346
yield "# into your database."
47+
yield "from __future__ import unicode_literals"
4448
yield ''
4549
yield 'from %s import models' % self.db_module
4650
yield ''
@@ -59,16 +63,19 @@ def handle_inspection(self, options):
5963
indexes = connection.introspection.get_indexes(cursor, table_name)
6064
except NotImplementedError:
6165
indexes = {}
66+
used_column_names = [] # Holds column names used in the table so far
6267
for i, row in enumerate(connection.introspection.get_table_description(cursor, table_name)):
63-
column_name = row[0]
64-
att_name = column_name.lower()
6568
comment_notes = [] # Holds Field notes, to be displayed in a Python comment.
6669
extra_params = {} # Holds Field parameters such as 'db_column'.
70+
column_name = row[0]
71+
is_relation = i in relations
72+
73+
att_name, params, notes = self.normalize_col_name(
74+
column_name, used_column_names, is_relation)
75+
extra_params.update(params)
76+
comment_notes.extend(notes)
6777

68-
# If the column name can't be used verbatim as a Python
69-
# attribute, set the "db_column" for this Field.
70-
if ' ' in att_name or '-' in att_name or keyword.iskeyword(att_name) or column_name != att_name:
71-
extra_params['db_column'] = column_name
78+
used_column_names.append(att_name)
7279

7380
# Add primary_key and unique, if necessary.
7481
if column_name in indexes:
@@ -77,30 +84,12 @@ def handle_inspection(self, options):
7784
elif indexes[column_name]['unique']:
7885
extra_params['unique'] = True
7986

80-
# Modify the field name to make it Python-compatible.
81-
if ' ' in att_name:
82-
att_name = att_name.replace(' ', '_')
83-
comment_notes.append('Field renamed to remove spaces.')
84-
85-
if '-' in att_name:
86-
att_name = att_name.replace('-', '_')
87-
comment_notes.append('Field renamed to remove dashes.')
88-
89-
if column_name != att_name:
90-
comment_notes.append('Field name made lowercase.')
91-
92-
if i in relations:
87+
if is_relation:
9388
rel_to = relations[i][1] == table_name and "'self'" or table2model(relations[i][1])
94-
9589
if rel_to in known_models:
9690
field_type = 'ForeignKey(%s' % rel_to
9791
else:
9892
field_type = "ForeignKey('%s'" % rel_to
99-
100-
if att_name.endswith('_id'):
101-
att_name = att_name[:-3]
102-
else:
103-
extra_params['db_column'] = column_name
10493
else:
10594
# Calling `get_field_type` to get the field type string and any
10695
# additional paramters and notes.
@@ -110,16 +99,6 @@ def handle_inspection(self, options):
11099

111100
field_type += '('
112101

113-
if keyword.iskeyword(att_name):
114-
att_name += '_field'
115-
comment_notes.append('Field renamed because it was a Python reserved word.')
116-
117-
if att_name[0].isdigit():
118-
att_name = 'number_%s' % att_name
119-
extra_params['db_column'] = six.text_type(column_name)
120-
comment_notes.append("Field renamed because it wasn't a "
121-
"valid Python identifier.")
122-
123102
# Don't output 'id = meta.AutoField(primary_key=True)', because
124103
# that's assumed if it doesn't exist.
125104
if att_name == 'id' and field_type == 'AutoField(' and extra_params == {'primary_key': True}:
@@ -136,14 +115,74 @@ def handle_inspection(self, options):
136115
if extra_params:
137116
if not field_desc.endswith('('):
138117
field_desc += ', '
139-
field_desc += ', '.join(['%s=%r' % (k, v) for k, v in extra_params.items()])
118+
field_desc += ', '.join([
119+
'%s=%s' % (k, strip_prefix(repr(v)))
120+
for k, v in extra_params.items()])
140121
field_desc += ')'
141122
if comment_notes:
142123
field_desc += ' # ' + ' '.join(comment_notes)
143124
yield ' %s' % field_desc
144125
for meta_line in self.get_meta(table_name):
145126
yield meta_line
146127

128+
def normalize_col_name(self, col_name, used_column_names, is_relation):
129+
"""
130+
Modify the column name to make it Python-compatible as a field name
131+
"""
132+
field_params = {}
133+
field_notes = []
134+
135+
new_name = col_name.lower()
136+
if new_name != col_name:
137+
field_notes.append('Field name made lowercase.')
138+
139+
if is_relation:
140+
if new_name.endswith('_id'):
141+
new_name = new_name[:-3]
142+
else:
143+
field_params['db_column'] = col_name
144+
145+
if ' ' in new_name:
146+
new_name = new_name.replace(' ', '_')
147+
field_notes.append('Field renamed to remove spaces.')
148+
149+
if '-' in new_name:
150+
new_name = new_name.replace('-', '_')
151+
field_notes.append('Field renamed to remove dashes.')
152+
153+
if new_name.find('__') >= 0:
154+
while new_name.find('__') >= 0:
155+
new_name = new_name.replace('__', '_')
156+
field_notes.append("Field renamed because it contained more than one '_' in a row.")
157+
158+
if new_name.startswith('_'):
159+
new_name = 'field%s' % new_name
160+
field_notes.append("Field renamed because it started with '_'.")
161+
162+
if new_name.endswith('_'):
163+
new_name = '%sfield' % new_name
164+
field_notes.append("Field renamed because it ended with '_'.")
165+
166+
if keyword.iskeyword(new_name):
167+
new_name += '_field'
168+
field_notes.append('Field renamed because it was a Python reserved word.')
169+
170+
if new_name[0].isdigit():
171+
new_name = 'number_%s' % new_name
172+
field_notes.append("Field renamed because it wasn't a valid Python identifier.")
173+
174+
if new_name in used_column_names:
175+
num = 0
176+
while '%s_%d' % (new_name, num) in used_column_names:
177+
num += 1
178+
new_name = '%s_%d' % (new_name, num)
179+
field_notes.append('Field renamed because of name conflict.')
180+
181+
if col_name != new_name and field_notes:
182+
field_params['db_column'] = col_name
183+
184+
return new_name, field_params, field_notes
185+
147186
def get_field_type(self, connection, table_name, row):
148187
"""
149188
Given the database connection, the table name, and the cursor row
@@ -181,6 +220,6 @@ def get_meta(self, table_name):
181220
to construct the inner Meta class for the model corresponding
182221
to the given database table name.
183222
"""
184-
return [' class Meta:',
185-
' db_table = %r' % table_name,
186-
'']
223+
return [" class Meta:",
224+
" db_table = '%s'" % table_name,
225+
""]

tests/regressiontests/inspectdb/models.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,9 @@ class DigitsInColumnName(models.Model):
1919
all_digits = models.CharField(max_length=11, db_column='123')
2020
leading_digit = models.CharField(max_length=11, db_column='4extra')
2121
leading_digits = models.CharField(max_length=11, db_column='45extra')
22+
23+
class UnderscoresInColumnName(models.Model):
24+
field = models.IntegerField(db_column='field')
25+
field_field_0 = models.IntegerField(db_column='Field_')
26+
field_field_1 = models.IntegerField(db_column='Field__')
27+
field_field_2 = models.IntegerField(db_column='__field')

tests/regressiontests/inspectdb/tests.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import unicode_literals
2+
13
from django.core.management import call_command
24
from django.test import TestCase, skipUnlessDBFeature
35
from django.utils.six import StringIO
@@ -17,7 +19,6 @@ def test_stealth_table_name_filter_option(self):
1719
# the Django test suite, check that one of its tables hasn't been
1820
# inspected
1921
self.assertNotIn("class DjangoContentType(models.Model):", out.getvalue(), msg=error_message)
20-
out.close()
2122

2223
@skipUnlessDBFeature('can_introspect_foreign_keys')
2324
def test_attribute_name_not_python_keyword(self):
@@ -27,15 +28,16 @@ def test_attribute_name_not_python_keyword(self):
2728
call_command('inspectdb',
2829
table_name_filter=lambda tn:tn.startswith('inspectdb_'),
2930
stdout=out)
31+
output = out.getvalue()
3032
error_message = "inspectdb generated an attribute name which is a python keyword"
31-
self.assertNotIn("from = models.ForeignKey(InspectdbPeople)", out.getvalue(), msg=error_message)
33+
self.assertNotIn("from = models.ForeignKey(InspectdbPeople)", output, msg=error_message)
3234
# As InspectdbPeople model is defined after InspectdbMessage, it should be quoted
33-
self.assertIn("from_field = models.ForeignKey('InspectdbPeople')", out.getvalue())
35+
self.assertIn("from_field = models.ForeignKey('InspectdbPeople', db_column='from_id')",
36+
output)
3437
self.assertIn("people_pk = models.ForeignKey(InspectdbPeople, primary_key=True)",
35-
out.getvalue())
38+
output)
3639
self.assertIn("people_unique = models.ForeignKey(InspectdbPeople, unique=True)",
37-
out.getvalue())
38-
out.close()
40+
output)
3941

4042
def test_digits_column_name_introspection(self):
4143
"""Introspection of column names consist/start with digits (#16536/#17676)"""
@@ -45,13 +47,24 @@ def test_digits_column_name_introspection(self):
4547
call_command('inspectdb',
4648
table_name_filter=lambda tn:tn.startswith('inspectdb_'),
4749
stdout=out)
50+
output = out.getvalue()
4851
error_message = "inspectdb generated a model field name which is a number"
49-
self.assertNotIn(" 123 = models.CharField", out.getvalue(), msg=error_message)
50-
self.assertIn("number_123 = models.CharField", out.getvalue())
52+
self.assertNotIn(" 123 = models.CharField", output, msg=error_message)
53+
self.assertIn("number_123 = models.CharField", output)
5154

5255
error_message = "inspectdb generated a model field name which starts with a digit"
53-
self.assertNotIn(" 4extra = models.CharField", out.getvalue(), msg=error_message)
54-
self.assertIn("number_4extra = models.CharField", out.getvalue())
56+
self.assertNotIn(" 4extra = models.CharField", output, msg=error_message)
57+
self.assertIn("number_4extra = models.CharField", output)
58+
59+
self.assertNotIn(" 45extra = models.CharField", output, msg=error_message)
60+
self.assertIn("number_45extra = models.CharField", output)
5561

56-
self.assertNotIn(" 45extra = models.CharField", out.getvalue(), msg=error_message)
57-
self.assertIn("number_45extra = models.CharField", out.getvalue())
62+
def test_underscores_column_name_introspection(self):
63+
"""Introspection of column names containing underscores (#12460)"""
64+
out = StringIO()
65+
call_command('inspectdb', stdout=out)
66+
output = out.getvalue()
67+
self.assertIn("field = models.IntegerField()", output)
68+
self.assertIn("field_field = models.IntegerField(db_column='Field_')", output)
69+
self.assertIn("field_field_0 = models.IntegerField(db_column='Field__')", output)
70+
self.assertIn("field_field_1 = models.IntegerField(db_column='__field')", output)

0 commit comments

Comments
 (0)