Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 106 additions & 74 deletions crates/ty/docs/rules.md

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ frozen_instance = MyFrozenClass(1)
frozen_instance.x = 2 # error: [invalid-assignment]
```

If `__setattr__()` or `__delattr__()` is defined in the class, we should emit a diagnostic.
If `__setattr__()` or `__delattr__()` is defined in the class, a diagnostic is emitted.

```py
from dataclasses import dataclass
Expand All @@ -452,10 +452,10 @@ from dataclasses import dataclass
class MyFrozenClass:
x: int

# TODO: Emit a diagnostic here
# error: [cannot-overwrite-attribute] "Cannot overwrite attribute `__setattr__` in class `MyFrozenClass`"
def __setattr__(self, name: str, value: object) -> None: ...

# TODO: Emit a diagnostic here
# error: [cannot-overwrite-attribute] "Cannot overwrite attribute `__delattr__` in class `MyFrozenClass`"
def __delattr__(self, name: str) -> None: ...
```

Expand Down
187 changes: 125 additions & 62 deletions crates/ty_python_semantic/src/types/class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ use crate::semantic_index::{
use crate::types::bound_super::BoundSuperError;
use crate::types::constraints::{ConstraintSet, IteratorConstraintsExtension};
use crate::types::context::InferContext;
use crate::types::diagnostic::{INVALID_TYPE_ALIAS_TYPE, SUPER_CALL_IN_NAMED_TUPLE_METHOD};
use crate::types::diagnostic::{
CANNOT_OVERWRITE_ATTRIBUTE, INVALID_TYPE_ALIAS_TYPE, SUPER_CALL_IN_NAMED_TUPLE_METHOD,
};
use crate::types::enums::enum_metadata;
use crate::types::function::{DataclassTransformerParams, KnownFunction};
use crate::types::generics::{
Expand Down Expand Up @@ -1923,6 +1925,69 @@ impl<'db> ClassLiteral<'db> {
Some(typed_dict_params_from_class_def(class_stmt))
}

/// Returns dataclass params for this class, sourced from both dataclass params and dataclass
/// transform params
fn merged_dataclass_params(
self,
db: &'db dyn Db,
field_policy: CodeGeneratorKind<'db>,
) -> (Option<DataclassParams<'db>>, Option<DataclassParams<'db>>) {
let dataclass_params = self.dataclass_params(db);

let mut transformer_params =
if let CodeGeneratorKind::DataclassLike(Some(transformer_params)) = field_policy {
Some(DataclassParams::from_transformer_params(
db,
transformer_params,
))
} else {
None
};

// Dataclass transformer flags can be overwritten using class arguments.
if let Some(transformer_params) = transformer_params.as_mut() {
if let Some(class_def) = self.definition(db).kind(db).as_class() {
let module = parsed_module(db, self.file(db)).load(db);

if let Some(arguments) = &class_def.node(&module).arguments {
let mut flags = transformer_params.flags(db);

for keyword in &arguments.keywords {
if let Some(arg_name) = &keyword.arg {
if let Some(is_set) =
keyword.value.as_boolean_literal_expr().map(|b| b.value)
{
for (flag_name, flag) in DATACLASS_FLAGS {
if arg_name.as_str() == *flag_name {
flags.set(*flag, is_set);
}
}
}
}
}

*transformer_params =
DataclassParams::new(db, flags, transformer_params.field_specifiers(db));
}
}
}

(dataclass_params, transformer_params)
}

/// Checks if the given dataclass parameter flag is set for this class.
/// This checks both the `dataclass_params` and `transformer_params`.
fn has_dataclass_param(
self,
db: &'db dyn Db,
field_policy: CodeGeneratorKind<'db>,
param: DataclassFlags,
) -> bool {
let (dataclass_params, transformer_params) = self.merged_dataclass_params(db, field_policy);
dataclass_params.is_some_and(|params| params.flags(db).contains(param))
|| transformer_params.is_some_and(|params| params.flags(db).contains(param))
}

/// Return the explicit `metaclass` of this class, if one is defined.
///
/// ## Note
Expand Down Expand Up @@ -2332,53 +2397,8 @@ impl<'db> ClassLiteral<'db> {
inherited_generic_context: Option<GenericContext<'db>>,
name: &str,
) -> Option<Type<'db>> {
let dataclass_params = self.dataclass_params(db);

let field_policy = CodeGeneratorKind::from_class(db, self, specialization)?;

let mut transformer_params =
if let CodeGeneratorKind::DataclassLike(Some(transformer_params)) = field_policy {
Some(DataclassParams::from_transformer_params(
db,
transformer_params,
))
} else {
None
};

// Dataclass transformer flags can be overwritten using class arguments.
if let Some(transformer_params) = transformer_params.as_mut() {
if let Some(class_def) = self.definition(db).kind(db).as_class() {
let module = parsed_module(db, self.file(db)).load(db);

if let Some(arguments) = &class_def.node(&module).arguments {
let mut flags = transformer_params.flags(db);

for keyword in &arguments.keywords {
if let Some(arg_name) = &keyword.arg {
if let Some(is_set) =
keyword.value.as_boolean_literal_expr().map(|b| b.value)
{
for (flag_name, flag) in DATACLASS_FLAGS {
if arg_name.as_str() == *flag_name {
flags.set(*flag, is_set);
}
}
}
}
}

*transformer_params =
DataclassParams::new(db, flags, transformer_params.field_specifiers(db));
}
}
}

let has_dataclass_param = |param| {
dataclass_params.is_some_and(|params| params.flags(db).contains(param))
|| transformer_params.is_some_and(|params| params.flags(db).contains(param))
};

let instance_ty =
Type::instance(db, self.apply_optional_specialization(db, specialization));

Expand Down Expand Up @@ -2456,7 +2476,11 @@ impl<'db> ClassLiteral<'db> {
}

let is_kw_only = name == "__replace__"
|| kw_only.unwrap_or(has_dataclass_param(DataclassFlags::KW_ONLY));
|| kw_only.unwrap_or(self.has_dataclass_param(
db,
field_policy,
DataclassFlags::KW_ONLY,
));

// Use the alias name if provided, otherwise use the field name
let parameter_name =
Expand Down Expand Up @@ -2498,7 +2522,7 @@ impl<'db> ClassLiteral<'db> {

match (field_policy, name) {
(CodeGeneratorKind::DataclassLike(_), "__init__") => {
if !has_dataclass_param(DataclassFlags::INIT) {
if !self.has_dataclass_param(db, field_policy, DataclassFlags::INIT) {
return None;
}

Expand All @@ -2513,7 +2537,7 @@ impl<'db> ClassLiteral<'db> {
signature_from_fields(vec![cls_parameter], Some(Type::none(db)))
}
(CodeGeneratorKind::DataclassLike(_), "__lt__" | "__le__" | "__gt__" | "__ge__") => {
if !has_dataclass_param(DataclassFlags::ORDER) {
if !self.has_dataclass_param(db, field_policy, DataclassFlags::ORDER) {
return None;
}

Expand All @@ -2535,9 +2559,10 @@ impl<'db> ClassLiteral<'db> {
Some(Type::function_like_callable(db, signature))
}
(CodeGeneratorKind::DataclassLike(_), "__hash__") => {
let unsafe_hash = has_dataclass_param(DataclassFlags::UNSAFE_HASH);
let frozen = has_dataclass_param(DataclassFlags::FROZEN);
let eq = has_dataclass_param(DataclassFlags::EQ);
let unsafe_hash =
self.has_dataclass_param(db, field_policy, DataclassFlags::UNSAFE_HASH);
let frozen = self.has_dataclass_param(db, field_policy, DataclassFlags::FROZEN);
let eq = self.has_dataclass_param(db, field_policy, DataclassFlags::EQ);

if unsafe_hash || (frozen && eq) {
let signature = Signature::new(
Expand All @@ -2560,11 +2585,12 @@ impl<'db> ClassLiteral<'db> {
(CodeGeneratorKind::DataclassLike(_), "__match_args__")
if Program::get(db).python_version(db) >= PythonVersion::PY310 =>
{
if !has_dataclass_param(DataclassFlags::MATCH_ARGS) {
if !self.has_dataclass_param(db, field_policy, DataclassFlags::MATCH_ARGS) {
return None;
}

let kw_only_default = has_dataclass_param(DataclassFlags::KW_ONLY);
let kw_only_default =
self.has_dataclass_param(db, field_policy, DataclassFlags::KW_ONLY);

let fields = self.fields(db, specialization, field_policy);
let match_args = fields
Expand All @@ -2582,8 +2608,8 @@ impl<'db> ClassLiteral<'db> {
(CodeGeneratorKind::DataclassLike(_), "__weakref__")
if Program::get(db).python_version(db) >= PythonVersion::PY311 =>
{
if !has_dataclass_param(DataclassFlags::WEAKREF_SLOT)
|| !has_dataclass_param(DataclassFlags::SLOTS)
if !self.has_dataclass_param(db, field_policy, DataclassFlags::WEAKREF_SLOT)
|| !self.has_dataclass_param(db, field_policy, DataclassFlags::SLOTS)
{
return None;
}
Expand Down Expand Up @@ -2625,7 +2651,7 @@ impl<'db> ClassLiteral<'db> {
signature_from_fields(vec![self_parameter], Some(instance_ty))
}
(CodeGeneratorKind::DataclassLike(_), "__setattr__") => {
if has_dataclass_param(DataclassFlags::FROZEN) {
if self.has_dataclass_param(db, field_policy, DataclassFlags::FROZEN) {
let signature = Signature::new(
Parameters::new(
db,
Expand All @@ -2646,11 +2672,12 @@ impl<'db> ClassLiteral<'db> {
(CodeGeneratorKind::DataclassLike(_), "__slots__")
if Program::get(db).python_version(db) >= PythonVersion::PY310 =>
{
has_dataclass_param(DataclassFlags::SLOTS).then(|| {
let fields = self.fields(db, specialization, field_policy);
let slots = fields.keys().map(|name| Type::string_literal(db, name));
Type::heterogeneous_tuple(db, slots)
})
self.has_dataclass_param(db, field_policy, DataclassFlags::SLOTS)
.then(|| {
let fields = self.fields(db, specialization, field_policy);
let slots = fields.keys().map(|name| Type::string_literal(db, name));
Type::heterogeneous_tuple(db, slots)
})
}
(CodeGeneratorKind::TypedDict, "__setitem__") => {
let fields = self.fields(db, specialization, field_policy);
Expand Down Expand Up @@ -3036,6 +3063,42 @@ impl<'db> ClassLiteral<'db> {
.collect()
}

pub(crate) fn validate_members(self, context: &InferContext<'db, '_>) {
let db = context.db();
let Some(field_policy) = CodeGeneratorKind::from_class(db, self, None) else {
return;
};
let class_body_scope = self.body_scope(db);
let table = place_table(db, class_body_scope);
let use_def = use_def_map(db, class_body_scope);
for (symbol_id, declarations) in use_def.all_end_of_scope_symbol_declarations() {
let result = place_from_declarations(db, declarations.clone());
let attr = result.ignore_conflicting_declarations();
let symbol = table.symbol(symbol_id);
let name = symbol.name();
if let Some(Type::FunctionLiteral(literal)) = attr.place.ignore_possibly_undefined()
&& matches!(name.as_str(), "__setattr__" | "__delattr__")
{
if let Some(CodeGeneratorKind::DataclassLike(_)) =
CodeGeneratorKind::from_class(db, self, None)
&& self.has_dataclass_param(db, field_policy, DataclassFlags::FROZEN)
{
if let Some(builder) = context.report_lint(
&CANNOT_OVERWRITE_ATTRIBUTE,
literal.node(db, context.file(), context.module()),
) {
let mut diagnostic = builder.into_diagnostic(format_args!(
"Cannot overwrite attribute `{}` in class `{}`",
name,
self.name(db)
));
diagnostic.info(name);
}
}
}
}
}

/// Returns a list of all annotated attributes defined in the body of this class. This is similar
/// to the `__annotations__` attribute at runtime, but also contains default values.
///
Expand Down
27 changes: 27 additions & 0 deletions crates/ty_python_semantic/src/types/diagnostic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ pub(crate) fn register_lints(registry: &mut LintRegistryBuilder) {
registry.register_lint(&AMBIGUOUS_PROTOCOL_MEMBER);
registry.register_lint(&CALL_NON_CALLABLE);
registry.register_lint(&POSSIBLY_MISSING_IMPLICIT_CALL);
registry.register_lint(&CANNOT_OVERWRITE_ATTRIBUTE);
registry.register_lint(&CONFLICTING_ARGUMENT_FORMS);
registry.register_lint(&CONFLICTING_DECLARATIONS);
registry.register_lint(&CONFLICTING_METACLASS);
Expand Down Expand Up @@ -393,6 +394,32 @@ declare_lint! {
}
}

declare_lint! {
/// ## What it does
/// Checks for dataclass definitions that have both `frozen=True` and a custom `__setattr__`
/// method defined.
///
/// ## Why is this bad?
/// Frozen dataclasses synthesize `__setattr__` and `__delattr__` methods which raise a
/// `FrozenInstanceError` to emulate immutability.
///
/// Overriding either of these methods raises a runtime error.
///
/// ## Examples
/// ```python
/// from dataclasses import dataclass
///
/// @dataclass(frozen=True)
/// class A:
/// def __setattr__(self, name: str, value: object) -> None: ...
/// ```
pub(crate) static CANNOT_OVERWRITE_ATTRIBUTE = {
summary: "detects dataclasses with `frozen=True` that have a custom `__setattr__` or `__delattr__` implementation",
status: LintStatus::preview("1.0.0"),
default_level: Level::Error,
}
}

declare_lint! {
/// ## What it does
/// Checks for classes definitions which will fail at runtime due to
Expand Down
2 changes: 2 additions & 0 deletions crates/ty_python_semantic/src/types/infer/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -990,6 +990,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
if let Some(protocol) = class.into_protocol_class(self.db()) {
protocol.validate_members(&self.context);
}

class.validate_members(&self.context);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Settings: Settings {
"ambiguous-protocol-member": Warning (Default),
"byte-string-type-annotation": Error (Default),
"call-non-callable": Error (Default),
"cannot-overwrite-attribute": Error (Default),
"conflicting-argument-forms": Error (Default),
"conflicting-declarations": Error (Default),
"conflicting-metaclass": Error (Default),
Expand Down
10 changes: 10 additions & 0 deletions ty.schema.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading