summaryrefslogtreecommitdiff
path: root/py
diff options
context:
space:
mode:
authorMartin Bielik <mx.bielik@gmail.com>2023-10-21 19:04:57 +0200
committerGitHub <noreply@github.com>2023-10-21 19:04:57 +0200
commitccf981974232636d08a48094bf9dad3812a43cc8 (patch)
tree611803cd36432c6e2cb68edf4cf3f81d502c4615 /py
parent8f8083ba0eed23150020b698e74d9302f7212c5d (diff)
parent55c4e2ec836e48552b52fb4b7878f7b50f67b53b (diff)
downloadvim-ai-ccf981974232636d08a48094bf9dad3812a43cc8.tar.gz
Merge pull request #59 from madox2/base-url-config
Custom APIs, closes #55, closes #51
Diffstat (limited to 'py')
-rw-r--r--py/chat.py3
-rw-r--r--py/complete.py6
-rw-r--r--py/utils.py20
3 files changed, 20 insertions, 9 deletions
diff --git a/py/chat.py b/py/chat.py
index 93d31bf..6d88015 100644
--- a/py/chat.py
+++ b/py/chat.py
@@ -69,7 +69,8 @@ try:
**openai_options
}
printDebug("[chat] request: {}", request)
- response = openai_request('https://api.openai.com/v1/chat/completions', request, http_options)
+ url = config_options['endpoint_url']
+ response = openai_request(url, request, http_options)
def map_chunk(resp):
printDebug("[chat] response: {}", resp)
return resp['choices'][0]['delta'].get('content', '')
diff --git a/py/complete.py b/py/complete.py
index c8d45fe..8386c09 100644
--- a/py/complete.py
+++ b/py/complete.py
@@ -17,7 +17,8 @@ def complete_engine(prompt):
**openai_options
}
printDebug("[engine-complete] request: {}", request)
- response = openai_request('https://api.openai.com/v1/completions', request, http_options)
+ url = config_options['endpoint_url']
+ response = openai_request(url, request, http_options)
def map_chunk(resp):
printDebug("[engine-complete] response: {}", resp)
return resp['choices'][0].get('text', '')
@@ -35,7 +36,8 @@ def chat_engine(prompt):
**openai_options
}
printDebug("[engine-chat] request: {}", request)
- response = openai_request('https://api.openai.com/v1/chat/completions', request, http_options)
+ url = config_options['endpoint_url']
+ response = openai_request(url, request, http_options)
def map_chunk(resp):
printDebug("[engine-chat] response: {}", resp)
return resp['choices'][0]['delta'].get('content', '')
diff --git a/py/utils.py b/py/utils.py
index 2e1f975..e5203bd 100644
--- a/py/utils.py
+++ b/py/utils.py
@@ -14,6 +14,9 @@ import traceback
is_debugging = vim.eval("g:vim_ai_debug") == "1"
debug_log_file = vim.eval("g:vim_ai_debug_log_file")
+class KnownError(Exception):
+ pass
+
def load_api_key():
config_file_path = os.path.join(os.path.expanduser("~"), ".config/openai.token")
api_key_param_value = os.getenv("OPENAI_API_KEY")
@@ -24,7 +27,7 @@ def load_api_key():
pass
if not api_key_param_value:
- raise Exception("Missing OpenAI API key")
+ raise KnownError("Missing OpenAI API key")
# The text is in format of "<api key>,<org id>" and the
# <org id> part is optional
@@ -56,6 +59,7 @@ def make_openai_options(options):
def make_http_options(options):
return {
'request_timeout': float(options['request_timeout']),
+ 'enable_auth': bool(int(options['enable_auth'])),
}
def render_text_chunks(chunks):
@@ -130,16 +134,18 @@ def printDebug(text, *args):
OPENAI_RESP_DATA_PREFIX = 'data: '
OPENAI_RESP_DONE = '[DONE]'
-(OPENAI_API_KEY, OPENAI_ORG_ID) = load_api_key()
def openai_request(url, data, options):
+ enable_auth=options['enable_auth']
headers = {
"Content-Type": "application/json",
- "Authorization": f"Bearer {OPENAI_API_KEY}"
}
+ if enable_auth:
+ (OPENAI_API_KEY, OPENAI_ORG_ID) = load_api_key()
+ headers['Authorization'] = f"Bearer {OPENAI_API_KEY}"
- if OPENAI_ORG_ID is not None:
- headers["OpenAI-Organization"] = f"{OPENAI_ORG_ID}"
+ if OPENAI_ORG_ID is not None:
+ headers["OpenAI-Organization"] = f"{OPENAI_ORG_ID}"
request_timeout=options['request_timeout']
req = urllib.request.Request(
@@ -153,7 +159,7 @@ def openai_request(url, data, options):
line = line_bytes.decode("utf-8", errors="replace")
if line.startswith(OPENAI_RESP_DATA_PREFIX):
line_data = line[len(OPENAI_RESP_DATA_PREFIX):-1]
- if line_data == OPENAI_RESP_DONE:
+ if line_data.strip() == OPENAI_RESP_DONE:
pass
else:
openai_obj = json.loads(line_data)
@@ -183,6 +189,8 @@ def handle_completion_error(error):
elif status_code == 429:
msg += ' (Hint: verify that your billing plan is "Pay as you go")'
print_info_message(msg)
+ elif isinstance(error, KnownError):
+ print_info_message(str(error))
else:
raise error