Multithreadding #8

Merged
sageTheDm merged 2 commits from React-Group/interstellar_ai:main into main 2024-09-23 11:56:31 +02:00
3 changed files with 44 additions and 13 deletions

View file

@ -1,4 +1,5 @@
from mistralai import Mistral
from openai import OpenAI
import ollama
@ -12,15 +13,15 @@ class AI:
options={"temperature": 0.5},
)
with return_class.ai_response_lock:
return_class.ai_response[access_token] = ""
for chunk in stream:
with return_class.ai_response_lock:
return_class.ai_response[access_token] += chunk['message']['content']
@staticmethod
def process_mistralai(model, messages, return_class, access_token):
with open("api_key.txt", 'r') as f:
api_key = f.read().strip()
def process_mistralai(model, messages, return_class, access_token, api_key):
client = Mistral(api_key=api_key)
@ -29,7 +30,26 @@ class AI:
messages=messages
)
with return_class.ai_response_lock:
return_class.ai_response[access_token] = ""
for chunk in stream_response:
with return_class.ai_response_lock:
return_class.ai_response[access_token] += chunk.data.choices[0].delta.content
@staticmethod
def process_openai(model, messages, return_class, access_token, api_key):
client = OpenAI(api_key=api_key)
stream_response = client.chat.completions.create(
model=model,
messages=messages
)
with return_class.ai_response_lock:
return_class.ai_response[access_token] = ""
for chunk in stream_response:
with return_class.ai_response_lock:
return_class.ai_response[access_token] += chunk.choices[0].delta.content

View file

@ -1,6 +1,7 @@
from flask import Flask, request, jsonify
from flask_cors import CORS
import secrets
import threading
from ai import AI
from db import DB
from OpenSSL import crypto
@ -8,12 +9,13 @@ from OpenSSL import crypto
class API:
def __init__(self):
self.crypt_size = 4096
self.crypt_size = 1
self.app = Flask(__name__)
self.ai_response = {}
self.ai = AI()
self.db = DB()
self.db.load_database()
self.ai_response_lock = threading.Lock()
CORS(self.app)
def run(self):
@ -34,10 +36,18 @@ class API:
return jsonify({'status': 401, 'error': 'Invalid access token'})
if model_type == "local":
self.ai.process_local(ai_model, messages, self, access_token)
if model_type == "mistral":
self.ai.process_mistralai(ai_model, messages, self, access_token)
thread = threading.Thread(target=self.ai.process_local, args=(ai_model, messages, self, access_token))
thread.start()
thread.join()
return jsonify({'status': 200})
elif model_type == "mistral":
api_key = data.get('api_key')
thread = threading.Thread(target=self.ai.process_mistralai, args=(ai_model, messages, self, access_token, api_key))
thread.start()
thread.join()
return jsonify({'status': 200})
return jsonify({'status': 401, 'error': 'Invalid AI model type'})
@self.app.route('/interstellar/api/ai_get', methods=['GET'])
def get_ai():
@ -95,7 +105,7 @@ class API:
f.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, k).decode("utf-8"))
ssl_context = ("cert.pem", "key.pem")
self.app.run(debug=True, host='0.0.0.0', port=5000, ssl_context=ssl_context)
self.app.run(debug=True, host='0.0.0.0', port=5000)
if __name__ == '__main__':

View file

@ -2,4 +2,5 @@ flask
flask-cors
ollama
mistralai
openai
pyOpenSSL