diff options
| -rw-r--r-- | autoload/vim_ai_config.vim | 10 | ||||
| -rw-r--r-- | py/chat.py | 4 | ||||
| -rw-r--r-- | py/complete.py | 5 | ||||
| -rw-r--r-- | py/types.py | 17 | ||||
| -rw-r--r-- | py/utils.py | 20 |
5 files changed, 53 insertions, 3 deletions
diff --git a/autoload/vim_ai_config.vim b/autoload/vim_ai_config.vim index d0ee972..6da02ba 100644 --- a/autoload/vim_ai_config.vim +++ b/autoload/vim_ai_config.vim @@ -26,7 +26,11 @@ let g:vim_ai_complete_default = { \ }, \ "ui": { \ "paste_mode": 1, -\ }, +\ }, +\ "provider": { +\ "name": "openai.complete", +\ "class": "OpenAIComplete" +\ } \} let g:vim_ai_edit_default = { \ "prompt": "", @@ -94,6 +98,10 @@ let g:vim_ai_chat_default = { \ "force_new_chat": 0, \ "paste_mode": 1, \ }, +\ "provider": { +\ "name": "openai.complete", +\ "class": "OpenAIChat" +\ } \} if !exists("g:vim_ai_open_chat_presets") @@ -63,8 +63,10 @@ def run_ai_chat(context): print('Answering...') vim.command("redraw") + provider_class = load_provider(config['provider']) + provider = provider_class(config) + text_chunks = provider.request(messages) - text_chunks = make_chat_text_chunks(messages, options) render_text_chunks(text_chunks) vim.command("normal! a\n\n>>> user\n\n") diff --git a/py/complete.py b/py/complete.py index 8078dea..765d6e0 100644 --- a/py/complete.py +++ b/py/complete.py @@ -42,7 +42,10 @@ def run_ai_completition(context): if prompt: print('Completing...') vim.command("redraw") - text_chunks = engines[engine](prompt) + provider_class = load_provider(config['provider']) + provider = provider_class(config) + messages = parse_chat_messages(f">>> user\n\n{prompt}".strip()) + text_chunks = provider.request(messages) render_text_chunks(text_chunks) clear_echo_message() except BaseException as error: diff --git a/py/types.py b/py/types.py new file mode 100644 index 0000000..819afe5 --- /dev/null +++ b/py/types.py @@ -0,0 +1,17 @@ +from collections.abc import Sequence, Mapping +from typing import TypedDict, Protocol + + +class AIProvider(Protocol): + + def __init__(self, config: Mapping[str, str]) -> None: + pass + + def request(self, messages: Sequence[Message]) -> Generator[str]: + pass + +class Message(TypedDict): + role: str + content: str + type: str + diff --git a/py/utils.py b/py/utils.py index 8cd204f..2147785 100644 --- a/py/utils.py +++ b/py/utils.py @@ -359,3 +359,23 @@ def save_b64_to_file(path, b64_data): f = open(path, "wb") f.write(base64.b64decode(b64_data)) f.close() + +def load_provider(provider): + provider_name, provider_module = provider["name"].split(".") + if provider_name != "openai": + provider_path = os.path.join(f"{plugin_root}", + "..", + f"vim-ai-{provider_name}", + "py", + f"{provider_module}.py") + else: + return openai_request + vim.command(f"py3file {provider_path}") + try: + provider_class = globals()[provider["class"]] + except KeyError as error: + printDebug("[load-provider] provider: {}", error) + raise KeyError(error.message, "provider not found") + return provider_class + + |