summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--autoload/vim_ai_config.vim10
-rw-r--r--py/chat.py4
-rw-r--r--py/complete.py5
-rw-r--r--py/types.py17
-rw-r--r--py/utils.py20
5 files changed, 53 insertions, 3 deletions
diff --git a/autoload/vim_ai_config.vim b/autoload/vim_ai_config.vim
index d0ee972..6da02ba 100644
--- a/autoload/vim_ai_config.vim
+++ b/autoload/vim_ai_config.vim
@@ -26,7 +26,11 @@ let g:vim_ai_complete_default = {
\ },
\ "ui": {
\ "paste_mode": 1,
-\ },
+\ },
+\ "provider": {
+\ "name": "openai.complete",
+\ "class": "OpenAIComplete"
+\ }
\}
let g:vim_ai_edit_default = {
\ "prompt": "",
@@ -94,6 +98,10 @@ let g:vim_ai_chat_default = {
\ "force_new_chat": 0,
\ "paste_mode": 1,
\ },
+\ "provider": {
+\ "name": "openai.complete",
+\ "class": "OpenAIChat"
+\ }
\}
if !exists("g:vim_ai_open_chat_presets")
diff --git a/py/chat.py b/py/chat.py
index 6861b51..810fc46 100644
--- a/py/chat.py
+++ b/py/chat.py
@@ -63,8 +63,10 @@ def run_ai_chat(context):
print('Answering...')
vim.command("redraw")
+ provider_class = load_provider(config['provider'])
+ provider = provider_class(config)
+ text_chunks = provider.request(messages)
- text_chunks = make_chat_text_chunks(messages, options)
render_text_chunks(text_chunks)
vim.command("normal! a\n\n>>> user\n\n")
diff --git a/py/complete.py b/py/complete.py
index 8078dea..765d6e0 100644
--- a/py/complete.py
+++ b/py/complete.py
@@ -42,7 +42,10 @@ def run_ai_completition(context):
if prompt:
print('Completing...')
vim.command("redraw")
- text_chunks = engines[engine](prompt)
+ provider_class = load_provider(config['provider'])
+ provider = provider_class(config)
+ messages = parse_chat_messages(f">>> user\n\n{prompt}".strip())
+ text_chunks = provider.request(messages)
render_text_chunks(text_chunks)
clear_echo_message()
except BaseException as error:
diff --git a/py/types.py b/py/types.py
new file mode 100644
index 0000000..819afe5
--- /dev/null
+++ b/py/types.py
@@ -0,0 +1,17 @@
+from collections.abc import Sequence, Mapping
+from typing import TypedDict, Protocol
+
+
+class AIProvider(Protocol):
+
+ def __init__(self, config: Mapping[str, str]) -> None:
+ pass
+
+ def request(self, messages: Sequence[Message]) -> Generator[str]:
+ pass
+
+class Message(TypedDict):
+ role: str
+ content: str
+ type: str
+
diff --git a/py/utils.py b/py/utils.py
index 8cd204f..2147785 100644
--- a/py/utils.py
+++ b/py/utils.py
@@ -359,3 +359,23 @@ def save_b64_to_file(path, b64_data):
f = open(path, "wb")
f.write(base64.b64decode(b64_data))
f.close()
+
+def load_provider(provider):
+ provider_name, provider_module = provider["name"].split(".")
+ if provider_name != "openai":
+ provider_path = os.path.join(f"{plugin_root}",
+ "..",
+ f"vim-ai-{provider_name}",
+ "py",
+ f"{provider_module}.py")
+ else:
+ return openai_request
+ vim.command(f"py3file {provider_path}")
+ try:
+ provider_class = globals()[provider["class"]]
+ except KeyError as error:
+ printDebug("[load-provider] provider: {}", error)
+ raise KeyError(error.message, "provider not found")
+ return provider_class
+
+