From a8f9ba27bdd60d77be356edbec7eb5bf74e8ef5d Mon Sep 17 00:00:00 2001
From: Patrick_Pluto <patrick_pluto@noreply.codeberg.org>
Date: Mon, 23 Sep 2024 11:01:39 +0200
Subject: [PATCH] Fixes and Multithreading

---
 py/ai.py            | 34 +++++++++++++++++++++++++++-------
 py/api.py           | 22 ++++++++++++++++------
 py/requirements.txt |  1 +
 3 files changed, 44 insertions(+), 13 deletions(-)

diff --git a/py/ai.py b/py/ai.py
index 0a57b57..bfde713 100644
--- a/py/ai.py
+++ b/py/ai.py
@@ -1,4 +1,5 @@
 from mistralai import Mistral
+from openai import OpenAI
 import ollama
 
 
@@ -12,15 +13,15 @@ class AI:
             options={"temperature": 0.5},
         )
 
-        return_class.ai_response[access_token] = ""
+        with return_class.ai_response_lock:
+            return_class.ai_response[access_token] = ""
 
         for chunk in stream:
-            return_class.ai_response[access_token] += chunk['message']['content']
+            with return_class.ai_response_lock:
+                return_class.ai_response[access_token] += chunk['message']['content']
 
     @staticmethod
-    def process_mistralai(model, messages, return_class, access_token):
-        with open("api_key.txt", 'r') as f:
-            api_key = f.read().strip()
+    def process_mistralai(model, messages, return_class, access_token, api_key):
 
         client = Mistral(api_key=api_key)
 
@@ -29,7 +30,26 @@ class AI:
             messages=messages
         )
 
-        return_class.ai_response[access_token] = ""
+        with return_class.ai_response_lock:
+            return_class.ai_response[access_token] = ""
 
         for chunk in stream_response:
-            return_class.ai_response[access_token] += chunk.data.choices[0].delta.content
+            with return_class.ai_response_lock:
+                return_class.ai_response[access_token] += chunk.data.choices[0].delta.content
+
+    @staticmethod
+    def process_openai(model, messages, return_class, access_token, api_key):
+
+        client = OpenAI(api_key=api_key)
+
+        stream_response = client.chat.completions.create(
+            model=model,
+            messages=messages
+        )
+
+        with return_class.ai_response_lock:
+            return_class.ai_response[access_token] = ""
+
+        for chunk in stream_response:
+            with return_class.ai_response_lock:
+                return_class.ai_response[access_token] += chunk.choices[0].delta.content
\ No newline at end of file
diff --git a/py/api.py b/py/api.py
index b52870f..386cd4f 100644
--- a/py/api.py
+++ b/py/api.py
@@ -1,6 +1,7 @@
 from flask import Flask, request, jsonify
 from flask_cors import CORS
 import secrets
+import threading
 from ai import AI
 from db import DB
 from OpenSSL import crypto
@@ -8,12 +9,13 @@ from OpenSSL import crypto
 
 class API:
     def __init__(self):
-        self.crypt_size = 4096
+        self.crypt_size = 1
         self.app = Flask(__name__)
         self.ai_response = {}
         self.ai = AI()
         self.db = DB()
         self.db.load_database()
+        self.ai_response_lock = threading.Lock()
         CORS(self.app)
 
     def run(self):
@@ -34,10 +36,18 @@ class API:
                 return jsonify({'status': 401, 'error': 'Invalid access token'})
 
             if model_type == "local":
-                self.ai.process_local(ai_model, messages, self, access_token)
-            if model_type == "mistral":
-                self.ai.process_mistralai(ai_model, messages, self, access_token)
-            return jsonify({'status': 200})
+                thread = threading.Thread(target=self.ai.process_local, args=(ai_model, messages, self, access_token))
+                thread.start()
+                thread.join()
+                return jsonify({'status': 200})
+            elif model_type == "mistral":
+                api_key = data.get('api_key')
+                thread = threading.Thread(target=self.ai.process_mistralai, args=(ai_model, messages, self, access_token, api_key))
+                thread.start()
+                thread.join()
+                return jsonify({'status': 200})
+
+            return jsonify({'status': 401, 'error': 'Invalid AI model type'})
 
         @self.app.route('/interstellar/api/ai_get', methods=['GET'])
         def get_ai():
@@ -95,7 +105,7 @@ class API:
             f.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, k).decode("utf-8"))
 
         ssl_context = ("cert.pem", "key.pem")
-        self.app.run(debug=True, host='0.0.0.0', port=5000, ssl_context=ssl_context)
+        self.app.run(debug=True, host='0.0.0.0', port=5000)
 
 
 if __name__ == '__main__':
diff --git a/py/requirements.txt b/py/requirements.txt
index 144c571..3c2be3b 100644
--- a/py/requirements.txt
+++ b/py/requirements.txt
@@ -2,4 +2,5 @@ flask
 flask-cors
 ollama
 mistralai
+openai
 pyOpenSSL
\ No newline at end of file