Skip to content

Commit b7e81c1

Browse files
committed
[ty] diagnostic on overridden __setattr__ and __delattr__ in frozen dataclasses
astral-sh/ty#111
1 parent 285d641 commit b7e81c1

File tree

7 files changed

+273
-139
lines changed

7 files changed

+273
-139
lines changed

crates/ty/docs/rules.md

Lines changed: 106 additions & 74 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/ty_python_semantic/resources/mdtest/dataclasses/dataclasses.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ frozen_instance = MyFrozenClass(1)
443443
frozen_instance.x = 2 # error: [invalid-assignment]
444444
```
445445

446-
If `__setattr__()` or `__delattr__()` is defined in the class, we should emit a diagnostic.
446+
If `__setattr__()` or `__delattr__()` is defined in the class, a diagnostic is emitted.
447447

448448
```py
449449
from dataclasses import dataclass
@@ -452,10 +452,10 @@ from dataclasses import dataclass
452452
class MyFrozenClass:
453453
x: int
454454

455-
# TODO: Emit a diagnostic here
455+
# error: [cannot-overwrite-attribute] "Cannot overwrite attribute `__setattr__` in class `MyFrozenClass`"
456456
def __setattr__(self, name: str, value: object) -> None: ...
457457

458-
# TODO: Emit a diagnostic here
458+
# error: [cannot-overwrite-attribute] "Cannot overwrite attribute `__delattr__` in class `MyFrozenClass`"
459459
def __delattr__(self, name: str) -> None: ...
460460
```
461461

crates/ty_python_semantic/src/types/class.rs

Lines changed: 124 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ use crate::semantic_index::{
1919
use crate::types::bound_super::BoundSuperError;
2020
use crate::types::constraints::{ConstraintSet, IteratorConstraintsExtension};
2121
use crate::types::context::InferContext;
22-
use crate::types::diagnostic::{INVALID_TYPE_ALIAS_TYPE, SUPER_CALL_IN_NAMED_TUPLE_METHOD};
22+
use crate::types::diagnostic::{
23+
CANNOT_OVERWRITE_ATTRIBUTE, INVALID_TYPE_ALIAS_TYPE, SUPER_CALL_IN_NAMED_TUPLE_METHOD,
24+
};
2325
use crate::types::enums::enum_metadata;
2426
use crate::types::function::{DataclassTransformerParams, KnownFunction};
2527
use crate::types::generics::{
@@ -1923,6 +1925,68 @@ impl<'db> ClassLiteral<'db> {
19231925
Some(typed_dict_params_from_class_def(class_stmt))
19241926
}
19251927

1928+
/// Returns a tuple containing both dataclass params and dataclass transform params
1929+
fn merged_dataclass_params(
1930+
self,
1931+
db: &'db dyn Db,
1932+
field_policy: CodeGeneratorKind<'db>,
1933+
) -> (Option<DataclassParams<'db>>, Option<DataclassParams<'db>>) {
1934+
let dataclass_params = self.dataclass_params(db);
1935+
1936+
let mut transformer_params =
1937+
if let CodeGeneratorKind::DataclassLike(Some(transformer_params)) = field_policy {
1938+
Some(DataclassParams::from_transformer_params(
1939+
db,
1940+
transformer_params,
1941+
))
1942+
} else {
1943+
None
1944+
};
1945+
1946+
// Dataclass transformer flags can be overwritten using class arguments.
1947+
if let Some(transformer_params) = transformer_params.as_mut() {
1948+
if let Some(class_def) = self.definition(db).kind(db).as_class() {
1949+
let module = parsed_module(db, self.file(db)).load(db);
1950+
1951+
if let Some(arguments) = &class_def.node(&module).arguments {
1952+
let mut flags = transformer_params.flags(db);
1953+
1954+
for keyword in &arguments.keywords {
1955+
if let Some(arg_name) = &keyword.arg {
1956+
if let Some(is_set) =
1957+
keyword.value.as_boolean_literal_expr().map(|b| b.value)
1958+
{
1959+
for (flag_name, flag) in DATACLASS_FLAGS {
1960+
if arg_name.as_str() == *flag_name {
1961+
flags.set(*flag, is_set);
1962+
}
1963+
}
1964+
}
1965+
}
1966+
}
1967+
1968+
*transformer_params =
1969+
DataclassParams::new(db, flags, transformer_params.field_specifiers(db));
1970+
}
1971+
}
1972+
}
1973+
1974+
(dataclass_params, transformer_params)
1975+
}
1976+
1977+
/// Checks if the given dataclass parameter flag is set for this class.
1978+
/// This checks both the `dataclass_params` and `transformer_params`.
1979+
fn has_dataclass_param(
1980+
self,
1981+
db: &'db dyn Db,
1982+
field_policy: CodeGeneratorKind<'db>,
1983+
param: DataclassFlags,
1984+
) -> bool {
1985+
let (dataclass_params, transformer_params) = self.merged_dataclass_params(db, field_policy);
1986+
dataclass_params.is_some_and(|params| params.flags(db).contains(param))
1987+
|| transformer_params.is_some_and(|params| params.flags(db).contains(param))
1988+
}
1989+
19261990
/// Return the explicit `metaclass` of this class, if one is defined.
19271991
///
19281992
/// ## Note
@@ -2332,53 +2396,8 @@ impl<'db> ClassLiteral<'db> {
23322396
inherited_generic_context: Option<GenericContext<'db>>,
23332397
name: &str,
23342398
) -> Option<Type<'db>> {
2335-
let dataclass_params = self.dataclass_params(db);
2336-
23372399
let field_policy = CodeGeneratorKind::from_class(db, self, specialization)?;
23382400

2339-
let mut transformer_params =
2340-
if let CodeGeneratorKind::DataclassLike(Some(transformer_params)) = field_policy {
2341-
Some(DataclassParams::from_transformer_params(
2342-
db,
2343-
transformer_params,
2344-
))
2345-
} else {
2346-
None
2347-
};
2348-
2349-
// Dataclass transformer flags can be overwritten using class arguments.
2350-
if let Some(transformer_params) = transformer_params.as_mut() {
2351-
if let Some(class_def) = self.definition(db).kind(db).as_class() {
2352-
let module = parsed_module(db, self.file(db)).load(db);
2353-
2354-
if let Some(arguments) = &class_def.node(&module).arguments {
2355-
let mut flags = transformer_params.flags(db);
2356-
2357-
for keyword in &arguments.keywords {
2358-
if let Some(arg_name) = &keyword.arg {
2359-
if let Some(is_set) =
2360-
keyword.value.as_boolean_literal_expr().map(|b| b.value)
2361-
{
2362-
for (flag_name, flag) in DATACLASS_FLAGS {
2363-
if arg_name.as_str() == *flag_name {
2364-
flags.set(*flag, is_set);
2365-
}
2366-
}
2367-
}
2368-
}
2369-
}
2370-
2371-
*transformer_params =
2372-
DataclassParams::new(db, flags, transformer_params.field_specifiers(db));
2373-
}
2374-
}
2375-
}
2376-
2377-
let has_dataclass_param = |param| {
2378-
dataclass_params.is_some_and(|params| params.flags(db).contains(param))
2379-
|| transformer_params.is_some_and(|params| params.flags(db).contains(param))
2380-
};
2381-
23822401
let instance_ty =
23832402
Type::instance(db, self.apply_optional_specialization(db, specialization));
23842403

@@ -2456,7 +2475,11 @@ impl<'db> ClassLiteral<'db> {
24562475
}
24572476

24582477
let is_kw_only = name == "__replace__"
2459-
|| kw_only.unwrap_or(has_dataclass_param(DataclassFlags::KW_ONLY));
2478+
|| kw_only.unwrap_or(self.has_dataclass_param(
2479+
db,
2480+
field_policy,
2481+
DataclassFlags::KW_ONLY,
2482+
));
24602483

24612484
// Use the alias name if provided, otherwise use the field name
24622485
let parameter_name =
@@ -2498,7 +2521,7 @@ impl<'db> ClassLiteral<'db> {
24982521

24992522
match (field_policy, name) {
25002523
(CodeGeneratorKind::DataclassLike(_), "__init__") => {
2501-
if !has_dataclass_param(DataclassFlags::INIT) {
2524+
if !self.has_dataclass_param(db, field_policy, DataclassFlags::INIT) {
25022525
return None;
25032526
}
25042527

@@ -2513,7 +2536,7 @@ impl<'db> ClassLiteral<'db> {
25132536
signature_from_fields(vec![cls_parameter], Some(Type::none(db)))
25142537
}
25152538
(CodeGeneratorKind::DataclassLike(_), "__lt__" | "__le__" | "__gt__" | "__ge__") => {
2516-
if !has_dataclass_param(DataclassFlags::ORDER) {
2539+
if !self.has_dataclass_param(db, field_policy, DataclassFlags::ORDER) {
25172540
return None;
25182541
}
25192542

@@ -2535,9 +2558,10 @@ impl<'db> ClassLiteral<'db> {
25352558
Some(Type::function_like_callable(db, signature))
25362559
}
25372560
(CodeGeneratorKind::DataclassLike(_), "__hash__") => {
2538-
let unsafe_hash = has_dataclass_param(DataclassFlags::UNSAFE_HASH);
2539-
let frozen = has_dataclass_param(DataclassFlags::FROZEN);
2540-
let eq = has_dataclass_param(DataclassFlags::EQ);
2561+
let unsafe_hash =
2562+
self.has_dataclass_param(db, field_policy, DataclassFlags::UNSAFE_HASH);
2563+
let frozen = self.has_dataclass_param(db, field_policy, DataclassFlags::FROZEN);
2564+
let eq = self.has_dataclass_param(db, field_policy, DataclassFlags::EQ);
25412565

25422566
if unsafe_hash || (frozen && eq) {
25432567
let signature = Signature::new(
@@ -2560,11 +2584,12 @@ impl<'db> ClassLiteral<'db> {
25602584
(CodeGeneratorKind::DataclassLike(_), "__match_args__")
25612585
if Program::get(db).python_version(db) >= PythonVersion::PY310 =>
25622586
{
2563-
if !has_dataclass_param(DataclassFlags::MATCH_ARGS) {
2587+
if !self.has_dataclass_param(db, field_policy, DataclassFlags::MATCH_ARGS) {
25642588
return None;
25652589
}
25662590

2567-
let kw_only_default = has_dataclass_param(DataclassFlags::KW_ONLY);
2591+
let kw_only_default =
2592+
self.has_dataclass_param(db, field_policy, DataclassFlags::KW_ONLY);
25682593

25692594
let fields = self.fields(db, specialization, field_policy);
25702595
let match_args = fields
@@ -2582,8 +2607,8 @@ impl<'db> ClassLiteral<'db> {
25822607
(CodeGeneratorKind::DataclassLike(_), "__weakref__")
25832608
if Program::get(db).python_version(db) >= PythonVersion::PY311 =>
25842609
{
2585-
if !has_dataclass_param(DataclassFlags::WEAKREF_SLOT)
2586-
|| !has_dataclass_param(DataclassFlags::SLOTS)
2610+
if !self.has_dataclass_param(db, field_policy, DataclassFlags::WEAKREF_SLOT)
2611+
|| !self.has_dataclass_param(db, field_policy, DataclassFlags::SLOTS)
25872612
{
25882613
return None;
25892614
}
@@ -2625,7 +2650,7 @@ impl<'db> ClassLiteral<'db> {
26252650
signature_from_fields(vec![self_parameter], Some(instance_ty))
26262651
}
26272652
(CodeGeneratorKind::DataclassLike(_), "__setattr__") => {
2628-
if has_dataclass_param(DataclassFlags::FROZEN) {
2653+
if self.has_dataclass_param(db, field_policy, DataclassFlags::FROZEN) {
26292654
let signature = Signature::new(
26302655
Parameters::new(
26312656
db,
@@ -2646,11 +2671,12 @@ impl<'db> ClassLiteral<'db> {
26462671
(CodeGeneratorKind::DataclassLike(_), "__slots__")
26472672
if Program::get(db).python_version(db) >= PythonVersion::PY310 =>
26482673
{
2649-
has_dataclass_param(DataclassFlags::SLOTS).then(|| {
2650-
let fields = self.fields(db, specialization, field_policy);
2651-
let slots = fields.keys().map(|name| Type::string_literal(db, name));
2652-
Type::heterogeneous_tuple(db, slots)
2653-
})
2674+
self.has_dataclass_param(db, field_policy, DataclassFlags::SLOTS)
2675+
.then(|| {
2676+
let fields = self.fields(db, specialization, field_policy);
2677+
let slots = fields.keys().map(|name| Type::string_literal(db, name));
2678+
Type::heterogeneous_tuple(db, slots)
2679+
})
26542680
}
26552681
(CodeGeneratorKind::TypedDict, "__setitem__") => {
26562682
let fields = self.fields(db, specialization, field_policy);
@@ -3036,6 +3062,42 @@ impl<'db> ClassLiteral<'db> {
30363062
.collect()
30373063
}
30383064

3065+
pub(crate) fn validate_members(self, context: &InferContext<'db, '_>) {
3066+
let db = context.db();
3067+
let Some(field_policy) = CodeGeneratorKind::from_class(db, self, None) else {
3068+
return;
3069+
};
3070+
let class_body_scope = self.body_scope(db);
3071+
let table = place_table(db, class_body_scope);
3072+
let use_def = use_def_map(db, class_body_scope);
3073+
for (symbol_id, declarations) in use_def.all_end_of_scope_symbol_declarations() {
3074+
let result = place_from_declarations(db, declarations.clone());
3075+
let attr = result.ignore_conflicting_declarations();
3076+
let symbol = table.symbol(symbol_id);
3077+
let name = symbol.name();
3078+
if let Some(Type::FunctionLiteral(literal)) = attr.place.ignore_possibly_undefined()
3079+
&& matches!(name.as_str(), "__setattr__" | "__delattr__")
3080+
{
3081+
if let Some(CodeGeneratorKind::DataclassLike(_)) =
3082+
CodeGeneratorKind::from_class(db, self, None)
3083+
&& self.has_dataclass_param(db, field_policy, DataclassFlags::FROZEN)
3084+
{
3085+
if let Some(builder) = context.report_lint(
3086+
&CANNOT_OVERWRITE_ATTRIBUTE,
3087+
literal.node(db, context.file(), context.module()),
3088+
) {
3089+
let mut diagnostic = builder.into_diagnostic(format_args!(
3090+
"Cannot overwrite attribute `{}` in class `{}`",
3091+
name,
3092+
self.name(db)
3093+
));
3094+
diagnostic.info(name);
3095+
}
3096+
}
3097+
}
3098+
}
3099+
}
3100+
30393101
/// Returns a list of all annotated attributes defined in the body of this class. This is similar
30403102
/// to the `__annotations__` attribute at runtime, but also contains default values.
30413103
///

crates/ty_python_semantic/src/types/diagnostic.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ pub(crate) fn register_lints(registry: &mut LintRegistryBuilder) {
5151
registry.register_lint(&AMBIGUOUS_PROTOCOL_MEMBER);
5252
registry.register_lint(&CALL_NON_CALLABLE);
5353
registry.register_lint(&POSSIBLY_MISSING_IMPLICIT_CALL);
54+
registry.register_lint(&CANNOT_OVERWRITE_ATTRIBUTE);
5455
registry.register_lint(&CONFLICTING_ARGUMENT_FORMS);
5556
registry.register_lint(&CONFLICTING_DECLARATIONS);
5657
registry.register_lint(&CONFLICTING_METACLASS);
@@ -393,6 +394,32 @@ declare_lint! {
393394
}
394395
}
395396

397+
declare_lint! {
398+
/// ## What it does
399+
/// Checks for dataclass definitions that have both `frozen=True` and a custom `__setattr__`
400+
/// method defined.
401+
///
402+
/// ## Why is this bad?
403+
/// Frozen dataclasses synthesize `__setattr__` and `__delattr__` methods which raise a
404+
/// `FrozenInstanceError` to emulate immutability.
405+
///
406+
/// Overriding either of these methods raises a runtime error.
407+
///
408+
/// ## Examples
409+
/// ```python
410+
/// from dataclasses import dataclass
411+
///
412+
/// @dataclass(frozen=True)
413+
/// class A:
414+
/// def __setattr__(self, name: str, value: object) -> None: ...
415+
/// ```
416+
pub(crate) static CANNOT_OVERWRITE_ATTRIBUTE = {
417+
summary: "detects dataclasses with `frozen=True` that have a custom `__setattr__` or `__delattr__` implementation",
418+
status: LintStatus::preview("1.0.0"),
419+
default_level: Level::Error,
420+
}
421+
}
422+
396423
declare_lint! {
397424
/// ## What it does
398425
/// Checks for classes definitions which will fail at runtime due to

crates/ty_python_semantic/src/types/infer/builder.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -990,6 +990,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
990990
if let Some(protocol) = class.into_protocol_class(self.db()) {
991991
protocol.validate_members(&self.context);
992992
}
993+
994+
class.validate_members(&self.context);
993995
}
994996
}
995997

crates/ty_server/tests/e2e/snapshots/e2e__commands__debug_command.snap

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ Settings: Settings {
3131
"ambiguous-protocol-member": Warning (Default),
3232
"byte-string-type-annotation": Error (Default),
3333
"call-non-callable": Error (Default),
34+
"cannot-overwrite-attribute": Error (Default),
3435
"conflicting-argument-forms": Error (Default),
3536
"conflicting-declarations": Error (Default),
3637
"conflicting-metaclass": Error (Default),

ty.schema.json

Lines changed: 10 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)