Backend optimisation #16
					 5 changed files with 104 additions and 31 deletions
				
			
		
							
								
								
									
										3
									
								
								.gitignore
									
										
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
										
									
									
										vendored
									
									
								
							|  | @ -1,4 +1,5 @@ | |||
| venv/ | ||||
| __pycache__ | ||||
| __pycache__/ | ||||
| .idea/ | ||||
| .vscode/ | ||||
| token.txt | ||||
|  |  | |||
							
								
								
									
										62
									
								
								py/api.py
									
										
									
									
									
								
							
							
						
						
									
										62
									
								
								py/api.py
									
										
									
									
									
								
							|  | @ -1,27 +1,52 @@ | |||
| import requests | ||||
| import json | ||||
| from transformers import AutoTokenizer, LlamaForCausalLM | ||||
| 
 | ||||
| from gradio_client import Client | ||||
| import os | ||||
| from mistralai import Mistral | ||||
| 
 | ||||
| class API: | ||||
|     # This method processes a message via transformers. (NOT FINISHED!) | ||||
|     @staticmethod | ||||
|     def process_text_transformers(prompt, model): | ||||
|         model = LlamaForCausalLM.from_pretrained(model) | ||||
|         tokenizer = AutoTokenizer.from_pretrained(model) | ||||
|     def process_text_mistralai(prompt, model, system): | ||||
|         with open("token.txt", "r") as f: | ||||
|             token = f.readlines()[0].strip() | ||||
| 
 | ||||
|         inputs = tokenizer(prompt, return_tensors="pt") | ||||
|         api_key = token | ||||
| 
 | ||||
|         generate_ids = model.generate(inputs.input_ids, max_length=30) | ||||
|         return tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | ||||
|         client = Mistral(api_key=api_key) | ||||
| 
 | ||||
|         chat_response = client.chat.complete( | ||||
|             model=model, | ||||
|             messages=[ | ||||
|                 { | ||||
|                     "role": "user", | ||||
|                     "content": prompt, | ||||
|                 }, { | ||||
|                     "role": "system", | ||||
|                     "content": system, | ||||
|                 }, | ||||
|             ] | ||||
|         ) | ||||
|         return chat_response.choices[0].message.content | ||||
|     @staticmethod | ||||
|     def process_text_gradio(prompt, model, system): | ||||
|         client = Client(model) | ||||
|         result = client.predict( | ||||
|             message=prompt, | ||||
|             system_message=system, | ||||
|             max_tokens=512, | ||||
|             temperature=0.7, | ||||
|             top_p=0.95, | ||||
|             api_name="/chat" | ||||
|         ) | ||||
|         return result; | ||||
| 
 | ||||
|     # This method processes a message via ollama | ||||
|     @staticmethod | ||||
|     def process_text_local(prompt, model): | ||||
|     def process_text_local(prompt, model, system): | ||||
|         ollama_url = "http://localhost:11434" | ||||
| 
 | ||||
|         response = requests.post( | ||||
|             f"{ollama_url}/api/generate", json={"model": model, "prompt": prompt} | ||||
|             f"{ollama_url}/api/generate", json={"model": model, "prompt": prompt, "system": system} | ||||
|         ) | ||||
| 
 | ||||
|         if response.status_code == 200: | ||||
|  | @ -37,17 +62,20 @@ class API: | |||
|             return "Error: " + response.text | ||||
| 
 | ||||
|     # This method sends a message to a certain AI. | ||||
|     def send_message(self, message, model): | ||||
| 
 | ||||
|     def send_message(self, message, model, system): | ||||
|         if model == 1: | ||||
|             answer = self.process_text_local(message, "phi3.5") | ||||
|             answer = self.process_text_local(message, "phi3.5", system) | ||||
|         elif model == 2: | ||||
|             answer = self.process_text_local(message, "gemma2:2b") | ||||
|             answer = self.process_text_local(message, "gemma2:9b", system) | ||||
|         elif model == 3: | ||||
|             answer = self.process_text_local(message, "qwen2:0.5b") | ||||
|             answer = self.process_text_local(message, "codegemma:2b", system) | ||||
|         elif model == 4: | ||||
|             answer = self.process_text_local(message, "codegemma:2b") | ||||
|             answer = self.process_text_gradio(message, "PatrickPluto/InterstellarAIChatbot", system) | ||||
|         elif model == 5: | ||||
|             answer = self.process_text_transformers(message, "meta-llama/Meta-Llama-3.1-8B") | ||||
|             answer = self.process_text_mistralai(message, "mistral-large-latest", system) | ||||
|         elif model == 6: | ||||
|             answer = self.process_text_mistralai(message, "codestral-latest", system) | ||||
|         else: | ||||
|             return "Invalid choice" | ||||
|         return answer | ||||
|  |  | |||
							
								
								
									
										17
									
								
								py/install.sh
									
										
									
									
									
										Executable file
									
								
							
							
						
						
									
										17
									
								
								py/install.sh
									
										
									
									
									
										Executable file
									
								
							|  | @ -0,0 +1,17 @@ | |||
| #!/bin/bash | ||||
| 
 | ||||
| python3 -m venv venv | ||||
| source venv/bin/activate | ||||
| pip install flask | ||||
| pip install SpeechRecognition | ||||
| pip install pyaudio | ||||
| pip install pocketsphinx | ||||
| pip install sentencepiece | ||||
| pip install pyqt5 | ||||
| pip install pyqtwebengine | ||||
| pip install gradio_client | ||||
| pip install mistralai | ||||
| 
 | ||||
| ollama pull phi3.5 | ||||
| ollama pull codegemma:2b | ||||
| ollama pull gemma2:9b | ||||
							
								
								
									
										10
									
								
								py/venv.sh
									
										
									
									
									
								
							
							
						
						
									
										10
									
								
								py/venv.sh
									
										
									
									
									
								
							|  | @ -1,10 +0,0 @@ | |||
| #!/bin/bash | ||||
| 
 | ||||
| virtualenv venv | ||||
| source venv/bin/activate | ||||
| pip install transformers | ||||
| pip install torch | ||||
| pip install flask | ||||
| pip install SpeechRecognition | ||||
| pip install pyaudio | ||||
| pip install pocketsphinx | ||||
							
								
								
									
										43
									
								
								py/web_flask.py
									
										
									
									
									
										
										
										Normal file → Executable file
									
								
							
							
						
						
									
										43
									
								
								py/web_flask.py
									
										
									
									
									
										
										
										Normal file → Executable file
									
								
							|  | @ -1,6 +1,14 @@ | |||
| #!venv/bin/python | ||||
| 
 | ||||
| from flask import Flask, request, render_template | ||||
| from api import API | ||||
| from voice_recognition import Voice | ||||
| import sys | ||||
| import threading | ||||
| from PyQt5.QtCore import * | ||||
| from PyQt5.QtWebEngineWidgets import * | ||||
| from PyQt5.QtWidgets import * | ||||
| 
 | ||||
| 
 | ||||
| APP = Flask(__name__) | ||||
| api = API() | ||||
|  | @ -13,6 +21,8 @@ messages = [] | |||
| @APP.route('/', methods=['GET', 'POST']) | ||||
| def index(): | ||||
|     global messages | ||||
|     system_prompt = 'You are a helpful assistant.' | ||||
|     system = 'Your system prompt is: \"'+system_prompt+'\" The following is your chat log so far: \n' | ||||
| 
 | ||||
|     if request.method == 'POST': | ||||
|         option = request.form['option'] | ||||
|  | @ -25,7 +35,12 @@ def index(): | |||
|         elif option == "chat": | ||||
|             messages.append(f"User: {user_message}") | ||||
| 
 | ||||
|         ai_response = "AI: " + api.send_message(user_message, 3) | ||||
|         for line in messages: | ||||
|             system += line + '\n' | ||||
| 
 | ||||
|         system += "The chat log is now finished." | ||||
| 
 | ||||
|         ai_response = "AI: " + api.send_message(user_message, 5, system) | ||||
|         messages.append(ai_response) | ||||
| 
 | ||||
|     return render_template('index.html', messages=messages) | ||||
|  | @ -42,6 +57,28 @@ def contact(): | |||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|     APP.run(debug=True) | ||||
|     qapp = QApplication(sys.argv) | ||||
| 
 | ||||
| # This is a comment --> test if this creates a merge conflict | ||||
|     view = QWebEngineView() | ||||
| 
 | ||||
|     view.setGeometry(100, 100, 1280, 720) | ||||
|     view.setWindowTitle("InterstellarAI") | ||||
| 
 | ||||
|     view.setUrl(QUrl("http://localhost:5000")) | ||||
| 
 | ||||
|     view.show() | ||||
| 
 | ||||
|     def run_flask(): | ||||
|         APP.run() | ||||
| 
 | ||||
|     def stop_flask(): | ||||
|         thread.join() | ||||
|         qapp.quit() | ||||
| 
 | ||||
|     thread = threading.Thread(target=run_flask) | ||||
|     thread.daemon = True | ||||
|     thread.start() | ||||
| 
 | ||||
|     qapp.aboutToQuit.connect(stop_flask) | ||||
| 
 | ||||
|     sys.exit(qapp.exec_()) | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue