summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--autoload/vim_ai.vim16
-rw-r--r--plugin/vim-ai.vim6
-rw-r--r--py/context.py3
-rw-r--r--py/roles.py10
-rw-r--r--tests/context_test.py11
-rw-r--r--tests/resources/roles.ini3
-rw-r--r--tests/roles_test.py16
7 files changed, 53 insertions, 12 deletions
diff --git a/autoload/vim_ai.vim b/autoload/vim_ai.vim
index 03ac978..f6cbed5 100644
--- a/autoload/vim_ai.vim
+++ b/autoload/vim_ai.vim
@@ -299,9 +299,21 @@ function! vim_ai#AIRedoRun() abort
endif
endfunction
-function! vim_ai#RoleCompletion(A,L,P) abort
+function! s:RoleCompletion(A, command_type) abort
call s:ImportPythonModules()
- let l:role_list = py3eval("load_ai_role_names()")
+ let l:role_list = py3eval("load_ai_role_names(unwrap('a:command_type'))")
call map(l:role_list, '"/" . v:val')
return filter(l:role_list, 'v:val =~ "^' . a:A . '"')
endfunction
+
+function! vim_ai#RoleCompletionComplete(A,L,P) abort
+ return s:RoleCompletion(a:A, 'complete')
+endfunction
+
+function! vim_ai#RoleCompletionEdit(A,L,P) abort
+ return s:RoleCompletion(a:A, 'edit')
+endfunction
+
+function! vim_ai#RoleCompletionChat(A,L,P) abort
+ return s:RoleCompletion(a:A, 'chat')
+endfunction
diff --git a/plugin/vim-ai.vim b/plugin/vim-ai.vim
index d27f312..648578a 100644
--- a/plugin/vim-ai.vim
+++ b/plugin/vim-ai.vim
@@ -4,8 +4,8 @@ if !has('python3')
finish
endif
-command! -range -nargs=? -complete=customlist,vim_ai#RoleCompletion AI <line1>,<line2>call vim_ai#AIRun(<range>, {}, <q-args>)
-command! -range -nargs=? -complete=customlist,vim_ai#RoleCompletion AIEdit <line1>,<line2>call vim_ai#AIEditRun(<range>, {}, <q-args>)
-command! -range -nargs=? -complete=customlist,vim_ai#RoleCompletion AIChat <line1>,<line2>call vim_ai#AIChatRun(<range>, {}, <q-args>)
+command! -range -nargs=? -complete=customlist,vim_ai#RoleCompletionComplete AI <line1>,<line2>call vim_ai#AIRun(<range>, {}, <q-args>)
+command! -range -nargs=? -complete=customlist,vim_ai#RoleCompletionEdit AIEdit <line1>,<line2>call vim_ai#AIEditRun(<range>, {}, <q-args>)
+command! -range -nargs=? -complete=customlist,vim_ai#RoleCompletionChat AIChat <line1>,<line2>call vim_ai#AIChatRun(<range>, {}, <q-args>)
command! -nargs=? AINewChat call vim_ai#AINewChatRun(<f-args>)
command! AIRedo call vim_ai#AIRedoRun()
diff --git a/py/context.py b/py/context.py
index f1eacd6..87c3a14 100644
--- a/py/context.py
+++ b/py/context.py
@@ -92,7 +92,8 @@ def load_role_config(role):
enhance_roles_with_custom_function(roles)
- if not role in roles:
+ postfixes = ["", ".complete", ".edit", ".chat"]
+ if not any([f"{role}{postfix}" in roles for postfix in postfixes]):
raise Exception(f"Role `{role}` not found")
if is_deprecated_role_syntax(roles, role):
diff --git a/py/roles.py b/py/roles.py
index 37e5b4d..692ac6c 100644
--- a/py/roles.py
+++ b/py/roles.py
@@ -6,7 +6,7 @@ if "PYTEST_VERSION" in os.environ:
roles_py_imported = True
-def load_ai_role_names():
+def load_ai_role_names(command_type):
roles_config_path = os.path.expanduser(vim.eval("g:vim_ai_roles_config_file"))
if not os.path.exists(roles_config_path):
raise Exception(f"Role config file does not exist: {roles_config_path}")
@@ -16,6 +16,10 @@ def load_ai_role_names():
enhance_roles_with_custom_function(roles)
- role_names = [name for name in roles.sections() if not '.' in name]
+ role_names = set()
+ for name in roles.sections():
+ parts = name.split('.')
+ if len(parts) == 1 or parts[-1] == command_type:
+ role_names.add(parts[0])
- return role_names
+ return list(role_names)
diff --git a/tests/context_test.py b/tests/context_test.py
index 1e179f1..9f7d004 100644
--- a/tests/context_test.py
+++ b/tests/context_test.py
@@ -115,6 +115,17 @@ def test_multiple_role_configs():
assert 'https://localhost/chat' == actual_config['options']['endpoint_url']
assert 'simple role prompt:\nhello' == actual_prompt
+def test_chat_only_role():
+ context = make_ai_context({
+ 'config_default': default_config,
+ 'config_extension': {},
+ 'user_instruction': '/chat-only-role',
+ 'user_selection': '',
+ 'command_type': 'chat',
+ })
+ actual_config = context['config']
+ assert 'preset_tab' == actual_config['options']['open_chat_command']
+
def test_user_prompt():
assert 'fix grammar: helo word' == make_prompt( '', 'fix grammar: helo word', '', '')
assert 'fix grammar:\nhelo word' == make_prompt( '', 'fix grammar', 'helo word', '')
diff --git a/tests/resources/roles.ini b/tests/resources/roles.ini
index 16335d3..450df1d 100644
--- a/tests/resources/roles.ini
+++ b/tests/resources/roles.ini
@@ -15,6 +15,9 @@ config.options.endpoint_url = https://localhost/complete
config.engine = complete
config.options.endpoint_url = https://localhost/edit
+[chat-only-role.chat]
+config.options.open_chat_command = preset_tab
+
[deprecated-test-role-simple]
prompt = simple role prompt
[deprecated-test-role-simple.options]
diff --git a/tests/roles_test.py b/tests/roles_test.py
index 3230329..ac5525d 100644
--- a/tests/roles_test.py
+++ b/tests/roles_test.py
@@ -1,10 +1,20 @@
from roles import load_ai_role_names
def test_role_completion():
- role_names = load_ai_role_names()
- assert role_names == [
+ role_names = load_ai_role_names('complete')
+ assert set(role_names) == {
'test-role-simple',
'test-role',
'deprecated-test-role-simple',
'deprecated-test-role',
- ]
+ }
+
+def test_role_chat_only():
+ role_names = load_ai_role_names('chat')
+ assert set(role_names) == {
+ 'test-role-simple',
+ 'test-role',
+ 'chat-only-role',
+ 'deprecated-test-role-simple',
+ 'deprecated-test-role',
+ }