summaryrefslogtreecommitdiff
path: root/py
diff options
context:
space:
mode:
authorMax Resnick <max@ofmax.li>2025-01-28 22:27:47 -0800
committerMax Resnick <max@ofmax.li>2025-01-31 22:45:23 -0800
commitc8162fc43a748a97aab4647be2f5cdf50bc739ea (patch)
tree0099688a7a09af76a2ae978790c5e013b59bcb9c /py
parent380d5cdd9538c2522dfc8d03a8a261760bb0439a (diff)
downloadvim-ai-c8162fc43a748a97aab4647be2f5cdf50bc739ea.tar.gz
chore: rebase and fix up conflicts
Diffstat (limited to '')
-rw-r--r--py/chat.py4
-rw-r--r--py/complete.py5
-rw-r--r--py/types.py17
-rw-r--r--py/utils.py20
4 files changed, 44 insertions, 2 deletions
diff --git a/py/chat.py b/py/chat.py
index 6861b51..810fc46 100644
--- a/py/chat.py
+++ b/py/chat.py
@@ -63,8 +63,10 @@ def run_ai_chat(context):
print('Answering...')
vim.command("redraw")
+ provider_class = load_provider(config['provider'])
+ provider = provider_class(config)
+ text_chunks = provider.request(messages)
- text_chunks = make_chat_text_chunks(messages, options)
render_text_chunks(text_chunks)
vim.command("normal! a\n\n>>> user\n\n")
diff --git a/py/complete.py b/py/complete.py
index 8078dea..765d6e0 100644
--- a/py/complete.py
+++ b/py/complete.py
@@ -42,7 +42,10 @@ def run_ai_completition(context):
if prompt:
print('Completing...')
vim.command("redraw")
- text_chunks = engines[engine](prompt)
+ provider_class = load_provider(config['provider'])
+ provider = provider_class(config)
+ messages = parse_chat_messages(f">>> user\n\n{prompt}".strip())
+ text_chunks = provider.request(messages)
render_text_chunks(text_chunks)
clear_echo_message()
except BaseException as error:
diff --git a/py/types.py b/py/types.py
new file mode 100644
index 0000000..819afe5
--- /dev/null
+++ b/py/types.py
@@ -0,0 +1,17 @@
+from collections.abc import Sequence, Mapping
+from typing import TypedDict, Protocol
+
+
+class AIProvider(Protocol):
+
+ def __init__(self, config: Mapping[str, str]) -> None:
+ pass
+
+ def request(self, messages: Sequence[Message]) -> Generator[str]:
+ pass
+
+class Message(TypedDict):
+ role: str
+ content: str
+ type: str
+
diff --git a/py/utils.py b/py/utils.py
index 8cd204f..2147785 100644
--- a/py/utils.py
+++ b/py/utils.py
@@ -359,3 +359,23 @@ def save_b64_to_file(path, b64_data):
f = open(path, "wb")
f.write(base64.b64decode(b64_data))
f.close()
+
+def load_provider(provider):
+ provider_name, provider_module = provider["name"].split(".")
+ if provider_name != "openai":
+ provider_path = os.path.join(f"{plugin_root}",
+ "..",
+ f"vim-ai-{provider_name}",
+ "py",
+ f"{provider_module}.py")
+ else:
+ return openai_request
+ vim.command(f"py3file {provider_path}")
+ try:
+ provider_class = globals()[provider["class"]]
+ except KeyError as error:
+ printDebug("[load-provider] provider: {}", error)
+ raise KeyError(error.message, "provider not found")
+ return provider_class
+
+