diff --git a/.gitignore b/.gitignore index fcb02fe89..0591e8f63 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ /target .idea -dhat-heap.json +.cursor +dhat-heap.json \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 9a7e0c7df..3fc6625c5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,8 @@ local a = {} ``` +### 🔧 Changed +- **Class Method Completion**: When a function call jumps, if there are multiple declarations, It will then attempt to return the most matching definition along with all actual code declarations, rather than returning all definitions. ### 🐛 Fixed - **Enum Variable Parameter Issue**: Fixed a crash issue when checking enum variable as parameter diff --git a/crates/emmylua_code_analysis/locales/lint.yml b/crates/emmylua_code_analysis/locales/lint.yml index 48c84d7d2..43c11f8fa 100644 --- a/crates/emmylua_code_analysis/locales/lint.yml +++ b/crates/emmylua_code_analysis/locales/lint.yml @@ -252,4 +252,12 @@ Cannot use `...` outside a vararg function.: "the string template type must be a string constant": en: "the string template type must be a string constant" zh_CN: "字符串模板类型必须是字符串常量" - zh_HK: "字串模板類型必須是字串常量" \ No newline at end of file + zh_HK: "字串模板類型必須是字串常量" +"Cannot cast `%{original}` to `%{target}`. %{reason}": + en: "Cannot cast `%{original}` to `%{target}`. %{reason}" + zh_CN: "不能将 `%{original}` 转换为 `%{target}`。%{reason}" + zh_HK: "不能將 `%{original}` 轉換為 `%{target}`。%{reason}" +"type recursion": + en: "type recursion" + zh_CN: "类型递归" + zh_HK: "類型遞歸" \ No newline at end of file diff --git a/crates/emmylua_code_analysis/resources/schema.json b/crates/emmylua_code_analysis/resources/schema.json index 2fc6520eb..9ee0aea58 100644 --- a/crates/emmylua_code_analysis/resources/schema.json +++ b/crates/emmylua_code_analysis/resources/schema.json @@ -35,6 +35,7 @@ "autoRequireFunction": "require", "autoRequireNamingConvention": "keep", "autoRequireSeparator": ".", + "baseFunctionIncludesName": true, "callSnippet": false, "enable": true, "postfix": "@" @@ -76,6 +77,7 @@ "enable": true, "indexHint": true, "localHint": true, + "metaCallHint": true, "overrideHint": true, "paramHint": true }, @@ -95,6 +97,16 @@ } ] }, + "inlineValues": { + "default": { + "enable": true + }, + "allOf": [ + { + "$ref": "#/definitions/EmmyrcInlineValues" + } + ] + }, "references": { "default": { "enable": true, @@ -511,6 +523,13 @@ "enum": [ "generic-constraint-mismatch" ] + }, + { + "description": "cast-type-mismatch", + "type": "string", + "enum": [ + "cast-type-mismatch" + ] } ] }, @@ -594,6 +613,11 @@ "default": ".", "type": "string" }, + "baseFunctionIncludesName": { + "description": "Whether to include the name in the base function completion. effect: `function () end` -> `function name() end`.", + "default": true, + "type": "boolean" + }, "callSnippet": { "description": "Whether to use call snippets in completions.", "default": false, @@ -710,6 +734,13 @@ "enum": [ "camel-case" ] + }, + { + "description": "When returning class definition, use class name, otherwise keep original name.", + "type": "string", + "enum": [ + "keep-class" + ] } ] }, @@ -741,6 +772,11 @@ "default": true, "type": "boolean" }, + "metaCallHint": { + "description": "Whether to enable meta __call operator hints.", + "default": true, + "type": "boolean" + }, "overrideHint": { "description": "Whether to enable override hints.", "default": true, @@ -753,6 +789,16 @@ } } }, + "EmmyrcInlineValues": { + "type": "object", + "properties": { + "enable": { + "description": "Whether to enable inline values.", + "default": true, + "type": "boolean" + } + } + }, "EmmyrcLuaVersion": { "oneOf": [ { diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/decl/exprs.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/decl/exprs.rs index 927050e7c..97585133d 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/decl/exprs.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/decl/exprs.rs @@ -1,6 +1,7 @@ use emmylua_parser::{ - LuaAst, LuaAstNode, LuaAstToken, LuaCallExpr, LuaClosureExpr, LuaExpr, LuaIndexExpr, - LuaIndexKey, LuaLiteralExpr, LuaLiteralToken, LuaNameExpr, LuaTableExpr, LuaVarExpr, + LuaAst, LuaAstNode, LuaAstToken, LuaCallExpr, LuaClosureExpr, LuaDocTagCast, LuaExpr, + LuaIndexExpr, LuaIndexKey, LuaLiteralExpr, LuaLiteralToken, LuaNameExpr, LuaTableExpr, + LuaVarExpr, }; use crate::{ @@ -50,6 +51,9 @@ pub fn analyze_name_expr(analyzer: &mut DeclAnalyzer, expr: LuaNameExpr) -> Opti } pub fn analyze_index_expr(analyzer: &mut DeclAnalyzer, index_expr: LuaIndexExpr) -> Option<()> { + if index_expr.ancestors::().next().is_some() { + return Some(()); + } let index_key = index_expr.get_index_key()?; let key = match index_key { LuaIndexKey::Name(name) => LuaMemberKey::Name(name.get_name_text().to_string().into()), diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/field_or_operator_def_tags.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/field_or_operator_def_tags.rs index 7ac397c24..b0b190848 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/field_or_operator_def_tags.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/field_or_operator_def_tags.rs @@ -66,7 +66,7 @@ pub fn analyze_field(analyzer: &mut DocAnalyzer, tag: LuaDocTagField) -> Option< for desc in tag.get_descriptions() { let mut desc_text = desc.get_description_text().to_string(); if !desc_text.is_empty() { - let text = preprocess_description(&mut desc_text); + let text = preprocess_description(&mut desc_text, Some(&property_owner)); if !description.is_empty() { description.push_str("\n\n"); } diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/infer_type.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/infer_type.rs index f5d527ac0..c536191fb 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/infer_type.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/infer_type.rs @@ -1,10 +1,10 @@ use std::sync::Arc; use emmylua_parser::{ - LuaAst, LuaAstNode, LuaDocBinaryType, LuaDocFuncType, LuaDocGenericType, - LuaDocMultiLineUnionType, LuaDocObjectFieldKey, LuaDocObjectType, LuaDocStrTplType, LuaDocType, - LuaDocUnaryType, LuaDocVariadicType, LuaLiteralToken, LuaSyntaxKind, LuaTypeBinaryOperator, - LuaTypeUnaryOperator, LuaVarExpr, + LuaAst, LuaAstNode, LuaDocBinaryType, LuaDocDescriptionOwner, LuaDocFuncType, + LuaDocGenericType, LuaDocMultiLineUnionType, LuaDocObjectFieldKey, LuaDocObjectType, + LuaDocStrTplType, LuaDocType, LuaDocUnaryType, LuaDocVariadicType, LuaLiteralToken, + LuaSyntaxKind, LuaTypeBinaryOperator, LuaTypeUnaryOperator, LuaVarExpr, }; use rowan::TextRange; use smol_str::SmolStr; @@ -588,7 +588,8 @@ fn infer_multi_line_union_type( }; let description = if let Some(description) = field.get_description() { - let description_text = preprocess_description(&description.get_description_text()); + let description_text = + preprocess_description(&description.get_description_text(), None); if !description_text.is_empty() { Some(description_text) } else { diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/mod.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/mod.rs index 806897170..d837a379a 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/mod.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/mod.rs @@ -11,7 +11,7 @@ use super::AnalyzeContext; use crate::{ db_index::{DbIndex, LuaTypeDeclId}, profile::Profile, - FileId, + FileId, LuaSemanticDeclId, }; use emmylua_parser::{LuaAstNode, LuaComment, LuaSyntaxNode}; use file_generic_index::FileGenericIndex; @@ -44,8 +44,10 @@ fn analyze_comment(analyzer: &mut DocAnalyzer) -> Option<()> { } let owenr = get_owner_id(analyzer)?; - let comment_description = - preprocess_description(&comment.get_description()?.get_description_text()); + let comment_description = preprocess_description( + &comment.get_description()?.get_description_text(), + Some(&owenr), + ); analyzer.db.get_property_index_mut().add_description( analyzer.file_id, owenr, @@ -90,9 +92,16 @@ impl<'a> DocAnalyzer<'a> { } } -pub fn preprocess_description(mut description: &str) -> String { - if description.starts_with(['#', '@']) { - description = description.trim_start_matches(|c| c == '#' || c == '@'); +pub fn preprocess_description(mut description: &str, owner: Option<&LuaSemanticDeclId>) -> String { + let has_remove_start_char = if let Some(owner) = owner { + !matches!(owner, LuaSemanticDeclId::Signature(_)) + } else { + true + }; + if has_remove_start_char { + if description.starts_with(['#', '@']) { + description = description.trim_start_matches(|c| c == '#' || c == '@'); + } } let mut result = String::new(); diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_def_tags.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_def_tags.rs index 5c5a058c1..afea5e39d 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_def_tags.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_def_tags.rs @@ -74,8 +74,16 @@ fn add_description_for_type_decl( ) { let mut description_text = String::new(); + // let comment = analyzer.comment.clone(); + // if let Some(description) = comment.get_description() { + // let description = preprocess_description(&description.get_description_text(), None); + // if !description.is_empty() { + // description_text.push_str(&description); + // } + // } + for description in descriptions { - let description = preprocess_description(&description.get_description_text()); + let description = preprocess_description(&description.get_description_text(), None); if !description.is_empty() { if !description_text.is_empty() { description_text.push_str("\n\n"); diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_ref_tags.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_ref_tags.rs index a315290b7..bdd2e4a94 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_ref_tags.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_ref_tags.rs @@ -116,7 +116,7 @@ pub fn analyze_param(analyzer: &mut DocAnalyzer, tag: LuaDocTagParam) -> Option< } let description = if let Some(des) = tag.get_description() { - Some(preprocess_description(&des.get_description_text())) + Some(preprocess_description(&des.get_description_text(), None)) } else { None }; @@ -155,7 +155,7 @@ pub fn analyze_param(analyzer: &mut DocAnalyzer, tag: LuaDocTagParam) -> Option< pub fn analyze_return(analyzer: &mut DocAnalyzer, tag: LuaDocTagReturn) -> Option<()> { let description = if let Some(des) = tag.get_description() { - Some(preprocess_description(&des.get_description_text())) + Some(preprocess_description(&des.get_description_text(), None)) } else { None }; @@ -383,7 +383,7 @@ pub fn analyze_other(analyzer: &mut DocAnalyzer, other: LuaDocTagOther) -> Optio let owner = get_owner_id(analyzer)?; let tag_name = other.get_tag_name()?; let description = if let Some(des) = other.get_description() { - let description = preprocess_description(&des.get_description_text()); + let description = preprocess_description(&des.get_description_text(), None); format!("@*{}* {}", tag_name, description) } else { format!("@*{}*", tag_name) diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/flow/build_flow_tree.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/flow/build_flow_tree.rs index 817e9f305..df6885e42 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/flow/build_flow_tree.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/flow/build_flow_tree.rs @@ -1,15 +1,16 @@ use std::collections::HashMap; use emmylua_parser::{ - LuaAst, LuaAstNode, LuaAstToken, LuaBlock, LuaBreakStat, LuaChunk, LuaDocTagCast, LuaGotoStat, - LuaIndexExpr, LuaLabelStat, LuaLoopStat, LuaNameExpr, LuaStat, LuaSyntaxKind, LuaTokenKind, - PathTrait, + LuaAst, LuaAstNode, LuaAstToken, LuaBlock, LuaBreakStat, LuaChunk, LuaComment, LuaDocTagCast, + LuaExpr, LuaGotoStat, LuaIndexExpr, LuaLabelStat, LuaLoopStat, LuaNameExpr, LuaStat, + LuaSyntaxKind, LuaTokenKind, PathTrait, }; use rowan::{TextRange, TextSize, WalkEvent}; use smol_str::SmolStr; use crate::{ - AnalyzeError, DbIndex, DiagnosticCode, FileId, LuaDeclId, LuaFlowId, LuaVarRefId, LuaVarRefNode, + AnalyzeError, DbIndex, DiagnosticCode, FileId, InFiled, LuaDeclId, LuaFlowId, LuaVarRefId, + LuaVarRefNode, }; use super::flow_node::{BlockId, FlowNode}; @@ -285,26 +286,59 @@ fn build_cast_flow( file_id: FileId, tag_cast: LuaDocTagCast, ) -> Option<()> { - let name_token = tag_cast.get_name_token()?; - let decl_tree = db.get_decl_index().get_decl_tree(&file_id)?; - let text = name_token.get_name_text(); - if let Some(decl) = decl_tree.find_local_decl(text, name_token.get_position()) { - let decl_id = decl.get_id(); - builder.add_flow_node( - LuaVarRefId::DeclId(decl_id), - LuaVarRefNode::CastRef(tag_cast.clone()), - ); - } else { - let ref_id = LuaVarRefId::Name(SmolStr::new(text)); - if db - .get_decl_index() - .get_decl(&LuaDeclId::new(file_id, name_token.get_position())) - .is_none() - { - builder.add_flow_node(ref_id, LuaVarRefNode::CastRef(tag_cast.clone())); + match tag_cast.get_key_expr() { + Some(target_expr) => { + let text = match &target_expr { + LuaExpr::NameExpr(name_expr) => name_expr.get_name_text()?, + LuaExpr::IndexExpr(index_expr) => index_expr.get_access_path()?, + _ => { + return None; + } + }; + + let decl_tree = db.get_decl_index().get_decl_tree(&file_id)?; + if let Some(decl) = decl_tree.find_local_decl(&text, target_expr.get_position()) { + let decl_id = decl.get_id(); + builder.add_flow_node( + LuaVarRefId::DeclId(decl_id), + LuaVarRefNode::CastRef(tag_cast.clone()), + ); + } else { + let ref_id = LuaVarRefId::Name(SmolStr::new(text)); + if db + .get_decl_index() + .get_decl(&LuaDeclId::new(file_id, target_expr.get_position())) + .is_none() + { + builder.add_flow_node(ref_id, LuaVarRefNode::CastRef(tag_cast.clone())); + } + } } - } + None => { + // 没有指定名称, 则附加到最近的表达式上 + let comment = tag_cast.get_parent::()?; + let mut left_token = comment.syntax().first_token()?.prev_token()?; + if left_token.kind() == LuaTokenKind::TkWhitespace.into() { + left_token = left_token.prev_token()?; + } + let mut ast_node = left_token.parent()?; + loop { + if LuaExpr::can_cast(ast_node.kind().into()) { + break; + } else if LuaBlock::can_cast(ast_node.kind().into()) { + return None; + } + ast_node = ast_node.parent()?; + } + let expr = LuaExpr::cast(ast_node)?; + let in_filed_syntax_id = InFiled::new(file_id, expr.get_syntax_id()); + builder.add_flow_node( + LuaVarRefId::SyntaxId(in_filed_syntax_id), + LuaVarRefNode::CastRef(tag_cast.clone()), + ); + } + } Some(()) } diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/stats.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/stats.rs index f6734ffc0..a7c19c770 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/stats.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/stats.rs @@ -1,5 +1,5 @@ use emmylua_parser::{ - BinaryOperator, LuaAssignStat, LuaAstNode, LuaExpr, LuaFuncStat, LuaIndexExpr, + BinaryOperator, LuaAssignStat, LuaAstNode, LuaAstToken, LuaExpr, LuaFuncStat, LuaIndexExpr, LuaLocalFuncStat, LuaLocalStat, LuaTableField, LuaVarExpr, PathTrait, }; @@ -542,7 +542,8 @@ pub fn try_add_class_default_call( decl_id.into(), LuaOperatorMetaMethod::Call, analyzer.file_id, - index_expr.get_range(), + // 必须指向名称, 使用 index_expr 的完整范围不会跳转到函数上 + index_expr.get_name_token()?.syntax().text_range(), OperatorFunction::DefaultCall(signature_id), ); analyzer.db.get_operator_index_mut().add_operator(operator); diff --git a/crates/emmylua_code_analysis/src/config/configs/completion.rs b/crates/emmylua_code_analysis/src/config/configs/completion.rs index 820b73cb6..b90928a1a 100644 --- a/crates/emmylua_code_analysis/src/config/configs/completion.rs +++ b/crates/emmylua_code_analysis/src/config/configs/completion.rs @@ -29,6 +29,9 @@ pub struct EmmyrcCompletion { /// The postfix trigger used in completions. #[serde(default = "default_postfix")] pub postfix: String, + /// Whether to include the name in the base function completion. effect: `function () end` -> `function name() end`. + #[serde(default = "default_true")] + pub base_function_includes_name: bool, } impl Default for EmmyrcCompletion { @@ -41,6 +44,7 @@ impl Default for EmmyrcCompletion { call_snippet: false, auto_require_separator: default_auto_require_separator(), postfix: default_postfix(), + base_function_includes_name: default_true(), } } } @@ -72,6 +76,8 @@ pub enum EmmyrcFilenameConvention { PascalCase, /// Convert the filename to camelCase. CamelCase, + /// When returning class definition, use class name, otherwise keep original name. + KeepClass, } impl Default for EmmyrcFilenameConvention { diff --git a/crates/emmylua_code_analysis/src/config/configs/inlayhint.rs b/crates/emmylua_code_analysis/src/config/configs/inlayhint.rs index 67ac57d08..05f2e0d37 100644 --- a/crates/emmylua_code_analysis/src/config/configs/inlayhint.rs +++ b/crates/emmylua_code_analysis/src/config/configs/inlayhint.rs @@ -20,6 +20,9 @@ pub struct EmmyrcInlayHint { /// Whether to enable override hints. #[serde(default = "default_true")] pub override_hint: bool, + /// Whether to enable meta __call operator hints. + #[serde(default = "default_true")] + pub meta_call_hint: bool, } impl Default for EmmyrcInlayHint { @@ -30,6 +33,7 @@ impl Default for EmmyrcInlayHint { index_hint: default_true(), local_hint: default_true(), override_hint: default_true(), + meta_call_hint: default_true(), } } } diff --git a/crates/emmylua_code_analysis/src/config/configs/inline_values.rs b/crates/emmylua_code_analysis/src/config/configs/inline_values.rs new file mode 100644 index 000000000..f46cefb75 --- /dev/null +++ b/crates/emmylua_code_analysis/src/config/configs/inline_values.rs @@ -0,0 +1,22 @@ +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize, Debug, JsonSchema, Clone)] +#[serde(rename_all = "camelCase")] +pub struct EmmyrcInlineValues { + /// Whether to enable inline values. + #[serde(default = "default_true")] + pub enable: bool, +} + +impl Default for EmmyrcInlineValues { + fn default() -> Self { + Self { + enable: default_true(), + } + } +} + +fn default_true() -> bool { + true +} diff --git a/crates/emmylua_code_analysis/src/config/configs/mod.rs b/crates/emmylua_code_analysis/src/config/configs/mod.rs index edda9d142..e67e207a7 100644 --- a/crates/emmylua_code_analysis/src/config/configs/mod.rs +++ b/crates/emmylua_code_analysis/src/config/configs/mod.rs @@ -5,6 +5,7 @@ mod diagnostics; mod document_color; mod hover; mod inlayhint; +mod inline_values; mod references; mod resource; mod runtime; @@ -20,6 +21,7 @@ pub use diagnostics::EmmyrcDiagnostic; pub use document_color::EmmyrcDocumentColor; pub use hover::EmmyrcHover; pub use inlayhint::EmmyrcInlayHint; +pub use inline_values::EmmyrcInlineValues; pub use references::EmmyrcReference; pub use resource::EmmyrcResource; pub use runtime::{EmmyrcLuaVersion, EmmyrcRuntime}; diff --git a/crates/emmylua_code_analysis/src/config/mod.rs b/crates/emmylua_code_analysis/src/config/mod.rs index d12ee0f60..4a9f380bb 100644 --- a/crates/emmylua_code_analysis/src/config/mod.rs +++ b/crates/emmylua_code_analysis/src/config/mod.rs @@ -13,8 +13,8 @@ pub use configs::EmmyrcLuaVersion; use configs::{EmmyrcCodeAction, EmmyrcDocumentColor}; use configs::{ EmmyrcCodeLen, EmmyrcCompletion, EmmyrcDiagnostic, EmmyrcHover, EmmyrcInlayHint, - EmmyrcReference, EmmyrcResource, EmmyrcRuntime, EmmyrcSemanticToken, EmmyrcSignature, - EmmyrcStrict, EmmyrcWorkspace, + EmmyrcInlineValues, EmmyrcReference, EmmyrcResource, EmmyrcRuntime, EmmyrcSemanticToken, + EmmyrcSignature, EmmyrcStrict, EmmyrcWorkspace, }; use emmylua_parser::{LuaLanguageLevel, ParserConfig, SpecialFunction}; use regex::Regex; @@ -56,6 +56,8 @@ pub struct Emmyrc { pub document_color: EmmyrcDocumentColor, #[serde(default)] pub code_action: EmmyrcCodeAction, + #[serde(default)] + pub inline_values: EmmyrcInlineValues, } impl Emmyrc { diff --git a/crates/emmylua_code_analysis/src/db_index/flow/flow_chain.rs b/crates/emmylua_code_analysis/src/db_index/flow/flow_chain.rs index be6bb2519..e601f4ee6 100644 --- a/crates/emmylua_code_analysis/src/db_index/flow/flow_chain.rs +++ b/crates/emmylua_code_analysis/src/db_index/flow/flow_chain.rs @@ -40,6 +40,10 @@ impl LuaFlowChain { }) .map(|assert| &assert.type_assert) } + + pub fn get_all_type_asserts(&self) -> impl Iterator { + self.type_asserts.iter().map(|assert| &assert.type_assert) + } } #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] diff --git a/crates/emmylua_code_analysis/src/db_index/flow/flow_var_ref_id.rs b/crates/emmylua_code_analysis/src/db_index/flow/flow_var_ref_id.rs index d395fb7bc..37654887e 100644 --- a/crates/emmylua_code_analysis/src/db_index/flow/flow_var_ref_id.rs +++ b/crates/emmylua_code_analysis/src/db_index/flow/flow_var_ref_id.rs @@ -1,13 +1,14 @@ -use emmylua_parser::{LuaAstNode, LuaDocTagCast, LuaVarExpr}; +use emmylua_parser::{LuaAstNode, LuaDocTagCast, LuaSyntaxId, LuaVarExpr}; use rowan::{TextRange, TextSize}; use smol_str::SmolStr; -use crate::LuaDeclId; +use crate::{InFiled, LuaDeclId}; #[derive(Debug, Eq, PartialEq, Clone, Hash)] pub enum LuaVarRefId { DeclId(LuaDeclId), Name(SmolStr), + SyntaxId(InFiled), } #[derive(Debug, Eq, PartialEq, Clone, Hash)] diff --git a/crates/emmylua_code_analysis/src/db_index/module/mod.rs b/crates/emmylua_code_analysis/src/db_index/module/mod.rs index 56411947b..a26a92a22 100644 --- a/crates/emmylua_code_analysis/src/db_index/module/mod.rs +++ b/crates/emmylua_code_analysis/src/db_index/module/mod.rs @@ -305,7 +305,7 @@ impl LuaModuleIndex { self.file_module_map.values().collect() } - fn extract_module_path(&self, path: &str) -> Option<(String, WorkspaceId)> { + pub fn extract_module_path(&self, path: &str) -> Option<(String, WorkspaceId)> { let path = Path::new(path); let mut matched_module_path: Option<(String, WorkspaceId)> = None; for workspace in &self.workspaces { @@ -339,7 +339,7 @@ impl LuaModuleIndex { module_path.to_string() } - fn match_pattern(&self, path: &str) -> Option { + pub fn match_pattern(&self, path: &str) -> Option { for pattern in &self.module_patterns { if let Some(captures) = pattern.captures(path) { if let Some(matched) = captures.get(1) { diff --git a/crates/emmylua_code_analysis/src/db_index/operators/lua_operator.rs b/crates/emmylua_code_analysis/src/db_index/operators/lua_operator.rs index c3847a8e8..f274db0dd 100644 --- a/crates/emmylua_code_analysis/src/db_index/operators/lua_operator.rs +++ b/crates/emmylua_code_analysis/src/db_index/operators/lua_operator.rs @@ -164,6 +164,10 @@ impl LuaOperator { position: self.range.start(), } } + + pub fn get_range(&self) -> TextRange { + self.range + } } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] diff --git a/crates/emmylua_code_analysis/src/db_index/signature/signature.rs b/crates/emmylua_code_analysis/src/db_index/signature/signature.rs index d5565031d..4d6539a15 100644 --- a/crates/emmylua_code_analysis/src/db_index/signature/signature.rs +++ b/crates/emmylua_code_analysis/src/db_index/signature/signature.rs @@ -123,13 +123,29 @@ impl LuaSignature { } if let Some(param_info) = self.get_param_info_by_id(0) { - if param_info.type_ref.is_self_infer() { + let param_type = ¶m_info.type_ref; + if param_type.is_self_infer() { return true; } match owner_type { - Some(owner_type) => semantic_model - .type_check(owner_type, ¶m_info.type_ref) - .is_ok(), + Some(owner_type) => { + // 一些类型不应该被视为 method + match (owner_type, param_type) { + (LuaType::Ref(_) | LuaType::Def(_), _) => { + if param_type.is_any() + || param_type.is_table() + || param_type.is_class_tpl() + { + return false; + } + } + _ => {} + } + + semantic_model + .type_check(owner_type, ¶m_info.type_ref) + .is_ok() + } None => param_info.name == "self", } } else { diff --git a/crates/emmylua_code_analysis/src/db_index/type/mod.rs b/crates/emmylua_code_analysis/src/db_index/type/mod.rs index f5df8c954..7994ef4dd 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/mod.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/mod.rs @@ -207,6 +207,15 @@ impl LuaTypeIndex { self.full_name_type_map.values().collect() } + pub fn get_file_namespaces(&self) -> Vec { + self.file_namespace + .values() + .cloned() + .collect::>() + .into_iter() + .collect() + } + pub fn get_type_decl_mut(&mut self, decl_id: &LuaTypeDeclId) -> Option<&mut LuaTypeDecl> { self.full_name_type_map.get_mut(decl_id) } diff --git a/crates/emmylua_code_analysis/src/db_index/type/types.rs b/crates/emmylua_code_analysis/src/db_index/type/types.rs index 57d5ff94b..38f058348 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/types.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/types.rs @@ -561,7 +561,19 @@ impl LuaFunctionType { return true; } match owner_type { - Some(owner_type) => semantic_model.type_check(owner_type, t).is_ok(), + Some(owner_type) => { + // 一些类型不应该被视为 method + match (owner_type, t) { + (LuaType::Ref(_) | LuaType::Def(_), _) => { + if t.is_any() || t.is_table() || t.is_class_tpl() { + return false; + } + } + _ => {} + } + + semantic_model.type_check(owner_type, t).is_ok() + } None => name == "self", } } @@ -934,6 +946,7 @@ impl VariadicType { } } + /// 获取可变参数的最小长度, 如果可变参数是无限长度, 则返回 None pub fn get_min_len(&self) -> Option { match self { VariadicType::Base(_) => None, @@ -955,6 +968,7 @@ impl VariadicType { } } + /// 获取可变参数的最大长度, 如果可变参数是无限长度, 则返回 None pub fn get_max_len(&self) -> Option { match self { VariadicType::Base(_) => None, diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/assign_type_mismatch.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/assign_type_mismatch.rs index e278814fa..eff80e95a 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/assign_type_mismatch.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/assign_type_mismatch.rs @@ -5,8 +5,8 @@ use emmylua_parser::{ use rowan::TextRange; use crate::{ - DiagnosticCode, LuaDeclExtra, LuaDeclId, LuaSemanticDeclId, LuaType, LuaTypeCache, - SemanticDeclLevel, SemanticModel, TypeCheckFailReason, TypeCheckResult, + infer_index_expr, DiagnosticCode, LuaDeclExtra, LuaDeclId, LuaMemberKey, LuaSemanticDeclId, + LuaType, SemanticDeclLevel, SemanticModel, TypeCheckFailReason, TypeCheckResult, }; use super::{humanize_lint_type, Checker, DiagnosticContext}; @@ -37,8 +37,7 @@ fn check_assign_stat( assign: &LuaAssignStat, ) -> Option<()> { let (vars, exprs) = assign.get_var_and_expr_list(); - let value_types = - semantic_model.infer_multi_value_adjusted_expression_types(&exprs, Some(vars.len())); + let value_types = semantic_model.infer_expr_list_types(&exprs, Some(vars.len())); for (idx, var) in vars.iter().enumerate() { match var { @@ -76,7 +75,7 @@ fn check_name_expr( rowan::NodeOrToken::Node(name_expr.syntax().clone()), SemanticDeclLevel::default(), )?; - let origin_type = match semantic_decl { + let source_type = match semantic_decl { LuaSemanticDeclId::LuaDecl(decl_id) => { let decl = semantic_model .get_db() @@ -106,12 +105,18 @@ fn check_name_expr( context, semantic_model, name_expr.get_range(), - origin_type.clone(), - value_type, + source_type.as_ref(), + &value_type, false, ); if let Some(expr) = expr { - handle_value_is_table_expr(context, semantic_model, origin_type, &expr); + check_table_expr( + context, + semantic_model, + source_type.as_ref(), + Some(&value_type), + &expr, + ); } Some(()) @@ -124,41 +129,30 @@ fn check_index_expr( expr: Option, value_type: LuaType, ) -> Option<()> { - let semantic_info = - semantic_model.get_semantic_info(rowan::NodeOrToken::Node(index_expr.syntax().clone()))?; - let mut typ = None; - match semantic_info.semantic_decl { - // 如果是已显示定义的成员, 我们不能获取其经过类型缩窄后的类型 - Some(LuaSemanticDeclId::Member(member_id)) => { - let type_cache = semantic_model - .get_db() - .get_type_index() - .get_type_cache(&member_id.into()); - if let Some(type_cache) = type_cache { - match type_cache { - LuaTypeCache::DocType(ty) => { - typ = Some(ty.clone()); - } - _ => {} - } - } - } - _ => {} - } - if typ.is_none() { - typ = Some(semantic_info.typ); - } + let source_type = infer_index_expr( + semantic_model.get_db(), + &mut semantic_model.get_config().borrow_mut(), + index_expr.clone(), + false, + ) + .ok(); check_assign_type_mismatch( context, semantic_model, index_expr.get_range(), - typ.clone(), - value_type, + source_type.as_ref(), + &value_type, true, ); if let Some(expr) = expr { - handle_value_is_table_expr(context, semantic_model, typ, &expr); + check_table_expr( + context, + semantic_model, + source_type.as_ref(), + Some(&value_type), + &expr, + ); } Some(()) } @@ -170,8 +164,7 @@ fn check_local_stat( ) -> Option<()> { let vars = local.get_local_name_list().collect::>(); let value_exprs = local.get_value_exprs().collect::>(); - let value_types = - semantic_model.infer_multi_value_adjusted_expression_types(&value_exprs, Some(vars.len())); + let value_types = semantic_model.infer_expr_list_types(&value_exprs, Some(vars.len())); for (idx, var) in vars.iter().enumerate() { let name_token = var.get_name_token()?; @@ -181,90 +174,181 @@ fn check_local_stat( .get_decl_index() .get_decl(&decl_id)? .get_range(); - let name_type = semantic_model + let var_type = semantic_model .get_db() .get_type_index() .get_type_cache(&decl_id.into()) .map(|cache| cache.as_type().clone())?; + let value_type = value_types.get(idx)?.0.clone(); check_assign_type_mismatch( context, semantic_model, range, - Some(name_type.clone()), - value_types.get(idx)?.0.clone(), + Some(&var_type), + &value_type, false, ); if let Some(expr) = value_exprs.get(idx).map(|expr| expr) { - handle_value_is_table_expr(context, semantic_model, Some(name_type), &expr); + check_table_expr( + context, + semantic_model, + Some(&var_type), + Some(&value_type), + &expr, + ); } } Some(()) } -// 处理 value_expr 是 TableExpr 的情况, 但不会处理 `local a = { x = 1 }, local v = a` -fn handle_value_is_table_expr( +fn check_table_expr( context: &mut DiagnosticContext, semantic_model: &SemanticModel, - table_type: Option, - value_expr: &LuaExpr, + table_type: Option<&LuaType>, // 记录的类型 + expr_type: Option<&LuaType>, // 实际表达式推导出的类型 + table_expr: &LuaExpr, ) -> Option<()> { - let table_type = table_type?; - let fields = LuaTableExpr::cast(value_expr.syntax().clone())? - .get_fields() - .collect::>(); - if fields.len() > 50 { - // 如果字段过多, 则不进行类型检查 + // 需要进行一些过滤 + if table_type == expr_type { return Some(()); } + let table_type = table_type?; + match table_type { + LuaType::Def(_) => return Some(()), + _ => {} + } + if let Some(table_expr) = LuaTableExpr::cast(table_expr.syntax().clone()) { + check_table_expr_content(context, semantic_model, table_type, &table_expr); + } + Some(()) +} - for field in fields { - if field.is_value_field() { - continue; +// 处理 value_expr 是 TableExpr 的情况, 但不会处理 `local a = { x = 1 }, local v = a` +fn check_table_expr_content( + context: &mut DiagnosticContext, + semantic_model: &SemanticModel, + table_type: &LuaType, + table_expr: &LuaTableExpr, +) -> Option<()> { + const MAX_CHECK_COUNT: usize = 250; + let mut check_count = 0; + + let fields = table_expr.get_fields().collect::>(); + + for (idx, field) in fields.iter().enumerate() { + check_count += 1; + if check_count > MAX_CHECK_COUNT { + return Some(()); } + let Some(value_expr) = field.get_value_expr() else { + continue; + }; - let field_key = field.get_field_key(); - if let Some(field_key) = field_key { - let member_key = semantic_model.get_member_key(&field_key)?; - let source_type = match semantic_model.infer_member_type(&table_type, &member_key) { - Ok(typ) => Some(typ), - Err(_) => { + let expr_type = semantic_model + .infer_expr(value_expr.clone()) + .unwrap_or(LuaType::Any); + + // 位于的最后的 TableFieldValue 允许接受函数调用返回的多值, 而且返回的值必然会从下标 1 开始覆盖掉所有索引字段. + if field.is_value_field() && idx == fields.len() - 1 { + match &expr_type { + LuaType::Variadic(_) => { + // 解开可变参数 + let expr_types = + semantic_model.infer_expr_list_types(&vec![value_expr.clone()], None); + check_table_last_variadic_type( + context, + semantic_model, + table_type, + &expr_types, + ); continue; } - }; - - let expr = field.get_value_expr(); - if let Some(expr) = expr { - let expr_type = semantic_model.infer_expr(expr).unwrap_or(LuaType::Any); + _ => {} + } + } - let allow_nil = match table_type { - LuaType::Array(_) => true, - _ => false, - }; + let Some(field_key) = field.get_field_key() else { + continue; + }; + let Some(member_key) = semantic_model.get_member_key(&field_key) else { + continue; + }; + let source_type = match semantic_model.infer_member_type(&table_type, &member_key) { + Ok(typ) => typ, + Err(_) => { + continue; + } + }; - check_assign_type_mismatch( - context, - semantic_model, - field.get_range(), - source_type, - expr_type, - allow_nil, - ); + if source_type.is_table() || source_type.is_custom_type() { + if let Some(table_expr) = LuaTableExpr::cast(value_expr.syntax().clone()) { + // 检查子表 + check_table_expr_content(context, semantic_model, &source_type, &table_expr); + continue; } } + + let allow_nil = match table_type { + LuaType::Array(_) => true, + _ => false, + }; + + check_assign_type_mismatch( + context, + semantic_model, + field.get_range(), + Some(&source_type), + &expr_type, + allow_nil, + ); } Some(()) } +fn check_table_last_variadic_type( + context: &mut DiagnosticContext, + semantic_model: &SemanticModel, + table_type: &LuaType, + expr_types: &[(LuaType, TextRange)], +) -> Option<()> { + for (idx, (expr_type, range)) in expr_types.iter().enumerate() { + // 此时的值必然是从下标 1 开始递增的 + let member_key = LuaMemberKey::Integer(idx as i64 + 1); + let source_type = semantic_model + .infer_member_type(&table_type, &member_key) + .ok()?; + + let expr_type = if let LuaType::Variadic(variadic) = expr_type { + let Some(typ) = variadic.get_type(idx) else { + continue; + }; + typ.clone() + } else { + expr_type.clone() + }; + + check_assign_type_mismatch( + context, + semantic_model, + *range, + Some(&source_type), + &expr_type, + false, + ); + } + Some(()) +} + fn check_assign_type_mismatch( context: &mut DiagnosticContext, semantic_model: &SemanticModel, range: TextRange, - source_type: Option, - value_type: LuaType, + source_type: Option<&LuaType>, + value_type: &LuaType, allow_nil: bool, ) -> Option<()> { - let source_type = source_type.unwrap_or(LuaType::Any); + let source_type = source_type.unwrap_or(&LuaType::Any); // 如果一致, 则不进行类型检查 if source_type == value_type { return Some(()); @@ -281,9 +365,7 @@ fn check_assign_type_mismatch( (LuaType::Def(_), _) => return Some(()), // 此时检查交给 table_field (LuaType::Ref(_) | LuaType::Tuple(_), LuaType::TableConst(_)) => return Some(()), - // 如果源类型是nil, 则不进行类型检查 (LuaType::Nil, _) => return Some(()), - // // fix issue #196 (LuaType::Ref(_), LuaType::Instance(instance)) => { if instance.get_base().is_table() { return Some(()); @@ -317,35 +399,25 @@ fn add_type_check_diagnostic( let db = semantic_model.get_db(); match result { Ok(_) => return, - Err(reason) => match reason { - TypeCheckFailReason::TypeNotMatchWithReason(reason) => { - context.add_diagnostic( - DiagnosticCode::AssignTypeMismatch, - range, - t!( - "Cannot assign `%{value}` to `%{source}`. %{reason}", - value = humanize_lint_type(db, &value_type), - source = humanize_lint_type(db, &source_type), - reason = reason - ) - .to_string(), - None, - ); - } - _ => { - context.add_diagnostic( - DiagnosticCode::AssignTypeMismatch, - range, - t!( - "Cannot assign `%{value}` to `%{source}`. %{reason}", - value = humanize_lint_type(db, &value_type), - source = humanize_lint_type(db, &source_type), - reason = "" - ) - .to_string(), - None, - ); - } - }, + Err(reason) => { + let reason_message = match reason { + TypeCheckFailReason::TypeNotMatchWithReason(reason) => reason, + TypeCheckFailReason::TypeRecursion => t!("type recursion").to_string(), + _ => "".to_string(), + }; + + context.add_diagnostic( + DiagnosticCode::AssignTypeMismatch, + range, + t!( + "Cannot assign `%{value}` to `%{source}`. %{reason}", + value = humanize_lint_type(db, &value_type), + source = humanize_lint_type(db, &source_type), + reason = reason_message + ) + .to_string(), + None, + ); + } } } diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/cast_type_mismatch.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/cast_type_mismatch.rs new file mode 100644 index 000000000..c73398bae --- /dev/null +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/cast_type_mismatch.rs @@ -0,0 +1,272 @@ +use emmylua_parser::{LuaAst, LuaAstNode, LuaDocTagCast}; +use itertools::Itertools; +use rowan::TextRange; +use std::collections::HashSet; + +use crate::diagnostic::checker::generic::infer_doc_type::infer_doc_type; +use crate::{ + get_real_type, DbIndex, DiagnosticCode, LuaType, LuaUnionType, SemanticModel, + TypeCheckFailReason, TypeCheckResult, +}; + +use super::{humanize_lint_type, Checker, DiagnosticContext}; + +pub struct CastTypeMismatchChecker; + +impl Checker for CastTypeMismatchChecker { + const CODES: &[DiagnosticCode] = &[DiagnosticCode::CastTypeMismatch]; + + fn check(context: &mut DiagnosticContext, semantic_model: &SemanticModel) { + // dbg!(&semantic_model.get_root()); + for node in semantic_model.get_root().descendants::() { + if let LuaAst::LuaDocTagCast(cast_tag) = node { + check_cast_tag(context, semantic_model, &cast_tag); + } + } + } +} + +fn check_cast_tag( + context: &mut DiagnosticContext, + semantic_model: &SemanticModel, + cast_tag: &LuaDocTagCast, +) -> Option<()> { + let key_expr = cast_tag.get_key_expr()?; + let origin_type = { + let typ = semantic_model.infer_expr(key_expr).ok()?; + expand_type(semantic_model.get_db(), &typ).unwrap_or(typ) + }; + + // 检查每个 cast 操作类型 + for op_type in cast_tag.get_op_types() { + // 如果具有操作符, 则不检查 + if let Some(_) = op_type.get_op() { + continue; + } + if let Some(target_doc_type) = op_type.get_type() { + let target_type = { + let typ = infer_doc_type(semantic_model, &target_doc_type); + expand_type(semantic_model.get_db(), &typ).unwrap_or(typ) + }; + check_cast_compatibility( + context, + semantic_model, + op_type.get_range(), + &origin_type, + &target_type, + ); + } + } + + Some(()) +} + +fn check_cast_compatibility( + context: &mut DiagnosticContext, + semantic_model: &SemanticModel, + range: TextRange, + origin_type: &LuaType, + target_type: &LuaType, +) -> Option<()> { + if origin_type == target_type { + return Some(()); + } + + // 检查是否可以从原始类型转换为目标类型 + let result = match origin_type { + LuaType::Union(union_type) => { + for member_type in union_type.get_types() { + // 不检查 nil 类型 + if member_type.is_nil() { + continue; + } + if cast_type_check(semantic_model, member_type, target_type, 0).is_ok() { + return Some(()); + } + } + Err(TypeCheckFailReason::TypeNotMatch) + } + _ => cast_type_check(semantic_model, origin_type, target_type, 0), + }; + + if !result.is_ok() { + add_cast_type_mismatch_diagnostic( + context, + semantic_model, + range, + origin_type, + target_type, + result, + ); + } + + Some(()) +} + +fn add_cast_type_mismatch_diagnostic( + context: &mut DiagnosticContext, + semantic_model: &SemanticModel, + range: TextRange, + origin_type: &LuaType, + target_type: &LuaType, + result: TypeCheckResult, +) { + let db = semantic_model.get_db(); + match result { + Ok(_) => return, + Err(reason) => { + let reason_message = match reason { + TypeCheckFailReason::TypeNotMatchWithReason(reason) => reason, + TypeCheckFailReason::TypeNotMatch | TypeCheckFailReason::DonotCheck => { + "".to_string() + } + TypeCheckFailReason::TypeRecursion => t!("type recursion").to_string(), + }; + + context.add_diagnostic( + DiagnosticCode::CastTypeMismatch, + range, + t!( + "Cannot cast `%{original}` to `%{target}`. %{reason}", + original = humanize_lint_type(db, origin_type), + target = humanize_lint_type(db, target_type), + reason = reason_message + ) + .to_string(), + None, + ); + } + } +} + +fn cast_type_check( + semantic_model: &SemanticModel, + origin_type: &LuaType, + target_type: &LuaType, + recursion_depth: u32, +) -> TypeCheckResult { + const MAX_RECURSION_DEPTH: u32 = 100; + if recursion_depth >= MAX_RECURSION_DEPTH { + return Err(TypeCheckFailReason::TypeRecursion); + } + + if origin_type == target_type { + return Ok(()); + } + + // cast 规则非常宽松 + match (origin_type, target_type) { + (LuaType::Any | LuaType::Nil, _) => Ok(()), + (LuaType::Number, LuaType::Integer) => Ok(()), + (LuaType::Userdata, target_type) + if target_type.is_table() || target_type.is_custom_type() => + { + Ok(()) + } + (_, LuaType::Union(union)) => { + // 通常来说这个的原始类型为 alias / enum-field 的集合 + for member_type in union.get_types() { + match cast_type_check( + semantic_model, + origin_type, + member_type, + recursion_depth + 1, + ) { + Ok(_) => {} + Err(reason) => { + return Err(reason); + } + } + } + return Ok(()); + } + _ => { + if origin_type.is_table() { + if target_type.is_table() || target_type.is_custom_type() { + return Ok(()); + } + } else if origin_type.is_custom_type() { + if target_type.is_table() { + return Ok(()); + } + } else if origin_type.is_string() { + if target_type.is_string() { + return Ok(()); + } + } else if origin_type.is_number() { + if target_type.is_number() { + return Ok(()); + } + } + + semantic_model.type_check(target_type, origin_type) + } + } +} + +fn expand_type(db: &DbIndex, typ: &LuaType) -> Option { + let mut visited = HashSet::new(); + expand_type_recursive(db, typ, &mut visited) +} + +fn expand_type_recursive( + db: &DbIndex, + typ: &LuaType, + visited: &mut HashSet, +) -> Option { + // TODO: 优化性能 + // 防止无限递归, 性能很有问题, 但 @cast 使用频率不高, 这是可以接受的 + if visited.contains(typ) { + return Some(typ.clone()); + } + visited.insert(typ.clone()); + + // 展开类型, 如果具有多种类型将尽量返回 union + match get_real_type(db, &typ).unwrap_or(&typ) { + LuaType::Ref(id) | LuaType::Def(id) => { + let type_decl = db.get_type_index().get_type_decl(id)?; + if type_decl.is_enum() { + if let Some(typ) = type_decl.get_enum_field_type(db) { + return expand_type_recursive(db, &typ, visited); + } + }; + } + LuaType::Instance(inst) => { + let base = inst.get_base(); + return Some(base.clone()); + } + LuaType::MultiLineUnion(multi_union) => { + let union = multi_union.to_union(); + return expand_type_recursive(db, &union, visited); + } + LuaType::Union(union_type) => { + // 递归展开 union 中的每个类型 + let mut expanded_types = Vec::new(); + for inner_type in union_type.get_types() { + if let Some(expanded) = expand_type_recursive(db, inner_type, visited) { + match expanded { + LuaType::Union(inner_union) => { + // 如果展开后还是 union,则将其成员类型添加到结果中 + expanded_types.extend(inner_union.get_types().iter().cloned()); + } + _ => { + expanded_types.push(expanded); + } + } + } else { + expanded_types.push(inner_type.clone()); + } + } + + let expanded_types = expanded_types.into_iter().unique().collect::>(); + return match expanded_types.len() { + 0 => Some(LuaType::Unknown), + 1 => Some(expanded_types[0].clone().into()), + _ => Some(LuaType::Union(LuaUnionType::new(expanded_types).into())), + }; + } + LuaType::TypeGuard(_) => return Some(LuaType::Boolean), + _ => {} + } + Some(typ.clone()) +} diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/check_field.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/check_field.rs index 05fcd6995..8e8658917 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/check_field.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/check_field.rs @@ -2,7 +2,7 @@ use std::collections::HashSet; use emmylua_parser::{ LuaAst, LuaAstNode, LuaElseIfClauseStat, LuaForRangeStat, LuaForStat, LuaIfStat, LuaIndexExpr, - LuaIndexKey, LuaRepeatStat, LuaSyntaxKind, LuaVarExpr, LuaWhileStat, + LuaIndexKey, LuaRepeatStat, LuaSyntaxKind, LuaTokenKind, LuaVarExpr, LuaWhileStat, }; use crate::{ @@ -69,17 +69,6 @@ fn check_index_expr( let index_key = index_expr.get_index_key()?; - // 检查是否为判断语句 - if matches!(code, DiagnosticCode::UndefinedField) { - if is_in_conditional_statement(index_expr) { - return Some(()); - } - } - - if is_in_conditional_statement(index_expr) { - return Some(()); - } - if is_valid_member(semantic_model, &prefix_typ, index_expr, &index_key, code).is_some() { return Some(()); } @@ -154,6 +143,30 @@ fn is_valid_member( _ => {} } + // 如果位于检查语句中, 则可以做一些宽泛的检查 + if matches!(code, DiagnosticCode::UndefinedField) && in_conditional_statement(index_expr) { + for child in index_expr.syntax().children_with_tokens() { + if child.kind() == LuaTokenKind::TkLeftBracket.into() { + // 此时为 [] 访问, 大部分类型都可以直接通行 + match prefix_typ { + LuaType::Ref(id) | LuaType::Def(id) => { + if let Some(decl) = + semantic_model.get_db().get_type_index().get_type_decl(&id) + { + // enum 仍然需要检查 + if decl.is_enum() { + break; + } else { + return Some(()); + } + } + } + _ => return Some(()), + } + } + } + } + // 检查 member_info let need_add_diagnostic = match semantic_model.get_semantic_info(index_expr.syntax().clone().into()) { @@ -365,9 +378,9 @@ fn get_key_types(typ: &LuaType) -> HashSet { /// * `node` - 要检查的AST节点 /// /// # 返回值 -/// * `true` - 如果节点位于判断语句的条件表达式中 -/// * `false` - 如果节点不在判断语句的条件表达式中 -fn is_in_conditional_statement(node: &T) -> bool { +/// * `true` - 节点位于判断语句的条件表达式中 +/// * `false` - 节点不在判断语句的条件表达式中 +fn in_conditional_statement(node: &T) -> bool { let node_range = node.get_range(); // 遍历所有祖先节点,查找条件语句 diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/generic/generic_constraint_mismatch.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/generic/generic_constraint_mismatch.rs index d22aadf09..8474d523d 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/generic/generic_constraint_mismatch.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/generic/generic_constraint_mismatch.rs @@ -1,7 +1,7 @@ use emmylua_parser::{LuaAst, LuaAstNode, LuaAstToken, LuaCallExpr, LuaDocTagType}; use rowan::TextRange; -use crate::diagnostic::checker::generic::infer_type::infer_type; +use crate::diagnostic::checker::generic::infer_doc_type::infer_doc_type; use crate::diagnostic::checker::param_type_check::get_call_source_type; use crate::{ humanize_type, DiagnosticCode, GenericTplId, LuaMemberOwner, LuaSemanticDeclId, LuaSignature, @@ -40,7 +40,7 @@ fn check_doc_tag_type( ) -> Option<()> { let type_list = doc_tag_type.get_type_list(); for doc_type in type_list { - let type_ref = infer_type(semantic_model, &doc_type); + let type_ref = infer_doc_type(semantic_model, &doc_type); let generic_type = match type_ref { LuaType::Generic(generic_type) => generic_type, _ => continue, @@ -85,8 +85,7 @@ fn check_call_expr( let mut params = signature.get_type_params(); let arg_exprs = call_expr.get_args_list()?.get_args().collect::>(); - let mut arg_infos = - semantic_model.infer_multi_value_adjusted_expression_types(&arg_exprs, None); + let mut arg_infos = semantic_model.infer_expr_list_types(&arg_exprs, None); match (call_expr.is_colon_call(), signature.is_colon_define) { (true, true) | (false, false) => {} (false, true) => { diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/generic/infer_type.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/generic/infer_doc_type.rs similarity index 91% rename from crates/emmylua_code_analysis/src/diagnostic/checker/generic/infer_type.rs rename to crates/emmylua_code_analysis/src/diagnostic/checker/generic/infer_doc_type.rs index e5fc9ae38..15f42d45f 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/generic/infer_type.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/generic/infer_doc_type.rs @@ -1,7 +1,7 @@ use emmylua_parser::{ - LuaAstNode, LuaDocBinaryType, LuaDocFuncType, LuaDocGenericType, LuaDocMultiLineUnionType, - LuaDocObjectFieldKey, LuaDocObjectType, LuaDocStrTplType, LuaDocType, LuaDocUnaryType, - LuaDocVariadicType, LuaLiteralToken, LuaSyntaxKind, LuaTypeBinaryOperator, + LuaAstNode, LuaDocBinaryType, LuaDocDescriptionOwner, LuaDocFuncType, LuaDocGenericType, + LuaDocMultiLineUnionType, LuaDocObjectFieldKey, LuaDocObjectType, LuaDocStrTplType, LuaDocType, + LuaDocUnaryType, LuaDocVariadicType, LuaLiteralToken, LuaSyntaxKind, LuaTypeBinaryOperator, LuaTypeUnaryOperator, }; use rowan::TextRange; @@ -14,7 +14,7 @@ use crate::{ VariadicType, }; -pub fn infer_type(semantic_model: &SemanticModel, node: &LuaDocType) -> LuaType { +pub fn infer_doc_type(semantic_model: &SemanticModel, node: &LuaDocType) -> LuaType { match node { LuaDocType::Name(name_type) => { if let Some(name) = name_type.get_name_text() { @@ -28,7 +28,7 @@ pub fn infer_type(semantic_model: &SemanticModel, node: &LuaDocType) -> LuaType } LuaDocType::Nullable(nullable_type) => { if let Some(inner_type) = nullable_type.get_type() { - let t = infer_type(semantic_model, &inner_type); + let t = infer_doc_type(semantic_model, &inner_type); if t.is_unknown() { return LuaType::Unknown; } @@ -42,7 +42,7 @@ pub fn infer_type(semantic_model: &SemanticModel, node: &LuaDocType) -> LuaType } LuaDocType::Array(array_type) => { if let Some(inner_type) = array_type.get_type() { - let t = infer_type(semantic_model, &inner_type); + let t = infer_doc_type(semantic_model, &inner_type); if t.is_unknown() { return LuaType::Unknown; } @@ -74,7 +74,7 @@ pub fn infer_type(semantic_model: &SemanticModel, node: &LuaDocType) -> LuaType LuaDocType::Tuple(tuple_type) => { let mut types = Vec::new(); for type_node in tuple_type.get_types() { - let t = infer_type(semantic_model, &type_node); + let t = infer_doc_type(semantic_model, &type_node); if t.is_unknown() { return LuaType::Unknown; } @@ -199,7 +199,7 @@ fn infer_generic_type(semantic_model: &SemanticModel, generic_type: &LuaDocGener let mut generic_params = Vec::new(); if let Some(generic_decl_list) = generic_type.get_generic_types() { for param in generic_decl_list.get_types() { - let param_type = infer_type(semantic_model, ¶m); + let param_type = infer_doc_type(semantic_model, ¶m); if param_type.is_unknown() { return LuaType::Unknown; } @@ -224,7 +224,7 @@ fn infer_special_generic_type( let mut types = Vec::new(); if let Some(generic_decl_list) = generic_type.get_generic_types() { for param in generic_decl_list.get_types() { - let param_type = infer_type(semantic_model, ¶m); + let param_type = infer_doc_type(semantic_model, ¶m); types.push(param_type); } } @@ -232,7 +232,7 @@ fn infer_special_generic_type( } "namespace" => { let first_doc_param_type = generic_type.get_generic_types()?.get_types().next()?; - let first_param = infer_type(semantic_model, &first_doc_param_type); + let first_param = infer_doc_type(semantic_model, &first_doc_param_type); if let LuaType::DocStringConst(ns_str) = first_param { return Some(LuaType::Namespace(ns_str)); } @@ -240,7 +240,7 @@ fn infer_special_generic_type( "std.Select" => { let mut params = Vec::new(); for param in generic_type.get_generic_types()?.get_types() { - let param_type = infer_type(semantic_model, ¶m); + let param_type = infer_doc_type(semantic_model, ¶m); params.push(param_type); } return Some(LuaType::Call( @@ -250,7 +250,7 @@ fn infer_special_generic_type( "std.Unpack" => { let mut params = Vec::new(); for param in generic_type.get_generic_types()?.get_types() { - let param_type = infer_type(semantic_model, ¶m); + let param_type = infer_doc_type(semantic_model, ¶m); params.push(param_type); } return Some(LuaType::Call( @@ -260,7 +260,7 @@ fn infer_special_generic_type( "std.RawGet" => { let mut params = Vec::new(); for param in generic_type.get_generic_types()?.get_types() { - let param_type = infer_type(semantic_model, ¶m); + let param_type = infer_doc_type(semantic_model, ¶m); params.push(param_type); } return Some(LuaType::Call( @@ -269,7 +269,7 @@ fn infer_special_generic_type( } "TypeGuard" => { let first_doc_param_type = generic_type.get_generic_types()?.get_types().next()?; - let first_param = infer_type(semantic_model, &first_doc_param_type); + let first_param = infer_doc_type(semantic_model, &first_doc_param_type); return Some(LuaType::TypeGuard(first_param.into())); } @@ -281,8 +281,8 @@ fn infer_special_generic_type( fn infer_binary_type(semantic_model: &SemanticModel, binary_type: &LuaDocBinaryType) -> LuaType { if let Some((left, right)) = binary_type.get_types() { - let left_type = infer_type(semantic_model, &left); - let right_type = infer_type(semantic_model, &right); + let left_type = infer_doc_type(semantic_model, &left); + let right_type = infer_doc_type(semantic_model, &right); if left_type.is_unknown() { return right_type; } @@ -370,7 +370,7 @@ fn infer_binary_type(semantic_model: &SemanticModel, binary_type: &LuaDocBinaryT fn infer_unary_type(semantic_model: &SemanticModel, unary_type: &LuaDocUnaryType) -> LuaType { if let Some(base_type) = unary_type.get_type() { - let base = infer_type(semantic_model, &base_type); + let base = infer_doc_type(semantic_model, &base_type); if base.is_unknown() { return LuaType::Unknown; } @@ -409,7 +409,7 @@ fn infer_func_type(semantic_model: &SemanticModel, func: &LuaDocFuncType) -> Lua let nullable = param.is_nullable(); let type_ref = if let Some(type_ref) = param.get_type() { - let mut typ = infer_type(semantic_model, &type_ref); + let mut typ = infer_doc_type(semantic_model, &type_ref); if nullable && !typ.is_nullable() { typ = TypeOps::Union.apply(semantic_model.get_db(), &typ, &LuaType::Nil); } @@ -426,7 +426,7 @@ fn infer_func_type(semantic_model: &SemanticModel, func: &LuaDocFuncType) -> Lua for return_type in return_type_list.get_return_type_list() { let (_, typ) = return_type.get_name_and_type(); if let Some(typ) = typ { - let t = infer_type(semantic_model, &typ); + let t = infer_doc_type(semantic_model, &typ); return_types.push(t); } else { return_types.push(LuaType::Unknown); @@ -468,7 +468,7 @@ fn infer_object_type(semantic_model: &SemanticModel, object_type: &LuaDocObjectT LuaIndexAccessKey::String(str.get_value().to_string().into()) } LuaDocObjectFieldKey::Type(t) => { - LuaIndexAccessKey::Type(infer_type(semantic_model, &t)) + LuaIndexAccessKey::Type(infer_doc_type(semantic_model, &t)) } } } else { @@ -476,7 +476,7 @@ fn infer_object_type(semantic_model: &SemanticModel, object_type: &LuaDocObjectT }; let mut type_ref = if let Some(type_ref) = field.get_type() { - infer_type(semantic_model, &type_ref) + infer_doc_type(semantic_model, &type_ref) } else { LuaType::Unknown }; @@ -518,7 +518,7 @@ fn infer_variadic_type( variadic_type: &LuaDocVariadicType, ) -> Option { let inner_type = variadic_type.get_type()?; - let base = infer_type(semantic_model, &inner_type); + let base = infer_doc_type(semantic_model, &inner_type); let variadic = VariadicType::Base(base.clone()); Some(LuaType::Variadic(variadic.into())) } @@ -530,7 +530,7 @@ fn infer_multi_line_union_type( let mut union_members = Vec::new(); for field in multi_union.get_fields() { let alias_member_type = if let Some(field_type) = field.get_type() { - let type_ref = infer_type(semantic_model, &field_type); + let type_ref = infer_doc_type(semantic_model, &field_type); if type_ref.is_unknown() { continue; } diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/generic/mod.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/generic/mod.rs index 9d3526474..14743b0af 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/generic/mod.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/generic/mod.rs @@ -1,2 +1,2 @@ pub mod generic_constraint_mismatch; -mod infer_type; +pub mod infer_doc_type; diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/mod.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/mod.rs index 81905b1b4..791e686b7 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/mod.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/mod.rs @@ -2,6 +2,7 @@ mod access_invisible; mod analyze_error; mod assign_type_mismatch; mod await_in_sync; +mod cast_type_mismatch; mod check_field; mod check_param_count; mod check_return_count; @@ -96,6 +97,7 @@ pub fn check_file(context: &mut DiagnosticContext, semantic_model: &SemanticMode context, semantic_model, ); + run_check::(context, semantic_model); run_check::( context, @@ -275,11 +277,12 @@ pub fn get_return_stats(closure_expr: &LuaClosureExpr) -> impl Iterator String { match typ { - LuaType::Ref(type_decl_id) => type_decl_id.get_simple_name().to_string(), - LuaType::Generic(generic_type) => generic_type - .get_base_type_id() - .get_simple_name() - .to_string(), + // TODO: 应该仅去掉命名空间 + // LuaType::Ref(type_decl_id) => type_decl_id.get_simple_name().to_string(), + // LuaType::Generic(generic_type) => generic_type + // .get_base_type_id() + // .get_simple_name() + // .to_string(), LuaType::IntegerConst(_) => "integer".to_string(), LuaType::FloatConst(_) => "number".to_string(), LuaType::BooleanConst(_) => "boolean".to_string(), diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/param_type_check.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/param_type_check.rs index f1e536102..9bbc69b81 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/param_type_check.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/param_type_check.rs @@ -36,7 +36,7 @@ fn check_call_expr( let mut params = func.get_params().to_vec(); let arg_exprs = call_expr.get_args_list()?.get_args().collect::>(); let (mut arg_types, mut arg_ranges): (Vec, Vec) = semantic_model - .infer_multi_value_adjusted_expression_types(&arg_exprs, None) + .infer_expr_list_types(&arg_exprs, None) .into_iter() .unzip(); diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/return_type_mismatch.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/return_type_mismatch.rs index 86fac146e..15db135a7 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/return_type_mismatch.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/return_type_mismatch.rs @@ -55,10 +55,8 @@ fn check_return_stat( return_stat: &LuaReturnStat, ) -> Option<()> { let (return_expr_types, return_expr_ranges) = { - let infos = semantic_model.infer_multi_value_adjusted_expression_types( - &return_stat.get_expr_list().collect::>(), - None, - ); + let infos = semantic_model + .infer_expr_list_types(&return_stat.get_expr_list().collect::>(), None); let mut return_expr_types = infos.iter().map(|(typ, _)| typ.clone()).collect::>(); // 解决 setmetatable 的返回值类型问题 let setmetatable_index = has_setmetatable(semantic_model, return_stat); diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/unbalanced_assignments.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/unbalanced_assignments.rs index 103b218d2..081a8d52e 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/unbalanced_assignments.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/unbalanced_assignments.rs @@ -59,8 +59,7 @@ fn check_unbalanced_assignment( return Some(()); } - let value_types = - semantic_model.infer_multi_value_adjusted_expression_types(value_exprs, Some(vars.len())); + let value_types = semantic_model.infer_expr_list_types(value_exprs, Some(vars.len())); if let Some(last_type) = value_types.last() { if check_last(&last_type.0) { return Some(()); diff --git a/crates/emmylua_code_analysis/src/diagnostic/lua_diagnostic_code.rs b/crates/emmylua_code_analysis/src/diagnostic/lua_diagnostic_code.rs index 35de9ae6c..fe3a3d616 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/lua_diagnostic_code.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/lua_diagnostic_code.rs @@ -92,6 +92,8 @@ pub enum DiagnosticCode { DuplicateIndex, /// generic-constraint-mismatch GenericConstraintMismatch, + /// cast-type-mismatch + CastTypeMismatch, #[serde(other)] None, diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/assign_type_mismatch_test.rs b/crates/emmylua_code_analysis/src/diagnostic/test/assign_type_mismatch_test.rs index 5f149ea94..cb6f97446 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/test/assign_type_mismatch_test.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/test/assign_type_mismatch_test.rs @@ -858,4 +858,68 @@ return t "# )); } + + #[test] + fn test_nesting_table_field_1() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@class T1 + ---@field x T2 + + ---@class T2 + ---@field xx number + "#, + ); + assert!(!ws.check_code_for( + DiagnosticCode::AssignTypeMismatch, + r#" + ---@type T1 + local t = { + x = { + xx = "", + } + } + "# + )); + } + + #[test] + fn test_nesting_table_field_2() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@class T1 + ---@field x number + "#, + ); + assert!(!ws.check_code_for( + DiagnosticCode::AssignTypeMismatch, + r#" + ---@type T1 + local t = { + x = { + xx = "", + } + } + "# + )); + } + + #[test] + fn test_issue_525() { + let mut ws = VirtualWorkspace::new(); + assert!(ws.check_code_for( + DiagnosticCode::AssignTypeMismatch, + r#" + ---@type table + local lines + for lnum = 1, #lines do + if lines[lnum] == true then + lines[lnum] = '' + end + end + "# + )); + } } diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/cast_type_mismatch_test.rs b/crates/emmylua_code_analysis/src/diagnostic/test/cast_type_mismatch_test.rs new file mode 100644 index 000000000..c76db8f65 --- /dev/null +++ b/crates/emmylua_code_analysis/src/diagnostic/test/cast_type_mismatch_test.rs @@ -0,0 +1,236 @@ +#[cfg(test)] +mod tests { + use crate::DiagnosticCode; + use crate::VirtualWorkspace; + + #[test] + fn test_valid_cast() { + let mut ws = VirtualWorkspace::new(); + let code = r#" +---@cast a number +---@cast a.field string +---@cast A.b.c.d boolean +---@cast -? + "#; + + assert!(ws.check_code_for(DiagnosticCode::CastTypeMismatch, code)); + } + + #[test] + fn test_invalid_cast() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@type string|boolean + A = "1" + "#, + ); + assert!(!ws.check_code_for( + DiagnosticCode::CastTypeMismatch, + r#" + ---@cast A number + "# + )); + } + + #[test] + fn test_valid_cast_from_union_to_member() { + let mut ws = VirtualWorkspace::new(); + assert!(ws.check_code_for( + DiagnosticCode::CastTypeMismatch, + r#" + ---@type string|number|boolean + local value + + ---@cast value string + "# + )); + } + + #[test] + fn test_invalid_cast_to_non_member() { + let mut ws = VirtualWorkspace::new(); + assert!(!ws.check_code_for( + DiagnosticCode::CastTypeMismatch, + r#" + ---@type string|boolean + local value + + ---@cast value table + "# + )); + } + + #[test] + fn test_cast_with_nil() { + let mut ws = VirtualWorkspace::new(); + assert!(ws.check_code_for( + DiagnosticCode::CastTypeMismatch, + r#" + ---@type string? + local value + + ---@cast value string + "# + )); + } + + #[test] + fn test_cast_same_type() { + let mut ws = VirtualWorkspace::new(); + assert!(ws.check_code_for( + DiagnosticCode::CastTypeMismatch, + r#" + ---@type string + local value + + ---@cast value string + "# + )); + } + + #[test] + fn test_cast_multiple_operations() { + let mut ws = VirtualWorkspace::new(); + assert!(ws.check_code_for( + DiagnosticCode::CastTypeMismatch, + r#" + ---@type string|boolean + local value + + ---@cast value +number, -boolean + "# + )); + } + + #[test] + fn test_cast_class_types() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@class Animal + ---@class Dog : Animal + "#, + ); + assert!(ws.check_code_for( + DiagnosticCode::CastTypeMismatch, + r#" + ---@type Animal + local pet + + ---@cast pet Dog + "# + )); + } + + #[test] + fn test_cast_invalid_class_types() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@class Animal + ---@class Car + "#, + ); + assert!(!ws.check_code_for( + DiagnosticCode::CastTypeMismatch, + r#" + ---@type Animal + local pet + + ---@cast pet Car + "# + )); + } + + #[test] + fn test_cast_1() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@class Animal.Dog + "#, + ); + assert!(ws.check_code_for( + DiagnosticCode::CastTypeMismatch, + r#" + ---@type any + local pet + + ---@cast pet Animal.Dog + "# + )); + } + + #[test] + fn test_cast_alias_1() { + let mut ws = VirtualWorkspace::new(); + assert!(ws.check_code_for( + DiagnosticCode::CastTypeMismatch, + r#" + ---@alias KV.SupportType + ---| boolean + ---| integer + ---| number + ---| string + + + ---@param value KV.SupportType + ---@return any + ---@return string + local function get_py_value_and_type(value) + local tp = type(value) + if tp == 'number' then + ---@cast value number + return value, math.type(value) + end + return value, tp + end + "# + )); + } + + #[test] + fn test_cast_alias_2() { + let mut ws = VirtualWorkspace::new(); + assert!(ws.check_code_for( + DiagnosticCode::CastTypeMismatch, + r#" + ---@alias KeyAlias + ---| "a" # 2010001 + ---| "b" # 2010002 + + ---@type string + local key + + ---@cast key KeyAlias + "# + )); + + assert!(!ws.check_code_for( + DiagnosticCode::CastTypeMismatch, + r#" + ---@alias IdAlias + ---| 2010001 + ---| 2010002 + + ---@type string + local key + + ---@cast key IdAlias + "# + )); + + assert!(!ws.check_code_for( + DiagnosticCode::CastTypeMismatch, + r#" + ---@alias IdAndKeyAlias IdAlias|KeyAlias + + ---@type string + local key + + ---@cast key IdAndKeyAlias + "# + )); + } +} diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/mod.rs b/crates/emmylua_code_analysis/src/diagnostic/test/mod.rs index 4aa3edfd5..29cb8a599 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/test/mod.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/test/mod.rs @@ -1,6 +1,7 @@ mod access_invisible_test; mod assign_type_mismatch_test; mod await_in_sync_test; +mod cast_type_mismatch_test; mod check_return_count_test; mod code_style; mod disable_line_test; diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/need_check_nil_test.rs b/crates/emmylua_code_analysis/src/diagnostic/test/need_check_nil_test.rs index 68aa51a52..7729ecdee 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/test/need_check_nil_test.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/test/need_check_nil_test.rs @@ -55,4 +55,26 @@ mod test { "# )); } + + #[test] + fn test_cast() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@class Cast1 + ---@field get fun(self: self, a: number): Cast1? + ---@field get2 fun(self: self, a: number): Cast1? + "#, + ); + assert!(ws.check_code_for( + DiagnosticCode::NeedCheckNil, + r#" + ---@type Cast1 + local A + + local a = A:get(1) --[[@cast -?]] + :get2(2) + "# + )); + } } diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/undefined_field_test.rs b/crates/emmylua_code_analysis/src/diagnostic/test/undefined_field_test.rs index 5f36f356d..31afce6b8 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/test/undefined_field_test.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/test/undefined_field_test.rs @@ -657,4 +657,82 @@ mod test { "# )); } + + #[test] + fn test_if_custom_type_1() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@enum Flags + Flags = { + b = 1 + } + "#, + ); + assert!(!ws.check_code_for( + DiagnosticCode::UndefinedField, + r#" + + if Flags.a then + end + "# + )); + + assert!(!ws.check_code_for( + DiagnosticCode::UndefinedField, + r#" + + if Flags['a'] then + end + "# + )); + } + + #[test] + fn test_if_custom_type_2() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@class Flags + ---@field a number + Flags = {} + "#, + ); + + assert!(!ws.check_code_for( + DiagnosticCode::UndefinedField, + r#" + if Flags.b then + end + "# + )); + + assert!(ws.check_code_for( + DiagnosticCode::UndefinedField, + r#" + if Flags["b"] then + end + "# + )); + + assert!(ws.check_code_for( + DiagnosticCode::UndefinedField, + r#" + ---@type string + local a + if Flags[a] then + end + "# + )); + + assert!(ws.check_code_for( + DiagnosticCode::UndefinedField, + r#" + ---@type string + local c + if Flags[c] then + end + "# + )); + } } diff --git a/crates/emmylua_code_analysis/src/lib.rs b/crates/emmylua_code_analysis/src/lib.rs index aae05b98d..1348e779c 100644 --- a/crates/emmylua_code_analysis/src/lib.rs +++ b/crates/emmylua_code_analysis/src/lib.rs @@ -233,6 +233,28 @@ impl EmmyLuaAnalysis { self.compilation.update_index(lib_file_ids); self.compilation.update_index(main_file_ids); } + + /// 清理文件系统中不再存在的文件 + pub fn cleanup_nonexistent_files(&mut self) { + let mut files_to_remove = Vec::new(); + + // 获取所有当前在VFS中的文件 + let vfs = self.compilation.get_db().get_vfs(); + for file_id in vfs.get_all_file_ids() { + if let Some(path) = vfs.get_file_path(&file_id) { + if !path.exists() { + if let Some(uri) = file_path_to_uri(path) { + files_to_remove.push(uri); + } + } + } + } + + // 移除不存在的文件 + for uri in files_to_remove { + self.remove_file_by_uri(&uri); + } + } } unsafe impl Send for EmmyLuaAnalysis {} diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs index 809a0e3f0..a40a5bc6f 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs @@ -10,8 +10,8 @@ use super::{ }, InferFailReason, InferResult, }; -use crate::semantic::generic::instantiate_doc_function; use crate::semantic::infer_expr; +use crate::{semantic::generic::instantiate_doc_function, LuaVarRefId}; use crate::{ CacheEntry, CacheKey, DbIndex, InFiled, LuaFunctionType, LuaGenericType, LuaInstanceType, LuaOperatorMetaMethod, LuaOperatorOwner, LuaSignatureId, LuaType, LuaTypeDeclId, LuaUnionType, @@ -625,17 +625,28 @@ pub fn infer_call_expr( let prefix_expr = call_expr.get_prefix_expr().ok_or(InferFailReason::None)?; let prefix_type = infer_expr(db, cache, prefix_expr)?; - - Ok(infer_call_expr_func( + let mut ret_type = infer_call_expr_func( db, cache, - call_expr, + call_expr.clone(), prefix_type, &mut InferGuard::new(), None, )? .get_ret() - .clone()) + .clone(); + + let file_id = cache.get_file_id(); + let var_ref_id = LuaVarRefId::SyntaxId(InFiled::new(file_id, call_expr.get_syntax_id())); + let flow_chain = db.get_flow_index().get_flow_chain(file_id, var_ref_id); + if let Some(flow_chain) = flow_chain { + let root = call_expr.get_root(); + for type_assert in flow_chain.get_all_type_asserts() { + ret_type = type_assert.tighten_type(db, cache, &root, ret_type)?; + } + } + + Ok(ret_type) } fn check_can_infer( diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_index.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_index.rs index c7383b699..439a7b3fe 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_index.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_index.rs @@ -28,6 +28,7 @@ pub fn infer_index_expr( db: &DbIndex, cache: &mut LuaInferCache, index_expr: LuaIndexExpr, + pass_flow: bool, ) -> InferResult { let prefix_expr = index_expr.get_prefix_expr().ok_or(InferFailReason::None)?; let prefix_type = infer_expr(db, cache, prefix_expr)?; @@ -41,7 +42,16 @@ pub fn infer_index_expr( &mut InferGuard::new(), ) { Ok(member_type) => { - return infer_member_type_pass_flow(db, cache, index_expr, &prefix_type, member_type); + if pass_flow { + return infer_member_type_pass_flow( + db, + cache, + index_expr, + &prefix_type, + member_type, + ); + } + return Ok(member_type); } Err(InferFailReason::FieldNotFound) => InferFailReason::FieldNotFound, Err(err) => return Err(err), @@ -55,7 +65,16 @@ pub fn infer_index_expr( &mut InferGuard::new(), ) { Ok(member_type) => { - return infer_member_type_pass_flow(db, cache, index_expr, &prefix_type, member_type) + if pass_flow { + return infer_member_type_pass_flow( + db, + cache, + index_expr, + &prefix_type, + member_type, + ); + } + return Ok(member_type); } Err(InferFailReason::FieldNotFound) => {} Err(err) => return Err(err), diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_table.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_table.rs index 4554dfdc8..39cf029dd 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_table.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_table.rs @@ -230,8 +230,17 @@ fn infer_table_field_type_by_parent( ) -> InferResult { let member_id = LuaMemberId::new(field.get_syntax_id(), cache.get_file_id()); if let Some(type_cache) = db.get_type_index().get_type_cache(&member_id.into()) { - match type_cache.as_type() { + let typ = type_cache.as_type(); + match typ { LuaType::TableConst(_) => {} + LuaType::Tuple(tuple) => { + let types = tuple.get_types(); + // 这种情况下缓存的类型可能是不精确的 + if tuple.is_infer_resolve() && types.len() == 1 && types[0].is_unknown() { + } else { + return Ok(typ.clone()); + } + } typ => return Ok(typ.clone()), } } else { diff --git a/crates/emmylua_code_analysis/src/semantic/infer/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/mod.rs index 781177934..9ddd91bf0 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/mod.rs @@ -17,7 +17,7 @@ use infer_binary::infer_binary_expr; use infer_call::infer_call_expr; pub use infer_call::infer_call_expr_func; pub use infer_fail_reason::InferFailReason; -use infer_index::infer_index_expr; +pub use infer_index::infer_index_expr; use infer_name::infer_name_expr; pub use infer_name::{find_self_decl_or_member_id, infer_param}; use infer_table::infer_table_expr; @@ -76,7 +76,7 @@ pub fn infer_expr(db: &DbIndex, cache: &mut LuaInferCache, expr: LuaExpr) -> Inf paren_expr.get_expr().ok_or(InferFailReason::None)?, ), LuaExpr::NameExpr(name_expr) => infer_name_expr(db, cache, name_expr), - LuaExpr::IndexExpr(index_expr) => infer_index_expr(db, cache, index_expr), + LuaExpr::IndexExpr(index_expr) => infer_index_expr(db, cache, index_expr, true), }; match &result_type { @@ -170,7 +170,7 @@ fn get_custom_type_operator( } } -pub fn infer_multi_value_adjusted_expression_types( +pub fn infer_expr_list_types( db: &DbIndex, cache: &mut LuaInferCache, exprs: &[LuaExpr], diff --git a/crates/emmylua_code_analysis/src/semantic/member/infer_raw_member.rs b/crates/emmylua_code_analysis/src/semantic/member/infer_raw_member.rs index dad1cef96..6fc860db8 100644 --- a/crates/emmylua_code_analysis/src/semantic/member/infer_raw_member.rs +++ b/crates/emmylua_code_analysis/src/semantic/member/infer_raw_member.rs @@ -1,6 +1,10 @@ +use std::sync::Arc; + +use smol_str::SmolStr; + use crate::{ - DbIndex, InferFailReason, InferGuard, LuaMemberKey, LuaMemberOwner, LuaObjectType, - LuaTupleType, LuaType, LuaTypeDeclId, + check_type_compact, DbIndex, InferFailReason, InferGuard, LuaMemberKey, LuaMemberOwner, + LuaObjectType, LuaTupleType, LuaType, LuaTypeDeclId, TypeOps, }; use super::{get_buildin_type_map_type_id, RawGetMemberTypeResult}; @@ -40,6 +44,10 @@ fn infer_raw_member_type_guard( } LuaType::Tuple(tuple) => infer_tuple_raw_member_type(tuple, member_key), LuaType::Object(object) => infer_object_raw_member_type(object, member_key), + LuaType::Array(array_type) => infer_array_raw_member_type(db, array_type, member_key), + LuaType::TableGeneric(table_generic) => { + infer_table_generic_raw_member_type(db, table_generic, member_key) + } // other do not support now _ => Err(InferFailReason::None), } @@ -141,3 +149,49 @@ fn infer_object_raw_member_type( Err(InferFailReason::FieldNotFound) } + +fn infer_array_raw_member_type( + db: &DbIndex, + array_type: &LuaType, + member_key: &LuaMemberKey, +) -> RawGetMemberTypeResult { + let typ = if db.get_emmyrc().strict.array_index { + TypeOps::Union.apply(db, array_type, &LuaType::Nil) + } else { + array_type.clone() + }; + match member_key { + LuaMemberKey::Integer(_) => Ok(typ), + LuaMemberKey::ExprType(member_type) => { + if member_type.is_integer() { + Ok(typ) + } else { + Err(InferFailReason::FieldNotFound) + } + } + _ => Err(InferFailReason::FieldNotFound), + } +} + +fn infer_table_generic_raw_member_type( + db: &DbIndex, + table_params: &Arc>, + member_key: &LuaMemberKey, +) -> RawGetMemberTypeResult { + if table_params.len() != 2 { + return Err(InferFailReason::None); + } + let key_type = &table_params[0]; + let value_type = &table_params[1]; + let access_key_type = match member_key { + LuaMemberKey::Integer(i) => LuaType::IntegerConst(*i), + LuaMemberKey::Name(name) => LuaType::StringConst(SmolStr::new(name.as_str()).into()), + LuaMemberKey::ExprType(expr_type) => expr_type.clone(), + LuaMemberKey::None => return Err(InferFailReason::FieldNotFound), + }; + if check_type_compact(db, key_type, &access_key_type).is_ok() { + return Ok(value_type.clone()); + } + + Err(InferFailReason::FieldNotFound) +} diff --git a/crates/emmylua_code_analysis/src/semantic/mod.rs b/crates/emmylua_code_analysis/src/semantic/mod.rs index cf14099d0..49101b655 100644 --- a/crates/emmylua_code_analysis/src/semantic/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/mod.rs @@ -19,9 +19,11 @@ use emmylua_parser::{ LuaCallExpr, LuaChunk, LuaExpr, LuaIndexKey, LuaParseError, LuaSyntaxNode, LuaSyntaxToken, LuaTableExpr, }; -use infer::{infer_bind_value_type, infer_multi_value_adjusted_expression_types}; +pub use infer::infer_index_expr; +use infer::{infer_bind_value_type, infer_expr_list_types}; pub use infer::{infer_table_field_value_should_be, infer_table_should_be}; use lsp_types::Uri; +pub use member::find_index_operations; pub use member::get_member_map; pub use member::LuaMemberInfo; use member::{find_member_origin_owner, find_members}; @@ -151,13 +153,13 @@ impl<'a> SemanticModel<'a> { .ok() } - /// 获取赋值时所有右值类型或调用时所有参数类型或返回时所有返回值类型 - pub fn infer_multi_value_adjusted_expression_types( + /// 推断表达式列表类型, 位于最后的表达式会触发多值推断 + pub fn infer_expr_list_types( &self, exprs: &[LuaExpr], var_count: Option, ) -> Vec<(LuaType, TextRange)> { - infer_multi_value_adjusted_expression_types( + infer_expr_list_types( self.db, &mut self.infer_cache.borrow_mut(), exprs, diff --git a/crates/emmylua_code_analysis/src/semantic/type_check/simple_type.rs b/crates/emmylua_code_analysis/src/semantic/type_check/simple_type.rs index bd7f91168..5647e4f4b 100644 --- a/crates/emmylua_code_analysis/src/semantic/type_check/simple_type.rs +++ b/crates/emmylua_code_analysis/src/semantic/type_check/simple_type.rs @@ -334,27 +334,33 @@ fn check_variadic_type_compact( check_guard: TypeCheckGuard, ) -> TypeCheckResult { match &source_type { - VariadicType::Base(source_base) => { - if let LuaType::Variadic(compact_variadic) = compact_type { - match compact_variadic.deref() { - VariadicType::Base(compact_base) => { - if source_base == compact_base { - return Ok(()); - } + VariadicType::Base(source_base) => match compact_type { + LuaType::Variadic(compact_variadic) => match compact_variadic.deref() { + VariadicType::Base(compact_base) => { + if source_base == compact_base { + return Ok(()); } - VariadicType::Multi(compact_multi) => { - for compact_type in compact_multi { - check_simple_type_compact( - db, - source_base, - compact_type, - check_guard.next_level()?, - )?; - } + } + VariadicType::Multi(compact_multi) => { + for compact_type in compact_multi { + check_simple_type_compact( + db, + source_base, + compact_type, + check_guard.next_level()?, + )?; } } + }, + _ => { + check_simple_type_compact( + db, + source_base, + compact_type, + check_guard.next_level()?, + )?; } - } + }, VariadicType::Multi(_) => {} } diff --git a/crates/emmylua_ls/locales/action/zh_CN.yaml b/crates/emmylua_ls/locales/action/zh_CN.yaml index 28e52ac74..3b51c2b42 100644 --- a/crates/emmylua_ls/locales/action/zh_CN.yaml +++ b/crates/emmylua_ls/locales/action/zh_CN.yaml @@ -7,4 +7,11 @@ Disable all diagnostics in current file (%{name}): | Disable all diagnostics in current project (%{name}): | 在此项目禁用诊断 (%{name}) +use cast to remove nil: | + 使用 cast 移除 nil +Do you want to modify the require path?: | + 你想要修改 `require` 的路径吗? + +Modify: | + 修改 diff --git a/crates/emmylua_ls/src/context/client.rs b/crates/emmylua_ls/src/context/client.rs index 41ede7848..1b3ec6d12 100644 --- a/crates/emmylua_ls/src/context/client.rs +++ b/crates/emmylua_ls/src/context/client.rs @@ -5,8 +5,9 @@ use std::{ use lsp_server::{Connection, Message, Notification, RequestId, Response}; use lsp_types::{ - ApplyWorkspaceEditParams, ApplyWorkspaceEditResponse, ConfigurationParams, - PublishDiagnosticsParams, RegistrationParams, ShowMessageParams, UnregistrationParams, + ApplyWorkspaceEditParams, ApplyWorkspaceEditResponse, ConfigurationParams, MessageActionItem, + PublishDiagnosticsParams, RegistrationParams, ShowMessageParams, ShowMessageRequestParams, + UnregistrationParams, }; use serde::de::DeserializeOwned; use tokio::{ @@ -118,6 +119,23 @@ impl ClientProxy { self.send_notification("window/showMessage", message); } + pub async fn show_message_request( + &self, + params: ShowMessageRequestParams, + cancel_token: CancellationToken, + ) -> Option { + let request_id = self.next_id(); + let response = self + .send_request( + request_id, + "window/showMessageRequest", + params, + cancel_token, + ) + .await?; + serde_json::from_value(response.result?).ok() + } + pub fn publish_diagnostics(&self, params: PublishDiagnosticsParams) { self.send_notification("textDocument/publishDiagnostics", params); } diff --git a/crates/emmylua_ls/src/context/file_diagnostic.rs b/crates/emmylua_ls/src/context/file_diagnostic.rs index c162004c4..5f708cfe4 100644 --- a/crates/emmylua_ls/src/context/file_diagnostic.rs +++ b/crates/emmylua_ls/src/context/file_diagnostic.rs @@ -84,6 +84,16 @@ impl FileDiagnostic { } } + /// 清除指定文件的诊断信息 + pub async fn clear_file_diagnostics(&self, uri: lsp_types::Uri) { + let diagnostic_param = lsp_types::PublishDiagnosticsParams { + uri, + diagnostics: vec![], + version: None, + }; + self.client.publish_diagnostics(diagnostic_param); + } + pub async fn add_workspace_diagnostic_task( &self, client_id: ClientId, diff --git a/crates/emmylua_ls/src/context/workspace_manager.rs b/crates/emmylua_ls/src/context/workspace_manager.rs index faf37dced..4d3bd2013 100644 --- a/crates/emmylua_ls/src/context/workspace_manager.rs +++ b/crates/emmylua_ls/src/context/workspace_manager.rs @@ -159,6 +159,10 @@ impl WorkspaceManager { } let mut analysis = analysis.write().await; + + // 在重新索引之前清理不存在的文件 + analysis.cleanup_nonexistent_files(); + analysis.reindex(); file_diagnostic .add_workspace_diagnostic_task(client_id, 500, true) diff --git a/crates/emmylua_ls/src/handlers/code_actions/actions/build_fix_code.rs b/crates/emmylua_ls/src/handlers/code_actions/actions/build_fix_code.rs new file mode 100644 index 000000000..911484433 --- /dev/null +++ b/crates/emmylua_ls/src/handlers/code_actions/actions/build_fix_code.rs @@ -0,0 +1,54 @@ +use std::collections::HashMap; + +use emmylua_code_analysis::SemanticModel; +use emmylua_parser::{LuaAstNode, LuaExpr}; +use lsp_types::{CodeAction, CodeActionKind, CodeActionOrCommand, Range, TextEdit, WorkspaceEdit}; +use rowan::{NodeOrToken, TokenAtOffset}; + +pub fn build_need_check_nil( + semantic_model: &SemanticModel, + actions: &mut Vec, + range: Range, +) -> Option<()> { + let document = semantic_model.get_document(); + let offset = document.get_offset(range.end.line as usize, range.end.character as usize)?; + let root = semantic_model.get_root(); + let token = match root.syntax().token_at_offset(offset.into()) { + TokenAtOffset::Single(token) => token, + TokenAtOffset::Between(_, token) => token, + _ => return None, + }; + // 取上一个token的父节点 + let node_or_token = token.prev_sibling_or_token()?; + match node_or_token { + NodeOrToken::Node(node) => match node { + expr_node if LuaExpr::can_cast(expr_node.kind().into()) => { + let expr = LuaExpr::cast(expr_node)?; + let range = expr.syntax().text_range(); + let mut lsp_range = document.to_lsp_range(range)?; + // 将范围缩小到最尾部的字符 + lsp_range.start.line = lsp_range.end.line; + lsp_range.start.character = lsp_range.end.character; + + let text_edit = TextEdit { + range: lsp_range, + new_text: "--[[@cast -?]]".to_string(), + }; + + actions.push(CodeActionOrCommand::CodeAction(CodeAction { + title: t!("use cast to remove nil").to_string(), + kind: Some(CodeActionKind::QUICKFIX), + edit: Some(WorkspaceEdit { + changes: Some(HashMap::from([(document.get_uri(), vec![text_edit])])), + ..Default::default() + }), + ..Default::default() + })); + } + _ => {} + }, + _ => {} + } + + Some(()) +} diff --git a/crates/emmylua_ls/src/handlers/code_actions/actions/mod.rs b/crates/emmylua_ls/src/handlers/code_actions/actions/mod.rs index 0b03f49e6..6a227ad95 100644 --- a/crates/emmylua_ls/src/handlers/code_actions/actions/mod.rs +++ b/crates/emmylua_ls/src/handlers/code_actions/actions/mod.rs @@ -1,3 +1,5 @@ mod build_disable_code; +mod build_fix_code; pub use build_disable_code::*; +pub use build_fix_code::*; diff --git a/crates/emmylua_ls/src/handlers/code_actions/build_actions.rs b/crates/emmylua_ls/src/handlers/code_actions/build_actions.rs index 82bf34b28..3a0517372 100644 --- a/crates/emmylua_ls/src/handlers/code_actions/build_actions.rs +++ b/crates/emmylua_ls/src/handlers/code_actions/build_actions.rs @@ -6,7 +6,10 @@ use lsp_types::{ NumberOrString, Range, WorkspaceEdit, }; -use crate::handlers::command::{make_disable_code_command, DisableAction}; +use crate::handlers::{ + code_actions::actions::build_need_check_nil, + command::{make_disable_code_command, DisableAction}, +}; use super::actions::{build_disable_file_changes, build_disable_next_line_changes}; @@ -29,7 +32,13 @@ pub fn build_actions( if let Some(code) = diagnostic.code { if let NumberOrString::String(action_string) = code { if let Some(diagnostic_code) = DiagnosticCode::from_str(&action_string).ok() { - add_fix_code_action(&mut actions, diagnostic_code, file_id, diagnostic.range); + add_fix_code_action( + &semantic_model, + &mut actions, + diagnostic_code, + file_id, + diagnostic.range, + ); add_disable_code_action( &semantic_model, &mut actions, @@ -51,12 +60,16 @@ pub fn build_actions( #[allow(unused_variables)] fn add_fix_code_action( + semantic_model: &SemanticModel, actions: &mut Vec, diagnostic_code: DiagnosticCode, file_id: FileId, range: Range, ) -> Option<()> { - Some(()) + match diagnostic_code { + DiagnosticCode::NeedCheckNil => build_need_check_nil(semantic_model, actions, range), + _ => Some(()), + } } fn add_disable_code_action( diff --git a/crates/emmylua_ls/src/handlers/code_actions/mod.rs b/crates/emmylua_ls/src/handlers/code_actions/mod.rs index b329a9a2b..8813be67d 100644 --- a/crates/emmylua_ls/src/handlers/code_actions/mod.rs +++ b/crates/emmylua_ls/src/handlers/code_actions/mod.rs @@ -2,9 +2,10 @@ mod actions; mod build_actions; use build_actions::build_actions; +use emmylua_code_analysis::{EmmyLuaAnalysis, FileId}; use lsp_types::{ ClientCapabilities, CodeActionParams, CodeActionProviderCapability, CodeActionResponse, - ServerCapabilities, + Diagnostic, ServerCapabilities, }; use tokio_util::sync::CancellationToken; @@ -22,6 +23,14 @@ pub async fn on_code_action_handler( let diagnostics = params.context.diagnostics; let analysis = context.analysis.read().await; let file_id = analysis.get_file_id(&uri)?; + code_action(&analysis, file_id, diagnostics) +} + +pub fn code_action( + analysis: &EmmyLuaAnalysis, + file_id: FileId, + diagnostics: Vec, +) -> Option { let mut semantic_model = analysis.compilation.get_semantic_model(file_id)?; build_actions(&mut semantic_model, diagnostics) diff --git a/crates/emmylua_ls/src/handlers/command/commands/emmy_auto_require.rs b/crates/emmylua_ls/src/handlers/command/commands/emmy_auto_require.rs index 5d6121200..6dbc7427c 100644 --- a/crates/emmylua_ls/src/handlers/command/commands/emmy_auto_require.rs +++ b/crates/emmylua_ls/src/handlers/command/commands/emmy_auto_require.rs @@ -21,6 +21,7 @@ impl CommandSpec for AutoRequireCommand { let add_to: FileId = serde_json::from_value(args.get(0)?.clone()).ok()?; let need_require_file_id: FileId = serde_json::from_value(args.get(1)?.clone()).ok()?; let position: Position = serde_json::from_value(args.get(2)?.clone()).ok()?; + let member_name: String = serde_json::from_value(args.get(3)?.clone()).ok()?; let analysis = context.analysis.read().await; let semantic_model = analysis.compilation.get_semantic_model(add_to)?; @@ -32,7 +33,7 @@ impl CommandSpec for AutoRequireCommand { let require_like_func = &emmyrc.runtime.require_like_function; let auto_require_func = emmyrc.completion.auto_require_function.clone(); let file_conversion = emmyrc.completion.auto_require_naming_convention; - let local_name = module_name_convert(&module_info.name, file_conversion); + let local_name = module_name_convert(&module_info, file_conversion); let require_separator = emmyrc.completion.auto_require_separator.clone(); let full_module_path = match require_separator.as_str() { "." | "" => module_info.full_module_name.clone(), @@ -42,8 +43,19 @@ impl CommandSpec for AutoRequireCommand { }; let require_str = format!( - "local {} = {}(\"{}\")", - local_name, auto_require_func, full_module_path + "local {} = {}(\"{}\"){}", + if member_name.is_empty() { + local_name + } else { + member_name.clone() + }, + auto_require_func, + full_module_path, + if !member_name.is_empty() { + format!(".{}", member_name) + } else { + "".to_string() + } ); let document = semantic_model.get_document(); let offset = document.get_offset(position.line as usize, position.character as usize)?; @@ -157,11 +169,13 @@ pub fn make_auto_require( add_to: FileId, need_require_file_id: FileId, position: Position, + member_name: Option, ) -> Command { let args = vec![ serde_json::to_value(add_to).unwrap(), serde_json::to_value(need_require_file_id).unwrap(), serde_json::to_value(position).unwrap(), + serde_json::to_value(member_name.unwrap_or_default()).unwrap(), ]; Command { diff --git a/crates/emmylua_ls/src/handlers/completion/add_completions/add_member_completion.rs b/crates/emmylua_ls/src/handlers/completion/add_completions/add_member_completion.rs index 72eb0adf0..95579f38c 100644 --- a/crates/emmylua_ls/src/handlers/completion/add_completions/add_member_completion.rs +++ b/crates/emmylua_ls/src/handlers/completion/add_completions/add_member_completion.rs @@ -1,9 +1,13 @@ use emmylua_code_analysis::{DbIndex, LuaMemberInfo, LuaMemberKey, LuaSemanticDeclId, LuaType}; -use emmylua_parser::LuaTokenKind; +use emmylua_parser::{ + LuaAssignStat, LuaAstNode, LuaAstToken, LuaFuncStat, LuaGeneralToken, LuaIndexExpr, + LuaParenExpr, LuaTokenKind, +}; use lsp_types::CompletionItem; use crate::handlers::completion::{ completion_builder::CompletionBuilder, completion_data::CompletionData, + providers::get_function_remove_nil, }; use super::{ @@ -54,14 +58,14 @@ pub fn add_member_completion( }, }; - let display = get_call_show(builder.semantic_model.get_db(), &member_info.typ, status) - .unwrap_or(CallDisplay::None); - let typ = member_info.typ; - if status == CompletionTriggerStatus::Colon && !typ.is_function() { + let remove_nil_type = + get_function_remove_nil(&builder.semantic_model.get_db(), &typ).unwrap_or(typ); + if status == CompletionTriggerStatus::Colon && !remove_nil_type.is_function() { return None; } + // 附加数据, 用于在`resolve`时进一步处理 let completion_data = if let Some(id) = &property_owner { if let Some(index) = member_info.overload_index { CompletionData::from_overload( @@ -81,10 +85,12 @@ pub fn add_member_completion( None }; + let call_display = get_call_show(builder.semantic_model.get_db(), &remove_nil_type, status) + .unwrap_or(CallDisplay::None); // 紧靠着 label 显示的描述 - let detail = get_detail(builder, &typ, display); + let detail = get_detail(builder, &remove_nil_type, call_display); // 在`detail`更右侧, 且不紧靠着`detail`显示 - let description = get_description(builder, &typ); + let description = get_description(builder, &remove_nil_type); let deprecated = if let Some(id) = &property_owner { Some(is_deprecated(builder, id.clone())) @@ -94,7 +100,7 @@ pub fn add_member_completion( let mut completion_item = CompletionItem { label: label.clone(), - kind: Some(get_completion_kind(&typ)), + kind: Some(get_completion_kind(&remove_nil_type)), data: completion_data, label_details: Some(lsp_types::CompletionItemLabelDetails { detail, @@ -116,6 +122,20 @@ pub fn add_member_completion( new_text: "".to_string(), }]); } + // 对于函数的定义时的特殊处理 + if matches!( + status, + CompletionTriggerStatus::Dot | CompletionTriggerStatus::Colon + ) && (builder.trigger_token.kind() == LuaTokenKind::TkDot.into() + || builder.trigger_token.kind() == LuaTokenKind::TkColon.into()) + { + resolve_function_params( + builder, + &mut completion_item, + &remove_nil_type, + call_display, + ); + } builder.add_completion_item(completion_item)?; @@ -123,12 +143,12 @@ pub fn add_member_completion( add_signature_overloads( builder, property_owner, - &typ, - display, + &remove_nil_type, + call_display, deprecated, label, function_overload_count, - )?; + ); Some(()) } @@ -137,52 +157,50 @@ fn add_signature_overloads( builder: &mut CompletionBuilder, property_owner: &Option, typ: &LuaType, - display: CallDisplay, + call_display: CallDisplay, deprecated: Option, label: String, - function_overload_count: Option, + overload_count: Option, ) -> Option<()> { - if let LuaType::Signature(signature_id) = typ { - let overloads = builder - .semantic_model - .get_db() - .get_signature_index() - .get(&signature_id)? - .overloads - .clone(); - - overloads - .into_iter() - .enumerate() - .for_each(|(index, overload)| { - let typ = LuaType::DocFunction(overload); - let description = get_description(builder, &typ); - let detail = get_detail(builder, &typ, display); - let data = if let Some(id) = &property_owner { - CompletionData::from_overload( - builder, - id.clone().into(), - index, - function_overload_count, - ) - } else { - None - }; - let completion_item = CompletionItem { - label: label.clone(), - kind: Some(get_completion_kind(&typ)), - data, - label_details: Some(lsp_types::CompletionItemLabelDetails { - detail, - description, - }), - deprecated, - ..Default::default() - }; - - builder.add_completion_item(completion_item); - }); + let signature_id = match typ { + LuaType::Signature(signature_id) => signature_id, + _ => return None, }; + + let overloads = builder + .semantic_model + .get_db() + .get_signature_index() + .get(&signature_id)? + .overloads + .clone(); + + overloads + .into_iter() + .enumerate() + .for_each(|(index, overload)| { + let typ = LuaType::DocFunction(overload); + let description = get_description(builder, &typ); + let detail = get_detail(builder, &typ, call_display); + let data = if let Some(id) = &property_owner { + CompletionData::from_overload(builder, id.clone().into(), index, overload_count) + } else { + None + }; + let completion_item = CompletionItem { + label: label.clone(), + kind: Some(get_completion_kind(&typ)), + data, + label_details: Some(lsp_types::CompletionItemLabelDetails { + detail, + description, + }), + deprecated, + ..Default::default() + }; + + builder.add_completion_item(completion_item); + }); Some(()) } @@ -212,3 +230,85 @@ fn get_call_show( _ => Some(CallDisplay::None), } } + +/// 在定义函数时, 是否需要补全参数列表, 只补全原类型为`docfunction`的函数 +/// ```lua +/// ---@class A +/// ---@field on_add fun(self: A, a: string, b: string) +/// +/// ---@type A +/// local a +/// function a:() end +/// ``` +fn resolve_function_params( + builder: &mut CompletionBuilder, + completion_item: &mut CompletionItem, + typ: &LuaType, + call_display: CallDisplay, +) -> Option<()> { + // 目前仅允许`completion_item.label`存在值时触发 + if completion_item.insert_text.is_some() || completion_item.text_edit.is_some() { + return None; + } + let new_text = get_resolve_function_params_str(&typ, call_display)?; + let index_expr = LuaIndexExpr::cast(builder.trigger_token.parent()?)?; + let func_stat = index_expr.get_parent::()?; + // 从 ast 解析 + if func_stat.get_closure().is_some() { + return None; + } + let next_sibling = func_stat.syntax().next_sibling()?; + let assign_stat = LuaAssignStat::cast(next_sibling)?; + let paren_expr = assign_stat.child::()?; + // 如果 ast 中包含了参数, 则不补全 + if let Some(_) = paren_expr.get_expr() { + return None; + } + let left_paren = paren_expr.token::()?; + if left_paren.get_token_kind() != LuaTokenKind::TkLeftParen.into() { + return None; + } + // 可能不稳定! 因为 completion_item.label 先被应用, 然后再应用本项, 此时 range 发生了改变 + let document = builder.semantic_model.get_document(); + // 先取得左括号位置 + let add_range = left_paren.syntax().text_range(); + let mut lsp_add_range = document.to_lsp_range(add_range)?; + // 必须要移动一位字符, 不能与 label 的插入位置重复 + lsp_add_range.start.character += 1; + if new_text.is_empty() { + return None; + } + + completion_item.additional_text_edits = Some(vec![lsp_types::TextEdit { + range: lsp_add_range, + new_text: new_text, + }]); + + Some(()) +} + +fn get_resolve_function_params_str(typ: &LuaType, display: CallDisplay) -> Option { + match typ { + LuaType::DocFunction(f) => { + let mut params_str = f + .get_params() + .iter() + .map(|param| param.0.clone()) + .collect::>(); + + match display { + CallDisplay::AddSelf => { + params_str.insert(0, "self".to_string()); + } + CallDisplay::RemoveFirst => { + if !params_str.is_empty() { + params_str.remove(0); + } + } + _ => {} + } + Some(format!("{}", params_str.join(", "))) + } + _ => None, + } +} diff --git a/crates/emmylua_ls/src/handlers/completion/mod.rs b/crates/emmylua_ls/src/handlers/completion/mod.rs index 8b1a72e16..92b40d84b 100644 --- a/crates/emmylua_ls/src/handlers/completion/mod.rs +++ b/crates/emmylua_ls/src/handlers/completion/mod.rs @@ -115,6 +115,7 @@ pub fn completion_resolve( .get_semantic_model(completion_data.field_id); if let Some(semantic_model) = semantic_model { resolve_completion( + &analysis.compilation, &semantic_model, db, &mut completion_item, diff --git a/crates/emmylua_ls/src/handlers/completion/providers/auto_require_provider.rs b/crates/emmylua_ls/src/handlers/completion/providers/auto_require_provider.rs index e737ca86a..127d74102 100644 --- a/crates/emmylua_ls/src/handlers/completion/providers/auto_require_provider.rs +++ b/crates/emmylua_ls/src/handlers/completion/providers/auto_require_provider.rs @@ -1,4 +1,4 @@ -use emmylua_code_analysis::{EmmyrcFilenameConvention, ModuleInfo}; +use emmylua_code_analysis::{EmmyrcFilenameConvention, LuaType, ModuleInfo}; use emmylua_parser::{LuaAstNode, LuaNameExpr}; use lsp_types::{CompletionItem, Position}; @@ -7,7 +7,7 @@ use crate::{ command::make_auto_require, completion::{completion_builder::CompletionBuilder, completion_data::CompletionData}, }, - util::module_name_convert, + util::{key_name_convert, module_name_convert}, }; pub fn add_completion(builder: &mut CompletionBuilder) -> Option<()> { @@ -68,8 +68,16 @@ fn add_module_completion_item( position: Position, completions: &mut Vec, ) -> Option<()> { - let completion_name = module_name_convert(&module_info.name, file_conversion); + let completion_name = module_name_convert(module_info, file_conversion); if !completion_name.to_lowercase().starts_with(prefix) { + try_add_member_completion_items( + builder, + prefix, + module_info, + file_conversion, + position, + completions, + ); return None; } @@ -94,6 +102,7 @@ fn add_module_completion_item( builder.semantic_model.get_file_id(), module_info.file_id, position, + None, )), data, ..Default::default() @@ -103,3 +112,79 @@ fn add_module_completion_item( Some(()) } + +fn try_add_member_completion_items( + builder: &CompletionBuilder, + prefix: &str, + module_info: &ModuleInfo, + file_conversion: EmmyrcFilenameConvention, + position: Position, + completions: &mut Vec, +) -> Option<()> { + if let Some(export_type) = &module_info.export_type { + match export_type { + LuaType::TableConst(_) | LuaType::Def(_) => { + let member_infos = builder.semantic_model.get_member_infos(export_type)?; + for member_info in member_infos { + let key_name = key_name_convert( + &member_info.key.to_path(), + &member_info.typ, + file_conversion, + ); + match member_info.typ { + LuaType::Ref(_) | LuaType::Def(_) => {} + LuaType::Signature(_) => {} + _ => { + continue; + } + } + + if key_name.to_lowercase().starts_with(prefix) { + if builder.env_duplicate_name.contains(&key_name) { + continue; + } + + let data = if let Some(property_owner_id) = &member_info.property_owner_id { + let is_visible = builder.semantic_model.is_semantic_visible( + builder.trigger_token.clone(), + property_owner_id.clone(), + ); + if !is_visible { + continue; + } + CompletionData::from_property_owner_id( + builder, + property_owner_id.clone(), + None, + ) + } else { + None + }; + + let completion_item = CompletionItem { + label: key_name, + kind: Some(lsp_types::CompletionItemKind::MODULE), + label_details: Some(lsp_types::CompletionItemLabelDetails { + detail: Some(format!(" (in {})", module_info.full_module_name)), + ..Default::default() + }), + command: Some(make_auto_require( + "", + builder.semantic_model.get_file_id(), + module_info.file_id, + position, + Some(member_info.key.to_path().to_string()), + )), + data, + ..Default::default() + }; + + completions.push(completion_item); + } + } + } + _ => {} + } + } + Some(()) +} diff --git a/crates/emmylua_ls/src/handlers/completion/providers/doc_name_token_provider.rs b/crates/emmylua_ls/src/handlers/completion/providers/doc_name_token_provider.rs index 0a726acdf..944aef766 100644 --- a/crates/emmylua_ls/src/handlers/completion/providers/doc_name_token_provider.rs +++ b/crates/emmylua_ls/src/handlers/completion/providers/doc_name_token_provider.rs @@ -32,6 +32,12 @@ pub fn add_completion(builder: &mut CompletionBuilder) -> Option<()> { DocCompletionExpected::ClassAttr => { add_tag_class_attr_completion(builder); } + DocCompletionExpected::Namespace => { + add_tag_namespace_completion(builder); + } + DocCompletionExpected::Using => { + add_tag_using_completion(builder); + } } builder.stop_here(); @@ -66,6 +72,8 @@ fn get_doc_completion_expected(trigger_token: &LuaSyntaxToken) -> Option None, } } + LuaTokenKind::TkTagNamespace => Some(DocCompletionExpected::Namespace), + LuaTokenKind::TkTagUsing => Some(DocCompletionExpected::Using), LuaTokenKind::TkComma => { let parent = left_token.parent()?; match parent.kind().into() { @@ -112,6 +120,8 @@ enum DocCompletionExpected { DiagnosticAction, DiagnosticCode, ClassAttr, + Namespace, + Using, } fn add_tag_param_name_completion(builder: &mut CompletionBuilder) -> Option<()> { @@ -253,3 +263,46 @@ fn add_tag_class_attr_completion(builder: &mut CompletionBuilder) { builder.add_completion_item(completion_item); } } + +fn add_tag_namespace_completion(builder: &mut CompletionBuilder) { + let type_index = builder.semantic_model.get_db().get_type_index(); + let file_id = builder.semantic_model.get_file_id(); + if type_index.get_file_namespace(&file_id).is_some() { + return; + } + let mut namespaces = type_index.get_file_namespaces(); + + namespaces.sort(); + + for (sorted_index, namespace) in namespaces.iter().enumerate() { + let completion_item = CompletionItem { + label: namespace.clone(), + kind: Some(lsp_types::CompletionItemKind::MODULE), + sort_text: Some(format!("{:03}", sorted_index)), + ..Default::default() + }; + builder.add_completion_item(completion_item); + } +} + +fn add_tag_using_completion(builder: &mut CompletionBuilder) { + let type_index = builder.semantic_model.get_db().get_type_index(); + let file_id = builder.semantic_model.get_file_id(); + let current_namespace = type_index.get_file_namespace(&file_id); + let mut namespaces = type_index.get_file_namespaces(); + if let Some(current_namespace) = current_namespace { + namespaces.retain(|namespace| namespace != current_namespace); + } + namespaces.sort(); + + for (sorted_index, namespace) in namespaces.iter().enumerate() { + let completion_item = CompletionItem { + label: format!("using {}", namespace), + kind: Some(lsp_types::CompletionItemKind::MODULE), + sort_text: Some(format!("{:03}", sorted_index)), + insert_text: Some(format!("{}", namespace)), + ..Default::default() + }; + builder.add_completion_item(completion_item); + } +} diff --git a/crates/emmylua_ls/src/handlers/completion/providers/env_provider.rs b/crates/emmylua_ls/src/handlers/completion/providers/env_provider.rs index 4b63ee654..70f3d5e4b 100644 --- a/crates/emmylua_ls/src/handlers/completion/providers/env_provider.rs +++ b/crates/emmylua_ls/src/handlers/completion/providers/env_provider.rs @@ -1,7 +1,9 @@ use std::collections::HashSet; use emmylua_code_analysis::{LuaFlowId, LuaSignatureId, LuaType, LuaVarRefId}; -use emmylua_parser::{LuaAst, LuaAstNode, LuaCallArgList, LuaClosureExpr, LuaParamList}; +use emmylua_parser::{ + LuaAst, LuaAstNode, LuaCallArgList, LuaClosureExpr, LuaParamList, LuaTokenKind, +}; use lsp_types::{CompletionItem, CompletionItemKind, CompletionTriggerKind}; use crate::handlers::completion::{ @@ -56,6 +58,10 @@ fn check_can_add_completion(builder: &CompletionBuilder) -> Option<()> { } } else if builder.trigger_kind == CompletionTriggerKind::INVOKED { let parent = builder.trigger_token.parent()?; + let prev_token = builder.trigger_token.prev_token()?; + if prev_token.kind() == LuaTokenKind::TkTagUsing.into() { + return None; + } // 即时是主动触发, 也不允许在函数定义的参数列表中添加 if trigger_text == "(" { if LuaParamList::can_cast(parent.kind().into()) { diff --git a/crates/emmylua_ls/src/handlers/completion/providers/function_provider.rs b/crates/emmylua_ls/src/handlers/completion/providers/function_provider.rs index 69dc2a594..5bda5088f 100644 --- a/crates/emmylua_ls/src/handlers/completion/providers/function_provider.rs +++ b/crates/emmylua_ls/src/handlers/completion/providers/function_provider.rs @@ -1,12 +1,14 @@ +use std::sync::Arc; + use emmylua_code_analysis::{ - DbIndex, InferGuard, LuaDeclLocation, LuaFunctionType, LuaMember, LuaMemberKey, LuaMemberOwner, - LuaMultiLineUnion, LuaSemanticDeclId, LuaStringTplType, LuaType, LuaTypeCache, LuaTypeDeclId, - LuaUnionType, RenderLevel, SemanticDeclLevel, + get_real_type, DbIndex, InferGuard, LuaDeclLocation, LuaFunctionType, LuaMember, LuaMemberKey, + LuaMemberOwner, LuaMultiLineUnion, LuaSemanticDeclId, LuaStringTplType, LuaType, LuaTypeCache, + LuaTypeDeclId, LuaUnionType, RenderLevel, SemanticDeclLevel, }; use emmylua_parser::{ - LuaAst, LuaAstNode, LuaAstToken, LuaCallArgList, LuaCallExpr, LuaClosureExpr, LuaComment, - LuaDocTagParam, LuaLiteralExpr, LuaLiteralToken, LuaNameToken, LuaParamList, LuaStat, - LuaSyntaxId, LuaSyntaxKind, LuaSyntaxToken, LuaTokenKind, LuaVarExpr, + LuaAssignStat, LuaAst, LuaAstNode, LuaAstToken, LuaCallArgList, LuaCallExpr, LuaClosureExpr, + LuaComment, LuaDocTagParam, LuaLiteralExpr, LuaLiteralToken, LuaNameToken, LuaParamList, + LuaStat, LuaSyntaxId, LuaSyntaxKind, LuaSyntaxToken, LuaTokenKind, LuaVarExpr, }; use itertools::Itertools; use lsp_types::{CompletionItem, Documentation}; @@ -46,6 +48,33 @@ fn get_token_should_type(builder: &mut CompletionBuilder) -> Option } return infer_param_list(builder, LuaParamList::cast(parent_node)?); } + LuaSyntaxKind::Block => { + /* + 补全以下形式: + ```lua + ---@class A + ---@field func fun(a: string) + + ---@type A + local a + + a.func = + ``` + */ + let prev_token = token.prev_token()?; + let assign_stat = LuaAssignStat::cast(prev_token.parent()?)?; + let (vars, exprs) = assign_stat.get_var_and_expr_list(); + if vars.len() != 1 || !exprs.is_empty() { + return None; + } + let var = vars.first()?; + let var_type = builder.semantic_model.infer_expr(var.clone().into()).ok()?; + let real_type = get_real_type(&builder.semantic_model.get_db(), &var_type)?; + return Some(vec![get_function_remove_nil( + &builder.semantic_model.get_db(), + &real_type, + )?]); + } _ => {} } @@ -76,7 +105,6 @@ pub fn dispatch_type( LuaType::StrTplRef(key) => { add_str_tpl_ref_completion(builder, &key); } - _ => {} } @@ -467,8 +495,8 @@ fn add_lambda_completion(builder: &mut CompletionBuilder, func: &LuaFunctionType .iter() .map(|p| p.0.clone()) .collect::>(); - let label = format!("function ({}) end", params_str.join(", ")); - let insert_text = format!("function ({})\n\t$0\nend", params_str.join(", ")); + let label = format!("function({}) end", params_str.join(", ")); + let insert_text = format!("function({})\n\t$0\nend", params_str.join(", ")); let completion_item = CompletionItem { label, @@ -750,3 +778,40 @@ fn get_str_tpl_ref_extend_type( _ => None, } } + +/// 确保所有成员均为 function 或者 nil, 然后返回 function 的联合类型, 如果非 function 则返回 None +pub fn get_function_remove_nil(db: &DbIndex, typ: &LuaType) -> Option { + match typ { + LuaType::Union(union_typ) => { + let mut new_types = Vec::new(); + for member in union_typ.get_types().iter() { + match member { + _ if member.is_function() => { + new_types.push(member.clone()); + } + _ if member.is_custom_type() => { + let real_type = get_real_type(db, member)?; + if real_type.is_function() { + new_types.push(real_type.clone()); + } + } + _ if member.is_nil() => { + continue; + } + _ => { + return None; + } + } + } + match new_types.len() { + 0 => None, + 1 => Some(new_types[0].clone()), + _ => Some(LuaType::Union(Arc::new(LuaUnionType::new(new_types)))), + } + } + _ if typ.is_function() => { + return Some(typ.clone()); + } + _ => None, + } +} diff --git a/crates/emmylua_ls/src/handlers/completion/providers/keywords_provider.rs b/crates/emmylua_ls/src/handlers/completion/providers/keywords_provider.rs index 93e962dad..48e51f9b5 100644 --- a/crates/emmylua_ls/src/handlers/completion/providers/keywords_provider.rs +++ b/crates/emmylua_ls/src/handlers/completion/providers/keywords_provider.rs @@ -78,14 +78,29 @@ fn add_stat_keyword_completions( continue; } + let (label_detail, insert_text) = + if matches!(keyword_info.label, "function" | "local function") + && !base_function_includes_name(builder) + { + ( + keyword_info.detail.replace("name", ""), + keyword_info.insert_text.replace("name", ""), + ) + } else { + ( + keyword_info.detail.to_string(), + keyword_info.insert_text.to_string(), + ) + }; + let item = CompletionItem { label: keyword_info.label.to_string(), kind: Some(keyword_info.kind), label_details: Some(CompletionItemLabelDetails { - detail: Some(keyword_info.detail.to_string()), + detail: Some(label_detail), ..CompletionItemLabelDetails::default() }), - insert_text: Some(keyword_info.insert_text.to_string()), + insert_text: Some(insert_text), insert_text_format: Some(InsertTextFormat::SNIPPET), insert_text_mode: Some(InsertTextMode::ADJUST_INDENTATION), ..CompletionItem::default() @@ -141,3 +156,11 @@ fn add_function_keyword_completions(builder: &mut CompletionBuilder) -> Option<( Some(()) } + +fn base_function_includes_name(builder: &CompletionBuilder) -> bool { + builder + .semantic_model + .get_emmyrc() + .completion + .base_function_includes_name +} diff --git a/crates/emmylua_ls/src/handlers/completion/providers/member_provider.rs b/crates/emmylua_ls/src/handlers/completion/providers/member_provider.rs index 33e90b6ec..639208117 100644 --- a/crates/emmylua_ls/src/handlers/completion/providers/member_provider.rs +++ b/crates/emmylua_ls/src/handlers/completion/providers/member_provider.rs @@ -38,6 +38,7 @@ pub fn add_completion(builder: &mut CompletionBuilder) -> Option<()> { { return None; } + let member_info_map = builder.semantic_model.get_member_info_map(&prefix_type)?; for (_, member_infos) in member_info_map.iter() { add_resolve_member_infos(builder, &member_infos, completion_status); @@ -52,7 +53,7 @@ fn add_resolve_member_infos( completion_status: CompletionTriggerStatus, ) -> Option<()> { if member_infos.len() == 1 { - let function_count = count_function_overloads( + let overload_count = count_function_overloads( builder.semantic_model.get_db(), &member_infos.iter().map(|info| info).collect::>(), ); @@ -61,7 +62,7 @@ fn add_resolve_member_infos( builder, member_info.clone(), completion_status, - function_count, + overload_count, ); return Some(()); } @@ -170,7 +171,7 @@ fn count_function_overloads(db: &DbIndex, member_infos: &Vec<&LuaMemberInfo>) -> _ => {} } } - if count > 1 { + if count >= 1 { count -= 1; } if count == 0 { diff --git a/crates/emmylua_ls/src/handlers/completion/providers/mod.rs b/crates/emmylua_ls/src/handlers/completion/providers/mod.rs index a73c51798..2a4760a13 100644 --- a/crates/emmylua_ls/src/handlers/completion/providers/mod.rs +++ b/crates/emmylua_ls/src/handlers/completion/providers/mod.rs @@ -12,10 +12,9 @@ mod module_path_provider; mod postfix_provider; mod table_field_provider; -use emmylua_code_analysis::DbIndex; -use emmylua_code_analysis::LuaType; use emmylua_parser::LuaAstToken; use emmylua_parser::LuaStringToken; +pub use function_provider::get_function_remove_nil; use rowan::TextRange; use super::completion_builder::CompletionBuilder; @@ -75,30 +74,3 @@ fn get_text_edit_range_in_string( lsp_range } - -pub fn get_real_type<'a>(db: &'a DbIndex, compact_type: &'a LuaType) -> Option<&'a LuaType> { - get_real_type_with_depth(db, compact_type, 0) -} - -fn get_real_type_with_depth<'a>( - db: &'a DbIndex, - compact_type: &'a LuaType, - depth: u32, -) -> Option<&'a LuaType> { - const MAX_RECURSION_DEPTH: u32 = 100; - - if depth >= MAX_RECURSION_DEPTH { - return Some(compact_type); - } - - match compact_type { - LuaType::Ref(type_decl_id) => { - let type_decl = db.get_type_index().get_type_decl(type_decl_id)?; - if type_decl.is_alias() { - return get_real_type_with_depth(db, type_decl.get_alias_ref()?, depth + 1); - } - Some(compact_type) - } - _ => Some(compact_type), - } -} diff --git a/crates/emmylua_ls/src/handlers/completion/providers/table_field_provider.rs b/crates/emmylua_ls/src/handlers/completion/providers/table_field_provider.rs index 2a8133e1f..3961775b3 100644 --- a/crates/emmylua_ls/src/handlers/completion/providers/table_field_provider.rs +++ b/crates/emmylua_ls/src/handlers/completion/providers/table_field_provider.rs @@ -1,7 +1,7 @@ use std::collections::HashSet; -use emmylua_code_analysis::{LuaMemberInfo, LuaMemberKey, LuaType}; -use emmylua_parser::{LuaAst, LuaAstNode, LuaTableExpr, LuaTableField}; +use emmylua_code_analysis::{get_real_type, LuaMemberInfo, LuaMemberKey, LuaType}; +use emmylua_parser::{LuaAst, LuaAstNode, LuaKind, LuaTableExpr, LuaTableField, LuaTokenKind}; use lsp_types::{CompletionItem, InsertTextFormat, InsertTextMode}; use rowan::NodeOrToken; @@ -11,8 +11,6 @@ use crate::handlers::completion::{ completion_data::CompletionData, }; -use super::get_real_type; - pub fn add_completion(builder: &mut CompletionBuilder) -> Option<()> { add_table_field_key_completion(builder); add_table_field_value_completion(builder); @@ -22,9 +20,25 @@ pub fn add_completion(builder: &mut CompletionBuilder) -> Option<()> { fn add_table_field_key_completion(builder: &mut CompletionBuilder) -> Option<()> { if !can_add_key_completion(builder) { - return Some(()); + return None; + } + // 出现以下情况则代表是补全 value + let prev_token = builder.trigger_token.prev_token()?; + if builder.trigger_token.kind() == LuaKind::Token(LuaTokenKind::TkWhitespace) + && prev_token.kind() == LuaKind::Token(LuaTokenKind::TkAssign) + { + return None; } - let table_expr = get_table_expr(builder)?; + + let node = LuaAst::cast(builder.trigger_token.parent()?)?; + let table_expr = match node { + LuaAst::LuaTableExpr(table_expr) => Some(table_expr), + LuaAst::LuaNameExpr(name_expr) => name_expr + .get_parent::()? + .get_parent::(), + _ => None, + }?; + let table_type = builder .semantic_model .infer_table_should_be(table_expr.clone())?; @@ -67,18 +81,6 @@ fn can_add_key_completion(builder: &mut CompletionBuilder) -> bool { true } -fn get_table_expr(builder: &mut CompletionBuilder) -> Option { - let node = LuaAst::cast(builder.trigger_token.parent()?)?; - - match node { - LuaAst::LuaTableExpr(table_expr) => Some(table_expr), - LuaAst::LuaNameExpr(name_expr) => name_expr - .get_parent::()? - .get_parent::(), - _ => None, - } -} - fn add_field_key_completion( builder: &mut CompletionBuilder, member_info: LuaMemberInfo, diff --git a/crates/emmylua_ls/src/handlers/completion/resolve_completion.rs b/crates/emmylua_ls/src/handlers/completion/resolve_completion.rs index 5593ce6d7..5a4cc02e9 100644 --- a/crates/emmylua_ls/src/handlers/completion/resolve_completion.rs +++ b/crates/emmylua_ls/src/handlers/completion/resolve_completion.rs @@ -1,4 +1,4 @@ -use emmylua_code_analysis::{DbIndex, SemanticModel}; +use emmylua_code_analysis::{DbIndex, LuaCompilation, SemanticModel}; use lsp_types::{CompletionItem, Documentation, MarkedString, MarkupContent}; use crate::{ @@ -9,6 +9,7 @@ use crate::{ use super::completion_data::{CompletionData, CompletionDataType}; pub fn resolve_completion( + compilation: &LuaCompilation, semantic_model: &SemanticModel, db: &DbIndex, completion_item: &mut CompletionItem, @@ -18,7 +19,8 @@ pub fn resolve_completion( // todo: resolve completion match completion_data.typ { CompletionDataType::PropertyOwnerId(property_id) => { - let hover_builder = build_hover_content_for_completion(semantic_model, db, property_id); + let hover_builder = + build_hover_content_for_completion(compilation, semantic_model, db, property_id); if let Some(mut hover_builder) = hover_builder { update_function_signature_info( &mut hover_builder, @@ -32,7 +34,8 @@ pub fn resolve_completion( } } CompletionDataType::Overload((property_id, index)) => { - let hover_builder = build_hover_content_for_completion(semantic_model, db, property_id); + let hover_builder = + build_hover_content_for_completion(compilation, semantic_model, db, property_id); if let Some(mut hover_builder) = hover_builder { update_function_signature_info( &mut hover_builder, diff --git a/crates/emmylua_ls/src/handlers/definition/goto_def_definition.rs b/crates/emmylua_ls/src/handlers/definition/goto_def_definition.rs index 5bd04f8e1..0decc44ba 100644 --- a/crates/emmylua_ls/src/handlers/definition/goto_def_definition.rs +++ b/crates/emmylua_ls/src/handlers/definition/goto_def_definition.rs @@ -1,20 +1,24 @@ use std::str::FromStr; use emmylua_code_analysis::{ - LuaDeclId, LuaMemberId, LuaMemberInfo, LuaMemberKey, LuaSemanticDeclId, LuaType, LuaTypeDeclId, - SemanticDeclLevel, SemanticModel, + LuaCompilation, LuaDeclId, LuaMemberId, LuaMemberInfo, LuaMemberKey, LuaSemanticDeclId, + LuaType, LuaTypeDeclId, SemanticDeclLevel, SemanticModel, }; use emmylua_parser::{ - LuaAstNode, LuaAstToken, LuaCallExpr, LuaExpr, LuaIndexExpr, LuaStringToken, LuaSyntaxToken, - LuaTableExpr, LuaTableField, + LuaAstNode, LuaAstToken, LuaCallExpr, LuaExpr, LuaIndexExpr, LuaReturnStat, LuaStringToken, + LuaSyntaxToken, LuaTableExpr, LuaTableField, }; use itertools::Itertools; use lsp_types::{GotoDefinitionResponse, Location, Position, Range, Uri}; -use crate::handlers::hover::find_all_same_named_members; +use crate::handlers::{ + definition::goto_function::find_call_match_function, + hover::{find_all_same_named_members, find_member_origin_owner}, +}; pub fn goto_def_definition( semantic_model: &SemanticModel, + compilation: &LuaCompilation, property_owner: LuaSemanticDeclId, trigger_token: &LuaSyntaxToken, ) -> Option { @@ -29,15 +33,9 @@ pub fn goto_def_definition( } } } - match property_owner { LuaSemanticDeclId::LuaDecl(decl_id) => { - let decl = semantic_model - .get_db() - .get_decl_index() - .get_decl(&decl_id)?; - let document = semantic_model.get_document_by_file_id(decl_id.file_id)?; - let location = document.to_lsp_location(decl.get_range())?; + let location = get_decl_location(semantic_model, &decl_id)?; return Some(GotoDefinitionResponse::Scalar(location)); } LuaSemanticDeclId::Member(member_id) => { @@ -47,6 +45,43 @@ pub fn goto_def_definition( )?; let mut locations: Vec = Vec::new(); + // 如果是函数调用, 则尝试寻找最匹配的定义 + if let Some(match_members) = + find_call_match_function(semantic_model, trigger_token, &same_named_members) + { + for member in match_members { + if let LuaSemanticDeclId::Member(member_id) = member { + if let Some(true) = should_trace_member(semantic_model, &member_id) { + // 尝试搜索这个成员最原始的定义 + match find_member_origin_owner(compilation, semantic_model, member_id) { + Some(LuaSemanticDeclId::Member(member_id)) => { + if let Some(location) = + get_member_location(semantic_model, &member_id) + { + locations.push(location); + continue; + } + } + Some(LuaSemanticDeclId::LuaDecl(decl_id)) => { + if let Some(location) = + get_decl_location(semantic_model, &decl_id) + { + locations.push(location); + continue; + } + } + _ => {} + } + } + if let Some(location) = get_member_location(semantic_model, &member_id) { + locations.push(location); + } + } + } + if !locations.is_empty() { + return Some(GotoDefinitionResponse::Array(locations)); + } + } // 添加原始成员的位置 for member in same_named_members { @@ -75,7 +110,7 @@ pub fn goto_def_definition( */ if let Some(table_field_infos) = - find_table_member_definition(semantic_model, trigger_token, &member_id) + find_instance_table_member(semantic_model, trigger_token, &member_id) { for table_field_info in table_field_infos { if let Some(LuaSemanticDeclId::Member(table_member_id)) = @@ -199,7 +234,7 @@ pub fn goto_str_tpl_ref_definition( None } -pub fn find_table_member_definition( +pub fn find_instance_table_member( semantic_model: &SemanticModel, trigger_token: &LuaSyntaxToken, member_id: &LuaMemberId, @@ -277,6 +312,22 @@ fn find_member_in_table_const( ) } +/// 是否对 member 启动追踪 +fn should_trace_member(semantic_model: &SemanticModel, member_id: &LuaMemberId) -> Option { + let root = semantic_model + .get_db() + .get_vfs() + .get_syntax_tree(&member_id.file_id)? + .get_red_root(); + let node = member_id.get_syntax_id().to_node_from_root(&root)?; + let parent = node.parent()?.parent()?; + // 如果成员在返回语句中, 则需要追踪 + if LuaReturnStat::can_cast(parent.kind().into()) { + return Some(true); + } + None +} + fn get_member_location( semantic_model: &SemanticModel, member_id: &LuaMemberId, @@ -284,3 +335,13 @@ fn get_member_location( let document = semantic_model.get_document_by_file_id(member_id.file_id)?; document.to_lsp_location(member_id.get_syntax_id().get_range()) } + +fn get_decl_location(semantic_model: &SemanticModel, decl_id: &LuaDeclId) -> Option { + let decl = semantic_model + .get_db() + .get_decl_index() + .get_decl(&decl_id)?; + let document = semantic_model.get_document_by_file_id(decl_id.file_id)?; + let location = document.to_lsp_location(decl.get_range())?; + Some(location) +} diff --git a/crates/emmylua_ls/src/handlers/definition/goto_function.rs b/crates/emmylua_ls/src/handlers/definition/goto_function.rs new file mode 100644 index 000000000..44f2d8406 --- /dev/null +++ b/crates/emmylua_ls/src/handlers/definition/goto_function.rs @@ -0,0 +1,117 @@ +use emmylua_code_analysis::{ + instantiate_func_generic, LuaFunctionType, LuaSemanticDeclId, LuaSignature, LuaType, + SemanticModel, +}; +use emmylua_parser::{LuaAstNode, LuaCallExpr, LuaSyntaxToken}; +use std::sync::Arc; + +pub fn find_call_match_function( + semantic_model: &SemanticModel, + trigger_token: &LuaSyntaxToken, + semantic_decls: &Vec, +) -> Option> { + let call_expr = LuaCallExpr::cast(trigger_token.parent()?.parent()?)?; + let call_function = get_call_function(semantic_model, &call_expr)?; + let mut result = Vec::new(); + let member_decls: Vec<_> = semantic_decls + .iter() + .filter_map(|decl| match decl { + LuaSemanticDeclId::Member(member_id) => Some((decl, member_id)), + _ => None, + }) + .collect(); + + let mut has_match = false; + for (decl, member_id) in member_decls { + let typ = semantic_model.get_type(member_id.clone().into()); + match typ { + LuaType::DocFunction(func) => { + if compare_function_types(semantic_model, &call_function, &func, &call_expr)? { + result.push(decl.clone()); + has_match = true; + } + } + LuaType::Signature(signature_id) => { + let signature = semantic_model + .get_db() + .get_signature_index() + .get(&signature_id)?; + let functions = get_signature_functions(signature); + + if functions.iter().any(|func| { + compare_function_types(semantic_model, &call_function, func, &call_expr) + .unwrap_or(false) + }) { + has_match = true; + } + // 此处为降低优先级, 因为如果返回多个选项, 那么 vscode 会默认指向最后的选项 + result.insert(0, decl.clone()); + } + _ => continue, + } + } + + if !has_match { + return None; + } + + match result.len() { + 0 => None, + _ => Some(result), + } +} + +/// 获取最匹配的函数(并不能确保完全匹配) +fn get_call_function( + semantic_model: &SemanticModel, + call_expr: &LuaCallExpr, +) -> Option> { + let func = semantic_model.infer_call_expr_func(call_expr.clone(), None); + if let Some(func) = func { + let call_expr_args_count = call_expr.get_args_count(); + if let Some(mut call_expr_args_count) = call_expr_args_count { + let func_params_count = func.get_params().len(); + if !func.is_colon_define() && call_expr.is_colon_call() { + // 不是冒号定义的函数, 但是是冒号调用 + call_expr_args_count += 1; + } + if call_expr_args_count == func_params_count { + return Some(func); + } + } + } + None +} + +fn get_signature_functions(signature: &LuaSignature) -> Vec> { + let mut functions = Vec::new(); + functions.push(signature.to_doc_func_type()); + functions.extend( + signature + .overloads + .iter() + .map(|overload| Arc::clone(overload)), + ); + functions +} + +/// 比较函数类型是否匹配, 会处理泛型情况 +pub fn compare_function_types( + semantic_model: &SemanticModel, + call_function: &LuaFunctionType, + func: &Arc, + call_expr: &LuaCallExpr, +) -> Option { + if func.contain_tpl() { + let instantiated_func = instantiate_func_generic( + semantic_model.get_db(), + &mut semantic_model.get_config().borrow_mut(), + func, + call_expr.clone(), + ) + .ok()?; + Some(call_function == &instantiated_func) + } else { + Some(call_function == func.as_ref()) + } +} diff --git a/crates/emmylua_ls/src/handlers/definition/goto_module_file.rs b/crates/emmylua_ls/src/handlers/definition/goto_module_file.rs index 31ccc9b66..d7faee8cc 100644 --- a/crates/emmylua_ls/src/handlers/definition/goto_module_file.rs +++ b/crates/emmylua_ls/src/handlers/definition/goto_module_file.rs @@ -18,6 +18,12 @@ pub fn goto_module_file( let file_id = founded_module.file_id; let document = semantic_model.get_document_by_file_id(file_id)?; let uri = document.get_uri(); + // 确保目标文件存在 + let file_path = document.get_file_path(); + if !file_path.try_exists().unwrap_or(false) { + return None; + } + let lsp_range = document.get_document_lsp_range(); Some(GotoDefinitionResponse::Scalar(Location { diff --git a/crates/emmylua_ls/src/handlers/definition/mod.rs b/crates/emmylua_ls/src/handlers/definition/mod.rs index 5383cf06a..b8ee1157f 100644 --- a/crates/emmylua_ls/src/handlers/definition/mod.rs +++ b/crates/emmylua_ls/src/handlers/definition/mod.rs @@ -1,5 +1,6 @@ mod goto_def_definition; mod goto_doc_see; +mod goto_function; mod goto_module_file; use emmylua_code_analysis::{EmmyLuaAnalysis, FileId, SemanticDeclLevel}; @@ -9,6 +10,7 @@ use emmylua_parser::{ pub use goto_def_definition::goto_def_definition; use goto_def_definition::goto_str_tpl_ref_definition; pub use goto_doc_see::goto_doc_see; +pub use goto_function::compare_function_types; pub use goto_module_file::goto_module_file; use lsp_types::{ ClientCapabilities, GotoDefinitionParams, GotoDefinitionResponse, OneOf, Position, @@ -67,7 +69,12 @@ pub fn definition( if let Some(semantic_decl) = semantic_model.find_decl(token.clone().into(), SemanticDeclLevel::default()) { - return goto_def_definition(&semantic_model, semantic_decl, &token); + return goto_def_definition( + &semantic_model, + &analysis.compilation, + semantic_decl, + &token, + ); } else if let Some(string_token) = LuaStringToken::cast(token.clone()) { if let Some(module_response) = goto_module_file(&semantic_model, string_token.clone()) { return Some(module_response); diff --git a/crates/emmylua_ls/src/handlers/document_symbol/builder.rs b/crates/emmylua_ls/src/handlers/document_symbol/builder.rs index a401011ec..88860a376 100644 --- a/crates/emmylua_ls/src/handlers/document_symbol/builder.rs +++ b/crates/emmylua_ls/src/handlers/document_symbol/builder.rs @@ -82,12 +82,17 @@ impl<'a> DocumentSymbolBuilder<'a> { let id = root.get_syntax_id(); let lua_symbol = self.document_symbols.get(&id).unwrap(); let lsp_range = self.document.to_lsp_range(lua_symbol.range).unwrap(); + let lsp_selection_range = lua_symbol + .selection_range + .and_then(|range| self.document.to_lsp_range(range)) + .unwrap_or_else(|| lsp_range.clone()); + let mut document_symbol = DocumentSymbol { name: lua_symbol.name.clone(), detail: lua_symbol.detail.clone(), kind: lua_symbol.kind, - range: lsp_range.clone(), - selection_range: lsp_range, + range: lsp_range, + selection_range: lsp_selection_range, children: None, tags: None, deprecated: None, @@ -107,6 +112,11 @@ impl<'a> DocumentSymbolBuilder<'a> { for child in &symbol.children { let child_symbol = self.document_symbols.get(child)?; let lsp_range = self.document.to_lsp_range(child_symbol.range)?; + let lsp_selection_range = child_symbol + .selection_range + .and_then(|range| self.document.to_lsp_range(range)) + .unwrap_or_else(|| lsp_range.clone()); + let child_symbol_name = if child_symbol.name.is_empty() { "(empty)".to_string() } else { @@ -117,8 +127,8 @@ impl<'a> DocumentSymbolBuilder<'a> { name: child_symbol_name, detail: child_symbol.detail.clone(), kind: child_symbol.kind, - range: lsp_range.clone(), - selection_range: lsp_range, + range: lsp_range, + selection_range: lsp_selection_range, children: None, tags: None, deprecated: None, @@ -196,6 +206,7 @@ pub struct LuaSymbol { detail: Option, kind: SymbolKind, range: TextRange, + selection_range: Option, children: Vec, } @@ -206,6 +217,24 @@ impl LuaSymbol { detail, kind, range, + selection_range: None, + children: Vec::new(), + } + } + + pub fn with_selection_range( + name: String, + detail: Option, + kind: SymbolKind, + range: TextRange, + selection_range: TextRange, + ) -> Self { + Self { + name, + detail, + kind, + range, + selection_range: Some(selection_range), children: Vec::new(), } } diff --git a/crates/emmylua_ls/src/handlers/document_symbol/stats.rs b/crates/emmylua_ls/src/handlers/document_symbol/stats.rs index f51804755..6bd238fdc 100644 --- a/crates/emmylua_ls/src/handlers/document_symbol/stats.rs +++ b/crates/emmylua_ls/src/handlers/document_symbol/stats.rs @@ -134,11 +134,16 @@ pub fn build_local_func_stat_symbol( let decl = builder.get_decl(&decl_id)?; let typ = builder.get_type(decl_id.into()); let desc = builder.get_symbol_kind_and_detail(Some(&typ)); - let symbol = LuaSymbol::new( + + let full_range = local_func.get_range(); + let name_range = decl.get_range(); + + let symbol = LuaSymbol::with_selection_range( decl.get_name().to_string(), desc.1, desc.0, - decl.get_range(), + full_range, + name_range, ); builder.add_node_symbol(local_func.syntax().clone(), symbol); @@ -156,7 +161,11 @@ pub fn build_func_stat_symbol( let signature_id = LuaSignatureId::from_closure(file_id, &closure); let func_ty = LuaType::Signature(signature_id); let desc = builder.get_symbol_kind_and_detail(Some(&func_ty)); - let symbol = LuaSymbol::new(name, desc.1, desc.0, func.get_range()); + + let full_range = func.get_range(); + let name_range = func_name.get_range(); + + let symbol = LuaSymbol::with_selection_range(name, desc.1, desc.0, full_range, name_range); builder.add_node_symbol(func.syntax().clone(), symbol); Some(()) diff --git a/crates/emmylua_ls/src/handlers/hover/build_hover.rs b/crates/emmylua_ls/src/handlers/hover/build_hover.rs index 392672dfb..9804ace98 100644 --- a/crates/emmylua_ls/src/handlers/hover/build_hover.rs +++ b/crates/emmylua_ls/src/handlers/hover/build_hover.rs @@ -1,8 +1,8 @@ use std::collections::HashSet; use emmylua_code_analysis::{ - DbIndex, LuaDeclId, LuaDocument, LuaMemberId, LuaMemberKey, LuaSemanticDeclId, LuaSignatureId, - LuaType, LuaTypeDeclId, RenderLevel, SemanticInfo, SemanticModel, + DbIndex, LuaCompilation, LuaDeclId, LuaDocument, LuaMemberId, LuaMemberKey, LuaSemanticDeclId, + LuaSignatureId, LuaType, LuaTypeDeclId, RenderLevel, SemanticInfo, SemanticModel, }; use emmylua_parser::{LuaAssignStat, LuaAstNode, LuaExpr, LuaSyntaxToken}; use lsp_types::{Hover, HoverContents, MarkedString, MarkupContent}; @@ -11,8 +11,8 @@ use emmylua_code_analysis::humanize_type; use crate::handlers::hover::{ find_origin::replace_semantic_type, - function_humanize::is_function, - hover_humanize::{hover_function_type, hover_humanize_type}, + function_humanize::{hover_function_type, is_function}, + hover_humanize::hover_humanize_type, }; use super::{ @@ -22,6 +22,7 @@ use super::{ }; pub fn build_semantic_info_hover( + compilation: &LuaCompilation, semantic_model: &SemanticModel, db: &DbIndex, document: &LuaDocument, @@ -33,6 +34,7 @@ pub fn build_semantic_info_hover( return build_hover_without_property(db, document, token, typ); } let hover_builder = build_hover_content( + compilation, semantic_model, db, Some(typ), @@ -64,6 +66,7 @@ fn build_hover_without_property( } pub fn build_hover_content_for_completion<'a>( + compilation: &'a LuaCompilation, semantic_model: &'a SemanticModel, db: &DbIndex, property_id: LuaSemanticDeclId, @@ -77,10 +80,19 @@ pub fn build_hover_content_for_completion<'a>( } _ => None, }; - build_hover_content(semantic_model, db, typ, property_id, true, None) + build_hover_content( + compilation, + semantic_model, + db, + typ, + property_id, + true, + None, + ) } fn build_hover_content<'a>( + compilation: &'a LuaCompilation, semantic_model: &'a SemanticModel, db: &DbIndex, typ: Option, @@ -88,7 +100,7 @@ fn build_hover_content<'a>( is_completion: bool, token: Option, ) -> Option> { - let mut builder = HoverBuilder::new(semantic_model, token, is_completion); + let mut builder = HoverBuilder::new(compilation, semantic_model, token, is_completion); match property_id { LuaSemanticDeclId::LuaDecl(decl_id) => { let typ = typ?; @@ -114,8 +126,9 @@ fn build_decl_hover( ) -> Option<()> { let decl = db.get_decl_index().get_decl(&decl_id)?; - let mut semantic_decls = find_decl_origin_owners(&builder.semantic_model, decl_id) - .get_types(&builder.semantic_model); + let mut semantic_decls = + find_decl_origin_owners(builder.compilation, &builder.semantic_model, decl_id) + .get_types(&builder.semantic_model); replace_semantic_type(&mut semantic_decls, &typ); // 处理类型签名 if is_function(&typ) { @@ -195,8 +208,13 @@ fn build_member_hover( member_id: LuaMemberId, ) -> Option<()> { let member = db.get_member_index().get_member(&member_id)?; - let mut semantic_decls = find_member_origin_owners(&builder.semantic_model, member_id) - .get_types(&builder.semantic_model); + let mut semantic_decls = find_member_origin_owners( + builder.compilation, + &builder.semantic_model, + member_id, + true, + ) + .get_types(&builder.semantic_model); replace_semantic_type(&mut semantic_decls, &typ); let member_name = match member.get_key() { LuaMemberKey::Name(name) => name.to_string(), diff --git a/crates/emmylua_ls/src/handlers/hover/find_origin.rs b/crates/emmylua_ls/src/handlers/hover/find_origin.rs index ea67b9ef0..b716da063 100644 --- a/crates/emmylua_ls/src/handlers/hover/find_origin.rs +++ b/crates/emmylua_ls/src/handlers/hover/find_origin.rs @@ -1,7 +1,8 @@ use std::collections::HashSet; use emmylua_code_analysis::{ - LuaDeclId, LuaMemberId, LuaSemanticDeclId, LuaType, SemanticDeclLevel, SemanticModel, + LuaCompilation, LuaDeclId, LuaMemberId, LuaSemanticDeclId, LuaType, SemanticDeclLevel, + SemanticModel, }; use emmylua_parser::{LuaAssignStat, LuaAstNode, LuaSyntaxKind, LuaTableExpr, LuaTableField}; @@ -42,6 +43,7 @@ impl DeclOriginResult { } pub fn find_decl_origin_owners( + compilation: &LuaCompilation, semantic_model: &SemanticModel, decl_id: LuaDeclId, ) -> DeclOriginResult { @@ -63,7 +65,7 @@ pub fn find_decl_origin_owners( let semantic_decl = semantic_model.find_decl(node.into(), SemanticDeclLevel::default()); match semantic_decl { Some(LuaSemanticDeclId::Member(member_id)) => { - find_member_origin_owners(semantic_model, member_id) + find_member_origin_owners(compilation, semantic_model, member_id, true) } Some(LuaSemanticDeclId::LuaDecl(decl_id)) => { DeclOriginResult::Single(LuaSemanticDeclId::LuaDecl(decl_id)) @@ -75,14 +77,16 @@ pub fn find_decl_origin_owners( } } -pub fn find_member_origin_owners( +pub fn find_member_origin_owners<'a>( + compilation: &'a LuaCompilation, semantic_model: &SemanticModel, member_id: LuaMemberId, + find_all: bool, ) -> DeclOriginResult { const MAX_ITERATIONS: usize = 50; let mut visited_members = HashSet::new(); - let mut current_owner = resolve_member_owner(semantic_model, &member_id); + let mut current_owner = resolve_member_owner(compilation, semantic_model, &member_id); let mut final_owner = current_owner.clone(); let mut iteration_count = 0; @@ -94,7 +98,7 @@ pub fn find_member_origin_owners( visited_members.insert(current_member_id.clone()); iteration_count += 1; - match resolve_member_owner(semantic_model, current_member_id) { + match resolve_member_owner(compilation, semantic_model, current_member_id) { Some(next_owner) => { final_owner = Some(next_owner.clone()); current_owner = Some(next_owner); @@ -107,6 +111,12 @@ pub fn find_member_origin_owners( final_owner = Some(LuaSemanticDeclId::Member(member_id)); } + if !find_all { + return DeclOriginResult::Single( + final_owner.unwrap_or_else(|| LuaSemanticDeclId::Member(member_id)), + ); + } + // 如果存在多个同名成员, 则返回多个成员 if let Some(same_named_members) = find_all_same_named_members(semantic_model, &final_owner) { if same_named_members.len() > 1 { @@ -118,10 +128,11 @@ pub fn find_member_origin_owners( } pub fn find_member_origin_owner( + compilation: &LuaCompilation, semantic_model: &SemanticModel, member_id: LuaMemberId, ) -> Option { - find_member_origin_owners(semantic_model, member_id).get_first() + find_member_origin_owners(compilation, semantic_model, member_id, false).get_first() } pub fn find_all_same_named_members( @@ -163,14 +174,18 @@ pub fn find_all_same_named_members( } fn resolve_member_owner( + compilation: &LuaCompilation, semantic_model: &SemanticModel, member_id: &LuaMemberId, ) -> Option { - let root = semantic_model - .get_db() - .get_vfs() - .get_syntax_tree(&member_id.file_id)? - .get_red_root(); + // 通常来说, 即使需要跨文件也一般只会跨一个文件, 所有不需要缓存 + let semantic_model = if member_id.file_id == semantic_model.get_file_id() { + semantic_model + } else { + &compilation.get_semantic_model(member_id.file_id)? + }; + + let root = semantic_model.get_root().syntax(); let current_node = member_id.get_syntax_id().to_node_from_root(&root)?; match member_id.get_syntax_id().get_kind() { LuaSyntaxKind::TableFieldAssign => { @@ -178,7 +193,7 @@ fn resolve_member_owner( let table_field = LuaTableField::cast(current_node.clone())?; // 如果表是类, 那么通过类型推断获取 owner if let Some(owner_id) = - resolve_table_field_through_type_inference(semantic_model, &table_field) + resolve_table_field_through_type_inference(&semantic_model, &table_field) { return Some(owner_id); } @@ -262,11 +277,17 @@ pub fn replace_semantic_type( } // 判断是否存在泛型, 如果有任意类型不匹配我们就认为存在泛型 + let mut has_generic = false; + let type_set: HashSet<_> = type_vec.iter().collect(); for (_, typ) in semantic_decls.iter() { - if !type_vec.iter().any(|t| *t == typ) { + if !type_set.contains(&typ) { + has_generic = true; break; } } + if !has_generic { + return; + } // 替换`semantic_decls`中的类型 for (i, (_, typ)) in semantic_decls.iter_mut().enumerate() { diff --git a/crates/emmylua_ls/src/handlers/hover/function_humanize.rs b/crates/emmylua_ls/src/handlers/hover/function_humanize.rs index 9e7476705..4dc527c8c 100644 --- a/crates/emmylua_ls/src/handlers/hover/function_humanize.rs +++ b/crates/emmylua_ls/src/handlers/hover/function_humanize.rs @@ -1,6 +1,739 @@ -use emmylua_code_analysis::{LuaMember, LuaSignatureId, LuaType, SemanticModel}; +use std::collections::HashSet; + +use emmylua_code_analysis::{ + humanize_type, DbIndex, LuaDocReturnInfo, LuaFunctionType, LuaMember, LuaMemberKey, + LuaMemberOwner, LuaSemanticDeclId, LuaSignature, LuaSignatureId, LuaType, RenderLevel, + SemanticModel, +}; use emmylua_parser::{LuaAstNode, LuaDocTagField, LuaDocType}; +use crate::handlers::hover::{ + hover_humanize::{ + extract_description_from_property_owner, extract_owner_name_from_element, + hover_humanize_type, DescriptionInfo, + }, + infer_prefix_global_name, HoverBuilder, +}; + +#[derive(Debug, Clone)] +struct HoverFunctionInfo { + type_description: String, + overloads: Option>, + description: Option, + is_call_function: bool, +} + +pub fn hover_function_type( + builder: &mut HoverBuilder, + db: &DbIndex, + semantic_decls: &[(LuaSemanticDeclId, LuaType)], +) -> Option<()> { + let (name, is_local) = { + let Some((semantic_decl, _)) = semantic_decls.first() else { + return None; + }; + match semantic_decl { + LuaSemanticDeclId::LuaDecl(id) => { + let decl = db.get_decl_index().get_decl(&id)?; + (decl.get_name().to_string(), decl.is_local()) + } + LuaSemanticDeclId::Member(id) => { + let member = db.get_member_index().get_member(&id)?; + (member.get_key().to_path(), false) + } + _ => { + return None; + } + } + }; + + let call_function = builder.get_call_function(); + // 已处理过的 semantic_decl_id, 用于解决`test_issue_499_3` + let mut handled_semantic_decl_ids = HashSet::new(); + let mut type_descs: Vec = Vec::with_capacity(semantic_decls.len()); + // 记录已处理过的类型,用于在 Union 中跳过重复类型. + // 这是为了解决最后一个类型可能是前面所有类型的联合类型的情况 + let mut processed_types = HashSet::new(); + + for (semantic_decl_id, typ) in semantic_decls { + let is_new = handled_semantic_decl_ids.insert(semantic_decl_id); + let mut function_info = HoverFunctionInfo { + type_description: String::new(), + overloads: None, + description: if is_new { + extract_description_from_property_owner(&builder.semantic_model, semantic_decl_id) + } else { + None + }, + is_call_function: false, + }; + + let function_member = match semantic_decl_id { + LuaSemanticDeclId::Member(id) => { + let member = db.get_member_index().get_member(&id)?; + // 以 @field 定义的 function 描述信息绑定的 id 并不是 member, 需要特殊处理 + if is_new && function_info.description.is_none() { + if let Some(signature_id) = + try_extract_signature_id_from_field(builder.semantic_model, &member) + { + function_info.description = extract_description_from_property_owner( + &builder.semantic_model, + &LuaSemanticDeclId::Signature(signature_id), + ); + } + } + Some(member) + } + _ => None, + }; + + // 如果当前类型是 Union,传入已处理的类型集合 + let result = match typ { + LuaType::Union(_) => process_single_function_type_with_exclusions( + builder, + db, + typ, + function_member, + &name, + is_local, + call_function.as_ref(), + &processed_types, + ), + _ => { + // 记录非 Union 类型 + processed_types.insert(typ.clone()); + process_single_function_type( + builder, + db, + typ, + function_member, + &name, + is_local, + call_function.as_ref(), + ) + } + }; + + match result { + ProcessFunctionTypeResult::Single(mut info) => { + // 合并描述信息 + if function_info.description.is_some() && info.description.is_none() { + info.description = function_info.description; + } + function_info = info; + } + ProcessFunctionTypeResult::Multiple(infos) => { + // 对于 Union 类型,将每个子类型的结果都添加到 type_descs 中 + let infos_len = infos.len(); + for (index, mut info) in infos.into_iter().enumerate() { + // 合并描述信息,只有最后一个才设置描述 + if function_info.description.is_some() + && info.description.is_none() + && index == infos_len - 1 + { + info.description = function_info.description.clone(); + } + if info.is_call_function { + type_descs.clear(); + type_descs.push(info); + break; + } else { + type_descs.push(info); + } + } + continue; + } + ProcessFunctionTypeResult::Skip => { + continue; + } + } + + if function_info.is_call_function { + type_descs.clear(); + type_descs.push(function_info); + break; + } else { + type_descs.push(function_info); + } + } + + // 此时是函数调用且具有完全匹配的签名, 那么只需要显示对应的签名, 不需要显示重载 + if let Some(info) = type_descs.first() { + if info.is_call_function { + builder.signature_overload = None; + builder.set_type_description(info.type_description.clone()); + + builder.add_description_from_info(info.description.clone()); + return Some(()); + } + } + + // 去重 + type_descs.dedup_by_key(|info| info.type_description.clone()); + + // 需要显示重载的情况 + match type_descs.len() { + 0 => { + return None; + } + 1 => { + builder.set_type_description(type_descs[0].type_description.clone()); + builder.add_description_from_info(type_descs[0].description.clone()); + } + _ => { + // 将最后一个作为 type_description + let main_type = type_descs.pop()?; + builder.set_type_description(main_type.type_description.clone()); + builder.add_description_from_info(main_type.description.clone()); + + for type_desc in type_descs { + builder.add_signature_overload(type_desc.type_description); + if let Some(overloads) = type_desc.overloads { + for overload in overloads { + builder.add_signature_overload(overload); + } + } + builder.add_description_from_info(type_desc.description); + } + } + } + + Some(()) +} + +fn hover_doc_function_type( + builder: &HoverBuilder, + db: &DbIndex, + lua_func: &LuaFunctionType, + owner_member: Option<&LuaMember>, + func_name: &str, +) -> String { + let async_label = if lua_func.is_async() { "async " } else { "" }; + let mut is_method = lua_func.is_colon_define(); + let mut type_label = "function "; + // 有可能来源于类. 例如: `local add = class.add`, `add()`应被视为类方法 + let full_name = if let Some(owner_member) = owner_member { + let global_name = infer_prefix_global_name(builder.semantic_model, owner_member); + let mut name = String::new(); + let parent_owner = db + .get_member_index() + .get_current_owner(&owner_member.get_id()); + if let Some(parent_owner) = parent_owner { + match parent_owner { + LuaMemberOwner::Type(type_decl_id) => { + // 如果是全局定义, 则使用定义时的名称 + if let Some(global_name) = global_name { + name.push_str(global_name); + } else { + name.push_str(type_decl_id.get_simple_name()); + } + if owner_member.is_field() { + type_label = "(field) "; + } + is_method = lua_func.is_method( + builder.semantic_model, + Some(&LuaType::Ref(type_decl_id.clone())), + ); + } + LuaMemberOwner::Element(element_id) => { + if let Some(owner_name) = + extract_owner_name_from_element(builder.semantic_model, element_id) + { + name.push_str(&owner_name); + } + } + _ => {} + } + } + + if is_method { + type_label = "(method) "; + name.push_str(":"); + } else { + name.push_str("."); + } + if let LuaMemberKey::Name(n) = owner_member.get_key() { + name.push_str(n.as_str()); + } + name + } else { + func_name.to_string() + }; + + let params = lua_func + .get_params() + .iter() + .enumerate() + .map(|(index, param)| { + let name = param.0.clone(); + if index == 0 && is_method { + "".to_string() + } else if let Some(ty) = ¶m.1 { + format!("{}: {}", name, humanize_type(db, ty, RenderLevel::Normal)) + } else { + name.to_string() + } + }) + .filter(|s| !s.is_empty()) + .collect::>() + .join(", "); + + let ret_detail = { + let ret_type = lua_func.get_ret(); + match ret_type { + LuaType::Nil => "".to_string(), + _ => { + format!(" -> {}", humanize_type(db, ret_type, RenderLevel::Simple)) + } + } + }; + format_function_type(type_label, async_label, full_name, params, ret_detail) +} + +struct HoverSignatureResult { + type_description: String, + overloads: Option>, + call_function: Option, +} + +fn hover_signature_type( + builder: &mut HoverBuilder, + db: &DbIndex, + signature_id: LuaSignatureId, + owner_member: Option<&LuaMember>, + func_name: &str, + is_local: bool, + call_function: Option<&LuaFunctionType>, +) -> Option { + let signature = db.get_signature_index().get(&signature_id)?; + + let mut is_method = signature.is_colon_define; + let mut self_real_type = LuaType::SelfInfer; + let mut type_label = "function "; + // 有可能来源于类. 例如: `local add = class.add`, `add()`应被视为类定义的内容 + let full_name = if let Some(owner_member) = owner_member { + let global_name = infer_prefix_global_name(builder.semantic_model, owner_member); + let mut name = String::new(); + let parent_owner = db + .get_member_index() + .get_current_owner(&owner_member.get_id()); + match parent_owner { + Some(LuaMemberOwner::Type(type_decl_id)) => { + self_real_type = LuaType::Ref(type_decl_id.clone()); + // 如果是全局定义, 则使用定义时的名称 + if let Some(global_name) = global_name { + name.push_str(global_name); + } else { + name.push_str(type_decl_id.get_simple_name()); + } + if owner_member.is_field() { + type_label = "(field) "; + } + // `field`定义的function也被视为`signature`, 因此这里需要额外处理 + is_method = signature.is_method(builder.semantic_model, Some(&self_real_type)); + if is_method { + type_label = "(method) "; + name.push_str(":"); + } else { + name.push_str("."); + } + } + Some(LuaMemberOwner::Element(element_id)) => { + if let Some(owner_name) = + extract_owner_name_from_element(builder.semantic_model, element_id) + { + name.push_str(&owner_name); + name.push_str("."); + } + } + _ => {} + } + if let LuaMemberKey::Name(n) = owner_member.get_key() { + name.push_str(n.as_str()); + } + name + } else { + if is_local { + type_label = "local function "; + } + func_name.to_string() + }; + + // 构建 signature + let signature_info: String = { + let async_label = db + .get_signature_index() + .get(&signature_id) + .map(|signature| if signature.is_async { "async " } else { "" }) + .unwrap_or(""); + let params = signature + .get_type_params() + .iter() + .enumerate() + .map(|(index, param)| { + let name = param.0.clone(); + if index == 0 && !signature.is_colon_define && is_method { + "".to_string() + } else if let Some(ty) = ¶m.1 { + format!("{}: {}", name, humanize_type(db, ty, RenderLevel::Simple)) + } else { + name + } + }) + .filter(|s| !s.is_empty()) + .collect::>() + .join(", "); + let rets = build_signature_rets(builder, signature, builder.is_completion, None); + let result = format_function_type(type_label, async_label, full_name.clone(), params, rets); + // 由于 @field 定义的`docfunction`会被视为`signature`, 因此这里额外处理 + if let Some(call_function) = call_function { + if call_function.get_params() == signature.get_type_params() { + // 如果具有完全匹配的签名, 那么将其设置为当前签名, 且不显示重载 + return Some(HoverSignatureResult { + type_description: result, + overloads: None, + call_function: Some(call_function.clone()), + }); + } + } + result + }; + // 构建所有重载 + let overloads: Vec = { + let mut overloads = Vec::new(); + for (_, overload) in signature.overloads.iter().enumerate() { + let async_label = if overload.is_async() { "async " } else { "" }; + let params = overload + .get_params() + .iter() + .enumerate() + .map(|(index, param)| { + let name = param.0.clone(); + if index == 0 + && param.1.is_some() + && overload.is_method(builder.semantic_model, Some(&self_real_type)) + { + "".to_string() + } else if let Some(ty) = ¶m.1 { + format!("{}: {}", name, humanize_type(db, ty, RenderLevel::Simple)) + } else { + name + } + }) + .filter(|s| !s.is_empty()) + .collect::>() + .join(", "); + let rets = + build_signature_rets(builder, signature, builder.is_completion, Some(overload)); + let result = + format_function_type(type_label, async_label, full_name.clone(), params, rets); + + if let Some(call_function) = call_function { + if *call_function == **overload { + // 如果具有完全匹配的签名, 那么将其设置为当前签名, 且不显示重载 + return Some(HoverSignatureResult { + type_description: result, + overloads: None, + call_function: Some(call_function.clone()), + }); + } + }; + overloads.push(result); + } + overloads + }; + + Some(HoverSignatureResult { + type_description: signature_info, + overloads: Some(overloads), + call_function: None, + }) +} + +fn build_signature_rets( + builder: &mut HoverBuilder, + signature: &LuaSignature, + is_completion: bool, + overload: Option<&LuaFunctionType>, +) -> String { + let db = builder.semantic_model.get_db(); + let mut result = String::new(); + // overload 的返回值固定为单行 + let overload_rets_string = if let Some(overload) = overload { + let ret_type = overload.get_ret(); + match ret_type { + LuaType::Nil => "".to_string(), + _ => { + format!(" -> {}", humanize_type(db, ret_type, RenderLevel::Simple)) + } + } + } else { + "".to_string() + }; + + if is_completion { + let rets = if !overload_rets_string.is_empty() { + overload_rets_string + } else { + let rets = &signature.return_docs; + if rets.is_empty() || signature.get_return_type().is_nil() { + "".to_string() + } else { + format!( + " -> {}", + rets.iter() + .enumerate() + .map(|(i, ret)| build_signature_ret_type(builder, ret, i)) + .collect::>() + .join(", ") + ) + } + }; + result.push_str(rets.as_str()); + return result; + } + + let rets = if !overload_rets_string.is_empty() { + overload_rets_string + } else { + let rets = &signature.return_docs; + if rets.is_empty() || signature.get_return_type().is_nil() { + "".to_string() + } else { + let mut rets_string_multiline = String::new(); + rets_string_multiline.push_str("\n"); + + for (i, ret) in rets.iter().enumerate() { + let type_text = build_signature_ret_type(builder, ret, i); + let prefix = if i == 0 { + "-> ".to_string() + } else { + format!("{}. ", i + 1) + }; + let name = ret.name.clone().unwrap_or_default(); + + rets_string_multiline.push_str(&format!( + " {}{}{}\n", + prefix, + if !name.is_empty() { + format!("{}: ", name) + } else { + "".to_string() + }, + type_text, + )); + } + rets_string_multiline + } + }; + result.push_str(rets.as_str()); + result +} + +fn build_signature_ret_type( + builder: &mut HoverBuilder, + ret_info: &LuaDocReturnInfo, + i: usize, +) -> String { + let type_expansion_count = builder.get_type_expansion_count(); + let type_text = hover_humanize_type(builder, &ret_info.type_ref, Some(RenderLevel::Simple)); + if builder.get_type_expansion_count() > type_expansion_count { + // 重新设置`type_expansion` + if let Some(pop_type_expansion) = + builder.pop_type_expansion(type_expansion_count, builder.get_type_expansion_count()) + { + let mut new_type_expansion = format!("return #{}", i + 1); + let mut seen = HashSet::new(); + for type_expansion in pop_type_expansion { + for line in type_expansion.lines().skip(1) { + if seen.insert(line.to_string()) { + new_type_expansion.push('\n'); + new_type_expansion.push_str(line); + } + } + } + builder.add_type_expansion(new_type_expansion); + } + }; + type_text +} + +fn format_function_type( + type_label: &str, + async_label: &str, + full_name: String, + params: String, + rets: String, +) -> String { + let prefix = if type_label.starts_with("function") { + format!("{}{}", async_label, type_label) + } else { + format!("{}{}", type_label, async_label) + }; + format!("{}{}({}){}", prefix, full_name, params, rets) +} + +#[derive(Debug, Clone)] +enum ProcessFunctionTypeResult { + Single(HoverFunctionInfo), + Multiple(Vec), + Skip, +} + +fn process_single_function_type( + builder: &mut HoverBuilder, + db: &DbIndex, + typ: &LuaType, + function_member: Option<&LuaMember>, + name: &str, + is_local: bool, + call_function: Option<&LuaFunctionType>, +) -> ProcessFunctionTypeResult { + match typ { + LuaType::Function => ProcessFunctionTypeResult::Single(HoverFunctionInfo { + type_description: format!("function {}()", name), + overloads: None, + description: None, + is_call_function: false, + }), + LuaType::DocFunction(lua_func) => { + let type_description = + hover_doc_function_type(builder, db, &lua_func, function_member, &name); + let is_call_function = if let Some(call_function) = call_function { + call_function.get_params() == lua_func.get_params() + } else { + false + }; + + ProcessFunctionTypeResult::Single(HoverFunctionInfo { + type_description, + overloads: None, + description: None, + is_call_function, + }) + } + LuaType::Signature(signature_id) => { + let signature_result = hover_signature_type( + builder, + db, + signature_id.clone(), + function_member, + name, + is_local, + call_function, + ) + .unwrap_or_else(|| HoverSignatureResult { + type_description: format!("function {}", name), + overloads: None, + call_function: None, + }); + + let is_call_function = signature_result.call_function.is_some(); + + ProcessFunctionTypeResult::Single(HoverFunctionInfo { + type_description: signature_result.type_description, + overloads: signature_result.overloads, + description: None, + is_call_function, + }) + } + LuaType::Union(union) => { + let mut results = Vec::new(); + for union_type in union.get_types() { + match process_single_function_type( + builder, + db, + union_type, + function_member, + name, + is_local, + call_function, + ) { + ProcessFunctionTypeResult::Single(info) => { + results.push(info); + } + ProcessFunctionTypeResult::Multiple(infos) => { + results.extend(infos); + } + ProcessFunctionTypeResult::Skip => {} + } + } + + if results.is_empty() { + ProcessFunctionTypeResult::Skip + } else { + ProcessFunctionTypeResult::Multiple(results) + } + } + _ => ProcessFunctionTypeResult::Single(HoverFunctionInfo { + type_description: format!("function {}", name), + overloads: None, + description: None, + is_call_function: false, + }), + } +} + +fn process_single_function_type_with_exclusions( + builder: &mut HoverBuilder, + db: &DbIndex, + typ: &LuaType, + function_member: Option<&LuaMember>, + name: &str, + is_local: bool, + call_function: Option<&LuaFunctionType>, + processed_types: &HashSet, +) -> ProcessFunctionTypeResult { + match typ { + LuaType::Union(union) => { + let mut results = Vec::new(); + for union_type in union.get_types() { + // 跳过已经处理过的类型 + if processed_types.contains(union_type) { + continue; + } + + match process_single_function_type_with_exclusions( + builder, + db, + union_type, + function_member, + name, + is_local, + call_function, + processed_types, + ) { + ProcessFunctionTypeResult::Single(info) => { + results.push(info); + } + ProcessFunctionTypeResult::Multiple(infos) => { + results.extend(infos); + } + ProcessFunctionTypeResult::Skip => {} + } + } + + if results.is_empty() { + ProcessFunctionTypeResult::Skip + } else { + ProcessFunctionTypeResult::Multiple(results) + } + } + _ => { + // 对于非 Union 类型,直接调用原函数 + process_single_function_type( + builder, + db, + typ, + function_member, + name, + is_local, + call_function, + ) + } + } +} + pub fn is_function(typ: &LuaType) -> bool { typ.is_function() || match &typ { diff --git a/crates/emmylua_ls/src/handlers/hover/hover_builder.rs b/crates/emmylua_ls/src/handlers/hover/hover_builder.rs index 684cafc34..cddb8d528 100644 --- a/crates/emmylua_ls/src/handlers/hover/hover_builder.rs +++ b/crates/emmylua_ls/src/handlers/hover/hover_builder.rs @@ -1,5 +1,6 @@ use emmylua_code_analysis::{ - LuaFunctionType, LuaMember, LuaMemberOwner, LuaSemanticDeclId, LuaType, SemanticModel, + LuaCompilation, LuaFunctionType, LuaMember, LuaMemberOwner, LuaSemanticDeclId, LuaType, + SemanticModel, }; use emmylua_parser::{LuaAstNode, LuaCallExpr, LuaSyntaxToken}; use lsp_types::{Hover, HoverContents, MarkedString, MarkupContent}; @@ -23,22 +24,25 @@ pub struct HoverBuilder<'a> { /// Type expansion, often used for alias types pub type_expansion: Option>, /// see - pub see_content: Option, + see_content: Option, /// other - pub other_content: Option, + other_content: Option, pub is_completion: bool, trigger_token: Option, pub semantic_model: &'a SemanticModel<'a>, + pub compilation: &'a LuaCompilation, } impl<'a> HoverBuilder<'a> { pub fn new( + compilation: &'a LuaCompilation, semantic_model: &'a SemanticModel, token: Option, is_completion: bool, ) -> Self { Self { + compilation, semantic_model, type_description: MarkedString::String("".to_string()), location_path: None, @@ -284,6 +288,9 @@ impl<'a> HoverBuilder<'a> { result.push_str(&description_content); result.push_str(&expansion); + // 清除空白字符 + result = result.trim().to_string(); + Some(Hover { contents: HoverContents::Markup(MarkupContent { kind: lsp_types::MarkupKind::Markdown, diff --git a/crates/emmylua_ls/src/handlers/hover/hover_humanize.rs b/crates/emmylua_ls/src/handlers/hover/hover_humanize.rs index ca5fc573b..7e3cd048a 100644 --- a/crates/emmylua_ls/src/handlers/hover/hover_humanize.rs +++ b/crates/emmylua_ls/src/handlers/hover/hover_humanize.rs @@ -1,16 +1,12 @@ -use std::collections::HashSet; - -use crate::handlers::hover::function_humanize::try_extract_signature_id_from_field; - use super::std_hover::{hover_std_description, is_std}; use emmylua_code_analysis::{ - format_union_type, DbIndex, LuaDocReturnInfo, LuaFunctionType, LuaMember, LuaMemberKey, - LuaMemberOwner, LuaMultiLineUnion, LuaSemanticDeclId, LuaSignature, LuaSignatureId, LuaType, - LuaUnionType, RenderLevel, SemanticDeclLevel, SemanticModel, + format_union_type, DbIndex, InFiled, LuaMember, LuaMemberOwner, LuaMultiLineUnion, + LuaSemanticDeclId, LuaType, LuaUnionType, RenderLevel, SemanticDeclLevel, SemanticModel, }; use emmylua_code_analysis::humanize_type; -use emmylua_parser::{LuaAstNode, LuaIndexExpr, LuaSyntaxKind}; +use emmylua_parser::{LuaAstNode, LuaExpr, LuaIndexExpr, LuaStat, LuaSyntaxId, LuaSyntaxKind}; +use rowan::TextRange; use super::hover_builder::HoverBuilder; @@ -27,581 +23,6 @@ pub fn hover_const_type(db: &DbIndex, typ: &LuaType) -> String { } } -#[derive(Debug, Clone)] -struct HoverFunctionInfo { - type_description: String, - overloads: Option>, - description: Option, - is_call_function: bool, -} - -pub fn hover_function_type( - builder: &mut HoverBuilder, - db: &DbIndex, - semantic_decls: &[(LuaSemanticDeclId, LuaType)], -) -> Option<()> { - let (name, is_local) = { - let Some((semantic_decl, _)) = semantic_decls.first() else { - return None; - }; - match semantic_decl { - LuaSemanticDeclId::LuaDecl(id) => { - let decl = db.get_decl_index().get_decl(&id)?; - (decl.get_name().to_string(), decl.is_local()) - } - LuaSemanticDeclId::Member(id) => { - let member = db.get_member_index().get_member(&id)?; - (member.get_key().to_path(), false) - } - _ => { - return None; - } - } - }; - - let call_function = builder.get_call_function(); - // 已处理过的 semantic_decl_id, 用于解决`test_issue_499_3` - let mut handled_semantic_decl_ids = HashSet::new(); - let mut type_descs: Vec = Vec::with_capacity(semantic_decls.len()); - - for (semantic_decl_id, typ) in semantic_decls { - let is_new = handled_semantic_decl_ids.insert(semantic_decl_id); - let mut function_info = HoverFunctionInfo { - type_description: String::new(), - overloads: None, - description: if is_new { - extract_description_from_property_owner(&builder.semantic_model, semantic_decl_id) - } else { - None - }, - is_call_function: false, - }; - - let function_member = match semantic_decl_id { - LuaSemanticDeclId::Member(id) => { - let member = db.get_member_index().get_member(&id)?; - // 以 @field 定义的 function 描述信息绑定的 id 并不是 member, 需要特殊处理 - if is_new && function_info.description.is_none() { - if let Some(signature_id) = - try_extract_signature_id_from_field(builder.semantic_model, &member) - { - function_info.description = extract_description_from_property_owner( - &builder.semantic_model, - &LuaSemanticDeclId::Signature(signature_id), - ); - } - } - Some(member) - } - _ => None, - }; - - match typ { - LuaType::Function => { - function_info.type_description = format!("function {}()", name); - } - LuaType::DocFunction(lua_func) => { - function_info.type_description = - hover_doc_function_type(builder, db, &lua_func, function_member, &name); - if let Some(call_function) = &call_function { - if call_function.get_params() == lua_func.get_params() { - function_info.is_call_function = true; - } - } - } - LuaType::Signature(signature_id) => { - let signature_result = hover_signature_type( - builder, - db, - signature_id.clone(), - function_member, - &name, - is_local, - ) - .unwrap_or_else(|| HoverSignatureResult { - type_description: format!("function {}", name), - overloads: None, - call_function: None, - }); - function_info.type_description = signature_result.type_description; - function_info.overloads = signature_result.overloads; - - if let Some(_) = signature_result.call_function { - function_info.is_call_function = true; - } - } - LuaType::Union(_) => { - continue; - } - _ => { - function_info.type_description = format!("function {}", name); - } - }; - if function_info.is_call_function { - type_descs.clear(); - type_descs.push(function_info); - break; - } else { - type_descs.push(function_info); - } - } - - // 此时是函数调用且具有完全匹配的签名, 那么只需要显示对应的签名, 不需要显示重载 - if let Some(info) = type_descs.first() { - if info.is_call_function { - builder.signature_overload = None; - builder.set_type_description(info.type_description.clone()); - - builder.add_description_from_info(info.description.clone()); - return Some(()); - } - } - - // 去重 - type_descs.dedup_by_key(|info| info.type_description.clone()); - - // 需要显示重载的情况 - match type_descs.len() { - 0 => { - return None; - } - 1 => { - builder.set_type_description(type_descs[0].type_description.clone()); - builder.add_description_from_info(type_descs[0].description.clone()); - } - _ => { - // 将最后一个作为 type_description - let main_type = type_descs.pop()?; - builder.set_type_description(main_type.type_description.clone()); - builder.add_description_from_info(main_type.description.clone()); - - for type_desc in type_descs { - builder.add_signature_overload(type_desc.type_description); - if let Some(overloads) = type_desc.overloads { - for overload in overloads { - builder.add_signature_overload(overload); - } - } - builder.add_description_from_info(type_desc.description); - } - } - } - - Some(()) -} - -// fn hover_union_function_type( -// builder: &mut HoverBuilder, -// db: &DbIndex, -// union: &LuaUnionType, -// function_member: Option<&LuaMember>, -// func_name: &str, -// ) { -// // 泛型处理 -// if let Some(call) = builder.get_call_function() { -// builder.set_type_description(hover_doc_function_type( -// builder, -// db, -// &call, -// function_member, -// func_name, -// )); -// return; -// } -// let mut overloads = Vec::new(); - -// let types = union.get_types(); -// for typ in types { -// match typ { -// LuaType::DocFunction(lua_func) => { -// overloads.push(hover_doc_function_type( -// builder, -// db, -// &lua_func, -// function_member, -// func_name, -// )); -// } -// LuaType::Signature(signature_id) => { -// if let Some((type_description, signature_overloads)) = hover_signature_type( -// builder, -// db, -// signature_id.clone(), -// function_member, -// func_name, -// false, -// true, -// ) { -// if let Some(signature_overloads) = signature_overloads { -// for overload in signature_overloads { -// overloads.push(overload); -// } -// } -// overloads.push(type_description); -// } -// } -// _ => {} -// } -// } -// // 将最后一个作为 type_description -// if let Some(type_description) = overloads.pop() { -// builder.set_type_description(type_description); -// for overload in overloads { -// builder.add_signature_overload(overload); -// } -// } -// } - -fn hover_doc_function_type( - builder: &HoverBuilder, - db: &DbIndex, - lua_func: &LuaFunctionType, - owner_member: Option<&LuaMember>, - func_name: &str, -) -> String { - let async_label = if lua_func.is_async() { "async " } else { "" }; - let mut is_method = lua_func.is_colon_define(); - let mut type_label = "function "; - // 有可能来源于类. 例如: `local add = class.add`, `add()`应被视为类方法 - let full_name = if let Some(owner_member) = owner_member { - let global_name = infer_prefix_global_name(builder.semantic_model, owner_member); - let mut name = String::new(); - let parent_owner = db - .get_member_index() - .get_current_owner(&owner_member.get_id()); - if let Some(LuaMemberOwner::Type(type_decl_id)) = parent_owner { - // 如果是全局定义, 则使用定义时的名称 - if let Some(global_name) = global_name { - name.push_str(global_name); - } else { - name.push_str(type_decl_id.get_simple_name()); - } - if owner_member.is_field() { - type_label = "(field) "; - } - is_method = lua_func.is_method( - builder.semantic_model, - Some(&LuaType::Ref(type_decl_id.clone())), - ); - } - - if is_method { - type_label = "(method) "; - name.push_str(":"); - } else { - name.push_str("."); - } - if let LuaMemberKey::Name(n) = owner_member.get_key() { - name.push_str(n.as_str()); - } - name - } else { - func_name.to_string() - }; - - let params = lua_func - .get_params() - .iter() - .enumerate() - .map(|(index, param)| { - let name = param.0.clone(); - if index == 0 && is_method { - "".to_string() - } else if let Some(ty) = ¶m.1 { - format!("{}: {}", name, humanize_type(db, ty, RenderLevel::Normal)) - } else { - name.to_string() - } - }) - .filter(|s| !s.is_empty()) - .collect::>() - .join(", "); - - let ret_detail = { - let ret_type = lua_func.get_ret(); - match ret_type { - LuaType::Nil => "".to_string(), - _ => { - format!(" -> {}", humanize_type(db, ret_type, RenderLevel::Simple)) - } - } - }; - format_function_type(type_label, async_label, full_name, params, ret_detail) -} - -struct HoverSignatureResult { - type_description: String, - overloads: Option>, - call_function: Option, -} - -fn hover_signature_type( - builder: &mut HoverBuilder, - db: &DbIndex, - signature_id: LuaSignatureId, - owner_member: Option<&LuaMember>, - func_name: &str, - is_local: bool, -) -> Option { - let signature = db.get_signature_index().get(&signature_id)?; - - let call_function = builder.get_call_function(); - let mut is_method = signature.is_colon_define; - let mut self_real_type = LuaType::SelfInfer; - - let mut type_label = "function "; - // 有可能来源于类. 例如: `local add = class.add`, `add()`应被视为类定义的内容 - let full_name = if let Some(owner_member) = owner_member { - let global_name = infer_prefix_global_name(builder.semantic_model, owner_member); - let mut name = String::new(); - let parent_owner = db - .get_member_index() - .get_current_owner(&owner_member.get_id()); - if let Some(LuaMemberOwner::Type(type_decl_id)) = parent_owner { - self_real_type = LuaType::Ref(type_decl_id.clone()); - // 如果是全局定义, 则使用定义时的名称 - if let Some(global_name) = global_name { - name.push_str(global_name); - } else { - name.push_str(type_decl_id.get_simple_name()); - } - if owner_member.is_field() { - type_label = "(field) "; - } - // `field`定义的function也被视为`signature`, 因此这里需要额外处理 - is_method = signature.is_method(builder.semantic_model, Some(&self_real_type)); - if is_method { - type_label = "(method) "; - name.push_str(":"); - } else { - name.push_str("."); - } - } - if let LuaMemberKey::Name(n) = owner_member.get_key() { - name.push_str(n.as_str()); - } - name - } else { - if is_local { - type_label = "local function "; - } - func_name.to_string() - }; - - // 构建 signature - let signature_info: String = { - let async_label = db - .get_signature_index() - .get(&signature_id) - .map(|signature| if signature.is_async { "async " } else { "" }) - .unwrap_or(""); - let params = signature - .get_type_params() - .iter() - .enumerate() - .map(|(index, param)| { - let name = param.0.clone(); - if index == 0 && !signature.is_colon_define && is_method { - "".to_string() - } else if let Some(ty) = ¶m.1 { - format!("{}: {}", name, humanize_type(db, ty, RenderLevel::Simple)) - } else { - name - } - }) - .filter(|s| !s.is_empty()) - .collect::>() - .join(", "); - let rets = build_signature_rets(builder, signature, builder.is_completion, None); - let result = format_function_type(type_label, async_label, full_name.clone(), params, rets); - // 由于 @field 定义的`docfunction`会被视为`signature`, 因此这里额外处理 - if let Some(call_function) = &call_function { - if call_function.get_params() == signature.get_type_params() { - // 如果具有完全匹配的签名, 那么将其设置为当前签名, 且不显示重载 - return Some(HoverSignatureResult { - type_description: result, - overloads: None, - call_function: Some(call_function.clone()), - }); - } - } - result - }; - // 构建所有重载 - let overloads: Vec = { - let mut overloads = Vec::new(); - for (_, overload) in signature.overloads.iter().enumerate() { - let async_label = if overload.is_async() { "async " } else { "" }; - let params = overload - .get_params() - .iter() - .enumerate() - .map(|(index, param)| { - let name = param.0.clone(); - if index == 0 - && param.1.is_some() - && overload.is_method(builder.semantic_model, Some(&self_real_type)) - { - "".to_string() - } else if let Some(ty) = ¶m.1 { - format!("{}: {}", name, humanize_type(db, ty, RenderLevel::Simple)) - } else { - name - } - }) - .filter(|s| !s.is_empty()) - .collect::>() - .join(", "); - let rets = - build_signature_rets(builder, signature, builder.is_completion, Some(overload)); - let result = - format_function_type(type_label, async_label, full_name.clone(), params, rets); - - if let Some(call_function) = &call_function { - if *call_function == **overload { - // 如果具有完全匹配的签名, 那么将其设置为当前签名, 且不显示重载 - return Some(HoverSignatureResult { - type_description: result, - overloads: None, - call_function: Some(call_function.clone()), - }); - } - }; - overloads.push(result); - } - overloads - }; - - Some(HoverSignatureResult { - type_description: signature_info, - overloads: Some(overloads), - call_function: None, - }) -} - -fn build_signature_rets( - builder: &mut HoverBuilder, - signature: &LuaSignature, - is_completion: bool, - overload: Option<&LuaFunctionType>, -) -> String { - let db = builder.semantic_model.get_db(); - let mut result = String::new(); - // overload 的返回值固定为单行 - let overload_rets_string = if let Some(overload) = overload { - let ret_type = overload.get_ret(); - match ret_type { - LuaType::Nil => "".to_string(), - _ => { - format!(" -> {}", humanize_type(db, ret_type, RenderLevel::Simple)) - } - } - } else { - "".to_string() - }; - - if is_completion { - let rets = if !overload_rets_string.is_empty() { - overload_rets_string - } else { - let rets = &signature.return_docs; - if rets.is_empty() || signature.get_return_type().is_nil() { - "".to_string() - } else { - format!( - " -> {}", - rets.iter() - .enumerate() - .map(|(i, ret)| build_signature_ret_type(builder, ret, i)) - .collect::>() - .join(", ") - ) - } - }; - result.push_str(rets.as_str()); - return result; - } - - let rets = if !overload_rets_string.is_empty() { - overload_rets_string - } else { - let rets = &signature.return_docs; - if rets.is_empty() || signature.get_return_type().is_nil() { - "".to_string() - } else { - let mut rets_string_multiline = String::new(); - rets_string_multiline.push_str("\n"); - - for (i, ret) in rets.iter().enumerate() { - let type_text = build_signature_ret_type(builder, ret, i); - let prefix = if i == 0 { - "-> ".to_string() - } else { - format!("{}. ", i + 1) - }; - let name = ret.name.clone().unwrap_or_default(); - - rets_string_multiline.push_str(&format!( - " {}{}{}\n", - prefix, - if !name.is_empty() { - format!("{}: ", name) - } else { - "".to_string() - }, - type_text, - )); - } - rets_string_multiline - } - }; - result.push_str(rets.as_str()); - result -} - -fn build_signature_ret_type( - builder: &mut HoverBuilder, - ret_info: &LuaDocReturnInfo, - i: usize, -) -> String { - let type_expansion_count = builder.get_type_expansion_count(); - let type_text = hover_humanize_type(builder, &ret_info.type_ref, Some(RenderLevel::Simple)); - if builder.get_type_expansion_count() > type_expansion_count { - // 重新设置`type_expansion` - if let Some(pop_type_expansion) = - builder.pop_type_expansion(type_expansion_count, builder.get_type_expansion_count()) - { - let mut new_type_expansion = format!("return #{}", i + 1); - let mut seen = HashSet::new(); - for type_expansion in pop_type_expansion { - for line in type_expansion.lines().skip(1) { - if seen.insert(line.to_string()) { - new_type_expansion.push('\n'); - new_type_expansion.push_str(line); - } - } - } - builder.add_type_expansion(new_type_expansion); - } - }; - type_text -} - -fn format_function_type( - type_label: &str, - async_label: &str, - full_name: String, - params: String, - rets: String, -) -> String { - let prefix = if type_label.starts_with("function") { - format!("{}{}", async_label, type_label) - } else { - format!("{}{}", type_label, async_label) - }; - format!("{}{}({}){}", prefix, full_name, params, rets) -} - pub fn hover_humanize_type( builder: &mut HoverBuilder, ty: &LuaType, @@ -801,3 +222,31 @@ pub fn extract_description_from_property_owner( Some(result) } } + +/// 从 element_id 中提取所有者名称 +pub fn extract_owner_name_from_element( + semantic_model: &SemanticModel, + element_id: &InFiled, +) -> Option { + let root = semantic_model + .get_db() + .get_vfs() + .get_syntax_tree(&element_id.file_id)? + .get_red_root(); + + // 通过 TextRange 找到对应的 AST 节点 + let node = LuaSyntaxId::to_node_at_range(&root, element_id.value)?; + let stat = LuaStat::cast(node.clone().parent()?)?; + match stat { + LuaStat::LocalStat(local_stat) => { + let value = LuaExpr::cast(node)?; + let local_name = local_stat.get_local_name_by_value(value); + if let Some(local_name) = local_name { + return Some(local_name.get_name_token()?.get_name_text().to_string()); + } + } + _ => {} + } + + None +} diff --git a/crates/emmylua_ls/src/handlers/hover/mod.rs b/crates/emmylua_ls/src/handlers/hover/mod.rs index ed081d72b..b82a8b62a 100644 --- a/crates/emmylua_ls/src/handlers/hover/mod.rs +++ b/crates/emmylua_ls/src/handlers/hover/mod.rs @@ -76,7 +76,14 @@ pub fn hover(analysis: &EmmyLuaAnalysis, file_id: FileId, position: Position) -> let semantic_info = semantic_model.get_semantic_info(token.clone().into())?; let db = semantic_model.get_db(); let document = semantic_model.get_document(); - build_semantic_info_hover(&semantic_model, db, &document, token, semantic_info) + build_semantic_info_hover( + &analysis.compilation, + &semantic_model, + db, + &document, + token, + semantic_info, + ) } } } diff --git a/crates/emmylua_ls/src/handlers/implementation/implementation_searcher.rs b/crates/emmylua_ls/src/handlers/implementation/implementation_searcher.rs index 2c60e3f8c..9b5b20a50 100644 --- a/crates/emmylua_ls/src/handlers/implementation/implementation_searcher.rs +++ b/crates/emmylua_ls/src/handlers/implementation/implementation_searcher.rs @@ -5,7 +5,8 @@ use emmylua_code_analysis::{ SemanticDeclLevel, SemanticModel, }; use emmylua_parser::{ - LuaAstNode, LuaDocTagField, LuaIndexExpr, LuaStat, LuaSyntaxNode, LuaSyntaxToken, LuaTableField, + LuaAstNode, LuaDocTagField, LuaExpr, LuaIndexExpr, LuaStat, LuaSyntaxNode, LuaSyntaxToken, + LuaTableField, }; use lsp_types::Location; @@ -56,7 +57,7 @@ pub fn search_member_implementations( let mut semantic_cache = HashMap::new(); - let property_owner = find_member_origin_owner(semantic_model, member_id) + let property_owner = find_member_origin_owner(compilation, semantic_model, member_id) .unwrap_or(LuaSemanticDeclId::Member(member_id)); for in_filed_syntax_id in index_references { let semantic_model = @@ -69,35 +70,46 @@ pub fn search_member_implementations( }; let root = semantic_model.get_root(); let node = in_filed_syntax_id.value.to_node_from_root(root.syntax())?; + if let Some(is_signature) = check_member_reference(&semantic_model, node.clone()) { + if !semantic_model.is_reference_to( + node, + property_owner.clone(), + SemanticDeclLevel::default(), + ) { + continue; + } - if check_member_reference(&semantic_model, node.clone()).is_none() { - continue; - } - - if !semantic_model.is_reference_to( - node, - property_owner.clone(), - SemanticDeclLevel::default(), - ) { - continue; + let document = semantic_model.get_document(); + let range = in_filed_syntax_id.value.get_range(); + let location = document.to_lsp_location(range)?; + // 由于允许函数声明重载, 所以需要将签名放在前面 + if is_signature { + result.insert(0, location); + } else { + result.push(location); + } } - - let document = semantic_model.get_document(); - let range = in_filed_syntax_id.value.get_range(); - let location = document.to_lsp_location(range)?; - result.push(location); } Some(()) } /// 检查成员引用是否符合实现 -fn check_member_reference(semantic_model: &SemanticModel, node: LuaSyntaxNode) -> Option<()> { +fn check_member_reference(semantic_model: &SemanticModel, node: LuaSyntaxNode) -> Option { match &node { expr_node if LuaIndexExpr::can_cast(expr_node.kind().into()) => { let expr = LuaIndexExpr::cast(expr_node.clone())?; let prefix_type = semantic_model .infer_expr(expr.get_prefix_expr()?.into()) .ok()?; + let mut is_signature = false; + if let Some(current_type) = semantic_model + .infer_expr(LuaExpr::IndexExpr(expr.clone())) + .ok() + { + if current_type.is_signature() { + is_signature = true; + } + } // TODO: 需要实现更复杂的逻辑, 即当为`Ref`时, 针对指定的实例定义到其实现 /* ---@class A @@ -123,7 +135,7 @@ fn check_member_reference(semantic_model: &SemanticModel, node: LuaSyntaxNode) - let stat = expr.ancestors::().next()?; match stat { LuaStat::FuncStat(_) => { - return Some(()); + return Some(is_signature); } LuaStat::AssignStat(assign_stat) => { // 判断是否在左侧 @@ -134,7 +146,7 @@ fn check_member_reference(semantic_model: &SemanticModel, node: LuaSyntaxNode) - .text_range() .contains(node.text_range().start()) { - return Some(()); + return Some(is_signature); } } return None; @@ -145,12 +157,12 @@ fn check_member_reference(semantic_model: &SemanticModel, node: LuaSyntaxNode) - } } tag_field_node if LuaDocTagField::can_cast(tag_field_node.kind().into()) => { - return Some(()); + return Some(false); } table_field_node if LuaTableField::can_cast(table_field_node.kind().into()) => { let table_field = LuaTableField::cast(table_field_node.clone())?; if table_field.is_assign_field() { - return Some(()); + return Some(false); } else { return None; } @@ -158,8 +170,9 @@ fn check_member_reference(semantic_model: &SemanticModel, node: LuaSyntaxNode) - _ => {} } - Some(()) + Some(false) } + pub fn search_type_implementations( semantic_model: &SemanticModel, compilation: &LuaCompilation, @@ -202,9 +215,23 @@ pub fn search_decl_implementations( if decl.is_local() { let document = semantic_model.get_document(); + let decl_refs = semantic_model + .get_db() + .get_reference_index() + .get_decl_references(&decl_id.file_id, &decl_id)?; + let range = decl.get_range(); let location = document.to_lsp_location(range)?; result.push(location); + + for decl_ref in decl_refs { + if decl_ref.is_write { + if let Some(location) = document.to_lsp_location(decl_ref.range) { + result.push(location); + } + } + } + return Some(()); } else { let name = decl.get_name(); diff --git a/crates/emmylua_ls/src/handlers/inlay_hint/build_function_hint.rs b/crates/emmylua_ls/src/handlers/inlay_hint/build_function_hint.rs index 69e1d7c3f..610c61679 100644 --- a/crates/emmylua_ls/src/handlers/inlay_hint/build_function_hint.rs +++ b/crates/emmylua_ls/src/handlers/inlay_hint/build_function_hint.rs @@ -1,6 +1,9 @@ use std::collections::HashMap; -use emmylua_code_analysis::{humanize_type, LuaSignatureId, LuaType, RenderLevel, SemanticModel}; +use emmylua_code_analysis::{ + format_union_type, humanize_type, LuaSignatureId, LuaType, LuaUnionType, RenderLevel, + SemanticModel, +}; use emmylua_parser::{LuaAstNode, LuaClosureExpr}; use itertools::Itertools; use lsp_types::{InlayHint, InlayHintKind, InlayHintLabel, InlayHintLabelPart, Location}; @@ -45,10 +48,7 @@ pub fn build_closure_hint( let mut label_parts = build_label_parts(semantic_model, &typ); // 为空时添加默认值 if label_parts.is_empty() { - let typ_desc = format!( - ": {}", - humanize_type(semantic_model.get_db(), &typ, RenderLevel::Simple) - ); + let typ_desc = format!(": {}", hint_humanize_type(semantic_model, &typ)); label_parts.push(InlayHintLabelPart { value: typ_desc, location: Some( @@ -134,7 +134,7 @@ fn get_part(semantic_model: &SemanticModel, typ: &LuaType) -> Option { - let value = humanize_type(semantic_model.get_db(), typ, RenderLevel::Simple); + let value = hint_humanize_type(semantic_model, typ); let location = get_type_location(semantic_model, typ); return Some(InlayHintLabelPart { value, @@ -147,7 +147,7 @@ fn get_part(semantic_model: &SemanticModel, typ: &LuaType) -> Option Option { match typ { - LuaType::Ref(id) => { + LuaType::Ref(id) | LuaType::Def(id) => { let type_decl = semantic_model .get_db() .get_type_index() @@ -183,3 +183,34 @@ fn get_base_type_location(semantic_model: &SemanticModel, name: &str) -> Option< let lsp_range = document.to_lsp_range(location.range)?; Some(Location::new(document.get_uri(), lsp_range)) } + +fn hint_humanize_type(semantic_model: &SemanticModel, typ: &LuaType) -> String { + match typ { + LuaType::Ref(id) | LuaType::Def(id) => { + let namespace = semantic_model + .get_db() + .get_type_index() + .get_file_namespace(&semantic_model.get_file_id()); + if let Some(namespace) = namespace { + // 如果 id 最前面是 namespace, 那么移除 + let id_name = id.get_name(); + let namespace_prefix = format!("{}.", namespace); + if id_name.starts_with(&namespace_prefix) { + id_name[namespace_prefix.len()..].to_string() + } else { + id_name.to_string() + } + } else { + id.get_name().to_string() + } + } + LuaType::Union(union) => hint_humanize_union_type(semantic_model, union), + _ => humanize_type(semantic_model.get_db(), typ, RenderLevel::Simple), + } +} + +fn hint_humanize_union_type(semantic_model: &SemanticModel, union: &LuaUnionType) -> String { + format_union_type(union, RenderLevel::Simple, |ty, _| { + hint_humanize_type(semantic_model, ty) + }) +} diff --git a/crates/emmylua_ls/src/handlers/inlay_hint/build_inlay_hint.rs b/crates/emmylua_ls/src/handlers/inlay_hint/build_inlay_hint.rs index ca3879faa..e6577462a 100644 --- a/crates/emmylua_ls/src/handlers/inlay_hint/build_inlay_hint.rs +++ b/crates/emmylua_ls/src/handlers/inlay_hint/build_inlay_hint.rs @@ -1,8 +1,9 @@ use std::collections::HashMap; +use std::sync::Arc; use emmylua_code_analysis::{ - FileId, InferGuard, LuaFunctionType, LuaMemberId, LuaMemberKey, LuaSemanticDeclId, LuaType, - SemanticModel, + FileId, InferGuard, LuaFunctionType, LuaMemberId, LuaMemberKey, LuaOperatorId, + LuaOperatorMetaMethod, LuaSemanticDeclId, LuaType, SemanticModel, }; use emmylua_parser::{ LuaAst, LuaAstNode, LuaCallExpr, LuaExpr, LuaFuncStat, LuaIndexExpr, LuaLocalFuncStat, @@ -14,6 +15,7 @@ use rowan::NodeOrToken; use rowan::TokenAtOffset; +use crate::handlers::definition::compare_function_types; use crate::handlers::inlay_hint::build_function_hint::{build_closure_hint, build_label_parts}; pub fn build_inlay_hints(semantic_model: &SemanticModel) -> Option> { @@ -26,7 +28,8 @@ pub fn build_inlay_hints(semantic_model: &SemanticModel) -> Option { build_call_expr_param_hint(semantic_model, &mut result, call_expr.clone()); - build_call_expr_await_hint(semantic_model, &mut result, call_expr); + build_call_expr_await_hint(semantic_model, &mut result, call_expr.clone()); + build_call_expr_meta_call_hint(semantic_model, &mut result, call_expr); } LuaAst::LuaLocalName(local_name) => { build_local_name_hint(semantic_model, &mut result, local_name); @@ -344,7 +347,7 @@ fn build_local_name_hint( position: lsp_range.end, text_edits: None, tooltip: None, - padding_left: Some(true), + padding_left: None, padding_right: None, data: None, }; @@ -458,3 +461,154 @@ fn get_override_lsp_location( let lsp_range = document.to_lsp_location(range)?; Some(lsp_range) } + +fn build_call_expr_meta_call_hint( + semantic_model: &SemanticModel, + result: &mut Vec, + call_expr: LuaCallExpr, +) -> Option<()> { + if !semantic_model.get_emmyrc().hint.meta_call_hint { + return Some(()); + } + + let prefix_expr = call_expr.get_prefix_expr()?; + let semantic_info = + semantic_model.get_semantic_info(NodeOrToken::Node(prefix_expr.syntax().clone()))?; + + match &semantic_info.typ { + LuaType::Ref(id) | LuaType::Def(id) => { + let decl = semantic_model.get_db().get_type_index().get_type_decl(id)?; + if !decl.is_class() { + return Some(()); + } + + let call_operator_ids = semantic_model + .get_db() + .get_operator_index() + .get_operators(&id.clone().into(), LuaOperatorMetaMethod::Call)?; + + set_meta_call_part( + semantic_model, + result, + call_operator_ids, + call_expr, + semantic_info.typ, + )?; + } + _ => {} + } + Some(()) +} + +fn set_meta_call_part( + semantic_model: &SemanticModel, + result: &mut Vec, + operator_ids: &Vec, + call_expr: LuaCallExpr, + target_type: LuaType, +) -> Option<()> { + let (operator_id, call_func) = + find_match_meta_call_operator_id(semantic_model, operator_ids, call_expr.clone())?; + + let operator = semantic_model + .get_db() + .get_operator_index() + .get_operator(&operator_id)?; + + let location = { + let range = operator.get_range(); + let document = semantic_model.get_document_by_file_id(operator.get_file_id())?; + let lsp_range = document.to_lsp_range(range)?; + Location::new(document.get_uri(), lsp_range) + }; + + let document = semantic_model.get_document(); + let parent = call_expr.syntax().parent()?; + + // 如果是 `Class(...)` 且调用返回值是 Class 类型, 则显示 `new` 提示 + let hint_new = { + LuaStat::can_cast(parent.kind().into()) + && !matches!(call_expr.get_prefix_expr()?, LuaExpr::CallExpr(_)) + && semantic_model + .type_check(call_func.get_ret(), &target_type) + .is_ok() + }; + + let (value, hint_range, padding_right) = if hint_new { + ("new".to_string(), call_expr.get_range(), Some(true)) + } else { + ( + "⚡".to_string(), + call_expr.get_prefix_expr()?.get_range(), + None, + ) + }; + + let hint_position = { + let lsp_range = document.to_lsp_range(hint_range)?; + if hint_new { + lsp_range.start + } else { + lsp_range.end + } + }; + + let part = InlayHintLabelPart { + value, + location: Some(location), + ..Default::default() + }; + + let hint = InlayHint { + kind: Some(InlayHintKind::TYPE), + label: InlayHintLabel::LabelParts(vec![part]), + position: hint_position, + text_edits: None, + tooltip: None, + padding_left: None, + padding_right, + data: None, + }; + + result.push(hint); + Some(()) +} + +fn find_match_meta_call_operator_id( + semantic_model: &SemanticModel, + operator_ids: &Vec, + call_expr: LuaCallExpr, +) -> Option<(LuaOperatorId, Arc)> { + let call_func = semantic_model.infer_call_expr_func(call_expr.clone(), None)?; + if operator_ids.len() == 1 { + return Some((operator_ids.first().cloned()?, call_func)); + } + for operator_id in operator_ids { + let operator = semantic_model + .get_db() + .get_operator_index() + .get_operator(operator_id)?; + let operator_func = { + let operator_type = operator.get_operator_func(semantic_model.get_db()); + match operator_type { + LuaType::DocFunction(func) => func, + LuaType::Signature(signature_id) => { + let signature = semantic_model + .get_db() + .get_signature_index() + .get(&signature_id)?; + signature.to_doc_func_type() + } + _ => return None, + } + }; + let is_match = + compare_function_types(semantic_model, &call_func, &operator_func, &call_expr) + .unwrap_or(false); + + if is_match { + return Some((operator_id.clone(), operator_func)); + } + } + operator_ids.first().cloned().map(|id| (id, call_func)) +} diff --git a/crates/emmylua_ls/src/handlers/inline_values/mod.rs b/crates/emmylua_ls/src/handlers/inline_values/mod.rs index e6df03179..acd0d1e9e 100644 --- a/crates/emmylua_ls/src/handlers/inline_values/mod.rs +++ b/crates/emmylua_ls/src/handlers/inline_values/mod.rs @@ -19,6 +19,9 @@ pub async fn on_inline_values_handler( let analysis = context.analysis.read().await; let file_id = analysis.get_file_id(&uri)?; let mut semantic_model = analysis.compilation.get_semantic_model(file_id)?; + if !semantic_model.get_emmyrc().inline_values.enable { + return None; + } build_inline_values(&mut semantic_model, stop_position) } diff --git a/crates/emmylua_ls/src/handlers/mod.rs b/crates/emmylua_ls/src/handlers/mod.rs index 7064d23ba..1af8bc0e3 100644 --- a/crates/emmylua_ls/src/handlers/mod.rs +++ b/crates/emmylua_ls/src/handlers/mod.rs @@ -29,6 +29,7 @@ mod signature_helper; mod test; mod test_lib; mod text_document; +mod workspace; mod workspace_symbol; pub use initialized::{init_analysis, initialized_handler, ClientConfig}; @@ -127,6 +128,7 @@ pub fn server_capabilities(client_capabilities: &ClientCapabilities) -> ServerCa &mut server_capabilities, client_capabilities, ); + register::(&mut server_capabilities, client_capabilities); // register::( // &mut server_capabilities, // client_capabilities, diff --git a/crates/emmylua_ls/src/handlers/notification_handler.rs b/crates/emmylua_ls/src/handlers/notification_handler.rs index 9b79a4670..1e8412210 100644 --- a/crates/emmylua_ls/src/handlers/notification_handler.rs +++ b/crates/emmylua_ls/src/handlers/notification_handler.rs @@ -5,14 +5,17 @@ use lsp_server::Notification; use lsp_types::{ notification::{ Cancel, DidChangeConfiguration, DidChangeTextDocument, DidChangeWatchedFiles, - DidCloseTextDocument, DidOpenTextDocument, DidSaveTextDocument, + DidCloseTextDocument, DidOpenTextDocument, DidRenameFiles, DidSaveTextDocument, Notification as lsp_notification, SetTrace, }, CancelParams, NumberOrString, }; use serde::de::DeserializeOwned; -use crate::context::{ServerContext, ServerContextSnapshot}; +use crate::{ + context::{ServerContext, ServerContextSnapshot}, + handlers::workspace::on_did_rename_files_handler, +}; use super::{ configuration::on_did_change_configuration, @@ -37,6 +40,7 @@ pub async fn on_notification_handler( .on_parallel::(on_did_change_watched_files) .on_parallel::(on_set_trace) .on_parallel::(on_did_change_configuration) + .on_parallel::(on_did_rename_files_handler) .finish(); Ok(()) diff --git a/crates/emmylua_ls/src/handlers/references/reference_seacher.rs b/crates/emmylua_ls/src/handlers/references/reference_seacher.rs index dd16f04e7..8627a48bc 100644 --- a/crates/emmylua_ls/src/handlers/references/reference_seacher.rs +++ b/crates/emmylua_ls/src/handlers/references/reference_seacher.rs @@ -54,6 +54,10 @@ pub fn search_decl_references( .get_reference_index() .get_decl_references(&decl_id.file_id, &decl_id)?; let document = semantic_model.get_document(); + // 加入自己 + if let Some(location) = document.to_lsp_location(decl.get_range()) { + result.push(location); + } for decl_ref in decl_refs { let location = document.to_lsp_location(decl_ref.range.clone())?; result.push(location); diff --git a/crates/emmylua_ls/src/handlers/rename/rename_member.rs b/crates/emmylua_ls/src/handlers/rename/rename_member.rs index 81a9aa66f..2dbfe194c 100644 --- a/crates/emmylua_ls/src/handlers/rename/rename_member.rs +++ b/crates/emmylua_ls/src/handlers/rename/rename_member.rs @@ -25,7 +25,7 @@ pub fn rename_member_references( .get_reference_index() .get_index_references(&key)?; - let property_owner = find_member_origin_owner(semantic_model, member_id) + let property_owner = find_member_origin_owner(compilation, semantic_model, member_id) .unwrap_or(LuaSemanticDeclId::Member(member_id)); let mut semantic_cache = HashMap::new(); for in_filed_syntax_id in index_references { diff --git a/crates/emmylua_ls/src/handlers/semantic_token/build_semantic_tokens.rs b/crates/emmylua_ls/src/handlers/semantic_token/build_semantic_tokens.rs index 612fc1a3b..bce810206 100644 --- a/crates/emmylua_ls/src/handlers/semantic_token/build_semantic_tokens.rs +++ b/crates/emmylua_ls/src/handlers/semantic_token/build_semantic_tokens.rs @@ -1,10 +1,11 @@ use emmylua_code_analysis::{ - LuaMemberId, LuaMemberOwner, LuaSemanticDeclId, LuaType, SemanticDeclLevel, SemanticModel, + LuaDecl, LuaDeclExtra, LuaMemberId, LuaMemberOwner, LuaSemanticDeclId, LuaType, LuaTypeDeclId, + SemanticDeclLevel, SemanticModel, }; use emmylua_parser::{ - LuaAst, LuaAstNode, LuaAstToken, LuaDocFieldKey, LuaDocObjectFieldKey, LuaExpr, - LuaGeneralToken, LuaLiteralToken, LuaNameToken, LuaSyntaxNode, LuaSyntaxToken, LuaTokenKind, - LuaVarExpr, + LuaAst, LuaAstNode, LuaAstToken, LuaCallExpr, LuaDocFieldKey, LuaDocObjectFieldKey, LuaExpr, + LuaGeneralToken, LuaKind, LuaLiteralToken, LuaNameToken, LuaSyntaxKind, LuaSyntaxNode, + LuaSyntaxToken, LuaTokenKind, LuaVarExpr, }; use lsp_types::{SemanticToken, SemanticTokenModifier, SemanticTokenType}; use rowan::NodeOrToken; @@ -179,23 +180,7 @@ fn build_tokens_semantic_token( builder.push(token, SemanticTokenType::KEYWORD); } LuaTokenKind::TkDocStart => { - let range = token.text_range(); - // find '@' - let text = token.text(); - let mut start = 0; - for (i, c) in text.char_indices() { - if c == '@' { - start = i; - break; - } - } - let position = u32::from(range.start()) + start as u32; - builder.push_at_position( - position.into(), - 1, - SemanticTokenType::KEYWORD, - SemanticTokenModifier::DOCUMENTATION, - ); + render_doc_at(builder, &token); } _ => {} } @@ -273,8 +258,39 @@ fn build_node_semantic_token( } } LuaAst::LuaDocTagCast(doc_cast) => { - let name = doc_cast.get_name_token()?; - builder.push(name.syntax(), SemanticTokenType::VARIABLE); + if let Some(target_expr) = doc_cast.get_key_expr() { + match target_expr { + LuaExpr::NameExpr(name_expr) => { + builder.push( + name_expr.get_name_token()?.syntax(), + SemanticTokenType::VARIABLE, + ); + } + LuaExpr::IndexExpr(index_expr) => { + let position = index_expr.syntax().text_range().start(); + let len = index_expr.syntax().text_range().len(); + builder.push_at_position( + position.into(), + len.into(), + SemanticTokenType::VARIABLE, + None, + ); + } + _ => {} + } + } + if let Some(NodeOrToken::Token(token)) = doc_cast.syntax().prev_sibling_or_token() { + if token.kind() == LuaKind::Token(LuaTokenKind::TkDocLongStart) { + render_doc_at(builder, &token); + } + } + } + LuaAst::LuaDocTagAs(doc_as) => { + if let Some(NodeOrToken::Token(token)) = doc_as.syntax().prev_sibling_or_token() { + if token.kind() == LuaKind::Token(LuaTokenKind::TkDocLongStart) { + render_doc_at(builder, &token); + } + } } LuaAst::LuaDocTagGeneric(doc_generic) => { let type_parameter_list = doc_generic.get_generic_decl_list()?; @@ -447,6 +463,10 @@ fn build_node_semantic_token( builder.push(name.syntax(), SemanticTokenType::FUNCTION); return Some(()); } + if decl_type.is_def() { + builder.push(name.syntax(), SemanticTokenType::CLASS); + return Some(()); + } let owner_id = semantic_model .get_db() @@ -526,19 +546,6 @@ fn build_node_semantic_token( Some(()) } -fn is_class_def(semantic_model: &SemanticModel, node: LuaSyntaxNode) -> Option<()> { - let semantic_decl = semantic_model.find_decl(node.into(), SemanticDeclLevel::default())?; - if let LuaSemanticDeclId::LuaDecl(decl_id) = semantic_decl { - let decl_type = semantic_model.get_type(decl_id.into()); - match decl_type { - LuaType::Def(_) => Some(()), - _ => None, - } - } else { - None - } -} - // 处理`local a = class``local a = class.method/field` fn handle_name_node( semantic_model: &SemanticModel, @@ -554,14 +561,9 @@ fn handle_name_node( ); return Some(()); } - if is_class_def(semantic_model, node.clone()).is_some() { - builder.push(name_token.syntax(), SemanticTokenType::CLASS); - return Some(()); - } let semantic_decl = - semantic_model.find_decl(node.clone().into(), SemanticDeclLevel::Trace(50))?; - + semantic_model.find_decl(node.clone().into(), SemanticDeclLevel::default())?; match semantic_decl { LuaSemanticDeclId::Member(member_id) => { let decl_type = semantic_model.get_type(member_id.into()); @@ -572,8 +574,30 @@ fn handle_name_node( } LuaSemanticDeclId::LuaDecl(decl_id) => { + let decl = semantic_model + .get_db() + .get_decl_index() + .get_decl(&decl_id)?; let decl_type = semantic_model.get_type(decl_id.into()); + let (token_type, modifier) = match decl_type { + LuaType::Def(_) => (SemanticTokenType::CLASS, None), + LuaType::Ref(ref_id) => { + if let Some(is_require) = + check_ref_is_require_def(semantic_model, &decl, &ref_id) + { + if is_require { + ( + SemanticTokenType::CLASS, + Some(SemanticTokenModifier::READONLY), + ) + } else { + (SemanticTokenType::VARIABLE, None) + } + } else { + (SemanticTokenType::VARIABLE, None) + } + } LuaType::Signature(signature) => { let is_meta = semantic_model .get_db() @@ -584,17 +608,38 @@ fn handle_name_node( is_meta.then_some(SemanticTokenModifier::DEFAULT_LIBRARY), ) } - _ => { - let decl = semantic_model - .get_db() - .get_decl_index() - .get_decl(&decl_id)?; - if decl.is_param() { - (SemanticTokenType::PARAMETER, None) + LuaType::DocFunction(_) => (SemanticTokenType::FUNCTION, None), + LuaType::Union(union) => { + if union.get_types().iter().any(|typ| typ.is_function()) { + (SemanticTokenType::FUNCTION, None) } else { - (SemanticTokenType::VARIABLE, None) + if decl.is_param() { + (SemanticTokenType::PARAMETER, None) + } else { + (SemanticTokenType::VARIABLE, None) + } } } + _ => match &decl.extra { + LuaDeclExtra::Param { + idx, signature_id, .. + } => { + let signature = semantic_model + .get_db() + .get_signature_index() + .get(&signature_id)?; + if let Some(param_info) = signature.get_param_info_by_id(*idx) { + if param_info.type_ref.is_function() { + (SemanticTokenType::FUNCTION, None) + } else { + (SemanticTokenType::PARAMETER, None) + } + } else { + (SemanticTokenType::VARIABLE, None) + } + } + _ => (SemanticTokenType::VARIABLE, None), + }, }; if let Some(modifier) = modifier { @@ -611,3 +656,69 @@ fn handle_name_node( builder.push(name_token.syntax(), SemanticTokenType::VARIABLE); Some(()) } + +fn render_doc_at(builder: &mut SemanticBuilder, token: &LuaSyntaxToken) { + let range = token.text_range(); + // find '@' + let text = token.text(); + let mut start = 0; + for (i, c) in text.char_indices() { + if c == '@' { + start = i; + break; + } + } + let position = u32::from(range.start()) + start as u32; + builder.push_at_position( + position.into(), + 1, + SemanticTokenType::KEYWORD, + Some(SemanticTokenModifier::DOCUMENTATION), + ); +} + +// 检查导入语句是否是类定义 +fn check_ref_is_require_def( + semantic_model: &SemanticModel, + decl: &LuaDecl, + ref_id: &LuaTypeDeclId, +) -> Option { + let value_syntax_id = decl.get_value_syntax_id()?; + if value_syntax_id.get_kind() != LuaSyntaxKind::RequireCallExpr { + return None; + } + let node = semantic_model + .get_db() + .get_vfs() + .get_syntax_tree(&decl.get_file_id()) + .and_then(|tree| { + let root = tree.get_red_root(); + semantic_model + .get_db() + .get_decl_index() + .get_decl(&decl.get_id()) + .and_then(|decl| decl.get_value_syntax_id()) + .and_then(|syntax_id| syntax_id.to_node_from_root(&root)) + })?; + let call_expr = LuaCallExpr::cast(node)?; + let arg_list = call_expr.get_args_list()?; + let first_arg = arg_list.get_args().next()?; + let require_path_type = semantic_model.infer_expr(first_arg.clone()).ok()?; + let module_path: String = match &require_path_type { + LuaType::StringConst(module_path) => module_path.as_ref().to_string(), + _ => { + return None; + } + }; + let module_info = semantic_model + .get_db() + .get_module_index() + .find_module(&module_path)?; + match &module_info.export_type { + Some(ty) => match ty { + LuaType::Def(id) => Some(id == ref_id), + _ => Some(false), + }, + None => None, + } +} diff --git a/crates/emmylua_ls/src/handlers/semantic_token/mod.rs b/crates/emmylua_ls/src/handlers/semantic_token/mod.rs index 9d2f2d97a..843072954 100644 --- a/crates/emmylua_ls/src/handlers/semantic_token/mod.rs +++ b/crates/emmylua_ls/src/handlers/semantic_token/mod.rs @@ -1,8 +1,9 @@ mod build_semantic_tokens; mod semantic_token_builder; -use crate::context::ServerContextSnapshot; +use crate::context::{ClientId, ServerContextSnapshot}; use build_semantic_tokens::build_semantic_tokens; +use emmylua_code_analysis::{EmmyLuaAnalysis, FileId}; use lsp_types::{ ClientCapabilities, SemanticTokens, SemanticTokensFullOptions, SemanticTokensLegend, SemanticTokensOptions, SemanticTokensParams, SemanticTokensResult, @@ -26,8 +27,15 @@ pub async fn on_semantic_token_handler( let client_id = config_manager.client_config.client_id; let _ = config_manager; let file_id = analysis.get_file_id(&uri)?; - let mut semantic_model = analysis.compilation.get_semantic_model(file_id)?; + semantic_token(&analysis, file_id, client_id) +} +pub fn semantic_token( + analysis: &EmmyLuaAnalysis, + file_id: FileId, + client_id: ClientId, +) -> Option { + let mut semantic_model = analysis.compilation.get_semantic_model(file_id)?; if !semantic_model.get_emmyrc().semantic_tokens.enable { return None; } diff --git a/crates/emmylua_ls/src/handlers/semantic_token/semantic_token_builder.rs b/crates/emmylua_ls/src/handlers/semantic_token/semantic_token_builder.rs index 582e294e7..af6d87680 100644 --- a/crates/emmylua_ls/src/handlers/semantic_token/semantic_token_builder.rs +++ b/crates/emmylua_ls/src/handlers/semantic_token/semantic_token_builder.rs @@ -173,7 +173,7 @@ impl<'a> SemanticBuilder<'a> { position: TextSize, length: u32, ty: SemanticTokenType, - modifiers: SemanticTokenModifier, + modifiers: Option, ) -> Option<()> { let lsp_position = self.document.to_lsp_position(position)?; let start_line = lsp_position.line; @@ -186,7 +186,7 @@ impl<'a> SemanticBuilder<'a> { col: start_col as u32, length, typ: *self.type_to_id.get(&ty)?, - modifiers: 1 << *self.modifier_to_id.get(&modifiers)?, + modifiers: modifiers.map_or(0, |m| 1 << *self.modifier_to_id.get(&m).unwrap_or(&0)), }), ); Some(()) diff --git a/crates/emmylua_ls/src/handlers/signature_helper/build_signature_helper.rs b/crates/emmylua_ls/src/handlers/signature_helper/build_signature_helper.rs index cfa943bb7..3cec89a67 100644 --- a/crates/emmylua_ls/src/handlers/signature_helper/build_signature_helper.rs +++ b/crates/emmylua_ls/src/handlers/signature_helper/build_signature_helper.rs @@ -1,6 +1,6 @@ use emmylua_code_analysis::{ - DbIndex, InFiled, LuaFunctionType, LuaInstanceType, LuaOperatorMetaMethod, LuaOperatorOwner, - LuaSignatureId, LuaType, LuaTypeDeclId, RenderLevel, SemanticModel, + DbIndex, InFiled, LuaCompilation, LuaFunctionType, LuaInstanceType, LuaOperatorMetaMethod, + LuaOperatorOwner, LuaSignatureId, LuaType, LuaTypeDeclId, RenderLevel, SemanticModel, }; use emmylua_parser::{LuaAstNode, LuaCallExpr, LuaSyntaxToken, LuaTokenKind}; use lsp_types::{ @@ -15,12 +15,13 @@ use super::signature_helper_builder::SignatureHelperBuilder; pub fn build_signature_helper( semantic_model: &SemanticModel, + compilation: &LuaCompilation, call_expr: LuaCallExpr, token: LuaSyntaxToken, ) -> Option { let prefix_expr = call_expr.get_prefix_expr()?; let prefix_expr_type = semantic_model.infer_expr(prefix_expr.clone()).ok()?; - let builder = SignatureHelperBuilder::new(semantic_model, call_expr.clone()); + let builder = SignatureHelperBuilder::new(compilation, semantic_model, call_expr.clone()); let colon_call = call_expr.is_colon_call(); let current_idx = get_current_param_index(&call_expr, &token)?; let help = match prefix_expr_type { diff --git a/crates/emmylua_ls/src/handlers/signature_helper/mod.rs b/crates/emmylua_ls/src/handlers/signature_helper/mod.rs index 641e44a9f..d4ca5213d 100644 --- a/crates/emmylua_ls/src/handlers/signature_helper/mod.rs +++ b/crates/emmylua_ls/src/handlers/signature_helper/mod.rs @@ -63,7 +63,7 @@ pub fn signature_help( match node.kind().into() { LuaSyntaxKind::CallArgList => { let call_expr = LuaCallExpr::cast(node.parent()?)?; - build_signature_helper(&mut semantic_model, call_expr, token) + build_signature_helper(&mut semantic_model, &analysis.compilation, call_expr, token) } // todo LuaSyntaxKind::TypeGeneric | LuaSyntaxKind::DocTypeList => None, @@ -90,7 +90,7 @@ pub fn signature_help( match node.kind().into() { LuaSyntaxKind::CallArgList => { let call_expr = LuaCallExpr::cast(node.parent()?)?; - build_signature_helper(&mut semantic_model, call_expr, token) + build_signature_helper(&mut semantic_model, &analysis.compilation, call_expr, token) } // todo LuaSyntaxKind::TypeGeneric | LuaSyntaxKind::DocTypeList => None, diff --git a/crates/emmylua_ls/src/handlers/signature_helper/signature_helper_builder.rs b/crates/emmylua_ls/src/handlers/signature_helper/signature_helper_builder.rs index a631643df..0125faf3a 100644 --- a/crates/emmylua_ls/src/handlers/signature_helper/signature_helper_builder.rs +++ b/crates/emmylua_ls/src/handlers/signature_helper/signature_helper_builder.rs @@ -1,5 +1,6 @@ use emmylua_code_analysis::{ - FileId, LuaMemberOwner, LuaSemanticDeclId, LuaType, SemanticDeclLevel, SemanticModel, + FileId, LuaCompilation, LuaMemberOwner, LuaSemanticDeclId, LuaType, SemanticDeclLevel, + SemanticModel, }; use emmylua_parser::{LuaAstNode, LuaCallExpr, LuaExpr}; use lsp_types::{Documentation, MarkupContent, MarkupKind, ParameterInformation, ParameterLabel}; @@ -14,6 +15,8 @@ use super::build_signature_helper::{build_function_label, generate_param_label}; #[derive(Debug)] pub struct SignatureHelperBuilder<'a> { pub semantic_model: &'a SemanticModel<'a>, + pub compilation: &'a LuaCompilation, + pub call_expr: LuaCallExpr, pub prefix_name: Option, pub function_name: String, @@ -24,8 +27,13 @@ pub struct SignatureHelperBuilder<'a> { } impl<'a> SignatureHelperBuilder<'a> { - pub fn new(semantic_model: &'a SemanticModel<'a>, call_expr: LuaCallExpr) -> Self { + pub fn new( + compilation: &'a LuaCompilation, + semantic_model: &'a SemanticModel<'a>, + call_expr: LuaCallExpr, + ) -> Self { let mut builder = Self { + compilation, semantic_model, call_expr, prefix_name: None, @@ -69,7 +77,8 @@ impl<'a> SignatureHelperBuilder<'a> { // 推断为来源 semantic_decl = match semantic_decl { Some(LuaSemanticDeclId::Member(member_id)) => { - find_member_origin_owner(semantic_model, member_id).or(semantic_decl) + find_member_origin_owner(self.compilation, semantic_model, member_id) + .or(semantic_decl) } Some(LuaSemanticDeclId::LuaDecl(_)) => semantic_decl, _ => None, diff --git a/crates/emmylua_ls/src/handlers/test/code_actions_test.rs b/crates/emmylua_ls/src/handlers/test/code_actions_test.rs new file mode 100644 index 000000000..a81480a87 --- /dev/null +++ b/crates/emmylua_ls/src/handlers/test/code_actions_test.rs @@ -0,0 +1,29 @@ +#[cfg(test)] +mod tests { + + use crate::handlers::test_lib::ProviderVirtualWorkspace; + + #[test] + fn test_1() { + let mut ws = ProviderVirtualWorkspace::new(); + ws.def( + r#" + ---@class Cast1 + ---@field get fun(self: self, a: number): Cast1? + "#, + ); + + let actions = ws + .check_code_action( + r#" + ---@type Cast1 + local A + + local _a = A:get(1):get(2):get(3) + "#, + ) + .unwrap(); + // 6 个禁用 + 2 个修复 + assert_eq!(actions.len(), 8); + } +} diff --git a/crates/emmylua_ls/src/handlers/test/completion_test.rs b/crates/emmylua_ls/src/handlers/test/completion_test.rs index c3204776f..7a4d38660 100644 --- a/crates/emmylua_ls/src/handlers/test/completion_test.rs +++ b/crates/emmylua_ls/src/handlers/test/completion_test.rs @@ -1,6 +1,9 @@ #[cfg(test)] mod tests { + use std::{ops::Deref, sync::Arc}; + + use emmylua_code_analysis::EmmyrcFilenameConvention; use lsp_types::{CompletionItemKind, CompletionTriggerKind}; use crate::handlers::test_lib::{ProviderVirtualWorkspace, VirtualCompletionItem}; @@ -752,4 +755,284 @@ mod tests { CompletionTriggerKind::TRIGGER_CHARACTER, )); } + + #[test] + fn test_issue_502() { + let mut ws = ProviderVirtualWorkspace::new(); + ws.def( + r#" + ---@param a { foo: { bar: number } } + function buz(a) end + "#, + ); + assert!(ws.check_completion_with_kind( + r#" + buz({ + foo = { + b + } + }) + "#, + vec![VirtualCompletionItem { + label: "bar = ".to_string(), + kind: CompletionItemKind::PROPERTY, + ..Default::default() + },], + CompletionTriggerKind::TRIGGER_CHARACTER, + )); + } + + #[test] + fn test_class_function_1() { + let mut ws = ProviderVirtualWorkspace::new(); + ws.def( + r#" + ---@class C1 + ---@field on_add fun(a: string, b: string) + "#, + ); + assert!(ws.check_completion_with_kind( + r#" + ---@type C1 + local c1 + + c1.on_add = + "#, + vec![VirtualCompletionItem { + label: "function(a, b) end".to_string(), + kind: CompletionItemKind::FUNCTION, + ..Default::default() + },], + CompletionTriggerKind::TRIGGER_CHARACTER, + )); + } + + #[test] + fn test_class_function_2() { + let mut ws = ProviderVirtualWorkspace::new(); + ws.def( + r#" + ---@class C1 + ---@field on_add fun(self: C1, a: string, b: string) + "#, + ); + assert!(ws.check_completion_with_kind( + r#" + ---@type C1 + local c1 + + function c1:() + + end + "#, + vec![VirtualCompletionItem { + label: "on_add".to_string(), + kind: CompletionItemKind::FUNCTION, + label_detail: Some("(a, b)".to_string()), + },], + CompletionTriggerKind::TRIGGER_CHARACTER, + )); + } + + #[test] + fn test_class_function_3() { + let mut ws = ProviderVirtualWorkspace::new(); + ws.def( + r#" + ---@class (partial) SkillMutator + ---@field on_add? fun(self: self, owner: string) + + ---@class (partial) SkillMutator.A + ---@field on_add? fun(self: self, owner: string) + "#, + ); + assert!(ws.check_completion_with_kind( + r#" + ---@class (partial) SkillMutator.A + local a + a.on_add = + "#, + vec![VirtualCompletionItem { + label: "function(self, owner) end".to_string(), + kind: CompletionItemKind::FUNCTION, + ..Default::default() + },], + CompletionTriggerKind::TRIGGER_CHARACTER, + )); + } + + #[test] + fn test_class_function_4() { + let mut ws = ProviderVirtualWorkspace::new(); + ws.def( + r#" + ---@class (partial) SkillMutator + ---@field on_add? fun(self: self, owner: string) + + ---@class (partial) SkillMutator.A + ---@field on_add? fun(self: self, owner: string) + "#, + ); + assert!(ws.check_completion_with_kind( + r#" + ---@class (partial) SkillMutator.A + local a + function a:() + + end + + "#, + vec![VirtualCompletionItem { + label: "on_add".to_string(), + kind: CompletionItemKind::FUNCTION, + label_detail: Some("(owner)".to_string()), + },], + CompletionTriggerKind::TRIGGER_CHARACTER, + )); + } + + #[test] + fn test_auto_require() { + let mut ws = ProviderVirtualWorkspace::new(); + let mut emmyrc = ws.analysis.emmyrc.deref().clone(); + emmyrc.completion.auto_require_naming_convention = EmmyrcFilenameConvention::KeepClass; + ws.analysis.update_config(Arc::new(emmyrc)); + ws.def_file( + "map.lua", + r#" + ---@class Map + local Map = {} + + return Map + "#, + ); + assert!(ws.check_completion( + r#" + ma + "#, + vec![VirtualCompletionItem { + label: "Map".to_string(), + kind: CompletionItemKind::MODULE, + label_detail: Some(" (in map)".to_string()), + },], + )); + } + + #[test] + fn test_auto_require_table_field() { + let mut ws = ProviderVirtualWorkspace::new(); + let mut emmyrc = ws.analysis.emmyrc.deref().clone(); + emmyrc.completion.auto_require_naming_convention = EmmyrcFilenameConvention::KeepClass; + ws.analysis.update_config(Arc::new(emmyrc)); + ws.def_file( + "aaaa.lua", + r#" + local export = {} + + ---@enum MapName + export.MapName = { + A = 1, + B = 2, + } + + return export + "#, + ); + assert!(ws.check_completion( + r#" + mapn + "#, + vec![VirtualCompletionItem { + label: "MapName".to_string(), + kind: CompletionItemKind::MODULE, + label_detail: Some(" (in aaaa)".to_string()), + },], + )); + } + + #[test] + fn test_field_is_alias_function() { + let mut ws = ProviderVirtualWorkspace::new(); + ws.def( + r#" + ---@alias ProxyHandler.Setter fun(raw: any) + + ---@class ProxyHandler + ---@field set? ProxyHandler.Setter + "#, + ); + assert!(ws.check_completion_with_kind( + r#" + ---@class MHandler: ProxyHandler + local MHandler + + MHandler.set = + + "#, + vec![VirtualCompletionItem { + label: "function(raw) end".to_string(), + kind: CompletionItemKind::FUNCTION, + ..Default::default() + },], + CompletionTriggerKind::TRIGGER_CHARACTER, + )); + } + + #[test] + fn test_namespace_base() { + let mut ws = ProviderVirtualWorkspace::new(); + ws.def( + r#" + ---@namespace Reactive + "#, + ); + ws.def( + r#" + ---@namespace AlienSignals + "#, + ); + assert!(ws.check_completion_with_kind( + r#" + ---@namespace + + "#, + vec![ + VirtualCompletionItem { + label: "AlienSignals".to_string(), + kind: CompletionItemKind::MODULE, + ..Default::default() + }, + VirtualCompletionItem { + label: "Reactive".to_string(), + kind: CompletionItemKind::MODULE, + ..Default::default() + }, + ], + CompletionTriggerKind::TRIGGER_CHARACTER, + )); + + assert!(ws.check_completion_with_kind( + r#" + ---@namespace Reactive + ---@namespace + + "#, + vec![], + CompletionTriggerKind::TRIGGER_CHARACTER, + )); + + assert!(ws.check_completion_with_kind( + r#" + ---@namespace Reactive + ---@using + + "#, + vec![VirtualCompletionItem { + label: "using AlienSignals".to_string(), + kind: CompletionItemKind::MODULE, + ..Default::default() + },], + CompletionTriggerKind::INVOKED, + )); + } } diff --git a/crates/emmylua_ls/src/handlers/test/definition_test.rs b/crates/emmylua_ls/src/handlers/test/definition_test.rs index 38d03e733..49872ac3a 100644 --- a/crates/emmylua_ls/src/handlers/test/definition_test.rs +++ b/crates/emmylua_ls/src/handlers/test/definition_test.rs @@ -1,5 +1,7 @@ #[cfg(test)] mod tests { + use lsp_types::GotoDefinitionResponse; + use crate::handlers::test_lib::ProviderVirtualWorkspace; #[test] @@ -72,4 +74,103 @@ mod tests { "#, ); } + + #[test] + fn test_goto_overload() { + let mut ws = ProviderVirtualWorkspace::new(); + ws.def( + r#" + ---@class Goto1 + ---@class Goto2 + ---@class Goto3 + + ---@class T + ---@field func fun(a:Goto1) # 1 + ---@field func fun(a:Goto2) # 2 + ---@field func fun(a:Goto3) # 3 + local T = {} + + function T:func(a) + end + "#, + ); + + { + let result = ws + .check_definition( + r#" + ---@type Goto2 + local Goto2 + + ---@type T + local t + t.func(Goto2) + "#, + ) + .unwrap(); + match result { + GotoDefinitionResponse::Array(array) => { + assert_eq!(array.len(), 2); + } + _ => { + panic!("expect array"); + } + } + } + + { + let result = ws + .check_definition( + r#" + ---@type T + local t + t.func() + "#, + ) + .unwrap(); + match result { + GotoDefinitionResponse::Array(array) => { + assert_eq!(array.len(), 4); + } + _ => { + panic!("expect array"); + } + } + } + } + + #[test] + fn test_goto_return_field() { + let mut ws = ProviderVirtualWorkspace::new(); + ws.def_file( + "test.lua", + r#" + local function test() + + end + + return { + test = test, + } + "#, + ); + let result = ws + .check_definition( + r#" + local t = require("test") + local test = t.test + test() + "#, + ) + .unwrap(); + match result { + GotoDefinitionResponse::Array(locations) => { + assert_eq!(locations.len(), 1); + assert_eq!(locations[0].range.start.line, 1); + } + _ => { + panic!("expect scalar"); + } + } + } } diff --git a/crates/emmylua_ls/src/handlers/test/hover_function_test.rs b/crates/emmylua_ls/src/handlers/test/hover_function_test.rs index d55184ea0..d8ef171aa 100644 --- a/crates/emmylua_ls/src/handlers/test/hover_function_test.rs +++ b/crates/emmylua_ls/src/handlers/test/hover_function_test.rs @@ -18,7 +18,7 @@ mod tests { local delete4 = delete3 "#, VirtualHoverResult { - value: "\n```lua\nlocal function delete(a: number)\n -> a: number\n\n```\n\n---\n\n@*param* `a` — 参数a\n\n\n\n@*return* `a` — 返回值a\n\n\n".to_string(), + value: "```lua\nlocal function delete(a: number)\n -> a: number\n\n```\n\n---\n\n@*param* `a` — 参数a\n\n\n\n@*return* `a` — 返回值a".to_string(), }, )); @@ -38,7 +38,7 @@ mod tests { } "#, VirtualHoverResult { - value: "\n```lua\nlocal function delete(a: number)\n -> a: number\n\n```\n\n---\n\n删除\n\n@*param* `a` — 参数a\n\n\n\n@*return* `a` — 返回值a\n\n\n".to_string(), + value: "```lua\nlocal function delete(a: number)\n -> a: number\n\n```\n\n---\n\n删除\n\n@*param* `a` — 参数a\n\n\n\n@*return* `a` — 返回值a".to_string(), }, )); @@ -61,7 +61,7 @@ mod tests { } "#, VirtualHoverResult { - value: "\n```lua\nlocal function delete(a: number)\n -> a: number\n\n```\n\n---\n\n@*param* `a` — 参数a\n\n\n\n@*return* `a` — 返回值a\n\n\n".to_string(), + value: "```lua\nlocal function delete(a: number)\n -> a: number\n\n```\n\n---\n\n@*param* `a` — 参数a\n\n\n\n@*return* `a` — 返回值a".to_string(), }, )); } @@ -97,7 +97,7 @@ mod tests { local local_b = local_a "#, VirtualHoverResult { - value: "\n```lua\n(method) Game:add(key: string, value: string)\n -> ret: number\n\n```\n\n---\n\n说明\n\n@*param* `key` — 参数key\n\n@*param* `value` — 参数value\n\n\n\n@*return* `ret` — 返回值\n\n\n" .to_string(), + value: "```lua\n(method) Game:add(key: string, value: string)\n -> ret: number\n\n```\n\n---\n\n说明\n\n@*param* `key` — 参数key\n\n@*param* `value` — 参数value\n\n\n\n@*return* `ret` — 返回值".to_string(), }, )); } @@ -122,7 +122,7 @@ mod tests { local event = test3.event "#, VirtualHoverResult { - value: "\n```lua\n(method) Test3:event(event: \"B\", key: string)\n```\n\n  in class `Hover.Test3`\n\n---\n\n---\n\n```lua\n(method) Test3:event(event: \"A\", key: string)\n```\n".to_string(), + value: "```lua\n(method) Test3:event(event: \"B\", key: string)\n```\n\n  in class `Hover.Test3`\n\n---\n\n---\n\n```lua\n(method) Test3:event(event: \"A\", key: string)\n```".to_string(), }, )); } @@ -160,7 +160,7 @@ mod tests { ---@field event fun(self: self, event: "游戏-http返回"): Trigger "#, VirtualHoverResult { - value: "\n```lua\n(method) GameA:event(event_type: EventTypeA, ...: any)\n -> Trigger\n\n```\n\n---\n\n注册引擎事件\n\n---\n\n```lua\n(method) GameA:event(event: \"游戏-初始化\") -> Trigger\n```\n\n```lua\n(method) GameA:event(event: \"游戏-追帧完成\") -> Trigger\n```\n\n```lua\n(method) GameA:event(event: \"游戏-逻辑不同步\") -> Trigger\n```\n\n```lua\n(method) GameA:event(event: \"游戏-地形预设加载完成\") -> Trigger\n```\n\n```lua\n(method) GameA:event(event: \"游戏-结束\") -> Trigger\n```\n\n```lua\n(method) GameA:event(event: \"游戏-暂停\") -> Trigger\n```\n\n```lua\n(method) GameA:event(event: \"游戏-恢复\") -> Trigger\n```\n\n```lua\n(method) GameA:event(event: \"游戏-昼夜变化\") -> Trigger\n```\n\n```lua\n(method) GameA:event(event: \"区域-进入\") -> Trigger\n```\n\n```lua\n(method) GameA:event(event: \"区域-离开\") -> Trigger\n```\n\n```lua\n(method) GameA:event(event: \"游戏-http返回\") -> Trigger\n```\n".to_string(), + value: "```lua\n(method) GameA:event(event_type: EventTypeA, ...: any)\n -> Trigger\n\n```\n\n---\n\n注册引擎事件\n\n---\n\n```lua\n(method) GameA:event(event: \"游戏-初始化\") -> Trigger\n```\n\n```lua\n(method) GameA:event(event: \"游戏-追帧完成\") -> Trigger\n```\n\n```lua\n(method) GameA:event(event: \"游戏-逻辑不同步\") -> Trigger\n```\n\n```lua\n(method) GameA:event(event: \"游戏-地形预设加载完成\") -> Trigger\n```\n\n```lua\n(method) GameA:event(event: \"游戏-结束\") -> Trigger\n```\n\n```lua\n(method) GameA:event(event: \"游戏-暂停\") -> Trigger\n```\n\n```lua\n(method) GameA:event(event: \"游戏-恢复\") -> Trigger\n```\n\n```lua\n(method) GameA:event(event: \"游戏-昼夜变化\") -> Trigger\n```\n\n```lua\n(method) GameA:event(event: \"区域-进入\") -> Trigger\n```\n\n```lua\n(method) GameA:event(event: \"区域-离开\") -> Trigger\n```\n\n```lua\n(method) GameA:event(event: \"游戏-http返回\") -> Trigger\n```".to_string(), }, )); } @@ -179,7 +179,7 @@ mod tests { end "#, VirtualHoverResult { - value: "\n```lua\nfunction ClosureTest.e(a: string, b: number)\n```\n\n---\n\n---\n\n```lua\n(field) ClosureTest.e(a: string, b: number)\n```\n".to_string(), + value: "```lua\nfunction ClosureTest.e(a: string, b: number)\n```\n\n---\n\n---\n\n```lua\n(field) ClosureTest.e(a: string, b: number)\n```".to_string(), }, )); } @@ -200,7 +200,7 @@ mod tests { } "#, VirtualHoverResult { - value: "\n```lua\n(method) T:func()\n```\n\n---\n\n注释注释\n".to_string(), + value: "```lua\n(method) T:func()\n```\n\n---\n\n注释注释".to_string(), }, )); } @@ -219,7 +219,7 @@ mod tests { } "#, VirtualHoverResult { - value: "\n```lua\n(field) a: string = \"a\"\n```\n\n---\n\n注释注释a\n".to_string(), + value: "```lua\n(field) a: string = \"a\"\n```\n\n---\n\n注释注释a".to_string(), }, )); } @@ -239,8 +239,7 @@ mod tests { } "#, VirtualHoverResult { - value: "\n```lua\n(field) T.func(self: string)\n```\n\n---\n\n注释注释\n" - .to_string(), + value: "```lua\n(field) T.func(self: string)\n```\n\n---\n\n注释注释".to_string(), }, )); } @@ -261,7 +260,7 @@ mod tests { } "#, VirtualHoverResult { - value: "\n```lua\n(field) T.func(a: (string|number))\n```\n\n---\n\n注释1\n\n注释2\n\n---\n\n```lua\n(field) T.func(a: string)\n```\n\n```lua\n(field) T.func(a: number)\n```\n" + value: "```lua\n(field) T.func(a: (string|number))\n```\n\n---\n\n注释1\n\n注释2\n\n---\n\n```lua\n(field) T.func(a: string)\n```\n\n```lua\n(field) T.func(a: number)\n```" .to_string(), }, )); @@ -285,7 +284,7 @@ mod tests { t.func(1) "#, VirtualHoverResult { - value: "\n```lua\n(field) T.func(a: number)\n```\n\n---\n\n注释2\n".to_string(), + value: "```lua\n(field) T.func(a: number)\n```\n\n---\n\n注释2".to_string(), }, )); } @@ -307,7 +306,91 @@ mod tests { local abc = t.func "#, VirtualHoverResult { - value: "\n```lua\n(field) T.func(a: number)\n```\n\n---\n\n注释2\n\n注释1\n\n---\n\n```lua\n(field) T.func(a: string)\n```\n".to_string(), + value: "```lua\n(field) T.func(a: number)\n```\n\n---\n\n注释2\n\n注释1\n\n---\n\n```lua\n(field) T.func(a: string)\n```".to_string(), + }, + )); + } + + #[test] + fn test_first_generic() { + let mut ws = ProviderVirtualWorkspace::new(); + assert!(ws.check_hover( + r#" + ---@class Reactive + local M + + ---@generic T: table + ---@param target T + ---@return T + function M.reactive(target) + end + + "#, + VirtualHoverResult { + value: "```lua\nfunction Reactive.reactive(target: T)\n -> T\n\n```".to_string(), + }, + )); + } + + #[test] + fn test_table_field_function() { + let mut ws = ProviderVirtualWorkspace::new(); + assert!(ws.check_hover( + r#" + local export = {} + ---@type fun() + export.NOOP = function() end + + "#, + VirtualHoverResult { + value: "```lua\nfunction export.NOOP()\n```".to_string(), + }, + )); + } + + #[test] + fn test_return_union_function() { + let mut ws = ProviderVirtualWorkspace::new(); + assert!(ws.check_hover( + r#" + ---@generic T + ---@param initialValue? T + ---@return (fun(): T) | (fun(value: T)) + local function signal(initialValue) + end + + ---测试 + local count = signal(1) + "#, + VirtualHoverResult { + value: "```lua\nfunction count(value: 1)\n```\n\n---\n\n测试\n\n---\n\n```lua\nfunction count() -> 1\n```".to_string(), + }, + )); + } + + #[test] + fn test_require_function() { + let mut ws = ProviderVirtualWorkspace::new(); + ws.def_file( + "test.lua", + r#" + + ---测试 + local function signal() + end + + return { + signal = signal + } + "#, + ); + assert!(ws.check_hover( + r#" + local test = require("test") + local signal = test.signal + "#, + VirtualHoverResult { + value: "```lua\nlocal function signal()\n```\n\n---\n\n测试".to_string(), }, )); } diff --git a/crates/emmylua_ls/src/handlers/test/hover_test.rs b/crates/emmylua_ls/src/handlers/test/hover_test.rs index cc8465fc3..7cd08662c 100644 --- a/crates/emmylua_ls/src/handlers/test/hover_test.rs +++ b/crates/emmylua_ls/src/handlers/test/hover_test.rs @@ -13,7 +13,9 @@ mod tests { ---@field c boolean "#, VirtualHoverResult { - value: "\n```lua\n(class) A {\n a: number,\n b: string,\n c: boolean,\n}\n```\n".to_string(), + value: + "```lua\n(class) A {\n a: number,\n b: string,\n c: boolean,\n}\n```" + .to_string(), }, )); } @@ -35,7 +37,7 @@ mod tests { m1.x = {} "#, VirtualHoverResult { - value: "\n```lua\n(field) x: integer = 1\n```\n".to_string(), + value: "```lua\n(field) x: integer = 1\n```".to_string(), }, )); @@ -58,7 +60,7 @@ mod tests { end "#, VirtualHoverResult { - value: "\n```lua\n(field) right: Node\n```\n".to_string(), + value: "```lua\n(field) right: Node\n```".to_string(), }, )); @@ -80,7 +82,7 @@ mod tests { end "#, VirtualHoverResult { - value: "\n```lua\nlocal node: Node1 {\n x: number,\n}\n```\n".to_string(), + value: "```lua\nlocal node: Node1 {\n x: number,\n}\n```".to_string(), }, )); } @@ -99,7 +101,7 @@ mod tests { local d = a.a "#, VirtualHoverResult { - value: "\n```lua\n(field) a: number?\n```\n".to_string(), + value: "```lua\n(field) a: number?\n```".to_string(), }, )); } @@ -114,7 +116,7 @@ mod tests { end "#, VirtualHoverResult { - value: "\n```lua\nlocal function f(a, b)\n```\n".to_string(), + value: "```lua\nlocal function f(a, b)\n```".to_string(), }, )); } @@ -133,7 +135,7 @@ mod tests { data.pulse "#, VirtualHoverResult { - value: "\n```lua\n(field) pulse: number?\n```\n\n  in class `Buff.AddData`\n\n---\n\n心跳周期\n".to_string(), + value: "```lua\n(field) pulse: number?\n```\n\n  in class `Buff.AddData`\n\n---\n\n心跳周期".to_string(), }, )); } @@ -154,7 +156,7 @@ mod tests { end "#, VirtualHoverResult { - value: "\n```lua\n(field) _cfg: number\n```\n".to_string(), + value: "```lua\n(field) _cfg: number\n```".to_string(), }, )); @@ -176,7 +178,52 @@ mod tests { end "#, VirtualHoverResult { - value: "\n```lua\n(field) _cfg: number\n```\n".to_string(), + value: "```lua\n(field) _cfg: number\n```".to_string(), + }, + )); + } + + #[test] + fn test_signature_desc() { + let mut ws = ProviderVirtualWorkspace::new(); + assert!(ws.check_hover( + r#" + -- # A + local function abc() + end + "#, + VirtualHoverResult { + value: "```lua\nlocal function abc()\n```\n\n---\n\n# A".to_string(), + }, + )); + } + + #[test] + fn test_class_desc() { + let mut ws = ProviderVirtualWorkspace::new(); + assert!(ws.check_hover( + r#" + ---A1 + ---@class ABC + ---A2 + "#, + VirtualHoverResult { + value: "```lua\n(class) ABC\n```\n\n---\n\nA1".to_string(), + }, + )); + } + + #[test] + fn test_alias_desc() { + let mut ws = ProviderVirtualWorkspace::new(); + assert!(ws.check_hover( + r#" + ---@alias TesAlias + ---| 'A' # A1 + ---| 'B' # A2 + "#, + VirtualHoverResult { + value: "```lua\n(alias) TesAlias = (\"A\"|\"B\")\n | \"A\" -- A1\n | \"B\" -- A2\n\n```".to_string(), }, )); } diff --git a/crates/emmylua_ls/src/handlers/test/implementation_test.rs b/crates/emmylua_ls/src/handlers/test/implementation_test.rs index 4819309dd..2330b5adc 100644 --- a/crates/emmylua_ls/src/handlers/test/implementation_test.rs +++ b/crates/emmylua_ls/src/handlers/test/implementation_test.rs @@ -128,4 +128,24 @@ mod tests { 2, )); } + + #[test] + fn test_separation_of_define_and_impl() { + let mut ws = ProviderVirtualWorkspace::new(); + assert!(ws.check_implementation( + r#" + local abc + + abc = function() + end + + local _a = abc + local _b = abc() + + abc = function() + end + "#, + 3, + )); + } } diff --git a/crates/emmylua_ls/src/handlers/test/inlay_hint_test.rs b/crates/emmylua_ls/src/handlers/test/inlay_hint_test.rs index 79c724f62..4e7e83ae7 100644 --- a/crates/emmylua_ls/src/handlers/test/inlay_hint_test.rs +++ b/crates/emmylua_ls/src/handlers/test/inlay_hint_test.rs @@ -1,5 +1,7 @@ #[cfg(test)] mod tests { + use std::{ops::Deref, sync::Arc}; + use lsp_types::{InlayHint, InlayHintLabel, Location, Position, Range}; use crate::handlers::test_lib::ProviderVirtualWorkspace; @@ -91,4 +93,70 @@ mod tests { .unwrap(); assert!(result.is_empty()); } + + #[test] + fn test_meta_call_hint() { + let mut ws = ProviderVirtualWorkspace::new(); + let result = ws + .check_inlay_hint( + r#" + ---@class Hint1 + ---@overload fun(a: string): Hint1 + local Hint1 + + local a = Hint1("a") + "#, + ) + .unwrap(); + assert!(result.len() == 4); + } + + #[test] + fn test_class_def_var_hint() { + let mut ws = ProviderVirtualWorkspace::new(); + let result = ws + .check_inlay_hint( + r#" + ---@class Hint.1 + ---@overload fun(a: integer): Hint.1 + local Hint1 + "#, + ) + .unwrap(); + assert!(result.len() == 1); + } + + #[test] + fn test_class_call_hint() { + let mut ws = ProviderVirtualWorkspace::new(); + let mut emmyrc = ws.analysis.get_emmyrc().deref().clone(); + emmyrc.runtime.class_default_call.function_name = "__init".to_string(); + emmyrc.runtime.class_default_call.force_non_colon = true; + emmyrc.runtime.class_default_call.force_return_self = true; + ws.analysis.update_config(Arc::new(emmyrc)); + + let result = ws + .check_inlay_hint( + r#" + ---@class MyClass + local A + + function A:__init(a) + end + + A() + "#, + ) + .unwrap(); + assert!(result.len() == 2); + + let location = match &result.get(1).unwrap().label { + InlayHintLabel::LabelParts(parts) => parts.first().unwrap().location.as_ref().unwrap(), + InlayHintLabel::String(_) => panic!(), + }; + assert_eq!( + location.range, + Range::new(Position::new(4, 27), Position::new(4, 33)) + ); + } } diff --git a/crates/emmylua_ls/src/handlers/test/mod.rs b/crates/emmylua_ls/src/handlers/test/mod.rs index 458de9f63..8b1cbb20c 100644 --- a/crates/emmylua_ls/src/handlers/test/mod.rs +++ b/crates/emmylua_ls/src/handlers/test/mod.rs @@ -1,3 +1,4 @@ +mod code_actions_test; mod completion_resolve_test; mod completion_test; mod definition_test; @@ -5,4 +6,5 @@ mod hover_function_test; mod hover_test; mod implementation_test; mod inlay_hint_test; +mod semantic_token_test; mod signature_helper_test; diff --git a/crates/emmylua_ls/src/handlers/test/semantic_token_test.rs b/crates/emmylua_ls/src/handlers/test/semantic_token_test.rs new file mode 100644 index 000000000..6fb4008c1 --- /dev/null +++ b/crates/emmylua_ls/src/handlers/test/semantic_token_test.rs @@ -0,0 +1,26 @@ +#[cfg(test)] +mod tests { + + use crate::handlers::test_lib::ProviderVirtualWorkspace; + + #[test] + fn test_1() { + let mut ws = ProviderVirtualWorkspace::new(); + ws.def( + r#" + ---@class Cast1 + ---@field get fun(self: self, a: number): Cast1? + "#, + ); + + ws.check_semantic_token( + r#" + ---@type Cast1 + local A + + local _a = A:get(1) --[[@cast -?]]:get(2) + "#, + ) + .unwrap(); + } +} diff --git a/crates/emmylua_ls/src/handlers/test_lib/mod.rs b/crates/emmylua_ls/src/handlers/test_lib/mod.rs index b32812e2e..2153b243f 100644 --- a/crates/emmylua_ls/src/handlers/test_lib/mod.rs +++ b/crates/emmylua_ls/src/handlers/test_lib/mod.rs @@ -1,16 +1,18 @@ use emmylua_code_analysis::{EmmyLuaAnalysis, FileId, VirtualUrlGenerator}; use lsp_types::{ - CompletionItemKind, CompletionResponse, CompletionTriggerKind, GotoDefinitionResponse, Hover, - HoverContents, InlayHint, MarkupContent, Position, SignatureHelpContext, - SignatureHelpTriggerKind, + CodeActionResponse, CompletionItemKind, CompletionResponse, CompletionTriggerKind, + GotoDefinitionResponse, Hover, HoverContents, InlayHint, MarkupContent, Position, + SemanticTokensResult, SignatureHelpContext, SignatureHelpTriggerKind, }; use tokio_util::sync::CancellationToken; use crate::{ context::ClientId, handlers::{ + code_actions::code_action, completion::{completion, completion_resolve}, inlay_hint::inlay_hint, + semantic_token::semantic_token, signature_helper::signature_help, }, }; @@ -143,7 +145,7 @@ impl ProviderVirtualWorkspace { let HoverContents::Markup(MarkupContent { kind, value }) = contents else { return false; }; - dbg!(&value); + // dbg!(&value); if value != expect.value { return false; } @@ -185,7 +187,7 @@ impl ProviderVirtualWorkspace { CompletionResponse::Array(items) => items, CompletionResponse::List(list) => list.items, }; - dbg!(&items); + // dbg!(&items); if items.len() != expect.len() { return false; } @@ -247,36 +249,31 @@ impl ProviderVirtualWorkspace { }; let file_id = self.def(&content); let result = implementation(&self.analysis, file_id, position); - dbg!(&result); let Some(result) = result else { return false; }; let GotoDefinitionResponse::Array(implementations) = result else { return false; }; - dbg!(&implementations.len()); if implementations.len() == len { return true; } false } - pub fn check_definition(&mut self, block_str: &str) -> bool { + pub fn check_definition(&mut self, block_str: &str) -> Option { let content = Self::handle_file_content(block_str); let Some((content, position)) = content else { - return false; + return None; }; let file_id = self.def(&content); - let result = super::definition::definition(&self.analysis, file_id, position); - dbg!(&result); + let result: Option = + super::definition::definition(&self.analysis, file_id, position); let Some(result) = result else { - return false; + return None; }; - match result { - GotoDefinitionResponse::Scalar(_) => true, - GotoDefinitionResponse::Array(_) => true, - GotoDefinitionResponse::Link(_) => true, - } + // dbg!(&result); + Some(result) } pub fn check_signature_helper( @@ -315,6 +312,32 @@ impl ProviderVirtualWorkspace { pub fn check_inlay_hint(&mut self, block_str: &str) -> Option> { let file_id = self.def(&block_str); let result = inlay_hint(&self.analysis, file_id); + dbg!(&result); return result; } + + pub fn check_code_action(&mut self, block_str: &str) -> Option { + let file_id = self.def(block_str); + let result = self + .analysis + .diagnose_file(file_id, CancellationToken::new()); + let Some(diagnostics) = result else { + return None; + }; + let result = code_action(&self.analysis, file_id, diagnostics); + // dbg!(&result); + result + } + + pub fn check_semantic_token(&mut self, block_str: &str) -> Option { + let file_id = self.def(block_str); + let result = semantic_token(&self.analysis, file_id, ClientId::VSCode); + let Some(result) = result else { + return None; + }; + + let data = serde_json::to_string(&result).unwrap(); + dbg!(&data); + Some(result) + } } diff --git a/crates/emmylua_ls/src/handlers/text_document/text_document_handler.rs b/crates/emmylua_ls/src/handlers/text_document/text_document_handler.rs index dfc911564..e9c04a50e 100644 --- a/crates/emmylua_ls/src/handlers/text_document/text_document_handler.rs +++ b/crates/emmylua_ls/src/handlers/text_document/text_document_handler.rs @@ -1,9 +1,9 @@ -use std::time::Duration; - +use emmylua_code_analysis::uri_to_file_path; use lsp_types::{ DidChangeTextDocumentParams, DidCloseTextDocumentParams, DidOpenTextDocumentParams, DidSaveTextDocumentParams, }; +use std::time::Duration; use crate::context::ServerContextSnapshot; @@ -101,13 +101,30 @@ pub async fn on_did_close_document( context: ServerContextSnapshot, params: DidCloseTextDocumentParams, ) -> Option<()> { + let uri = ¶ms.text_document.uri; let mut workspace = context.workspace_manager.write().await; workspace .current_open_files .remove(¶ms.text_document.uri); drop(workspace); + + // 如果关闭后文件不存在, 则移除 + if let Some(file_path) = uri_to_file_path(uri) { + if !file_path.exists() { + let mut mut_analysis = context.analysis.write().await; + mut_analysis.remove_file_by_uri(uri); + drop(mut_analysis); + + context + .file_diagnostic + .clear_file_diagnostics(uri.clone()) + .await; + + return Some(()); + } + } + let analysis = context.analysis.read().await; - let uri = ¶ms.text_document.uri; let file_id = analysis.get_file_id(uri)?; let module_info = analysis .compilation @@ -118,6 +135,12 @@ pub async fn on_did_close_document( drop(analysis); let mut mut_analysis = context.analysis.write().await; mut_analysis.remove_file_by_uri(uri); + drop(mut_analysis); + // 发送空诊断消息以清除客户端显示的诊断 + context + .file_diagnostic + .clear_file_diagnostics(uri.clone()) + .await; } Some(()) diff --git a/crates/emmylua_ls/src/handlers/text_document/watched_file_handler.rs b/crates/emmylua_ls/src/handlers/text_document/watched_file_handler.rs index 0797aaf4a..0482fa6f4 100644 --- a/crates/emmylua_ls/src/handlers/text_document/watched_file_handler.rs +++ b/crates/emmylua_ls/src/handlers/text_document/watched_file_handler.rs @@ -20,6 +20,11 @@ pub async fn on_did_change_watched_files( Some(WatchedFileType::Lua) => { if file_event.typ == FileChangeType::DELETED { analysis.remove_file_by_uri(&file_event.uri); + // 发送空诊断消息以清除客户端显示的诊断 + context + .file_diagnostic + .clear_file_diagnostics(file_event.uri) + .await; continue; } diff --git a/crates/emmylua_ls/src/handlers/workspace/did_rename_files.rs b/crates/emmylua_ls/src/handlers/workspace/did_rename_files.rs new file mode 100644 index 000000000..a86c4b88b --- /dev/null +++ b/crates/emmylua_ls/src/handlers/workspace/did_rename_files.rs @@ -0,0 +1,291 @@ +use std::{ + collections::HashMap, + path::{Path, PathBuf}, + str::FromStr, +}; + +use emmylua_code_analysis::{ + file_path_to_uri, read_file_with_encoding, uri_to_file_path, FileId, LuaCompilation, + LuaModuleIndex, LuaType, SemanticModel, WorkspaceId, +}; +use emmylua_parser::{LuaAstNode, LuaCallExpr, LuaIndexExpr}; +use lsp_types::{ + ApplyWorkspaceEditParams, FileRename, MessageActionItem, MessageType, RenameFilesParams, + ShowMessageRequestParams, TextEdit, Uri, WorkspaceEdit, +}; +use tokio_util::sync::CancellationToken; +use walkdir::WalkDir; + +use crate::{context::ServerContextSnapshot, handlers::ClientConfig}; + +pub async fn on_did_rename_files_handler( + context: ServerContextSnapshot, + params: RenameFilesParams, +) -> Option<()> { + let mut all_renames: Vec = vec![]; + + let analysis = context.analysis.read().await; + + let module_index = analysis.compilation.get_db().get_module_index(); + for file_rename in params.files { + let FileRename { old_uri, new_uri } = file_rename; + + let old_uri = Uri::from_str(&old_uri).ok()?; + let new_uri = Uri::from_str(&new_uri).ok()?; + + let old_path = uri_to_file_path(&old_uri)?; + let new_path = uri_to_file_path(&new_uri)?; + + // 提取重命名信息 + let rename_info = collect_rename_info(&old_uri, &new_uri, &module_index); + if let Some(rename_info) = rename_info { + all_renames.push(rename_info.clone()); + } else { + // 有可能是目录重命名, 需要收集目录下所有 lua 文件 + if let Some(collected_renames) = + collect_directory_lua_files(&old_path, &new_path, &module_index) + { + all_renames.extend(collected_renames); + } + } + } + + // 如果有重命名的文件, 弹窗询问用户是否要修改require路径 + if !all_renames.is_empty() { + drop(analysis); + // 更新 + let mut analysis = context.analysis.write().await; + let encoding = &analysis.get_emmyrc().workspace.encoding; + for rename in all_renames.iter() { + analysis.remove_file_by_uri(&rename.old_uri); + if let Some(new_path) = uri_to_file_path(&rename.new_uri) { + if let Some(text) = read_file_with_encoding(&new_path, encoding) { + analysis.update_file_by_uri(&rename.new_uri, Some(text)); + } + } + } + drop(analysis); + + let analysis = context.analysis.read().await; + if let Some(changes) = try_modify_require_path(&analysis.compilation, &all_renames) { + drop(analysis); + if changes.is_empty() { + return Some(()); + } + + let client = context.client.clone(); + + let show_message_params = ShowMessageRequestParams { + typ: MessageType::INFO, + message: t!("Do you want to modify the require path?").to_string(), + actions: Some(vec![MessageActionItem { + title: t!("Modify").to_string(), + properties: HashMap::new(), + }]), + }; + + // 发送弹窗请求 + let cancel_token = CancellationToken::new(); + if let Some(selected_action) = client + .show_message_request(show_message_params, cancel_token) + .await + { + let cancel_token = CancellationToken::new(); + if selected_action.title == t!("Modify") { + client + .apply_edit( + ApplyWorkspaceEditParams { + edit: WorkspaceEdit { + changes: Some(changes), + document_changes: None, + change_annotations: None, + }, + label: None, + }, + cancel_token, + ) + .await?; + } + } + } + } + + Some(()) +} + +#[derive(Debug, PartialEq, Eq, Clone)] +struct RenameInfo { + old_uri: Uri, + new_uri: Uri, + old_module_path: String, + new_module_path: String, + workspace_id: WorkspaceId, +} + +fn collect_rename_info( + old_uri: &Uri, + new_uri: &Uri, + module_index: &LuaModuleIndex, +) -> Option { + let (mut old_module_path, workspace_id) = + module_index.extract_module_path(uri_to_file_path(&old_uri)?.to_str()?)?; + old_module_path = old_module_path.replace(['\\', '/'], "."); + + let (mut new_module_path, _) = + module_index.extract_module_path(uri_to_file_path(&new_uri)?.to_str()?)?; + new_module_path = new_module_path.replace(['\\', '/'], "."); + + Some(RenameInfo { + old_uri: old_uri.clone(), + new_uri: new_uri.clone(), + old_module_path, + new_module_path, + workspace_id, + }) +} + +/// 收集目录重命名后所有的Lua文件 +fn collect_directory_lua_files( + old_path: &PathBuf, + new_path: &PathBuf, + module_index: &LuaModuleIndex, +) -> Option> { + // 检查新路径是否是目录(旧路径已经不存在了) + if !new_path.is_dir() { + return None; + } + + let mut renames = vec![]; + + // 遍历新目录下的所有Lua文件 + for entry in WalkDir::new(new_path) + .into_iter() + .filter_map(|e| e.ok()) + .filter(|e| e.file_type().is_file()) + { + let new_file_path = entry.path(); + + // 计算在新目录中的相对路径 + if let Ok(relative_path) = new_file_path.strip_prefix(new_path) { + // 根据目录重命名推算出对应的旧文件路径 + let old_file_path = old_path.join(relative_path); + + // 转换为URI + if let (Some(old_file_uri), Some(new_file_uri)) = ( + file_path_to_uri(&old_file_path), + file_path_to_uri(&new_file_path.to_path_buf()), + ) { + let rename_info = collect_rename_info(&old_file_uri, &new_file_uri, module_index); + if let Some(rename_info) = rename_info { + renames.push(rename_info); + } + } + } + } + + if renames.is_empty() { + None + } else { + Some(renames) + } +} + +#[allow(unused)] +/// 检查文件路径是否是Lua文件 +fn is_lua_file(file_path: &Path, client_config: &ClientConfig) -> bool { + let file_name = file_path.to_string_lossy(); + + if file_name.ends_with(".lua") { + return true; + } + + // 检查客户端配置的扩展名 + for extension in &client_config.extensions { + if file_name.ends_with(extension) { + return true; + } + } + + false +} + +fn try_modify_require_path( + compilation: &LuaCompilation, + renames: &Vec, +) -> Option>> { + let mut changes: HashMap> = HashMap::new(); + for file_id in compilation.get_db().get_vfs().get_all_file_ids() { + if compilation.get_db().get_module_index().is_std(&file_id) { + continue; + } + + if let Some(semantic_model) = compilation.get_semantic_model(file_id) { + for call_expr in semantic_model.get_root().descendants::() { + if call_expr.is_require() { + try_convert(&semantic_model, call_expr, renames, &mut changes, file_id); + } + } + } + } + Some(changes) +} + +fn try_convert( + semantic_model: &SemanticModel, + call_expr: LuaCallExpr, + renames: &Vec, + changes: &mut HashMap>, + current_file_id: FileId, // 当前文件id +) -> Option<()> { + if let Some(_) = call_expr.get_parent::() { + return None; + } + + let args_list = call_expr.get_args_list()?; + let arg_expr = args_list.get_args().next()?; + let ty = semantic_model + .infer_expr(arg_expr.clone()) + .unwrap_or(LuaType::Any); + let name = if let LuaType::StringConst(s) = ty { + s + } else { + return None; + }; + let emmyrc = semantic_model.get_emmyrc(); + let separator = &emmyrc.completion.auto_require_separator; + let strict_require_path = emmyrc.strict.require_path; + // 转换为标准导入语法 + let normalized_path = name.replace(separator, "."); + + for rename in renames { + let is_matched = if strict_require_path { + rename.old_module_path == normalized_path + } else { + rename.old_module_path.ends_with(&normalized_path) + }; + + if is_matched { + let range = arg_expr.syntax().text_range(); + let lsp_range = semantic_model.get_document().to_lsp_range(range)?; + + let current_uri = semantic_model + .get_db() + .get_vfs() + .get_uri(¤t_file_id)?; + + let full_module_path = match separator.as_str() { + "." | "" => rename.new_module_path.clone(), + _ => rename.new_module_path.replace(".", &separator), + }; + + changes.entry(current_uri).or_insert(vec![]).push(TextEdit { + range: lsp_range, + new_text: format!("'{}'", full_module_path), + }); + + return Some(()); + } + } + + Some(()) +} diff --git a/crates/emmylua_ls/src/handlers/workspace/mod.rs b/crates/emmylua_ls/src/handlers/workspace/mod.rs new file mode 100644 index 000000000..f686eed45 --- /dev/null +++ b/crates/emmylua_ls/src/handlers/workspace/mod.rs @@ -0,0 +1,35 @@ +mod did_rename_files; + +pub use did_rename_files::on_did_rename_files_handler; +use lsp_types::{ + ClientCapabilities, FileOperationFilter, FileOperationPattern, FileOperationPatternOptions, + FileOperationRegistrationOptions, ServerCapabilities, + WorkspaceFileOperationsServerCapabilities, WorkspaceServerCapabilities, +}; + +use crate::handlers::RegisterCapabilities; + +pub struct WorkspaceCapabilities; + +impl RegisterCapabilities for WorkspaceCapabilities { + fn register_capabilities(server_capabilities: &mut ServerCapabilities, _: &ClientCapabilities) { + server_capabilities.workspace = Some(WorkspaceServerCapabilities { + file_operations: Some(WorkspaceFileOperationsServerCapabilities { + did_rename: Some(FileOperationRegistrationOptions { + filters: vec![FileOperationFilter { + scheme: Some(String::from("file")), + pattern: FileOperationPattern { + glob: "**/*".to_string(), + matches: None, + options: Some(FileOperationPatternOptions { + ignore_case: Some(true), + }), + }, + }], + }), + ..Default::default() + }), + ..Default::default() + }); + } +} diff --git a/crates/emmylua_ls/src/util/mod.rs b/crates/emmylua_ls/src/util/mod.rs index 3c7f1d8e0..f3b01aefd 100644 --- a/crates/emmylua_ls/src/util/mod.rs +++ b/crates/emmylua_ls/src/util/mod.rs @@ -1,5 +1,5 @@ mod module_name_convert; mod time_cancel_token; -pub use module_name_convert::module_name_convert; +pub use module_name_convert::{key_name_convert, module_name_convert}; pub use time_cancel_token::time_cancel_token; diff --git a/crates/emmylua_ls/src/util/module_name_convert.rs b/crates/emmylua_ls/src/util/module_name_convert.rs index 6fa6e6063..be21b1d46 100644 --- a/crates/emmylua_ls/src/util/module_name_convert.rs +++ b/crates/emmylua_ls/src/util/module_name_convert.rs @@ -1,7 +1,10 @@ -use emmylua_code_analysis::EmmyrcFilenameConvention; +use emmylua_code_analysis::{EmmyrcFilenameConvention, LuaType, ModuleInfo}; -pub fn module_name_convert(name: &str, file_conversion: EmmyrcFilenameConvention) -> String { - let mut module_name = name.to_string(); +pub fn module_name_convert( + module_info: &ModuleInfo, + file_conversion: EmmyrcFilenameConvention, +) -> String { + let mut module_name = module_info.name.to_string(); match file_conversion { EmmyrcFilenameConvention::SnakeCase => { @@ -14,11 +17,44 @@ pub fn module_name_convert(name: &str, file_conversion: EmmyrcFilenameConvention module_name = to_pascal_case(&module_name); } EmmyrcFilenameConvention::Keep => {} + EmmyrcFilenameConvention::KeepClass => { + if let Some(export_type) = &module_info.export_type { + if let LuaType::Def(id) = export_type { + module_name = id.get_simple_name().to_string(); + } + } + } } module_name } +pub fn key_name_convert( + key: &str, + typ: &LuaType, + file_conversion: EmmyrcFilenameConvention, +) -> String { + let mut key_name = key.to_string(); + match file_conversion { + EmmyrcFilenameConvention::SnakeCase => { + key_name = to_snake_case(&key_name); + } + EmmyrcFilenameConvention::CamelCase => { + key_name = to_camel_case(&key_name); + } + EmmyrcFilenameConvention::PascalCase => { + key_name = to_pascal_case(&key_name); + } + EmmyrcFilenameConvention::Keep => {} + EmmyrcFilenameConvention::KeepClass => { + if let LuaType::Def(id) = typ { + key_name = id.get_simple_name().to_string(); + } + } + } + key_name +} + fn to_snake_case(s: &str) -> String { let mut result = String::new(); for (i, ch) in s.chars().enumerate() { diff --git a/crates/emmylua_parser/locales/app.yml b/crates/emmylua_parser/locales/app.yml index 839e58551..1f9f90e14 100644 --- a/crates/emmylua_parser/locales/app.yml +++ b/crates/emmylua_parser/locales/app.yml @@ -188,4 +188,19 @@ colon accessor must be followed by a function call or table constructor or strin en: colon accessor must be followed by a function call or table constructor or string literal zh_CN: 冒号访问器后必须跟随函数调用、表构造或字符串字面量 zh_HK: 冒號訪問器後必須跟隨函數調用、表構造或字符串字面量 - zh_TW: 冒號存取器後必須跟隨函數呼叫、表建構或字串字面量 \ No newline at end of file + zh_TW: 冒號存取器後必須跟隨函數呼叫、表建構或字串字面量 +expected '}' to close table: + en: expected '}' to close table + zh_CN: 期望 '}' 关闭表 + zh_HK: 期望 '}' 關閉表 + zh_TW: 期望 '}' 關閉表 +expected ']': + en: expected ']' + zh_CN: 期望 ']' + zh_HK: 期望 ']' + zh_TW: 期望 ']' +expected '=': + en: expected '=' + zh_CN: 期望 '=' + zh_HK: 期望 '=' + zh_TW: 期望 '=' \ No newline at end of file diff --git a/crates/emmylua_parser/src/grammar/doc/tag.rs b/crates/emmylua_parser/src/grammar/doc/tag.rs index 71def21e0..736d39aee 100644 --- a/crates/emmylua_parser/src/grammar/doc/tag.rs +++ b/crates/emmylua_parser/src/grammar/doc/tag.rs @@ -418,11 +418,20 @@ fn parse_tag_overload(p: &mut LuaDocParser) -> ParseResult { // ---@cast a +? // ---@cast a +string, -number fn parse_tag_cast(p: &mut LuaDocParser) -> ParseResult { - p.set_state(LuaDocLexerState::Normal); + p.set_state(LuaDocLexerState::CastExpr); let m = p.mark(LuaSyntaxKind::DocTagCast); p.bump(); - expect_token(p, LuaTokenKind::TkName)?; + if p.current_token() == LuaTokenKind::TkName { + match parse_cast_expr(p) { + Ok(_) => {} + Err(e) => { + return Err(e); + } + } + } + + // 切换回正常状态 parse_op_type(p)?; while p.current_token() == LuaTokenKind::TkComma { p.bump(); @@ -434,6 +443,25 @@ fn parse_tag_cast(p: &mut LuaDocParser) -> ParseResult { Ok(m.complete(p)) } +fn parse_cast_expr(p: &mut LuaDocParser) -> ParseResult { + let m = p.mark(LuaSyntaxKind::NameExpr); + p.bump(); + let mut cm = m.complete(p); + // 处理多级字段访问 + while p.current_token() == LuaTokenKind::TkDot { + let index_m = cm.precede(p, LuaSyntaxKind::IndexExpr); + p.bump(); + if p.current_token() == LuaTokenKind::TkName { + p.bump(); + } else { + // 找不到也不报错 + } + cm = index_m.complete(p); + } + + Ok(cm) +} + // +, -, +?, fn parse_op_type(p: &mut LuaDocParser) -> ParseResult { p.set_state(LuaDocLexerState::Normal); diff --git a/crates/emmylua_parser/src/grammar/doc/test.rs b/crates/emmylua_parser/src/grammar/doc/test.rs index 36a6e5a43..1d0eddfa6 100644 --- a/crates/emmylua_parser/src/grammar/doc/test.rs +++ b/crates/emmylua_parser/src/grammar/doc/test.rs @@ -1101,7 +1101,8 @@ Syntax(Chunk)@0..169 Syntax(DocTagCast)@13..26 Token(TkTagCast)@13..17 "cast" Token(TkWhitespace)@17..18 " " - Token(TkName)@18..19 "a" + Syntax(NameExpr)@18..19 + Token(TkName)@18..19 "a" Token(TkWhitespace)@19..20 " " Syntax(DocOpType)@20..26 Syntax(TypeName)@20..26 @@ -1112,7 +1113,8 @@ Syntax(Chunk)@0..169 Syntax(DocTagCast)@39..53 Token(TkTagCast)@39..43 "cast" Token(TkWhitespace)@43..44 " " - Token(TkName)@44..45 "b" + Syntax(NameExpr)@44..45 + Token(TkName)@44..45 "b" Token(TkWhitespace)@45..46 " " Syntax(DocOpType)@46..53 Token(TkPlus)@46..47 "+" @@ -1124,7 +1126,8 @@ Syntax(Chunk)@0..169 Syntax(DocTagCast)@66..80 Token(TkTagCast)@66..70 "cast" Token(TkWhitespace)@70..71 " " - Token(TkName)@71..72 "c" + Syntax(NameExpr)@71..72 + Token(TkName)@71..72 "c" Token(TkWhitespace)@72..73 " " Syntax(DocOpType)@73..80 Token(TkMinus)@73..74 "-" @@ -1136,7 +1139,8 @@ Syntax(Chunk)@0..169 Syntax(DocTagCast)@93..102 Token(TkTagCast)@93..97 "cast" Token(TkWhitespace)@97..98 " " - Token(TkName)@98..99 "d" + Syntax(NameExpr)@98..99 + Token(TkName)@98..99 "d" Token(TkWhitespace)@99..100 " " Syntax(DocOpType)@100..102 Token(TkPlus)@100..101 "+" @@ -1147,7 +1151,8 @@ Syntax(Chunk)@0..169 Syntax(DocTagCast)@115..124 Token(TkTagCast)@115..119 "cast" Token(TkWhitespace)@119..120 " " - Token(TkName)@120..121 "e" + Syntax(NameExpr)@120..121 + Token(TkName)@120..121 "e" Token(TkWhitespace)@121..122 " " Syntax(DocOpType)@122..124 Token(TkMinus)@122..123 "-" @@ -1158,7 +1163,8 @@ Syntax(Chunk)@0..169 Syntax(DocTagCast)@137..160 Token(TkTagCast)@137..141 "cast" Token(TkWhitespace)@141..142 " " - Token(TkName)@142..143 "f" + Syntax(NameExpr)@142..143 + Token(TkName)@142..143 "f" Token(TkWhitespace)@143..144 " " Syntax(DocOpType)@144..151 Token(TkPlus)@144..145 "+" @@ -1750,6 +1756,98 @@ Syntax(Chunk)@0..51 assert_ast_eq!(code, result); } + #[test] + fn test_cast_expr() { + let code = r#" +---@cast a number +---@cast a.field string +---@cast A.b.c.d boolean +---@cast -? + "#; + let result = r#" +Syntax(Chunk)@0..88 + Syntax(Block)@0..88 + Token(TkEndOfLine)@0..1 "\n" + Syntax(Comment)@1..79 + Token(TkDocStart)@1..5 "---@" + Syntax(DocTagCast)@5..18 + Token(TkTagCast)@5..9 "cast" + Token(TkWhitespace)@9..10 " " + Syntax(NameExpr)@10..11 + Token(TkName)@10..11 "a" + Token(TkWhitespace)@11..12 " " + Syntax(DocOpType)@12..18 + Syntax(TypeName)@12..18 + Token(TkName)@12..18 "number" + Token(TkEndOfLine)@18..19 "\n" + Token(TkDocStart)@19..23 "---@" + Syntax(DocTagCast)@23..42 + Token(TkTagCast)@23..27 "cast" + Token(TkWhitespace)@27..28 " " + Syntax(IndexExpr)@28..35 + Syntax(NameExpr)@28..29 + Token(TkName)@28..29 "a" + Token(TkDot)@29..30 "." + Token(TkName)@30..35 "field" + Token(TkWhitespace)@35..36 " " + Syntax(DocOpType)@36..42 + Syntax(TypeName)@36..42 + Token(TkName)@36..42 "string" + Token(TkEndOfLine)@42..43 "\n" + Token(TkDocStart)@43..47 "---@" + Syntax(DocTagCast)@47..67 + Token(TkTagCast)@47..51 "cast" + Token(TkWhitespace)@51..52 " " + Syntax(IndexExpr)@52..59 + Syntax(IndexExpr)@52..57 + Syntax(IndexExpr)@52..55 + Syntax(NameExpr)@52..53 + Token(TkName)@52..53 "A" + Token(TkDot)@53..54 "." + Token(TkName)@54..55 "b" + Token(TkDot)@55..56 "." + Token(TkName)@56..57 "c" + Token(TkDot)@57..58 "." + Token(TkName)@58..59 "d" + Token(TkWhitespace)@59..60 " " + Syntax(DocOpType)@60..67 + Syntax(TypeName)@60..67 + Token(TkName)@60..67 "boolean" + Token(TkEndOfLine)@67..68 "\n" + Token(TkDocStart)@68..72 "---@" + Syntax(DocTagCast)@72..79 + Token(TkTagCast)@72..76 "cast" + Token(TkWhitespace)@76..77 " " + Syntax(DocOpType)@77..79 + Token(TkMinus)@77..78 "-" + Token(TkDocQuestion)@78..79 "?" + Token(TkEndOfLine)@79..80 "\n" + Token(TkWhitespace)@80..88 " " + "#; + + assert_ast_eq!(code, result); + } + + #[test] + fn test_multi_level_cast() { + let code = r#" + ---@cast obj.a.b.c.d string + "#; + // Note: The exact line numbers may vary, but the structure should be correct + let tree = LuaParser::parse(code, ParserConfig::default()); + let result = format!("{:#?}", tree.get_red_root()); + + // Verify that we have the correct nested structure + assert!(result.contains("IndexExpr")); + assert!(result.contains("NameExpr")); + assert!(result.contains("TkDot")); + assert!(result.contains("obj")); + assert!(result.contains("string")); + + // Print the actual result for debugging + println!("Actual AST structure:\n{}", result); + } + #[test] fn test_compact_luals_param() { let code = r#" diff --git a/crates/emmylua_parser/src/grammar/lua/expr.rs b/crates/emmylua_parser/src/grammar/lua/expr.rs index 46cf2d805..502a769f2 100644 --- a/crates/emmylua_parser/src/grammar/lua/expr.rs +++ b/crates/emmylua_parser/src/grammar/lua/expr.rs @@ -132,15 +132,20 @@ fn parse_table_expr(p: &mut LuaParser) -> ParseResult { return Ok(m.complete(p)); } - let mut cm = parse_field(p)?; - match cm.kind { - LuaSyntaxKind::TableFieldAssign => { - m.set_kind(p, LuaSyntaxKind::TableObjectExpr); - } - LuaSyntaxKind::TableFieldValue => { - m.set_kind(p, LuaSyntaxKind::TableArrayExpr); + match parse_field_with_recovery(p) { + Ok(cm) => match cm.kind { + LuaSyntaxKind::TableFieldAssign => { + m.set_kind(p, LuaSyntaxKind::TableObjectExpr); + } + LuaSyntaxKind::TableFieldValue => { + m.set_kind(p, LuaSyntaxKind::TableArrayExpr); + } + _ => {} + }, + Err(_) => { + // 即使字段解析失败, 我们也不中止解析 + recover_to_table_boundary(p); } - _ => {} } while p.current_token() == LuaTokenKind::TkComma @@ -150,42 +155,170 @@ fn parse_table_expr(p: &mut LuaParser) -> ParseResult { if p.current_token() == LuaTokenKind::TkRightBrace { break; } - cm = parse_field(p)?; - if cm.kind == LuaSyntaxKind::TableFieldAssign { - m.set_kind(p, LuaSyntaxKind::TableObjectExpr); + + match parse_field_with_recovery(p) { + Ok(cm) => { + if cm.kind == LuaSyntaxKind::TableFieldAssign { + m.set_kind(p, LuaSyntaxKind::TableObjectExpr); + } + } + Err(_) => { + // 即使字段解析失败, 我们也不中止解析 + recover_to_table_boundary(p); + if p.current_token() == LuaTokenKind::TkRightBrace { + break; + } + } + } + } + + // 处理闭合括号 + if p.current_token() == LuaTokenKind::TkRightBrace { + p.bump(); + } else { + // 表可能是错的, 但可以继续尝试解析 + let mut found_brace = false; + let mut brace_count = 1; // 我们已经在表中 + let mut lookahead_count = 0; + const MAX_LOOKAHEAD: usize = 50; // 限制令牌数避免无休止的解析 + + while p.current_token() != LuaTokenKind::TkEof && lookahead_count < MAX_LOOKAHEAD { + match p.current_token() { + LuaTokenKind::TkRightBrace => { + brace_count -= 1; + if brace_count == 0 { + p.bump(); // 消费闭合括号 + found_brace = true; + break; + } + p.bump(); + } + LuaTokenKind::TkLeftBrace => { + brace_count += 1; + p.bump(); + } + // 如果遇到则认为已经是表的边界 + LuaTokenKind::TkLocal + | LuaTokenKind::TkFunction + | LuaTokenKind::TkIf + | LuaTokenKind::TkWhile + | LuaTokenKind::TkFor + | LuaTokenKind::TkReturn => { + break; + } + _ => { + p.bump(); + } + } + lookahead_count += 1; + } + + if !found_brace { + // 没有找到闭合括号, 报告错误 + p.push_error(LuaParseError::syntax_error_from( + &t!("expected '}' to close table"), + p.current_token_range(), + )); } } - expect_token(p, LuaTokenKind::TkRightBrace)?; Ok(m.complete(p)) } -fn parse_field(p: &mut LuaParser) -> ParseResult { +fn parse_field_with_recovery(p: &mut LuaParser) -> ParseResult { let mut m = p.mark(LuaSyntaxKind::TableFieldValue); - - if p.current_token() == LuaTokenKind::TkLeftBracket { - m.set_kind(p, LuaSyntaxKind::TableFieldAssign); - p.bump(); - parse_expr(p)?; - expect_token(p, LuaTokenKind::TkRightBracket)?; - expect_token(p, LuaTokenKind::TkAssign)?; - parse_expr(p)?; - } else if p.current_token() == LuaTokenKind::TkName { - if p.peek_next_token() == LuaTokenKind::TkAssign { + // 即使字段解析失败, 我们也不会中止解析 + match p.current_token() { + LuaTokenKind::TkLeftBracket => { m.set_kind(p, LuaSyntaxKind::TableFieldAssign); p.bump(); - p.bump(); - parse_expr(p)?; - } else { - parse_expr(p)?; + match parse_expr(p) { + Ok(_) => {} + Err(err) => { + p.push_error(err); + // 找到边界 + while !matches!( + p.current_token(), + LuaTokenKind::TkRightBracket + | LuaTokenKind::TkAssign + | LuaTokenKind::TkComma + | LuaTokenKind::TkSemicolon + | LuaTokenKind::TkRightBrace + | LuaTokenKind::TkEof + ) { + p.bump(); + } + } + } + if p.current_token() == LuaTokenKind::TkRightBracket { + p.bump(); + } else { + p.push_error(LuaParseError::syntax_error_from( + &t!("expected ']'"), + p.current_token_range(), + )); + } + if p.current_token() == LuaTokenKind::TkAssign { + p.bump(); + } else { + p.push_error(LuaParseError::syntax_error_from( + &t!("expected '='"), + p.current_token_range(), + )); + } + match parse_expr(p) { + Ok(_) => {} + Err(err) => { + p.push_error(err); + } + } } - } else { - parse_expr(p)?; + LuaTokenKind::TkName => { + if p.peek_next_token() == LuaTokenKind::TkAssign { + m.set_kind(p, LuaSyntaxKind::TableFieldAssign); + p.bump(); // consume name + p.bump(); // consume '=' + match parse_expr(p) { + Ok(_) => {} + Err(err) => { + p.push_error(err); + } + } + } else { + match parse_expr(p) { + Ok(_) => {} + Err(err) => { + p.push_error(err); + } + } + } + } + // 一些表示`table`实际上已经结束的令牌 + LuaTokenKind::TkEof | LuaTokenKind::TkLocal => {} + _ => match parse_expr(p) { + Ok(_) => {} + Err(err) => { + p.push_error(err); + } + }, } Ok(m.complete(p)) } +fn recover_to_table_boundary(p: &mut LuaParser) { + // 跳过直到找到表边界或字段分隔符 + while !matches!( + p.current_token(), + LuaTokenKind::TkComma + | LuaTokenKind::TkSemicolon + | LuaTokenKind::TkRightBrace + | LuaTokenKind::TkEof + ) { + p.bump(); + } +} + fn parse_suffixed_expr(p: &mut LuaParser) -> ParseResult { let mut cm = match p.current_token() { LuaTokenKind::TkName => parse_name_or_special_function(p)?, diff --git a/crates/emmylua_parser/src/grammar/lua/test.rs b/crates/emmylua_parser/src/grammar/lua/test.rs index ab168aae6..edcbf2e19 100644 --- a/crates/emmylua_parser/src/grammar/lua/test.rs +++ b/crates/emmylua_parser/src/grammar/lua/test.rs @@ -1150,4 +1150,62 @@ Syntax(Chunk)@0..12 ParserConfig::with_level(LuaLanguageLevel::Lua55) ); } + + #[test] + fn test_wrong_table_expr() { + let code = r#" + local _A = { + a = , + b = , + c = , + } + "#; + let result = r#" +Syntax(Chunk)@0..94 + Syntax(Block)@0..94 + Token(TkEndOfLine)@0..1 "\n" + Token(TkWhitespace)@1..9 " " + Syntax(LocalStat)@9..85 + Token(TkLocal)@9..14 "local" + Token(TkWhitespace)@14..15 " " + Syntax(LocalName)@15..17 + Token(TkName)@15..17 "_A" + Token(TkWhitespace)@17..18 " " + Token(TkAssign)@18..19 "=" + Token(TkWhitespace)@19..20 " " + Syntax(TableObjectExpr)@20..85 + Token(TkLeftBrace)@20..21 "{" + Token(TkEndOfLine)@21..22 "\n" + Token(TkWhitespace)@22..34 " " + Syntax(TableFieldAssign)@34..37 + Token(TkName)@34..35 "a" + Token(TkWhitespace)@35..36 " " + Token(TkAssign)@36..37 "=" + Token(TkWhitespace)@37..38 " " + Token(TkComma)@38..39 "," + Token(TkEndOfLine)@39..40 "\n" + Token(TkWhitespace)@40..52 " " + Syntax(TableFieldAssign)@52..55 + Token(TkName)@52..53 "b" + Token(TkWhitespace)@53..54 " " + Token(TkAssign)@54..55 "=" + Token(TkWhitespace)@55..56 " " + Token(TkComma)@56..57 "," + Token(TkEndOfLine)@57..58 "\n" + Token(TkWhitespace)@58..70 " " + Syntax(TableFieldAssign)@70..73 + Token(TkName)@70..71 "c" + Token(TkWhitespace)@71..72 " " + Token(TkAssign)@72..73 "=" + Token(TkWhitespace)@73..74 " " + Token(TkComma)@74..75 "," + Token(TkEndOfLine)@75..76 "\n" + Token(TkWhitespace)@76..84 " " + Token(TkRightBrace)@84..85 "}" + Token(TkEndOfLine)@85..86 "\n" + Token(TkWhitespace)@86..94 " " + "#; + + assert_ast_eq!(code, result); + } } diff --git a/crates/emmylua_parser/src/lexer/lua_doc_lexer.rs b/crates/emmylua_parser/src/lexer/lua_doc_lexer.rs index 167afe4dc..d5fa5d075 100644 --- a/crates/emmylua_parser/src/lexer/lua_doc_lexer.rs +++ b/crates/emmylua_parser/src/lexer/lua_doc_lexer.rs @@ -26,6 +26,7 @@ pub enum LuaDocLexerState { Version, Source, NormalDescription, + CastExpr, } impl LuaDocLexer<'_> { @@ -70,6 +71,7 @@ impl LuaDocLexer<'_> { LuaDocLexerState::Version => self.lex_version(), LuaDocLexerState::Source => self.lex_source(), LuaDocLexerState::NormalDescription => self.lex_normal_description(), + LuaDocLexerState::CastExpr => self.lex_cast_expr(), } } @@ -536,6 +538,26 @@ impl LuaDocLexer<'_> { } } } + + fn lex_cast_expr(&mut self) -> LuaTokenKind { + let reader = self.reader.as_mut().unwrap(); + match reader.current_char() { + ch if is_doc_whitespace(ch) => { + reader.eat_while(is_doc_whitespace); + LuaTokenKind::TkWhitespace + } + '.' => { + reader.bump(); + LuaTokenKind::TkDot + } + ch if is_name_start(ch) => { + reader.bump(); + reader.eat_while(is_name_continue); + LuaTokenKind::TkName + } + _ => self.lex_normal(), + } + } } fn to_tag(text: &str) -> LuaTokenKind { diff --git a/crates/emmylua_parser/src/parser/lua_doc_parser.rs b/crates/emmylua_parser/src/parser/lua_doc_parser.rs index 815c5b78d..3fbfe22ef 100644 --- a/crates/emmylua_parser/src/parser/lua_doc_parser.rs +++ b/crates/emmylua_parser/src/parser/lua_doc_parser.rs @@ -98,6 +98,11 @@ impl LuaDocParser<'_, '_> { self.eat_current_and_lex_next(); } } + LuaDocLexerState::CastExpr => { + while matches!(self.current_token, LuaTokenKind::TkWhitespace) { + self.eat_current_and_lex_next(); + } + } LuaDocLexerState::Init => { while matches!( self.current_token, @@ -204,6 +209,11 @@ impl LuaDocParser<'_, '_> { self.current_token = LuaTokenKind::TkDocTrivia; } } + LuaDocLexerState::Normal => { + if self.lexer.state == LuaDocLexerState::CastExpr { + self.re_calc_cast_type(); + } + } _ => {} } @@ -229,6 +239,31 @@ impl LuaDocParser<'_, '_> { self.bump(); } + fn re_calc_cast_type(&mut self) { + if self.lexer.is_invalid() { + return; + } + + // cast key 的解析是可以以`.`分割的, 但 `type` 不能以`.`分割必须视为一个整体, 因此我们需要回退 + let read_range = self.current_token_range; + let origin_token_range = self.tokens[self.origin_token_index].range; + let origin_token_kind = self.tokens[self.origin_token_index].kind; + let new_range = SourceRange { + start_offset: read_range.start_offset, + length: origin_token_range.end_offset() - read_range.start_offset, + }; + self.lexer.reset(origin_token_kind, new_range); + + self.lexer.state = LuaDocLexerState::Normal; + + let token = self.lex_token(); + self.current_token = token.kind; + + if !token.range.is_empty() { + self.current_token_range = token.range; + } + } + pub fn bump_to_end(&mut self) { self.set_state(LuaDocLexerState::Trivia); self.eat_current_and_lex_next(); diff --git a/crates/emmylua_parser/src/syntax/node/doc/tag.rs b/crates/emmylua_parser/src/syntax/node/doc/tag.rs index 14da91dcf..e36475748 100644 --- a/crates/emmylua_parser/src/syntax/node/doc/tag.rs +++ b/crates/emmylua_parser/src/syntax/node/doc/tag.rs @@ -2,8 +2,9 @@ use crate::{ kind::LuaSyntaxKind, syntax::{traits::LuaAstNode, LuaDocDescriptionOwner}, BinaryOperator, LuaAstChildren, LuaAstToken, LuaAstTokenChildren, LuaBinaryOpToken, - LuaDocVersionNumberToken, LuaDocVisibilityToken, LuaGeneralToken, LuaKind, LuaNameToken, - LuaNumberToken, LuaPathToken, LuaStringToken, LuaSyntaxNode, LuaTokenKind, LuaVersionCondition, + LuaDocVersionNumberToken, LuaDocVisibilityToken, LuaExpr, LuaGeneralToken, LuaKind, + LuaNameToken, LuaNumberToken, LuaPathToken, LuaStringToken, LuaSyntaxNode, LuaTokenKind, + LuaVersionCondition, }; use super::{ @@ -993,8 +994,8 @@ impl LuaDocTagCast { self.children() } - pub fn get_name_token(&self) -> Option { - self.token() + pub fn get_key_expr(&self) -> Option { + self.child() } } diff --git a/crates/emmylua_parser/src/syntax/node/doc/types.rs b/crates/emmylua_parser/src/syntax/node/doc/types.rs index 76ae6c6e9..d0fbe2ae7 100644 --- a/crates/emmylua_parser/src/syntax/node/doc/types.rs +++ b/crates/emmylua_parser/src/syntax/node/doc/types.rs @@ -1,9 +1,10 @@ use crate::{ - LuaAstChildren, LuaAstNode, LuaAstToken, LuaDocTypeBinaryToken, LuaDocTypeUnaryToken, - LuaLiteralToken, LuaNameToken, LuaSyntaxKind, LuaSyntaxNode, LuaTokenKind, + LuaAstChildren, LuaAstNode, LuaAstToken, LuaDocDescriptionOwner, LuaDocTypeBinaryToken, + LuaDocTypeUnaryToken, LuaLiteralToken, LuaNameToken, LuaSyntaxKind, LuaSyntaxNode, + LuaTokenKind, }; -use super::{LuaDocDescription, LuaDocObjectField, LuaDocTypeList}; +use super::{LuaDocObjectField, LuaDocTypeList}; #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum LuaDocType { @@ -717,12 +718,10 @@ impl LuaAstNode for LuaDocOneLineField { } } +impl LuaDocDescriptionOwner for LuaDocOneLineField {} + impl LuaDocOneLineField { pub fn get_type(&self) -> Option { self.child() } - - pub fn get_description(&self) -> Option { - self.child() - } } diff --git a/crates/emmylua_parser/src/syntax/node/lua/mod.rs b/crates/emmylua_parser/src/syntax/node/lua/mod.rs index 93553d899..021e501d9 100644 --- a/crates/emmylua_parser/src/syntax/node/lua/mod.rs +++ b/crates/emmylua_parser/src/syntax/node/lua/mod.rs @@ -249,10 +249,12 @@ impl LuaAstNode for LuaTableField { impl LuaCommentOwner for LuaTableField {} impl LuaTableField { + /// TableFieldAssign: { a = "a" } pub fn is_assign_field(&self) -> bool { self.syntax().kind() == LuaSyntaxKind::TableFieldAssign.into() } + /// TableFieldValue: { "a" } pub fn is_value_field(&self) -> bool { self.syntax().kind() == LuaSyntaxKind::TableFieldValue.into() }